00001 00002 // $Id: util.H,v 1.59 2009/10/29 15:28:54 stanchen Exp $ 00003 00004 /** * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * ** 00005 * @file util.H 00006 * @brief I/O routines, GmmSet, and Graph classes. 00007 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ 00008 00009 #ifndef _UTIL_H 00010 #define _UTIL_H 00011 00012 00013 #include <cassert> 00014 #include <cfloat> 00015 #include <cmath> 00016 #include <algorithm> 00017 #include <fstream> 00018 #include <iostream> 00019 #include <map> 00020 #include <stdexcept> 00021 #include <string> 00022 #include <vector> 00023 #include <boost/format.hpp> 00024 #include <boost/numeric/ublas/matrix.hpp> 00025 #include <boost/shared_ptr.hpp> 00026 00027 00028 using namespace std; 00029 00030 #ifndef SWIG 00031 using boost::format; 00032 using boost::str; 00033 using boost::numeric::ublas::matrix; 00034 using boost::shared_ptr; 00035 #endif 00036 00037 00038 /** * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * ** 00039 * @name Math stuff. 00040 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ 00041 00042 //@{ 00043 00044 /** This value can be used to represent the logprob of a "zero" prob. 00045 * Theoretically, log 0 is negative infinity which we can't store, 00046 * so we can use this very large negative value instead. 00047 **/ 00048 const double g_zeroLogProb = -FLT_MAX / 2.0; 00049 00050 /** Adds the log probs held in @p logProbList, returning answer as log prob. 00051 * That is, let's say we have a list of probability values, the logs of 00052 * which are stored in @p logProbList. Then, this routine returns the 00053 * log of the sum of those probability values. 00054 * Logarithms are base <i>e</i>. 00055 **/ 00056 double add_log_probs(const vector<double>& logProbList); 00057 00058 /** Does in-place real FFT. 00059 * For inputs <tt>vals[i]</tt>, i = 0, ..., N-1 with sample period T, 00060 * on return the real and imaginary parts of the FFT value for frequency 00061 * i/NT are held in the outputs <tt>vals[2*i]</tt> and <tt>vals[2*i+1]</tt>. 00062 **/ 00063 void real_fft(vector<double>& vals); 00064 00065 /** Sets @p vec to be equal to the @p rowIdx-th row of @p mat. 00066 * Rows are numbered starting from 0. 00067 **/ 00068 void copy_matrix_row_to_vector(const matrix<double>& mat, unsigned rowIdx, 00069 vector<double>& vec); 00070 00071 /** Sets the @p rowIdx-th row of @p mat to @p vec; sizes must match. 00072 * Rows are numbered starting from 0. 00073 **/ 00074 void copy_vector_to_matrix_row(const vector<double>& vec, 00075 matrix<double>& mat, unsigned rowIdx); 00076 00077 //@} 00078 00079 00080 /** * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * ** 00081 * @name Command-line parsing and parameter lookup routines. 00082 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ 00083 00084 //@{ 00085 00086 /** Type of object used for holding program parameters. 00087 * Declaration needed for hack to get default arguments to work. 00088 **/ 00089 typedef map<string, string> ParamsType; 00090 00091 #ifndef SWIG 00092 /** Given cmd line arguments @p argv, parses flags of the 00093 * form <tt>--<flag> <val></tt> and places in flag-to-value 00094 * map @p params. 00095 * Expects same @p argv value as passed to <tt>main()</tt>. 00096 * Existing values in @p params are not erased (unless overriden 00097 * in @p argv). 00098 **/ 00099 void process_cmd_line(const char** argv, map<string, string>& params); 00100 #endif 00101 00102 /** Like process_cmd_line(), but expects arguments as string vector. **/ 00103 void process_cmd_line(const vector<string>& argList, 00104 map<string, string>& params); 00105 00106 /** Like process_cmd_line(), but expects space-separated arguments 00107 * in single string. 00108 **/ 00109 void process_cmd_line(const string& argStr, map<string, string>& params); 00110 00111 /** Returns value of boolean parameter @p name from parameter map @p params. 00112 * If not present, returns @p defaultVal. 00113 **/ 00114 bool get_bool_param(const map<string, string>& params, const string& name, 00115 bool defaultVal = false); 00116 00117 /** Like get_bool_param(), but for integer parameters. **/ 00118 int get_int_param(const map<string, string>& params, const string& name, 00119 int defaultVal = 0); 00120 00121 /** Like get_bool_param(), but for floating-point parameters. **/ 00122 double get_float_param(const map<string, string>& params, const string& name, 00123 double defaultVal = 0.0); 00124 00125 /** Like get_bool_param(), but for string parameters. **/ 00126 string get_string_param(const map<string, string>& params, const string& name, 00127 const string& defaultVal = string()); 00128 00129 /** Like get_string_param(), but throws exception if parameter absent. **/ 00130 string get_required_string_param(const map<string, string>& params, 00131 const string& name); 00132 00133 //@} 00134 00135 00136 /** * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * ** 00137 * @name Vector/matrix I/O routines. 00138 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ 00139 00140 //@{ 00141 00142 /** Splits @p inStr into space-separated tokens; places in @p outList. **/ 00143 void split_string(const string& inStr, vector<string>& outList); 00144 00145 /** Reads a list of strings, one to a line, from file @p fileName and 00146 * places in @p strList. 00147 **/ 00148 void read_string_list(const string& fileName, vector<string>& strList); 00149 00150 /** Reads matrix of floating-point numbers from stream @p inStrm in Matlab 00151 * text format and places in @p mat. Expects optional matrix header, and 00152 * then one row per line. If argument @p name is provided, checks 00153 * name associated with matrix matches and throws exception if doesn't. 00154 * Returns name given in matrix header, or empty string if none provided. 00155 **/ 00156 string read_float_matrix(istream& inStrm, matrix<double>& mat, 00157 const string& name = string()); 00158 00159 /** Like read_float_matrix(), but for float vectors. **/ 00160 string read_float_vector(istream& inStrm, vector<double>& vec, 00161 const string& name = string()); 00162 00163 /** Like read_float_matrix(), but for integer matrices. **/ 00164 string read_int_matrix(istream& inStrm, matrix<int>& mat, 00165 const string& name = string()); 00166 00167 /** Like read_float_matrix(), but for integer vectors. **/ 00168 string read_int_vector(istream& inStrm, vector<int>& vec, 00169 const string& name = string()); 00170 00171 /** Like read_float_matrix(), but reads from file @p fileName instead 00172 * of stream. 00173 **/ 00174 void read_float_matrix(const string& fileName, matrix<double>& mat); 00175 00176 /** Like read_float_vector(), but reads from file @p fileName instead 00177 * of stream. 00178 **/ 00179 void read_float_vector(const string& fileName, vector<double>& vec); 00180 00181 /** Like read_int_matrix(), but reads from file @p fileName instead 00182 * of stream. 00183 **/ 00184 void read_int_matrix(const string& fileName, matrix<int>& mat); 00185 00186 /** Like read_int_vector(), but reads from file @p fileName instead 00187 * of stream. 00188 **/ 00189 void read_int_vector(const string& fileName, vector<int>& vec); 00190 00191 /** Writes floating-point matrix @p mat to stream @p outStrm in Matlab text 00192 * format. If the argument @p name is provided, this name will 00193 * be written in the matrix header (and is the name 00194 * that will be assigned to the matrix if loaded in octave). 00195 **/ 00196 void write_float_matrix(ostream& outStrm, const matrix<double>& mat, 00197 const string& name = string()); 00198 00199 /** Like write_float_matrix(), but for float vectors. **/ 00200 void write_float_vector(ostream& outStrm, const vector<double>& vec, 00201 const string& name = string()); 00202 00203 /** Like write_float_matrix(), but for integer matrices. **/ 00204 void write_int_matrix(ostream& outStrm, const matrix<int>& mat, 00205 const string& name = string()); 00206 00207 /** Like write_float_matrix(), but for integer vectors. **/ 00208 void write_int_vector(ostream& outStrm, const vector<int>& vec, 00209 const string& name = string()); 00210 00211 /** Like write_float_matrix(), but writes to file @p fileName instead 00212 * of stream. 00213 **/ 00214 void write_float_matrix(const string& fileName, const matrix<double>& mat); 00215 00216 /** Like write_float_vector(), but writes to file @p fileName instead 00217 * of stream. 00218 **/ 00219 void write_float_vector(const string& fileName, const vector<double>& vec); 00220 00221 /** Like write_int_matrix(), but writes to file @p fileName instead 00222 * of stream. 00223 **/ 00224 void write_int_matrix(const string& fileName, const matrix<int>& mat); 00225 00226 /** Like write_int_vector(), but writes to file @p fileName instead 00227 * of stream. 00228 **/ 00229 void write_int_vector(const string& fileName, const vector<int>& vec); 00230 00231 //@} 00232 00233 00234 /** * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * ** 00235 * Abstract base class, interface for object computing GMM probs. 00236 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ 00237 class GmmScorer 00238 { 00239 public: 00240 /** Virtual destructor. **/ 00241 virtual ~GmmScorer() { } 00242 00243 /** Returns number of GMM's in model. **/ 00244 virtual unsigned get_gmm_count() const = 0; 00245 00246 /** Returns dimension of Gaussians. **/ 00247 virtual unsigned get_dim_count() const = 0; 00248 00249 /** Given input feature vectors, computes log prob (base e) of each GMM 00250 * for each frame; i.e., on exit @p logProbs will have the 00251 * same number of rows as @p feats and one column for each GMM. 00252 **/ 00253 virtual void calc_gmm_probs(const matrix<double>& feats, 00254 matrix<double>& logProbs) const = 0; 00255 }; 00256 00257 00258 /** * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * ** 00259 * Class holding set of diagonal covariance GMM's. 00260 * 00261 * Here, we summarize the key routines for accessing the parameters of 00262 * each GMM. Each GMM has a set of component Gaussians. To find the 00263 * number of components of a GMM, use #get_component_count(). To find the 00264 * mixture weight of a particular component of a GMM, 00265 * use #get_component_weight(). 00266 * To get the means and variances of a particular component, first call 00267 * #get_gaussian_index() to find the index of the corresponding Gaussian. 00268 * Then, one can call #get_gaussian_mean() and #get_gaussian_var() 00269 * with this index to find the means and variances. 00270 * The reason for this indirection is to support the sharing of 00271 * Gaussians between GMM's. 00272 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ 00273 class GmmSet: public GmmScorer 00274 { 00275 public: 00276 /** Ctor; loads from file @p fileName if argument present. **/ 00277 GmmSet(const string& fileName = string()); 00278 00279 /** Reads GMM parameters from file @p fileName. **/ 00280 void read(const string& fileName); 00281 00282 /** Writes GMM parameters to file @p fileName. **/ 00283 void write(const string& fileName) const; 00284 00285 /** Initializes object, where the vector @p gmmCompCounts holds the 00286 * number of component Gaussians of each GMM, and @p dimCnt 00287 * is the dimension of 00288 * the Gaussians. Component weights are initialized to be 00289 * uniform; and means and variances for each dimension are set 00290 * to 0 and 1, respectively. 00291 * If the optional argument @p compMap is not provided, then 00292 * no Gaussians are shared between GMM's. Otherwise, the vector 00293 * @p compMap holds the index of the Gaussian for each 00294 * component of each GMM, in order; i.e., it contains first 00295 * the indices of each component of the 0th GMM, then the 00296 * 1st GMM, etc. 00297 **/ 00298 void init(const vector<int>& gmmCompCounts, unsigned dimCnt, 00299 const vector<int>& compMap = vector<int>()); 00300 00301 /** Clears object, i.e., deletes all GMM's. **/ 00302 void clear(); 00303 00304 /** Returns whether object is empty, i.e., has no GMM's. **/ 00305 bool empty() const { return m_gmmMap.empty(); } 00306 00307 /** Returns number of GMM's in model. **/ 00308 unsigned get_gmm_count() const { return m_gmmMap.size(); } 00309 00310 /** Returns total number of individual Gaussians in model. **/ 00311 unsigned get_gaussian_count() const { return m_gaussParams.size1(); } 00312 00313 /** Returns dimension of Gaussians. **/ 00314 unsigned get_dim_count() const 00315 { assert(!(m_gaussParams.size2() & 1)); 00316 return m_gaussParams.size2() / 2; } 00317 00318 /** Returns number of component Gaussians in @p gmmIdx-th GMM. 00319 * GMM's are numbered starting from 0. 00320 **/ 00321 unsigned get_component_count(unsigned gmmIdx) const 00322 { return get_max_comp_index(gmmIdx) - get_min_comp_index(gmmIdx); } 00323 00324 /** Returns Gaussian index for @p compIdx-th component of 00325 * @p gmmIdx-th GMM. This index is needed for looking up 00326 * or setting Gaussian parameters. 00327 * GMM's and components are numbered starting from 0. 00328 * @see #get_gaussian_mean(), #get_gaussian_var(), etc. 00329 **/ 00330 unsigned get_gaussian_index(unsigned gmmIdx, unsigned compIdx) const 00331 { 00332 assert((gmmIdx < m_gmmMap.size()) && 00333 (compIdx < get_component_count(gmmIdx))); 00334 return m_compMap[get_min_comp_index(gmmIdx) + compIdx]; 00335 } 00336 00337 /** Returns mixture weight of @p compIdx-th component of 00338 * @p gmmIdx-th GMM. 00339 * GMM's and components are numbered starting from 0. 00340 **/ 00341 double get_component_weight(unsigned gmmIdx, unsigned compIdx) const 00342 { 00343 assert((gmmIdx < m_gmmMap.size()) && 00344 (compIdx < get_component_count(gmmIdx))); 00345 return m_compWeights[get_min_comp_index(gmmIdx) + compIdx]; 00346 } 00347 00348 /** Sets mixture weight of @p compIdx-th component of 00349 * @p gmmIdx-th GMM to @p wgt. 00350 * GMM's and components are numbered starting from 0. 00351 **/ 00352 void set_component_weight(unsigned gmmIdx, unsigned compIdx, 00353 double wgt) 00354 { 00355 assert((gmmIdx < m_gmmMap.size()) && 00356 (compIdx < get_component_count(gmmIdx))); 00357 m_compWeights[get_min_comp_index(gmmIdx) + compIdx] = wgt; 00358 m_normsUpToDate = false; 00359 } 00360 00361 /** Returns mean for dimension @p dimIdx for Gaussian with index 00362 * @p gaussIdx. 00363 * Gaussians and dimensions are numbered starting from 0. 00364 **/ 00365 double get_gaussian_mean(unsigned gaussIdx, unsigned dimIdx) const 00366 { 00367 assert((gaussIdx < m_gaussParams.size1()) && 00368 (2 * dimIdx < m_gaussParams.size2())); 00369 return m_gaussParams(gaussIdx, 2 * dimIdx); 00370 } 00371 00372 /** Returns variance for dimension @p dimIdx for Gaussian with index 00373 * @p gaussIdx. 00374 * Gaussians and dimensions are numbered starting from 0. 00375 **/ 00376 double get_gaussian_var(unsigned gaussIdx, unsigned dimIdx) const 00377 { 00378 assert((gaussIdx < m_gaussParams.size1()) && 00379 (2 * dimIdx + 1 < m_gaussParams.size2())); 00380 return m_gaussParams(gaussIdx, 2 * dimIdx + 1); 00381 } 00382 00383 /** Sets mean for dimension @p dimIdx for Gaussian @p gaussIdx. 00384 * Gaussians and dimensions are numbered starting from 0. 00385 **/ 00386 void set_gaussian_mean(unsigned gaussIdx, unsigned dimIdx, double val) 00387 { 00388 assert((gaussIdx < m_gaussParams.size1()) && 00389 (2 * dimIdx < m_gaussParams.size2())); 00390 m_gaussParams(gaussIdx, 2 * dimIdx) = val; 00391 m_normsUpToDate = false; 00392 } 00393 00394 /** Sets variance for dimension @p dimIdx for Gaussian @p gaussIdx. 00395 * Gaussians and dimensions are numbered starting from 0. 00396 **/ 00397 void set_gaussian_var(unsigned gaussIdx, unsigned dimIdx, double val) 00398 { 00399 assert((gaussIdx < m_gaussParams.size1()) && 00400 (2 * dimIdx + 1 < m_gaussParams.size2())); 00401 m_gaussParams(gaussIdx, 2 * dimIdx + 1) = val; 00402 m_normsUpToDate = false; 00403 } 00404 00405 /** Copies the means and variances of Gaussian @p srcGaussIdx 00406 * in @p srcGmmSet into Gaussian @p dstGaussIdx in this object. 00407 * Gaussians are numbered starting from 0. 00408 **/ 00409 void copy_gaussian(unsigned dstGaussIdx, const GmmSet& srcGmmSet, 00410 unsigned srcGaussIdx); 00411 00412 /** Given input feature vectors, computes log prob (base e) of each GMM 00413 * for each frame; i.e., on exit @p logProbs will have the 00414 * same number of rows as @p feats and one column for each GMM. 00415 **/ 00416 void calc_gmm_probs(const matrix<double>& feats, 00417 matrix<double>& logProbs) const; 00418 00419 /** Computes log prob (base e) of each component Gaussian for 00420 * GMM @p gmmIdx for feature vector @p feats. 00421 * Places result in @p logProbs. Returns total log prob of GMM. 00422 * GMM's are numbered starting from 0. 00423 **/ 00424 double calc_component_probs(const vector<double>& feats, 00425 unsigned gmmIdx, vector<double>& logProbs) const; 00426 00427 private: 00428 /** Returns index of first component for GMM @p gmmIdx. **/ 00429 unsigned get_min_comp_index(unsigned gmmIdx) const 00430 { 00431 assert(gmmIdx < m_gmmMap.size()); 00432 return m_gmmMap[gmmIdx]; 00433 } 00434 00435 /** Returns one past index of last component for GMM @p gmmIdx. **/ 00436 unsigned get_max_comp_index(unsigned gmmIdx) const 00437 { 00438 assert(gmmIdx < m_gmmMap.size()); 00439 return (gmmIdx + 1 < m_gmmMap.size()) ? 00440 m_gmmMap[gmmIdx + 1] : m_compMap.size(); 00441 } 00442 00443 /** Recomputes normalization constants useful for log prob 00444 * computation. 00445 **/ 00446 void compute_norms() const; 00447 00448 /** Returns log norm constant + log weight (base e) for 00449 * @p compIdx-th component of @p gmmIdx-th GMM. 00450 * GMM's and components are numbered starting from 0. 00451 **/ 00452 double get_component_norm(unsigned gmmIdx, unsigned compIdx) const 00453 { 00454 assert(m_normsUpToDate && (gmmIdx < m_gmmMap.size()) && 00455 (compIdx < get_component_count(gmmIdx))); 00456 return m_logNorms[get_min_comp_index(gmmIdx) + compIdx]; 00457 } 00458 00459 private: 00460 /** For each GMM, index of first component in m_compMap, etc. **/ 00461 vector<unsigned> m_gmmMap; 00462 00463 /** For each component of each GMM, its index in m_gaussParams. **/ 00464 vector<unsigned> m_compMap; 00465 00466 /** For each component of each GMM, its mixture weight. **/ 00467 vector<double> m_compWeights; 00468 00469 /** For each Gaussian, alternating mean + var for each dim. **/ 00470 matrix<double> m_gaussParams; 00471 00472 /** Whether m_logNorms is up to date. **/ 00473 mutable bool m_normsUpToDate; 00474 00475 /** For each component of each GMM, log norm constant + log weight. **/ 00476 mutable vector<double> m_logNorms; 00477 }; 00478 00479 00480 /** * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * ** 00481 * GMM count class. 00482 * 00483 * Holds a posterior count for a GMM at a frame. Includes the 00484 * GMM index, the frame, and the posterior count. This is 00485 * used to facilitate Forward-Backward and Viterbi EM training. 00486 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ 00487 class GmmCount 00488 { 00489 public: 00490 /** Ctor; initializes fields to default values. **/ 00491 GmmCount() : m_gmmIdx(0), m_frmIdx(0), m_count(0.0) { } 00492 00493 /** Ctor; explicitly initializes all fields. **/ 00494 GmmCount(unsigned gmmIdx, unsigned frmIdx, double count) : 00495 m_gmmIdx(gmmIdx), m_frmIdx(frmIdx), m_count(count) { } 00496 00497 /** Sets all values in object. **/ 00498 void assign(unsigned gmmIdx, unsigned frmIdx, double count) 00499 { m_gmmIdx = gmmIdx; m_frmIdx = frmIdx; m_count = count; } 00500 00501 /** Returns the associated GMM index. **/ 00502 unsigned get_gmm_index() const { return m_gmmIdx; } 00503 00504 /** Returns the associated frame index. **/ 00505 unsigned get_frame_index() const { return m_frmIdx; } 00506 00507 /** Returns the posterior count. **/ 00508 double get_count() const { return m_count; } 00509 00510 private: 00511 /** The index of the GMM. **/ 00512 unsigned m_gmmIdx; 00513 00514 /** Which frame the count occurred at. **/ 00515 unsigned m_frmIdx; 00516 00517 /** The posterior count. **/ 00518 float m_count; 00519 }; 00520 00521 #ifndef SWIG 00522 00523 /** Orders GmmCount objects first by frame, then GMM index, then 00524 * by decreasing count. 00525 **/ 00526 inline bool operator<(const GmmCount& cnt1, const GmmCount& cnt2) 00527 { 00528 if (cnt1.get_frame_index() != cnt2.get_frame_index()) 00529 return cnt1.get_frame_index() < cnt2.get_frame_index(); 00530 if (cnt1.get_gmm_index() != cnt2.get_gmm_index()) 00531 return cnt1.get_gmm_index() < cnt2.get_gmm_index(); 00532 return cnt1.get_count() > cnt2.get_count(); 00533 } 00534 00535 #endif 00536 00537 00538 /** * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * ** 00539 * Class holding symbol table for a graph/FSM. 00540 * 00541 * In graphs/FSM's, word labels are stored internally as integers. 00542 * This object holds the mapping from word spellings to their 00543 * integer representations, and vice versa. Word indices are 00544 * constrained to be nonnegative, though they need not be consecutive. 00545 * By convention, the word 00546 * index corresponding to epsilon (i.e., the representation of 00547 * the empty string) is 0. 00548 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ 00549 class SymbolTable 00550 { 00551 public: 00552 /** Ctor; loads from file @p fileName if argument present. **/ 00553 SymbolTable(const string& fileName = string()) 00554 { if (!fileName.empty()) read(fileName); } 00555 00556 /** Reads symbols from file @p fileName. **/ 00557 void read(const string& fileName); 00558 00559 /** Clears object. **/ 00560 void clear(); 00561 00562 /** Returns number of symbols in table. **/ 00563 unsigned size() const { return m_strToIdxMap.size(); } 00564 00565 /** Returns whether symbol table is empty. **/ 00566 bool empty() const { return m_strToIdxMap.empty(); } 00567 00568 /** Maps from a string to its index. 00569 * If not in table, returns -1. 00570 **/ 00571 int get_index(const string& theStr) const 00572 { 00573 map<string, unsigned>::const_iterator lookup = 00574 m_strToIdxMap.find(theStr); 00575 return (lookup != m_strToIdxMap.end()) ? (int) lookup->second : -1; 00576 } 00577 00578 /** Maps from an index to its string. 00579 * If not in table, returns empty string. 00580 **/ 00581 string get_str(unsigned theIdx) const 00582 { 00583 map<unsigned, string>::const_iterator lookup = 00584 m_idxToStrMap.find(theIdx); 00585 return (lookup != m_idxToStrMap.end()) ? lookup->second : 00586 string(); 00587 } 00588 00589 private: 00590 /** Map from strings to integer indices. **/ 00591 map<string, unsigned> m_strToIdxMap; 00592 00593 /** Map from integer indices to strings. **/ 00594 map<unsigned, string> m_idxToStrMap; 00595 }; 00596 00597 /** For converting a vector of strings @p wordList to a vector of 00598 * ints @p wordIdxList using a SymbolTable, for n-gram model processing. 00599 * Words not in the symbol table are converted 00600 * to the value @p unkIdx. The beginning of the output sequence 00601 * is padded with @p n - 1 @p bosIdx values, and a single @p eosIdx 00602 * value is added to the end. 00603 **/ 00604 void convert_words_to_indices(const vector<string>& wordList, 00605 vector<int>& wordIdxList, const SymbolTable& symTable, 00606 int n, int bosIdx, int eosIdx, int unkIdx); 00607 00608 00609 /** * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * ** 00610 * Arc class. 00611 * 00612 * Holds a single arc in a graph/FSM. With each graph, there is 00613 * an implicitly associated GmmSet and an explicitly associated 00614 * SymbolTable (see Graph::get_word_sym_table()). 00615 * An arc holds a destination state; 00616 * an optional GMM index (corresponding to a GMM in the GmmSet); 00617 * an optional word index (corresponding to a word in the SymbolTable); 00618 * and a log prob (base e). 00619 * Source state information is not present; if you have the arc ID 00620 * (see Graph::get_first_arc_id(), Graph::get_arc()), you can look this 00621 * up using Graph::get_src_state(). 00622 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ 00623 class Arc 00624 { 00625 public: 00626 /** Ctor; initializes fields to default values. **/ 00627 Arc() : m_dst(0), m_gmmIdx(-1), m_wordIdx(0), m_logProb(0.0) { } 00628 00629 /** Ctor; explicitly initializes all fields. 00630 * @see #assign(). 00631 **/ 00632 Arc(unsigned dst, int gmmIdx, int wordIdx, double logProb) : 00633 m_dst(dst), m_gmmIdx(gmmIdx), m_wordIdx(wordIdx), 00634 m_logProb(logProb) { } 00635 00636 /** Sets all values in arc. 00637 * The argument @p dst should be the destination state; 00638 * @p gmmIdx should be the index of the associated GMM (or -1 00639 * if not present); @p wordIdx should be the index of the 00640 * associated word (or -1 if not present); and @p logProb 00641 * should be the log prob base e of the arc. 00642 **/ 00643 void assign(unsigned dst, int gmmIdx, int wordIdx, double logProb) 00644 { m_dst = dst; m_gmmIdx = gmmIdx; m_wordIdx = wordIdx; 00645 m_logProb = logProb; } 00646 00647 /** Returns dest state index. 00648 * To find src state, see Graph::get_src_state(). 00649 **/ 00650 unsigned get_dst_state() const { return m_dst; } 00651 00652 /** Returns assoc GMM index, or -1 if not present. **/ 00653 int get_gmm() const { return m_gmmIdx; } 00654 00655 /** Returns assoc word index, or 0 if not present/epsilon. 00656 * To find the corresponding word spelling, see 00657 * Graph::get_word_sym_table(). 00658 **/ 00659 unsigned get_word() const { return m_wordIdx; } 00660 00661 /** Returns assoc log prob base e. **/ 00662 double get_log_prob() const { return m_logProb; } 00663 00664 private: 00665 /** Destination state. **/ 00666 unsigned m_dst; 00667 00668 /** GMM index, or -1 if not present. **/ 00669 int m_gmmIdx; 00670 00671 /** Word index, or 0 if not present/epsilon. **/ 00672 unsigned m_wordIdx; 00673 00674 /** Log prob base e. **/ 00675 float m_logProb; 00676 }; 00677 00678 00679 /** * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * ** 00680 * Graph/FSM class. 00681 * 00682 * This object holds a graph as needed for training or decoding. 00683 * A graph has a set of states numbered starting from 0 00684 * (see #get_state_count()), one of 00685 * which is a start state (see #get_start_state()), 00686 * and many of which may be final states (see #is_final_state(), 00687 * #get_final_state_list()). Final states have an associated final 00688 * log prob (see #get_final_log_prob()). 00689 * 00690 * For each state, you can access the list of outgoing arcs using 00691 * #get_arc_count(), #get_first_arc_id(), and #get_arc(). 00692 * Each arc has an associated ID, which is used for iterating 00693 * through arcs. Here is an example of how to iterate through 00694 * the outgoing arcs of state @p stateIdx of graph @p graph. 00695 * 00696 * @code 00697 * // Get number of outgoing arcs. 00698 * int arcCnt = graph.get_arc_count(stateIdx); 00699 * // Get arc ID of first outgoing arc. 00700 * int arcId = graph.get_first_arc_id(stateIdx); 00701 * for (int arcIdx = 0; arcIdx < arcCnt; ++arcIdx) 00702 * { 00703 * Arc arc; 00704 * // Place arc with ID "arcId" in "arc"; set "arcId" 00705 * // to arc ID of the next outgoing arc. 00706 * arcId = graph.get_arc(arcId, arc); 00707 * 00708 * // You can now access elements of the Arc "arc", e.g., 00709 * int dstState = arc.get_dst_state(); 00710 * ... 00711 * } 00712 * @endcode 00713 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ 00714 class Graph 00715 { 00716 public: 00717 /** Ctor; loads from file @p fileName if argument present, 00718 * and loads sym table from file @p symFile if argument present. 00719 **/ 00720 Graph(const string& fileName = string(), 00721 const string& symFile = string()); 00722 00723 /** Reads graph from file @p fileName. 00724 * Reads symbol table from file @p symFile if not empty string. 00725 **/ 00726 void read(const string& fileName, const string& symFile = string()); 00727 00728 /** Reads graph from stream @p inStrm. 00729 * If argument @p name is provided, checks that name associated with 00730 * graph matches. Returns name given in graph header, 00731 * or empty string if none provided. 00732 **/ 00733 string read(istream& inStrm, const string& name = string()); 00734 00735 /** Writes graph to file @p fileName. **/ 00736 void write(const string& fileName) const; 00737 00738 /** Writes graph to stream @p outStrm. 00739 * If the argument @p name is provided, this name will be written 00740 * in the graph header. 00741 **/ 00742 void write(ostream& outStrm, const string& name = string()) const; 00743 00744 /** Reads word symbol table from file @p symFile. 00745 * Pass in the empty string to load an empty symbol table. 00746 **/ 00747 void read_word_sym_table(const string& symFile); 00748 00749 /** Clears object except for symbol table; i.e., delete all states 00750 * and arcs. 00751 **/ 00752 void clear(); 00753 00754 /** Returns whether there are no states. **/ 00755 bool empty() const { return m_stateMap.empty(); } 00756 00757 00758 /** Returns a reference to the word symbol table. **/ 00759 const SymbolTable& get_word_sym_table() const 00760 { return *m_symTable.get(); } 00761 00762 /** Returns one above highest GMM index in graph. **/ 00763 unsigned get_gmm_count() const; 00764 00765 /** Returns total number of states. **/ 00766 unsigned get_state_count() const { return m_stateMap.size(); } 00767 00768 /** Returns index of start state, or -1 if unset. **/ 00769 int get_start_state() const { return m_start; } 00770 00771 /** Returns number of outgoing arcs for state @p stateIdx. 00772 * States are numbered from 0. 00773 **/ 00774 unsigned get_arc_count(unsigned stateIdx) const 00775 { 00776 return get_max_arc_index(stateIdx) - get_min_arc_index(stateIdx); 00777 } 00778 00779 /** Returns arc ID of first outgoing arc of state @p stateIdx. 00780 * States are numbered from 0. 00781 **/ 00782 unsigned get_first_arc_id(unsigned stateIdx) const 00783 { 00784 assert(stateIdx < m_stateMap.size()); 00785 return get_min_arc_index(stateIdx); 00786 } 00787 00788 /** Places arc with ID @p arcId in @p arc. 00789 * Returns ID of next outgoing arc of same state. 00790 **/ 00791 unsigned get_arc(unsigned arcId, Arc& arc) const 00792 { 00793 assert(arcId < m_arcList.size()); 00794 arc = m_arcList[arcId]; 00795 return arcId + 1; 00796 } 00797 00798 /** Returns src state of an arc given its ID. **/ 00799 unsigned get_src_state(unsigned arcId) const; 00800 00801 /** Returns wheter state @p stateIdx is final state. 00802 * States are numbered from 0. 00803 **/ 00804 bool is_final_state(unsigned stateIdx) const 00805 { return m_finalLogProbs.find(stateIdx) != m_finalLogProbs.end(); } 00806 00807 /** Returns final log prob (base e) of state @p stateIdx or 00808 * g_zeroLogProb if not final. 00809 * States are numbered from 0. 00810 **/ 00811 double get_final_log_prob(unsigned stateIdx) const 00812 { 00813 map<unsigned, float>::const_iterator lookup = 00814 m_finalLogProbs.find(stateIdx); 00815 return (lookup != m_finalLogProbs.end()) ? lookup->second : 00816 g_zeroLogProb; 00817 } 00818 00819 /** Returns number of final states; places list in @p stateList. **/ 00820 unsigned get_final_state_list(vector<int>& stateList) const 00821 { 00822 stateList.clear(); 00823 for (map<unsigned, float>::const_iterator elemPtr = 00824 m_finalLogProbs.begin(); elemPtr != m_finalLogProbs.end(); 00825 ++elemPtr) 00826 stateList.push_back(elemPtr->first); 00827 sort(stateList.begin(), stateList.end()); 00828 return stateList.size(); 00829 } 00830 00831 private: 00832 /** Returns index of first arc for state @p stateIdx. **/ 00833 unsigned get_min_arc_index(unsigned stateIdx) const 00834 { 00835 assert(stateIdx < m_stateMap.size()); 00836 return m_stateMap[stateIdx]; 00837 } 00838 00839 /** Returns one past index of last arc for state @p stateIdx. **/ 00840 unsigned get_max_arc_index(unsigned stateIdx) const 00841 { 00842 assert(stateIdx < m_stateMap.size()); 00843 return (stateIdx + 1 < m_stateMap.size()) ? 00844 m_stateMap[stateIdx + 1] : m_arcList.size(); 00845 } 00846 00847 private: 00848 /** Map from words to integer indices. **/ 00849 shared_ptr<SymbolTable> m_symTable; 00850 00851 /** Index of start state. **/ 00852 int m_start; 00853 00854 /** Map from final states to their final log probs. **/ 00855 map<unsigned, float> m_finalLogProbs; 00856 00857 /** For each state, index of first arc in m_arcList. 00858 * Assumes that arcs for each state are contiguous. 00859 **/ 00860 vector<unsigned> m_stateMap; 00861 00862 /** List of arcs in graph, grouped by source state. **/ 00863 vector<Arc> m_arcList; 00864 }; 00865 00866 00867 /** * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * ** 00868 * Class for storing counts for a set of n-grams. 00869 * 00870 * N-grams are represented as vectors of integers. The main access 00871 * methods are #incr_count(), #set_count(), and #get_count(). 00872 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ 00873 class NGramCounter 00874 { 00875 public: 00876 /** Ctor; initializes object to be empty. **/ 00877 NGramCounter() { } 00878 00879 /** Writes all counts to stream @p outStrm in a text format. 00880 * Uses @p symTable to map integer indices to strings, if present. 00881 **/ 00882 void write(ostream& outStrm, 00883 const SymbolTable& symTable = SymbolTable()) const; 00884 00885 /** Clears object; deletes all n-grams in table. **/ 00886 void clear() { m_countMap.clear(); } 00887 00888 /** Returns number of n-grams in table. **/ 00889 unsigned size() const { return m_countMap.size(); } 00890 00891 /** Returns whether object is empty. **/ 00892 bool empty() const { return m_countMap.empty(); } 00893 00894 /** Increments count of an n-gram; returns new count. **/ 00895 unsigned incr_count(const vector<int>& ngram) 00896 { return ++m_countMap[ngram]; } 00897 00898 /** Sets count of an n-gram to @p val. 00899 * If @p val is 0, n-gram is removed from table. 00900 **/ 00901 void set_count(const vector<int>& ngram, unsigned val) 00902 { 00903 if (val != 0) 00904 m_countMap[ngram] = val; 00905 else 00906 m_countMap.erase(ngram); 00907 } 00908 00909 /** Returns count of an n-gram, or 0 if not present. **/ 00910 unsigned get_count(const vector<int>& ngram) const 00911 { 00912 map<vector<int>, int>::const_iterator iter = 00913 m_countMap.find(ngram); 00914 return (iter != m_countMap.end()) ? iter->second : 0; 00915 } 00916 00917 private: 00918 /** Map from n-grams to counts. **/ 00919 map<vector<int>, int> m_countMap; 00920 }; 00921 00922 00923 /* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * 00924 * 00925 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ 00926 00927 #endif 00928 00929