function [stateseq, loglik, lattice, tb] = decode_fhmm(hmm, seq, maxRank, beamLogProb, verb)
%function [stateseq, loglik] = fasterdecodefhmm(hmm, seq, rank, beam)
%
% Factorial Viterbi decode of seq.  Does rank and beam pruning.
%
% hmm is a cell array of hmm structures, stateseq is a cell array of
% state sequences
%
% seq can be a cell array of sequences. 
%
% Assume all hmm params are logprobs.  Right now this function only
% works with two simultaneous hmms.
%
% 2006-06-06 ronw@ee.columbia.edu

% Copyright (C) 2006-2007 Ron J. Weiss
%
% This program is free software: you can redistribute it and/or modify
% it under the terms of the GNU General Public License as published by
% the Free Software Foundation, either version 3 of the License, or
% (at your option) any later version.
%
% This program is distributed in the hope that it will be useful,
% but WITHOUT ANY WARRANTY; without even the implied warranty of
% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
% GNU General Public License for more details.
%
% You should have received a copy of the GNU General Public License
% along with this program.  If not, see <http://www.gnu.org/licenses/>.

zeroLogProb = -1e200;

% no rank pruning by default
if nargin < 3
  maxRank = 0;
end

% no beam pruning by default
if nargin < 4
  beamLogProb = -Inf;
end

if nargin < 5
  verb = 0;
end

seqnotcell = ~iscell(seq);
if iscell(seq)
  nseq = length(seq);
else
  nseq = 1;
  seq = {seq};
end

hmmnotcell = ~iscell(hmm);
if iscell(hmm)
  nhmm = length(hmm);
else
  nhmm = 1;
  hmm = {hmm};
end

if nhmm ~= 2
  error('length(hmm) must be 2');
end

for x = 1:nseq
  [ndim, nobs] = size(seq{x});

  % I am not going to do this in a graph since its slow as hell...
  % hopefully the sequences aren't so long so the state/obs lattice
  % isn't too big
  for h = 1:nhmm
    nstates{h} = length(hmm{h}.priors);
%    lattice{h} = zeros(nstates{h}, nobs);
%    lattice{h}(:,1) = hmm{h}.priors;
    stateseq{x,h} = zeros(1,nobs);
    tb{h} = zeros(nstates{h}, nobs);
    states{h} = 1:nstates{h};
  end

  latsize = cat(2, nstates{:});

  % precompute likelihoods and CDFs
  for h = 1:nhmm
    llik{h} = zeros(nstates{h}, nobs);
    for s = 1:nstates{h}
      cv = hmm{h}.covar(:,s)';
      mu = repmat(hmm{h}.mu(:,s), 1, nobs);
      dzm = seq{x} - mu;
      %llik1 = -1/2*(diag((1./cv1)'*(dzm1).^2)' + log(2*pi)*ndim + ...
      %    sum(log(cv1)));
      % might need to transpose cv1?
      llik{h}(s,:) = -.5*((1./cv)*dzm.^2 + ndim*log(2*pi) ...
                                 + sum(log(cv)))';

      tmplcdf = log(.5*(1+erf((1./(sqrt(2*cv)))*dzm)) + eps);
      lcdf{h}(s,:) = sum(tmplcdf, 1);
    end
  end

  lattice = zeros(nstates{1}, nstates{2}, nobs+1) + zeroLogProb;
  % combinations of hmm state priors is first entry of lattice
  lattice(:,:,1) = repmat(hmm{1}.priors', 1, nstates{2}) ...
                   + repmat(hmm{2}.priors, nstates{1}, 1);
  tb = zeros(size(lattice));

  % fill in the lattice...
  prevFrameMaxLogProb = zeroLogProb;
  for o = 2:nobs+1
    tic

    % o is index into lattice, obs is actual observation index
    obs = o - 1;  

    % beam pruning
    threshLogProb = prevFrameMaxLogProb + beamLogProb;

    % rank pruning               
    if maxRank > 0
      tmp = lattice(:,:,o-1);
      [hst cdf] = hist(tmp(:), 100);

      % want to look at the high ranks of the last frame
      hst = hst(end:-1:1);
      cdf = cdf(end:-1:1);
 
      hst = cumsum(hst);
      idx = min(find(hst >= maxRank));
      rankThresh = cdf(idx);
      
      % only change the threshold if it is stricter than the beam
      % threshold
      threshLogProb = max(threshLogProb, rankThresh);
    end

    % which states are active?
    [s1_tmp s2_tmp] = ind2sub(latsize, find(squeeze(lattice(:,:,o-1)) ... 
                                          >= threshLogProb));
    s1_active = unique(s1_tmp)';
    s1_nactive = numel(s1_active);
    s2_active = unique(s2_tmp)';
    s2_nactive = numel(s2_active);

    % likelihood of observation o given each possible state combination:
    % likelihood = cdf(s1)N(s2) + N(s1)cdf(s2)
    %  See Varga Moore paper on factorial HMMs
    currllik = logsum(cat(3, ...  
        repmat(llik{1}(:,obs), 1, nstates{2}) ...
        + repmat(lcdf{2}(:,obs)', nstates{1}, 1), ...
        repmat(lcdf{1}(:,obs), 1, nstates{2}) ...
        + repmat(llik{2}(:,obs)', nstates{1}, 1)), 3);
    
    % want to look at probability of moving through each state at
    % frame o-1
    for s1 = s1_active 
      transprob_s1 = repmat(hmm{1}.transmat(s1,:)', 1, nstates{2});

      for s2 = s2_active 
        % viterbi posterior in current frame is max over all prev
        % states of likelihood of curr obs in all states * p(being in
        % prev state) * transprob(prev state, curr state)
        transprob = transprob_s1 + repmat(hmm{2}.transmat(s2,:), nstates{1}, 1);
        tmpLat = lattice(s1,s2,o-1) + transprob + currllik;
        %[lattice(s1,s2,o), tb(s1,s2,o)] = max(tmpLat(:));

        idx = tmpLat > lattice(:,:,o);
        lattice(:,:,o) = max(lattice(:,:,o), tmpLat);

        % traceback
        tmpTB = tb(:,:,o);
        tmpTB(idx) = repmat(sub2ind(latsize, s1, s2), 1, numel(find(idx)));
        tb(:,:,o) = tmpTB;
      end
    end

    prevFrameMaxLogProb = max(max(lattice(:,:,o)));

    %imgsc(squeeze(lattice(:,:,o))), title(['obs ' num2str(o-1)]); drawnow;

    T = toc;
    if verb
      disp(['frame ' num2str(obs), ...
            ': active states in hmm1: ' num2str(s1_nactive) ...
            ', active states in hmm2: ' num2str(s2_nactive) ...
            ' (' num2str(T) ' sec)']);
    end
  end

  % do the traceback:
  tmp = lattice(:,:,end);
  [loglik tmp] = max(tmp(:));
  [s1, s2] = ind2sub(latsize, tmp);
  for o = nobs+1:-1:2
    % o is index into lattice, obs is actual observation index
    obs = o - 1;  

    %disp(['obs ' num2str(obs), ': s1 = ' num2str(s1), ', s2 = ' num2str(s2)]);

    stateseq{x,1}(obs) = s1;
    stateseq{x,2}(obs) = s2;
    
    [s1, s2] = ind2sub(latsize, tb(s1,s2,o));
  end
end

if seqnotcell
  stateseq = stateseq{1};
end

