function [ new_mvn_model, all_log_lkld ] = do_em_gmm( mvn_model, data, log_domain_flag ) % function [ new_mvn_model, all_log_lkld ] = do_em_gmm( mvn_model, data [ , log_domain_flag ] ) % % EM optimization of a GMM (full covariance matrix). % The GMM has M components mvn_model( 1:M ). % % By default log_domain_flag = 0 (linear domain). % So far the log domain implementation does not bring much. % % See also: estimate_K_gmm.m, compute_post_gmm.m. if nargin < 2 error( [ mfilename '.do_em_mvn() needs at least 2 input arguments.' ] ); end if ~exist( 'log_domain_flag', 'var' ) % By default, M-step implemented in log domain (more precise) log_domain_flag = 0; end % Dimensionality D = size( data, 1 ); % Number of samples N = size( data, 2 ); % Number of components, % each component is a multivariate normal distribution M = numel( mvn_model ); % Binary flag to mark which covariance matrix has turn'd diagonal if ~isfield( mvn_model, 'diag_Sigma' ) mvn_model( 1 ).diag_Sigma = []; end for m = 1:M if isempty( mvn_model( m ).diag_Sigma ) % By default, full covariance matrix mvn_model( m ).diag_Sigma = 0; end end % EM finished = 0; iter = 0; previous_all_log_lkld = -Inf; while ~finished disp( ' ' ); % First make sure that all covariance matrices are positive definite % If not, fall back on a diagonal covariance matrix (change of model) % for ONE particular matrix. % -> this prevents many crashes. for m = 1:M if ~mvn_model( m ).diag_Sigma [ R, p ] = chol( mvn_model( m ).Sigma ); mvn_model( m ).diag_Sigma = ( p ~= 0 ); if mvn_model( m ).diag_Sigma disp( sprintf( [ mfilename ': from now on, forcing covariance matrix of model %d to be diagonal.' ], m ) ); % Then we need to compute the lkld again % ( the convergence test won't be valid ) previous_all_log_lkld = -Inf; all_log_lkld = -Inf; end end if mvn_model( m ).diag_Sigma % use diagonal only mvn_model( m ).Sigma = diag( sparse( diag( mvn_model( m ).Sigma ) ) ); end end tmp = find( [ mvn_model.diag_Sigma ] ); if ~isempty( tmp ) disp( [ mfilename ': E-step: using diagonal only for models [' sprintf( ' %d', tmp ) ' ]' ] ); end % E step % Compute likelihoods % % We use a trick to avoid variance flooring % -> which means we may need to drop a component % and restart EM using the present state of the remaining % parameters is_stable = 0; while ~is_stable M = numel( mvn_model ); component_to_keep = ones( 1, M ); comp_log_joint = repmat( NaN, M, N ); ttt = dbstatus; warn_stop = ismember( 'warning', { ttt.cond } ); dbclear warning; error_stop = ismember( 'error', { ttt.cond } ); dbclear error; for m = 1:M try lastwarn( '' ); % log_mvnpdf may throw a warning or an error comp_log_joint( m,: ) = log( mvn_model( m ).w ) + reshape( log_mvnpdf( data.', mvn_model( m ).mu(:).', full( mvn_model( m ).Sigma ) ), 1, [] ); % Detect a possible unstability (if log_mvnpdf threw a warning) component_to_keep( m ) = isempty( lastwarn ); % Detect another possible unstability (NaN in the result) component_to_keep( m ) = component_to_keep( m ) & ~any( isnan( comp_log_joint( m,: ) ) ); % Detect another possible unstability (+Inf in the result) % It can happen if Sigma is collapsing, but still well-conditioned % ( likelihood around the mean is beyond machine's precision ) component_to_keep( m ) = component_to_keep( m ) & ~any( +Inf == comp_log_joint( m,: ) ); catch % Detect a possible unstability (if log_mvnpdf threw an error) component_to_keep( m ) = 0; end end if warn_stop, dbstop warning; end; if error_stop, dbstop error; end; is_stable = all( component_to_keep ); if ~is_stable % Drop unstable components disp( sprintf( [ mfilename ': %d unstable components, we''re going to drop them.' ], sum( double( ~component_to_keep ) ) ) ); mvn_model = mvn_model( find( component_to_keep ) ); M = numel( mvn_model ); % Rescale priors sum_w = sum( [ mvn_model.w ] ); for m = 1:M mvn_model( m ).w = mvn_model( m ).w / sum_w; end % The parameter space has changed -> restart EM previous_all_log_lkld = -Inf; all_log_lkld = -Inf; end end % Compute joints and posteriors if M > 1 log_lkld = my_logsum_fast( comp_log_joint ); else log_lkld = comp_log_joint; end all_log_lkld = sum( log_lkld ); comp_log_post = comp_log_joint - repmat( log_lkld, M, 1 ); % M step if log_domain_flag % DEBUG disp( [ mfilename ': M-step, log domain' ] ) % M step: implementation in log domain % ( more precise but slower than linear domain ) new_log_w = my_logsum_fast( comp_log_post.' ); new_log_w = new_log_w - my_logsum_fast( new_log_w ); new_w = exp( new_log_w ); new_w = new_w / sum( new_w ); for m = 1:M % Update weight of this component mvn_model( m ).w = new_w( m ); % Update parameters of this component % Weights (posteriors scaled so that they sum to one over all samples). tmp_log_w = comp_log_post( m,: ) - my_logsum_fast( comp_log_post( m,: ) ); % Mean for i = 1:D % Sum positive terms ind = find( data( i,: ) > 0 ); tmp = exp( my_logsum_fast( log( data( i, ind ) ) + tmp_log_w( ind ) ) ); % Sum negative terms ind = find( data( i,: ) < 0 ); tmp = tmp - exp( my_logsum_fast( log( -data( i, ind ) ) + tmp_log_w( ind ) ) ); % Result mvn_model( m ).mu( i ) = tmp; end % Covariance matrix % ( use of log avoids to reach the limits of the machine in most cases ) tmp_term = data - repmat( mvn_model( m ).mu(:), 1, N ); mvn_model( m ).Sigma = zeros( D, D ); for i = 1:D j_max = D; if mvn_model( m ).diag_Sigma j_max = i; end for j = i:j_max tmp_product = tmp_term( i,: ) .* tmp_term( j,: ); % Sum positive terms ind = find( tmp_product > 0 ); tmp = exp( my_logsum_fast( log( tmp_product( ind ) ) + tmp_log_w( ind ) ) ); % Sum negative terms ind = find( tmp_product < 0 ); tmp = tmp - exp( my_logsum_fast( log( -tmp_product( ind ) ) + tmp_log_w( ind ) ) ); % Store it mvn_model( m ).Sigma( i, j ) = tmp; end end % Covariance matrix is symmetric mvn_model( m ).Sigma = mvn_model( m ).Sigma + triu( mvn_model( m ).Sigma, 1 ).'; end else % DEBUG disp( [ mfilename ': M-step, linear domain' ] ); % M step: implementation in linear domain % ( less precise but faster than log domain ) new_w = exp( my_logsum_fast( comp_log_post.' ) ); new_w = new_w / sum( new_w ); for m = 1:M % Update weight of this component mvn_model( m ).w = new_w( m ); % Update parameters of this component % Weights tmp_w = exp( comp_log_post( m,: ) ); tmp_sum_w = sum( tmp_w ); % Mean mvn_model( m ).mu = full( ( data * tmp_w(:) ) / tmp_sum_w ); % Covariance matrix if 0 tmp_term = full( ( data - repmat( mvn_model( m ).mu(:), 1, N ) ) * diag( sparse( sqrt( tmp_w(:) / tmp_sum_w ) ) ) ); mvn_model( m ).Sigma = zeros( D, D ); for i = 1:D j_max = D; if mvn_model( m ).diag_Sigma j_max = i; end for j = i:j_max mvn_model( m ).Sigma( i, j ) = tmp_term( i,: ) * tmp_term( j,: ).'; end end % Covariance matrix is symmetric mvn_model( m ).Sigma = mvn_model( m ).Sigma + triu( mvn_model( m ).Sigma, 1 ).'; else if ~mvn_model( m ).diag_Sigma % Update full covariance matrix tmp_term = full( ( data - repmat( mvn_model( m ).mu(:), 1, N ) ) * diag( sparse( sqrt( tmp_w(:) / tmp_sum_w ) ) ) ); mvn_model( m ).Sigma = tmp_term * tmp_term.'; else % Update diagonal covariance matrix tmp_var = ( ( data - repmat( mvn_model( m ).mu(:), 1, N ) ) .^ 2 ) * tmp_w(:) / tmp_sum_w; mvn_model( m ).Sigma = diag( sparse( tmp_var ) ); end end end end for m = 1:M % If Sigma not definite positive, % then restrict to the diagonal if ~mvn_model( m ).diag_Sigma [R,p] = chol( mvn_model( m ).Sigma ); mvn_model( m ).diag_Sigma = ( p~=0 ); if mvn_model( m ).diag_Sigma disp( sprintf( [ mfilename ': from now on, forcing covariance matrix of model %d to be diagonal.' ], m ) ); % Then we need to compute the lkld again % ( the convergence test won't be valid ) previous_all_log_lkld = -Inf; all_log_lkld = -Inf; end end if mvn_model( m ).diag_Sigma % use diagonal only mvn_model( m ).Sigma = diag( sparse( diag( mvn_model( m ).Sigma ) ) ); end end disp( [ mfilename sprintf( ': iter:%d all_lkld:%.10g', iter, all_log_lkld ) ] ); % Termination test iter = iter + 1; % Note that when both "previous_all_log_lkld" and "all_log_lkld" are -Inf, % we have -Inf - -Inf = NaN, NaN < abs(...) * 1e-8 is always 0, % so that we basically don't check convergence, which is what we need % in such case ( change of parameter space = restart EM ). finished = ( iter >= 1000 ) | ( ( all_log_lkld - previous_all_log_lkld ) < abs( all_log_lkld ) * 1e-8 ); previous_all_log_lkld = all_log_lkld; end new_mvn_model = mvn_model;