Main Page   Packages   Class Hierarchy   Compound List   File List   Compound Members  

HMMComposer.java

00001 /*
00002  *  Copyright 2006-2007 Columbia University.
00003  *
00004  *  This file is part of MEAPsoft.
00005  *
00006  *  MEAPsoft is free software; you can redistribute it and/or modify
00007  *  it under the terms of the GNU General Public License version 2 as
00008  *  published by the Free Software Foundation.
00009  *
00010  *  MEAPsoft is distributed in the hope that it will be useful, but
00011  *  WITHOUT ANY WARRANTY; without even the implied warranty of
00012  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00013  *  General Public License for more details.
00014  *
00015  *  You should have received a copy of the GNU General Public License
00016  *  along with MEAPsoft; if not, write to the Free Software
00017  *  Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
00018  *  02110-1301 USA
00019  *
00020  *  See the file "COPYING" for the text of the license.
00021  */
00022 
00023 package com.meapsoft.composers;
00024 
00025 import gnu.getopt.Getopt;
00026 
00027 import java.io.IOException;
00028 import java.util.Iterator;
00029 import java.util.Vector;
00030 import java.util.Arrays;
00031 import java.util.Random;
00032 
00033 import com.meapsoft.ChunkDist;
00034 import com.meapsoft.EuclideanDist;
00035 import com.meapsoft.DSP;
00036 import com.meapsoft.EDLChunk;
00037 import com.meapsoft.EDLFile;
00038 import com.meapsoft.FeatChunk;
00039 import com.meapsoft.FeatFile;
00040 import com.meapsoft.MinHeap;
00041 import com.meapsoft.ParserException;
00042 
00043 
00058 public class HMMComposer extends VQComposer
00059 {
00060     public static String description = "HMMComposer uses a features file to train a simple statistical model of a song and uses it to randomly generate a new sequence of chunks.  This works best when used with chunks created by the beat detector.";
00061 
00062     private int sequenceLength = 50;
00063     // prior probability of starting in a given state
00064     private double[] startProbs;
00065     // probability of transitioning from one state to another
00066     private double[][] transitionMatrix;
00067 
00068     public HMMComposer(FeatFile trainFN, EDLFile outFN)
00069     {
00070                 super(trainFN, outFN);
00071         }
00072 
00073         public void printUsageAndExit() 
00074         {
00075                 System.out.println("Usage: HMMComposer [-options] features.feat \n\n" + 
00076                            "  where options include:\n" + 
00077                            "    -o output_file   the file to write the output to (defaults to "+outFileName+")\n" +
00078                "    -g               debug mode\n" +
00079                "    -q codebook_size number of states in the HMM (defaults to "+cbSize+")\n" + 
00080                "    -b nbeats        number of beats each HMM state should contain (defaults to "+beatsPerCodeword+")\n" + 
00081                "    -s sequence_len  length of chunk sequence to generate (defaults to "+sequenceLength+")."); 
00082         printCommandLineOptions('i');
00083         printCommandLineOptions('d');
00084         printCommandLineOptions('c');
00085                 System.out.println();
00086                 System.exit(0);
00087         }
00088 
00089         public HMMComposer(String[] args) 
00090         {
00091         // java demands that we do this
00092         super(null, null);
00093 
00094                 if(args.length == 0)
00095                         printUsageAndExit();
00096 
00097                 Vector features = new Vector();
00098 
00099                 // Parse arguments
00100                 String argString = "o:c:q:i:gd:s:b:";
00101         featdim = parseFeatDim(args, argString);
00102         dist = parseChunkDist(args, argString, featdim);
00103         parseCommands(args, argString);
00104 
00105                 Getopt opt = new Getopt("HMMComposer", args, argString);
00106                 opt.setOpterr(false);
00107         
00108                 int c = -1;
00109                 while ((c =opt.getopt()) != -1) 
00110                 {
00111                         switch(c) 
00112                         {
00113                         case 'o':
00114                                 outFileName = opt.getOptarg();
00115                                 break;
00116                         case 'g':
00117                                 debug = true;
00118                                 break;
00119                         case 'q':
00120                                 cbSize = Integer.parseInt(opt.getOptarg());
00121                                 break;
00122                         case 'b':
00123                                 beatsPerCodeword = Integer.parseInt(opt.getOptarg());
00124                                 break;
00125                         case 's':
00126                                 sequenceLength = Integer.parseInt(opt.getOptarg());
00127                                 break;
00128             case 'c':  // already handled above
00129                 break;
00130             case 'd':  // already handled above
00131                 break;
00132             case 'i':  // already handled above
00133                 break;
00134                         case '?':
00135                                 printUsageAndExit();
00136                                 break;
00137                         default:
00138                                 System.out.print("getopt() returned " + c + "\n");
00139                         }
00140                 }
00141         
00142                 // parse arguments
00143                 int ind = opt.getOptind();
00144                 if(ind > args.length)
00145                         printUsageAndExit();
00146         
00147                 trainFile = new FeatFile(args[args.length-1]);
00148                 outFile = new EDLFile(outFileName);
00149 
00150                 System.out.println("Composing " + outFileName + 
00151                                                    " from " +  args[args.length-1] + ".");
00152         }
00153 
00154     public void setSequenceLength(int len)
00155     {
00156         sequenceLength = len;
00157     }
00158 
00159     private void learnTransitionMatrix(FeatFile trainFile)
00160     {
00161         startProbs = new double[cbSize];
00162         Arrays.fill(startProbs, 0);
00163 
00164         transitionMatrix = new double[cbSize][cbSize];
00165         for(int x = 0; x < cbSize; x++)
00166             Arrays.fill(transitionMatrix[x], 0);
00167 
00168         // sort the chunks in order of increasing startTime, while
00169         // keeping all chunks from the same srcFile together
00170         trainFile = (FeatFile)trainFile.clone();
00171         trainFile.chunks = new MinHeap(trainFile.chunks);
00172         ((MinHeap)trainFile.chunks).sort();
00173 
00174         int ndat = trainFile.chunks.size();
00175         int prevState = -1;
00176         String lastSrcFile = "";
00177         for(int n = 0; n < ndat; n++) 
00178         {
00179             FeatChunk ch = (FeatChunk)trainFile.chunks.get(n);
00180             
00181             int currState = quantizeChunk(ch);
00182 
00183             // is this the beginning of a srcFile?
00184             if(!lastSrcFile.equals(ch.srcFile))
00185             {
00186                 lastSrcFile = ch.srcFile;
00187 
00188                 startProbs[currState] += 1.0;
00189             }
00190             else
00191                 transitionMatrix[prevState][currState] += 1.0;
00192 
00193             prevState = currState;
00194         }
00195 
00196         // normalize probabilities
00197         double s = DSP.sum(startProbs);
00198         for(int x = 0; x < startProbs.length; x++)
00199             startProbs[x] /= s;
00200 
00201         for(int x = 0; x < transitionMatrix.length; x++)
00202         {
00203             s = DSP.sum(transitionMatrix[x]);
00204 
00205             for(int y = 0; y < transitionMatrix[x].length; y++)
00206                 transitionMatrix[x][y] /= s;
00207         }
00208 
00209         if(debug)
00210         {
00211             FeatFile f = new FeatFile("tmp");
00212             f.chunks = templateChunks;
00213             DSP.imagesc(f.getFeatures(), "codebook");
00214             DSP.imagesc(transitionMatrix, "transitionMatrix");
00215             DSP.imagesc(startProbs, "startProbs");
00216         }
00217     }
00218 
00219     private int multinomialSample(double uniformRV, double[] prob)
00220     {
00221         if(uniformRV <= prob[0])
00222             return 0;
00223 
00224         double[] cdf = DSP.cumsum(prob);
00225 
00226         for(int x = 1; x < cdf.length; x++)
00227             if(uniformRV > cdf[x-1] && uniformRV <= cdf[x])
00228                 return x;
00229 
00230         return prob.length;
00231     }
00232 
00233     public EDLFile compose()
00234     {
00235         learnCodebook(trainFile);
00236 
00237         learnTransitionMatrix(trainFile);
00238 
00239         // generate a sequence of chunks from the codebook and
00240         // transition matrix
00241 
00242         Random rand = new Random();
00243         double currTime = 0;
00244 
00245         // get first chunk from startProbs
00246         int lastIdx = multinomialSample(rand.nextDouble(), startProbs);
00247         EDLChunk nc = new EDLChunk((FeatChunk)templateChunks.get(lastIdx),
00248                                    currTime);
00249         outFile.chunks.add(nc);
00250         currTime += nc.length;
00251         progress.setValue(progress.getValue()+1);
00252 
00253         // use transitionMatrix for the remaining chunks
00254         for(int x = 1; x < sequenceLength; x++)
00255         {
00256             int currIdx = multinomialSample(rand.nextDouble(), 
00257                                             transitionMatrix[lastIdx]);
00258             
00259             nc = new EDLChunk((FeatChunk)templateChunks.get(currIdx),
00260                               currTime);
00261             outFile.chunks.add(nc);
00262             currTime += nc.length;
00263             progress.setValue(progress.getValue()+1);
00264             
00265             lastIdx = currIdx;
00266         }
00267 
00268         return outFile;
00269     }
00270 
00271 
00272         public static void main(String[] args) 
00273         {
00274                 HMMComposer m = new HMMComposer(args);
00275                 long startTime = System.currentTimeMillis();
00276                 m.run();
00277                 System.out.println("Done. Took " +
00278                                                    ((System.currentTimeMillis() - startTime)/1000.0)
00279                                                    + "s");
00280                 System.exit(0);
00281         }
00282 }

Generated on Tue Feb 6 19:02:26 2007 for MEAPsoft by doxygen1.2.18