• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // rmepsilon.h
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Author: allauzen@cs.nyu.edu (Cyril Allauzen)
16 //
17 // \file
18 // Functions and classes that implemement epsilon-removal.
19 
20 #ifndef FST_LIB_RMEPSILON_H__
21 #define FST_LIB_RMEPSILON_H__
22 
23 #include <ext/hash_map>
24 using __gnu_cxx::hash_map;
25 #include <ext/slist>
26 using __gnu_cxx::slist;
27 
28 #include "fst/lib/arcfilter.h"
29 #include "fst/lib/cache.h"
30 #include "fst/lib/connect.h"
31 #include "fst/lib/factor-weight.h"
32 #include "fst/lib/invert.h"
33 #include "fst/lib/map.h"
34 #include "fst/lib/queue.h"
35 #include "fst/lib/shortest-distance.h"
36 #include "fst/lib/topsort.h"
37 
38 namespace fst {
39 
40 template <class Arc, class Queue>
41 struct RmEpsilonOptions
42     : public ShortestDistanceOptions<Arc, Queue, EpsilonArcFilter<Arc> > {
43   typedef typename Arc::StateId StateId;
44 
45   bool connect;  // Connect output
46 
47   RmEpsilonOptions(Queue *q, float d = kDelta, bool c = true)
48       : ShortestDistanceOptions<Arc, Queue, EpsilonArcFilter<Arc> >(
49           q, EpsilonArcFilter<Arc>(), kNoStateId, d), connect(c) {}
50 
51 };
52 
53 
54 // Computation state of the epsilon-removal algorithm.
55 template <class Arc, class Queue>
56 class RmEpsilonState {
57  public:
58   typedef typename Arc::Label Label;
59   typedef typename Arc::StateId StateId;
60   typedef typename Arc::Weight Weight;
61 
RmEpsilonState(const Fst<Arc> & fst,vector<Weight> * distance,const RmEpsilonOptions<Arc,Queue> & opts)62   RmEpsilonState(const Fst<Arc> &fst,
63                  vector<Weight> *distance,
64                  const RmEpsilonOptions<Arc, Queue> &opts)
65       : fst_(fst), distance_(distance), sd_state_(fst_, distance, opts, true) {
66   }
67 
68   // Compute arcs and final weight for state 's'
69   void Expand(StateId s);
70 
71   // Returns arcs of expanded state.
Arcs()72   vector<Arc> &Arcs() { return arcs_; }
73 
74   // Returns final weight of expanded state.
Final()75   const Weight &Final() const { return final_; }
76 
77  private:
78   struct Element {
79     Label ilabel;
80     Label olabel;
81     StateId nextstate;
82 
ElementElement83     Element() {}
84 
ElementElement85     Element(Label i, Label o, StateId s)
86         : ilabel(i), olabel(o), nextstate(s) {}
87   };
88 
89   class ElementKey {
90    public:
operator()91     size_t operator()(const Element& e) const {
92       return static_cast<size_t>(e.nextstate);
93       return static_cast<size_t>(e.nextstate +
94                                  e.ilabel * kPrime0 +
95                                  e.olabel * kPrime1);
96     }
97 
98    private:
99     static const int kPrime0 = 7853;
100     static const int kPrime1 = 7867;
101   };
102 
103   class ElementEqual {
104    public:
operator()105     bool operator()(const Element &e1, const Element &e2) const {
106       return (e1.ilabel == e2.ilabel) &&  (e1.olabel == e2.olabel)
107                          && (e1.nextstate == e2.nextstate);
108     }
109   };
110 
111  private:
112   typedef hash_map<Element, pair<StateId, ssize_t>,
113                    ElementKey, ElementEqual> ElementMap;
114 
115   const Fst<Arc> &fst_;
116   // Distance from state being expanded in epsilon-closure.
117   vector<Weight> *distance_;
118   // Shortest distance algorithm computation state.
119   ShortestDistanceState<Arc, Queue, EpsilonArcFilter<Arc> > sd_state_;
120   // Maps an element 'e' to a pair 'p' corresponding to a position
121   // in the arcs vector of the state being expanded. 'e' corresponds
122   // to the position 'p.second' in the 'arcs_' vector if 'p.first' is
123   // equal to the state being expanded.
124   ElementMap element_map_;
125   EpsilonArcFilter<Arc> eps_filter_;
126   stack<StateId> eps_queue_;      // Queue used to visit the epsilon-closure
127   vector<bool> visited_;          // '[i] = true' if state 'i' has been visited
128   slist<StateId> visited_states_; // List of visited states
129   vector<Arc> arcs_;              // Arcs of state being expanded
130   Weight final_;                  // Final weight of state being expanded
131 
132   void operator=(const RmEpsilonState);  // Disallow
133 };
134 
135 
136 template <class Arc, class Queue>
Expand(typename Arc::StateId source)137 void RmEpsilonState<Arc,Queue>::Expand(typename Arc::StateId source) {
138    sd_state_.ShortestDistance(source);
139    eps_queue_.push(source);
140    final_ = Weight::Zero();
141    arcs_.clear();
142 
143    while (!eps_queue_.empty()) {
144      StateId state = eps_queue_.top();
145      eps_queue_.pop();
146 
147      while ((StateId)visited_.size() <= state) visited_.push_back(false);
148      visited_[state] = true;
149      visited_states_.push_front(state);
150 
151      for (ArcIterator< Fst<Arc> > ait(fst_, state);
152           !ait.Done();
153           ait.Next()) {
154        Arc arc = ait.Value();
155        arc.weight = Times((*distance_)[state], arc.weight);
156 
157        if (eps_filter_(arc)) {
158          while ((StateId)visited_.size() <= arc.nextstate)
159            visited_.push_back(false);
160          if (!visited_[arc.nextstate])
161            eps_queue_.push(arc.nextstate);
162        } else {
163           Element element(arc.ilabel, arc.olabel, arc.nextstate);
164           typename ElementMap::iterator it = element_map_.find(element);
165           if (it == element_map_.end()) {
166             element_map_.insert(
167                 pair<Element, pair<StateId, ssize_t> >
168                 (element, pair<StateId, ssize_t>(source, arcs_.size())));
169             arcs_.push_back(arc);
170           } else {
171             if (((*it).second).first == source) {
172               Weight &w = arcs_[((*it).second).second].weight;
173               w = Plus(w, arc.weight);
174             } else {
175               ((*it).second).first = source;
176               ((*it).second).second = arcs_.size();
177               arcs_.push_back(arc);
178             }
179           }
180         }
181      }
182      final_ = Plus(final_, Times((*distance_)[state], fst_.Final(state)));
183    }
184 
185    while (!visited_states_.empty()) {
186      visited_[visited_states_.front()] = false;
187      visited_states_.pop_front();
188    }
189 }
190 
191 
192 // Removes epsilon-transitions (when both the input and output label
193 // are an epsilon) from a transducer. The result will be an equivalent
194 // FST that has no such epsilon transitions.  This version modifies
195 // its input. It allows fine control via the options argument; see
196 // below for a simpler interface.
197 //
198 // The vector 'distance' will be used to hold the shortest distances
199 // during the epsilon-closure computation. The state queue discipline
200 // and convergence delta are taken in the options argument.
201 template <class Arc, class Queue>
RmEpsilon(MutableFst<Arc> * fst,vector<typename Arc::Weight> * distance,const RmEpsilonOptions<Arc,Queue> & opts)202 void RmEpsilon(MutableFst<Arc> *fst,
203                vector<typename Arc::Weight> *distance,
204                const RmEpsilonOptions<Arc, Queue> &opts) {
205   typedef typename Arc::StateId StateId;
206   typedef typename Arc::Weight Weight;
207   typedef typename Arc::Label Label;
208 
209   // States sorted in topological order when (acyclic) or generic
210   // topological order (cyclic).
211   vector<StateId> states;
212 
213   if (fst->Properties(kTopSorted, false) & kTopSorted) {
214     for (StateId i = 0; i < (StateId)fst->NumStates(); i++)
215       states.push_back(i);
216   } else if (fst->Properties(kAcyclic, false) & kAcyclic) {
217     vector<StateId> order;
218     bool acyclic;
219     TopOrderVisitor<Arc> top_order_visitor(&order, &acyclic);
220     DfsVisit(*fst, &top_order_visitor, EpsilonArcFilter<Arc>());
221     if (!acyclic)
222       LOG(FATAL) << "RmEpsilon: not acyclic though property bit is set";
223     states.resize(order.size());
224     for (StateId i = 0; i < (StateId)order.size(); i++)
225       states[order[i]] = i;
226   } else {
227      uint64 props;
228      vector<StateId> scc;
229      SccVisitor<Arc> scc_visitor(&scc, 0, 0, &props);
230      DfsVisit(*fst, &scc_visitor, EpsilonArcFilter<Arc>());
231      vector<StateId> first(scc.size(), kNoStateId);
232      vector<StateId> next(scc.size(), kNoStateId);
233      for (StateId i = 0; i < (StateId)scc.size(); i++) {
234        if (first[scc[i]] != kNoStateId)
235          next[i] = first[scc[i]];
236        first[scc[i]] = i;
237      }
238      for (StateId i = 0; i < (StateId)first.size(); i++)
239        for (StateId j = first[i]; j != kNoStateId; j = next[j])
240          states.push_back(j);
241   }
242 
243   RmEpsilonState<Arc, Queue>
244     rmeps_state(*fst, distance, opts);
245 
246   while (!states.empty()) {
247     StateId state = states.back();
248     states.pop_back();
249     rmeps_state.Expand(state);
250     fst->SetFinal(state, rmeps_state.Final());
251     fst->DeleteArcs(state);
252     vector<Arc> &arcs = rmeps_state.Arcs();
253     while (!arcs.empty()) {
254       fst->AddArc(state, arcs.back());
255       arcs.pop_back();
256     }
257   }
258 
259   fst->SetProperties(RmEpsilonProperties(
260                          fst->Properties(kFstProperties, false)),
261                      kFstProperties);
262 
263   if (opts.connect)
264     Connect(fst);
265 }
266 
267 
268 // Removes epsilon-transitions (when both the input and output label
269 // are an epsilon) from a transducer. The result will be an equivalent
270 // FST that has no such epsilon transitions. This version modifies its
271 // input. It has a simplified interface; see above for a version that
272 // allows finer control.
273 //
274 // Complexity:
275 // - Time:
276 //   - Unweighted: O(V2 + V E)
277 //   - Acyclic: O(V2 + V E)
278 //   - Tropical semiring: O(V2 log V + V E)
279 //   - General: exponential
280 // - Space: O(V E)
281 // where V = # of states visited, E = # of arcs.
282 //
283 // References:
284 // - Mehryar Mohri. Generic Epsilon-Removal and Input
285 //   Epsilon-Normalization Algorithms for Weighted Transducers,
286 //   "International Journal of Computer Science", 13(1):129-143 (2002).
287 template <class Arc>
288 void RmEpsilon(MutableFst<Arc> *fst, bool connect = true) {
289   typedef typename Arc::StateId StateId;
290   typedef typename Arc::Weight Weight;
291   typedef typename Arc::Label Label;
292 
293   vector<Weight> distance;
294   AutoQueue<StateId> state_queue(*fst, &distance, EpsilonArcFilter<Arc>());
295   RmEpsilonOptions<Arc, AutoQueue<StateId> >
296     opts(&state_queue, kDelta, connect);
297 
298   RmEpsilon(fst, &distance, opts);
299 }
300 
301 
302 struct RmEpsilonFstOptions : CacheOptions {
303   float delta;
304 
305   RmEpsilonFstOptions(const CacheOptions &opts, float delta = kDelta)
CacheOptionsRmEpsilonFstOptions306       : CacheOptions(opts), delta(delta) {}
307 
deltaRmEpsilonFstOptions308   explicit RmEpsilonFstOptions(float delta = kDelta) : delta(delta) {}
309 };
310 
311 
312 // Implementation of delayed RmEpsilonFst.
313 template <class A>
314 class RmEpsilonFstImpl : public CacheImpl<A> {
315  public:
316   using FstImpl<A>::SetType;
317   using FstImpl<A>::SetProperties;
318   using FstImpl<A>::Properties;
319   using FstImpl<A>::SetInputSymbols;
320   using FstImpl<A>::SetOutputSymbols;
321 
322   using CacheBaseImpl< CacheState<A> >::HasStart;
323   using CacheBaseImpl< CacheState<A> >::HasFinal;
324   using CacheBaseImpl< CacheState<A> >::HasArcs;
325 
326   typedef typename A::Label Label;
327   typedef typename A::Weight Weight;
328   typedef typename A::StateId StateId;
329   typedef CacheState<A> State;
330 
RmEpsilonFstImpl(const Fst<A> & fst,const RmEpsilonFstOptions & opts)331   RmEpsilonFstImpl(const Fst<A>& fst, const RmEpsilonFstOptions &opts)
332       : CacheImpl<A>(opts),
333         fst_(fst.Copy()),
334         rmeps_state_(
335             *fst_,
336             &distance_,
337             RmEpsilonOptions<A, FifoQueue<StateId> >(&queue_, opts.delta,
338                                                      false)
339             ) {
340     SetType("rmepsilon");
341     uint64 props = fst.Properties(kFstProperties, false);
342     SetProperties(RmEpsilonProperties(props, true), kCopyProperties);
343   }
344 
~RmEpsilonFstImpl()345   ~RmEpsilonFstImpl() {
346     delete fst_;
347   }
348 
Start()349   StateId Start() {
350     if (!HasStart()) {
351       SetStart(fst_->Start());
352     }
353     return CacheImpl<A>::Start();
354   }
355 
Final(StateId s)356   Weight Final(StateId s) {
357     if (!HasFinal(s)) {
358       Expand(s);
359     }
360     return CacheImpl<A>::Final(s);
361   }
362 
NumArcs(StateId s)363   size_t NumArcs(StateId s) {
364     if (!HasArcs(s))
365       Expand(s);
366     return CacheImpl<A>::NumArcs(s);
367   }
368 
NumInputEpsilons(StateId s)369   size_t NumInputEpsilons(StateId s) {
370     if (!HasArcs(s))
371       Expand(s);
372     return CacheImpl<A>::NumInputEpsilons(s);
373   }
374 
NumOutputEpsilons(StateId s)375   size_t NumOutputEpsilons(StateId s) {
376     if (!HasArcs(s))
377       Expand(s);
378     return CacheImpl<A>::NumOutputEpsilons(s);
379   }
380 
InitArcIterator(StateId s,ArcIteratorData<A> * data)381   void InitArcIterator(StateId s, ArcIteratorData<A> *data) {
382     if (!HasArcs(s))
383       Expand(s);
384     CacheImpl<A>::InitArcIterator(s, data);
385   }
386 
Expand(StateId s)387   void Expand(StateId s) {
388     rmeps_state_.Expand(s);
389     SetFinal(s, rmeps_state_.Final());
390     vector<A> &arcs = rmeps_state_.Arcs();
391     while (!arcs.empty()) {
392       AddArc(s, arcs.back());
393       arcs.pop_back();
394     }
395     SetArcs(s);
396   }
397 
398  private:
399   const Fst<A> *fst_;
400   vector<Weight> distance_;
401   FifoQueue<StateId> queue_;
402   RmEpsilonState<A, FifoQueue<StateId> > rmeps_state_;
403 
404   DISALLOW_EVIL_CONSTRUCTORS(RmEpsilonFstImpl);
405 };
406 
407 
408 // Removes epsilon-transitions (when both the input and output label
409 // are an epsilon) from a transducer. The result will be an equivalent
410 // FST that has no such epsilon transitions.  This version is a
411 // delayed Fst.
412 //
413 // Complexity:
414 // - Time:
415 //   - Unweighted: O(v^2 + v e)
416 //   - General: exponential
417 // - Space: O(v e)
418 // where v = # of states visited, e = # of arcs visited. Constant time
419 // to visit an input state or arc is assumed and exclusive of caching.
420 //
421 // References:
422 // - Mehryar Mohri. Generic Epsilon-Removal and Input
423 //   Epsilon-Normalization Algorithms for Weighted Transducers,
424 //   "International Journal of Computer Science", 13(1):129-143 (2002).
425 template <class A>
426 class RmEpsilonFst : public Fst<A> {
427  public:
428   friend class ArcIterator< RmEpsilonFst<A> >;
429   friend class CacheStateIterator< RmEpsilonFst<A> >;
430   friend class CacheArcIterator< RmEpsilonFst<A> >;
431 
432   typedef A Arc;
433   typedef typename A::Weight Weight;
434   typedef typename A::StateId StateId;
435   typedef CacheState<A> State;
436 
RmEpsilonFst(const Fst<A> & fst)437   RmEpsilonFst(const Fst<A> &fst)
438       : impl_(new RmEpsilonFstImpl<A>(fst, RmEpsilonFstOptions())) {}
439 
RmEpsilonFst(const Fst<A> & fst,const RmEpsilonFstOptions & opts)440   RmEpsilonFst(const Fst<A> &fst, const RmEpsilonFstOptions &opts)
441       : impl_(new RmEpsilonFstImpl<A>(fst, opts)) {}
442 
RmEpsilonFst(const RmEpsilonFst<A> & fst)443   explicit RmEpsilonFst(const RmEpsilonFst<A> &fst) : impl_(fst.impl_) {
444     impl_->IncrRefCount();
445   }
446 
~RmEpsilonFst()447   virtual ~RmEpsilonFst() { if (!impl_->DecrRefCount()) delete impl_;  }
448 
Start()449   virtual StateId Start() const { return impl_->Start(); }
450 
Final(StateId s)451   virtual Weight Final(StateId s) const { return impl_->Final(s); }
452 
NumArcs(StateId s)453   virtual size_t NumArcs(StateId s) const { return impl_->NumArcs(s); }
454 
NumInputEpsilons(StateId s)455   virtual size_t NumInputEpsilons(StateId s) const {
456     return impl_->NumInputEpsilons(s);
457   }
458 
NumOutputEpsilons(StateId s)459   virtual size_t NumOutputEpsilons(StateId s) const {
460     return impl_->NumOutputEpsilons(s);
461   }
462 
Properties(uint64 mask,bool test)463   virtual uint64 Properties(uint64 mask, bool test) const {
464     if (test) {
465       uint64 known, test = TestProperties(*this, mask, &known);
466       impl_->SetProperties(test, known);
467       return test & mask;
468     } else {
469       return impl_->Properties(mask);
470     }
471   }
472 
Type()473   virtual const string& Type() const { return impl_->Type(); }
474 
Copy()475   virtual RmEpsilonFst<A> *Copy() const {
476     return new RmEpsilonFst<A>(*this);
477   }
478 
InputSymbols()479   virtual const SymbolTable* InputSymbols() const {
480     return impl_->InputSymbols();
481   }
482 
OutputSymbols()483   virtual const SymbolTable* OutputSymbols() const {
484     return impl_->OutputSymbols();
485   }
486 
487   virtual inline void InitStateIterator(StateIteratorData<A> *data) const;
488 
InitArcIterator(StateId s,ArcIteratorData<A> * data)489   virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
490     impl_->InitArcIterator(s, data);
491   }
492 
493  protected:
Impl()494   RmEpsilonFstImpl<A> *Impl() { return impl_; }
495 
496  private:
497   RmEpsilonFstImpl<A> *impl_;
498 
499   void operator=(const RmEpsilonFst<A> &fst);  // disallow
500 };
501 
502 
503 // Specialization for RmEpsilonFst.
504 template<class A>
505 class StateIterator< RmEpsilonFst<A> >
506     : public CacheStateIterator< RmEpsilonFst<A> > {
507  public:
StateIterator(const RmEpsilonFst<A> & fst)508   explicit StateIterator(const RmEpsilonFst<A> &fst)
509       : CacheStateIterator< RmEpsilonFst<A> >(fst) {}
510 };
511 
512 
513 // Specialization for RmEpsilonFst.
514 template <class A>
515 class ArcIterator< RmEpsilonFst<A> >
516     : public CacheArcIterator< RmEpsilonFst<A> > {
517  public:
518   typedef typename A::StateId StateId;
519 
ArcIterator(const RmEpsilonFst<A> & fst,StateId s)520   ArcIterator(const RmEpsilonFst<A> &fst, StateId s)
521       : CacheArcIterator< RmEpsilonFst<A> >(fst, s) {
522     if (!fst.impl_->HasArcs(s))
523       fst.impl_->Expand(s);
524   }
525 
526  private:
527   DISALLOW_EVIL_CONSTRUCTORS(ArcIterator);
528 };
529 
530 
531 template <class A> inline
InitStateIterator(StateIteratorData<A> * data)532 void RmEpsilonFst<A>::InitStateIterator(StateIteratorData<A> *data) const {
533   data->base = new StateIterator< RmEpsilonFst<A> >(*this);
534 }
535 
536 
537 // Useful alias when using StdArc.
538 typedef RmEpsilonFst<StdArc> StdRmEpsilonFst;
539 
540 }  // namespace fst
541 
542 #endif  // FST_LIB_RMEPSILON_H__
543