function [acc,confusm,lhood,models] = do_expt(trainset, testset, ngmm, nsamp, dims, verb)
% [acc,confusm,lhood,models] = do_expt(trainset, testset, ngmm, nsamp, dims, verb)
%
% Example baseline artist ID task.  Trainset is the name of a file listing 
% all the training examples (with corresponding label file).
% Test set lists all the test files.
% ngmm (10) is the number of gaussians to use in each mixture model
% (ngmm==1 is a special case for full-covariance single gaussians)
% nsamp (2000) is the number of randomly-selected samples to train on
% dims is vector indicating which columns of feature vectors to use.
% acc returns the overall accuracy (0..1)
% confusm is a confusion count matrix
% lhood shows the raw scores for all tracks across all models
%
% 2007-04-04 Dan Ellis dpwe@ee.columbia.edu
% $Header: /homes/drspeech/data/uspop2002/baseline/RCS/do_expt.m,v 1.1 2007/04/06 15:44:05 dpwe Exp dpwe $

%%%% Input parameters
if nargin < 1; trainset = 'tracks-train.txt'; end
if nargin < 2; testset = 'tracks-test.txt'; end
% How many mixture components?
if nargin < 3; ngmm = 10; end
% Train on how many samples per file?
if nargin < 4; nsamp = 2000; end
% Using which cepstral dimensions?
if nargin < 5; dims = 1:12; end
% progress messages?
if nargin < 6; verb = 0; end

% plot confusion matrix?
doplot = 0;
if verb
  doplot = 1;
  % seed the RNG when doing a verbose run for consistent results
  rand('state', 0);
end

% where to find the list files
listdir = '../mandelset';

% which kind of features to use
features = 'mfcc';
%features = 'chroma';
disp(['Features = ',features]);

%%%% Configuration
if strcmp(features,'mfcc')
  datapath = '../artists/';
  dataext = '.htk';
  ndims = 20;
elseif strcmp(features,'chroma')
  datapath = '../chromfeats/';
  dataext = '.mat';
  ndims = 12;
else
  error(['Unknown features ',features]);
end

% These are the global normalization constants; 
% Initial empty values means they are set from the data in the first 
% call to model_train
%mn = [];
%st = [];

% no norm for chroma
mn = zeros(ndims,1);
st = ones(ndims,1);
% no - kills gaussian modelling .. not any more? log flag to gauss_prob?

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Train models
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

train_files = listfileread(fullfile(listdir, trainset));
train_labs = listfileread(fullfile(listdir, ['labels-',trainset]));

% Make the train files into full relative paths
for i = 1:length(train_files)
  train_files{i} = fullfile(datapath, [train_files{i},dataext]);
end

% List of all unique labels
ulabs = unique(train_labs);
nlabs = length(ulabs);

% Train models one by one

%if exist('models') == 1
%  disp('Using already-trained models ("clear models" to retrain)...');
%else

  for model = 1:nlabs

    if verb; disp(['training for ',ulabs{model},' ...']); end

    % Select filenames that have this label as ground truth
    files = train_files(strcmp(train_labs, ulabs{model}));
    
    [models(model),mn,st] = model_train(files, mn, st, ngmm, nsamp, dims);
  
  end

%end

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Test models
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

test_files = listfileread(fullfile(listdir, testset));
test_labs = listfileread(fullfile(listdir, ['labels-',testset]));

ntest = length(test_files);

% Eval likelihood of each dataset on each model, or do KL between
% models
match_models = 0;

if match_models
  disp('** Matching as KL between models');
else
  disp('** Matching by max l/hood of test samples');
end

for file = 1:ntest

  if verb; disp(['testing ',test_files{file},'...']); end
  
  filename = fullfile(datapath,[test_files{file},dataext]);
  d = standardize(readdatafile(filename),mn,st)';
  
  % choose random time samples
  ntsamp = nsamp;
  % ri will consist of ntsamp random indices (with replacement) at least 
  % <guard> samples away from end, where guard is larger of 1000 or one
  % quarter of the dataset length
  guard = min(round(size(d,1)/4),1000); % how far to stay away from ends
  ri = guard + ceil((size(d,1)-2*guard)*rand(1,ntsamp));
  
  d = d(ri,dims);

  if match_models
    tmodel = model_train_data(d, ngmm);
  end
    
  % Evaluate likelihood of this data under every model
  for model = 1:nlabs
    if match_models
      lhood(model,file) = model_match_models(models(model), tmodel);
    else
      lhood(model,file) = model_match(models(model), d);
    end
  end

  % Which model should it have been according to ground truth label?
  gt(file) = find(strcmp(ulabs, test_labs{file}));
  
end

% Which is the most likely model for each test item?
[maxv,indx] = max(lhood);

% Overall error rate
acc = mean(indx==gt);
if verb; disp(['Classification accuracy = ', num2str(100*acc),'%']); end

% Matrix of which track was classified to which class, and ground truth;
cm = 0*lhood;
gtm = 0*lhood;
for i = 1:ntest
  cm(indx(i),i) = 1;
  gtm(gt(i),i) = 1;
end
% So confusion matrix
confusm = gtm*cm';
% rows are true class, columns are reported (model) class

if doplot
  % Plot full l/hood matrix and confusion matrix
  subplot(121)
  imagesc(lhood); axis xy  % cols are tracks, rows are models
  ylabel('model');
  xlabel('track');

  % Label Y axis with model names
  set(gca,'YTick',1:nlabs);
  set(gca,'YTickLabel',ulabs);

  % Divide X axis into blocks coming from each artist & label
  for i = 1:nlabs; xx(i) = min(find(strcmp(test_labs, ulabs{i}))); end
  xx = [xx,length(test_labs)];
  for i = 1:nlabs; ulb2{i} = ulabs{i}([1 2]); end
  set(gca,'XTick',mean([xx(1:end-1);xx(2:end)]));
  set(gca,'XTickLabel',ulb2);
  hold on; plot([1;1]*xx(2:end-1),[0.5 nlabs+0.5],'-w'); hold off
  colorbar
  
  subplot(122)
  imagesc(confusm); axis xy
  xlabel('recog');
  ylabel('true');
  set(gca,'YTick',1:nlabs);
  set(gca,'YTickLabel',ulb2);
  set(gca,'XTick',1:nlabs);
  set(gca,'XTickLabel',ulb2);  
  gcolor(gray)
  colorbar
  
end
