• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #ifndef __FST_IO_H__
2 #define __FST_IO_H__
3 
4 // fst-io.h
5 // This is a copy of the OPENFST SDK application sample files ...
6 // except for the main functions ifdef'ed out
7 // 2007, 2008 Nuance Communications
8 //
9 // print-main.h compile-main.h
10 //
11 // Licensed under the Apache License, Version 2.0 (the "License");
12 // you may not use this file except in compliance with the License.
13 // You may obtain a copy of the License at
14 //
15 //      http://www.apache.org/licenses/LICENSE-2.0
16 //
17 // Unless required by applicable law or agreed to in writing, software
18 // distributed under the License is distributed on an "AS IS" BASIS,
19 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20 // See the License for the specific language governing permissions and
21 // limitations under the License.
22 //
23 //
24 // \file
25 // Classes and functions to compile a binary Fst from textual input.
26 // Includes helper function for fstcompile.cc that templates the main
27 // on the arc type to support multiple and extensible arc types.
28 
29 #include <fstream>
30 #include <sstream>
31 
32 #include "fst/lib/fst.h"
33 #include "fst/lib/fstlib.h"
34 #include "fst/lib/fst-decl.h"
35 #include "fst/lib/vector-fst.h"
36 #include "fst/lib/arcsort.h"
37 #include "fst/lib/invert.h"
38 
39 namespace fst {
40 
41   template <class A> class FstPrinter {
42   public:
43     typedef A Arc;
44     typedef typename A::StateId StateId;
45     typedef typename A::Label Label;
46     typedef typename A::Weight Weight;
47 
FstPrinter(const Fst<A> & fst,const SymbolTable * isyms,const SymbolTable * osyms,const SymbolTable * ssyms,bool accep)48     FstPrinter(const Fst<A> &fst,
49 	       const SymbolTable *isyms,
50 	       const SymbolTable *osyms,
51 	       const SymbolTable *ssyms,
52 	       bool accep)
53       : fst_(fst), isyms_(isyms), osyms_(osyms), ssyms_(ssyms),
54       accep_(accep && fst.Properties(kAcceptor, true)), ostrm_(0) {}
55 
56     // Print Fst to an output strm
Print(ostream * ostrm,const string & dest)57     void Print(ostream *ostrm, const string &dest) {
58       ostrm_ = ostrm;
59       dest_ = dest;
60       StateId start = fst_.Start();
61       if (start == kNoStateId)
62 	return;
63       // initial state first
64       PrintState(start);
65       for (StateIterator< Fst<A> > siter(fst_);
66 	   !siter.Done();
67 	   siter.Next()) {
68 	StateId s = siter.Value();
69 	if (s != start)
70 	  PrintState(s);
71       }
72     }
73 
74   private:
75     // Maximum line length in text file.
76     static const int kLineLen = 8096;
77 
PrintId(int64 id,const SymbolTable * syms,const char * name)78     void PrintId(int64 id, const SymbolTable *syms,
79 		 const char *name) const {
80       if (syms) {
81 	string symbol = syms->Find(id);
82 	if (symbol == "") {
83 	  LOG(ERROR) << "FstPrinter: Integer " << id
84 		     << " is not mapped to any textual symbol"
85 		     << ", symbol table = " << syms->Name()
86 		     << ", destination = " << dest_;
87 	  exit(1);
88 	}
89 	*ostrm_ << symbol;
90       } else {
91 	*ostrm_ << id;
92       }
93     }
94 
PrintStateId(StateId s)95     void PrintStateId(StateId s) const {
96       PrintId(s, ssyms_, "state ID");
97     }
98 
PrintILabel(Label l)99     void PrintILabel(Label l) const {
100       PrintId(l, isyms_, "arc input label");
101     }
102 
PrintOLabel(Label l)103     void PrintOLabel(Label l) const {
104       PrintId(l, osyms_, "arc output label");
105     }
106 
PrintState(StateId s)107     void PrintState(StateId s) const {
108       bool output = false;
109       for (ArcIterator< Fst<A> > aiter(fst_, s);
110 	   !aiter.Done();
111 	   aiter.Next()) {
112 	Arc arc = aiter.Value();
113 	PrintStateId(s);
114 	*ostrm_ << "\t";
115 	PrintStateId(arc.nextstate);
116 	*ostrm_ << "\t";
117 	PrintILabel(arc.ilabel);
118 	if (!accep_) {
119 	  *ostrm_ << "\t";
120 	  PrintOLabel(arc.olabel);
121 	}
122 	if (arc.weight != Weight::One())
123 	  *ostrm_ << "\t" << arc.weight;
124 	*ostrm_ << "\n";
125 	output = true;
126       }
127       Weight final = fst_.Final(s);
128       if (final != Weight::Zero() || !output) {
129 	PrintStateId(s);
130 	if (final != Weight::One()) {
131 	  *ostrm_ << "\t" << final;
132 	}
133 	*ostrm_ << "\n";
134       }
135     }
136 
137     const Fst<A> &fst_;
138     const SymbolTable *isyms_;     // ilabel symbol table
139     const SymbolTable *osyms_;     // olabel symbol table
140     const SymbolTable *ssyms_;     // slabel symbol table
141     bool accep_;                   // print as acceptor when possible
142     ostream *ostrm_;                // binary FST destination
143     string dest_;                  // binary FST destination name
144     DISALLOW_EVIL_CONSTRUCTORS(FstPrinter);
145   };
146 
147 #if 0
148   // Main function for fstprint templated on the arc type.
149   template <class Arc>
150     int PrintMain(int argc, char **argv, istream &istrm,
151 		  const FstReadOptions &opts) {
152     Fst<Arc> *fst = Fst<Arc>::Read(istrm, opts);
153     if (!fst) return 1;
154 
155     string dest = "standard output";
156     ostream *ostrm = &std::cout;
157     if (argc == 3) {
158       dest = argv[2];
159       ostrm = new ofstream(argv[2]);
160       if (!*ostrm) {
161 	LOG(ERROR) << argv[0] << ": Open failed, file = " << argv[2];
162 	return 0;
163       }
164     }
165     ostrm->precision(9);
166 
167     const SymbolTable *isyms = 0, *osyms = 0, *ssyms = 0;
168 
169     if (!FLAGS_isymbols.empty() && !FLAGS_numeric) {
170       isyms = SymbolTable::ReadText(FLAGS_isymbols);
171       if (!isyms) exit(1);
172     }
173 
174     if (!FLAGS_osymbols.empty() && !FLAGS_numeric) {
175       osyms = SymbolTable::ReadText(FLAGS_osymbols);
176       if (!osyms) exit(1);
177     }
178 
179     if (!FLAGS_ssymbols.empty() && !FLAGS_numeric) {
180       ssyms = SymbolTable::ReadText(FLAGS_ssymbols);
181       if (!ssyms) exit(1);
182     }
183 
184     if (!isyms && !FLAGS_numeric)
185       isyms = fst->InputSymbols();
186     if (!osyms && !FLAGS_numeric)
187       osyms = fst->OutputSymbols();
188 
189     FstPrinter<Arc> fstprinter(*fst, isyms, osyms, ssyms, FLAGS_acceptor);
190     fstprinter.Print(ostrm, dest);
191 
192     if (isyms && !FLAGS_save_isymbols.empty())
193       isyms->WriteText(FLAGS_save_isymbols);
194 
195     if (osyms && !FLAGS_save_osymbols.empty())
196       osyms->WriteText(FLAGS_save_osymbols);
197 
198     if (ostrm != &std::cout)
199       delete ostrm;
200     return 0;
201   }
202 #endif
203 
204 
205   template <class A> class FstReader {
206   public:
207     typedef A Arc;
208     typedef typename A::StateId StateId;
209     typedef typename A::Label Label;
210     typedef typename A::Weight Weight;
211 
FstReader(istream & istrm,const string & source,const SymbolTable * isyms,const SymbolTable * osyms,const SymbolTable * ssyms,bool accep,bool ikeep,bool okeep,bool nkeep)212     FstReader(istream &istrm, const string &source,
213 	      const SymbolTable *isyms, const SymbolTable *osyms,
214 	      const SymbolTable *ssyms, bool accep, bool ikeep,
215 	      bool okeep, bool nkeep)
216       : nline_(0), source_(source),
217       isyms_(isyms), osyms_(osyms), ssyms_(ssyms),
218       nstates_(0), keep_state_numbering_(nkeep) {
219       char line[kLineLen];
220       while (istrm.getline(line, kLineLen)) {
221 	++nline_;
222 	vector<char *> col;
223 	SplitToVector(line, "\n\t ", &col, true);
224 	if (col.size() == 0 || col[0][0] == '\0')  // empty line
225 	  continue;
226 	if (col.size() > 5 ||
227 	    (col.size() > 4 && accep) ||
228 	    (col.size() == 3 && !accep)) {
229 	  LOG(ERROR) << "FstReader: Bad number of columns, source = " << source_
230 		     << ", line = " << nline_;
231 	  exit(1);
232 	}
233 	StateId s = StrToStateId(col[0]);
234 	while (s >= fst_.NumStates())
235 	  fst_.AddState();
236 	if (nline_ == 1)
237 	  fst_.SetStart(s);
238 
239 	Arc arc;
240 	StateId d = s;
241 	switch (col.size()) {
242 	case 1:
243 	  fst_.SetFinal(s, Weight::One());
244 	  break;
245 	case 2:
246 	  fst_.SetFinal(s, StrToWeight(col[1], true));
247 	  break;
248 	case 3:
249 	  arc.nextstate = d = StrToStateId(col[1]);
250 	  arc.ilabel = StrToILabel(col[2]);
251 	  arc.olabel = arc.ilabel;
252 	  arc.weight = Weight::One();
253 	  fst_.AddArc(s, arc);
254 	  break;
255 	case 4:
256 	  arc.nextstate = d = StrToStateId(col[1]);
257 	  arc.ilabel = StrToILabel(col[2]);
258 	  if (accep) {
259 	    arc.olabel = arc.ilabel;
260 	    arc.weight = StrToWeight(col[3], false);
261 	  } else {
262 	    arc.olabel = StrToOLabel(col[3]);
263 	    arc.weight = Weight::One();
264 	  }
265 	  fst_.AddArc(s, arc);
266 	  break;
267 	case 5:
268 	  arc.nextstate = d = StrToStateId(col[1]);
269 	  arc.ilabel = StrToILabel(col[2]);
270 	  arc.olabel = StrToOLabel(col[3]);
271 	  arc.weight = StrToWeight(col[4], false);
272 	  fst_.AddArc(s, arc);
273 	}
274 	while (d >= fst_.NumStates())
275 	  fst_.AddState();
276       }
277       if (ikeep)
278 	fst_.SetInputSymbols(isyms);
279       if (okeep)
280 	fst_.SetOutputSymbols(osyms);
281     }
282 
Fst()283     const VectorFst<A> &Fst() const { return fst_; }
284 
285   private:
286     // Maximum line length in text file.
287     static const int kLineLen = 8096;
288 
StrToId(const char * s,const SymbolTable * syms,const char * name)289     int64 StrToId(const char *s, const SymbolTable *syms,
290 		  const char *name) const {
291       int64 n;
292 
293       if (syms) {
294 	n = syms->Find(s);
295 	if (n < 0) {
296 	  LOG(ERROR) << "FstReader: Symbol \"" << s
297 		     << "\" is not mapped to any integer " << name
298 		     << ", symbol table = " << syms->Name()
299 		     << ", source = " << source_ << ", line = " << nline_;
300 	  exit(1);
301 	}
302       } else {
303 	char *p;
304 	n = strtoll(s, &p, 10);
305 	if (p < s + strlen(s) || n < 0) {
306 	  LOG(ERROR) << "FstReader: Bad " << name << " integer = \"" << s
307 		     << "\", source = " << source_ << ", line = " << nline_;
308 	  exit(1);
309 	}
310       }
311       return n;
312     }
313 
StrToStateId(const char * s)314     StateId StrToStateId(const char *s) {
315       StateId n = StrToId(s, ssyms_, "state ID");
316 
317       if (keep_state_numbering_)
318 	return n;
319 
320       // remap state IDs to make dense set
321       typename std::unordered_map<StateId, StateId>::const_iterator it =
322          states_.find(n);
323       if (it == states_.end()) {
324 	states_[n] = nstates_;
325 	return nstates_++;
326       } else {
327 	return it->second;
328       }
329     }
330 
StrToILabel(const char * s)331     StateId StrToILabel(const char *s) const {
332       return StrToId(s, isyms_, "arc ilabel");
333     }
334 
StrToOLabel(const char * s)335     StateId StrToOLabel(const char *s) const {
336       return StrToId(s, osyms_, "arc olabel");
337     }
338 
StrToWeight(const char * s,bool allow_zero)339     Weight StrToWeight(const char *s, bool allow_zero) const {
340       Weight w;
341       istringstream strm(s);
342       strm >> w;
343       if (strm.fail() || (!allow_zero && w == Weight::Zero())) {
344 	LOG(ERROR) << "FstReader: Bad weight = \"" << s
345 		   << "\", source = " << source_ << ", line = " << nline_;
346 	exit(1);
347       }
348       return w;
349     }
350 
351     VectorFst<A> fst_;
352     size_t nline_;
353     string source_;                      // text FST source name
354     const SymbolTable *isyms_;           // ilabel symbol table
355     const SymbolTable *osyms_;           // olabel symbol table
356     const SymbolTable *ssyms_;           // slabel symbol table
357     std::unordered_map<StateId, StateId> states_;  // state ID map
358     StateId nstates_;                    // number of seen states
359     bool keep_state_numbering_;
360     DISALLOW_EVIL_CONSTRUCTORS(FstReader);
361   };
362 
363 #if 0
364   // Main function for fstcompile templated on the arc type.  Last two
365   // arguments unneeded since fstcompile passes the arc type as a flag
366   // unlike the other mains, which infer the arc type from an input Fst.
367   template <class Arc>
368     int CompileMain(int argc, char **argv, istream& /* strm */,
369 		    const FstReadOptions & /* opts */) {
370     char *ifilename = "standard input";
371     istream *istrm = &std::cin;
372     if (argc > 1 && strcmp(argv[1], "-") != 0) {
373       ifilename = argv[1];
374       istrm = new ifstream(ifilename);
375       if (!*istrm) {
376 	LOG(ERROR) << argv[0] << ": Open failed, file = " << ifilename;
377 	return 1;
378       }
379     }
380     const SymbolTable *isyms = 0, *osyms = 0, *ssyms = 0;
381 
382     if (!FLAGS_isymbols.empty()) {
383       isyms = SymbolTable::ReadText(FLAGS_isymbols);
384       if (!isyms) exit(1);
385     }
386 
387     if (!FLAGS_osymbols.empty()) {
388       osyms = SymbolTable::ReadText(FLAGS_osymbols);
389       if (!osyms) exit(1);
390     }
391 
392     if (!FLAGS_ssymbols.empty()) {
393       ssyms = SymbolTable::ReadText(FLAGS_ssymbols);
394       if (!ssyms) exit(1);
395     }
396 
397     FstReader<Arc> fstreader(*istrm, ifilename, isyms, osyms, ssyms,
398 			     FLAGS_acceptor, FLAGS_keep_isymbols,
399 			     FLAGS_keep_osymbols, FLAGS_keep_state_numbering);
400 
401     const Fst<Arc> *fst = &fstreader.Fst();
402     if (FLAGS_fst_type != "vector") {
403       fst = Convert<Arc>(*fst, FLAGS_fst_type);
404       if (!fst) return 1;
405     }
406     fst->Write(argc > 2 ? argv[2] : "");
407     if (istrm != &std::cin)
408       delete istrm;
409     return 0;
410   }
411 #endif
412 
413 }  // namespace fst
414 
415 #endif /* __FST_IO_H__ */
416 
417