function [acc,confusm,lhood,models] = do_expt_chroma(trainset, testset, ngmm, nsamp, twin, alig, verb)
% [acc,confusm,lhood,models] = do_expt_chroma(trainset, testset, ngmm, nsamp, twin alig, verb)
% Example baseline artist ID task.  Trainset is the name of a file listing 
% all the training examples (with corresponding label file).
% Test set lists all the test files.
% ngmm (10) is the number of gaussians to use in each mixture model
% (ngmm==1 is a special case for full-covariance single gaussians)
% nsamp is the maximum number of frames to use in training
% twin (1) is the number of successive beats to concatenate into
% single model frame
% alig = 1 means to attempt aligning (transposing) tracks on train
% and test
% acc returns the overall accuracy (0..1)
% confusm is a confusion count matrix
% lhood shows the raw scores for all tracks across all models
%
% 2007-04-04 Dan Ellis dpwe@ee.columbia.edu
% $Header: /homes/drspeech/data/uspop2002/baseline/RCS/do_expt.m,v 1.1 2007/04/06 15:44:05 dpwe Exp dpwe $

%%%% Input parameters
if nargin < 1; trainset = 'tracks-train.txt'; end
if nargin < 2; testset = 'tracks-test.txt'; end
% How many mixture components?
if nargin < 3; ngmm = 1; end
% max number of frames to train on
if nargin < 4; nsamp = 1000; end
% default twin
if nargin < 5; twin = 1; end
% should we transpose-align?
if nargin < 6; alig = 0; end
% progress messages?
if nargin < 7; verb = 0; end

% plot confusion matrix?
doplot = 0;
if verb
  doplot = 1;
  % seed the RNG when doing a verbose run for consistent results
  rand('state', 0);
end

% where to find the list files
listdir = '../mandelset';

datapath = '../chromfeats/';
dataext = '.mat';

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Train models
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

train_files = listfileread(fullfile(listdir, trainset));
train_labs = listfileread(fullfile(listdir, ['labels-',trainset]));

% List of all unique labels
ulabs = unique(train_labs);
nlabs = length(ulabs);

% Train models one by one

for model = 1:nlabs

  if verb; disp(['training for ',ulabs{model},' ...']); end

  % Select filenames that have this label as ground truth
  files = train_files(strcmp(train_labs, ulabs{model}));
    
  models(model) = model_train_chroma(buildpath(datapath,files,dataext), ...
                                     ngmm, nsamp, twin, alig);
  
end

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Test models
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

test_files = listfileread(fullfile(listdir,testset));
test_labs = listfileread(fullfile(listdir, ['labels-',testset]));

ntest = length(test_files);

% Eval likelihood of each dataset on each model

for file = 1:ntest

  if verb; disp(['testing ',test_files{file},'...']); end
  
  dd = readdatafile(buildpath(datapath,test_files{file},dataext));
  
  if size(dd,2) > nsamp
    % random subselection
    ri = randperm(size(dd,2));
    dd = dd(:,ri(1:nsamp));
  end

% Evaluate likelihood of this data under every model
  for model = 1:nlabs
    if alig
      lhood(model,file) = model_match_chroma(models(model), dd, twin, ...
                                             [], models(model).rmodel);
    else
      lhood(model,file) = model_match_chroma(models(model), dd, twin, ...
                                             0, models(model).rmodel);
    end
  end

  % Which model should it have been according to ground truth label?
  gt(file) = find(strcmp(ulabs, test_labs{file}));
  
end

% Which is the most likely model for each test item?
[maxv,indx] = max(lhood);

% Overall error rate
acc = mean(indx==gt);
if verb; disp(['Classification accuracy = ', num2str(100*acc),'%']); end

% Matrix of which track was classified to which class, and ground truth;
cm = 0*lhood;
gtm = 0*lhood;
for i = 1:ntest
  cm(indx(i),i) = 1;
  gtm(gt(i),i) = 1;
end
% So confusion matrix
confusm = gtm*cm';
% rows are true class, columns are reported (model) class

if doplot
  % Plot full l/hood matrix and confusion matrix
  subplot(121)
  imagesc(lhood); axis xy  % cols are tracks, rows are models
  ylabel('model');
  xlabel('track');

  % Label Y axis with model names
  set(gca,'YTick',1:nlabs);
  set(gca,'YTickLabel',ulabs);

  % Divide X axis into blocks coming from each artist & label
  for i = 1:nlabs; xx(i) = min(find(strcmp(test_labs, ulabs{i}))); end
  xx = [xx,length(test_labs)];
  for i = 1:nlabs; ulb2{i} = ulabs{i}([1 2]); end
  set(gca,'XTick',mean([xx(1:end-1);xx(2:end)]));
  set(gca,'XTickLabel',ulb2);
  hold on; plot([1;1]*xx(2:end-1),[0.5 nlabs+0.5],'-w'); hold off
  colorbar
  
  subplot(122)
  imagesc(confusm); axis xy
  xlabel('recog');
  ylabel('true');
  set(gca,'YTick',1:nlabs);
  set(gca,'YTickLabel',ulb2);
  set(gca,'XTick',1:nlabs);
  set(gca,'XTickLabel',ulb2);  
  gcolor(gray)
  colorbar
  
end
