function [e,ins,del,tru,hh,tt] = beat_score(beats,truth,collar,VERBOSE)
% [e,ins,del,tru,hh,tt] = beat_score(beats,truth,collar)
%    Compare beat times to ground truth.
%    <beats> is a list of system-generated beat times
%    <truth> is a set of truth tap times, potentially a cell-array
%    of different subjects' responses.
%    Return <e> as the average error rate of <beats> against full
%    ground truth, a simple average of error rate against each
%    individual truth set, where error rate = (inserts + deletions)/true
%    and a hit means the beat time was within (collar*mean true
%    period) of a true hit.  <ins>, <del>, <tru> return counts of
%    individual beats inserted, deleted, and true; 
%    <hh> returns a histogram of system beat times relative to true
%    beat times, quantized to 10ms bins, with timebase given in <tt>.
% 2012-03-27 Dan Ellis [email protected]

if nargin < 3; collar = 0.2; end
if nargin < 4; VERBSOSE = 0; end

if nargout == 0; VERBOSE = 1; end

if isnumeric(truth)
  % make sure truth is always a cell array.  Convert from rows, if any.
  trutharray = truth;
  truth = cell(size(truth,1));
  for i = 1:size(truth,1)
    % keep only nonnegative values in each row
    truth{i} = trutharray(i, find(trutharray(i,:)>0));
  end
end

% parameters for histogram
maxt = 0.5; % cover for +/- 0.5 sec around true beats
tres = 0.01; % in 10ms bins
tt = (-maxt):tres:maxt;  % actual bin values
h = zeros(1,length(tt));

ntruth = length(truth);

for i = 1:ntruth
  truebeats = truth{i};
  medianperiod = median(diff(truebeats));
  collartime = collar*medianperiod;
  
  % find nearest truth to each system beat
  ntrue(i) = length(truebeats);
  nsys(i) = length(beats);
  % We're working with reported beats - true, so late tracking is positive
  tdiffs = repmat(beats,ntrue(i),1) - repmat(truebeats',1,nsys(i));
  
  % insertions are any system-generated beats more than collartime
  % away from the nearest true beat
  inserts(i) = sum(min(abs(tdiffs),[],1) > collartime);
  % deletions are true beats more than collartime away from nearest
  % system-generated beat
  deletes(i) = sum(min(abs(tdiffs),[],2) > collartime);
  
  % So error = (insertions + deletes)/ntrue
  error(i) = ( inserts(i) + deletes(i) )/ntrue(i);

  % update the histogram
%  % with just best times?
%  [mm, xx] = min(abs(tdiffs),[],1);
%  h = h + hist(tdiffs(sub2ind(size(tdiffs),xx,1:size(tdiffs,2))),tt);
  % or with *all* time differences within the window
  h = h + hist(tdiffs(:),tt);
end

% Average the error
e = mean(error);

% Combine error counts
ins = sum(inserts);
del = sum(deletes);
tru = sum(ntrue);

% trim the extreme bins from the histogram
hh = h(2:end-1);
tt = tt(2:end-1);

if VERBOSE
  fprintf(1,'Overall error= %5.1f%% (%4d ins, %4d del, %4d true)\n', ...
          100*e, ins, del, tru);
end