• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // lookahead-matcher.h
2 
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Copyright 2005-2010 Google, Inc.
16 // Author: riley@google.com (Michael Riley)
17 //
18 // \file
19 // Classes to add lookahead to FST matchers, useful e.g. for improving
20 // composition efficiency with certain inputs.
21 
22 #ifndef FST_LIB_LOOKAHEAD_MATCHER_H__
23 #define FST_LIB_LOOKAHEAD_MATCHER_H__
24 
25 #include <fst/add-on.h>
26 #include <fst/const-fst.h>
27 #include <fst/fst.h>
28 #include <fst/label-reachable.h>
29 #include <fst/matcher.h>
30 
31 
32 DECLARE_string(save_relabel_ipairs);
33 DECLARE_string(save_relabel_opairs);
34 
35 namespace fst {
36 
37 // LOOKAHEAD MATCHERS - these have the interface of Matchers (see
38 // matcher.h) and these additional methods:
39 //
40 // template <class F>
41 // class LookAheadMatcher {
42 //  public:
43 //   typedef F FST;
44 //   typedef F::Arc Arc;
45 //   typedef typename Arc::StateId StateId;
46 //   typedef typename Arc::Label Label;
47 //   typedef typename Arc::Weight Weight;
48 //
49 //  // Required constructors.
50 //  LookAheadMatcher(const F &fst, MatchType match_type);
51 //   // If safe=true, the copy is thread-safe (except the lookahead Fst is
52 //   // preserved). See Fst<>::Cop() for further doc.
53 //  LookAheadMatcher(const LookAheadMatcher &matcher, bool safe = false);
54 //
55 //  Below are methods for looking ahead for a match to a label and
56 //  more generally, to a rational set. Each returns false if there is
57 //  definitely not a match and returns true if there possibly is a
58 //  match.
59 
60 //  // LABEL LOOKAHEAD: Can 'label' be read from the current matcher state
61 //  // after possibly following epsilon transitions?
62 //  bool LookAheadLabel(Label label) const;
63 //
64 //  // RATIONAL LOOKAHEAD: The next methods allow looking ahead for an
65 //  // arbitrary rational set of strings, specified by an FST and a state
66 //  // from which to begin the matching. If the lookahead FST is a
67 //  // transducer, this looks on the side different from the matcher
68 //  // 'match_type' (cf. composition).
69 //
70 //  // Are there paths P from 's' in the lookahead FST that can be read from
71 //  // the cur. matcher state?
72 //  bool LookAheadFst(const Fst<Arc>& fst, StateId s);
73 //
74 //  // Gives an estimate of the combined weight of the paths P in the
75 //  // lookahead and matcher FSTs for the last call to LookAheadFst.
76 //  // A trivial implementation returns Weight::One(). Non-trivial
77 //  // implementations are useful for weight-pushing in composition.
78 //  Weight LookAheadWeight() const;
79 //
80 //  // Is there is a single non-epsilon arc found in the lookahead FST
81 //  // that begins P (after possibly following any epsilons) in the last
82 //  // call LookAheadFst? If so, return true and copy it to '*arc', o.w.
83 //  // return false. A trivial implementation returns false. Non-trivial
84 //  // implementations are useful for label-pushing in composition.
85 //  bool LookAheadPrefix(Arc *arc);
86 //
87 //  // Optionally pre-specifies the lookahead FST that will be passed
88 //  // to LookAheadFst() for possible precomputation. If copy is true,
89 //  // then 'fst' is a copy of the FST used in the previous call to
90 //  // this method (useful to avoid unnecessary updates).
91 //  void InitLookAheadFst(const Fst<Arc>& fst, bool copy = false);
92 //
93 // };
94 
95 //
96 // LOOK-AHEAD FLAGS (see also kMatcherFlags in matcher.h):
97 //
98 // Matcher is a lookahead matcher when 'match_type' is MATCH_INPUT.
99 const uint32 kInputLookAheadMatcher =      0x00000001;
100 
101 // Matcher is a lookahead matcher when 'match_type' is MATCH_OUTPUT.
102 const uint32 kOutputLookAheadMatcher =     0x00000002;
103 
104 // A non-trivial implementation of LookAheadWeight() method defined and
105 // should be used?
106 const uint32 kLookAheadWeight =            0x00000004;
107 
108 // A non-trivial implementation of LookAheadPrefix() method defined and
109 // should be used?
110 const uint32 kLookAheadPrefix =            0x00000008;
111 
112 // Look-ahead of matcher FST non-epsilon arcs?
113 const uint32 kLookAheadNonEpsilons =       0x00000010;
114 
115 // Look-ahead of matcher FST epsilon arcs?
116 const uint32 kLookAheadEpsilons =          0x00000020;
117 
118 // Ignore epsilon paths for the lookahead prefix? Note this gives
119 // correct results in composition only with an appropriate composition
120 // filter since it depends on the filter blocking the ignored paths.
121 const uint32 kLookAheadNonEpsilonPrefix =  0x00000040;
122 
123 // For LabelLookAheadMatcher, save relabeling data to file
124 const uint32 kLookAheadKeepRelabelData =  0x00000080;
125 
126 // Flags used for lookahead matchers.
127 const uint32 kLookAheadFlags =            0x000000ff;
128 
129 // LookAhead Matcher interface, templated on the Arc definition; used
130 // for lookahead matcher specializations that are returned by the
131 // InitMatcher() Fst method.
132 template <class A>
133 class LookAheadMatcherBase : public MatcherBase<A> {
134  public:
135   typedef A Arc;
136   typedef typename A::StateId StateId;
137   typedef typename A::Label Label;
138   typedef typename A::Weight Weight;
139 
LookAheadMatcherBase()140   LookAheadMatcherBase()
141   : weight_(Weight::One()),
142     prefix_arc_(kNoLabel, kNoLabel, Weight::One(), kNoStateId) {}
143 
~LookAheadMatcherBase()144   virtual ~LookAheadMatcherBase() {}
145 
LookAheadLabel(Label label)146   bool LookAheadLabel(Label label) const { return LookAheadLabel_(label); }
147 
LookAheadFst(const Fst<Arc> & fst,StateId s)148   bool LookAheadFst(const Fst<Arc> &fst, StateId s) {
149     return LookAheadFst_(fst, s);
150   }
151 
LookAheadWeight()152   Weight LookAheadWeight() const { return weight_; }
153 
LookAheadPrefix(Arc * arc)154   bool LookAheadPrefix(Arc *arc) const {
155     if (prefix_arc_.nextstate != kNoStateId) {
156       *arc = prefix_arc_;
157       return true;
158     } else {
159       return false;
160     }
161   }
162 
163   virtual void InitLookAheadFst(const Fst<Arc>& fst, bool copy = false) = 0;
164 
165  protected:
SetLookAheadWeight(const Weight & w)166   void SetLookAheadWeight(const Weight &w) { weight_ = w; }
167 
SetLookAheadPrefix(const Arc & arc)168   void SetLookAheadPrefix(const Arc &arc) { prefix_arc_ = arc; }
169 
ClearLookAheadPrefix()170   void ClearLookAheadPrefix() { prefix_arc_.nextstate = kNoStateId; }
171 
172  private:
173   virtual bool LookAheadLabel_(Label label) const = 0;
174   virtual bool LookAheadFst_(const Fst<Arc> &fst,
175                              StateId s) = 0;  // This must set l.a. weight and
176                                               // prefix if non-trivial.
177   Weight weight_;                             // Look-ahead weight
178   Arc prefix_arc_;                            // Look-ahead prefix arc
179 };
180 
181 
182 // Don't really lookahead, just declare future looks good regardless.
183 template <class M>
184 class TrivialLookAheadMatcher
185     : public LookAheadMatcherBase<typename M::FST::Arc> {
186  public:
187   typedef typename M::FST FST;
188   typedef typename M::Arc Arc;
189   typedef typename Arc::StateId StateId;
190   typedef typename Arc::Label Label;
191   typedef typename Arc::Weight Weight;
192 
TrivialLookAheadMatcher(const FST & fst,MatchType match_type)193   TrivialLookAheadMatcher(const FST &fst, MatchType match_type)
194       : matcher_(fst, match_type) {}
195 
196   TrivialLookAheadMatcher(const TrivialLookAheadMatcher<M> &lmatcher,
197                           bool safe = false)
198       : matcher_(lmatcher.matcher_, safe) {}
199 
200   // General matcher methods
201   TrivialLookAheadMatcher<M> *Copy(bool safe = false) const {
202     return new TrivialLookAheadMatcher<M>(*this, safe);
203   }
204 
Type(bool test)205   MatchType Type(bool test) const { return matcher_.Type(test); }
SetState(StateId s)206   void SetState(StateId s) { return matcher_.SetState(s); }
Find(Label label)207   bool Find(Label label) { return matcher_.Find(label); }
Done()208   bool Done() const { return matcher_.Done(); }
Value()209   const Arc& Value() const { return matcher_.Value(); }
Next()210   void Next() { matcher_.Next(); }
GetFst()211   virtual const FST &GetFst() const { return matcher_.GetFst(); }
Properties(uint64 props)212   uint64 Properties(uint64 props) const { return matcher_.Properties(props); }
Flags()213   uint32 Flags() const {
214     return matcher_.Flags() | kInputLookAheadMatcher | kOutputLookAheadMatcher;
215   }
216 
217   // Look-ahead methods.
LookAheadLabel(Label label)218   bool LookAheadLabel(Label label) const { return true;  }
LookAheadFst(const Fst<Arc> & fst,StateId s)219   bool LookAheadFst(const Fst<Arc> &fst, StateId s) {return true; }
LookAheadWeight()220   Weight LookAheadWeight() const { return Weight::One(); }
LookAheadPrefix(Arc * arc)221   bool LookAheadPrefix(Arc *arc) const { return false; }
222   void InitLookAheadFst(const Fst<Arc>& fst, bool copy = false) {}
223 
224  private:
225   // This allows base class virtual access to non-virtual derived-
226   // class members of the same name. It makes the derived class more
227   // efficient to use but unsafe to further derive.
SetState_(StateId s)228   virtual void SetState_(StateId s) { SetState(s); }
Find_(Label label)229   virtual bool Find_(Label label) { return Find(label); }
Done_()230   virtual bool Done_() const { return Done(); }
Value_()231   virtual const Arc& Value_() const { return Value(); }
Next_()232   virtual void Next_() { Next(); }
233 
LookAheadLabel_(Label l)234   bool LookAheadLabel_(Label l) const { return LookAheadLabel(l); }
235 
LookAheadFst_(const Fst<Arc> & fst,StateId s)236   bool LookAheadFst_(const Fst<Arc> &fst, StateId s) {
237     return LookAheadFst(fst, s);
238   }
239 
LookAheadWeight_()240   Weight LookAheadWeight_() const { return LookAheadWeight(); }
LookAheadPrefix_(Arc * arc)241   bool LookAheadPrefix_(Arc *arc) const { return LookAheadPrefix(arc); }
242 
243   M matcher_;
244 };
245 
246 // Look-ahead of one transition. Template argument F accepts flags to
247 // control behavior.
248 template <class M, uint32 F = kLookAheadNonEpsilons | kLookAheadEpsilons |
249           kLookAheadWeight | kLookAheadPrefix>
250 class ArcLookAheadMatcher
251     : public LookAheadMatcherBase<typename M::FST::Arc> {
252  public:
253   typedef typename M::FST FST;
254   typedef typename M::Arc Arc;
255   typedef typename Arc::StateId StateId;
256   typedef typename Arc::Label Label;
257   typedef typename Arc::Weight Weight;
258   typedef NullAddOn MatcherData;
259 
260   using LookAheadMatcherBase<Arc>::LookAheadWeight;
261   using LookAheadMatcherBase<Arc>::SetLookAheadPrefix;
262   using LookAheadMatcherBase<Arc>::SetLookAheadWeight;
263   using LookAheadMatcherBase<Arc>::ClearLookAheadPrefix;
264 
265   ArcLookAheadMatcher(const FST &fst, MatchType match_type,
266                       MatcherData *data = 0)
matcher_(fst,match_type)267       : matcher_(fst, match_type),
268         fst_(matcher_.GetFst()),
269         lfst_(0),
270         s_(kNoStateId) {}
271 
272   ArcLookAheadMatcher(const ArcLookAheadMatcher<M, F> &lmatcher,
273                       bool safe = false)
274       : matcher_(lmatcher.matcher_, safe),
275         fst_(matcher_.GetFst()),
276         lfst_(lmatcher.lfst_),
277         s_(kNoStateId) {}
278 
279   // General matcher methods
280   ArcLookAheadMatcher<M, F> *Copy(bool safe = false) const {
281     return new ArcLookAheadMatcher<M, F>(*this, safe);
282   }
283 
Type(bool test)284   MatchType Type(bool test) const { return matcher_.Type(test); }
285 
SetState(StateId s)286   void SetState(StateId s) {
287     s_ = s;
288     matcher_.SetState(s);
289   }
290 
Find(Label label)291   bool Find(Label label) { return matcher_.Find(label); }
Done()292   bool Done() const { return matcher_.Done(); }
Value()293   const Arc& Value() const { return matcher_.Value(); }
Next()294   void Next() { matcher_.Next(); }
GetFst()295   const FST &GetFst() const { return fst_; }
Properties(uint64 props)296   uint64 Properties(uint64 props) const { return matcher_.Properties(props); }
Flags()297   uint32 Flags() const {
298     return matcher_.Flags() | kInputLookAheadMatcher |
299         kOutputLookAheadMatcher | F;
300   }
301 
302   // Writable matcher methods
GetData()303   MatcherData *GetData() const { return 0; }
304 
305   // Look-ahead methods.
LookAheadLabel(Label label)306   bool LookAheadLabel(Label label) const { return matcher_.Find(label); }
307 
308   // Checks if there is a matching (possibly super-final) transition
309   // at (s_, s).
310   bool LookAheadFst(const Fst<Arc> &fst, StateId s);
311 
312   void InitLookAheadFst(const Fst<Arc>& fst, bool copy = false) {
313     lfst_ = &fst;
314   }
315 
316  private:
317   // This allows base class virtual access to non-virtual derived-
318   // class members of the same name. It makes the derived class more
319   // efficient to use but unsafe to further derive.
SetState_(StateId s)320   virtual void SetState_(StateId s) { SetState(s); }
Find_(Label label)321   virtual bool Find_(Label label) { return Find(label); }
Done_()322   virtual bool Done_() const { return Done(); }
Value_()323   virtual const Arc& Value_() const { return Value(); }
Next_()324   virtual void Next_() { Next(); }
325 
LookAheadLabel_(Label l)326   bool LookAheadLabel_(Label l) const { return LookAheadLabel(l); }
LookAheadFst_(const Fst<Arc> & fst,StateId s)327   bool LookAheadFst_(const Fst<Arc> &fst, StateId s) {
328     return LookAheadFst(fst, s);
329   }
330 
331   mutable M matcher_;
332   const FST &fst_;         // Matcher FST
333   const Fst<Arc> *lfst_;   // Look-ahead FST
334   StateId s_;              // Matcher state
335 };
336 
337 template <class M, uint32 F>
LookAheadFst(const Fst<Arc> & fst,StateId s)338 bool ArcLookAheadMatcher<M, F>::LookAheadFst(const Fst<Arc> &fst, StateId s) {
339   if (&fst != lfst_)
340     InitLookAheadFst(fst);
341 
342   bool ret = false;
343   ssize_t nprefix = 0;
344   if (F & kLookAheadWeight)
345     SetLookAheadWeight(Weight::Zero());
346   if (F & kLookAheadPrefix)
347     ClearLookAheadPrefix();
348   if (fst_.Final(s_) != Weight::Zero() &&
349       lfst_->Final(s) != Weight::Zero()) {
350     if (!(F & (kLookAheadWeight | kLookAheadPrefix)))
351       return true;
352     ++nprefix;
353     if (F & kLookAheadWeight)
354       SetLookAheadWeight(Plus(LookAheadWeight(),
355                               Times(fst_.Final(s_), lfst_->Final(s))));
356     ret = true;
357   }
358   if (matcher_.Find(kNoLabel)) {
359     if (!(F & (kLookAheadWeight | kLookAheadPrefix)))
360       return true;
361     ++nprefix;
362     if (F & kLookAheadWeight)
363       for (; !matcher_.Done(); matcher_.Next())
364         SetLookAheadWeight(Plus(LookAheadWeight(), matcher_.Value().weight));
365     ret = true;
366   }
367   for (ArcIterator< Fst<Arc> > aiter(*lfst_, s);
368        !aiter.Done();
369        aiter.Next()) {
370     const Arc &arc = aiter.Value();
371     Label label = kNoLabel;
372     switch (matcher_.Type(false)) {
373       case MATCH_INPUT:
374         label = arc.olabel;
375         break;
376       case MATCH_OUTPUT:
377         label = arc.ilabel;
378         break;
379       default:
380         FSTERROR() << "ArcLookAheadMatcher::LookAheadFst: bad match type";
381         return true;
382     }
383     if (label == 0) {
384       if (!(F & (kLookAheadWeight | kLookAheadPrefix)))
385         return true;
386       if (!(F & kLookAheadNonEpsilonPrefix))
387         ++nprefix;
388       if (F & kLookAheadWeight)
389         SetLookAheadWeight(Plus(LookAheadWeight(), arc.weight));
390       ret = true;
391     } else if (matcher_.Find(label)) {
392       if (!(F & (kLookAheadWeight | kLookAheadPrefix)))
393         return true;
394       for (; !matcher_.Done(); matcher_.Next()) {
395         ++nprefix;
396         if (F & kLookAheadWeight)
397           SetLookAheadWeight(Plus(LookAheadWeight(),
398                                   Times(arc.weight,
399                                         matcher_.Value().weight)));
400         if ((F & kLookAheadPrefix) && nprefix == 1)
401           SetLookAheadPrefix(arc);
402       }
403       ret = true;
404     }
405   }
406   if (F & kLookAheadPrefix) {
407     if (nprefix == 1)
408       SetLookAheadWeight(Weight::One());  // Avoids double counting.
409     else
410       ClearLookAheadPrefix();
411   }
412   return ret;
413 }
414 
415 
416 // Template argument F accepts flags to control behavior.
417 // It must include precisely one of KInputLookAheadMatcher or
418 // KOutputLookAheadMatcher.
419 template <class M, uint32 F = kLookAheadEpsilons | kLookAheadWeight |
420           kLookAheadPrefix | kLookAheadNonEpsilonPrefix |
421           kLookAheadKeepRelabelData,
422           class S = DefaultAccumulator<typename M::Arc> >
423 class LabelLookAheadMatcher
424     : public LookAheadMatcherBase<typename M::FST::Arc> {
425  public:
426   typedef typename M::FST FST;
427   typedef typename M::Arc Arc;
428   typedef typename Arc::StateId StateId;
429   typedef typename Arc::Label Label;
430   typedef typename Arc::Weight Weight;
431   typedef LabelReachableData<Label> MatcherData;
432 
433   using LookAheadMatcherBase<Arc>::LookAheadWeight;
434   using LookAheadMatcherBase<Arc>::SetLookAheadPrefix;
435   using LookAheadMatcherBase<Arc>::SetLookAheadWeight;
436   using LookAheadMatcherBase<Arc>::ClearLookAheadPrefix;
437 
438   LabelLookAheadMatcher(const FST &fst, MatchType match_type,
439                         MatcherData *data = 0, S *s = 0)
matcher_(fst,match_type)440       : matcher_(fst, match_type),
441         lfst_(0),
442         label_reachable_(0),
443         s_(kNoStateId),
444         error_(false) {
445     if (!(F & (kInputLookAheadMatcher | kOutputLookAheadMatcher))) {
446       FSTERROR() << "LabelLookaheadMatcher: bad matcher flags: " << F;
447       error_ = true;
448     }
449     bool reach_input = match_type == MATCH_INPUT;
450     if (data) {
451       if (reach_input == data->ReachInput())
452         label_reachable_ = new LabelReachable<Arc, S>(data, s);
453     } else if ((reach_input && (F & kInputLookAheadMatcher)) ||
454                (!reach_input && (F & kOutputLookAheadMatcher))) {
455       label_reachable_ = new LabelReachable<Arc, S>(
456           fst, reach_input, s, F & kLookAheadKeepRelabelData);
457     }
458   }
459 
460   LabelLookAheadMatcher(const LabelLookAheadMatcher<M, F, S> &lmatcher,
461                         bool safe = false)
462       : matcher_(lmatcher.matcher_, safe),
463         lfst_(lmatcher.lfst_),
464         label_reachable_(
465             lmatcher.label_reachable_ ?
466             new LabelReachable<Arc, S>(*lmatcher.label_reachable_) : 0),
467         s_(kNoStateId),
468         error_(lmatcher.error_) {}
469 
~LabelLookAheadMatcher()470   ~LabelLookAheadMatcher() {
471     delete label_reachable_;
472   }
473 
474   // General matcher methods
475   LabelLookAheadMatcher<M, F, S> *Copy(bool safe = false) const {
476     return new LabelLookAheadMatcher<M, F, S>(*this, safe);
477   }
478 
Type(bool test)479   MatchType Type(bool test) const { return matcher_.Type(test); }
480 
SetState(StateId s)481   void SetState(StateId s) {
482     if (s_ == s)
483       return;
484     s_ = s;
485     match_set_state_ = false;
486     reach_set_state_ = false;
487   }
488 
Find(Label label)489   bool Find(Label label) {
490     if (!match_set_state_) {
491       matcher_.SetState(s_);
492       match_set_state_ = true;
493     }
494     return matcher_.Find(label);
495   }
496 
Done()497   bool Done() const { return matcher_.Done(); }
Value()498   const Arc& Value() const { return matcher_.Value(); }
Next()499   void Next() { matcher_.Next(); }
GetFst()500   const FST &GetFst() const { return matcher_.GetFst(); }
501 
Properties(uint64 inprops)502   uint64 Properties(uint64 inprops) const {
503     uint64 outprops = matcher_.Properties(inprops);
504     if (error_ || (label_reachable_ && label_reachable_->Error()))
505       outprops |= kError;
506     return outprops;
507   }
508 
Flags()509   uint32 Flags() const {
510     if (label_reachable_ && label_reachable_->GetData()->ReachInput())
511       return matcher_.Flags() | F | kInputLookAheadMatcher;
512     else if (label_reachable_ && !label_reachable_->GetData()->ReachInput())
513       return matcher_.Flags() | F | kOutputLookAheadMatcher;
514     else
515       return matcher_.Flags();
516   }
517 
518   // Writable matcher methods
GetData()519   MatcherData *GetData() const {
520     return label_reachable_ ? label_reachable_->GetData() : 0;
521   };
522 
523   // Look-ahead methods.
LookAheadLabel(Label label)524   bool LookAheadLabel(Label label) const {
525     if (label == 0)
526       return true;
527 
528     if (label_reachable_) {
529       if (!reach_set_state_) {
530         label_reachable_->SetState(s_);
531         reach_set_state_ = true;
532       }
533       return label_reachable_->Reach(label);
534     } else {
535       return true;
536     }
537   }
538 
539   // Checks if there is a matching (possibly super-final) transition
540   // at (s_, s).
541   template <class L>
542   bool LookAheadFst(const L &fst, StateId s);
543 
544   void InitLookAheadFst(const Fst<Arc>& fst, bool copy = false) {
545     lfst_ = &fst;
546     if (label_reachable_)
547       label_reachable_->ReachInit(fst, copy);
548   }
549 
550   template <class L>
551   void InitLookAheadFst(const L& fst, bool copy = false) {
552     lfst_ = static_cast<const Fst<Arc> *>(&fst);
553     if (label_reachable_)
554       label_reachable_->ReachInit(fst, copy);
555   }
556 
557  private:
558   // This allows base class virtual access to non-virtual derived-
559   // class members of the same name. It makes the derived class more
560   // efficient to use but unsafe to further derive.
SetState_(StateId s)561   virtual void SetState_(StateId s) { SetState(s); }
Find_(Label label)562   virtual bool Find_(Label label) { return Find(label); }
Done_()563   virtual bool Done_() const { return Done(); }
Value_()564   virtual const Arc& Value_() const { return Value(); }
Next_()565   virtual void Next_() { Next(); }
566 
LookAheadLabel_(Label l)567   bool LookAheadLabel_(Label l) const { return LookAheadLabel(l); }
LookAheadFst_(const Fst<Arc> & fst,StateId s)568   bool LookAheadFst_(const Fst<Arc> &fst, StateId s) {
569     return LookAheadFst(fst, s);
570   }
571 
572   mutable M matcher_;
573   const Fst<Arc> *lfst_;                     // Look-ahead FST
574   LabelReachable<Arc, S> *label_reachable_;  // Label reachability info
575   StateId s_;                                // Matcher state
576   bool match_set_state_;                     // matcher_.SetState called?
577   mutable bool reach_set_state_;             // reachable_.SetState called?
578   bool error_;
579 };
580 
581 template <class M, uint32 F, class S>
582 template <class L> inline
LookAheadFst(const L & fst,StateId s)583 bool LabelLookAheadMatcher<M, F, S>::LookAheadFst(const L &fst, StateId s) {
584   if (static_cast<const Fst<Arc> *>(&fst) != lfst_)
585     InitLookAheadFst(fst);
586 
587   SetLookAheadWeight(Weight::One());
588   ClearLookAheadPrefix();
589 
590   if (!label_reachable_)
591     return true;
592 
593   label_reachable_->SetState(s_, s);
594   reach_set_state_ = true;
595 
596   bool compute_weight = F & kLookAheadWeight;
597   bool compute_prefix = F & kLookAheadPrefix;
598 
599   bool reach_input = Type(false) == MATCH_OUTPUT;
600   ArcIterator<L> aiter(fst, s);
601   bool reach_arc = label_reachable_->Reach(&aiter, 0,
602                                            internal::NumArcs(*lfst_, s),
603                                            reach_input, compute_weight);
604   if (reach_arc) {
605     ssize_t begin = label_reachable_->ReachBegin();
606     ssize_t end = label_reachable_->ReachEnd();
607     if (compute_prefix && end - begin == 1) {
608       aiter.Seek(begin);
609       SetLookAheadPrefix(aiter.Value());
610       compute_weight = false;
611     } else if (compute_weight) {
612       SetLookAheadWeight(label_reachable_->ReachWeight());
613     }
614   }
615   Weight lfinal = internal::Final(*lfst_, s);
616   bool reach_final = lfinal != Weight::Zero() &&
617       label_reachable_->ReachFinal();
618   if (reach_final && compute_weight)
619     SetLookAheadWeight(reach_arc ?
620                        Plus(LookAheadWeight(), lfinal) : lfinal);
621 
622   return reach_arc || reach_final;
623 }
624 
625 
626 // Label-lookahead relabeling class.
627 template <class A>
628 class LabelLookAheadRelabeler {
629  public:
630   typedef typename A::Label Label;
631   typedef LabelReachableData<Label> MatcherData;
632   typedef AddOnPair<MatcherData, MatcherData> D;
633 
634   // Relabels matcher Fst - initialization function object.
635   template <typename I>
636   LabelLookAheadRelabeler(I **impl);
637 
638   // Relabels arbitrary Fst. Class L should be a label-lookahead Fst.
639   template <class L>
Relabel(MutableFst<A> * fst,const L & mfst,bool relabel_input)640   static void Relabel(MutableFst<A> *fst, const L &mfst,
641                       bool relabel_input) {
642     typename L::Impl *impl = mfst.GetImpl();
643     D *data = impl->GetAddOn();
644     LabelReachable<A> reachable(data->First() ?
645                                   data->First() : data->Second());
646     reachable.Relabel(fst, relabel_input);
647   }
648 
649   // Returns relabeling pairs (cf. relabel.h::Relabel()).
650   // Class L should be a label-lookahead Fst.
651   // If 'avoid_collisions' is true, extra pairs are added to
652   // ensure no collisions when relabeling automata that have
653   // labels unseen here.
654   template <class L>
655   static void RelabelPairs(const L &mfst, vector<pair<Label, Label> > *pairs,
656                            bool avoid_collisions = false) {
657     typename L::Impl *impl = mfst.GetImpl();
658     D *data = impl->GetAddOn();
659     LabelReachable<A> reachable(data->First() ?
660                                   data->First() : data->Second());
661     reachable.RelabelPairs(pairs, avoid_collisions);
662   }
663 };
664 
665 template <class A>
666 template <typename I> inline
LabelLookAheadRelabeler(I ** impl)667 LabelLookAheadRelabeler<A>::LabelLookAheadRelabeler(I **impl) {
668   Fst<A> &fst = (*impl)->GetFst();
669   D *data = (*impl)->GetAddOn();
670   const string name = (*impl)->Type();
671   bool is_mutable = fst.Properties(kMutable, false);
672   MutableFst<A> *mfst = 0;
673   if (is_mutable) {
674     mfst = static_cast<MutableFst<A> *>(&fst);
675   } else {
676     mfst = new VectorFst<A>(fst);
677     data->IncrRefCount();
678     delete *impl;
679   }
680   if (data->First()) {  // reach_input
681     LabelReachable<A> reachable(data->First());
682     reachable.Relabel(mfst, true);
683     if (!FLAGS_save_relabel_ipairs.empty()) {
684       vector<pair<Label, Label> > pairs;
685       reachable.RelabelPairs(&pairs, true);
686       WriteLabelPairs(FLAGS_save_relabel_ipairs, pairs);
687     }
688   } else {
689     LabelReachable<A> reachable(data->Second());
690     reachable.Relabel(mfst, false);
691     if (!FLAGS_save_relabel_opairs.empty()) {
692       vector<pair<Label, Label> > pairs;
693       reachable.RelabelPairs(&pairs, true);
694       WriteLabelPairs(FLAGS_save_relabel_opairs, pairs);
695     }
696   }
697   if (!is_mutable) {
698     *impl = new I(*mfst, name);
699     (*impl)->SetAddOn(data);
700     delete mfst;
701     data->DecrRefCount();
702   }
703 }
704 
705 
706 // Generic lookahead matcher, templated on the FST definition
707 // - a wrapper around pointer to specific one.
708 template <class F>
709 class LookAheadMatcher {
710  public:
711   typedef F FST;
712   typedef typename F::Arc Arc;
713   typedef typename Arc::StateId StateId;
714   typedef typename Arc::Label Label;
715   typedef typename Arc::Weight Weight;
716   typedef LookAheadMatcherBase<Arc> LBase;
717 
LookAheadMatcher(const F & fst,MatchType match_type)718   LookAheadMatcher(const F &fst, MatchType match_type) {
719     base_ = fst.InitMatcher(match_type);
720     if (!base_)
721       base_ = new SortedMatcher<F>(fst, match_type);
722     lookahead_ = false;
723   }
724 
725   LookAheadMatcher(const LookAheadMatcher<F> &matcher, bool safe = false) {
726     base_ = matcher.base_->Copy(safe);
727     lookahead_ = matcher.lookahead_;
728   }
729 
~LookAheadMatcher()730   ~LookAheadMatcher() { delete base_; }
731 
732   // General matcher methods
733   LookAheadMatcher<F> *Copy(bool safe = false) const {
734       return new LookAheadMatcher<F>(*this, safe);
735   }
736 
Type(bool test)737   MatchType Type(bool test) const { return base_->Type(test); }
SetState(StateId s)738   void SetState(StateId s) { base_->SetState(s); }
Find(Label label)739   bool Find(Label label) { return base_->Find(label); }
Done()740   bool Done() const { return base_->Done(); }
Value()741   const Arc& Value() const { return base_->Value(); }
Next()742   void Next() { base_->Next(); }
GetFst()743   const F &GetFst() const { return static_cast<const F &>(base_->GetFst()); }
744 
Properties(uint64 props)745   uint64 Properties(uint64 props) const { return base_->Properties(props); }
746 
Flags()747   uint32 Flags() const { return base_->Flags(); }
748 
749   // Look-ahead methods
LookAheadLabel(Label label)750   bool LookAheadLabel(Label label) const {
751     if (LookAheadCheck()) {
752       LBase *lbase = static_cast<LBase *>(base_);
753       return lbase->LookAheadLabel(label);
754     } else {
755       return true;
756     }
757   }
758 
LookAheadFst(const Fst<Arc> & fst,StateId s)759   bool LookAheadFst(const Fst<Arc> &fst, StateId s) {
760     if (LookAheadCheck()) {
761       LBase *lbase = static_cast<LBase *>(base_);
762       return lbase->LookAheadFst(fst, s);
763     } else {
764       return true;
765     }
766   }
767 
LookAheadWeight()768   Weight LookAheadWeight() const {
769     if (LookAheadCheck()) {
770       LBase *lbase = static_cast<LBase *>(base_);
771       return lbase->LookAheadWeight();
772     } else {
773       return Weight::One();
774     }
775   }
776 
LookAheadPrefix(Arc * arc)777   bool LookAheadPrefix(Arc *arc) const {
778     if (LookAheadCheck()) {
779       LBase *lbase = static_cast<LBase *>(base_);
780       return lbase->LookAheadPrefix(arc);
781     } else {
782       return false;
783     }
784   }
785 
786   void InitLookAheadFst(const Fst<Arc>& fst, bool copy = false) {
787     if (LookAheadCheck()) {
788       LBase *lbase = static_cast<LBase *>(base_);
789       lbase->InitLookAheadFst(fst, copy);
790     }
791   }
792 
793  private:
LookAheadCheck()794   bool LookAheadCheck() const {
795     if (!lookahead_) {
796       lookahead_ = base_->Flags() &
797           (kInputLookAheadMatcher | kOutputLookAheadMatcher);
798       if (!lookahead_) {
799         FSTERROR() << "LookAheadMatcher: No look-ahead matcher defined";
800       }
801     }
802     return lookahead_;
803   }
804 
805   MatcherBase<Arc> *base_;
806   mutable bool lookahead_;
807 
808   void operator=(const LookAheadMatcher<Arc> &);  // disallow
809 };
810 
811 }  // namespace fst
812 
813 #endif  // FST_LIB_LOOKAHEAD_MATCHER_H__
814