function beats = beatOnset(y, fs)

% Calculate onsets for a sound y as in BeatOnsetDetector.java in
% the mashup system.  Output is a vector indicating each predicted
% downbeat in the sound, as an index into the vector y.

frame = 256;
history = 4;
onsAlpha = 0.01;
onsThresh = 1.5;
beatFrame = 512;

% Take spectrogram
S = 20*log10(abs(specgram(y, frame, fs, [], 0)));
wts = fft2melmx(frame, fs); wts = wts(:,1:end/2+1);
S = wts * S;

% Find onsets
onsD = zeros(size(S));
for i=1:history
  td = S - [zeros(size(S,1),i-1) S(:,1:end-i+1)];
  td = td .* (td > 0);
  onsD = onsD + td;
end

% Find mean for moving threshold
onsMean = filter(onsAlpha, [1 -(1-onsAlpha)], onsD, [], 2);

% Threshold onsets
ons = onsD > onsMean*onsThresh;

% Find periodicity in onsets
onsum = sum(ons);
Sons = specgram(onsum, beatFrame, 1, [], beatFrame-1);
% $$$ SonsR = real(Sons);
SonsPh = angle(Sons);
Sons = log(abs(Sons));

% Interpolate to get better estimate of periodicity
[mx, imx] = max(Sons(3:20,:), [], 1);
imx = imx+2;
N = size(Sons,2);
ind = sub2ind(size(Sons), [imx-1; imx; imx+1], [1:N; 1:N; 1:N]);
% $$$ imx = imx + 2*(imx<=0);      % reflect around 1, turn 0 into 2
[mx, pimx] = parabMax(Sons(ind));
pimx = pimx + imx;

% Find beat locations
beats = [];
cont = contiguous(imx);
for i=1:size(cont,1)
  % Find where the phase crosses 0
  ph = SonsPh(imx(cont(i,1)), cont(i,1):cont(i,2));
  phd = diff(ph > 0);
  beat = find(phd > 0) + cont(i,1)-1;
  beats = [beats beat];
end
beatBin = zeros(size(onsum));
beatBin(beats) = 1;


if(nargout <= 0)
  clear beats;
  
  % Draw pictures
  subplot 511, imagesc(S), axis xy, colorbar;
  subplot 512, imagesc(ons), axis xy, colorbar;
  subplot 513, plot(onsum), axis tight, colorbar;
  % $$$ subplot 514, imagesc(SonsPh(1:20,:)>=0), axis xy, colorbar;
  % $$$ subplot 514, imagesc(SonsR(1:20,:)), axis xy, colorbar;
  % $$$ subplot 515, plot(Sons(:,500)), axis xy, colorbar;
  subplot 514, plot(beatBin), axis tight, colorbar;
  % $$$ subplot 515, plot(pimx), axis tight, colorbar;
  subplot 515, imagesc(Sons(1:20,:)), axis xy, colorbar;
end



function c = contiguous(ind)

% List the start and end points of all of the regions made up of
% the same index.  Ind is a vector of integers with long runs of
% the same value.  C is an Nx2 list where ind(c(i,1):c(i,2)) is all
% the same value.

c = find(diff(ind) ~= 0);
if(isempty(c))
  c = [1 length(ind)];
else
  c = [1 c+1; c length(ind)]';
end
