• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // replace.h
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 //
16 // \file
17 // Functions and classes for the recursive replacement of Fsts.
18 //
19 
20 #ifndef FST_LIB_REPLACE_H__
21 #define FST_LIB_REPLACE_H__
22 
23 #include <ext/hash_map>
24 using __gnu_cxx::hash_map;
25 
26 #include "fst/lib/fst.h"
27 #include "fst/lib/cache.h"
28 #include "fst/lib/test-properties.h"
29 
30 namespace fst {
31 
32 // By default ReplaceFst will copy the input label of the 'replace arc'.
33 // For acceptors we do not want this behaviour. Instead we need to
34 // create an epsilon arc when recursing into the appropriate Fst.
35 // The epsilon_on_replace option can be used to toggle this behaviour.
36 struct ReplaceFstOptions : CacheOptions {
37   int64 root;    // root rule for expansion
38   bool  epsilon_on_replace;
39 
ReplaceFstOptionsReplaceFstOptions40   ReplaceFstOptions(const CacheOptions &opts, int64 r)
41       : CacheOptions(opts), root(r), epsilon_on_replace(false) {}
ReplaceFstOptionsReplaceFstOptions42   explicit ReplaceFstOptions(int64 r)
43       : root(r), epsilon_on_replace(false) {}
ReplaceFstOptionsReplaceFstOptions44   ReplaceFstOptions(int64 r, bool epsilon_replace_arc)
45       : root(r), epsilon_on_replace(epsilon_replace_arc) {}
ReplaceFstOptionsReplaceFstOptions46   ReplaceFstOptions()
47       : root(kNoLabel), epsilon_on_replace(false) {}
48 };
49 
50 //
51 // \class ReplaceFstImpl
52 // \brief Implementation class for replace class Fst
53 //
54 // The replace implementation class supports a dynamic
55 // expansion of a recursive transition network represented as Fst
56 // with dynamic replacable arcs.
57 //
58 template <class A>
59 class ReplaceFstImpl : public CacheImpl<A> {
60  public:
61   using FstImpl<A>::SetType;
62   using FstImpl<A>::SetProperties;
63   using FstImpl<A>::Properties;
64   using FstImpl<A>::SetInputSymbols;
65   using FstImpl<A>::SetOutputSymbols;
66   using FstImpl<A>::InputSymbols;
67   using FstImpl<A>::OutputSymbols;
68 
69   using CacheImpl<A>::HasStart;
70   using CacheImpl<A>::HasArcs;
71   using CacheImpl<A>::SetStart;
72 
73   typedef typename A::Label   Label;
74   typedef typename A::Weight  Weight;
75   typedef typename A::StateId StateId;
76   typedef CacheState<A> State;
77   typedef A Arc;
78   typedef hash_map<Label, Label> NonTerminalHash;
79 
80 
81   // \struct StateTuple
82   // \brief Tuple of information that uniquely defines a state
83   struct StateTuple {
84     typedef int PrefixId;
85 
StateTupleStateTuple86     StateTuple() {}
StateTupleStateTuple87     StateTuple(PrefixId p, StateId f, StateId s) :
88         prefix_id(p), fst_id(f), fst_state(s) {}
89 
90     PrefixId prefix_id;  // index in prefix table
91     StateId fst_id;      // current fst being walked
92     StateId fst_state;   // current state in fst being walked, not to be
93                          // confused with the state_id of the combined fst
94   };
95 
96   // constructor for replace class implementation.
97   // \param fst_tuples array of label/fst tuples, one for each non-terminal
ReplaceFstImpl(const vector<pair<Label,const Fst<A> * >> & fst_tuples,const ReplaceFstOptions & opts)98   ReplaceFstImpl(const vector< pair<Label, const Fst<A>* > >& fst_tuples,
99                  const ReplaceFstOptions &opts)
100       : CacheImpl<A>(opts), opts_(opts) {
101     SetType("replace");
102     if (fst_tuples.size() > 0) {
103       SetInputSymbols(fst_tuples[0].second->InputSymbols());
104       SetOutputSymbols(fst_tuples[0].second->OutputSymbols());
105     }
106 
107     fst_array_.push_back(0);
108     for (size_t i = 0; i < fst_tuples.size(); ++i)
109       AddFst(fst_tuples[i].first, fst_tuples[i].second);
110 
111     SetRoot(opts.root);
112   }
113 
ReplaceFstImpl(const ReplaceFstOptions & opts)114   explicit ReplaceFstImpl(const ReplaceFstOptions &opts)
115       : CacheImpl<A>(opts), opts_(opts), root_(kNoLabel) {
116     fst_array_.push_back(0);
117   }
118 
ReplaceFstImpl(const ReplaceFstImpl & impl)119   ReplaceFstImpl(const ReplaceFstImpl& impl)
120       : opts_(impl.opts_), state_tuples_(impl.state_tuples_),
121         state_hash_(impl.state_hash_),
122         prefix_hash_(impl.prefix_hash_),
123         stackprefix_array_(impl.stackprefix_array_),
124         nonterminal_hash_(impl.nonterminal_hash_),
125         root_(impl.root_) {
126     SetType("replace");
127     SetProperties(impl.Properties(), kCopyProperties);
128     SetInputSymbols(InputSymbols());
129     SetOutputSymbols(OutputSymbols());
130     fst_array_.reserve(impl.fst_array_.size());
131     fst_array_.push_back(0);
132     for (size_t i = 1; i < impl.fst_array_.size(); ++i)
133       fst_array_.push_back(impl.fst_array_[i]->Copy());
134   }
135 
~ReplaceFstImpl()136   ~ReplaceFstImpl() {
137     for (size_t i = 1; i < fst_array_.size(); ++i) {
138       delete fst_array_[i];
139     }
140   }
141 
142   // Add to Fst array
AddFst(Label label,const Fst<A> * fst)143   void AddFst(Label label, const Fst<A>* fst) {
144     nonterminal_hash_[label] = fst_array_.size();
145     fst_array_.push_back(fst->Copy());
146     if (fst_array_.size() > 1) {
147       vector<uint64> inprops(fst_array_.size());
148 
149       for (size_t i = 1; i < fst_array_.size(); ++i) {
150         inprops[i] = fst_array_[i]->Properties(kCopyProperties, false);
151       }
152       SetProperties(ReplaceProperties(inprops));
153 
154       const SymbolTable* isymbols = fst_array_[1]->InputSymbols();
155       const SymbolTable* osymbols = fst_array_[1]->OutputSymbols();
156       for (size_t i = 2; i < fst_array_.size(); ++i) {
157         if (!CompatSymbols(isymbols, fst_array_[i]->InputSymbols())) {
158           LOG(FATAL) << "ReplaceFst::AddFst input symbols of Fst " << i-1
159                      << " does not match input symbols of base Fst (0'th fst)";
160         }
161         if (!CompatSymbols(osymbols, fst_array_[i]->OutputSymbols())) {
162           LOG(FATAL) << "ReplaceFst::AddFst output symbols of Fst " << i-1
163                      << " does not match output symbols of base Fst "
164                      << "(0'th fst)";
165         }
166       }
167     }
168   }
169 
170   // Computes the dependency graph of the replace class and returns
171   // true if the dependencies are cyclic. Cyclic dependencies will result
172   // in an un-expandable replace fst.
CyclicDependencies()173   bool CyclicDependencies() const {
174     StdVectorFst depfst;
175 
176     // one state for each fst
177     for (size_t i = 1; i < fst_array_.size(); ++i)
178       depfst.AddState();
179 
180     // an arc from each state (representing the fst) to the
181     // state representing the fst being replaced
182     for (size_t i = 1; i < fst_array_.size(); ++i) {
183       for (StateIterator<Fst<A> > siter(*(fst_array_[i]));
184            !siter.Done(); siter.Next()) {
185         for (ArcIterator<Fst<A> > aiter(*(fst_array_[i]), siter.Value());
186              !aiter.Done(); aiter.Next()) {
187           const A& arc = aiter.Value();
188 
189           typename NonTerminalHash::const_iterator it =
190               nonterminal_hash_.find(arc.olabel);
191           if (it != nonterminal_hash_.end()) {
192             Label j = it->second - 1;
193             depfst.AddArc(i - 1, A(arc.olabel, arc.olabel, Weight::One(), j));
194           }
195         }
196       }
197     }
198 
199     depfst.SetStart(root_ - 1);
200     depfst.SetFinal(root_ - 1, Weight::One());
201     return depfst.Properties(kCyclic, true);
202   }
203 
204   // set root rule for expansion
SetRoot(Label root)205   void SetRoot(Label root) {
206     Label nonterminal = nonterminal_hash_[root];
207     root_ = (nonterminal > 0) ? nonterminal : 1;
208   }
209 
210   // Change Fst array
SetFst(Label label,const Fst<A> * fst)211   void SetFst(Label label, const Fst<A>* fst) {
212     Label nonterminal = nonterminal_hash_[label];
213     delete fst_array_[nonterminal];
214     fst_array_[nonterminal] = fst->Copy();
215   }
216 
217   // Return or compute start state of replace fst
Start()218   StateId Start() {
219     if (!HasStart()) {
220       if (fst_array_.size() == 1) {      // no fsts defined for replace
221         SetStart(kNoStateId);
222         return kNoStateId;
223       } else {
224         const Fst<A>* fst = fst_array_[root_];
225         StateId fst_start = fst->Start();
226         if (fst_start == kNoStateId)  // root Fst is empty
227           return kNoStateId;
228 
229         int prefix = PrefixId(StackPrefix());
230         StateId start = FindState(StateTuple(prefix, root_, fst_start));
231         SetStart(start);
232         return start;
233       }
234     } else {
235       return CacheImpl<A>::Start();
236     }
237   }
238 
239   // return final weight of state (kInfWeight means state is not final)
Final(StateId s)240   Weight Final(StateId s) {
241     if (!HasFinal(s)) {
242       const StateTuple& tuple  = state_tuples_[s];
243       const StackPrefix& stack = stackprefix_array_[tuple.prefix_id];
244       const Fst<A>* fst = fst_array_[tuple.fst_id];
245       StateId fst_state = tuple.fst_state;
246 
247       if (fst->Final(fst_state) != Weight::Zero() && stack.Depth() == 0)
248         SetFinal(s, fst->Final(fst_state));
249       else
250         SetFinal(s, Weight::Zero());
251     }
252     return CacheImpl<A>::Final(s);
253   }
254 
NumArcs(StateId s)255   size_t NumArcs(StateId s) {
256     if (!HasArcs(s))
257       Expand(s);
258     return CacheImpl<A>::NumArcs(s);
259   }
260 
NumInputEpsilons(StateId s)261   size_t NumInputEpsilons(StateId s) {
262     if (!HasArcs(s))
263       Expand(s);
264     return CacheImpl<A>::NumInputEpsilons(s);
265   }
266 
NumOutputEpsilons(StateId s)267   size_t NumOutputEpsilons(StateId s) {
268     if (!HasArcs(s))
269       Expand(s);
270     return CacheImpl<A>::NumOutputEpsilons(s);
271   }
272 
273   // return the base arc iterator, if arcs have not been computed yet,
274   // extend/recurse for new arcs.
InitArcIterator(StateId s,ArcIteratorData<A> * data)275   void InitArcIterator(StateId s, ArcIteratorData<A> *data) {
276     if (!HasArcs(s))
277       Expand(s);
278     CacheImpl<A>::InitArcIterator(s, data);
279   }
280 
281   // Find/create an Fst state given a StateTuple.  Only create a new
282   // state if StateTuple is not found in the state hash.
FindState(const StateTuple & tuple)283   StateId FindState(const StateTuple& tuple) {
284     typename StateTupleHash::iterator it = state_hash_.find(tuple);
285     if (it == state_hash_.end()) {
286       StateId new_state_id = state_tuples_.size();
287       state_tuples_.push_back(tuple);
288       state_hash_[tuple] = new_state_id;
289       return new_state_id;
290     } else {
291       return it->second;
292     }
293   }
294 
295   // extend current state (walk arcs one level deep)
Expand(StateId s)296   void Expand(StateId s) {
297     StateTuple tuple  = state_tuples_[s];
298     const Fst<A>* fst = fst_array_[tuple.fst_id];
299     StateId fst_state = tuple.fst_state;
300     if (fst_state == kNoStateId) {
301       SetArcs(s);
302       return;
303     }
304 
305     // if state is final, pop up stack
306     const StackPrefix& stack = stackprefix_array_[tuple.prefix_id];
307     if (fst->Final(fst_state) != Weight::Zero() && stack.Depth()) {
308       int prefix_id = PopPrefix(stack);
309       const PrefixTuple& top = stack.Top();
310 
311       StateId nextstate =
312         FindState(StateTuple(prefix_id, top.fst_id, top.nextstate));
313       AddArc(s, A(0, 0, fst->Final(fst_state), nextstate));
314     }
315 
316     // extend arcs leaving the state
317     for (ArcIterator< Fst<A> > aiter(*fst, fst_state);
318          !aiter.Done(); aiter.Next()) {
319       const Arc& arc = aiter.Value();
320       if (arc.olabel == 0) {  // expand local fst
321         StateId nextstate =
322           FindState(StateTuple(tuple.prefix_id, tuple.fst_id, arc.nextstate));
323         AddArc(s, A(arc.ilabel, arc.olabel, arc.weight, nextstate));
324       } else {
325         // check for non terminal
326         typename NonTerminalHash::const_iterator it =
327             nonterminal_hash_.find(arc.olabel);
328         if (it != nonterminal_hash_.end()) {  // recurse into non terminal
329           Label nonterminal = it->second;
330           const Fst<A>* nt_fst = fst_array_[nonterminal];
331           int nt_prefix = PushPrefix(stackprefix_array_[tuple.prefix_id],
332                                      tuple.fst_id, arc.nextstate);
333 
334           // if start state is valid replace, else arc is implicitly
335           // deleted
336           StateId nt_start = nt_fst->Start();
337           if (nt_start != kNoStateId) {
338             StateId nt_nextstate = FindState(
339                 StateTuple(nt_prefix, nonterminal, nt_start));
340             Label ilabel = (opts_.epsilon_on_replace) ? 0 : arc.ilabel;
341             AddArc(s, A(ilabel, 0, arc.weight, nt_nextstate));
342           }
343         } else {
344           StateId nextstate =
345             FindState(
346                 StateTuple(tuple.prefix_id, tuple.fst_id, arc.nextstate));
347           AddArc(s, A(arc.ilabel, arc.olabel, arc.weight, nextstate));
348         }
349       }
350     }
351 
352     SetArcs(s);
353   }
354 
355 
356   // private helper classes
357  private:
358   static const int kPrime0 = 7853;
359   static const int kPrime1 = 7867;
360 
361   // \class StateTupleEqual
362   // \brief Compare two StateTuples for equality
363   class StateTupleEqual {
364    public:
operator()365     bool operator()(const StateTuple& x, const StateTuple& y) const {
366       return ((x.prefix_id == y.prefix_id) && (x.fst_id == y.fst_id) &&
367               (x.fst_state == y.fst_state));
368     }
369   };
370 
371   // \class StateTupleKey
372   // \brief Hash function for StateTuple to Fst states
373   class StateTupleKey {
374    public:
operator()375     size_t operator()(const StateTuple& x) const {
376       return static_cast<size_t>(x.prefix_id +
377                                  x.fst_id * kPrime0 +
378                                  x.fst_state * kPrime1);
379     }
380   };
381 
382   typedef hash_map<StateTuple, StateId, StateTupleKey, StateTupleEqual>
383   StateTupleHash;
384 
385   // \class PrefixTuple
386   // \brief Tuple of fst_id and destination state (entry in stack prefix)
387   struct PrefixTuple {
PrefixTuplePrefixTuple388     PrefixTuple(Label f, StateId s) : fst_id(f), nextstate(s) {}
389 
390     Label   fst_id;
391     StateId nextstate;
392   };
393 
394   // \class StackPrefix
395   // \brief Container for stack prefix.
396   class StackPrefix {
397    public:
StackPrefix()398     StackPrefix() {}
399 
400     // copy constructor
StackPrefix(const StackPrefix & x)401     StackPrefix(const StackPrefix& x) :
402         prefix_(x.prefix_) {
403     }
404 
Push(int fst_id,StateId nextstate)405     void Push(int fst_id, StateId nextstate) {
406       prefix_.push_back(PrefixTuple(fst_id, nextstate));
407     }
408 
Pop()409     void Pop() {
410       prefix_.pop_back();
411     }
412 
Top()413     const PrefixTuple& Top() const {
414       return prefix_[prefix_.size()-1];
415     }
416 
Depth()417     size_t Depth() const {
418       return prefix_.size();
419     }
420 
421    public:
422     vector<PrefixTuple> prefix_;
423   };
424 
425 
426   // \class StackPrefixEqual
427   // \brief Compare two stack prefix classes for equality
428   class StackPrefixEqual {
429    public:
operator()430     bool operator()(const StackPrefix& x, const StackPrefix& y) const {
431       if (x.prefix_.size() != y.prefix_.size()) return false;
432       for (size_t i = 0; i < x.prefix_.size(); ++i) {
433         if (x.prefix_[i].fst_id    != y.prefix_[i].fst_id ||
434            x.prefix_[i].nextstate != y.prefix_[i].nextstate) return false;
435       }
436       return true;
437     }
438   };
439 
440   //
441   // \class StackPrefixKey
442   // \brief Hash function for stack prefix to prefix id
443   class StackPrefixKey {
444    public:
operator()445     size_t operator()(const StackPrefix& x) const {
446       int sum = 0;
447       for (size_t i = 0; i < x.prefix_.size(); ++i) {
448         sum += x.prefix_[i].fst_id + x.prefix_[i].nextstate*kPrime0;
449       }
450       return (size_t) sum;
451     }
452   };
453 
454   typedef hash_map<StackPrefix, int, StackPrefixKey, StackPrefixEqual>
455   StackPrefixHash;
456 
457   // private methods
458  private:
459   // hash stack prefix (return unique index into stackprefix array)
PrefixId(const StackPrefix & prefix)460   int PrefixId(const StackPrefix& prefix) {
461     typename StackPrefixHash::iterator it = prefix_hash_.find(prefix);
462     if (it == prefix_hash_.end()) {
463       int prefix_id = stackprefix_array_.size();
464       stackprefix_array_.push_back(prefix);
465       prefix_hash_[prefix] = prefix_id;
466       return prefix_id;
467     } else {
468       return it->second;
469     }
470   }
471 
472   // prefix id after a stack pop
PopPrefix(StackPrefix prefix)473   int PopPrefix(StackPrefix prefix) {
474     prefix.Pop();
475     return PrefixId(prefix);
476   }
477 
478   // prefix id after a stack push
PushPrefix(StackPrefix prefix,Label fst_id,StateId nextstate)479   int PushPrefix(StackPrefix prefix, Label fst_id, StateId nextstate) {
480     prefix.Push(fst_id, nextstate);
481     return PrefixId(prefix);
482   }
483 
484 
485   // private data
486  private:
487   // runtime options
488   ReplaceFstOptions opts_;
489 
490   // maps from StateId to StateTuple
491   vector<StateTuple> state_tuples_;
492 
493   // hashes from StateTuple to StateId
494   StateTupleHash state_hash_;
495 
496   // cross index of unique stack prefix
497   // could potentially have one copy of prefix array
498   StackPrefixHash prefix_hash_;
499   vector<StackPrefix> stackprefix_array_;
500 
501   NonTerminalHash nonterminal_hash_;
502   vector<const Fst<A>*> fst_array_;
503 
504   Label root_;
505 
506   void operator=(const ReplaceFstImpl<A> &);  // disallow
507 };
508 
509 
510 //
511 // \class ReplaceFst
512 // \brief Recursivively replaces arcs in the root Fst with other Fsts.
513 // This version is a delayed Fst.
514 //
515 // ReplaceFst supports dynamic replacement of arcs in one Fst with
516 // another Fst. This replacement is recursive.  ReplaceFst can be used
517 // to support a variety of delayed constructions such as recursive
518 // transition networks, union, or closure.  It is constructed with an
519 // array of Fst(s). One Fst represents the root (or topology)
520 // machine. The root Fst refers to other Fsts by recursively replacing
521 // arcs labeled as non-terminals with the matching non-terminal
522 // Fst. Currently the ReplaceFst uses the output symbols of the arcs
523 // to determine whether the arc is a non-terminal arc or not. A
524 // non-terminal can be any label that is not a non-zero terminal label
525 // in the output alphabet.
526 //
527 // Note that the constructor uses a vector of pair<>. These correspond
528 // to the tuple of non-terminal Label and corresponding Fst. For example
529 // to implement the closure operation we need 2 Fsts. The first root
530 // Fst is a single Arc on the start State that self loops, it references
531 // the particular machine for which we are performing the closure operation.
532 //
533 template <class A>
534 class ReplaceFst : public Fst<A> {
535  public:
536   friend class ArcIterator< ReplaceFst<A> >;
537   friend class CacheStateIterator< ReplaceFst<A> >;
538   friend class CacheArcIterator< ReplaceFst<A> >;
539 
540   typedef A Arc;
541   typedef typename A::Label   Label;
542   typedef typename A::Weight  Weight;
543   typedef typename A::StateId StateId;
544   typedef CacheState<A> State;
545 
ReplaceFst(const vector<pair<Label,const Fst<A> * >> & fst_array,Label root)546   ReplaceFst(const vector<pair<Label, const Fst<A>* > >& fst_array,
547              Label root)
548       : impl_(new ReplaceFstImpl<A>(fst_array, ReplaceFstOptions(root))) {}
549 
ReplaceFst(const vector<pair<Label,const Fst<A> * >> & fst_array,const ReplaceFstOptions & opts)550   ReplaceFst(const vector<pair<Label, const Fst<A>* > >& fst_array,
551              const ReplaceFstOptions &opts)
552       : impl_(new ReplaceFstImpl<A>(fst_array, opts)) {}
553 
ReplaceFst(const ReplaceFst<A> & fst)554   ReplaceFst(const ReplaceFst<A>& fst) :
555       impl_(new ReplaceFstImpl<A>(*(fst.impl_))) {}
556 
~ReplaceFst()557   virtual ~ReplaceFst() {
558     delete impl_;
559   }
560 
Start()561   virtual StateId Start() const {
562     return impl_->Start();
563   }
564 
Final(StateId s)565   virtual Weight Final(StateId s) const {
566     return impl_->Final(s);
567   }
568 
NumArcs(StateId s)569   virtual size_t NumArcs(StateId s) const {
570     return impl_->NumArcs(s);
571   }
572 
NumInputEpsilons(StateId s)573   virtual size_t NumInputEpsilons(StateId s) const {
574     return impl_->NumInputEpsilons(s);
575   }
576 
NumOutputEpsilons(StateId s)577   virtual size_t NumOutputEpsilons(StateId s) const {
578     return impl_->NumOutputEpsilons(s);
579   }
580 
Properties(uint64 mask,bool test)581   virtual uint64 Properties(uint64 mask, bool test) const {
582     if (test) {
583       uint64 known, test = TestProperties(*this, mask, &known);
584       impl_->SetProperties(test, known);
585       return test & mask;
586     } else {
587       return impl_->Properties(mask);
588     }
589   }
590 
Type()591   virtual const string& Type() const {
592     return impl_->Type();
593   }
594 
Copy()595   virtual ReplaceFst<A>* Copy() const {
596     return new ReplaceFst<A>(*this);
597   }
598 
InputSymbols()599   virtual const SymbolTable* InputSymbols() const {
600     return impl_->InputSymbols();
601   }
602 
OutputSymbols()603   virtual const SymbolTable* OutputSymbols() const {
604     return impl_->OutputSymbols();
605   }
606 
607   virtual inline void InitStateIterator(StateIteratorData<A> *data) const;
608 
InitArcIterator(StateId s,ArcIteratorData<A> * data)609   virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
610     impl_->InitArcIterator(s, data);
611   }
612 
CyclicDependencies()613   bool CyclicDependencies() const {
614     return impl_->CyclicDependencies();
615   }
616 
617  private:
618   ReplaceFstImpl<A>* impl_;
619 };
620 
621 
622 // Specialization for ReplaceFst.
623 template<class A>
624 class StateIterator< ReplaceFst<A> >
625     : public CacheStateIterator< ReplaceFst<A> > {
626  public:
StateIterator(const ReplaceFst<A> & fst)627   explicit StateIterator(const ReplaceFst<A> &fst)
628       : CacheStateIterator< ReplaceFst<A> >(fst) {}
629 
630  private:
631   DISALLOW_EVIL_CONSTRUCTORS(StateIterator);
632 };
633 
634 // Specialization for ReplaceFst.
635 template <class A>
636 class ArcIterator< ReplaceFst<A> >
637     : public CacheArcIterator< ReplaceFst<A> > {
638  public:
639   typedef typename A::StateId StateId;
640 
ArcIterator(const ReplaceFst<A> & fst,StateId s)641   ArcIterator(const ReplaceFst<A> &fst, StateId s)
642       : CacheArcIterator< ReplaceFst<A> >(fst, s) {
643     if (!fst.impl_->HasArcs(s))
644       fst.impl_->Expand(s);
645   }
646 
647  private:
648   DISALLOW_EVIL_CONSTRUCTORS(ArcIterator);
649 };
650 
651 template <class A> inline
InitStateIterator(StateIteratorData<A> * data)652 void ReplaceFst<A>::InitStateIterator(StateIteratorData<A> *data) const {
653   data->base = new StateIterator< ReplaceFst<A> >(*this);
654 }
655 
656 typedef ReplaceFst<StdArc> StdReplaceFst;
657 
658 
659 // // Recursivively replaces arcs in the root Fst with other Fsts.
660 // This version writes the result of replacement to an output MutableFst.
661 //
662 // Replace supports replacement of arcs in one Fst with another
663 // Fst. This replacement is recursive.  Replace takes an array of
664 // Fst(s). One Fst represents the root (or topology) machine. The root
665 // Fst refers to other Fsts by recursively replacing arcs labeled as
666 // non-terminals with the matching non-terminal Fst. Currently Replace
667 // uses the output symbols of the arcs to determine whether the arc is
668 // a non-terminal arc or not. A non-terminal can be any label that is
669 // not a non-zero terminal label in the output alphabet.  Note that
670 // input argument is a vector of pair<>. These correspond to the tuple
671 // of non-terminal Label and corresponding Fst.
672 template<class Arc>
Replace(const vector<pair<typename Arc::Label,const Fst<Arc> * >> & ifst_array,MutableFst<Arc> * ofst,typename Arc::Label root)673 void Replace(const vector<pair<typename Arc::Label,
674              const Fst<Arc>* > >& ifst_array,
675              MutableFst<Arc> *ofst, typename Arc::Label root) {
676   ReplaceFstOptions opts(root);
677   opts.gc_limit = 0;  // Cache only the last state for fastest copy.
678   *ofst = ReplaceFst<Arc>(ifst_array, opts);
679 }
680 
681 }
682 
683 #endif  // FST_LIB_REPLACE_H__
684