• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // fst.h
2 
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Copyright 2005-2010 Google, Inc.
16 // Author: riley@google.com (Michael Riley)
17 //
18 // \file
19 // Finite-State Transducer (FST) - abstract base class definition,
20 // state and arc iterator interface, and suggested base implementation.
21 //
22 
23 #ifndef FST_LIB_FST_H__
24 #define FST_LIB_FST_H__
25 
26 #include <stddef.h>
27 #include <sys/types.h>
28 #include <cmath>
29 #include <string>
30 
31 #include <fst/compat.h>
32 #include <fst/types.h>
33 
34 #include <fst/arc.h>
35 #include <fst/properties.h>
36 #include <fst/register.h>
37 #include <iostream>
38 #include <fstream>
39 #include <fst/symbol-table.h>
40 #include <fst/util.h>
41 
42 
43 DECLARE_bool(fst_align);
44 
45 namespace fst {
46 
47 bool IsFstHeader(istream &, const string &);
48 
49 class FstHeader;
50 template <class A> class StateIteratorData;
51 template <class A> class ArcIteratorData;
52 template <class A> class MatcherBase;
53 
54 struct FstReadOptions {
55   string source;                // Where you're reading from
56   const FstHeader *header;      // Pointer to Fst header. If non-zero, use
57                                 // this info (don't read a stream header)
58   const SymbolTable* isymbols;  // Pointer to input symbols. If non-zero, use
59                                 // this info (read and skip stream isymbols)
60   const SymbolTable* osymbols;  // Pointer to output symbols. If non-zero, use
61                                 // this info (read and skip stream osymbols)
62 
63   explicit FstReadOptions(const string& src = "<unspecfied>",
64                           const FstHeader *hdr = 0,
65                           const SymbolTable* isym = 0,
66                           const SymbolTable* osym = 0)
sourceFstReadOptions67       : source(src), header(hdr), isymbols(isym), osymbols(osym) {}
68 
69   explicit FstReadOptions(const string& src,
70                           const SymbolTable* isym,
71                           const SymbolTable* osym = 0)
sourceFstReadOptions72       : source(src), header(0), isymbols(isym), osymbols(osym) {}
73 };
74 
75 
76 struct FstWriteOptions {
77   string source;                 // Where you're writing to
78   bool write_header;             // Write the header?
79   bool write_isymbols;           // Write input symbols?
80   bool write_osymbols;           // Write output symbols?
81   bool align;                    // Write data aligned where appropriate;
82                                  // this may fail on pipes
83 
84   explicit FstWriteOptions(const string& src = "<unspecifed>",
85                            bool hdr = true, bool isym = true,
86                            bool osym = true, bool alig = FLAGS_fst_align)
sourceFstWriteOptions87       : source(src), write_header(hdr),
88         write_isymbols(isym), write_osymbols(osym), align(alig) {}
89 };
90 
91 //
92 // Fst HEADER CLASS
93 //
94 // This is the recommended Fst file header representation.
95 //
96 class FstHeader {
97  public:
98   enum {
99     HAS_ISYMBOLS = 0x1,          // Has input symbol table
100     HAS_OSYMBOLS = 0x2,          // Has output symbol table
101     IS_ALIGNED   = 0x4,          // Memory-aligned (where appropriate)
102   } Flags;
103 
FstHeader()104   FstHeader() : version_(0), flags_(0), properties_(0), start_(-1),
105                 numstates_(0), numarcs_(0) {}
FstType()106   const string &FstType() const { return fsttype_; }
ArcType()107   const string &ArcType() const { return arctype_; }
Version()108   int32 Version() const { return version_; }
GetFlags()109   int32 GetFlags() const { return flags_; }
Properties()110   uint64 Properties() const { return properties_; }
Start()111   int64 Start() const { return start_; }
NumStates()112   int64 NumStates() const { return numstates_; }
NumArcs()113   int64 NumArcs() const { return numarcs_; }
114 
SetFstType(const string & type)115   void SetFstType(const string& type) { fsttype_ = type; }
SetArcType(const string & type)116   void SetArcType(const string& type) { arctype_ = type; }
SetVersion(int32 version)117   void SetVersion(int32 version) { version_ = version; }
SetFlags(int32 flags)118   void SetFlags(int32 flags) { flags_ = flags; }
SetProperties(uint64 properties)119   void SetProperties(uint64 properties) { properties_ = properties; }
SetStart(int64 start)120   void SetStart(int64 start) { start_ = start; }
SetNumStates(int64 numstates)121   void SetNumStates(int64 numstates) { numstates_ = numstates; }
SetNumArcs(int64 numarcs)122   void SetNumArcs(int64 numarcs) { numarcs_ = numarcs; }
123 
124   bool Read(istream &strm, const string &source, bool rewind = false);
125   bool Write(ostream &strm, const string &source) const;
126 
127  private:
128 
129   string fsttype_;                   // E.g. "vector"
130   string arctype_;                   // E.g. "standard"
131   int32 version_;                    // Type version #
132   int32 flags_;                      // File format bits
133   uint64 properties_;                // FST property bits
134   int64 start_;                      // Start state
135   int64 numstates_;                  // # of states
136   int64 numarcs_;                    // # of arcs
137 };
138 
139 
140 // Specifies matcher action.
141 enum MatchType { MATCH_INPUT,      // Match input label.
142                  MATCH_OUTPUT,     // Match output label.
143                  MATCH_BOTH,       // Match input or output label.
144                  MATCH_NONE,       // Match nothing.
145                  MATCH_UNKNOWN };  // Match type unknown.
146 
147 //
148 // Fst INTERFACE CLASS DEFINITION
149 //
150 
151 // A generic FST, templated on the arc definition, with
152 // common-demoninator methods (use StateIterator and ArcIterator to
153 // iterate over its states and arcs).
154 template <class A>
155 class Fst {
156  public:
157   typedef A Arc;
158   typedef typename A::Weight Weight;
159   typedef typename A::StateId StateId;
160 
~Fst()161   virtual ~Fst() {}
162 
163   virtual StateId Start() const = 0;          // Initial state
164 
165   virtual Weight Final(StateId) const = 0;    // State's final weight
166 
167   virtual size_t NumArcs(StateId) const = 0;  // State's arc count
168 
169   virtual size_t NumInputEpsilons(StateId)
170       const = 0;                              // State's input epsilon count
171 
172   virtual size_t NumOutputEpsilons(StateId)
173       const = 0;                              // State's output epsilon count
174 
175   // If test=false, return stored properties bits for mask (some poss. unknown)
176   // If test=true, return property bits for mask (computing o.w. unknown)
177   virtual uint64 Properties(uint64 mask, bool test)
178       const = 0;  // Property bits
179 
180   virtual const string& Type() const = 0;    // Fst type name
181 
182   // Get a copy of this Fst. The copying behaves as follows:
183   //
184   // (1) The copying is constant time if safe = false or if safe = true
185   // and is on an otherwise unaccessed Fst.
186   //
187   // (2) If safe = true, the copy is thread-safe in that the original
188   // and copy can be safely accessed (but not necessarily mutated) by
189   // separate threads. For some Fst types, 'Copy(true)' should only be
190   // called on an Fst that has not otherwise been accessed. Its behavior
191   // is undefined otherwise.
192   //
193   // (3) If a MutableFst is copied and then mutated, then the original is
194   // unmodified and vice versa (often by a copy-on-write on the initial
195   // mutation, which may not be constant time).
196   virtual Fst<A> *Copy(bool safe = false) const = 0;
197 
198   // Read an Fst from an input stream; returns NULL on error
Read(istream & strm,const FstReadOptions & opts)199   static Fst<A> *Read(istream &strm, const FstReadOptions &opts) {
200     FstReadOptions ropts(opts);
201     FstHeader hdr;
202     if (ropts.header)
203       hdr = *opts.header;
204     else {
205       if (!hdr.Read(strm, opts.source))
206         return 0;
207       ropts.header = &hdr;
208     }
209     FstRegister<A> *registr = FstRegister<A>::GetRegister();
210     const typename FstRegister<A>::Reader reader =
211       registr->GetReader(hdr.FstType());
212     if (!reader) {
213       LOG(ERROR) << "Fst::Read: Unknown FST type \"" << hdr.FstType()
214                  << "\" (arc type = \"" << A::Type()
215                  << "\"): " << ropts.source;
216       return 0;
217     }
218     return reader(strm, ropts);
219   };
220 
221   // Read an Fst from a file; return NULL on error
222   // Empty filename reads from standard input
Read(const string & filename)223   static Fst<A> *Read(const string &filename) {
224     if (!filename.empty()) {
225       ifstream strm(filename.c_str(), ifstream::in | ifstream::binary);
226       if (!strm) {
227         LOG(ERROR) << "Fst::Read: Can't open file: " << filename;
228         return 0;
229       }
230       return Read(strm, FstReadOptions(filename));
231     } else {
232       return Read(std::cin, FstReadOptions("standard input"));
233     }
234   }
235 
236   // Write an Fst to an output stream; return false on error
Write(ostream & strm,const FstWriteOptions & opts)237   virtual bool Write(ostream &strm, const FstWriteOptions &opts) const {
238     LOG(ERROR) << "Fst::Write: No write stream method for " << Type()
239                << " Fst type";
240     return false;
241   }
242 
243   // Write an Fst to a file; return false on error
244   // Empty filename writes to standard output
Write(const string & filename)245   virtual bool Write(const string &filename) const {
246     LOG(ERROR) << "Fst::Write: No write filename method for " << Type()
247                << " Fst type";
248     return false;
249   }
250 
251   // Return input label symbol table; return NULL if not specified
252   virtual const SymbolTable* InputSymbols() const = 0;
253 
254   // Return output label symbol table; return NULL if not specified
255   virtual const SymbolTable* OutputSymbols() const = 0;
256 
257   // For generic state iterator construction; not normally called
258   // directly by users.
259   virtual void InitStateIterator(StateIteratorData<A> *) const = 0;
260 
261   // For generic arc iterator construction; not normally called
262   // directly by users.
263   virtual void InitArcIterator(StateId s, ArcIteratorData<A> *) const = 0;
264 
265   // For generic matcher construction; not normally called
266   // directly by users.
267   virtual MatcherBase<A> *InitMatcher(MatchType match_type) const;
268 
269  protected:
270 
WriteFile(const string & filename)271   bool WriteFile(const string &filename) const {
272     if (!filename.empty()) {
273       ofstream strm(filename.c_str(), ofstream::out | ofstream::binary);
274       if (!strm) {
275         LOG(ERROR) << "Fst::Write: Can't open file: " << filename;
276         return false;
277       }
278       return Write(strm, FstWriteOptions(filename));
279     } else {
280       return Write(std::cout, FstWriteOptions("standard output"));
281     }
282   }
283 };
284 
285 
286 //
287 // STATE and ARC ITERATOR DEFINITIONS
288 //
289 
290 // State iterator interface templated on the Arc definition; used
291 // for StateIterator specializations returned by the InitStateIterator
292 // Fst method.
293 template <class A>
294 class StateIteratorBase {
295  public:
296   typedef A Arc;
297   typedef typename A::StateId StateId;
298 
~StateIteratorBase()299   virtual ~StateIteratorBase() {}
300 
Done()301   bool Done() const { return Done_(); }       // End of iterator?
Value()302   StateId Value() const { return Value_(); }  // Current state (when !Done)
Next()303   void Next() { Next_(); }      // Advance to next state (when !Done)
Reset()304   void Reset() { Reset_(); }    // Return to initial condition
305 
306  private:
307   // This allows base class virtual access to non-virtual derived-
308   // class members of the same name. It makes the derived class more
309   // efficient to use but unsafe to further derive.
310   virtual bool Done_() const = 0;
311   virtual StateId Value_() const = 0;
312   virtual void Next_() = 0;
313   virtual void Reset_() = 0;
314 };
315 
316 
317 // StateIterator initialization data
318 
319 template <class A> struct StateIteratorData {
320   StateIteratorBase<A> *base;   // Specialized iterator if non-zero
321   typename A::StateId nstates;  // O.w. total # of states
322 };
323 
324 
325 // Generic state iterator, templated on the FST definition
326 // - a wrapper around pointer to specific one.
327 // Here is a typical use: \code
328 //   for (StateIterator<StdFst> siter(fst);
329 //        !siter.Done();
330 //        siter.Next()) {
331 //     StateId s = siter.Value();
332 //     ...
333 //   } \endcode
334 template <class F>
335 class StateIterator {
336  public:
337   typedef F FST;
338   typedef typename F::Arc Arc;
339   typedef typename Arc::StateId StateId;
340 
StateIterator(const F & fst)341   explicit StateIterator(const F &fst) : s_(0) {
342     fst.InitStateIterator(&data_);
343   }
344 
~StateIterator()345   ~StateIterator() { if (data_.base) delete data_.base; }
346 
Done()347   bool Done() const {
348     return data_.base ? data_.base->Done() : s_ >= data_.nstates;
349   }
350 
Value()351   StateId Value() const { return data_.base ? data_.base->Value() : s_; }
352 
Next()353   void Next() {
354     if (data_.base)
355       data_.base->Next();
356     else
357       ++s_;
358   }
359 
Reset()360   void Reset() {
361     if (data_.base)
362       data_.base->Reset();
363     else
364       s_ = 0;
365   }
366 
367  private:
368   StateIteratorData<Arc> data_;
369   StateId s_;
370 
371   DISALLOW_COPY_AND_ASSIGN(StateIterator);
372 };
373 
374 
375 // Flags to control the behavior on an arc iterator:
376 static const uint32 kArcILabelValue    = 0x0001;  // Value() gives valid ilabel
377 static const uint32 kArcOLabelValue    = 0x0002;  //  "       "     "    olabel
378 static const uint32 kArcWeightValue    = 0x0004;  //  "       "     "    weight
379 static const uint32 kArcNextStateValue = 0x0008;  //  "       "     " nextstate
380 static const uint32 kArcNoCache   = 0x0010;       // No need to cache arcs
381 
382 static const uint32 kArcValueFlags =
383                   kArcILabelValue | kArcOLabelValue |
384                   kArcWeightValue | kArcNextStateValue;
385 
386 static const uint32 kArcFlags = kArcValueFlags | kArcNoCache;
387 
388 
389 // Arc iterator interface, templated on the Arc definition; used
390 // for Arc iterator specializations that are returned by the InitArcIterator
391 // Fst method.
392 template <class A>
393 class ArcIteratorBase {
394  public:
395   typedef A Arc;
396   typedef typename A::StateId StateId;
397 
~ArcIteratorBase()398   virtual ~ArcIteratorBase() {}
399 
Done()400   bool Done() const { return Done_(); }            // End of iterator?
Value()401   const A& Value() const { return Value_(); }      // Current arc (when !Done)
Next()402   void Next() { Next_(); }           // Advance to next arc (when !Done)
Position()403   size_t Position() const { return Position_(); }  // Return current position
Reset()404   void Reset() { Reset_(); }         // Return to initial condition
Seek(size_t a)405   void Seek(size_t a) { Seek_(a); }  // Random arc access by position
Flags()406   uint32 Flags() const { return Flags_(); }  // Return current behavorial flags
SetFlags(uint32 flags,uint32 mask)407   void SetFlags(uint32 flags, uint32 mask) {  // Set behavorial flags
408     SetFlags_(flags, mask);
409   }
410 
411  private:
412   // This allows base class virtual access to non-virtual derived-
413   // class members of the same name. It makes the derived class more
414   // efficient to use but unsafe to further derive.
415   virtual bool Done_() const = 0;
416   virtual const A& Value_() const = 0;
417   virtual void Next_() = 0;
418   virtual size_t Position_() const = 0;
419   virtual void Reset_() = 0;
420   virtual void Seek_(size_t a) = 0;
421   virtual uint32 Flags_() const = 0;
422   virtual void SetFlags_(uint32 flags, uint32 mask) = 0;
423 };
424 
425 
426 // ArcIterator initialization data
427 template <class A> struct ArcIteratorData {
428   ArcIteratorBase<A> *base;  // Specialized iterator if non-zero
429   const A *arcs;             // O.w. arcs pointer
430   size_t narcs;              // ... and arc count
431   int *ref_count;            // ... and reference count if non-zero
432 };
433 
434 
435 // Generic arc iterator, templated on the FST definition
436 // - a wrapper around pointer to specific one.
437 // Here is a typical use: \code
438 //   for (ArcIterator<StdFst> aiter(fst, s));
439 //        !aiter.Done();
440 //         aiter.Next()) {
441 //     StdArc &arc = aiter.Value();
442 //     ...
443 //   } \endcode
444 template <class F>
445 class ArcIterator {
446    public:
447   typedef F FST;
448   typedef typename F::Arc Arc;
449   typedef typename Arc::StateId StateId;
450 
ArcIterator(const F & fst,StateId s)451   ArcIterator(const F &fst, StateId s) : i_(0) {
452     fst.InitArcIterator(s, &data_);
453   }
454 
ArcIterator(const ArcIteratorData<Arc> & data)455   explicit ArcIterator(const ArcIteratorData<Arc> &data) : data_(data), i_(0) {
456     if (data_.ref_count)
457       ++(*data_.ref_count);
458   }
459 
~ArcIterator()460   ~ArcIterator() {
461     if (data_.base)
462       delete data_.base;
463     else if (data_.ref_count)
464       --(*data_.ref_count);
465   }
466 
Done()467   bool Done() const {
468     return data_.base ?  data_.base->Done() : i_ >= data_.narcs;
469   }
470 
Value()471   const Arc& Value() const {
472     return data_.base ? data_.base->Value() : data_.arcs[i_];
473   }
474 
Next()475   void Next() {
476     if (data_.base)
477       data_.base->Next();
478     else
479       ++i_;
480   }
481 
Reset()482   void Reset() {
483     if (data_.base)
484       data_.base->Reset();
485     else
486       i_ = 0;
487   }
488 
Seek(size_t a)489   void Seek(size_t a) {
490     if (data_.base)
491       data_.base->Seek(a);
492     else
493       i_ = a;
494   }
495 
Position()496   size_t Position() const {
497     return data_.base ? data_.base->Position() : i_;
498   }
499 
Flags()500   uint32 Flags() const {
501     if (data_.base)
502       return data_.base->Flags();
503     else
504       return kArcValueFlags;
505   }
506 
SetFlags(uint32 flags,uint32 mask)507   void SetFlags(uint32 flags, uint32 mask) {
508     if (data_.base)
509       data_.base->SetFlags(flags, mask);
510   }
511 
512  private:
513   ArcIteratorData<Arc> data_;
514   size_t i_;
515   DISALLOW_COPY_AND_ASSIGN(ArcIterator);
516 };
517 
518 //
519 // MATCHER DEFINITIONS
520 //
521 
522 template <class A>
InitMatcher(MatchType match_type)523 MatcherBase<A> *Fst<A>::InitMatcher(MatchType match_type) const {
524   return 0;  // Use the default matcher
525 }
526 
527 
528 //
529 // FST ACCESSORS - Useful functions in high-performance cases.
530 //
531 
532 namespace internal {
533 
534 // General case - requires non-abstract, 'final' methods. Use for inlining.
535 template <class F> inline
Final(const F & fst,typename F::Arc::StateId s)536 typename F::Arc::Weight Final(const F &fst, typename F::Arc::StateId s) {
537   return fst.F::Final(s);
538 }
539 
540 template <class F> inline
NumArcs(const F & fst,typename F::Arc::StateId s)541 ssize_t NumArcs(const F &fst, typename F::Arc::StateId s) {
542   return fst.F::NumArcs(s);
543 }
544 
545 template <class F> inline
NumInputEpsilons(const F & fst,typename F::Arc::StateId s)546 ssize_t NumInputEpsilons(const F &fst, typename F::Arc::StateId s) {
547   return fst.F::NumInputEpsilons(s);
548 }
549 
550 template <class F> inline
NumOutputEpsilons(const F & fst,typename F::Arc::StateId s)551 ssize_t NumOutputEpsilons(const F &fst, typename F::Arc::StateId s) {
552   return fst.F::NumOutputEpsilons(s);
553 }
554 
555 
556 //  Fst<A> case - abstract methods.
557 template <class A> inline
Final(const Fst<A> & fst,typename A::StateId s)558 typename A::Weight Final(const Fst<A> &fst, typename A::StateId s) {
559   return fst.Final(s);
560 }
561 
562 template <class A> inline
NumArcs(const Fst<A> & fst,typename A::StateId s)563 ssize_t NumArcs(const Fst<A> &fst, typename A::StateId s) {
564   return fst.NumArcs(s);
565 }
566 
567 template <class A> inline
NumInputEpsilons(const Fst<A> & fst,typename A::StateId s)568 ssize_t NumInputEpsilons(const Fst<A> &fst, typename A::StateId s) {
569   return fst.NumInputEpsilons(s);
570 }
571 
572 template <class A> inline
NumOutputEpsilons(const Fst<A> & fst,typename A::StateId s)573 ssize_t NumOutputEpsilons(const Fst<A> &fst, typename A::StateId s) {
574   return fst.NumOutputEpsilons(s);
575 }
576 
577 }  // namespace internal
578 
579 // A useful alias when using StdArc.
580 typedef Fst<StdArc> StdFst;
581 
582 
583 //
584 //  CONSTANT DEFINITIONS
585 //
586 
587 const int kNoStateId   =  -1;  // Not a valid state ID
588 const int kNoLabel     =  -1;  // Not a valid label
589 
590 //
591 // Fst IMPLEMENTATION BASE
592 //
593 // This is the recommended Fst implementation base class. It will
594 // handle reference counts, property bits, type information and symbols.
595 //
596 
597 template <class A> class FstImpl {
598  public:
599   typedef typename A::Weight Weight;
600   typedef typename A::StateId StateId;
601 
FstImpl()602   FstImpl()
603       : properties_(0), type_("null"), isymbols_(0), osymbols_(0) {}
604 
FstImpl(const FstImpl<A> & impl)605   FstImpl(const FstImpl<A> &impl)
606       : properties_(impl.properties_), type_(impl.type_),
607         isymbols_(impl.isymbols_ ? impl.isymbols_->Copy() : 0),
608         osymbols_(impl.osymbols_ ? impl.osymbols_->Copy() : 0) {}
609 
~FstImpl()610   virtual ~FstImpl() {
611     delete isymbols_;
612     delete osymbols_;
613   }
614 
Type()615   const string& Type() const { return type_; }
616 
SetType(const string & type)617   void SetType(const string &type) { type_ = type; }
618 
Properties()619   virtual uint64 Properties() const { return properties_; }
620 
Properties(uint64 mask)621   virtual uint64 Properties(uint64 mask) const { return properties_ & mask; }
622 
SetProperties(uint64 props)623   void SetProperties(uint64 props) {
624     properties_ &= kError;          // kError can't be cleared
625     properties_ |= props;
626   }
627 
SetProperties(uint64 props,uint64 mask)628   void SetProperties(uint64 props, uint64 mask) {
629     properties_ &= ~mask | kError;  // kError can't be cleared
630     properties_ |= props & mask;
631   }
632 
633   // Allows (only) setting error bit on const FST impls
SetProperties(uint64 props,uint64 mask)634   void SetProperties(uint64 props, uint64 mask) const {
635     if (mask != kError)
636       FSTERROR() << "FstImpl::SetProperties() const: can only set kError";
637     properties_ |= kError;
638   }
639 
InputSymbols()640   const SymbolTable* InputSymbols() const { return isymbols_; }
641 
OutputSymbols()642   const SymbolTable* OutputSymbols() const { return osymbols_; }
643 
InputSymbols()644   SymbolTable* InputSymbols() { return isymbols_; }
645 
OutputSymbols()646   SymbolTable* OutputSymbols() { return osymbols_; }
647 
SetInputSymbols(const SymbolTable * isyms)648   void SetInputSymbols(const SymbolTable* isyms) {
649     if (isymbols_) delete isymbols_;
650     isymbols_ = isyms ? isyms->Copy() : 0;
651   }
652 
SetOutputSymbols(const SymbolTable * osyms)653   void SetOutputSymbols(const SymbolTable* osyms) {
654     if (osymbols_) delete osymbols_;
655     osymbols_ = osyms ? osyms->Copy() : 0;
656   }
657 
RefCount()658   int RefCount() const {
659     return ref_count_.count();
660   }
661 
IncrRefCount()662   int IncrRefCount() {
663     return ref_count_.Incr();
664   }
665 
DecrRefCount()666   int DecrRefCount() {
667     return ref_count_.Decr();
668   }
669 
670   // Read-in header and symbols from input stream, initialize Fst, and
671   // return the header.  If opts.header is non-null, skip read-in and
672   // use the option value.  If opts.[io]symbols is non-null, read-in
673   // (if present), but use the option value.
674   bool ReadHeader(istream &strm, const FstReadOptions& opts,
675                   int min_version, FstHeader *hdr);
676 
677   // Write-out header and symbols from output stream.
678   // If a opts.header is false, skip writing header.
679   // If opts.[io]symbols is false, skip writing those symbols.
680   // This method is needed for Impl's that implement Write methods.
WriteHeader(ostream & strm,const FstWriteOptions & opts,int version,FstHeader * hdr)681   void WriteHeader(ostream &strm, const FstWriteOptions& opts,
682                    int version, FstHeader *hdr) const {
683     if (opts.write_header) {
684       hdr->SetFstType(type_);
685       hdr->SetArcType(A::Type());
686       hdr->SetVersion(version);
687       hdr->SetProperties(properties_);
688       int32 file_flags = 0;
689       if (isymbols_ && opts.write_isymbols)
690         file_flags |= FstHeader::HAS_ISYMBOLS;
691       if (osymbols_ && opts.write_osymbols)
692         file_flags |= FstHeader::HAS_OSYMBOLS;
693       if (opts.align)
694         file_flags |= FstHeader::IS_ALIGNED;
695       hdr->SetFlags(file_flags);
696       hdr->Write(strm, opts.source);
697     }
698     if (isymbols_ && opts.write_isymbols) isymbols_->Write(strm);
699     if (osymbols_ && opts.write_osymbols) osymbols_->Write(strm);
700   }
701 
702   // Write-out header and symbols to output stream.
703   // If a opts.header is false, skip writing header.
704   // If opts.[io]symbols is false, skip writing those symbols.
705   // type is the Fst type being written.
706   // This method is used in the cross-type serialization methods Fst::WriteFst.
WriteFstHeader(const Fst<A> & fst,ostream & strm,const FstWriteOptions & opts,int version,const string & type,FstHeader * hdr)707   static void WriteFstHeader(const Fst<A> &fst, ostream &strm,
708                              const FstWriteOptions& opts, int version,
709                              const string &type, FstHeader *hdr) {
710     if (opts.write_header) {
711       hdr->SetFstType(type);
712       hdr->SetArcType(A::Type());
713       hdr->SetVersion(version);
714       hdr->SetProperties(fst.Properties(kFstProperties, false));
715       int32 file_flags = 0;
716       if (fst.InputSymbols() && opts.write_isymbols)
717         file_flags |= FstHeader::HAS_ISYMBOLS;
718       if (fst.OutputSymbols() && opts.write_osymbols)
719         file_flags |= FstHeader::HAS_OSYMBOLS;
720       if (opts.align)
721         file_flags |= FstHeader::IS_ALIGNED;
722       hdr->SetFlags(file_flags);
723       hdr->Write(strm, opts.source);
724     }
725     if (fst.InputSymbols() && opts.write_isymbols) {
726       fst.InputSymbols()->Write(strm);
727     }
728     if (fst.OutputSymbols() && opts.write_osymbols) {
729       fst.OutputSymbols()->Write(strm);
730     }
731   }
732 
733   // In serialization routines where the header cannot be written until after
734   // the machine has been serialized, this routine can be called to seek to
735   // the beginning of the file an rewrite the header with updated fields.
736   // It repositions the file pointer back at the end of the file.
737   // returns true on success, false on failure.
UpdateFstHeader(const Fst<A> & fst,ostream & strm,const FstWriteOptions & opts,int version,const string & type,FstHeader * hdr,size_t header_offset)738   static bool UpdateFstHeader(const Fst<A> &fst, ostream &strm,
739                               const FstWriteOptions& opts, int version,
740                               const string &type, FstHeader *hdr,
741                               size_t header_offset) {
742     strm.seekp(header_offset);
743     if (!strm) {
744       LOG(ERROR) << "Fst::UpdateFstHeader: write failed: " << opts.source;
745       return false;
746     }
747     WriteFstHeader(fst, strm, opts, version, type, hdr);
748     if (!strm) {
749       LOG(ERROR) << "Fst::UpdateFstHeader: write failed: " << opts.source;
750       return false;
751     }
752     strm.seekp(0, ios_base::end);
753     if (!strm) {
754       LOG(ERROR) << "Fst::UpdateFstHeader: write failed: " << opts.source;
755       return false;
756     }
757     return true;
758   }
759 
760  protected:
761   mutable uint64 properties_;           // Property bits
762 
763  private:
764   string type_;                 // Unique name of Fst class
765   SymbolTable *isymbols_;       // Ilabel symbol table
766   SymbolTable *osymbols_;       // Olabel symbol table
767   RefCounter ref_count_;        // Reference count
768 
769   void operator=(const FstImpl<A> &impl);  // disallow
770 };
771 
772 template <class A> inline
ReadHeader(istream & strm,const FstReadOptions & opts,int min_version,FstHeader * hdr)773 bool FstImpl<A>::ReadHeader(istream &strm, const FstReadOptions& opts,
774                             int min_version, FstHeader *hdr) {
775   if (opts.header)
776     *hdr = *opts.header;
777   else if (!hdr->Read(strm, opts.source))
778     return false;
779 
780   if (FLAGS_v >= 2) {
781     LOG(INFO) << "FstImpl::ReadHeader: source: " << opts.source
782               << ", fst_type: " << hdr->FstType()
783               << ", arc_type: " << A::Type()
784               << ", version: " << hdr->Version()
785               << ", flags: " << hdr->GetFlags();
786   }
787 
788   if (hdr->FstType() != type_) {
789     LOG(ERROR) << "FstImpl::ReadHeader: Fst not of type \"" << type_
790                << "\": " << opts.source;
791     return false;
792   }
793   if (hdr->ArcType() != A::Type()) {
794     LOG(ERROR) << "FstImpl::ReadHeader: Arc not of type \"" << A::Type()
795                << "\": " << opts.source;
796     return false;
797   }
798   if (hdr->Version() < min_version) {
799     LOG(ERROR) << "FstImpl::ReadHeader: Obsolete " << type_
800                << " Fst version: " << opts.source;
801     return false;
802   }
803   properties_ = hdr->Properties();
804   if (hdr->GetFlags() & FstHeader::HAS_ISYMBOLS)
805     isymbols_ = SymbolTable::Read(strm, opts.source);
806   if (hdr->GetFlags() & FstHeader::HAS_OSYMBOLS)
807     osymbols_ =SymbolTable::Read(strm, opts.source);
808 
809   if (opts.isymbols) {
810     delete isymbols_;
811     isymbols_ = opts.isymbols->Copy();
812   }
813   if (opts.osymbols) {
814     delete osymbols_;
815     osymbols_ = opts.osymbols->Copy();
816   }
817   return true;
818 }
819 
820 
821 template<class Arc>
822 uint64 TestProperties(const Fst<Arc> &fst, uint64 mask, uint64 *known);
823 
824 
825 // This is a helper class template useful for attaching an Fst interface to
826 // its implementation, handling reference counting.
827 template < class I, class F = Fst<typename I::Arc> >
828 class ImplToFst : public F {
829  public:
830   typedef typename I::Arc Arc;
831   typedef typename Arc::Weight Weight;
832   typedef typename Arc::StateId StateId;
833 
~ImplToFst()834   virtual ~ImplToFst() { if (!impl_->DecrRefCount()) delete impl_;  }
835 
Start()836   virtual StateId Start() const { return impl_->Start(); }
837 
Final(StateId s)838   virtual Weight Final(StateId s) const { return impl_->Final(s); }
839 
NumArcs(StateId s)840   virtual size_t NumArcs(StateId s) const { return impl_->NumArcs(s); }
841 
NumInputEpsilons(StateId s)842   virtual size_t NumInputEpsilons(StateId s) const {
843     return impl_->NumInputEpsilons(s);
844   }
845 
NumOutputEpsilons(StateId s)846   virtual size_t NumOutputEpsilons(StateId s) const {
847     return impl_->NumOutputEpsilons(s);
848   }
849 
Properties(uint64 mask,bool test)850   virtual uint64 Properties(uint64 mask, bool test) const {
851     if (test) {
852       uint64 knownprops, testprops = TestProperties(*this, mask, &knownprops);
853       impl_->SetProperties(testprops, knownprops);
854       return testprops & mask;
855     } else {
856       return impl_->Properties(mask);
857     }
858   }
859 
Type()860   virtual const string& Type() const { return impl_->Type(); }
861 
InputSymbols()862   virtual const SymbolTable* InputSymbols() const {
863     return impl_->InputSymbols();
864   }
865 
OutputSymbols()866   virtual const SymbolTable* OutputSymbols() const {
867     return impl_->OutputSymbols();
868   }
869 
870  protected:
ImplToFst()871   ImplToFst() : impl_(0) {}
872 
ImplToFst(I * impl)873   ImplToFst(I *impl) : impl_(impl) {}
874 
ImplToFst(const ImplToFst<I,F> & fst)875   ImplToFst(const ImplToFst<I, F> &fst) {
876     impl_ = fst.impl_;
877     impl_->IncrRefCount();
878   }
879 
880   // This constructor presumes there is a copy constructor for the
881   // implementation.
ImplToFst(const ImplToFst<I,F> & fst,bool safe)882   ImplToFst(const ImplToFst<I, F> &fst, bool safe) {
883     if (safe) {
884       impl_ = new I(*(fst.impl_));
885     } else {
886       impl_ = fst.impl_;
887       impl_->IncrRefCount();
888     }
889   }
890 
GetImpl()891   I *GetImpl() const { return impl_; }
892 
893   // Change Fst implementation pointer. If 'own_impl' is true,
894   // ownership of the input implementation is given to this
895   // object; otherwise, the input implementation's reference count
896   // should be incremented.
897   void SetImpl(I *impl, bool own_impl = true) {
898     if (!own_impl)
899       impl->IncrRefCount();
900     if (impl_ && !impl_->DecrRefCount()) delete impl_;
901     impl_ = impl;
902   }
903 
904  private:
905   // Disallow
906   ImplToFst<I, F> &operator=(const ImplToFst<I, F> &fst);
907 
908   ImplToFst<I, F> &operator=(const Fst<Arc> &fst) {
909     FSTERROR() << "ImplToFst: Assignment operator disallowed";
910     GetImpl()->SetProperties(kError, kError);
911     return *this;
912   }
913 
914   I *impl_;
915 };
916 
917 
918 // Converts FSTs by casting their implementations, where this makes
919 // sense (which excludes implementations with weight-dependent virtual
920 // methods). Must be a friend of the Fst classes involved (currently
921 // the concrete Fsts: VectorFst, ConstFst, CompactFst).
Cast(const F & ifst,G * ofst)922 template<class F, class G> void Cast(const F &ifst, G *ofst) {
923   ofst->SetImpl(reinterpret_cast<typename G::Impl *>(ifst.GetImpl()), false);
924 }
925 
926 // Fst Serialization
927 template <class A>
FstToString(const Fst<A> & fst,string * result)928 void FstToString(const Fst<A> &fst, string *result) {
929   ostringstream ostrm;
930   fst.Write(ostrm, FstWriteOptions("FstToString"));
931   *result = ostrm.str();
932 }
933 
934 template <class A>
StringToFst(const string & s)935 Fst<A> *StringToFst(const string &s) {
936   istringstream istrm(s);
937   return Fst<A>::Read(istrm, FstReadOptions("StringToFst"));
938 }
939 
940 }  // namespace fst
941 
942 #endif  // FST_LIB_FST_H__
943