function [ raylsherl, par, iter ] = fit_raylsherl( data, par ) % function [ raylsherl, par, iter ] = fit_raylsherl( data, par ) % % Fit a Rayleigh + Shifted Erlang model on strictly positive data. % This implements unsupervised EM training. % The model was introduced in the article: % % "Unsupervised Spectral Subtraction", % by G. Lathoud, M. Magimai.-Doss, B. Mesot and H. Bourlard, % to appear in Proceedings of ASRU 2005. % % The pdf of this model can be briefly written as: % p( data | raylsherl ) = raylsherl.p0 * raylpdf( data, raylsherl.sigma ) % + ( 1 - raylsherl.p0 ) * erlangpdf( data - raylsherl.sigma * raylsherl.erlang_t, raylsherl.lambda ) % % % INPUT ARGUMENTS % % data: a matrix of real, strictly positive values (e.g. from % a magnitude spectrogram). % % par: a structure containing parameters. % % par.erlang_t_factor: scalar, mandatory parameter. % It defines the shift of the "Shifted Erlang", % as written above. % % par.do_fminsearch: select which implementation to use % par.do_fminsearch = 0: moment method, fast but approximate. % par.do_fminsearch = 1: ML method, slower but exact. % % par.fminsearch_options: if ML selected, a structure % for "fminsearch" containins some parameters. % % par.min_iter: minimum number of iterations to accomplish. % par.max_iter: maximum number of iterations to accomplish. % % Other less important parameters can be set, % see the code of this function with "dbtype fit_raylsherl". % % % OUTPUT ARGUMENTS % % raylsherl: a structure that defines the "Rayleigh + Shifted Erlang" model, % with the following fields: % raylsherl.p0 is the prior probability of silence. % raylsherl.sigma is the Rayleigh parameter. % raylsherl.erlang_t defines the shift of the Shifted Erlang. % raylsherl.lambda is the parameter of the Shifted Erlang. % % par: the input argument "par", after default values have been filled. % % iter: how many iterations were done before convergence or % par.max_iter was reached. % % % SEE ALSO: compute_post_raylsherl % % % By Guillaume Lathoud, 2005-08-26 % lathoud@idiap.ch if nargin < 2 error( [ mfilename ' needs at least 2 input argument.' ] ); end % Check that all data is >0 if ~all( data(:) > 0 ) error( [ mfilename ' requires all( data(:) > 0 ).' ] ) end % Check the main parameter check_param( { 'erlang_t_factor' }, fieldnames( par ) ); if ~( par.erlang_t_factor > 0 ) error( [ mfilename ' needs par.erlang_t_factor > 0.' ] ); end %%% % Fill in default parameters par_default = []; par_default.min_iter = 0; par_default.max_iter = +Inf; % Flooring of the Rayleigh: use "-Inf" to disable it par_default.sigma_floor = 1e-20; % Flooring of the Shifted Erlang: use "-Inf" to disable it par_default.lambda_floor = 1e-20; % Verbosity flag par_default.verbose = 1; % Sanity check flag: mostly used by "compute_post_raylsherl" par_default.sanity_check = 1; % Select which method for M step: % 0 means moment method: quick, fits mostly all real speech data, but certainly not every possible artificial data. % 1 means numerical optimization through fminsearch: the safest par_default.do_fminsearch = 1; % If "do_fminsearch = 1", then you can optionally pass some options % ( see OPTIMSET and FMINSERACH ) par_default.fminsearch_options = []; par = fill_default( par, par_default ); if par.verbose disp( [ mfilename ' parameters.' ] ); disp( par ); end %%% % "global" is used to transmit data to "negloglike" at a minimal cost (no copy) global nx x x2 log_x w sum_w x = data; clear data; % Don't care about dimensionality x = x(:); x2 = x .^ 2; log_x = log( x ); nx = numel( x ); %%%%%%%%%%%%%%% % Everything is ready, % we can do a rough fit % to initialize the model. raylsherl = []; raylsherl.erlang_t_factor = par.erlang_t_factor; raylsherl.sigma = max( par.sigma_floor, sqrt( 0.5 .* mean( x2 ) ) ); raylsherl.erlang_t = raylsherl.erlang_t_factor * raylsherl.sigma; % Redo it to ensure safe initialization: % make sure we'll use: % - data with small magnitudes only for the Rayleigh. % - data with large magnitudes only for the shifted Erlang. a_threshold = max( 2 * raylsherl.sigma, raylsherl.erlang_t ); % Init the Rayleigh lsr_ind = find( x <= a_threshold ); if isempty( lsr_ind ) error( [ mfilename ' is insane.' ] ); end raylsherl.sigma = max( par.sigma_floor, sqrt( 0.5 .* mean( x2( lsr_ind ) ) ) ); raylsherl.erlang_t = raylsherl.erlang_t_factor * raylsherl.sigma; % Init the Shifted Erlang grt_ind = find( x > a_threshold ); if isempty( grt_ind ) % Failure % Note: most likely this is an artificial or almost all-zero signal disp( [ mfilename ' could not find any data to initialize the Shifted Erlang.' ] ); raylsherl = []; iter = 0; return; end raylsherl.lambda = max( par.lambda_floor, mean( ( x( grt_ind ) - raylsherl.erlang_t ) .^ -1 ) ); % Init priors raylsherl.p0 = max( 0.1, min( 0.9, numel( lsr_ind ) / nx ) ); %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% % EM fitting % Keep track of the best model in terms of overall lkld best = []; % Start EM iter = 0; finished = 0; previous_raylsherl = raylsherl; while ~finished if par.verbose disp( ' ' ); disp( sprintf( [ mfilename ': model at iteration %d:' ], iter ) ); disp( raylsherl ); end %%% E: compute posteriors [ log_post, raylsherl.alldata_log_lkld ] = compute_post_raylsherl( raylsherl, x, par.sanity_check ); %%% Did we improve? if isempty( best ) is_improving = 1; else is_improving = best.alldata_log_lkld < raylsherl.alldata_log_lkld; end % Update if is_improving best = raylsherl; end % Termination test finished = ( iter >= par.max_iter ) | ( ( iter >= par.min_iter ) & ~is_improving ); if ~finished %%% M: fit the parameters % Prepare weights log_w = [ reshape( log_post(:,:,1), [], 1 ) reshape( log_post( :,:,2 ), [], 1 ) ]; % Update priors tmp = my_logsum_fast( log_w ); raylsherl.p0 = min( 1, exp( tmp( 1 ) - my_logsum_fast( tmp ) ) ); % Update distributions' parameters w = exp( log_w ); sum_w = sum( w, 1 ); if par.do_fminsearch % ML method % Numerical optimization: in this case we are sure to increase the likelihood theta_0 = [ raylsherl.sigma; raylsherl.lambda ]; theta = fminsearch( @negloglike, theta_0, par.fminsearch_options, raylsherl.erlang_t_factor ); raylsherl.sigma = theta( 1 ); raylsherl.lambda = theta( 2 ); raylsherl.erlang_t = raylsherl.erlang_t_factor * raylsherl.sigma; else % Moment method: in this case we are NOT sure to increase the likelihood % Rayleigh raylsherl.sigma = max( par.sigma_floor, sqrt( 0.5 * w(:,1).' * x2(:) / sum( w(:,1) ) ) ); % Shifted Erlang raylsherl.erlang_t = raylsherl.erlang_t_factor * raylsherl.sigma; grt_ind = find( x(:) > raylsherl.erlang_t ); if isempty( grt_ind ) % Failure % Note: most likely this is an artificial or almost all-zero signal disp( [ mfilename ' could not find any data to update the Shifted Erlang.' ] ); raylsherl = []; return; end raylsherl.lambda = max( par.lambda_floor, w( grt_ind,2 ).' * ( ( x( grt_ind ) - raylsherl.erlang_t ) .^ -1 ) / sum( w( grt_ind,2 ) ) ); end %%% % Prepare for next iteration iter = iter + 1; end end % Verbosity if par.verbose disp( ' ' ); disp( [ mfilename ': training finished. Best model:' ] ); disp( best ); end % Return the best model in a ML sense raylsherl = rmfield( best, 'alldata_log_lkld' ); %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% function y = negloglike( theta, erlang_t_factor ) global nx x x2 log_x w % Note: we use global to minimize the cost of transmitting the data % to this function, by avoiding duplicating it unnecessarily in the memory. sigma = theta( 1 ); lambda = theta( 2 ); % Define the domain of definition of all parameters if ( sigma <= 0 ) | ( lambda <= 0 ) | ( erlang_t_factor <= 0 ) y = +Inf; return; end erlang_t = erlang_t_factor * sigma; % Rayleigh y = -w( :, 1 ).' * ( log_x(:) - 2 * log( sigma ) - x2(:) / ( 2 * sigma ^ 2 ) ); % Shifted Erlang grt_ind = find( x > erlang_t ); if ~isempty( grt_ind ) tmp = x( grt_ind ) - erlang_t; y = y - w( grt_ind, 2 ).' * ( log( tmp ) + 2 * log( lambda ) - lambda * tmp ); end % Normalization (this is purely for visual comfort) y = y / nx;