• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // fst.h
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 //
16 // \file
17 // Finite-State Transducer (FST) - abstract base class definition,
18 // state and arc iterator interface, and suggested base implementation.
19 
20 #ifndef FST_LIB_FST_H__
21 #define FST_LIB_FST_H__
22 
23 #include "fst/lib/arc.h"
24 #include "fst/lib/compat.h"
25 #include "fst/lib/properties.h"
26 #include "fst/lib/register.h"
27 #include "fst/lib/symbol-table.h"
28 #include "fst/lib/util.h"
29 
30 namespace fst {
31 
32 class FstHeader;
33 template <class A> class StateIteratorData;
34 template <class A> class ArcIteratorData;
35 
36 struct FstReadOptions  {
37   string source;                // Where you're reading from
38   const FstHeader *header;      // Pointer to Fst header (if non-zero)
39   const SymbolTable* isymbols;  // Pointer to input symbols (if non-zero)
40   const SymbolTable* osymbols;  // Pointer to output symbols (if non-zero)
41 
42   explicit FstReadOptions(const string& src = "<unspecified>",
43                           const FstHeader *hdr = 0,
44                           const SymbolTable* isym = 0,
45                           const SymbolTable* osym = 0)
sourceFstReadOptions46       : source(src), header(hdr), isymbols(isym), osymbols(osym) {}
47 };
48 
49 
50 struct FstWriteOptions {
51   string source;                    // Where you're writing to
52   bool write_header;                // Write the header?
53   bool write_isymbols;              // Write input symbols?
54   bool write_osymbols;              // Write output symbols?
55 
56   explicit FstWriteOptions(const string& src = "<unspecifed>",
57                            bool hdr = true, bool isym = true,
58                            bool osym = true)
sourceFstWriteOptions59       : source(src), write_header(hdr),
60         write_isymbols(isym),  write_osymbols(osym) {}
61 };
62 
63 //
64 // Fst HEADER CLASS
65 //
66 // This is the recommended Fst file header representation.
67 //
68 
69 class FstHeader {
70  public:
71   enum {
72     HAS_ISYMBOLS = 1,                           // Has input symbol table
73     HAS_OSYMBOLS = 2                            // Has output symbol table
74   } Flags;
75 
FstHeader()76   FstHeader() : version_(0), flags_(0), properties_(0), start_(-1),
77                 numstates_(0), numarcs_(0) {}
FstType()78   const string &FstType() const { return fsttype_; }
ArcType()79   const string &ArcType() const { return arctype_; }
Version()80   int32 Version() const { return version_; }
GetFlags()81   int32 GetFlags() const { return flags_; }
Properties()82   uint64 Properties() const { return properties_; }
Start()83   int64 Start() const { return start_; }
NumStates()84   int64 NumStates() const { return numstates_; }
NumArcs()85   int64 NumArcs() const { return numarcs_; }
86 
SetFstType(const string & type)87   void SetFstType(const string& type) { fsttype_ = type; }
SetArcType(const string & type)88   void SetArcType(const string& type) { arctype_ = type; }
SetVersion(int32 version)89   void SetVersion(int32 version) { version_ = version; }
SetFlags(int32 flags)90   void SetFlags(int32 flags) { flags_ = flags; }
SetProperties(uint64 properties)91   void SetProperties(uint64 properties) { properties_ = properties; }
SetStart(int64 start)92   void SetStart(int64 start) { start_ = start; }
SetNumStates(int64 numstates)93   void SetNumStates(int64 numstates) { numstates_ = numstates; }
SetNumArcs(int64 numarcs)94   void SetNumArcs(int64 numarcs) { numarcs_ = numarcs; }
95 
96   bool Read(istream &strm, const string &source);
97   bool Write(ostream &strm, const string &source) const;
98 
99  private:
100   string fsttype_;                   // E.g. "vector"
101   string arctype_;                   // E.g. "standard"
102   int32 version_;                    // Type version #
103   int32 flags_;                      // File format bits
104   uint64 properties_;                // FST property bits
105   int64 start_;                      // Start state
106   int64 numstates_;                  // # of states
107   int64 numarcs_;                    // # of arcs
108 };
109 
110 //
111 // Fst INTERFACE CLASS DEFINITION
112 //
113 
114 // A generic FST, templated on the arc definition, with
115 // common-demoninator methods (use StateIterator and ArcIterator to
116 // iterate over its states and arcs).
117 template <class A>
118 class Fst {
119  public:
120   typedef A Arc;
121   typedef typename A::Weight Weight;
122   typedef typename A::StateId StateId;
123 
~Fst()124   virtual ~Fst() {}
125 
126   virtual StateId Start() const = 0;          // Initial state
127 
128   virtual Weight Final(StateId) const = 0;    // State's final weight
129 
130   virtual size_t NumArcs(StateId) const = 0;  // State's arc count
131 
132   virtual size_t NumInputEpsilons(StateId)
133       const = 0;                              // State's input epsilon count
134 
135   virtual size_t NumOutputEpsilons(StateId)
136       const = 0;                              // State's output epsilon count
137 
138   // If test=false, return stored properties bits for mask (some poss. unknown)
139   // If test=true, return property bits for mask (computing o.w. unknown)
140   virtual uint64 Properties(uint64 mask, bool test)
141       const = 0;  // Property bits
142 
143   virtual const string& Type() const = 0;    // Fst type name
144 
145   // Get a copy of this Fst.
146   virtual Fst<A> *Copy() const = 0;
147   // Read an Fst from an input stream; returns NULL on error
148 
Read(istream & strm,const FstReadOptions & opts)149   static Fst<A> *Read(istream &strm, const FstReadOptions &opts) {
150     FstReadOptions ropts(opts);
151     FstHeader hdr;
152     if (ropts.header)
153       hdr = *opts.header;
154     else {
155       if (!hdr.Read(strm, opts.source))
156         return 0;
157       ropts.header = &hdr;
158     }
159     FstRegister<A> *registr = FstRegister<A>::GetRegister();
160     const typename FstRegister<A>::Reader reader =
161         registr->GetReader(hdr.FstType());
162     if (!reader) {
163       LOG(ERROR) << "Fst::Read: Unknown FST type \"" << hdr.FstType()
164                  << "\" (arc type = \"" << A::Type()
165                  << "\"): " << ropts.source;
166       return 0;
167     }
168     return reader(strm, ropts);
169   };
170 
171   // Read an Fst from a file; return NULL on error
Read(const string & filename)172   static Fst<A> *Read(const string &filename) {
173     ifstream strm(filename.c_str());
174     if (!strm) {
175       LOG(ERROR) << "Fst::Read: Can't open file: " << filename;
176       return 0;
177     }
178     return Read(strm, FstReadOptions(filename));
179   }
180 
181   // Write an Fst to an output stream; return false on error
Write(ostream & strm,const FstWriteOptions & opts)182   virtual bool Write(ostream &strm, const FstWriteOptions &opts) const {
183     LOG(ERROR) << "Fst::Write: No write method for " << Type() << " Fst type";
184     return false;
185   }
186 
187   // Write an Fst to a file; return false on error
Write(const string & filename)188   virtual bool Write(const string &filename) const {
189     LOG(ERROR) << "Fst::Write: No write method for "
190                << Type() << " Fst type: "
191                << (filename.empty() ? "standard output" : filename);
192     return false;
193   }
194 
195   // Return input label symbol table; return NULL if not specified
196   virtual const SymbolTable* InputSymbols() const = 0;
197 
198   // Return output label symbol table; return NULL if not specified
199   virtual const SymbolTable* OutputSymbols() const = 0;
200 
201   // For generic state iterator construction; not normally called
202   // directly by users.
203   virtual void InitStateIterator(StateIteratorData<A> *) const = 0;
204 
205   // For generic arc iterator construction; not normally called
206   // directly by users.
207   virtual void InitArcIterator(StateId s, ArcIteratorData<A> *) const = 0;
208 };
209 
210 
211 //
212 // STATE and ARC ITERATOR DEFINITIONS
213 //
214 
215 // State iterator interface templated on the Arc definition; used
216 // for StateIterator specializations returned by InitStateIterator.
217 template <class A>
218 class StateIteratorBase {
219  public:
220   typedef A Arc;
221   typedef typename A::StateId StateId;
222 
~StateIteratorBase()223   virtual ~StateIteratorBase() {}
224   virtual bool Done() const = 0;      // End of iterator?
225   virtual StateId Value() const = 0;  // Current state (when !Done)
226   virtual void Next() = 0;            // Advance to next state (when !Done)
227   virtual void Reset() = 0;           // Return to initial condition
228 };
229 
230 
231 // StateIterator initialization data
232 template <class A> struct StateIteratorData {
233   StateIteratorBase<A> *base;   // Specialized iterator if non-zero
234   typename A::StateId nstates;  // O.w. total # of states
235 };
236 
237 
238 // Generic state iterator, templated on the FST definition
239 // - a wrapper around pointer to specific one.
240 // Here is a typical use: \code
241 //   for (StateIterator<StdFst> siter(fst);
242 //        !siter.Done();
243 //        siter.Next()) {
244 //     StateId s = siter.Value();
245 //     ...
246 //   } \endcode
247 template <class F>
248 class StateIterator {
249  public:
250   typedef typename F::Arc Arc;
251   typedef typename Arc::StateId StateId;
252 
StateIterator(const F & fst)253   explicit StateIterator(const F &fst) : s_(0) {
254     fst.InitStateIterator(&data_);
255   }
256 
~StateIterator()257   ~StateIterator() { if (data_.base) delete data_.base; }
258 
Done()259   bool Done() const {
260     return data_.base ? data_.base->Done() : s_ >= data_.nstates;
261   }
262 
Value()263   StateId Value() const { return data_.base ? data_.base->Value() : s_; }
264 
Next()265   void Next() {
266     if (data_.base)
267       data_.base->Next();
268     else
269       ++s_;
270   }
271 
Reset()272   void Reset() {
273     if (data_.base)
274       data_.base->Reset();
275     else
276       s_ = 0;
277   }
278 
279  private:
280   StateIteratorData<Arc> data_;
281   StateId s_;
282   DISALLOW_EVIL_CONSTRUCTORS(StateIterator);
283 };
284 
285 
286 // Arc iterator interface, templated on the Arc definition; used
287 // for Arc iterator specializations that are returned by InitArcIterator.
288 template <class A>
289 class ArcIteratorBase {
290  public:
291   typedef A Arc;
292   typedef typename A::StateId StateId;
293 
~ArcIteratorBase()294   virtual ~ArcIteratorBase() {}
295   virtual bool Done() const = 0;       // End of iterator?
296   virtual const A& Value() const = 0;  // Current state (when !Done)
297   virtual void Next() = 0;             // Advance to next arc (when !Done)
298   virtual void Reset() = 0;            // Return to initial condition
299   virtual void Seek(size_t a) = 0;     // Random arc access by position
300 };
301 
302 
303 // ArcIterator initialization data
304 template <class A> struct ArcIteratorData {
305   ArcIteratorBase<A> *base;  // Specialized iterator if non-zero
306   const A *arcs;             // O.w. arcs pointer
307   size_t narcs;              // ... and arc count
308   int *ref_count;            // ... and reference count if non-zero
309 };
310 
311 
312 // Generic arc iterator, templated on the FST definition
313 // - a wrapper around pointer to specific one.
314 // Here is a typical use: \code
315 //   for (ArcIterator<StdFst> aiter(fst, s));
316 //        !aiter.Done();
317 //         aiter.Next()) {
318 //     StdArc &arc = aiter.Value();
319 //     ...
320 //   } \endcode
321 template <class F>
322 class ArcIterator {
323    public:
324   typedef typename F::Arc Arc;
325   typedef typename Arc::StateId StateId;
326 
ArcIterator(const F & fst,StateId s)327   ArcIterator(const F &fst, StateId s) : i_(0) {
328     fst.InitArcIterator(s, &data_);
329   }
330 
~ArcIterator()331   ~ArcIterator() {
332     if (data_.base)
333       delete data_.base;
334     else if (data_.ref_count)
335     --(*data_.ref_count);
336   }
337 
Done()338   bool Done() const {
339     return data_.base ?  data_.base->Done() : i_ >= data_.narcs;
340   }
341 
Value()342   const Arc& Value() const {
343     return data_.base ? data_.base->Value() : data_.arcs[i_];
344   }
345 
Next()346   void Next() {
347     if (data_.base)
348       data_.base->Next();
349     else
350       ++i_;
351   }
352 
Reset()353   void Reset() {
354     if (data_.base)
355       data_.base->Reset();
356     else
357       i_ = 0;
358   }
359 
Seek(size_t a)360   void Seek(size_t a) {
361     if (data_.base)
362       data_.base->Seek(a);
363     else
364       i_ = a;
365   }
366 
367  private:
368   ArcIteratorData<Arc> data_;
369   size_t i_;
370   DISALLOW_EVIL_CONSTRUCTORS(ArcIterator);
371 };
372 
373 
374 // A useful alias when using StdArc.
375 typedef Fst<StdArc> StdFst;
376 
377 
378 //
379 //  CONSTANT DEFINITIONS
380 //
381 
382 const int kNoStateId   =  -1;  // Not a valid state ID
383 const int kNoLabel     =  -1;  // Not a valid label
384 const int kPhiLabel    =  -2;  // Failure transition label
385 const int kRhoLabel    =  -3;  // Matches o.w. unmatched labels (lib. internal)
386 const int kSigmaLabel  =  -4;  // Matches all labels in alphabet.
387 
388 
389 //
390 // Fst IMPLEMENTATION BASE
391 //
392 // This is the recommended Fst implementation base class. It will
393 // handle reference counts, property bits, type information and symbols.
394 //
395 
396 template <class A> class FstImpl {
397  public:
398   typedef typename A::Weight Weight;
399   typedef typename A::StateId StateId;
400 
FstImpl()401   FstImpl()
402       : properties_(0), type_("null"), isymbols_(0), osymbols_(0),
403         ref_count_(1) {}
404 
FstImpl(const FstImpl<A> & impl)405   FstImpl(const FstImpl<A> &impl)
406       : properties_(impl.properties_), type_(impl.type_),
407         isymbols_(impl.isymbols_ ? new SymbolTable(impl.isymbols_) : 0),
408         osymbols_(impl.osymbols_ ? new SymbolTable(impl.osymbols_) : 0),
409         ref_count_(1) {}
410 
~FstImpl()411   ~FstImpl() {
412     delete isymbols_;
413     delete osymbols_;
414   }
415 
Type()416   const string& Type() const { return type_; }
417 
SetType(const string & type)418   void SetType(const string &type) { type_ = type; }
419 
Properties()420   uint64 Properties() const { return properties_; }
421 
Properties(uint64 mask)422   uint64 Properties(uint64 mask) const { return properties_ & mask; }
423 
SetProperties(uint64 props)424   void SetProperties(uint64 props) { properties_ = props; }
425 
SetProperties(uint64 props,uint64 mask)426   void SetProperties(uint64 props, uint64 mask) {
427     properties_ &= ~mask;
428     properties_ |= props & mask;
429   }
430 
InputSymbols()431   const SymbolTable* InputSymbols() const { return isymbols_; }
432 
OutputSymbols()433   const SymbolTable* OutputSymbols() const { return osymbols_; }
434 
InputSymbols()435   SymbolTable* InputSymbols() { return isymbols_; }
436 
OutputSymbols()437   SymbolTable* OutputSymbols() { return osymbols_; }
438 
SetInputSymbols(const SymbolTable * isyms)439   void SetInputSymbols(const SymbolTable* isyms) {
440     if (isymbols_) delete isymbols_;
441     isymbols_ = isyms ? isyms->Copy() : 0;
442   }
443 
SetOutputSymbols(const SymbolTable * osyms)444   void SetOutputSymbols(const SymbolTable* osyms) {
445     if (osymbols_) delete osymbols_;
446     osymbols_ = osyms ? osyms->Copy() : 0;
447   }
448 
RefCount()449   int RefCount() const { return ref_count_; }
450 
IncrRefCount()451   int IncrRefCount() { return ++ref_count_; }
452 
DecrRefCount()453   int DecrRefCount() { return --ref_count_; }
454 
455   // Read-in header and symbols, initialize Fst, and return the header.
456   // If opts.header is non-null, skip read-in and use the option value.
457   // If opts.[io]symbols is non-null, read-in but use the option value.
ReadHeaderAndSymbols(istream & strm,const FstReadOptions & opts,int min_version,FstHeader * hdr)458   bool ReadHeaderAndSymbols(istream &strm, const FstReadOptions& opts,
459                   int min_version, FstHeader *hdr) {
460     if (opts.header)
461       *hdr = *opts.header;
462     else if (!hdr->Read(strm, opts.source))
463       return false;
464     if (hdr->FstType() != type_) {
465       LOG(ERROR) << "FstImpl::ReadHeaderAndSymbols: Fst not of type \""
466                  << type_ << "\": " << opts.source;
467       return false;
468     }
469     if (hdr->ArcType() != A::Type()) {
470       LOG(ERROR) << "FstImpl::ReadHeaderAndSymbols: Arc not of type \""
471                  << A::Type()
472                  << "\": " << opts.source;
473       return false;
474     }
475     if (hdr->Version() < min_version) {
476       LOG(ERROR) << "FstImpl::ReadHeaderAndSymbols: Obsolete "
477                  << type_ << " Fst version: " << opts.source;
478       return false;
479     }
480     properties_ = hdr->Properties();
481     if (hdr->GetFlags() & FstHeader::HAS_ISYMBOLS)
482       isymbols_ = SymbolTable::Read(strm, opts.source);
483     if (hdr->GetFlags() & FstHeader::HAS_OSYMBOLS)
484       osymbols_ =SymbolTable::Read(strm, opts.source);
485 
486     if (opts.isymbols) {
487       delete isymbols_;
488       isymbols_ = opts.isymbols->Copy();
489     }
490     if (opts.osymbols) {
491       delete osymbols_;
492       osymbols_ = opts.osymbols->Copy();
493     }
494     return true;
495   }
496 
497   // Write-out header and symbols.
498   // If a opts.header is false, skip writing header.
499   // If opts.[io]symbols is false, skip writing those symbols.
WriteHeaderAndSymbols(ostream & strm,const FstWriteOptions & opts,int version,FstHeader * hdr)500   void WriteHeaderAndSymbols(ostream &strm, const FstWriteOptions& opts,
501                              int version, FstHeader *hdr) const {
502     if (opts.write_header) {
503       hdr->SetFstType(type_);
504       hdr->SetArcType(A::Type());
505       hdr->SetVersion(version);
506       hdr->SetProperties(properties_);
507       int32 file_flags = 0;
508       if (isymbols_ && opts.write_isymbols)
509         file_flags |= FstHeader::HAS_ISYMBOLS;
510       if (osymbols_ && opts.write_osymbols)
511         file_flags |= FstHeader::HAS_OSYMBOLS;
512       hdr->SetFlags(file_flags);
513       hdr->Write(strm, opts.source);
514     }
515     if (isymbols_ && opts.write_isymbols) isymbols_->Write(strm);
516     if (osymbols_ && opts.write_osymbols) osymbols_->Write(strm);
517   }
518 
519  protected:
520   uint64 properties_;           // Property bits
521 
522  private:
523   string type_;                 // Unique name of Fst class
524   SymbolTable *isymbols_;       // Ilabel symbol table
525   SymbolTable *osymbols_;       // Olabel symbol table
526   int ref_count_;               // Reference count
527 
528   void operator=(const FstImpl<A> &impl);  // disallow
529 };
530 
531 }  // namespace fst;
532 
533 #endif  // FST_LIB_FST_H__
534