00001 00002 // $Id: lab4_vit.H,v 1.7 2009/11/05 14:25:21 stanchen Exp stanchen $ 00003 00004 00005 /** * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * ** 00006 * @file lab4_vit.H 00007 * @brief Main loop for Lab 4 Viterbi decoder. 00008 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ 00009 00010 #ifndef _LAB4_VIT_H 00011 #define _LAB4_VIT_H 00012 00013 00014 #include <functional> 00015 #include <utility> 00016 #include "util.H" 00017 #include "front_end.H" 00018 00019 00020 /** * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * ** 00021 * CPU timer. 00022 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ 00023 class Timer 00024 { 00025 public: 00026 /** Ctor; if @p doStart is true, starts timer. **/ 00027 Timer(bool doStart = false) : m_cumSecs(0.0), m_start(-1.0) 00028 { if (doStart) start(); } 00029 00030 /** Returns whether timer is currently on. **/ 00031 bool is_on() const { return m_start != -1.0; } 00032 00033 /** Starts timer. **/ 00034 void start(); 00035 00036 /** Stops timer. Returns cumulative time on so far. **/ 00037 double stop(); 00038 00039 /** Returns cumulative seconds timer has been on so far. 00040 * If timer currently on, doesn't include time since last 00041 * time was started. 00042 **/ 00043 double get_cum_secs() const { return m_cumSecs; } 00044 00045 private: 00046 /** Cumulative seconds timer has been on so far. **/ 00047 double m_cumSecs; 00048 00049 /** If timer on, last time timer was started; -1 otherwise. **/ 00050 double m_start; 00051 }; 00052 00053 00054 /** * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * ** 00055 * Struct for holding a backtrace word tree. 00056 * 00057 * This object can be used to hold a list of word sequences in 00058 * the form of a tree. Each node in the tree is assigned an 00059 * integer index, and each arc in the tree is labeled with an 00060 * integer index corresponding to a word label. Each node in 00061 * the tree can be viewed as representing the word sequence 00062 * labeling the path from the root to that node. 00063 * 00064 * To get the index of the root node, use #get_root_node(). 00065 * To find/create the node reached by extending a node with a word, 00066 * use #insert_node(). To recover the word sequence a node 00067 * corresponds to, you can use #get_parent_node() and #get_last_word(). 00068 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ 00069 class WordTree 00070 { 00071 private: 00072 typedef pair<int, unsigned> Node; 00073 00074 public: 00075 /** Ctor; initializes object to just contain root node. **/ 00076 WordTree() { clear(); } 00077 00078 /** Clears object except for root node. **/ 00079 void clear() { m_nodeArray.clear(); m_nodeHash.clear(); 00080 insert_node((unsigned) -1, (unsigned) -1); } 00081 00082 /** Returns number of nodes in tree. **/ 00083 unsigned size() const { return m_nodeArray.size(); } 00084 00085 /** Returns index of root node. **/ 00086 unsigned get_root_node() const { return 0; } 00087 00088 /** Given an existing node @p parentIdx, returns index of 00089 * child node reached when traversing arc labeled with 00090 * word index @p lastWord. If node doesn't exist, it is created. 00091 **/ 00092 unsigned insert_node(unsigned parentIdx, unsigned lastWord) 00093 { 00094 Node key(parentIdx, lastWord); 00095 map<Node, unsigned>::const_iterator itemPtr = m_nodeHash.find(key); 00096 if (itemPtr != m_nodeHash.end()) 00097 return itemPtr->second; 00098 00099 m_nodeArray.push_back(Node(parentIdx, lastWord)); 00100 unsigned nodeIdx = m_nodeArray.size() - 1; 00101 m_nodeHash[key] = nodeIdx; 00102 return nodeIdx; 00103 } 00104 00105 /** Returns index of parent node for node with index @p nodeIdx. **/ 00106 unsigned get_parent_node(unsigned nodeIdx) const 00107 { return m_nodeArray[nodeIdx].first; } 00108 00109 /** Returns index of word labeling arc from node @p nodeIdx 00110 * to its parent node. 00111 **/ 00112 unsigned get_last_word(unsigned nodeIdx) const 00113 { return m_nodeArray[nodeIdx].second; } 00114 00115 private: 00116 /** Array of nodes in tree. */ 00117 vector<Node> m_nodeArray; 00118 00119 /** Hash table, for fast node lookup. */ 00120 map<Node, unsigned> m_nodeHash; 00121 }; 00122 00123 00124 /** * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * ** 00125 * Cell in dynamic programming chart for Viterbi algorithm. 00126 * 00127 * Holds Viterbi log prob; and arc ID of best incoming arc for backtrace. 00128 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ 00129 class FrameCell 00130 { 00131 public: 00132 /** Ctor; inits log prob to g_zeroLogProb and node index to 0. **/ 00133 FrameCell() : m_logProb(g_zeroLogProb), m_nodeIdx(0) { } 00134 00135 #ifndef SWIG 00136 #ifndef DOXYGEN 00137 // Hack; for bug in matrix<> class in boost 1.32. 00138 explicit FrameCell(int) : m_logProb(g_zeroLogProb), m_nodeIdx(0) { } 00139 #endif 00140 #endif 00141 00142 /** Sets associated log prob and WordTree node index. **/ 00143 void assign(double logProb, unsigned nodeIdx) 00144 { m_logProb = logProb; m_nodeIdx = nodeIdx; } 00145 00146 /** Returns log prob of cell. **/ 00147 double get_log_prob() const { return m_logProb; } 00148 00149 /** Returns node index in WordTree for best incoming word sequence. **/ 00150 unsigned get_node_index() const { return m_nodeIdx; } 00151 00152 private: 00153 /** Forward Viterbi logprob. **/ 00154 float m_logProb; 00155 00156 /** Holds node index in WordTree for best incoming word sequence. **/ 00157 unsigned m_nodeIdx; 00158 }; 00159 00160 00161 /** * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * ** 00162 * Struct for holding (active) cells at a frame in the DP chart. 00163 * 00164 * Stores a list of cells of type FrameCell. 00165 * 00166 * To find a cell and create it if it doesn't exist, use #insert_cell(). 00167 * To look up cells by state index, use #get_cell_by_state() and #has_cell(). 00168 * 00169 * To loop through all cells in increasing state order, use 00170 * #reset_iteration() and #get_next_state(). 00171 * 00172 * To loop through all cells (in no particular order), use 00173 * #get_cell_by_index() (and #size() to determine how many cells there are). 00174 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ 00175 class FrameData 00176 { 00177 public: 00178 /** Ctor; initializes object to be empty. 00179 * The argument @p stateCnt should be the number of states 00180 * in the graph. 00181 **/ 00182 FrameData(unsigned stateCnt) : 00183 m_stateMap(stateCnt, -1), m_heapSize(-1) { } 00184 00185 /** Clears object. **/ 00186 void clear() 00187 { 00188 vector<unsigned>::const_iterator endPtr = m_activeStates.end(); 00189 for (vector<unsigned>::const_iterator curPtr = 00190 m_activeStates.begin(); curPtr != endPtr; ++curPtr) 00191 { 00192 assert(m_stateMap[*curPtr] >= 0); 00193 m_stateMap[*curPtr] = -1; 00194 } 00195 m_activeStates.clear(); 00196 m_cellArray.clear(); 00197 m_heapSize = -1; 00198 } 00199 00200 /** Returns number of active cells. **/ 00201 unsigned size() const { return m_cellArray.size(); } 00202 00203 /** Returns whether no active cells. **/ 00204 bool empty() const { return m_cellArray.empty(); } 00205 00206 /** Returns number of states in corresponding graph. **/ 00207 unsigned get_state_count() const { return m_stateMap.size(); } 00208 00209 00210 /** Returns cell corresponding to state with index @p stateIdx. 00211 * The cell must already exist. You can check whether a cell 00212 * exists using #has_cell(). 00213 **/ 00214 const FrameCell& get_cell_by_state(unsigned stateIdx) const 00215 { assert(m_stateMap[stateIdx] >= 0); 00216 return m_cellArray[m_stateMap[stateIdx]]; } 00217 00218 /** Returns whether cell exists for state @p stateIdx. **/ 00219 bool has_cell(unsigned stateIdx) const 00220 { return m_stateMap[stateIdx] >= 0; } 00221 00222 /** Returns cell for state @p stateIdx, creating it if absent. 00223 * If called in the middle of looping through states 00224 * (see #reset_iteration(), #get_next_state()), 00225 * will be added into list of states not yet looped through. 00226 **/ 00227 FrameCell& insert_cell(unsigned stateIdx) 00228 { 00229 int cellIdx = m_stateMap[stateIdx]; 00230 if (cellIdx >= 0) 00231 return m_cellArray[cellIdx]; 00232 m_activeStates.push_back(stateIdx); 00233 if (m_heapSize >= 0) 00234 { 00235 std::swap(m_activeStates.back(), m_activeStates[m_heapSize]); 00236 ++m_heapSize; 00237 push_heap(m_activeStates.begin(), m_activeStates.begin() + 00238 m_heapSize, greater<unsigned>()); 00239 } 00240 m_cellArray.push_back(FrameCell()); 00241 m_stateMap[stateIdx] = m_cellArray.size() - 1; 00242 return m_cellArray.back(); 00243 } 00244 00245 00246 /** Returns cell with index @p cellIdx, where cells are numbered 00247 * in an arbitrary order. 00248 * Cells are numbered upwards from 0. There is no easy way to 00249 * recover the state index corresponding to a cell retrieved 00250 * by this method. However, this method may be useful for 00251 * computing pruning thresholds. 00252 **/ 00253 const FrameCell& get_cell_by_index(unsigned cellIdx) const 00254 { return m_cellArray[cellIdx]; } 00255 00256 /** Returns state index for @p idx-th active state, where states 00257 * are numbered in no particular order. 00258 * If any non-read-only methods are called, the numbering 00259 * of states may change. 00260 **/ 00261 unsigned get_state_by_index(unsigned idx) const 00262 { return m_activeStates[idx]; } 00263 00264 00265 /** Prepares object for iterating through states in upward order. 00266 * See #get_next_state() to do actual iteration. Specifically, 00267 * puts all active states in list of states not yet iterated through. 00268 **/ 00269 void reset_iteration() 00270 { 00271 make_heap(m_activeStates.begin(), m_activeStates.end(), 00272 greater<unsigned>()); 00273 m_heapSize = m_activeStates.size(); 00274 } 00275 00276 /** Returns lowest-numbered state not yet iterated through, 00277 * or -1 if no more active states. 00278 **/ 00279 int get_next_state() 00280 { 00281 assert(m_heapSize >= 0); 00282 if (!m_heapSize) 00283 return -1; 00284 pop_heap(m_activeStates.begin(), m_activeStates.begin() + 00285 m_heapSize, greater<unsigned>()); 00286 --m_heapSize; 00287 return m_activeStates[m_heapSize]; 00288 } 00289 00290 /** Swap operation. **/ 00291 void swap(FrameData& frmData) 00292 { 00293 m_activeStates.swap(frmData.m_activeStates); 00294 m_cellArray.swap(frmData.m_cellArray); 00295 m_stateMap.swap(frmData.m_stateMap); 00296 std::swap(m_heapSize, frmData.m_heapSize); 00297 } 00298 00299 private: 00300 /** The states that are active, in no particular order. **/ 00301 vector<unsigned> m_activeStates; 00302 00303 /** Array of DP cells for active states. **/ 00304 vector<FrameCell> m_cellArray; 00305 00306 /** For each state, location in m_cellArray if active, -1 otherwise. 00307 * That is, for an inactive state @c stateIdx, @c m_stateMap[stateIdx] 00308 * will be -1. For an active state @c stateIdx, its DP cell can be 00309 * found at @c m_cellArray[m_stateMap[stateIdx]]. 00310 **/ 00311 vector<int> m_stateMap; 00312 00313 /** If nonnegative, how many states in heap in m_activeStates. 00314 * That is, we loop through states in a frame in order by 00315 * keeping the states in a "heap". A "heap" is formed 00316 * by ordering the elements in an array in such a way that 00317 * it is easy to keep track of the lowest-numbered element. 00318 * Before we begin looping through states in order, 00319 * we arrange all states in m_activeStates as a heap. 00320 * We then repeatedly grab the lowest-numbered state in the heap, 00321 * and then remove the state from the heap. 00322 * The array m_activeStates stays the same size 00323 * during this, but the part of m_activeStates 00324 * arranged as a heap becomes smaller and smaller. 00325 **/ 00326 int m_heapSize; 00327 }; 00328 00329 00330 /* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * 00331 * 00332 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ 00333 00334 /** Routine for copying debugging info from @p curFrame into @p chart. **/ 00335 void copy_frame_to_chart(const FrameData& curFrame, unsigned frmIdx, 00336 matrix<FrameCell>& chart); 00337 00338 /** Routine for Viterbi backtrace; token passing. **/ 00339 double viterbi_backtrace_word_tree(const Graph& graph, 00340 const FrameData& lastFrame, const WordTree& wordTree, 00341 vector<int>& outLabelList); 00342 00343 00344 /** * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * ** 00345 * Encapsulation of main loop for Viterbi decoding. 00346 * 00347 * Holds global variables and has routines for initializing variables 00348 * and updating them for each utterance. 00349 * We do this so that we can call this code from Java as well. 00350 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ 00351 class Lab4VitMain 00352 { 00353 public: 00354 /** Initialize all data given parameters. **/ 00355 Lab4VitMain(const map<string, string>& params); 00356 00357 /** Called at the beginning of processing each utterance. 00358 * Returns whether at EOF. 00359 **/ 00360 bool init_utt(); 00361 00362 /** Called at the end of processing each utterance. **/ 00363 void finish_utt(double logProb); 00364 00365 /** Called at end of program. **/ 00366 void finish(); 00367 00368 00369 /** Returns decoding graph/HMM. **/ 00370 const Graph& get_graph() const { return m_graph; } 00371 00372 /** Returns matrix of GMM log probs for each frame. **/ 00373 const matrix<double>& get_gmm_probs() const { return m_gmmProbs; } 00374 00375 /** Returns vector to place decoded labels in. **/ 00376 vector<int>& get_label_list() { return m_labelList; } 00377 00378 /** Returns acoustic weight. **/ 00379 double get_acous_wgt() const { return m_acousWgt; } 00380 00381 /** Returns beam width, log base e. **/ 00382 double get_log_prob_beam() const { return m_logProbBeam; } 00383 00384 /** Returns rank beam; 0 signals no rank pruning. **/ 00385 unsigned get_state_count_beam() const { return m_stateCntBeam; } 00386 00387 /** Returns full DP chart; only used for storing diagnostic info. **/ 00388 matrix<FrameCell>& get_chart() { return m_chart; } 00389 00390 private: 00391 /** Program parameters. **/ 00392 map<string, string> m_params; 00393 00394 /** Front end. **/ 00395 FrontEnd m_frontEnd; 00396 00397 /** Acoustic model. **/ 00398 shared_ptr<GmmScorer> m_gmmScorerPtr; 00399 00400 /** Stream for reading audio data. **/ 00401 ifstream m_audioStrm; 00402 00403 /** Graph/HMM. **/ 00404 Graph m_graph; 00405 00406 /** Stream for writing decoding output to. **/ 00407 ofstream m_outStrm; 00408 00409 /** Acoustic weight. **/ 00410 double m_acousWgt; 00411 00412 /** Beam width, log base e. **/ 00413 double m_logProbBeam; 00414 00415 /** Rank beam; 0 signals no rank pruning. **/ 00416 unsigned m_stateCntBeam; 00417 00418 /** ID string for current utterance. **/ 00419 string m_idStr; 00420 00421 /** Input audio for current utterance. **/ 00422 matrix<double> m_inAudio; 00423 00424 /** Feature vectors for current utterance. **/ 00425 matrix<double> m_feats; 00426 00427 /** GMM probs for current utterance. **/ 00428 matrix<double> m_gmmProbs; 00429 00430 /** Decoded output. **/ 00431 vector<int> m_labelList; 00432 00433 /** DP chart for current utterance, for returning diagnostic info. **/ 00434 matrix<FrameCell> m_chart; 00435 00436 /** Total frames processed so far. **/ 00437 int m_totFrmCnt; 00438 00439 /** Total log prob of utterances processed so far. **/ 00440 double m_totLogProb; 00441 00442 /** Timer for front end processing. **/ 00443 Timer m_frontEndTimer; 00444 00445 /** Timer for GMM prob computation. **/ 00446 Timer m_gmmTimer; 00447 00448 /** Timer for search computation. **/ 00449 Timer m_searchTimer; 00450 }; 00451 00452 00453 /* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * 00454 * 00455 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ 00456 00457 #endif 00458 00459