function [centres, options, post, errlog] = kmeans_radians(ncentres, data, options, centres) %KMEANS_RADIANS Trains a k means cluster model. % % function [centres, options, post, errlog] = kmeans_radians(ncentres, data, options) % % Description % CENTRES = KMEANS_RADIANS(NCENTRES, DATA, OPTIONS) uses the batch K-means % algorithm to set the centres of a cluster model. The matrix DATA % represents the data which is being clustered, with each row % corresponding to a vector. The sum of squares error function is used. % The point at which a local minimum is achieved is returned as % CENTRES. The error value at that point is returned in OPTIONS(8). % % [CENTRES, OPTIONS, POST, ERRLOG] = KMEANS_RADIANS(NCENTRES, DATA, OPTIONS) % also returns the cluster number (in a one-of-N encoding) for each % data point in POST and a log of the error values after each cycle in % ERRLOG. The optional parameters have the following % interpretations. % % OPTIONS(1) is set to 1 to display error values; also logs error % values in the return argument ERRLOG. If OPTIONS(1) is set to 0, then % only warning messages are displayed. If OPTIONS(1) is -1, then % nothing is displayed. % % OPTIONS(2) is a measure of the absolute precision required for the % value of CENTRES at the solution. If the absolute difference between % the values of CENTRES between two successive steps is less than % OPTIONS(2), then this condition is satisfied. % % OPTIONS(3) is a measure of the precision required of the error % function at the solution. If the absolute difference between the % error functions between two successive steps is less than OPTIONS(3), % then this condition is satisfied. Both this and the previous % condition must be satisfied for termination. % % OPTIONS(14) is the maximum number of iterations; default 100. % % See also % GMMINIT, GMMEM % % Copyright (c) Ian T Nabney (1996-2001) siz = size(data); if sum( siz > 1 ) > 1 error('"data" dimension must be one' ); end data = data(:); ndata = length( data ); siz = size( ncentres ); if sum( siz > 1 ) > 0 error( '"ncentres" must be a scalar' ); end if (ncentres > ndata) error('More centres than data') end if ~exist( 'options', 'var' ) options = foptions; end % Sort out the options if (options(14)) niters = options(14); else niters = 100; end store = 0; if (nargout > 3) store = 1; errlog = []; end if ~exist( 'centres', 'var' ) % initialize centres centres = (0:ncentres-1) / ncentres * 2 * pi; % Make sure no cluster is empty % -> for each center, pick the closest data point ind = []; for a = 1:length( centres ) remaining_ind = setdiff( 1:length( data ), ind ); [b, c] = min( abs( delta_angle( data( remaining_ind ), centres( a ) ) ) ); ind( end+1 ) = remaining_ind( c ); end centres = sort( mod( data( ind ), 2*pi ) ); % For readability centres = centres(:).'; if options(1)>0 disp( 'kmeans_radians: centres intialized to:' ); disp( [ 'kmeans_radians: centres = [ ' sprintf( ' %6.2f', centres/pi*180 ) ' ] degrees' ] ); end end % This is not necessary, just for readability centres = sort( mod( centres, 2*pi ) ); % Matrix to make unit vectors easy to construct id = eye(ncentres); % Main loop of algorithm for n = 1:niters % Save old centres to check for termination old_centres = centres; % Calculate posteriors based on existing centres [d2,dd] = dist2_radians( data, centres ); % Assign each point to nearest centre [minvals, index] = min(d2', [], 1); post = id(index,:); num_points = sum(post, 1); % Adjust the centres based on new posteriors the_std = repmat( NaN, size( centres ) ); for j = 1:ncentres if (num_points(j) > 0) cluster_points = find(post(:,j)); centres(j) = centres(j) + mean( dd( cluster_points, j ) ); new_dd = -pi + mod( data( cluster_points ) - centres(j) + pi, 2*pi ); the_std( j ) = std( new_dd ); end end centres = mod( centres, 2 * pi ); % Keep only non-empty clusters if ~all( num_points > 0 ) ind = find( num_points > 0 ); num_points = num_points( ind ); centres = centres( ind ); post = post( :, ind ); the_std = the_std( ind ); end % Error value is total squared distance from cluster centres % e = sum(minvals); e = sum( the_std ); if store errlog(n) = e; end if options(1) > 0 fprintf(1, 'Cycle %4d Error %11.6f\n', n, e); end if (n > 1) & (length( centres ) == length( old_centres )) % Test for termination if max(max(abs(centres - old_centres))) < options(2) & abs(old_e - e) < options(3) options(8) = e; return; end end old_e = e; end % If we get here, then we haven't terminated in the given number of % iterations. options(8) = e; if (options(1) >= 0) disp('Warning: Maximum number of iterations has been exceeded'); %DEBUG keyboard end