1 #ifndef __FST_IO_H__ 2 #define __FST_IO_H__ 3 4 // fst-io.h 5 // This is a copy of the OPENFST SDK application sample files ... 6 // except for the main functions ifdef'ed out 7 // 2007, 2008 Nuance Communications 8 // 9 // print-main.h compile-main.h 10 // 11 // Licensed under the Apache License, Version 2.0 (the "License"); 12 // you may not use this file except in compliance with the License. 13 // You may obtain a copy of the License at 14 // 15 // http://www.apache.org/licenses/LICENSE-2.0 16 // 17 // Unless required by applicable law or agreed to in writing, software 18 // distributed under the License is distributed on an "AS IS" BASIS, 19 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 20 // See the License for the specific language governing permissions and 21 // limitations under the License. 22 // 23 // 24 // \file 25 // Classes and functions to compile a binary Fst from textual input. 26 // Includes helper function for fstcompile.cc that templates the main 27 // on the arc type to support multiple and extensible arc types. 28 29 #include <fstream> 30 #include <sstream> 31 32 #include "fst/lib/fst.h" 33 #include "fst/lib/fstlib.h" 34 #include "fst/lib/fst-decl.h" 35 #include "fst/lib/vector-fst.h" 36 #include "fst/lib/arcsort.h" 37 #include "fst/lib/invert.h" 38 39 namespace fst { 40 41 template <class A> class FstPrinter { 42 public: 43 typedef A Arc; 44 typedef typename A::StateId StateId; 45 typedef typename A::Label Label; 46 typedef typename A::Weight Weight; 47 FstPrinter(const Fst<A> & fst,const SymbolTable * isyms,const SymbolTable * osyms,const SymbolTable * ssyms,bool accep)48 FstPrinter(const Fst<A> &fst, 49 const SymbolTable *isyms, 50 const SymbolTable *osyms, 51 const SymbolTable *ssyms, 52 bool accep) 53 : fst_(fst), isyms_(isyms), osyms_(osyms), ssyms_(ssyms), 54 accep_(accep && fst.Properties(kAcceptor, true)), ostrm_(0) {} 55 56 // Print Fst to an output strm Print(ostream * ostrm,const string & dest)57 void Print(ostream *ostrm, const string &dest) { 58 ostrm_ = ostrm; 59 dest_ = dest; 60 StateId start = fst_.Start(); 61 if (start == kNoStateId) 62 return; 63 // initial state first 64 PrintState(start); 65 for (StateIterator< Fst<A> > siter(fst_); 66 !siter.Done(); 67 siter.Next()) { 68 StateId s = siter.Value(); 69 if (s != start) 70 PrintState(s); 71 } 72 } 73 74 private: 75 // Maximum line length in text file. 76 static const int kLineLen = 8096; 77 PrintId(int64 id,const SymbolTable * syms,const char * name)78 void PrintId(int64 id, const SymbolTable *syms, 79 const char *name) const { 80 if (syms) { 81 string symbol = syms->Find(id); 82 if (symbol == "") { 83 LOG(ERROR) << "FstPrinter: Integer " << id 84 << " is not mapped to any textual symbol" 85 << ", symbol table = " << syms->Name() 86 << ", destination = " << dest_; 87 exit(1); 88 } 89 *ostrm_ << symbol; 90 } else { 91 *ostrm_ << id; 92 } 93 } 94 PrintStateId(StateId s)95 void PrintStateId(StateId s) const { 96 PrintId(s, ssyms_, "state ID"); 97 } 98 PrintILabel(Label l)99 void PrintILabel(Label l) const { 100 PrintId(l, isyms_, "arc input label"); 101 } 102 PrintOLabel(Label l)103 void PrintOLabel(Label l) const { 104 PrintId(l, osyms_, "arc output label"); 105 } 106 PrintState(StateId s)107 void PrintState(StateId s) const { 108 bool output = false; 109 for (ArcIterator< Fst<A> > aiter(fst_, s); 110 !aiter.Done(); 111 aiter.Next()) { 112 Arc arc = aiter.Value(); 113 PrintStateId(s); 114 *ostrm_ << "\t"; 115 PrintStateId(arc.nextstate); 116 *ostrm_ << "\t"; 117 PrintILabel(arc.ilabel); 118 if (!accep_) { 119 *ostrm_ << "\t"; 120 PrintOLabel(arc.olabel); 121 } 122 if (arc.weight != Weight::One()) 123 *ostrm_ << "\t" << arc.weight; 124 *ostrm_ << "\n"; 125 output = true; 126 } 127 Weight final = fst_.Final(s); 128 if (final != Weight::Zero() || !output) { 129 PrintStateId(s); 130 if (final != Weight::One()) { 131 *ostrm_ << "\t" << final; 132 } 133 *ostrm_ << "\n"; 134 } 135 } 136 137 const Fst<A> &fst_; 138 const SymbolTable *isyms_; // ilabel symbol table 139 const SymbolTable *osyms_; // olabel symbol table 140 const SymbolTable *ssyms_; // slabel symbol table 141 bool accep_; // print as acceptor when possible 142 ostream *ostrm_; // binary FST destination 143 string dest_; // binary FST destination name 144 DISALLOW_EVIL_CONSTRUCTORS(FstPrinter); 145 }; 146 147 #if 0 148 // Main function for fstprint templated on the arc type. 149 template <class Arc> 150 int PrintMain(int argc, char **argv, istream &istrm, 151 const FstReadOptions &opts) { 152 Fst<Arc> *fst = Fst<Arc>::Read(istrm, opts); 153 if (!fst) return 1; 154 155 string dest = "standard output"; 156 ostream *ostrm = &std::cout; 157 if (argc == 3) { 158 dest = argv[2]; 159 ostrm = new ofstream(argv[2]); 160 if (!*ostrm) { 161 LOG(ERROR) << argv[0] << ": Open failed, file = " << argv[2]; 162 return 0; 163 } 164 } 165 ostrm->precision(9); 166 167 const SymbolTable *isyms = 0, *osyms = 0, *ssyms = 0; 168 169 if (!FLAGS_isymbols.empty() && !FLAGS_numeric) { 170 isyms = SymbolTable::ReadText(FLAGS_isymbols); 171 if (!isyms) exit(1); 172 } 173 174 if (!FLAGS_osymbols.empty() && !FLAGS_numeric) { 175 osyms = SymbolTable::ReadText(FLAGS_osymbols); 176 if (!osyms) exit(1); 177 } 178 179 if (!FLAGS_ssymbols.empty() && !FLAGS_numeric) { 180 ssyms = SymbolTable::ReadText(FLAGS_ssymbols); 181 if (!ssyms) exit(1); 182 } 183 184 if (!isyms && !FLAGS_numeric) 185 isyms = fst->InputSymbols(); 186 if (!osyms && !FLAGS_numeric) 187 osyms = fst->OutputSymbols(); 188 189 FstPrinter<Arc> fstprinter(*fst, isyms, osyms, ssyms, FLAGS_acceptor); 190 fstprinter.Print(ostrm, dest); 191 192 if (isyms && !FLAGS_save_isymbols.empty()) 193 isyms->WriteText(FLAGS_save_isymbols); 194 195 if (osyms && !FLAGS_save_osymbols.empty()) 196 osyms->WriteText(FLAGS_save_osymbols); 197 198 if (ostrm != &std::cout) 199 delete ostrm; 200 return 0; 201 } 202 #endif 203 204 205 template <class A> class FstReader { 206 public: 207 typedef A Arc; 208 typedef typename A::StateId StateId; 209 typedef typename A::Label Label; 210 typedef typename A::Weight Weight; 211 FstReader(istream & istrm,const string & source,const SymbolTable * isyms,const SymbolTable * osyms,const SymbolTable * ssyms,bool accep,bool ikeep,bool okeep,bool nkeep)212 FstReader(istream &istrm, const string &source, 213 const SymbolTable *isyms, const SymbolTable *osyms, 214 const SymbolTable *ssyms, bool accep, bool ikeep, 215 bool okeep, bool nkeep) 216 : nline_(0), source_(source), 217 isyms_(isyms), osyms_(osyms), ssyms_(ssyms), 218 nstates_(0), keep_state_numbering_(nkeep) { 219 char line[kLineLen]; 220 while (istrm.getline(line, kLineLen)) { 221 ++nline_; 222 vector<char *> col; 223 SplitToVector(line, "\n\t ", &col, true); 224 if (col.size() == 0 || col[0][0] == '\0') // empty line 225 continue; 226 if (col.size() > 5 || 227 (col.size() > 4 && accep) || 228 (col.size() == 3 && !accep)) { 229 LOG(ERROR) << "FstReader: Bad number of columns, source = " << source_ 230 << ", line = " << nline_; 231 exit(1); 232 } 233 StateId s = StrToStateId(col[0]); 234 while (s >= fst_.NumStates()) 235 fst_.AddState(); 236 if (nline_ == 1) 237 fst_.SetStart(s); 238 239 Arc arc; 240 StateId d = s; 241 switch (col.size()) { 242 case 1: 243 fst_.SetFinal(s, Weight::One()); 244 break; 245 case 2: 246 fst_.SetFinal(s, StrToWeight(col[1], true)); 247 break; 248 case 3: 249 arc.nextstate = d = StrToStateId(col[1]); 250 arc.ilabel = StrToILabel(col[2]); 251 arc.olabel = arc.ilabel; 252 arc.weight = Weight::One(); 253 fst_.AddArc(s, arc); 254 break; 255 case 4: 256 arc.nextstate = d = StrToStateId(col[1]); 257 arc.ilabel = StrToILabel(col[2]); 258 if (accep) { 259 arc.olabel = arc.ilabel; 260 arc.weight = StrToWeight(col[3], false); 261 } else { 262 arc.olabel = StrToOLabel(col[3]); 263 arc.weight = Weight::One(); 264 } 265 fst_.AddArc(s, arc); 266 break; 267 case 5: 268 arc.nextstate = d = StrToStateId(col[1]); 269 arc.ilabel = StrToILabel(col[2]); 270 arc.olabel = StrToOLabel(col[3]); 271 arc.weight = StrToWeight(col[4], false); 272 fst_.AddArc(s, arc); 273 } 274 while (d >= fst_.NumStates()) 275 fst_.AddState(); 276 } 277 if (ikeep) 278 fst_.SetInputSymbols(isyms); 279 if (okeep) 280 fst_.SetOutputSymbols(osyms); 281 } 282 Fst()283 const VectorFst<A> &Fst() const { return fst_; } 284 285 private: 286 // Maximum line length in text file. 287 static const int kLineLen = 8096; 288 StrToId(const char * s,const SymbolTable * syms,const char * name)289 int64 StrToId(const char *s, const SymbolTable *syms, 290 const char *name) const { 291 int64 n; 292 293 if (syms) { 294 n = syms->Find(s); 295 if (n < 0) { 296 LOG(ERROR) << "FstReader: Symbol \"" << s 297 << "\" is not mapped to any integer " << name 298 << ", symbol table = " << syms->Name() 299 << ", source = " << source_ << ", line = " << nline_; 300 exit(1); 301 } 302 } else { 303 char *p; 304 n = strtoll(s, &p, 10); 305 if (p < s + strlen(s) || n < 0) { 306 LOG(ERROR) << "FstReader: Bad " << name << " integer = \"" << s 307 << "\", source = " << source_ << ", line = " << nline_; 308 exit(1); 309 } 310 } 311 return n; 312 } 313 StrToStateId(const char * s)314 StateId StrToStateId(const char *s) { 315 StateId n = StrToId(s, ssyms_, "state ID"); 316 317 if (keep_state_numbering_) 318 return n; 319 320 // remap state IDs to make dense set 321 typename std::unordered_map<StateId, StateId>::const_iterator it = 322 states_.find(n); 323 if (it == states_.end()) { 324 states_[n] = nstates_; 325 return nstates_++; 326 } else { 327 return it->second; 328 } 329 } 330 StrToILabel(const char * s)331 StateId StrToILabel(const char *s) const { 332 return StrToId(s, isyms_, "arc ilabel"); 333 } 334 StrToOLabel(const char * s)335 StateId StrToOLabel(const char *s) const { 336 return StrToId(s, osyms_, "arc olabel"); 337 } 338 StrToWeight(const char * s,bool allow_zero)339 Weight StrToWeight(const char *s, bool allow_zero) const { 340 Weight w; 341 istringstream strm(s); 342 strm >> w; 343 if (strm.fail() || (!allow_zero && w == Weight::Zero())) { 344 LOG(ERROR) << "FstReader: Bad weight = \"" << s 345 << "\", source = " << source_ << ", line = " << nline_; 346 exit(1); 347 } 348 return w; 349 } 350 351 VectorFst<A> fst_; 352 size_t nline_; 353 string source_; // text FST source name 354 const SymbolTable *isyms_; // ilabel symbol table 355 const SymbolTable *osyms_; // olabel symbol table 356 const SymbolTable *ssyms_; // slabel symbol table 357 std::unordered_map<StateId, StateId> states_; // state ID map 358 StateId nstates_; // number of seen states 359 bool keep_state_numbering_; 360 DISALLOW_EVIL_CONSTRUCTORS(FstReader); 361 }; 362 363 #if 0 364 // Main function for fstcompile templated on the arc type. Last two 365 // arguments unneeded since fstcompile passes the arc type as a flag 366 // unlike the other mains, which infer the arc type from an input Fst. 367 template <class Arc> 368 int CompileMain(int argc, char **argv, istream& /* strm */, 369 const FstReadOptions & /* opts */) { 370 char *ifilename = "standard input"; 371 istream *istrm = &std::cin; 372 if (argc > 1 && strcmp(argv[1], "-") != 0) { 373 ifilename = argv[1]; 374 istrm = new ifstream(ifilename); 375 if (!*istrm) { 376 LOG(ERROR) << argv[0] << ": Open failed, file = " << ifilename; 377 return 1; 378 } 379 } 380 const SymbolTable *isyms = 0, *osyms = 0, *ssyms = 0; 381 382 if (!FLAGS_isymbols.empty()) { 383 isyms = SymbolTable::ReadText(FLAGS_isymbols); 384 if (!isyms) exit(1); 385 } 386 387 if (!FLAGS_osymbols.empty()) { 388 osyms = SymbolTable::ReadText(FLAGS_osymbols); 389 if (!osyms) exit(1); 390 } 391 392 if (!FLAGS_ssymbols.empty()) { 393 ssyms = SymbolTable::ReadText(FLAGS_ssymbols); 394 if (!ssyms) exit(1); 395 } 396 397 FstReader<Arc> fstreader(*istrm, ifilename, isyms, osyms, ssyms, 398 FLAGS_acceptor, FLAGS_keep_isymbols, 399 FLAGS_keep_osymbols, FLAGS_keep_state_numbering); 400 401 const Fst<Arc> *fst = &fstreader.Fst(); 402 if (FLAGS_fst_type != "vector") { 403 fst = Convert<Arc>(*fst, FLAGS_fst_type); 404 if (!fst) return 1; 405 } 406 fst->Write(argc > 2 ? argv[2] : ""); 407 if (istrm != &std::cin) 408 delete istrm; 409 return 0; 410 } 411 #endif 412 413 } // namespace fst 414 415 #endif /* __FST_IO_H__ */ 416 417