• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // rmepsilon.h
2 
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Copyright 2005-2010 Google, Inc.
16 // Author: allauzen@google.com (Cyril Allauzen)
17 //
18 // \file
19 // Functions and classes that implemement epsilon-removal.
20 
21 #ifndef FST_LIB_RMEPSILON_H__
22 #define FST_LIB_RMEPSILON_H__
23 
24 #include <unordered_map>
25 using std::tr1::unordered_map;
26 using std::tr1::unordered_multimap;
27 #include <fst/slist.h>
28 #include <stack>
29 #include <string>
30 #include <utility>
31 using std::pair; using std::make_pair;
32 #include <vector>
33 using std::vector;
34 
35 #include <fst/arcfilter.h>
36 #include <fst/cache.h>
37 #include <fst/connect.h>
38 #include <fst/factor-weight.h>
39 #include <fst/invert.h>
40 #include <fst/prune.h>
41 #include <fst/queue.h>
42 #include <fst/shortest-distance.h>
43 #include <fst/topsort.h>
44 
45 
46 namespace fst {
47 
48 template <class Arc, class Queue>
49 class RmEpsilonOptions
50     : public ShortestDistanceOptions<Arc, Queue, EpsilonArcFilter<Arc> > {
51  public:
52   typedef typename Arc::StateId StateId;
53   typedef typename Arc::Weight Weight;
54 
55   bool connect;              // Connect output
56   Weight weight_threshold;   // Pruning weight threshold.
57   StateId state_threshold;   // Pruning state threshold.
58 
59   explicit RmEpsilonOptions(Queue *q, float d = kDelta, bool c = true,
60                             Weight w = Weight::Zero(),
61                             StateId n = kNoStateId)
62       : ShortestDistanceOptions< Arc, Queue, EpsilonArcFilter<Arc> >(
63           q, EpsilonArcFilter<Arc>(), kNoStateId, d),
64         connect(c), weight_threshold(w), state_threshold(n) {}
65  private:
66   RmEpsilonOptions();  // disallow
67 };
68 
69 // Computation state of the epsilon-removal algorithm.
70 template <class Arc, class Queue>
71 class RmEpsilonState {
72  public:
73   typedef typename Arc::Label Label;
74   typedef typename Arc::StateId StateId;
75   typedef typename Arc::Weight Weight;
76 
RmEpsilonState(const Fst<Arc> & fst,vector<Weight> * distance,const RmEpsilonOptions<Arc,Queue> & opts)77   RmEpsilonState(const Fst<Arc> &fst,
78                  vector<Weight> *distance,
79                  const RmEpsilonOptions<Arc, Queue> &opts)
80       : fst_(fst), distance_(distance), sd_state_(fst_, distance, opts, true),
81         expand_id_(0) {}
82 
83   // Compute arcs and final weight for state 's'
84   void Expand(StateId s);
85 
86   // Returns arcs of expanded state.
Arcs()87   vector<Arc> &Arcs() { return arcs_; }
88 
89   // Returns final weight of expanded state.
Final()90   const Weight &Final() const { return final_; }
91 
92   // Return true if an error has occured.
Error()93   bool Error() const { return sd_state_.Error(); }
94 
95  private:
96   static const size_t kPrime0 = 7853;
97   static const size_t kPrime1 = 7867;
98 
99   struct Element {
100     Label ilabel;
101     Label olabel;
102     StateId nextstate;
103 
ElementElement104     Element() {}
105 
ElementElement106     Element(Label i, Label o, StateId s)
107         : ilabel(i), olabel(o), nextstate(s) {}
108   };
109 
110   class ElementKey {
111    public:
operator()112     size_t operator()(const Element& e) const {
113       return static_cast<size_t>(e.nextstate);
114       return static_cast<size_t>(e.nextstate +
115                                  e.ilabel * kPrime0 +
116                                  e.olabel * kPrime1);
117     }
118 
119    private:
120   };
121 
122   class ElementEqual {
123    public:
operator()124     bool operator()(const Element &e1, const Element &e2) const {
125       return (e1.ilabel == e2.ilabel) &&  (e1.olabel == e2.olabel)
126                          && (e1.nextstate == e2.nextstate);
127     }
128   };
129 
130   typedef unordered_map<Element, pair<StateId, size_t>,
131                    ElementKey, ElementEqual> ElementMap;
132 
133   const Fst<Arc> &fst_;
134   // Distance from state being expanded in epsilon-closure.
135   vector<Weight> *distance_;
136   // Shortest distance algorithm computation state.
137   ShortestDistanceState<Arc, Queue, EpsilonArcFilter<Arc> > sd_state_;
138   // Maps an element 'e' to a pair 'p' corresponding to a position
139   // in the arcs vector of the state being expanded. 'e' corresponds
140   // to the position 'p.second' in the 'arcs_' vector if 'p.first' is
141   // equal to the state being expanded.
142   ElementMap element_map_;
143   EpsilonArcFilter<Arc> eps_filter_;
144   stack<StateId> eps_queue_;      // Queue used to visit the epsilon-closure
145   vector<bool> visited_;          // '[i] = true' if state 'i' has been visited
146   slist<StateId> visited_states_; // List of visited states
147   vector<Arc> arcs_;              // Arcs of state being expanded
148   Weight final_;                  // Final weight of state being expanded
149   StateId expand_id_;             // Unique ID for each call to Expand
150 
151   DISALLOW_COPY_AND_ASSIGN(RmEpsilonState);
152 };
153 
154 template <class Arc, class Queue>
155 const size_t RmEpsilonState<Arc, Queue>::kPrime0;
156 template <class Arc, class Queue>
157 const size_t RmEpsilonState<Arc, Queue>::kPrime1;
158 
159 
160 template <class Arc, class Queue>
Expand(typename Arc::StateId source)161 void RmEpsilonState<Arc,Queue>::Expand(typename Arc::StateId source) {
162    final_ = Weight::Zero();
163    arcs_.clear();
164    sd_state_.ShortestDistance(source);
165    if (sd_state_.Error())
166      return;
167    eps_queue_.push(source);
168 
169    while (!eps_queue_.empty()) {
170      StateId state = eps_queue_.top();
171      eps_queue_.pop();
172 
173      while (visited_.size() <= state) visited_.push_back(false);
174      if (visited_[state]) continue;
175      visited_[state] = true;
176      visited_states_.push_front(state);
177 
178      for (ArcIterator< Fst<Arc> > ait(fst_, state);
179           !ait.Done();
180           ait.Next()) {
181        Arc arc = ait.Value();
182        arc.weight = Times((*distance_)[state], arc.weight);
183 
184        if (eps_filter_(arc)) {
185          while (visited_.size() <= arc.nextstate)
186            visited_.push_back(false);
187          if (!visited_[arc.nextstate])
188            eps_queue_.push(arc.nextstate);
189        } else {
190           Element element(arc.ilabel, arc.olabel, arc.nextstate);
191           typename ElementMap::iterator it = element_map_.find(element);
192           if (it == element_map_.end()) {
193             element_map_.insert(
194                 pair<Element, pair<StateId, size_t> >
195                 (element, pair<StateId, size_t>(expand_id_, arcs_.size())));
196             arcs_.push_back(arc);
197           } else {
198             if (((*it).second).first == expand_id_) {
199               Weight &w = arcs_[((*it).second).second].weight;
200               w = Plus(w, arc.weight);
201             } else {
202               ((*it).second).first = expand_id_;
203               ((*it).second).second = arcs_.size();
204               arcs_.push_back(arc);
205             }
206           }
207         }
208      }
209      final_ = Plus(final_, Times((*distance_)[state], fst_.Final(state)));
210    }
211 
212    while (!visited_states_.empty()) {
213      visited_[visited_states_.front()] = false;
214      visited_states_.pop_front();
215    }
216    ++expand_id_;
217 }
218 
219 // Removes epsilon-transitions (when both the input and output label
220 // are an epsilon) from a transducer. The result will be an equivalent
221 // FST that has no such epsilon transitions.  This version modifies
222 // its input. It allows fine control via the options argument; see
223 // below for a simpler interface.
224 //
225 // The vector 'distance' will be used to hold the shortest distances
226 // during the epsilon-closure computation. The state queue discipline
227 // and convergence delta are taken in the options argument.
228 template <class Arc, class Queue>
RmEpsilon(MutableFst<Arc> * fst,vector<typename Arc::Weight> * distance,const RmEpsilonOptions<Arc,Queue> & opts)229 void RmEpsilon(MutableFst<Arc> *fst,
230                vector<typename Arc::Weight> *distance,
231                const RmEpsilonOptions<Arc, Queue> &opts) {
232   typedef typename Arc::StateId StateId;
233   typedef typename Arc::Weight Weight;
234   typedef typename Arc::Label Label;
235 
236   if (fst->Start() == kNoStateId) {
237     return;
238   }
239 
240   // 'noneps_in[s]' will be set to true iff 's' admits a non-epsilon
241   // incoming transition or is the start state.
242   vector<bool> noneps_in(fst->NumStates(), false);
243   noneps_in[fst->Start()] = true;
244   for (StateId i = 0; i < fst->NumStates(); ++i) {
245     for (ArcIterator<Fst<Arc> > aiter(*fst, i);
246          !aiter.Done();
247          aiter.Next()) {
248       if (aiter.Value().ilabel != 0 || aiter.Value().olabel != 0)
249         noneps_in[aiter.Value().nextstate] = true;
250     }
251   }
252 
253   // States sorted in topological order when (acyclic) or generic
254   // topological order (cyclic).
255   vector<StateId> states;
256   states.reserve(fst->NumStates());
257 
258   if (fst->Properties(kTopSorted, false) & kTopSorted) {
259     for (StateId i = 0; i < fst->NumStates(); i++)
260       states.push_back(i);
261   } else if (fst->Properties(kAcyclic, false) & kAcyclic) {
262     vector<StateId> order;
263     bool acyclic;
264     TopOrderVisitor<Arc> top_order_visitor(&order, &acyclic);
265     DfsVisit(*fst, &top_order_visitor, EpsilonArcFilter<Arc>());
266     // Sanity check: should be acyclic if property bit is set.
267     if(!acyclic) {
268       FSTERROR() << "RmEpsilon: inconsistent acyclic property bit";
269       fst->SetProperties(kError, kError);
270       return;
271     }
272     states.resize(order.size());
273     for (StateId i = 0; i < order.size(); i++)
274       states[order[i]] = i;
275   } else {
276      uint64 props;
277      vector<StateId> scc;
278      SccVisitor<Arc> scc_visitor(&scc, 0, 0, &props);
279      DfsVisit(*fst, &scc_visitor, EpsilonArcFilter<Arc>());
280      vector<StateId> first(scc.size(), kNoStateId);
281      vector<StateId> next(scc.size(), kNoStateId);
282      for (StateId i = 0; i < scc.size(); i++) {
283        if (first[scc[i]] != kNoStateId)
284          next[i] = first[scc[i]];
285        first[scc[i]] = i;
286      }
287      for (StateId i = 0; i < first.size(); i++)
288        for (StateId j = first[i]; j != kNoStateId; j = next[j])
289          states.push_back(j);
290   }
291 
292   RmEpsilonState<Arc, Queue>
293     rmeps_state(*fst, distance, opts);
294 
295   while (!states.empty()) {
296     StateId state = states.back();
297     states.pop_back();
298     if (!noneps_in[state])
299       continue;
300     rmeps_state.Expand(state);
301     fst->SetFinal(state, rmeps_state.Final());
302     fst->DeleteArcs(state);
303     vector<Arc> &arcs = rmeps_state.Arcs();
304     fst->ReserveArcs(state, arcs.size());
305     while (!arcs.empty()) {
306       fst->AddArc(state, arcs.back());
307       arcs.pop_back();
308     }
309   }
310 
311   for (StateId s = 0; s < fst->NumStates(); ++s) {
312     if (!noneps_in[s])
313       fst->DeleteArcs(s);
314   }
315 
316   if(rmeps_state.Error())
317     fst->SetProperties(kError, kError);
318   fst->SetProperties(
319       RmEpsilonProperties(fst->Properties(kFstProperties, false)),
320       kFstProperties);
321 
322   if (opts.weight_threshold != Weight::Zero() ||
323       opts.state_threshold != kNoStateId)
324     Prune(fst, opts.weight_threshold, opts.state_threshold);
325   if (opts.connect && (opts.weight_threshold == Weight::Zero() ||
326                        opts.state_threshold != kNoStateId))
327     Connect(fst);
328 }
329 
330 // Removes epsilon-transitions (when both the input and output label
331 // are an epsilon) from a transducer. The result will be an equivalent
332 // FST that has no such epsilon transitions. This version modifies its
333 // input. It has a simplified interface; see above for a version that
334 // allows finer control.
335 //
336 // Complexity:
337 // - Time:
338 //   - Unweighted: O(V2 + V E)
339 //   - Acyclic: O(V2 + V E)
340 //   - Tropical semiring: O(V2 log V + V E)
341 //   - General: exponential
342 // - Space: O(V E)
343 // where V = # of states visited, E = # of arcs.
344 //
345 // References:
346 // - Mehryar Mohri. Generic Epsilon-Removal and Input
347 //   Epsilon-Normalization Algorithms for Weighted Transducers,
348 //   "International Journal of Computer Science", 13(1):129-143 (2002).
349 template <class Arc>
350 void RmEpsilon(MutableFst<Arc> *fst,
351                bool connect = true,
352                typename Arc::Weight weight_threshold = Arc::Weight::Zero(),
353                typename Arc::StateId state_threshold = kNoStateId,
354                float delta = kDelta) {
355   typedef typename Arc::StateId StateId;
356   typedef typename Arc::Weight Weight;
357   typedef typename Arc::Label Label;
358 
359   vector<Weight> distance;
360   AutoQueue<StateId> state_queue(*fst, &distance, EpsilonArcFilter<Arc>());
361   RmEpsilonOptions<Arc, AutoQueue<StateId> >
362       opts(&state_queue, delta, connect, weight_threshold, state_threshold);
363 
364   RmEpsilon(fst, &distance, opts);
365 }
366 
367 
368 struct RmEpsilonFstOptions : CacheOptions {
369   float delta;
370 
371   RmEpsilonFstOptions(const CacheOptions &opts, float delta = kDelta)
CacheOptionsRmEpsilonFstOptions372       : CacheOptions(opts), delta(delta) {}
373 
deltaRmEpsilonFstOptions374   explicit RmEpsilonFstOptions(float delta = kDelta) : delta(delta) {}
375 };
376 
377 
378 // Implementation of delayed RmEpsilonFst.
379 template <class A>
380 class RmEpsilonFstImpl : public CacheImpl<A> {
381  public:
382   using FstImpl<A>::SetType;
383   using FstImpl<A>::SetProperties;
384   using FstImpl<A>::SetInputSymbols;
385   using FstImpl<A>::SetOutputSymbols;
386 
387   using CacheBaseImpl< CacheState<A> >::PushArc;
388   using CacheBaseImpl< CacheState<A> >::HasArcs;
389   using CacheBaseImpl< CacheState<A> >::HasFinal;
390   using CacheBaseImpl< CacheState<A> >::HasStart;
391   using CacheBaseImpl< CacheState<A> >::SetArcs;
392   using CacheBaseImpl< CacheState<A> >::SetFinal;
393   using CacheBaseImpl< CacheState<A> >::SetStart;
394 
395   typedef typename A::Label Label;
396   typedef typename A::Weight Weight;
397   typedef typename A::StateId StateId;
398   typedef CacheState<A> State;
399 
RmEpsilonFstImpl(const Fst<A> & fst,const RmEpsilonFstOptions & opts)400   RmEpsilonFstImpl(const Fst<A>& fst, const RmEpsilonFstOptions &opts)
401       : CacheImpl<A>(opts),
402         fst_(fst.Copy()),
403         delta_(opts.delta),
404         rmeps_state_(
405             *fst_,
406             &distance_,
407             RmEpsilonOptions<A, FifoQueue<StateId> >(&queue_, delta_, false)) {
408     SetType("rmepsilon");
409     uint64 props = fst.Properties(kFstProperties, false);
410     SetProperties(RmEpsilonProperties(props, true), kCopyProperties);
411     SetInputSymbols(fst.InputSymbols());
412     SetOutputSymbols(fst.OutputSymbols());
413   }
414 
RmEpsilonFstImpl(const RmEpsilonFstImpl & impl)415   RmEpsilonFstImpl(const RmEpsilonFstImpl &impl)
416       : CacheImpl<A>(impl),
417         fst_(impl.fst_->Copy(true)),
418         delta_(impl.delta_),
419         rmeps_state_(
420             *fst_,
421             &distance_,
422             RmEpsilonOptions<A, FifoQueue<StateId> >(&queue_, delta_, false)) {
423     SetType("rmepsilon");
424     SetProperties(impl.Properties(), kCopyProperties);
425     SetInputSymbols(impl.InputSymbols());
426     SetOutputSymbols(impl.OutputSymbols());
427   }
428 
~RmEpsilonFstImpl()429   ~RmEpsilonFstImpl() {
430     delete fst_;
431   }
432 
Start()433   StateId Start() {
434     if (!HasStart()) {
435       SetStart(fst_->Start());
436     }
437     return CacheImpl<A>::Start();
438   }
439 
Final(StateId s)440   Weight Final(StateId s) {
441     if (!HasFinal(s)) {
442       Expand(s);
443     }
444     return CacheImpl<A>::Final(s);
445   }
446 
NumArcs(StateId s)447   size_t NumArcs(StateId s) {
448     if (!HasArcs(s))
449       Expand(s);
450     return CacheImpl<A>::NumArcs(s);
451   }
452 
NumInputEpsilons(StateId s)453   size_t NumInputEpsilons(StateId s) {
454     if (!HasArcs(s))
455       Expand(s);
456     return CacheImpl<A>::NumInputEpsilons(s);
457   }
458 
NumOutputEpsilons(StateId s)459   size_t NumOutputEpsilons(StateId s) {
460     if (!HasArcs(s))
461       Expand(s);
462     return CacheImpl<A>::NumOutputEpsilons(s);
463   }
464 
Properties()465   uint64 Properties() const { return Properties(kFstProperties); }
466 
467   // Set error if found; return FST impl properties.
Properties(uint64 mask)468   uint64 Properties(uint64 mask) const {
469     if ((mask & kError) &&
470         (fst_->Properties(kError, false) || rmeps_state_.Error()))
471       SetProperties(kError, kError);
472     return FstImpl<A>::Properties(mask);
473   }
474 
InitArcIterator(StateId s,ArcIteratorData<A> * data)475   void InitArcIterator(StateId s, ArcIteratorData<A> *data) {
476     if (!HasArcs(s))
477       Expand(s);
478     CacheImpl<A>::InitArcIterator(s, data);
479   }
480 
Expand(StateId s)481   void Expand(StateId s) {
482     rmeps_state_.Expand(s);
483     SetFinal(s, rmeps_state_.Final());
484     vector<A> &arcs = rmeps_state_.Arcs();
485     while (!arcs.empty()) {
486       PushArc(s, arcs.back());
487       arcs.pop_back();
488     }
489     SetArcs(s);
490   }
491 
492  private:
493   const Fst<A> *fst_;
494   float delta_;
495   vector<Weight> distance_;
496   FifoQueue<StateId> queue_;
497   RmEpsilonState<A, FifoQueue<StateId> > rmeps_state_;
498 
499   void operator=(const RmEpsilonFstImpl<A> &);  // disallow
500 };
501 
502 
503 // Removes epsilon-transitions (when both the input and output label
504 // are an epsilon) from a transducer. The result will be an equivalent
505 // FST that has no such epsilon transitions.  This version is a
506 // delayed Fst.
507 //
508 // Complexity:
509 // - Time:
510 //   - Unweighted: O(v^2 + v e)
511 //   - General: exponential
512 // - Space: O(v e)
513 // where v = # of states visited, e = # of arcs visited. Constant time
514 // to visit an input state or arc is assumed and exclusive of caching.
515 //
516 // References:
517 // - Mehryar Mohri. Generic Epsilon-Removal and Input
518 //   Epsilon-Normalization Algorithms for Weighted Transducers,
519 //   "International Journal of Computer Science", 13(1):129-143 (2002).
520 //
521 // This class attaches interface to implementation and handles
522 // reference counting, delegating most methods to ImplToFst.
523 template <class A>
524 class RmEpsilonFst : public ImplToFst< RmEpsilonFstImpl<A> > {
525  public:
526   friend class ArcIterator< RmEpsilonFst<A> >;
527   friend class StateIterator< RmEpsilonFst<A> >;
528 
529   typedef A Arc;
530   typedef typename A::StateId StateId;
531   typedef CacheState<A> State;
532   typedef RmEpsilonFstImpl<A> Impl;
533 
RmEpsilonFst(const Fst<A> & fst)534   RmEpsilonFst(const Fst<A> &fst)
535       : ImplToFst<Impl>(new Impl(fst, RmEpsilonFstOptions())) {}
536 
RmEpsilonFst(const Fst<A> & fst,const RmEpsilonFstOptions & opts)537   RmEpsilonFst(const Fst<A> &fst, const RmEpsilonFstOptions &opts)
538       : ImplToFst<Impl>(new Impl(fst, opts)) {}
539 
540   // See Fst<>::Copy() for doc.
541   RmEpsilonFst(const RmEpsilonFst<A> &fst, bool safe = false)
542       : ImplToFst<Impl>(fst, safe) {}
543 
544   // Get a copy of this RmEpsilonFst. See Fst<>::Copy() for further doc.
545   virtual RmEpsilonFst<A> *Copy(bool safe = false) const {
546     return new RmEpsilonFst<A>(*this, safe);
547   }
548 
549   virtual inline void InitStateIterator(StateIteratorData<A> *data) const;
550 
InitArcIterator(StateId s,ArcIteratorData<Arc> * data)551   virtual void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const {
552     GetImpl()->InitArcIterator(s, data);
553   }
554 
555  private:
556   // Makes visible to friends.
GetImpl()557   Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); }
558 
559   void operator=(const RmEpsilonFst<A> &fst);  // disallow
560 };
561 
562 // Specialization for RmEpsilonFst.
563 template<class A>
564 class StateIterator< RmEpsilonFst<A> >
565     : public CacheStateIterator< RmEpsilonFst<A> > {
566  public:
StateIterator(const RmEpsilonFst<A> & fst)567   explicit StateIterator(const RmEpsilonFst<A> &fst)
568       : CacheStateIterator< RmEpsilonFst<A> >(fst, fst.GetImpl()) {}
569 };
570 
571 
572 // Specialization for RmEpsilonFst.
573 template <class A>
574 class ArcIterator< RmEpsilonFst<A> >
575     : public CacheArcIterator< RmEpsilonFst<A> > {
576  public:
577   typedef typename A::StateId StateId;
578 
ArcIterator(const RmEpsilonFst<A> & fst,StateId s)579   ArcIterator(const RmEpsilonFst<A> &fst, StateId s)
580       : CacheArcIterator< RmEpsilonFst<A> >(fst.GetImpl(), s) {
581     if (!fst.GetImpl()->HasArcs(s))
582       fst.GetImpl()->Expand(s);
583   }
584 
585  private:
586   DISALLOW_COPY_AND_ASSIGN(ArcIterator);
587 };
588 
589 
590 template <class A> inline
InitStateIterator(StateIteratorData<A> * data)591 void RmEpsilonFst<A>::InitStateIterator(StateIteratorData<A> *data) const {
592   data->base = new StateIterator< RmEpsilonFst<A> >(*this);
593 }
594 
595 
596 // Useful alias when using StdArc.
597 typedef RmEpsilonFst<StdArc> StdRmEpsilonFst;
598 
599 }  // namespace fst
600 
601 #endif  // FST_LIB_RMEPSILON_H__
602