• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // compose.h
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 //
16 // \file
17 // Class to compute the composition of two FSTs
18 
19 #ifndef FST_LIB_COMPOSE_H__
20 #define FST_LIB_COMPOSE_H__
21 
22 #include <algorithm>
23 
24 #include <ext/hash_map>
25 using __gnu_cxx::hash_map;
26 
27 #include "fst/lib/cache.h"
28 #include "fst/lib/test-properties.h"
29 
30 namespace fst {
31 
32 // Enumeration of uint64 bits used to represent the user-defined
33 // properties of FST composition (in the template parameter to
34 // ComposeFstOptions<T>). The bits stand for extensions of generic FST
35 // composition. ComposeFstOptions<> (all the bits unset) is the "plain"
36 // compose without any extra extensions.
37 enum ComposeTypes {
38   // RHO: flags dealing with a special "rest" symbol in the FSTs.
39   // NB: at most one of the bits COMPOSE_FST1_RHO, COMPOSE_FST2_RHO
40   // may be set.
41   COMPOSE_FST1_RHO    = 1ULL<<0,  // "Rest" symbol on the output side of fst1.
42   COMPOSE_FST2_RHO    = 1ULL<<1,  // "Rest" symbol on the input side of fst2.
43   COMPOSE_FST1_PHI    = 1ULL<<2,  // "Failure" symbol on the output
44                                   // side of fst1.
45   COMPOSE_FST2_PHI    = 1ULL<<3,  // "Failure" symbol on the input side
46                                   // of fst2.
47   COMPOSE_FST1_SIGMA  = 1ULL<<4,  // "Any" symbol on the output side of
48                                   // fst1.
49   COMPOSE_FST2_SIGMA  = 1ULL<<5,  // "Any" symbol on the input side of
50                                   // fst2.
51   // Optimization related bits.
52   COMPOSE_GENERIC     = 1ULL<<32,  // Disables optimizations, applies
53                                    // the generic version of the
54                                    // composition algorithm. This flag
55                                    // is used for internal testing
56                                    // only.
57 
58   // -----------------------------------------------------------------
59   // Auxiliary enum values denoting specific combinations of
60   // bits. Internal use only.
61   COMPOSE_RHO         = COMPOSE_FST1_RHO | COMPOSE_FST2_RHO,
62   COMPOSE_PHI         = COMPOSE_FST1_PHI | COMPOSE_FST2_PHI,
63   COMPOSE_SIGMA       = COMPOSE_FST1_SIGMA | COMPOSE_FST2_SIGMA,
64   COMPOSE_SPECIAL_SYMBOLS = COMPOSE_RHO | COMPOSE_PHI | COMPOSE_SIGMA,
65 
66   // -----------------------------------------------------------------
67   // The following bits, denoting specific optimizations, are
68   // typically set *internally* by the composition algorithm.
69   COMPOSE_FST1_STRING = 1ULL<<33,  // fst1 is a string
70   COMPOSE_FST2_STRING = 1ULL<<34,  // fst2 is a string
71   COMPOSE_FST1_DET    = 1ULL<<35,  // fst1 is deterministic
72   COMPOSE_FST2_DET    = 1ULL<<36,  // fst2 is deterministic
73   COMPOSE_INTERNAL_MASK    = 0xffffffff00000000ULL
74 };
75 
76 
77 template <uint64 T = 0ULL>
78 struct ComposeFstOptions : public CacheOptions {
ComposeFstOptionsComposeFstOptions79   explicit ComposeFstOptions(const CacheOptions &opts) : CacheOptions(opts) {}
ComposeFstOptionsComposeFstOptions80   ComposeFstOptions() { }
81 };
82 
83 
84 // Abstract base for the implementation of delayed ComposeFst. The
85 // concrete specializations are templated on the (uint64-valued)
86 // properties of the FSTs being composed.
87 template <class A>
88 class ComposeFstImplBase : public CacheImpl<A> {
89  public:
90   using FstImpl<A>::SetType;
91   using FstImpl<A>::SetProperties;
92   using FstImpl<A>::Properties;
93   using FstImpl<A>::SetInputSymbols;
94   using FstImpl<A>::SetOutputSymbols;
95 
96   using CacheBaseImpl< CacheState<A> >::HasStart;
97   using CacheBaseImpl< CacheState<A> >::HasFinal;
98   using CacheBaseImpl< CacheState<A> >::HasArcs;
99 
100   typedef typename A::Label Label;
101   typedef typename A::Weight Weight;
102   typedef typename A::StateId StateId;
103   typedef CacheState<A> State;
104 
ComposeFstImplBase(const Fst<A> & fst1,const Fst<A> & fst2,const CacheOptions & opts)105   ComposeFstImplBase(const Fst<A> &fst1,
106                      const Fst<A> &fst2,
107                      const CacheOptions &opts)
108       :CacheImpl<A>(opts), fst1_(fst1.Copy()), fst2_(fst2.Copy()) {
109     SetType("compose");
110     uint64 props1 = fst1.Properties(kFstProperties, false);
111     uint64 props2 = fst2.Properties(kFstProperties, false);
112     SetProperties(ComposeProperties(props1, props2), kCopyProperties);
113 
114     if (!CompatSymbols(fst2.InputSymbols(), fst1.OutputSymbols()))
115       LOG(FATAL) << "ComposeFst: output symbol table of 1st argument "
116                  << "does not match input symbol table of 2nd argument";
117 
118     SetInputSymbols(fst1.InputSymbols());
119     SetOutputSymbols(fst2.OutputSymbols());
120   }
121 
~ComposeFstImplBase()122   virtual ~ComposeFstImplBase() {
123     delete fst1_;
124     delete fst2_;
125   }
126 
Start()127   StateId Start() {
128     if (!HasStart()) {
129       StateId start = ComputeStart();
130       if (start != kNoStateId) {
131         this->SetStart(start);
132       }
133     }
134     return CacheImpl<A>::Start();
135   }
136 
Final(StateId s)137   Weight Final(StateId s) {
138     if (!HasFinal(s)) {
139       Weight final = ComputeFinal(s);
140       this->SetFinal(s, final);
141     }
142     return CacheImpl<A>::Final(s);
143   }
144 
145   virtual void Expand(StateId s) = 0;
146 
NumArcs(StateId s)147   size_t NumArcs(StateId s) {
148     if (!HasArcs(s))
149       Expand(s);
150     return CacheImpl<A>::NumArcs(s);
151   }
152 
NumInputEpsilons(StateId s)153   size_t NumInputEpsilons(StateId s) {
154     if (!HasArcs(s))
155       Expand(s);
156     return CacheImpl<A>::NumInputEpsilons(s);
157   }
158 
NumOutputEpsilons(StateId s)159   size_t NumOutputEpsilons(StateId s) {
160     if (!HasArcs(s))
161       Expand(s);
162     return CacheImpl<A>::NumOutputEpsilons(s);
163   }
164 
InitArcIterator(StateId s,ArcIteratorData<A> * data)165   void InitArcIterator(StateId s, ArcIteratorData<A> *data) {
166     if (!HasArcs(s))
167       Expand(s);
168     CacheImpl<A>::InitArcIterator(s, data);
169   }
170 
171   // Access to flags encoding compose options/optimizations etc.  (for
172   // debugging).
173   virtual uint64 ComposeFlags() const = 0;
174 
175  protected:
176   virtual StateId ComputeStart() = 0;
177   virtual Weight ComputeFinal(StateId s) = 0;
178 
179   const Fst<A> *fst1_;            // first input Fst
180   const Fst<A> *fst2_;            // second input Fst
181 };
182 
183 
184 // The following class encapsulates implementation-dependent details
185 // of state tuple lookup, i.e. a bijective mapping from triples of two
186 // FST states and an epsilon filter state to the corresponding state
187 // IDs of the fst resulting from composition. The mapping must
188 // implement the [] operator in the style of STL associative
189 // containers (map, hash_map), i.e. table[x] must return a reference
190 // to the value associated with x. If x is an unassigned tuple, the
191 // operator must automatically associate x with value 0.
192 //
193 // NB: "table[x] == 0" for unassigned tuples x is required by the
194 // following off-by-one device used in the implementation of
195 // ComposeFstImpl. The value stored in the table is equal to tuple ID
196 // plus one, i.e. it is always a strictly positive number. Therefore,
197 // table[x] is equal to 0 if and only if x is an unassigned tuple (in
198 // which the algorithm assigns a new ID to x, and sets table[x] -
199 // stored in a reference - to "new ID + 1"). This form of lookup is
200 // more efficient than calling "find(x)" and "insert(make_pair(x, new
201 // ID))" if x is an unassigned tuple.
202 //
203 // The generic implementation is a wrapper around a hash_map.
204 template <class A, uint64 T>
205 class ComposeStateTable {
206  public:
207   typedef typename A::StateId StateId;
208 
209   struct StateTuple {
StateTupleStateTuple210     StateTuple() {}
StateTupleStateTuple211     StateTuple(StateId s1, StateId s2, int f)
212         : state_id1(s1), state_id2(s2), filt(f) {}
213     StateId state_id1;  // state Id on fst1
214     StateId state_id2;  // state Id on fst2
215     int filt;           // epsilon filter state
216   };
217 
ComposeStateTable()218   ComposeStateTable() {
219     StateTuple empty_tuple(kNoStateId, kNoStateId, 0);
220   }
221 
222   // NB: if 'tuple' is not in 'table_', the pair (tuple, StateId()) is
223   // inserted into 'table_' (standard STL container semantics). Since
224   // StateId is a built-in type, the explicit default constructor call
225   // StateId() returns 0.
226   StateId &operator[](const StateTuple &tuple) {
227     return table_[tuple];
228   }
229 
230  private:
231   // Comparison object for hashing StateTuple(s).
232   class StateTupleEqual {
233    public:
operator()234     bool operator()(const StateTuple& x, const StateTuple& y) const {
235       return x.state_id1 == y.state_id1 &&
236              x.state_id2 == y.state_id2 &&
237              x.filt == y.filt;
238     }
239   };
240 
241   static const int kPrime0 = 7853;
242   static const int kPrime1 = 7867;
243 
244   // Hash function for StateTuple to Fst states.
245   class StateTupleKey {
246    public:
operator()247     size_t operator()(const StateTuple& x) const {
248       return static_cast<size_t>(x.state_id1 +
249                                  x.state_id2 * kPrime0 +
250                                  x.filt * kPrime1);
251     }
252   };
253 
254   // Lookup table mapping state tuples to state IDs.
255   typedef hash_map<StateTuple,
256                          StateId,
257                          StateTupleKey,
258                          StateTupleEqual> StateTable;
259  // Actual table data.
260   StateTable table_;
261 
262   DISALLOW_EVIL_CONSTRUCTORS(ComposeStateTable);
263 };
264 
265 
266 // State tuple lookup table for the composition of a string FST with a
267 // deterministic FST.  The class maps state tuples to their unique IDs
268 // (i.e. states of the ComposeFst). Main optimization: due to the
269 // 1-to-1 correspondence between the states of the input string FST
270 // and those of the resulting (string) FST, a state tuple (s1, s2) is
271 // simply mapped to StateId s1. Hence, we use an STL vector as a
272 // lookup table. Template argument Fst1IsString specifies which FST is
273 // a string (this determines whether or not we index the lookup table
274 // by the first or by the second state).
275 template <class A, bool Fst1IsString>
276 class StringDetComposeStateTable {
277  public:
278   typedef typename A::StateId StateId;
279 
280   struct StateTuple {
281     typedef typename A::StateId StateId;
StateTupleStateTuple282     StateTuple() {}
StateTupleStateTuple283     StateTuple(StateId s1, StateId s2, int /* f */)
284         : state_id1(s1), state_id2(s2) {}
285     StateId state_id1;  // state Id on fst1
286     StateId state_id2;  // state Id on fst2
287     static const int filt = 0;  // 'fake' epsilon filter - only needed
288                                 // for API compatibility
289   };
290 
StringDetComposeStateTable()291   StringDetComposeStateTable() {}
292 
293   // Subscript operator. Behaves in a way similar to its map/hash_map
294   // counterpart, i.e. returns a reference to the value associated
295   // with 'tuple', inserting a 0 value if 'tuple' is unassigned.
296   StateId &operator[](const StateTuple &tuple) {
297     StateId index = Fst1IsString ? tuple.state_id1 : tuple.state_id2;
298     if (index >= (StateId)data_.size()) {
299       // NB: all values in [old_size; index] are initialized to 0.
300       data_.resize(index + 1);
301     }
302     return data_[index];
303   }
304 
305  private:
306   vector<StateId> data_;
307 
308   DISALLOW_EVIL_CONSTRUCTORS(StringDetComposeStateTable);
309 };
310 
311 
312 // Specializations of ComposeStateTable for the string/det case.
313 // Both inherit from StringDetComposeStateTable.
314 template <class A>
315 class ComposeStateTable<A, COMPOSE_FST1_STRING | COMPOSE_FST2_DET>
316     : public StringDetComposeStateTable<A, true> { };
317 
318 template <class A>
319 class ComposeStateTable<A, COMPOSE_FST2_STRING | COMPOSE_FST1_DET>
320     : public StringDetComposeStateTable<A, false> { };
321 
322 
323 // Parameterized implementation of FST composition for a pair of FSTs
324 // matching the property bit vector T. If possible,
325 // instantiation-specific switches in the code are based on the values
326 // of the bits in T, which are known at compile time, so unused code
327 // should be optimized away by the compiler.
328 template <class A, uint64 T>
329 class ComposeFstImpl : public ComposeFstImplBase<A> {
330   typedef typename A::StateId StateId;
331   typedef typename A::Label   Label;
332   typedef typename A::Weight  Weight;
333   using FstImpl<A>::SetType;
334   using FstImpl<A>::SetProperties;
335 
336   enum FindType { FIND_INPUT  = 1,          // find input label on fst2
337                   FIND_OUTPUT = 2,          // find output label on fst1
338                   FIND_BOTH   = 3 };        // find choice state dependent
339 
340   typedef ComposeStateTable<A, T & COMPOSE_INTERNAL_MASK> StateTupleTable;
341   typedef typename StateTupleTable::StateTuple StateTuple;
342 
343  public:
ComposeFstImpl(const Fst<A> & fst1,const Fst<A> & fst2,const CacheOptions & opts)344   ComposeFstImpl(const Fst<A> &fst1,
345                  const Fst<A> &fst2,
346                  const CacheOptions &opts)
347       :ComposeFstImplBase<A>(fst1, fst2, opts) {
348 
349     bool osorted = fst1.Properties(kOLabelSorted, false);
350     bool isorted = fst2.Properties(kILabelSorted, false);
351 
352     switch (T & COMPOSE_SPECIAL_SYMBOLS) {
353       case COMPOSE_FST1_RHO:
354       case COMPOSE_FST1_PHI:
355       case COMPOSE_FST1_SIGMA:
356         if (!osorted || FLAGS_fst_verify_properties)
357           osorted = fst1.Properties(kOLabelSorted, true);
358         if (!osorted)
359           LOG(FATAL) << "ComposeFst: 1st argument not output label "
360                      << "sorted (special symbols present)";
361         break;
362       case COMPOSE_FST2_RHO:
363       case COMPOSE_FST2_PHI:
364       case COMPOSE_FST2_SIGMA:
365         if (!isorted || FLAGS_fst_verify_properties)
366           isorted = fst2.Properties(kILabelSorted, true);
367         if (!isorted)
368           LOG(FATAL) << "ComposeFst: 2nd argument not input label "
369                      << "sorted (special symbols present)";
370         break;
371       case 0:
372         if (!isorted && !osorted || FLAGS_fst_verify_properties) {
373           osorted = fst1.Properties(kOLabelSorted, true);
374           if (!osorted)
375             isorted = fst2.Properties(kILabelSorted, true);
376         }
377         break;
378       default:
379         LOG(FATAL)
380           << "ComposeFst: More than one special symbol used in composition";
381     }
382 
383     if (isorted && (T & COMPOSE_FST2_SIGMA)) {
384       find_type_ = FIND_INPUT;
385     } else if (osorted && (T & COMPOSE_FST1_SIGMA)) {
386       find_type_ = FIND_OUTPUT;
387     } else if (isorted && (T & COMPOSE_FST2_PHI)) {
388       find_type_ = FIND_INPUT;
389     } else if (osorted && (T & COMPOSE_FST1_PHI)) {
390       find_type_ = FIND_OUTPUT;
391     } else if (isorted && (T & COMPOSE_FST2_RHO)) {
392       find_type_ = FIND_INPUT;
393     } else if (osorted && (T & COMPOSE_FST1_RHO)) {
394       find_type_ = FIND_OUTPUT;
395     } else if (isorted && (T & COMPOSE_FST1_STRING)) {
396       find_type_ = FIND_INPUT;
397     } else if(osorted && (T & COMPOSE_FST2_STRING)) {
398       find_type_ = FIND_OUTPUT;
399     } else if (isorted && osorted) {
400       find_type_ = FIND_BOTH;
401     } else if (isorted) {
402       find_type_ = FIND_INPUT;
403     } else if (osorted) {
404       find_type_ = FIND_OUTPUT;
405     } else {
406       LOG(FATAL) << "ComposeFst: 1st argument not output label sorted "
407                  << "and 2nd argument is not input label sorted";
408     }
409   }
410 
411   // Finds/creates an Fst state given a StateTuple.  Only creates a new
412   // state if StateTuple is not found in the state hash.
413   //
414   // The method exploits the following device: all pairs stored in the
415   // associative container state_tuple_table_ are of the form (tuple,
416   // id(tuple) + 1), i.e. state_tuple_table_[tuple] > 0 if tuple has
417   // been stored previously. For unassigned tuples, the call to
418   // state_tuple_table_[tuple] creates a new pair (tuple, 0). As a
419   // result, state_tuple_table_[tuple] == 0 iff tuple is new.
FindState(const StateTuple & tuple)420   StateId FindState(const StateTuple& tuple) {
421     StateId &assoc_value = state_tuple_table_[tuple];
422     if (assoc_value == 0) {  // tuple wasn't present in lookup table:
423                              // assign it a new ID.
424       state_tuples_.push_back(tuple);
425       assoc_value = state_tuples_.size();
426     }
427     return assoc_value - 1;  // NB: assoc_value = ID + 1
428   }
429 
430   // Generates arc for composition state s from matched input Fst arcs.
AddArc(StateId s,const A & arca,const A & arcb,int f,bool find_input)431   void AddArc(StateId s, const A &arca, const A &arcb, int f,
432               bool find_input) {
433     A arc;
434     if (find_input) {
435       arc.ilabel = arcb.ilabel;
436       arc.olabel = arca.olabel;
437       arc.weight = Times(arcb.weight, arca.weight);
438       StateTuple tuple(arcb.nextstate, arca.nextstate, f);
439       arc.nextstate = FindState(tuple);
440     } else {
441       arc.ilabel = arca.ilabel;
442       arc.olabel = arcb.olabel;
443       arc.weight = Times(arca.weight, arcb.weight);
444       StateTuple tuple(arca.nextstate, arcb.nextstate, f);
445       arc.nextstate = FindState(tuple);
446     }
447     CacheImpl<A>::AddArc(s, arc);
448   }
449 
450   // Arranges it so that the first arg to OrderedExpand is the Fst
451   // that will be passed to FindLabel.
Expand(StateId s)452   void Expand(StateId s) {
453     StateTuple &tuple = state_tuples_[s];
454     StateId s1 = tuple.state_id1;
455     StateId s2 = tuple.state_id2;
456     int f = tuple.filt;
457     if (find_type_ == FIND_INPUT)
458       OrderedExpand(s, ComposeFstImplBase<A>::fst2_, s2,
459                     ComposeFstImplBase<A>::fst1_, s1, f, true);
460     else
461       OrderedExpand(s, ComposeFstImplBase<A>::fst1_, s1,
462                     ComposeFstImplBase<A>::fst2_, s2, f, false);
463   }
464 
465   // Access to flags encoding compose options/optimizations etc.  (for
466   // debugging).
ComposeFlags()467   virtual uint64 ComposeFlags() const { return T; }
468 
469  private:
470   // This does that actual matching of labels in the composition. The
471   // arguments are ordered so FindLabel is called with state SA of
472   // FSTA for each arc leaving state SB of FSTB. The FIND_INPUT arg
473   // determines whether the input or output label of arcs at SB is
474   // the one to match on.
OrderedExpand(StateId s,const Fst<A> * fsta,StateId sa,const Fst<A> * fstb,StateId sb,int f,bool find_input)475   void OrderedExpand(StateId s, const Fst<A> *fsta, StateId sa,
476                      const Fst<A> *fstb, StateId sb, int f, bool find_input) {
477 
478     size_t numarcsa = fsta->NumArcs(sa);
479     size_t numepsa = find_input ? fsta->NumInputEpsilons(sa) :
480                      fsta->NumOutputEpsilons(sa);
481     bool finala = fsta->Final(sa) != Weight::Zero();
482     ArcIterator< Fst<A> > aitera(*fsta, sa);
483     // First handle special epsilons and sigmas on FSTA
484     for (; !aitera.Done(); aitera.Next()) {
485       const A &arca = aitera.Value();
486       Label match_labela = find_input ? arca.ilabel : arca.olabel;
487       if (match_labela > 0) {
488         break;
489       }
490       if ((T & COMPOSE_SIGMA) != 0 &&  match_labela == kSigmaLabel) {
491         // Found a sigma? Match it against all (non-special) symbols
492         // on side b.
493         for (ArcIterator< Fst<A> > aiterb(*fstb, sb);
494              !aiterb.Done();
495              aiterb.Next()) {
496           const A &arcb = aiterb.Value();
497           Label labelb = find_input ? arcb.olabel : arcb.ilabel;
498           if (labelb <= 0) continue;
499           AddArc(s, arca, arcb, 0, find_input);
500         }
501       } else if (f == 0 && match_labela == 0) {
502         A earcb(0, 0, Weight::One(), sb);
503         AddArc(s, arca, earcb, 0, find_input);  // move forward on epsilon
504       }
505     }
506     // Next handle non-epsilon matches, rho labels, and epsilons on FSTB
507     for (ArcIterator< Fst<A> > aiterb(*fstb, sb);
508          !aiterb.Done();
509          aiterb.Next()) {
510       const A &arcb = aiterb.Value();
511       Label match_labelb = find_input ? arcb.olabel : arcb.ilabel;
512       if (match_labelb) {  // Consider non-epsilon match
513         if (FindLabel(&aitera, numarcsa, match_labelb, find_input)) {
514           for (; !aitera.Done(); aitera.Next()) {
515             const A &arca = aitera.Value();
516             Label match_labela = find_input ? arca.ilabel : arca.olabel;
517             if (match_labela != match_labelb)
518               break;
519             AddArc(s, arca, arcb, 0, find_input);  // move forward on match
520           }
521         } else if ((T & COMPOSE_SPECIAL_SYMBOLS) != 0) {
522           // If there is no transition labelled 'match_labelb' in
523           // fsta, try matching 'match_labelb' against special symbols
524           // (Phi, Rho,...).
525           for (aitera.Reset(); !aitera.Done(); aitera.Next()) {
526             A arca = aitera.Value();
527             Label labela = find_input ? arca.ilabel : arca.olabel;
528             if (labela >= 0) {
529               break;
530             } else if (((T & COMPOSE_PHI) != 0) && (labela == kPhiLabel)) {
531               // Case 1: if a failure transition exists, follow its
532               // transitive closure until a) a transition labelled
533               // 'match_labelb' is found, or b) the initial state of
534               // fsta is reached.
535 
536               StateId sf = sa;  // Start of current failure transition.
537               while (labela == kPhiLabel && sf != arca.nextstate) {
538                 sf = arca.nextstate;
539 
540                 size_t numarcsf = fsta->NumArcs(sf);
541                 ArcIterator< Fst<A> > aiterf(*fsta, sf);
542                 if (FindLabel(&aiterf, numarcsf, match_labelb, find_input)) {
543                   // Sub-case 1a: there exists a transition starting
544                   // in sf and consuming symbol 'match_labelb'.
545                   AddArc(s, aiterf.Value(), arcb, 0, find_input);
546                   break;
547                 } else {
548                   // No transition labelled 'match_labelb' found: try
549                   // next failure transition (starting at 'sf').
550                   for (aiterf.Reset(); !aiterf.Done(); aiterf.Next()) {
551                     arca = aiterf.Value();
552                     labela = find_input ? arca.ilabel : arca.olabel;
553                     if (labela >= kPhiLabel) break;
554                   }
555                 }
556               }
557               if (labela == kPhiLabel && sf == arca.nextstate) {
558                 // Sub-case 1b: failure transitions lead to start
559                 // state without finding a matching
560                 // transition. Therefore, we generate a loop in start
561                 // state of fsta.
562                 A loop(match_labelb, match_labelb, Weight::One(), sf);
563                 AddArc(s, loop, arcb, 0, find_input);
564               }
565             } else if (((T & COMPOSE_RHO) != 0) && (labela == kRhoLabel)) {
566               // Case 2: 'match_labelb' can be matched against a
567               // "rest" (rho) label in fsta.
568               if (find_input) {
569                 arca.ilabel = match_labelb;
570                 if (arca.olabel == kRhoLabel)
571                   arca.olabel = match_labelb;
572               } else {
573                 arca.olabel = match_labelb;
574                 if (arca.ilabel == kRhoLabel)
575                   arca.ilabel = match_labelb;
576               }
577               AddArc(s, arca, arcb, 0, find_input);  // move fwd on match
578             }
579           }
580         }
581       } else if (numepsa != numarcsa || finala) {  // Handle FSTB epsilon
582         A earca(0, 0, Weight::One(), sa);
583         AddArc(s, earca, arcb, numepsa > 0, find_input);  // move on epsilon
584       }
585     }
586     this->SetArcs(s);
587    }
588 
589 
590   // Finds matches to MATCH_LABEL in arcs given by AITER
591   // using FIND_INPUT to determine whether to look on input or output.
FindLabel(ArcIterator<Fst<A>> * aiter,size_t numarcs,Label match_label,bool find_input)592   bool FindLabel(ArcIterator< Fst<A> > *aiter, size_t numarcs,
593                  Label match_label, bool find_input) {
594     // binary search for match
595     size_t low = 0;
596     size_t high = numarcs;
597     while (low < high) {
598       size_t mid = (low + high) / 2;
599       aiter->Seek(mid);
600       Label label = find_input ?
601                     aiter->Value().ilabel : aiter->Value().olabel;
602       if (label > match_label) {
603         high = mid;
604       } else if (label < match_label) {
605         low = mid + 1;
606       } else {
607         // find first matching label (when non-determinism)
608         for (size_t i = mid; i > low; --i) {
609           aiter->Seek(i - 1);
610           label = find_input ? aiter->Value().ilabel : aiter->Value().olabel;
611           if (label != match_label) {
612             aiter->Seek(i);
613             return true;
614           }
615         }
616         return true;
617       }
618     }
619     return false;
620   }
621 
ComputeStart()622   StateId ComputeStart() {
623     StateId s1 = ComposeFstImplBase<A>::fst1_->Start();
624     StateId s2 = ComposeFstImplBase<A>::fst2_->Start();
625     if (s1 == kNoStateId || s2 == kNoStateId)
626       return kNoStateId;
627     StateTuple tuple(s1, s2, 0);
628     return FindState(tuple);
629   }
630 
ComputeFinal(StateId s)631   Weight ComputeFinal(StateId s) {
632     StateTuple &tuple = state_tuples_[s];
633     Weight final = Times(ComposeFstImplBase<A>::fst1_->Final(tuple.state_id1),
634                          ComposeFstImplBase<A>::fst2_->Final(tuple.state_id2));
635     return final;
636   }
637 
638 
639   FindType find_type_;            // find label on which side?
640 
641   // Maps from StateId to StateTuple.
642   vector<StateTuple> state_tuples_;
643 
644   // Maps from StateTuple to StateId.
645   StateTupleTable state_tuple_table_;
646 
647   DISALLOW_EVIL_CONSTRUCTORS(ComposeFstImpl);
648 };
649 
650 
651 // Computes the composition of two transducers. This version is a
652 // delayed Fst. If FST1 transduces string x to y with weight a and FST2
653 // transduces y to z with weight b, then their composition transduces
654 // string x to z with weight Times(x, z).
655 //
656 // The output labels of the first transducer or the input labels of
657 // the second transducer must be sorted.  The weights need to form a
658 // commutative semiring (valid for TropicalWeight and LogWeight).
659 //
660 // Complexity:
661 // Assuming the first FST is unsorted and the second is sorted:
662 // - Time: O(v1 v2 d1 (log d2 + m2)),
663 // - Space: O(v1 v2)
664 // where vi = # of states visited, di = maximum out-degree, and mi the
665 // maximum multiplicity of the states visited for the ith
666 // FST. Constant time and space to visit an input state or arc is
667 // assumed and exclusive of caching.
668 //
669 // Caveats:
670 // - ComposeFst does not trim its output (since it is a delayed operation).
671 // - The efficiency of composition can be strongly affected by several factors:
672 //   - the choice of which tnansducer is sorted - prefer sorting the FST
673 //     that has the greater average out-degree.
674 //   - the amount of non-determinism
675 //   - the presence and location of epsilon transitions - avoid epsilon
676 //     transitions on the output side of the first transducer or
677 //     the input side of the second transducer or prefer placing
678 //     them later in a path since they delay matching and can
679 //     introduce non-coaccessible states and transitions.
680 template <class A>
681 class ComposeFst : public Fst<A> {
682  public:
683   friend class ArcIterator< ComposeFst<A> >;
684   friend class CacheStateIterator< ComposeFst<A> >;
685   friend class CacheArcIterator< ComposeFst<A> >;
686 
687   typedef A Arc;
688   typedef typename A::Weight Weight;
689   typedef typename A::StateId StateId;
690   typedef CacheState<A> State;
691 
ComposeFst(const Fst<A> & fst1,const Fst<A> & fst2)692   ComposeFst(const Fst<A> &fst1, const Fst<A> &fst2)
693       : impl_(Init(fst1, fst2, ComposeFstOptions<>())) { }
694 
695   template <uint64 T>
ComposeFst(const Fst<A> & fst1,const Fst<A> & fst2,const ComposeFstOptions<T> & opts)696   ComposeFst(const Fst<A> &fst1,
697              const Fst<A> &fst2,
698              const ComposeFstOptions<T> &opts)
699       : impl_(Init(fst1, fst2, opts)) { }
700 
ComposeFst(const ComposeFst<A> & fst)701   ComposeFst(const ComposeFst<A> &fst) : Fst<A>(fst), impl_(fst.impl_) {
702     impl_->IncrRefCount();
703   }
704 
~ComposeFst()705   virtual ~ComposeFst() { if (!impl_->DecrRefCount()) delete impl_;  }
706 
Start()707   virtual StateId Start() const { return impl_->Start(); }
708 
Final(StateId s)709   virtual Weight Final(StateId s) const { return impl_->Final(s); }
710 
NumArcs(StateId s)711   virtual size_t NumArcs(StateId s) const { return impl_->NumArcs(s); }
712 
NumInputEpsilons(StateId s)713   virtual size_t NumInputEpsilons(StateId s) const {
714     return impl_->NumInputEpsilons(s);
715   }
716 
NumOutputEpsilons(StateId s)717   virtual size_t NumOutputEpsilons(StateId s) const {
718     return impl_->NumOutputEpsilons(s);
719   }
720 
Properties(uint64 mask,bool test)721   virtual uint64 Properties(uint64 mask, bool test) const {
722     if (test) {
723       uint64 known, test = TestProperties(*this, mask, &known);
724       impl_->SetProperties(test, known);
725       return test & mask;
726     } else {
727       return impl_->Properties(mask);
728     }
729   }
730 
Type()731   virtual const string& Type() const { return impl_->Type(); }
732 
Copy()733   virtual ComposeFst<A> *Copy() const {
734     return new ComposeFst<A>(*this);
735   }
736 
InputSymbols()737   virtual const SymbolTable* InputSymbols() const {
738     return impl_->InputSymbols();
739   }
740 
OutputSymbols()741   virtual const SymbolTable* OutputSymbols() const {
742     return impl_->OutputSymbols();
743   }
744 
745   virtual inline void InitStateIterator(StateIteratorData<A> *data) const;
746 
InitArcIterator(StateId s,ArcIteratorData<A> * data)747   virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
748     impl_->InitArcIterator(s, data);
749   }
750 
751   // Access to flags encoding compose options/optimizations etc.  (for
752   // debugging).
ComposeFlags()753   uint64 ComposeFlags() const { return impl_->ComposeFlags(); }
754 
755  protected:
Impl()756   ComposeFstImplBase<A> *Impl() { return impl_; }
757 
758  private:
759   ComposeFstImplBase<A> *impl_;
760 
761   // Auxiliary method encapsulating the creation of a ComposeFst
762   // implementation that is appropriate for the properties of fst1 and
763   // fst2.
764   template <uint64 T>
Init(const Fst<A> & fst1,const Fst<A> & fst2,const ComposeFstOptions<T> & opts)765   static ComposeFstImplBase<A> *Init(
766       const Fst<A> &fst1,
767       const Fst<A> &fst2,
768       const ComposeFstOptions<T> &opts) {
769 
770     // Filter for sort properties (forces a property check).
771     uint64 sort_props_mask = kILabelSorted | kOLabelSorted;
772     // Filter for optimization-related properties (does not force a
773     // property-check).
774     uint64 opt_props_mask =
775       kString | kIDeterministic | kODeterministic | kNoIEpsilons |
776       kNoOEpsilons;
777 
778     uint64 props1 = fst1.Properties(sort_props_mask, true);
779     uint64 props2 = fst2.Properties(sort_props_mask, true);
780 
781     props1 |= fst1.Properties(opt_props_mask, false);
782     props2 |= fst2.Properties(opt_props_mask, false);
783 
784     if (!(Weight::Properties() & kCommutative)) {
785       props1 |= fst1.Properties(kUnweighted, true);
786       props2 |= fst2.Properties(kUnweighted, true);
787       if (!(props1 & kUnweighted) && !(props2 & kUnweighted))
788         LOG(FATAL) << "ComposeFst: Weight needs to be a commutative semiring: "
789                    << Weight::Type();
790     }
791 
792     // Case 1: flag COMPOSE_GENERIC disables optimizations.
793     if (T & COMPOSE_GENERIC) {
794       return new ComposeFstImpl<A, T>(fst1, fst2, opts);
795     }
796 
797     const uint64 kStringDetOptProps =
798       kIDeterministic | kILabelSorted | kNoIEpsilons;
799     const uint64 kDetStringOptProps =
800       kODeterministic | kOLabelSorted | kNoOEpsilons;
801 
802     // Case 2: fst1 is a string, fst2 is deterministic and epsilon-free.
803     if ((props1 & kString) &&
804         !(T & (COMPOSE_FST1_RHO | COMPOSE_FST1_PHI | COMPOSE_FST1_SIGMA)) &&
805         ((props2 & kStringDetOptProps) == kStringDetOptProps)) {
806       return new ComposeFstImpl<A, T | COMPOSE_FST1_STRING | COMPOSE_FST2_DET>(
807           fst1, fst2, opts);
808     }
809     // Case 3: fst1 is deterministic and epsilon-free, fst2 is string.
810     if ((props2 & kString) &&
811         !(T & (COMPOSE_FST1_RHO | COMPOSE_FST1_PHI | COMPOSE_FST1_SIGMA)) &&
812         ((props1 & kDetStringOptProps) == kDetStringOptProps)) {
813       return new ComposeFstImpl<A, T | COMPOSE_FST2_STRING | COMPOSE_FST1_DET>(
814           fst1, fst2, opts);
815     }
816 
817     // Default case: no optimizations.
818     return new ComposeFstImpl<A, T>(fst1, fst2, opts);
819   }
820 
821   void operator=(const ComposeFst<A> &fst);  // disallow
822 };
823 
824 
825 // Specialization for ComposeFst.
826 template<class A>
827 class StateIterator< ComposeFst<A> >
828     : public CacheStateIterator< ComposeFst<A> > {
829  public:
StateIterator(const ComposeFst<A> & fst)830   explicit StateIterator(const ComposeFst<A> &fst)
831       : CacheStateIterator< ComposeFst<A> >(fst) {}
832 };
833 
834 
835 // Specialization for ComposeFst.
836 template <class A>
837 class ArcIterator< ComposeFst<A> >
838     : public CacheArcIterator< ComposeFst<A> > {
839  public:
840   typedef typename A::StateId StateId;
841 
ArcIterator(const ComposeFst<A> & fst,StateId s)842   ArcIterator(const ComposeFst<A> &fst, StateId s)
843       : CacheArcIterator< ComposeFst<A> >(fst, s) {
844     if (!fst.impl_->HasArcs(s))
845       fst.impl_->Expand(s);
846   }
847 
848  private:
849   DISALLOW_EVIL_CONSTRUCTORS(ArcIterator);
850 };
851 
852 template <class A> inline
InitStateIterator(StateIteratorData<A> * data)853 void ComposeFst<A>::InitStateIterator(StateIteratorData<A> *data) const {
854   data->base = new StateIterator< ComposeFst<A> >(*this);
855 }
856 
857 // Useful alias when using StdArc.
858 typedef ComposeFst<StdArc> StdComposeFst;
859 
860 
861 struct ComposeOptions {
862   bool connect;  // Connect output
863 
ComposeOptionsComposeOptions864   ComposeOptions(bool c) : connect(c) {}
ComposeOptionsComposeOptions865   ComposeOptions() : connect(true) { }
866 };
867 
868 
869 // Computes the composition of two transducers. This version writes
870 // the composed FST into a MurableFst. If FST1 transduces string x to
871 // y with weight a and FST2 transduces y to z with weight b, then
872 // their composition transduces string x to z with weight
873 // Times(x, z).
874 //
875 // The output labels of the first transducer or the input labels of
876 // the second transducer must be sorted.  The weights need to form a
877 // commutative semiring (valid for TropicalWeight and LogWeight).
878 //
879 // Complexity:
880 // Assuming the first FST is unsorted and the second is sorted:
881 // - Time: O(V1 V2 D1 (log D2 + M2)),
882 // - Space: O(V1 V2 D1 M2)
883 // where Vi = # of states, Di = maximum out-degree, and Mi is
884 // the maximum multiplicity for the ith FST.
885 //
886 // Caveats:
887 // - Compose trims its output.
888 // - The efficiency of composition can be strongly affected by several factors:
889 //   - the choice of which tnansducer is sorted - prefer sorting the FST
890 //     that has the greater average out-degree.
891 //   - the amount of non-determinism
892 //   - the presence and location of epsilon transitions - avoid epsilon
893 //     transitions on the output side of the first transducer or
894 //     the input side of the second transducer or prefer placing
895 //     them later in a path since they delay matching and can
896 //     introduce non-coaccessible states and transitions.
897 template<class Arc>
898 void Compose(const Fst<Arc> &ifst1, const Fst<Arc> &ifst2,
899              MutableFst<Arc> *ofst,
900              const ComposeOptions &opts = ComposeOptions()) {
901   ComposeFstOptions<> nopts;
902   nopts.gc_limit = 0;  // Cache only the last state for fastest copy.
903   *ofst = ComposeFst<Arc>(ifst1, ifst2, nopts);
904   if (opts.connect)
905     Connect(ofst);
906 }
907 
908 }  // namespace fst
909 
910 #endif  // FST_LIB_COMPOSE_H__
911