function [cpts] = learnNet(inputs,parents)
%
% function [cpts] = learnNet(inputs,parents)
%
% Parents is an MxM matrix for each node listing
% which nodes are its parents. A one says that the
% corresponding column is a parent of the current row
% and zero means not a parents
% Make sure that parents come first (topological)
% Make sure that this is a valid DAG
% Make sure that nodes are not self-parents
%
% inputs is an NxM matrix where the values at the
% m'th column range from 0..D_m 
%
% cpts is a matrix of rows where each row corresponds
% to a node and contains its conditional probability
% hypercube rasterized as a long vector
%

debug = 1;

[N,M] = size(inputs);

% Figure out the cptsizes
cards = max(inputs);
cards = cards+ones(size(cards));
cptsize = zeros(M,1);
for i=1:M
  cptsize(i) = cards(i);
  for j=1:(i-1)
    if (parents(i,j)>0.5)
      cptsize(i) = cptsize(i)*cards(j);
    end
  end
end


% Initialize the cpt structure to -1 where
% negative indicates it is currently an invalid value.
cpts = -1.0*ones(M,max(cptsize));


% Max likelihood estimation is just counting
for i=1:M
   pi = find(parents(i,:));
   %disp([num2str(i) ': parents: ' num2str(pi)]);
   phi = [pi i];
   
   % enumerate all configs of phi
   % i think this works for higher cardinalities too
   configs = [];
   c = 1;
	% phi
   for k=fliplr(phi)
      x = [0:(cards(k)-1)]; % assume discrete values run from 0 up.
      x = repmat(x,c,1);
      configs = [x(:) repmat(configs,cards(k),1)];
      c = c*cards(k);
   end
	% configs
   
   for k=1:size(configs,1)
      match = inputs(:,phi)==repmat(configs(k,:),N,1);
      mphi = sum(sum(match,2)==length(phi)); % count em.
      match = inputs(:,pi)==repmat(configs(k,1:(end-1)),N,1);
      mpi = sum(sum(match,2)==length(phi)-1); % count em.
      
      if mpi>0
         cpts(i,k) = mphi/mpi;
      else
         cpts(i,k)=0;
      end
   end   
end

% Normalize the cpts
for i=1:M
  index = 1;
  while (index<cptsize(i))
    tiny = min(cpts(i,index:index+(cards(i)-1)));
    if (tiny<0.0)
      fprintf(1,'BAD: Found negative for cpt for node %d at index [%d,%d].\n',i,tot,index,index+cards(i)-1);
    end
    tot  = sum(cpts(i,index:index+(cards(i)-1)));
    cpts(i,index:index+(cards(i)-1)) = cpts(i,index:index+(cards(i)-1))*(1/tot);
    index = index+cards(i);
  end
end

