• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // replace.h
2 
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Copyright 2005-2010 Google, Inc.
16 // Author: johans@google.com (Johan Schalkwyk)
17 //
18 // \file
19 // Functions and classes for the recursive replacement of Fsts.
20 //
21 
22 #ifndef FST_LIB_REPLACE_H__
23 #define FST_LIB_REPLACE_H__
24 
25 #include <unordered_map>
26 using std::tr1::unordered_map;
27 using std::tr1::unordered_multimap;
28 #include <set>
29 #include <string>
30 #include <utility>
31 using std::pair; using std::make_pair;
32 #include <vector>
33 using std::vector;
34 
35 #include <fst/cache.h>
36 #include <fst/expanded-fst.h>
37 #include <fst/fst.h>
38 #include <fst/matcher.h>
39 #include <fst/replace-util.h>
40 #include <fst/state-table.h>
41 #include <fst/test-properties.h>
42 
43 namespace fst {
44 
45 //
46 // REPLACE STATE TUPLES AND TABLES
47 //
48 // The replace state table has the form
49 //
50 // template <class A, class P>
51 // class ReplaceStateTable {
52 //  public:
53 //   typedef A Arc;
54 //   typedef P PrefixId;
55 //   typedef typename A::StateId StateId;
56 //   typedef ReplaceStateTuple<StateId, PrefixId> StateTuple;
57 //   typedef typename A::Label Label;
58 //
59 //   // Required constuctor
60 //   ReplaceStateTable(const vector<pair<Label, const Fst<A>*> > &fst_tuples,
61 //                     Label root);
62 //
63 //   // Required copy constructor that does not copy state
64 //   ReplaceStateTable(const ReplaceStateTable<A,P> &table);
65 //
66 //   // Lookup state ID by tuple. If it doesn't exist, then add it.
67 //   StateId FindState(const StateTuple &tuple);
68 //
69 //   // Lookup state tuple by ID.
70 //   const StateTuple &Tuple(StateId id) const;
71 // };
72 
73 
74 // \struct ReplaceStateTuple
75 // \brief Tuple of information that uniquely defines a state in replace
76 template <class S, class P>
77 struct ReplaceStateTuple {
78   typedef S StateId;
79   typedef P PrefixId;
80 
ReplaceStateTupleReplaceStateTuple81   ReplaceStateTuple()
82       : prefix_id(-1), fst_id(kNoStateId), fst_state(kNoStateId) {}
83 
ReplaceStateTupleReplaceStateTuple84   ReplaceStateTuple(PrefixId p, StateId f, StateId s)
85       : prefix_id(p), fst_id(f), fst_state(s) {}
86 
87   PrefixId prefix_id;  // index in prefix table
88   StateId fst_id;      // current fst being walked
89   StateId fst_state;   // current state in fst being walked, not to be
90                        // confused with the state_id of the combined fst
91 };
92 
93 
94 // Equality of replace state tuples.
95 template <class S, class P>
96 inline bool operator==(const ReplaceStateTuple<S, P>& x,
97                        const ReplaceStateTuple<S, P>& y) {
98   return x.prefix_id == y.prefix_id &&
99       x.fst_id == y.fst_id &&
100       x.fst_state == y.fst_state;
101 }
102 
103 
104 // \class ReplaceRootSelector
105 // Functor returning true for tuples corresponding to states in the root FST
106 template <class S, class P>
107 class ReplaceRootSelector {
108  public:
operator()109   bool operator()(const ReplaceStateTuple<S, P> &tuple) const {
110     return tuple.prefix_id == 0;
111   }
112 };
113 
114 
115 // \class ReplaceFingerprint
116 // Fingerprint for general replace state tuples.
117 template <class S, class P>
118 class ReplaceFingerprint {
119  public:
ReplaceFingerprint(const vector<uint64> * size_array)120   ReplaceFingerprint(const vector<uint64> *size_array)
121       : cumulative_size_array_(size_array) {}
122 
operator()123   uint64 operator()(const ReplaceStateTuple<S, P> &tuple) const {
124     return tuple.prefix_id * (cumulative_size_array_->back()) +
125         cumulative_size_array_->at(tuple.fst_id - 1) +
126         tuple.fst_state;
127   }
128 
129  private:
130   const vector<uint64> *cumulative_size_array_;
131 };
132 
133 
134 // \class ReplaceFstStateFingerprint
135 // Useful when the fst_state uniquely define the tuple.
136 template <class S, class P>
137 class ReplaceFstStateFingerprint {
138  public:
operator()139   uint64 operator()(const ReplaceStateTuple<S, P>& tuple) const {
140     return tuple.fst_state;
141   }
142 };
143 
144 
145 // \class ReplaceHash
146 // A generic hash function for replace state tuples.
147 template <typename S, typename P>
148 class ReplaceHash {
149  public:
operator()150   size_t operator()(const ReplaceStateTuple<S, P>& t) const {
151     return t.prefix_id + t.fst_id * kPrime0 + t.fst_state * kPrime1;
152   }
153  private:
154   static const size_t kPrime0;
155   static const size_t kPrime1;
156 };
157 
158 template <typename S, typename P>
159 const size_t ReplaceHash<S, P>::kPrime0 = 7853;
160 
161 template <typename S, typename P>
162 const size_t ReplaceHash<S, P>::kPrime1 = 7867;
163 
164 template <class A, class T> class ReplaceFstMatcher;
165 
166 
167 // \class VectorHashReplaceStateTable
168 // A two-level state table for replace.
169 // Warning: calls CountStates to compute the number of states of each
170 // component Fst.
171 template <class A, class P = ssize_t>
172 class VectorHashReplaceStateTable {
173  public:
174   typedef A Arc;
175   typedef typename A::StateId StateId;
176   typedef typename A::Label Label;
177   typedef P PrefixId;
178   typedef ReplaceStateTuple<StateId, P> StateTuple;
179   typedef VectorHashStateTable<ReplaceStateTuple<StateId, P>,
180                                ReplaceRootSelector<StateId, P>,
181                                ReplaceFstStateFingerprint<StateId, P>,
182                                ReplaceFingerprint<StateId, P> > StateTable;
183 
VectorHashReplaceStateTable(const vector<pair<Label,const Fst<A> * >> & fst_tuples,Label root)184   VectorHashReplaceStateTable(
185       const vector<pair<Label, const Fst<A>*> > &fst_tuples,
186       Label root) : root_size_(0) {
187     cumulative_size_array_.push_back(0);
188     for (size_t i = 0; i < fst_tuples.size(); ++i) {
189       if (fst_tuples[i].first == root) {
190         root_size_ = CountStates(*(fst_tuples[i].second));
191         cumulative_size_array_.push_back(cumulative_size_array_.back());
192       } else {
193         cumulative_size_array_.push_back(cumulative_size_array_.back() +
194                                          CountStates(*(fst_tuples[i].second)));
195       }
196     }
197     state_table_ = new StateTable(
198         new ReplaceRootSelector<StateId, P>,
199         new ReplaceFstStateFingerprint<StateId, P>,
200         new ReplaceFingerprint<StateId, P>(&cumulative_size_array_),
201         root_size_,
202         root_size_ + cumulative_size_array_.back());
203   }
204 
VectorHashReplaceStateTable(const VectorHashReplaceStateTable<A,P> & table)205   VectorHashReplaceStateTable(const VectorHashReplaceStateTable<A, P> &table)
206       : root_size_(table.root_size_),
207         cumulative_size_array_(table.cumulative_size_array_) {
208     state_table_ = new StateTable(
209         new ReplaceRootSelector<StateId, P>,
210         new ReplaceFstStateFingerprint<StateId, P>,
211         new ReplaceFingerprint<StateId, P>(&cumulative_size_array_),
212         root_size_,
213         root_size_ + cumulative_size_array_.back());
214   }
215 
~VectorHashReplaceStateTable()216   ~VectorHashReplaceStateTable() {
217     delete state_table_;
218   }
219 
FindState(const StateTuple & tuple)220   StateId FindState(const StateTuple &tuple) {
221     return state_table_->FindState(tuple);
222   }
223 
Tuple(StateId id)224   const StateTuple &Tuple(StateId id) const {
225     return state_table_->Tuple(id);
226   }
227 
228  private:
229   StateId root_size_;
230   vector<uint64> cumulative_size_array_;
231   StateTable *state_table_;
232 };
233 
234 
235 // \class DefaultReplaceStateTable
236 // Default replace state table
237 template <class A, class P = ssize_t>
238 class DefaultReplaceStateTable : public CompactHashStateTable<
239   ReplaceStateTuple<typename A::StateId, P>,
240   ReplaceHash<typename A::StateId, P> > {
241  public:
242   typedef A Arc;
243   typedef typename A::StateId StateId;
244   typedef typename A::Label Label;
245   typedef P PrefixId;
246   typedef ReplaceStateTuple<StateId, P> StateTuple;
247   typedef CompactHashStateTable<StateTuple,
248                                 ReplaceHash<StateId, PrefixId> > StateTable;
249 
250   using StateTable::FindState;
251   using StateTable::Tuple;
252 
DefaultReplaceStateTable(const vector<pair<Label,const Fst<A> * >> & fst_tuples,Label root)253   DefaultReplaceStateTable(
254       const vector<pair<Label, const Fst<A>*> > &fst_tuples,
255       Label root) {}
256 
DefaultReplaceStateTable(const DefaultReplaceStateTable<A,P> & table)257   DefaultReplaceStateTable(const DefaultReplaceStateTable<A, P> &table)
258       : StateTable() {}
259 };
260 
261 //
262 // REPLACE FST CLASS
263 //
264 
265 // By default ReplaceFst will copy the input label of the 'replace arc'.
266 // For acceptors we do not want this behaviour. Instead we need to
267 // create an epsilon arc when recursing into the appropriate Fst.
268 // The 'epsilon_on_replace' option can be used to toggle this behaviour.
269 template <class A, class T = DefaultReplaceStateTable<A> >
270 struct ReplaceFstOptions : CacheOptions {
271   int64 root;    // root rule for expansion
272   bool  epsilon_on_replace;
273   bool  take_ownership;  // take ownership of input Fst(s)
274   T*    state_table;
275 
ReplaceFstOptionsReplaceFstOptions276   ReplaceFstOptions(const CacheOptions &opts, int64 r)
277       : CacheOptions(opts),
278         root(r),
279         epsilon_on_replace(false),
280         take_ownership(false),
281         state_table(0) {}
ReplaceFstOptionsReplaceFstOptions282   explicit ReplaceFstOptions(int64 r)
283       : root(r),
284         epsilon_on_replace(false),
285         take_ownership(false),
286         state_table(0) {}
ReplaceFstOptionsReplaceFstOptions287   ReplaceFstOptions(int64 r, bool epsilon_replace_arc)
288       : root(r),
289         epsilon_on_replace(epsilon_replace_arc),
290         take_ownership(false),
291         state_table(0) {}
ReplaceFstOptionsReplaceFstOptions292   ReplaceFstOptions()
293       : root(kNoLabel),
294         epsilon_on_replace(false),
295         take_ownership(false),
296         state_table(0) {}
297 };
298 
299 
300 // \class ReplaceFstImpl
301 // \brief Implementation class for replace class Fst
302 //
303 // The replace implementation class supports a dynamic
304 // expansion of a recursive transition network represented as Fst
305 // with dynamic replacable arcs.
306 //
307 template <class A, class T>
308 class ReplaceFstImpl : public CacheImpl<A> {
309   friend class ReplaceFstMatcher<A, T>;
310 
311  public:
312   using FstImpl<A>::SetType;
313   using FstImpl<A>::SetProperties;
314   using FstImpl<A>::WriteHeader;
315   using FstImpl<A>::SetInputSymbols;
316   using FstImpl<A>::SetOutputSymbols;
317   using FstImpl<A>::InputSymbols;
318   using FstImpl<A>::OutputSymbols;
319 
320   using CacheImpl<A>::PushArc;
321   using CacheImpl<A>::HasArcs;
322   using CacheImpl<A>::HasFinal;
323   using CacheImpl<A>::HasStart;
324   using CacheImpl<A>::SetArcs;
325   using CacheImpl<A>::SetFinal;
326   using CacheImpl<A>::SetStart;
327 
328   typedef typename A::Label   Label;
329   typedef typename A::Weight  Weight;
330   typedef typename A::StateId StateId;
331   typedef CacheState<A> State;
332   typedef A Arc;
333   typedef unordered_map<Label, Label> NonTerminalHash;
334 
335   typedef T StateTable;
336   typedef typename T::PrefixId PrefixId;
337   typedef ReplaceStateTuple<StateId, PrefixId> StateTuple;
338 
339   // constructor for replace class implementation.
340   // \param fst_tuples array of label/fst tuples, one for each non-terminal
ReplaceFstImpl(const vector<pair<Label,const Fst<A> * >> & fst_tuples,const ReplaceFstOptions<A,T> & opts)341   ReplaceFstImpl(const vector< pair<Label, const Fst<A>* > >& fst_tuples,
342                  const ReplaceFstOptions<A, T> &opts)
343       : CacheImpl<A>(opts),
344         epsilon_on_replace_(opts.epsilon_on_replace),
345         state_table_(opts.state_table ? opts.state_table :
346                      new StateTable(fst_tuples, opts.root)) {
347 
348     SetType("replace");
349 
350     if (fst_tuples.size() > 0) {
351       SetInputSymbols(fst_tuples[0].second->InputSymbols());
352       SetOutputSymbols(fst_tuples[0].second->OutputSymbols());
353     }
354 
355     bool all_negative = true;  // all nonterminals are negative?
356     bool dense_range = true;   // all nonterminals are positive
357                                // and form a dense range containing 1?
358     for (size_t i = 0; i < fst_tuples.size(); ++i) {
359       Label nonterminal = fst_tuples[i].first;
360       if (nonterminal >= 0)
361         all_negative = false;
362       if (nonterminal > fst_tuples.size() || nonterminal <= 0)
363         dense_range = false;
364     }
365 
366     vector<uint64> inprops;
367     bool all_ilabel_sorted = true;
368     bool all_olabel_sorted = true;
369     bool all_non_empty = true;
370     fst_array_.push_back(0);
371     for (size_t i = 0; i < fst_tuples.size(); ++i) {
372       Label label = fst_tuples[i].first;
373       const Fst<A> *fst = fst_tuples[i].second;
374       nonterminal_hash_[label] = fst_array_.size();
375       nonterminal_set_.insert(label);
376       fst_array_.push_back(opts.take_ownership ? fst : fst->Copy());
377       if (fst->Start() == kNoStateId)
378         all_non_empty = false;
379       if(!fst->Properties(kILabelSorted, false))
380         all_ilabel_sorted = false;
381       if(!fst->Properties(kOLabelSorted, false))
382         all_olabel_sorted = false;
383       inprops.push_back(fst->Properties(kCopyProperties, false));
384       if (i) {
385         if (!CompatSymbols(InputSymbols(), fst->InputSymbols())) {
386           FSTERROR() << "ReplaceFstImpl: input symbols of Fst " << i
387                      << " does not match input symbols of base Fst (0'th fst)";
388           SetProperties(kError, kError);
389         }
390         if (!CompatSymbols(OutputSymbols(), fst->OutputSymbols())) {
391           FSTERROR() << "ReplaceFstImpl: output symbols of Fst " << i
392                      << " does not match output symbols of base Fst "
393                      << "(0'th fst)";
394           SetProperties(kError, kError);
395         }
396       }
397     }
398     Label nonterminal = nonterminal_hash_[opts.root];
399     if ((nonterminal == 0) && (fst_array_.size() > 1)) {
400       FSTERROR() << "ReplaceFstImpl: no Fst corresponding to root label '"
401                  << opts.root << "' in the input tuple vector";
402       SetProperties(kError, kError);
403     }
404     root_ = (nonterminal > 0) ? nonterminal : 1;
405 
406     SetProperties(ReplaceProperties(inprops, root_ - 1, epsilon_on_replace_,
407                                     all_non_empty));
408     // We assume that all terminals are positive.  The resulting
409     // ReplaceFst is known to be kILabelSorted when all sub-FSTs are
410     // kILabelSorted and one of the 3 following conditions is satisfied:
411     //  1. 'epsilon_on_replace' is false, or
412     //  2. all non-terminals are negative, or
413     //  3. all non-terninals are positive and form a dense range containing 1.
414     if (all_ilabel_sorted &&
415         (!epsilon_on_replace_ || all_negative || dense_range))
416       SetProperties(kILabelSorted, kILabelSorted);
417     // Similarly, the resulting ReplaceFst is known to be
418     // kOLabelSorted when all sub-FSTs are kOLabelSorted and one of
419     // the 2 following conditions is satisfied:
420     //  1. all non-terminals are negative, or
421     //  2. all non-terninals are positive and form a dense range containing 1.
422     if (all_olabel_sorted && (all_negative || dense_range))
423       SetProperties(kOLabelSorted, kOLabelSorted);
424 
425     // Enable optional caching as long as sorted and all non empty.
426     if (Properties(kILabelSorted | kOLabelSorted) && all_non_empty)
427       always_cache_ = false;
428     else
429       always_cache_ = true;
430     VLOG(2) << "ReplaceFstImpl::ReplaceFstImpl: always_cache = "
431             << (always_cache_ ? "true" : "false");
432   }
433 
ReplaceFstImpl(const ReplaceFstImpl & impl)434   ReplaceFstImpl(const ReplaceFstImpl& impl)
435       : CacheImpl<A>(impl),
436         epsilon_on_replace_(impl.epsilon_on_replace_),
437         always_cache_(impl.always_cache_),
438         state_table_(new StateTable(*(impl.state_table_))),
439         nonterminal_set_(impl.nonterminal_set_),
440         nonterminal_hash_(impl.nonterminal_hash_),
441         root_(impl.root_) {
442     SetType("replace");
443     SetProperties(impl.Properties(), kCopyProperties);
444     SetInputSymbols(impl.InputSymbols());
445     SetOutputSymbols(impl.OutputSymbols());
446     fst_array_.reserve(impl.fst_array_.size());
447     fst_array_.push_back(0);
448     for (size_t i = 1; i < impl.fst_array_.size(); ++i) {
449       fst_array_.push_back(impl.fst_array_[i]->Copy(true));
450     }
451   }
452 
~ReplaceFstImpl()453   ~ReplaceFstImpl() {
454     VLOG(2) << "~ReplaceFstImpl: gc = "
455             << (CacheImpl<A>::GetCacheGc() ? "true" : "false")
456             << ", gc_size = " << CacheImpl<A>::GetCacheSize()
457             << ", gc_limit = " << CacheImpl<A>::GetCacheLimit();
458 
459     delete state_table_;
460     for (size_t i = 1; i < fst_array_.size(); ++i) {
461       delete fst_array_[i];
462     }
463   }
464 
465   // Computes the dependency graph of the replace class and returns
466   // true if the dependencies are cyclic. Cyclic dependencies will result
467   // in an un-expandable replace fst.
CyclicDependencies()468   bool CyclicDependencies() const {
469     ReplaceUtil<A> replace_util(fst_array_, nonterminal_hash_, root_);
470     return replace_util.CyclicDependencies();
471   }
472 
473   // Return or compute start state of replace fst
Start()474   StateId Start() {
475     if (!HasStart()) {
476       if (fst_array_.size() == 1) {      // no fsts defined for replace
477         SetStart(kNoStateId);
478         return kNoStateId;
479       } else {
480         const Fst<A>* fst = fst_array_[root_];
481         StateId fst_start = fst->Start();
482         if (fst_start == kNoStateId)  // root Fst is empty
483           return kNoStateId;
484 
485         PrefixId prefix = GetPrefixId(StackPrefix());
486         StateId start = state_table_->FindState(
487             StateTuple(prefix, root_, fst_start));
488         SetStart(start);
489         return start;
490       }
491     } else {
492       return CacheImpl<A>::Start();
493     }
494   }
495 
496   // return final weight of state (kInfWeight means state is not final)
Final(StateId s)497   Weight Final(StateId s) {
498     if (!HasFinal(s)) {
499       const StateTuple& tuple  = state_table_->Tuple(s);
500       const StackPrefix& stack = stackprefix_array_[tuple.prefix_id];
501       const Fst<A>* fst = fst_array_[tuple.fst_id];
502       StateId fst_state = tuple.fst_state;
503 
504       if (fst->Final(fst_state) != Weight::Zero() && stack.Depth() == 0)
505         SetFinal(s, fst->Final(fst_state));
506       else
507         SetFinal(s, Weight::Zero());
508     }
509     return CacheImpl<A>::Final(s);
510   }
511 
NumArcs(StateId s)512   size_t NumArcs(StateId s) {
513     if (HasArcs(s)) {  // If state cached, use the cached value.
514       return CacheImpl<A>::NumArcs(s);
515     } else if (always_cache_) {  // If always caching, expand and cache state.
516       Expand(s);
517       return CacheImpl<A>::NumArcs(s);
518     } else {  // Otherwise compute the number of arcs without expanding.
519       StateTuple tuple  = state_table_->Tuple(s);
520       if (tuple.fst_state == kNoStateId)
521         return 0;
522 
523       const Fst<A>* fst = fst_array_[tuple.fst_id];
524       size_t num_arcs = fst->NumArcs(tuple.fst_state);
525       if (ComputeFinalArc(tuple, 0))
526         num_arcs++;
527 
528       return num_arcs;
529     }
530   }
531 
532   // Returns whether a given label is a non terminal
IsNonTerminal(Label l)533   bool IsNonTerminal(Label l) const {
534     // TODO(allauzen): be smarter and take advantage of
535     // all_dense or all_negative.
536     // Use also in ComputeArc, this would require changes to replace
537     // so that recursing into an empty fst lead to a non co-accessible
538     // state instead of deleting the arc as done currently.
539     // Current use correct, since i/olabel sorted iff all_non_empty.
540     typename NonTerminalHash::const_iterator it =
541         nonterminal_hash_.find(l);
542     return it != nonterminal_hash_.end();
543   }
544 
NumInputEpsilons(StateId s)545   size_t NumInputEpsilons(StateId s) {
546     if (HasArcs(s)) {
547       // If state cached, use the cached value.
548       return CacheImpl<A>::NumInputEpsilons(s);
549     } else if (always_cache_ || !Properties(kILabelSorted)) {
550       // If always caching or if the number of input epsilons is too expensive
551       // to compute without caching (i.e. not ilabel sorted),
552       // then expand and cache state.
553       Expand(s);
554       return CacheImpl<A>::NumInputEpsilons(s);
555     } else {
556       // Otherwise, compute the number of input epsilons without caching.
557       StateTuple tuple  = state_table_->Tuple(s);
558       if (tuple.fst_state == kNoStateId)
559         return 0;
560       const Fst<A>* fst = fst_array_[tuple.fst_id];
561       size_t num  = 0;
562       if (!epsilon_on_replace_) {
563         // If epsilon_on_replace is false, all input epsilon arcs
564         // are also input epsilons arcs in the underlying machine.
565         fst->NumInputEpsilons(tuple.fst_state);
566       } else {
567         // Otherwise, one need to consider that all non-terminal arcs
568         // in the underlying machine also become input epsilon arc.
569         ArcIterator<Fst<A> > aiter(*fst, tuple.fst_state);
570         for (; !aiter.Done() &&
571                  ((aiter.Value().ilabel == 0) ||
572                   IsNonTerminal(aiter.Value().olabel));
573              aiter.Next())
574           ++num;
575       }
576       if (ComputeFinalArc(tuple, 0))
577         num++;
578       return num;
579     }
580   }
581 
NumOutputEpsilons(StateId s)582   size_t NumOutputEpsilons(StateId s) {
583     if (HasArcs(s)) {
584       // If state cached, use the cached value.
585       return CacheImpl<A>::NumOutputEpsilons(s);
586     } else if(always_cache_ || !Properties(kOLabelSorted)) {
587       // If always caching or if the number of output epsilons is too expensive
588       // to compute without caching (i.e. not olabel sorted),
589       // then expand and cache state.
590       Expand(s);
591       return CacheImpl<A>::NumOutputEpsilons(s);
592     } else {
593       // Otherwise, compute the number of output epsilons without caching.
594       StateTuple tuple  = state_table_->Tuple(s);
595       if (tuple.fst_state == kNoStateId)
596         return 0;
597       const Fst<A>* fst = fst_array_[tuple.fst_id];
598       size_t num  = 0;
599       ArcIterator<Fst<A> > aiter(*fst, tuple.fst_state);
600       for (; !aiter.Done() &&
601                ((aiter.Value().olabel == 0) ||
602                 IsNonTerminal(aiter.Value().olabel));
603            aiter.Next())
604         ++num;
605       if (ComputeFinalArc(tuple, 0))
606         num++;
607       return num;
608     }
609   }
610 
Properties()611   uint64 Properties() const { return Properties(kFstProperties); }
612 
613   // Set error if found; return FST impl properties.
Properties(uint64 mask)614   uint64 Properties(uint64 mask) const {
615     if (mask & kError) {
616       for (size_t i = 1; i < fst_array_.size(); ++i) {
617         if (fst_array_[i]->Properties(kError, false))
618           SetProperties(kError, kError);
619       }
620     }
621     return FstImpl<Arc>::Properties(mask);
622   }
623 
624   // return the base arc iterator, if arcs have not been computed yet,
625   // extend/recurse for new arcs.
InitArcIterator(StateId s,ArcIteratorData<A> * data)626   void InitArcIterator(StateId s, ArcIteratorData<A> *data) {
627     if (!HasArcs(s))
628       Expand(s);
629     CacheImpl<A>::InitArcIterator(s, data);
630     // TODO(allauzen): Set behaviour of generic iterator
631     // Warning: ArcIterator<ReplaceFst<A> >::InitCache()
632     // relies on current behaviour.
633   }
634 
635 
636   // Extend current state (walk arcs one level deep)
Expand(StateId s)637   void Expand(StateId s) {
638     StateTuple tuple = state_table_->Tuple(s);
639 
640     // If local fst is empty
641     if (tuple.fst_state == kNoStateId) {
642       SetArcs(s);
643       return;
644     }
645 
646     ArcIterator< Fst<A> > aiter(
647         *(fst_array_[tuple.fst_id]), tuple.fst_state);
648     Arc arc;
649 
650     // Create a final arc when needed
651     if (ComputeFinalArc(tuple, &arc))
652       PushArc(s, arc);
653 
654     // Expand all arcs leaving the state
655     for (;!aiter.Done(); aiter.Next()) {
656       if (ComputeArc(tuple, aiter.Value(), &arc))
657         PushArc(s, arc);
658     }
659 
660     SetArcs(s);
661   }
662 
Expand(StateId s,const StateTuple & tuple,const ArcIteratorData<A> & data)663   void Expand(StateId s, const StateTuple &tuple,
664               const ArcIteratorData<A> &data) {
665      // If local fst is empty
666     if (tuple.fst_state == kNoStateId) {
667       SetArcs(s);
668       return;
669     }
670 
671     ArcIterator< Fst<A> > aiter(data);
672     Arc arc;
673 
674     // Create a final arc when needed
675     if (ComputeFinalArc(tuple, &arc))
676       AddArc(s, arc);
677 
678     // Expand all arcs leaving the state
679     for (; !aiter.Done(); aiter.Next()) {
680       if (ComputeArc(tuple, aiter.Value(), &arc))
681         AddArc(s, arc);
682     }
683 
684     SetArcs(s);
685   }
686 
687   // If arcp == 0, only returns if a final arc is required, does not
688   // actually compute it.
689   bool ComputeFinalArc(const StateTuple &tuple, A* arcp,
690                        uint32 flags = kArcValueFlags) {
691     const Fst<A>* fst = fst_array_[tuple.fst_id];
692     StateId fst_state = tuple.fst_state;
693     if (fst_state == kNoStateId)
694       return false;
695 
696    // if state is final, pop up stack
697     const StackPrefix& stack = stackprefix_array_[tuple.prefix_id];
698     if (fst->Final(fst_state) != Weight::Zero() && stack.Depth()) {
699       if (arcp) {
700         arcp->ilabel = 0;
701         arcp->olabel = 0;
702         if (flags & kArcNextStateValue) {
703           PrefixId prefix_id = PopPrefix(stack);
704           const PrefixTuple& top = stack.Top();
705           arcp->nextstate = state_table_->FindState(
706               StateTuple(prefix_id, top.fst_id, top.nextstate));
707         }
708         if (flags & kArcWeightValue)
709           arcp->weight = fst->Final(fst_state);
710       }
711       return true;
712     } else {
713       return false;
714     }
715   }
716 
717   // Compute the arc in the replace fst corresponding to a given
718   // in the underlying machine. Returns false if the underlying arc
719   // corresponds to no arc in the replace.
720   bool ComputeArc(const StateTuple &tuple, const A &arc, A* arcp,
721                   uint32 flags = kArcValueFlags) {
722     if (!epsilon_on_replace_ &&
723         (flags == (flags & (kArcILabelValue | kArcWeightValue)))) {
724       *arcp = arc;
725       return true;
726     }
727 
728     if (arc.olabel == 0) {  // expand local fst
729       StateId nextstate = flags & kArcNextStateValue
730           ? state_table_->FindState(
731               StateTuple(tuple.prefix_id, tuple.fst_id, arc.nextstate))
732           : kNoStateId;
733       *arcp = A(arc.ilabel, arc.olabel, arc.weight, nextstate);
734     } else {
735       // check for non terminal
736       typename NonTerminalHash::const_iterator it =
737           nonterminal_hash_.find(arc.olabel);
738       if (it != nonterminal_hash_.end()) {  // recurse into non terminal
739         Label nonterminal = it->second;
740         const Fst<A>* nt_fst = fst_array_[nonterminal];
741         PrefixId nt_prefix = PushPrefix(stackprefix_array_[tuple.prefix_id],
742                                         tuple.fst_id, arc.nextstate);
743 
744         // if start state is valid replace, else arc is implicitly
745         // deleted
746         StateId nt_start = nt_fst->Start();
747         if (nt_start != kNoStateId) {
748           StateId nt_nextstate =  flags & kArcNextStateValue
749               ? state_table_->FindState(
750                   StateTuple(nt_prefix, nonterminal, nt_start))
751               : kNoStateId;
752           Label ilabel = (epsilon_on_replace_) ? 0 : arc.ilabel;
753           *arcp = A(ilabel, 0, arc.weight, nt_nextstate);
754         } else {
755           return false;
756         }
757       } else {
758         StateId nextstate = flags & kArcNextStateValue
759             ? state_table_->FindState(
760                 StateTuple(tuple.prefix_id, tuple.fst_id, arc.nextstate))
761             : kNoStateId;
762         *arcp = A(arc.ilabel, arc.olabel, arc.weight, nextstate);
763       }
764     }
765     return true;
766   }
767 
768   // Returns the arc iterator flags supported by this Fst.
ArcIteratorFlags()769   uint32 ArcIteratorFlags() const {
770     uint32 flags = kArcValueFlags;
771     if (!always_cache_)
772       flags |= kArcNoCache;
773     return flags;
774   }
775 
GetStateTable()776   T* GetStateTable() const {
777     return state_table_;
778   }
779 
GetFst(Label fst_id)780   const Fst<A>* GetFst(Label fst_id) const {
781     return fst_array_[fst_id];
782   }
783 
EpsilonOnReplace()784   bool EpsilonOnReplace() const { return epsilon_on_replace_; }
785 
786   // private helper classes
787  private:
788   static const size_t kPrime0;
789 
790   // \class PrefixTuple
791   // \brief Tuple of fst_id and destination state (entry in stack prefix)
792   struct PrefixTuple {
PrefixTuplePrefixTuple793     PrefixTuple(Label f, StateId s) : fst_id(f), nextstate(s) {}
794 
795     Label   fst_id;
796     StateId nextstate;
797   };
798 
799   // \class StackPrefix
800   // \brief Container for stack prefix.
801   class StackPrefix {
802    public:
StackPrefix()803     StackPrefix() {}
804 
805     // copy constructor
StackPrefix(const StackPrefix & x)806     StackPrefix(const StackPrefix& x) :
807         prefix_(x.prefix_) {
808     }
809 
Push(StateId fst_id,StateId nextstate)810     void Push(StateId fst_id, StateId nextstate) {
811       prefix_.push_back(PrefixTuple(fst_id, nextstate));
812     }
813 
Pop()814     void Pop() {
815       prefix_.pop_back();
816     }
817 
Top()818     const PrefixTuple& Top() const {
819       return prefix_[prefix_.size()-1];
820     }
821 
Depth()822     size_t Depth() const {
823       return prefix_.size();
824     }
825 
826    public:
827     vector<PrefixTuple> prefix_;
828   };
829 
830 
831   // \class StackPrefixEqual
832   // \brief Compare two stack prefix classes for equality
833   class StackPrefixEqual {
834    public:
operator()835     bool operator()(const StackPrefix& x, const StackPrefix& y) const {
836       if (x.prefix_.size() != y.prefix_.size()) return false;
837       for (size_t i = 0; i < x.prefix_.size(); ++i) {
838         if (x.prefix_[i].fst_id    != y.prefix_[i].fst_id ||
839            x.prefix_[i].nextstate != y.prefix_[i].nextstate) return false;
840       }
841       return true;
842     }
843   };
844 
845   //
846   // \class StackPrefixKey
847   // \brief Hash function for stack prefix to prefix id
848   class StackPrefixKey {
849    public:
operator()850     size_t operator()(const StackPrefix& x) const {
851       size_t sum = 0;
852       for (size_t i = 0; i < x.prefix_.size(); ++i) {
853         sum += x.prefix_[i].fst_id + x.prefix_[i].nextstate*kPrime0;
854       }
855       return sum;
856     }
857   };
858 
859   typedef unordered_map<StackPrefix, PrefixId, StackPrefixKey, StackPrefixEqual>
860   StackPrefixHash;
861 
862   // private methods
863  private:
864   // hash stack prefix (return unique index into stackprefix array)
GetPrefixId(const StackPrefix & prefix)865   PrefixId GetPrefixId(const StackPrefix& prefix) {
866     typename StackPrefixHash::iterator it = prefix_hash_.find(prefix);
867     if (it == prefix_hash_.end()) {
868       PrefixId prefix_id = stackprefix_array_.size();
869       stackprefix_array_.push_back(prefix);
870       prefix_hash_[prefix] = prefix_id;
871       return prefix_id;
872     } else {
873       return it->second;
874     }
875   }
876 
877   // prefix id after a stack pop
PopPrefix(StackPrefix prefix)878   PrefixId PopPrefix(StackPrefix prefix) {
879     prefix.Pop();
880     return GetPrefixId(prefix);
881   }
882 
883   // prefix id after a stack push
PushPrefix(StackPrefix prefix,Label fst_id,StateId nextstate)884   PrefixId PushPrefix(StackPrefix prefix, Label fst_id, StateId nextstate) {
885     prefix.Push(fst_id, nextstate);
886     return GetPrefixId(prefix);
887   }
888 
889 
890   // private data
891  private:
892   // runtime options
893   bool epsilon_on_replace_;
894   bool always_cache_;  // Optionally caching arc iterator disabled when true
895 
896   // state table
897   StateTable *state_table_;
898 
899   // cross index of unique stack prefix
900   // could potentially have one copy of prefix array
901   StackPrefixHash prefix_hash_;
902   vector<StackPrefix> stackprefix_array_;
903 
904   set<Label> nonterminal_set_;
905   NonTerminalHash nonterminal_hash_;
906   vector<const Fst<A>*> fst_array_;
907   Label root_;
908 
909   void operator=(const ReplaceFstImpl<A, T> &);  // disallow
910 };
911 
912 
913 template <class A, class T>
914 const size_t ReplaceFstImpl<A, T>::kPrime0 = 7853;
915 
916 //
917 // \class ReplaceFst
918 // \brief Recursivively replaces arcs in the root Fst with other Fsts.
919 // This version is a delayed Fst.
920 //
921 // ReplaceFst supports dynamic replacement of arcs in one Fst with
922 // another Fst. This replacement is recursive.  ReplaceFst can be used
923 // to support a variety of delayed constructions such as recursive
924 // transition networks, union, or closure.  It is constructed with an
925 // array of Fst(s). One Fst represents the root (or topology)
926 // machine. The root Fst refers to other Fsts by recursively replacing
927 // arcs labeled as non-terminals with the matching non-terminal
928 // Fst. Currently the ReplaceFst uses the output symbols of the arcs
929 // to determine whether the arc is a non-terminal arc or not. A
930 // non-terminal can be any label that is not a non-zero terminal label
931 // in the output alphabet.
932 //
933 // Note that the constructor uses a vector of pair<>. These correspond
934 // to the tuple of non-terminal Label and corresponding Fst. For example
935 // to implement the closure operation we need 2 Fsts. The first root
936 // Fst is a single Arc on the start State that self loops, it references
937 // the particular machine for which we are performing the closure operation.
938 //
939 // The ReplaceFst class supports an optionally caching arc iterator:
940 //    ArcIterator< ReplaceFst<A> >
941 // The ReplaceFst need to be built such that it is known to be ilabel
942 // or olabel sorted (see usage below).
943 //
944 // Observe that Matcher<Fst<A> > will use the optionally caching arc
945 // iterator when available (Fst is ilabel sorted and matching on the
946 // input, or Fst is olabel sorted and matching on the output).
947 // In order to obtain the most efficient behaviour, it is recommended
948 // to set 'epsilon_on_replace' to false (this means constructing acceptors
949 // as transducers with epsilons on the input side of nonterminal arcs)
950 // and matching on the input side.
951 //
952 // This class attaches interface to implementation and handles
953 // reference counting, delegating most methods to ImplToFst.
954 template <class A, class T = DefaultReplaceStateTable<A> >
955 class ReplaceFst : public ImplToFst< ReplaceFstImpl<A, T> > {
956  public:
957   friend class ArcIterator< ReplaceFst<A, T> >;
958   friend class StateIterator< ReplaceFst<A, T> >;
959   friend class ReplaceFstMatcher<A, T>;
960 
961   typedef A Arc;
962   typedef typename A::Label   Label;
963   typedef typename A::Weight  Weight;
964   typedef typename A::StateId StateId;
965   typedef CacheState<A> State;
966   typedef ReplaceFstImpl<A, T> Impl;
967 
968   using ImplToFst<Impl>::Properties;
969 
ReplaceFst(const vector<pair<Label,const Fst<A> * >> & fst_array,Label root)970   ReplaceFst(const vector<pair<Label, const Fst<A>* > >& fst_array,
971              Label root)
972       : ImplToFst<Impl>(new Impl(fst_array, ReplaceFstOptions<A, T>(root))) {}
973 
ReplaceFst(const vector<pair<Label,const Fst<A> * >> & fst_array,const ReplaceFstOptions<A,T> & opts)974   ReplaceFst(const vector<pair<Label, const Fst<A>* > >& fst_array,
975              const ReplaceFstOptions<A, T> &opts)
976       : ImplToFst<Impl>(new Impl(fst_array, opts)) {}
977 
978   // See Fst<>::Copy() for doc.
979   ReplaceFst(const ReplaceFst<A, T>& fst, bool safe = false)
980       : ImplToFst<Impl>(fst, safe) {}
981 
982   // Get a copy of this ReplaceFst. See Fst<>::Copy() for further doc.
983   virtual ReplaceFst<A, T> *Copy(bool safe = false) const {
984     return new ReplaceFst<A, T>(*this, safe);
985   }
986 
987   virtual inline void InitStateIterator(StateIteratorData<A> *data) const;
988 
InitArcIterator(StateId s,ArcIteratorData<A> * data)989   virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
990     GetImpl()->InitArcIterator(s, data);
991   }
992 
InitMatcher(MatchType match_type)993   virtual MatcherBase<A> *InitMatcher(MatchType match_type) const {
994     if ((GetImpl()->ArcIteratorFlags() & kArcNoCache) &&
995         ((match_type == MATCH_INPUT && Properties(kILabelSorted, false)) ||
996          (match_type == MATCH_OUTPUT && Properties(kOLabelSorted, false)))) {
997       return new ReplaceFstMatcher<A, T>(*this, match_type);
998     }
999     else {
1000       VLOG(2) << "Not using replace matcher";
1001       return 0;
1002     }
1003   }
1004 
CyclicDependencies()1005   bool CyclicDependencies() const {
1006     return GetImpl()->CyclicDependencies();
1007   }
1008 
1009  private:
1010   // Makes visible to friends.
GetImpl()1011   Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); }
1012 
1013   void operator=(const ReplaceFst<A> &fst);  // disallow
1014 };
1015 
1016 
1017 // Specialization for ReplaceFst.
1018 template<class A, class T>
1019 class StateIterator< ReplaceFst<A, T> >
1020     : public CacheStateIterator< ReplaceFst<A, T> > {
1021  public:
StateIterator(const ReplaceFst<A,T> & fst)1022   explicit StateIterator(const ReplaceFst<A, T> &fst)
1023       : CacheStateIterator< ReplaceFst<A, T> >(fst, fst.GetImpl()) {}
1024 
1025  private:
1026   DISALLOW_COPY_AND_ASSIGN(StateIterator);
1027 };
1028 
1029 
1030 // Specialization for ReplaceFst.
1031 // Implements optional caching. It can be used as follows:
1032 //
1033 //   ReplaceFst<A> replace;
1034 //   ArcIterator< ReplaceFst<A> > aiter(replace, s);
1035 //   // Note: ArcIterator< Fst<A> > is always a caching arc iterator.
1036 //   aiter.SetFlags(kArcNoCache, kArcNoCache);
1037 //   // Use the arc iterator, no arc will be cached, no state will be expanded.
1038 //   // The varied 'kArcValueFlags' can be used to decide which part
1039 //   // of arc values needs to be computed.
1040 //   aiter.SetFlags(kArcILabelValue, kArcValueFlags);
1041 //   // Only want the ilabel for this arc
1042 //   aiter.Value();  // Does not compute the destination state.
1043 //   aiter.Next();
1044 //   aiter.SetFlags(kArcNextStateValue, kArcNextStateValue);
1045 //   // Want both ilabel and nextstate for that arc
1046 //   aiter.Value();  // Does compute the destination state and inserts it
1047 //                   // in the replace state table.
1048 //   // No Arc has been cached at that point.
1049 //
1050 template <class A, class T>
1051 class ArcIterator< ReplaceFst<A, T> > {
1052  public:
1053   typedef A Arc;
1054   typedef typename A::StateId StateId;
1055 
ArcIterator(const ReplaceFst<A,T> & fst,StateId s)1056   ArcIterator(const ReplaceFst<A, T> &fst, StateId s)
1057       : fst_(fst), state_(s), pos_(0), offset_(0), flags_(0), arcs_(0),
1058         data_flags_(0), final_flags_(0) {
1059     cache_data_.ref_count = 0;
1060     local_data_.ref_count = 0;
1061 
1062     // If FST does not support optional caching, force caching.
1063     if(!(fst_.GetImpl()->ArcIteratorFlags() & kArcNoCache) &&
1064        !(fst_.GetImpl()->HasArcs(state_)))
1065        fst_.GetImpl()->Expand(state_);
1066 
1067     // If state is already cached, use cached arcs array.
1068     if (fst_.GetImpl()->HasArcs(state_)) {
1069       (fst_.GetImpl())->template CacheImpl<A>::InitArcIterator(state_,
1070                                                                &cache_data_);
1071       num_arcs_ = cache_data_.narcs;
1072       arcs_ = cache_data_.arcs;      // 'arcs_' is a ptr to the cached arcs.
1073       data_flags_ = kArcValueFlags;  // All the arc member values are valid.
1074     } else {  // Otherwise delay decision until Value() is called.
1075       tuple_ = fst_.GetImpl()->GetStateTable()->Tuple(state_);
1076       if (tuple_.fst_state == kNoStateId) {
1077         num_arcs_ = 0;
1078       } else {
1079         // The decision to cache or not to cache has been defered
1080         // until Value() or SetFlags() is called. However, the arc
1081         // iterator is set up now to be ready for non-caching in order
1082         // to keep the Value() method simple and efficient.
1083         const Fst<A>* fst = fst_.GetImpl()->GetFst(tuple_.fst_id);
1084         fst->InitArcIterator(tuple_.fst_state, &local_data_);
1085         // 'arcs_' is a pointer to the arcs in the underlying machine.
1086         arcs_ = local_data_.arcs;
1087         // Compute the final arc (but not its destination state)
1088         // if a final arc is required.
1089         bool has_final_arc = fst_.GetImpl()->ComputeFinalArc(
1090             tuple_,
1091             &final_arc_,
1092             kArcValueFlags & ~kArcNextStateValue);
1093         // Set the arc value flags that hold for 'final_arc_'.
1094         final_flags_ = kArcValueFlags & ~kArcNextStateValue;
1095         // Compute the number of arcs.
1096         num_arcs_ = local_data_.narcs;
1097         if (has_final_arc)
1098           ++num_arcs_;
1099         // Set the offset between the underlying arc positions and
1100         // the positions in the arc iterator.
1101         offset_ = num_arcs_ - local_data_.narcs;
1102         // Defers the decision to cache or not until Value() or
1103         // SetFlags() is called.
1104         data_flags_ = 0;
1105       }
1106     }
1107   }
1108 
~ArcIterator()1109   ~ArcIterator() {
1110     if (cache_data_.ref_count)
1111       --(*cache_data_.ref_count);
1112     if (local_data_.ref_count)
1113       --(*local_data_.ref_count);
1114   }
1115 
ExpandAndCache()1116   void ExpandAndCache() const   {
1117     // TODO(allauzen): revisit this
1118     // fst_.GetImpl()->Expand(state_, tuple_, local_data_);
1119     // (fst_.GetImpl())->CacheImpl<A>*>::InitArcIterator(state_,
1120     //                                               &cache_data_);
1121     //
1122     fst_.InitArcIterator(state_, &cache_data_);  // Expand and cache state.
1123     arcs_ = cache_data_.arcs;  // 'arcs_' is a pointer to the cached arcs.
1124     data_flags_ = kArcValueFlags;  // All the arc member values are valid.
1125     offset_ = 0;  // No offset
1126 
1127   }
1128 
Init()1129   void Init() {
1130     if (flags_ & kArcNoCache) {  // If caching is disabled
1131       // 'arcs_' is a pointer to the arcs in the underlying machine.
1132       arcs_ = local_data_.arcs;
1133       // Set the arcs value flags that hold for 'arcs_'.
1134       data_flags_ = kArcWeightValue;
1135       if (!fst_.GetImpl()->EpsilonOnReplace())
1136           data_flags_ |= kArcILabelValue;
1137       // Set the offset between the underlying arc positions and
1138       // the positions in the arc iterator.
1139       offset_ = num_arcs_ - local_data_.narcs;
1140     } else {  // Otherwise, expand and cache
1141       ExpandAndCache();
1142     }
1143   }
1144 
Done()1145   bool Done() const { return pos_ >= num_arcs_; }
1146 
Value()1147   const A& Value() const {
1148     // If 'data_flags_' was set to 0, non-caching was not requested
1149     if (!data_flags_) {
1150       // TODO(allauzen): revisit this.
1151       if (flags_ & kArcNoCache) {
1152         // Should never happen.
1153         FSTERROR() << "ReplaceFst: inconsistent arc iterator flags";
1154       }
1155       ExpandAndCache();  // Expand and cache.
1156     }
1157 
1158     if (pos_ - offset_ >= 0) {  // The requested arc is not the 'final' arc.
1159       const A& arc = arcs_[pos_ - offset_];
1160       if ((data_flags_ & flags_) == (flags_ & kArcValueFlags)) {
1161         // If the value flags for 'arc' match the recquired value flags
1162         // then return 'arc'.
1163         return arc;
1164       } else {
1165         // Otherwise, compute the corresponding arc on-the-fly.
1166         fst_.GetImpl()->ComputeArc(tuple_, arc, &arc_, flags_ & kArcValueFlags);
1167         return arc_;
1168       }
1169     } else {  // The requested arc is the 'final' arc.
1170       if ((final_flags_ & flags_) != (flags_ & kArcValueFlags)) {
1171         // If the arc value flags that hold for the final arc
1172         // do not match the requested value flags, then
1173         // 'final_arc_' needs to be updated.
1174         fst_.GetImpl()->ComputeFinalArc(tuple_, &final_arc_,
1175                                     flags_ & kArcValueFlags);
1176         final_flags_ = flags_ & kArcValueFlags;
1177       }
1178       return final_arc_;
1179     }
1180   }
1181 
Next()1182   void Next() { ++pos_; }
1183 
Position()1184   size_t Position() const { return pos_; }
1185 
Reset()1186   void Reset() { pos_ = 0;  }
1187 
Seek(size_t pos)1188   void Seek(size_t pos) { pos_ = pos; }
1189 
Flags()1190   uint32 Flags() const { return flags_; }
1191 
SetFlags(uint32 f,uint32 mask)1192   void SetFlags(uint32 f, uint32 mask) {
1193     // Update the flags taking into account what flags are supported
1194     // by the Fst.
1195     flags_ &= ~mask;
1196     flags_ |= (f & fst_.GetImpl()->ArcIteratorFlags());
1197     // If non-caching is not requested (and caching has not already
1198     // been performed), then flush 'data_flags_' to request caching
1199     // during the next call to Value().
1200     if (!(flags_ & kArcNoCache) && data_flags_ != kArcValueFlags) {
1201       if (!fst_.GetImpl()->HasArcs(state_))
1202          data_flags_ = 0;
1203     }
1204     // If 'data_flags_' has been flushed but non-caching is requested
1205     // before calling Value(), then set up the iterator for non-caching.
1206     if ((f & kArcNoCache) && (!data_flags_))
1207       Init();
1208   }
1209 
1210  private:
1211   const ReplaceFst<A, T> &fst_;           // Reference to the FST
1212   StateId state_;                         // State in the FST
1213   mutable typename T::StateTuple tuple_;  // Tuple corresponding to state_
1214 
1215   ssize_t pos_;             // Current position
1216   mutable ssize_t offset_;  // Offset between position in iterator and in arcs_
1217   ssize_t num_arcs_;        // Number of arcs at state_
1218   uint32 flags_;            // Behavorial flags for the arc iterator
1219   mutable Arc arc_;         // Memory to temporarily store computed arcs
1220 
1221   mutable ArcIteratorData<Arc> cache_data_;  // Arc iterator data in cache
1222   mutable ArcIteratorData<Arc> local_data_;  // Arc iterator data in local fst
1223 
1224   mutable const A* arcs_;       // Array of arcs
1225   mutable uint32 data_flags_;   // Arc value flags valid for data in arcs_
1226   mutable Arc final_arc_;       // Final arc (when required)
1227   mutable uint32 final_flags_;  // Arc value flags valid for final_arc_
1228 
1229   DISALLOW_COPY_AND_ASSIGN(ArcIterator);
1230 };
1231 
1232 
1233 template <class A, class T>
1234 class ReplaceFstMatcher : public MatcherBase<A> {
1235  public:
1236   typedef A Arc;
1237   typedef typename A::StateId StateId;
1238   typedef typename A::Label Label;
1239   typedef MultiEpsMatcher<Matcher<Fst<A> > > LocalMatcher;
1240 
ReplaceFstMatcher(const ReplaceFst<A,T> & fst,fst::MatchType match_type)1241   ReplaceFstMatcher(const ReplaceFst<A, T> &fst, fst::MatchType match_type)
1242       : fst_(fst),
1243         impl_(fst_.GetImpl()),
1244         s_(fst::kNoStateId),
1245         match_type_(match_type),
1246         current_loop_(false),
1247         final_arc_(false),
1248         loop_(fst::kNoLabel, 0, A::Weight::One(), fst::kNoStateId) {
1249     if (match_type_ == fst::MATCH_OUTPUT)
1250       swap(loop_.ilabel, loop_.olabel);
1251     InitMatchers();
1252   }
1253 
1254   ReplaceFstMatcher(const ReplaceFstMatcher<A, T> &matcher, bool safe = false)
1255       : fst_(matcher.fst_),
1256         impl_(fst_.GetImpl()),
1257         s_(fst::kNoStateId),
1258         match_type_(matcher.match_type_),
1259         current_loop_(false),
1260         loop_(fst::kNoLabel, 0, A::Weight::One(), fst::kNoStateId) {
1261     if (match_type_ == fst::MATCH_OUTPUT)
1262       swap(loop_.ilabel, loop_.olabel);
1263     InitMatchers();
1264   }
1265 
1266   // Create a local matcher for each component Fst of replace.
1267   // LocalMatcher is a multi epsilon wrapper matcher. MultiEpsilonMatcher
1268   // is used to match each non-terminal arc, since these non-terminal
1269   // turn into epsilons on recursion.
InitMatchers()1270   void InitMatchers() {
1271     const vector<const Fst<A>*>& fst_array = impl_->fst_array_;
1272     matcher_.resize(fst_array.size(), 0);
1273     for (size_t i = 0; i < fst_array.size(); ++i) {
1274       if (fst_array[i]) {
1275         matcher_[i] =
1276             new LocalMatcher(*fst_array[i], match_type_, kMultiEpsList);
1277 
1278         typename set<Label>::iterator it = impl_->nonterminal_set_.begin();
1279         for (; it != impl_->nonterminal_set_.end(); ++it) {
1280           matcher_[i]->AddMultiEpsLabel(*it);
1281         }
1282       }
1283     }
1284   }
1285 
1286   virtual ReplaceFstMatcher<A, T> *Copy(bool safe = false) const {
1287     return new ReplaceFstMatcher<A, T>(*this, safe);
1288   }
1289 
~ReplaceFstMatcher()1290   virtual ~ReplaceFstMatcher() {
1291     for (size_t i = 0; i < matcher_.size(); ++i)
1292       delete matcher_[i];
1293   }
1294 
Type(bool test)1295   virtual MatchType Type(bool test) const {
1296     if (match_type_ == MATCH_NONE)
1297       return match_type_;
1298 
1299     uint64 true_prop =  match_type_ == MATCH_INPUT ?
1300         kILabelSorted : kOLabelSorted;
1301     uint64 false_prop = match_type_ == MATCH_INPUT ?
1302         kNotILabelSorted : kNotOLabelSorted;
1303     uint64 props = fst_.Properties(true_prop | false_prop, test);
1304 
1305     if (props & true_prop)
1306       return match_type_;
1307     else if (props & false_prop)
1308       return MATCH_NONE;
1309     else
1310       return MATCH_UNKNOWN;
1311   }
1312 
GetFst()1313   virtual const Fst<A> &GetFst() const {
1314     return fst_;
1315   }
1316 
Properties(uint64 props)1317   virtual uint64 Properties(uint64 props) const {
1318     return props;
1319   }
1320 
1321  private:
1322   // Set the sate from which our matching happens.
SetState_(StateId s)1323   virtual void SetState_(StateId s) {
1324     if (s_ == s) return;
1325 
1326     s_ = s;
1327     tuple_ = impl_->GetStateTable()->Tuple(s_);
1328     if (tuple_.fst_state == kNoStateId) {
1329       done_ = true;
1330       return;
1331     }
1332     // Get current matcher. Used for non epsilon matching
1333     current_matcher_ = matcher_[tuple_.fst_id];
1334     current_matcher_->SetState(tuple_.fst_state);
1335     loop_.nextstate = s_;
1336 
1337     final_arc_ = false;
1338   }
1339 
1340   // Search for label, from previous set state. If label == 0, first
1341   // hallucinate and epsilon loop, else use the underlying matcher to
1342   // search for the label or epsilons.
1343   // - Note since the ReplaceFST recursion on non-terminal arcs causes
1344   //   epsilon transitions to be created we use the MultiEpsilonMatcher
1345   //   to search for possible matches of non terminals.
1346   // - If the component Fst reaches a final state we also need to add
1347   //   the exiting final arc.
Find_(Label label)1348   virtual bool Find_(Label label) {
1349     bool found = false;
1350     label_ = label;
1351     if (label_ == 0 || label_ == kNoLabel) {
1352       // Compute loop directly, saving Replace::ComputeArc
1353       if (label_ == 0) {
1354         current_loop_ = true;
1355         found = true;
1356       }
1357       // Search for matching multi epsilons
1358       final_arc_ = impl_->ComputeFinalArc(tuple_, 0);
1359       found = current_matcher_->Find(kNoLabel) || final_arc_ || found;
1360     } else {
1361       // Search on sub machine directly using sub machine matcher.
1362       found = current_matcher_->Find(label_);
1363     }
1364     return found;
1365   }
1366 
Done_()1367   virtual bool Done_() const {
1368     return !current_loop_ && !final_arc_ && current_matcher_->Done();
1369   }
1370 
Value_()1371   virtual const Arc& Value_() const {
1372     if (current_loop_) {
1373       return loop_;
1374     }
1375     if (final_arc_) {
1376       impl_->ComputeFinalArc(tuple_, &arc_);
1377       return arc_;
1378     }
1379     const Arc& component_arc = current_matcher_->Value();
1380     impl_->ComputeArc(tuple_, component_arc, &arc_);
1381     return arc_;
1382   }
1383 
Next_()1384   virtual void Next_() {
1385     if (current_loop_) {
1386       current_loop_ = false;
1387       return;
1388     }
1389     if (final_arc_) {
1390       final_arc_ = false;
1391       return;
1392     }
1393     current_matcher_->Next();
1394   }
1395 
1396   const ReplaceFst<A, T>& fst_;
1397   ReplaceFstImpl<A, T> *impl_;
1398   LocalMatcher* current_matcher_;
1399   vector<LocalMatcher*> matcher_;
1400 
1401   StateId s_;                        // Current state
1402   Label label_;                      // Current label
1403 
1404   MatchType match_type_;             // Supplied by caller
1405   mutable bool done_;
1406   mutable bool current_loop_;        // Current arc is the implicit loop
1407   mutable bool final_arc_;           // Current arc for exiting recursion
1408   mutable typename T::StateTuple tuple_;  // Tuple corresponding to state_
1409   mutable Arc arc_;
1410   Arc loop_;
1411 };
1412 
1413 template <class A, class T> inline
InitStateIterator(StateIteratorData<A> * data)1414 void ReplaceFst<A, T>::InitStateIterator(StateIteratorData<A> *data) const {
1415   data->base = new StateIterator< ReplaceFst<A, T> >(*this);
1416 }
1417 
1418 typedef ReplaceFst<StdArc> StdReplaceFst;
1419 
1420 
1421 // // Recursivively replaces arcs in the root Fst with other Fsts.
1422 // This version writes the result of replacement to an output MutableFst.
1423 //
1424 // Replace supports replacement of arcs in one Fst with another
1425 // Fst. This replacement is recursive.  Replace takes an array of
1426 // Fst(s). One Fst represents the root (or topology) machine. The root
1427 // Fst refers to other Fsts by recursively replacing arcs labeled as
1428 // non-terminals with the matching non-terminal Fst. Currently Replace
1429 // uses the output symbols of the arcs to determine whether the arc is
1430 // a non-terminal arc or not. A non-terminal can be any label that is
1431 // not a non-zero terminal label in the output alphabet.  Note that
1432 // input argument is a vector of pair<>. These correspond to the tuple
1433 // of non-terminal Label and corresponding Fst.
1434 template<class Arc>
Replace(const vector<pair<typename Arc::Label,const Fst<Arc> * >> & ifst_array,MutableFst<Arc> * ofst,typename Arc::Label root,bool epsilon_on_replace)1435 void Replace(const vector<pair<typename Arc::Label,
1436              const Fst<Arc>* > >& ifst_array,
1437              MutableFst<Arc> *ofst, typename Arc::Label root,
1438              bool epsilon_on_replace) {
1439   ReplaceFstOptions<Arc> opts(root, epsilon_on_replace);
1440   opts.gc_limit = 0;  // Cache only the last state for fastest copy.
1441   *ofst = ReplaceFst<Arc>(ifst_array, opts);
1442 }
1443 
1444 template<class Arc>
Replace(const vector<pair<typename Arc::Label,const Fst<Arc> * >> & ifst_array,MutableFst<Arc> * ofst,typename Arc::Label root)1445 void Replace(const vector<pair<typename Arc::Label,
1446              const Fst<Arc>* > >& ifst_array,
1447              MutableFst<Arc> *ofst, typename Arc::Label root) {
1448   Replace(ifst_array, ofst, root, false);
1449 }
1450 
1451 }  // namespace fst
1452 
1453 #endif  // FST_LIB_REPLACE_H__
1454