1 // rmepsilon.h
2
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Copyright 2005-2010 Google, Inc.
16 // Author: allauzen@google.com (Cyril Allauzen)
17 //
18 // \file
19 // Functions and classes that implemement epsilon-removal.
20
21 #ifndef FST_LIB_RMEPSILON_H__
22 #define FST_LIB_RMEPSILON_H__
23
24 #include <tr1/unordered_map>
25 using std::tr1::unordered_map;
26 using std::tr1::unordered_multimap;
27 #include <fst/slist.h>
28 #include <stack>
29 #include <string>
30 #include <utility>
31 using std::pair; using std::make_pair;
32 #include <vector>
33 using std::vector;
34
35 #include <fst/arcfilter.h>
36 #include <fst/cache.h>
37 #include <fst/connect.h>
38 #include <fst/factor-weight.h>
39 #include <fst/invert.h>
40 #include <fst/prune.h>
41 #include <fst/queue.h>
42 #include <fst/shortest-distance.h>
43 #include <fst/topsort.h>
44
45
46 namespace fst {
47
48 template <class Arc, class Queue>
49 class RmEpsilonOptions
50 : public ShortestDistanceOptions<Arc, Queue, EpsilonArcFilter<Arc> > {
51 public:
52 typedef typename Arc::StateId StateId;
53 typedef typename Arc::Weight Weight;
54
55 bool connect; // Connect output
56 Weight weight_threshold; // Pruning weight threshold.
57 StateId state_threshold; // Pruning state threshold.
58
59 explicit RmEpsilonOptions(Queue *q, float d = kDelta, bool c = true,
60 Weight w = Weight::Zero(),
61 StateId n = kNoStateId)
62 : ShortestDistanceOptions< Arc, Queue, EpsilonArcFilter<Arc> >(
63 q, EpsilonArcFilter<Arc>(), kNoStateId, d),
64 connect(c), weight_threshold(w), state_threshold(n) {}
65 private:
66 RmEpsilonOptions(); // disallow
67 };
68
69 // Computation state of the epsilon-removal algorithm.
70 template <class Arc, class Queue>
71 class RmEpsilonState {
72 public:
73 typedef typename Arc::Label Label;
74 typedef typename Arc::StateId StateId;
75 typedef typename Arc::Weight Weight;
76
RmEpsilonState(const Fst<Arc> & fst,vector<Weight> * distance,const RmEpsilonOptions<Arc,Queue> & opts)77 RmEpsilonState(const Fst<Arc> &fst,
78 vector<Weight> *distance,
79 const RmEpsilonOptions<Arc, Queue> &opts)
80 : fst_(fst), distance_(distance), sd_state_(fst_, distance, opts, true),
81 expand_id_(0) {}
82
83 // Compute arcs and final weight for state 's'
84 void Expand(StateId s);
85
86 // Returns arcs of expanded state.
Arcs()87 vector<Arc> &Arcs() { return arcs_; }
88
89 // Returns final weight of expanded state.
Final()90 const Weight &Final() const { return final_; }
91
92 // Return true if an error has occured.
Error()93 bool Error() const { return sd_state_.Error(); }
94
95 private:
96 static const size_t kPrime0 = 7853;
97 static const size_t kPrime1 = 7867;
98
99 struct Element {
100 Label ilabel;
101 Label olabel;
102 StateId nextstate;
103
ElementElement104 Element() {}
105
ElementElement106 Element(Label i, Label o, StateId s)
107 : ilabel(i), olabel(o), nextstate(s) {}
108 };
109
110 class ElementKey {
111 public:
operator()112 size_t operator()(const Element& e) const {
113 return static_cast<size_t>(e.nextstate +
114 e.ilabel * kPrime0 +
115 e.olabel * kPrime1);
116 }
117
118 private:
119 };
120
121 class ElementEqual {
122 public:
operator()123 bool operator()(const Element &e1, const Element &e2) const {
124 return (e1.ilabel == e2.ilabel) && (e1.olabel == e2.olabel)
125 && (e1.nextstate == e2.nextstate);
126 }
127 };
128
129 typedef unordered_map<Element, pair<StateId, size_t>,
130 ElementKey, ElementEqual> ElementMap;
131
132 const Fst<Arc> &fst_;
133 // Distance from state being expanded in epsilon-closure.
134 vector<Weight> *distance_;
135 // Shortest distance algorithm computation state.
136 ShortestDistanceState<Arc, Queue, EpsilonArcFilter<Arc> > sd_state_;
137 // Maps an element 'e' to a pair 'p' corresponding to a position
138 // in the arcs vector of the state being expanded. 'e' corresponds
139 // to the position 'p.second' in the 'arcs_' vector if 'p.first' is
140 // equal to the state being expanded.
141 ElementMap element_map_;
142 EpsilonArcFilter<Arc> eps_filter_;
143 stack<StateId> eps_queue_; // Queue used to visit the epsilon-closure
144 vector<bool> visited_; // '[i] = true' if state 'i' has been visited
145 slist<StateId> visited_states_; // List of visited states
146 vector<Arc> arcs_; // Arcs of state being expanded
147 Weight final_; // Final weight of state being expanded
148 StateId expand_id_; // Unique ID for each call to Expand
149
150 DISALLOW_COPY_AND_ASSIGN(RmEpsilonState);
151 };
152
153 template <class Arc, class Queue>
154 const size_t RmEpsilonState<Arc, Queue>::kPrime0;
155 template <class Arc, class Queue>
156 const size_t RmEpsilonState<Arc, Queue>::kPrime1;
157
158
159 template <class Arc, class Queue>
Expand(typename Arc::StateId source)160 void RmEpsilonState<Arc,Queue>::Expand(typename Arc::StateId source) {
161 final_ = Weight::Zero();
162 arcs_.clear();
163 sd_state_.ShortestDistance(source);
164 if (sd_state_.Error())
165 return;
166 eps_queue_.push(source);
167
168 while (!eps_queue_.empty()) {
169 StateId state = eps_queue_.top();
170 eps_queue_.pop();
171
172 while (visited_.size() <= state) visited_.push_back(false);
173 if (visited_[state]) continue;
174 visited_[state] = true;
175 visited_states_.push_front(state);
176
177 for (ArcIterator< Fst<Arc> > ait(fst_, state);
178 !ait.Done();
179 ait.Next()) {
180 Arc arc = ait.Value();
181 arc.weight = Times((*distance_)[state], arc.weight);
182
183 if (eps_filter_(arc)) {
184 while (visited_.size() <= arc.nextstate)
185 visited_.push_back(false);
186 if (!visited_[arc.nextstate])
187 eps_queue_.push(arc.nextstate);
188 } else {
189 Element element(arc.ilabel, arc.olabel, arc.nextstate);
190 typename ElementMap::iterator it = element_map_.find(element);
191 if (it == element_map_.end()) {
192 element_map_.insert(
193 pair<Element, pair<StateId, size_t> >
194 (element, pair<StateId, size_t>(expand_id_, arcs_.size())));
195 arcs_.push_back(arc);
196 } else {
197 if (((*it).second).first == expand_id_) {
198 Weight &w = arcs_[((*it).second).second].weight;
199 w = Plus(w, arc.weight);
200 } else {
201 ((*it).second).first = expand_id_;
202 ((*it).second).second = arcs_.size();
203 arcs_.push_back(arc);
204 }
205 }
206 }
207 }
208 final_ = Plus(final_, Times((*distance_)[state], fst_.Final(state)));
209 }
210
211 while (!visited_states_.empty()) {
212 visited_[visited_states_.front()] = false;
213 visited_states_.pop_front();
214 }
215 ++expand_id_;
216 }
217
218 // Removes epsilon-transitions (when both the input and output label
219 // are an epsilon) from a transducer. The result will be an equivalent
220 // FST that has no such epsilon transitions. This version modifies
221 // its input. It allows fine control via the options argument; see
222 // below for a simpler interface.
223 //
224 // The vector 'distance' will be used to hold the shortest distances
225 // during the epsilon-closure computation. The state queue discipline
226 // and convergence delta are taken in the options argument.
227 template <class Arc, class Queue>
RmEpsilon(MutableFst<Arc> * fst,vector<typename Arc::Weight> * distance,const RmEpsilonOptions<Arc,Queue> & opts)228 void RmEpsilon(MutableFst<Arc> *fst,
229 vector<typename Arc::Weight> *distance,
230 const RmEpsilonOptions<Arc, Queue> &opts) {
231 typedef typename Arc::StateId StateId;
232 typedef typename Arc::Weight Weight;
233 typedef typename Arc::Label Label;
234
235 if (fst->Start() == kNoStateId) {
236 return;
237 }
238
239 // 'noneps_in[s]' will be set to true iff 's' admits a non-epsilon
240 // incoming transition or is the start state.
241 vector<bool> noneps_in(fst->NumStates(), false);
242 noneps_in[fst->Start()] = true;
243 for (StateId i = 0; i < fst->NumStates(); ++i) {
244 for (ArcIterator<Fst<Arc> > aiter(*fst, i);
245 !aiter.Done();
246 aiter.Next()) {
247 if (aiter.Value().ilabel != 0 || aiter.Value().olabel != 0)
248 noneps_in[aiter.Value().nextstate] = true;
249 }
250 }
251
252 // States sorted in topological order when (acyclic) or generic
253 // topological order (cyclic).
254 vector<StateId> states;
255 states.reserve(fst->NumStates());
256
257 if (fst->Properties(kTopSorted, false) & kTopSorted) {
258 for (StateId i = 0; i < fst->NumStates(); i++)
259 states.push_back(i);
260 } else if (fst->Properties(kAcyclic, false) & kAcyclic) {
261 vector<StateId> order;
262 bool acyclic;
263 TopOrderVisitor<Arc> top_order_visitor(&order, &acyclic);
264 DfsVisit(*fst, &top_order_visitor, EpsilonArcFilter<Arc>());
265 // Sanity check: should be acyclic if property bit is set.
266 if(!acyclic) {
267 FSTERROR() << "RmEpsilon: inconsistent acyclic property bit";
268 fst->SetProperties(kError, kError);
269 return;
270 }
271 states.resize(order.size());
272 for (StateId i = 0; i < order.size(); i++)
273 states[order[i]] = i;
274 } else {
275 uint64 props;
276 vector<StateId> scc;
277 SccVisitor<Arc> scc_visitor(&scc, 0, 0, &props);
278 DfsVisit(*fst, &scc_visitor, EpsilonArcFilter<Arc>());
279 vector<StateId> first(scc.size(), kNoStateId);
280 vector<StateId> next(scc.size(), kNoStateId);
281 for (StateId i = 0; i < scc.size(); i++) {
282 if (first[scc[i]] != kNoStateId)
283 next[i] = first[scc[i]];
284 first[scc[i]] = i;
285 }
286 for (StateId i = 0; i < first.size(); i++)
287 for (StateId j = first[i]; j != kNoStateId; j = next[j])
288 states.push_back(j);
289 }
290
291 RmEpsilonState<Arc, Queue>
292 rmeps_state(*fst, distance, opts);
293
294 while (!states.empty()) {
295 StateId state = states.back();
296 states.pop_back();
297 if (!noneps_in[state])
298 continue;
299 rmeps_state.Expand(state);
300 fst->SetFinal(state, rmeps_state.Final());
301 fst->DeleteArcs(state);
302 vector<Arc> &arcs = rmeps_state.Arcs();
303 fst->ReserveArcs(state, arcs.size());
304 while (!arcs.empty()) {
305 fst->AddArc(state, arcs.back());
306 arcs.pop_back();
307 }
308 }
309
310 for (StateId s = 0; s < fst->NumStates(); ++s) {
311 if (!noneps_in[s])
312 fst->DeleteArcs(s);
313 }
314
315 if(rmeps_state.Error())
316 fst->SetProperties(kError, kError);
317 fst->SetProperties(
318 RmEpsilonProperties(fst->Properties(kFstProperties, false)),
319 kFstProperties);
320
321 if (opts.weight_threshold != Weight::Zero() ||
322 opts.state_threshold != kNoStateId)
323 Prune(fst, opts.weight_threshold, opts.state_threshold);
324 if (opts.connect && (opts.weight_threshold == Weight::Zero() ||
325 opts.state_threshold != kNoStateId))
326 Connect(fst);
327 }
328
329 // Removes epsilon-transitions (when both the input and output label
330 // are an epsilon) from a transducer. The result will be an equivalent
331 // FST that has no such epsilon transitions. This version modifies its
332 // input. It has a simplified interface; see above for a version that
333 // allows finer control.
334 //
335 // Complexity:
336 // - Time:
337 // - Unweighted: O(V2 + V E)
338 // - Acyclic: O(V2 + V E)
339 // - Tropical semiring: O(V2 log V + V E)
340 // - General: exponential
341 // - Space: O(V E)
342 // where V = # of states visited, E = # of arcs.
343 //
344 // References:
345 // - Mehryar Mohri. Generic Epsilon-Removal and Input
346 // Epsilon-Normalization Algorithms for Weighted Transducers,
347 // "International Journal of Computer Science", 13(1):129-143 (2002).
348 template <class Arc>
349 void RmEpsilon(MutableFst<Arc> *fst,
350 bool connect = true,
351 typename Arc::Weight weight_threshold = Arc::Weight::Zero(),
352 typename Arc::StateId state_threshold = kNoStateId,
353 float delta = kDelta) {
354 typedef typename Arc::StateId StateId;
355 typedef typename Arc::Weight Weight;
356 typedef typename Arc::Label Label;
357
358 vector<Weight> distance;
359 AutoQueue<StateId> state_queue(*fst, &distance, EpsilonArcFilter<Arc>());
360 RmEpsilonOptions<Arc, AutoQueue<StateId> >
361 opts(&state_queue, delta, connect, weight_threshold, state_threshold);
362
363 RmEpsilon(fst, &distance, opts);
364 }
365
366
367 struct RmEpsilonFstOptions : CacheOptions {
368 float delta;
369
370 RmEpsilonFstOptions(const CacheOptions &opts, float delta = kDelta)
CacheOptionsRmEpsilonFstOptions371 : CacheOptions(opts), delta(delta) {}
372
deltaRmEpsilonFstOptions373 explicit RmEpsilonFstOptions(float delta = kDelta) : delta(delta) {}
374 };
375
376
377 // Implementation of delayed RmEpsilonFst.
378 template <class A>
379 class RmEpsilonFstImpl : public CacheImpl<A> {
380 public:
381 using FstImpl<A>::SetType;
382 using FstImpl<A>::SetProperties;
383 using FstImpl<A>::SetInputSymbols;
384 using FstImpl<A>::SetOutputSymbols;
385
386 using CacheBaseImpl< CacheState<A> >::PushArc;
387 using CacheBaseImpl< CacheState<A> >::HasArcs;
388 using CacheBaseImpl< CacheState<A> >::HasFinal;
389 using CacheBaseImpl< CacheState<A> >::HasStart;
390 using CacheBaseImpl< CacheState<A> >::SetArcs;
391 using CacheBaseImpl< CacheState<A> >::SetFinal;
392 using CacheBaseImpl< CacheState<A> >::SetStart;
393
394 typedef typename A::Label Label;
395 typedef typename A::Weight Weight;
396 typedef typename A::StateId StateId;
397 typedef CacheState<A> State;
398
RmEpsilonFstImpl(const Fst<A> & fst,const RmEpsilonFstOptions & opts)399 RmEpsilonFstImpl(const Fst<A>& fst, const RmEpsilonFstOptions &opts)
400 : CacheImpl<A>(opts),
401 fst_(fst.Copy()),
402 delta_(opts.delta),
403 rmeps_state_(
404 *fst_,
405 &distance_,
406 RmEpsilonOptions<A, FifoQueue<StateId> >(&queue_, delta_, false)) {
407 SetType("rmepsilon");
408 uint64 props = fst.Properties(kFstProperties, false);
409 SetProperties(RmEpsilonProperties(props, true), kCopyProperties);
410 SetInputSymbols(fst.InputSymbols());
411 SetOutputSymbols(fst.OutputSymbols());
412 }
413
RmEpsilonFstImpl(const RmEpsilonFstImpl & impl)414 RmEpsilonFstImpl(const RmEpsilonFstImpl &impl)
415 : CacheImpl<A>(impl),
416 fst_(impl.fst_->Copy(true)),
417 delta_(impl.delta_),
418 rmeps_state_(
419 *fst_,
420 &distance_,
421 RmEpsilonOptions<A, FifoQueue<StateId> >(&queue_, delta_, false)) {
422 SetType("rmepsilon");
423 SetProperties(impl.Properties(), kCopyProperties);
424 SetInputSymbols(impl.InputSymbols());
425 SetOutputSymbols(impl.OutputSymbols());
426 }
427
~RmEpsilonFstImpl()428 ~RmEpsilonFstImpl() {
429 delete fst_;
430 }
431
Start()432 StateId Start() {
433 if (!HasStart()) {
434 SetStart(fst_->Start());
435 }
436 return CacheImpl<A>::Start();
437 }
438
Final(StateId s)439 Weight Final(StateId s) {
440 if (!HasFinal(s)) {
441 Expand(s);
442 }
443 return CacheImpl<A>::Final(s);
444 }
445
NumArcs(StateId s)446 size_t NumArcs(StateId s) {
447 if (!HasArcs(s))
448 Expand(s);
449 return CacheImpl<A>::NumArcs(s);
450 }
451
NumInputEpsilons(StateId s)452 size_t NumInputEpsilons(StateId s) {
453 if (!HasArcs(s))
454 Expand(s);
455 return CacheImpl<A>::NumInputEpsilons(s);
456 }
457
NumOutputEpsilons(StateId s)458 size_t NumOutputEpsilons(StateId s) {
459 if (!HasArcs(s))
460 Expand(s);
461 return CacheImpl<A>::NumOutputEpsilons(s);
462 }
463
Properties()464 uint64 Properties() const { return Properties(kFstProperties); }
465
466 // Set error if found; return FST impl properties.
Properties(uint64 mask)467 uint64 Properties(uint64 mask) const {
468 if ((mask & kError) &&
469 (fst_->Properties(kError, false) || rmeps_state_.Error()))
470 SetProperties(kError, kError);
471 return FstImpl<A>::Properties(mask);
472 }
473
InitArcIterator(StateId s,ArcIteratorData<A> * data)474 void InitArcIterator(StateId s, ArcIteratorData<A> *data) {
475 if (!HasArcs(s))
476 Expand(s);
477 CacheImpl<A>::InitArcIterator(s, data);
478 }
479
Expand(StateId s)480 void Expand(StateId s) {
481 rmeps_state_.Expand(s);
482 SetFinal(s, rmeps_state_.Final());
483 vector<A> &arcs = rmeps_state_.Arcs();
484 while (!arcs.empty()) {
485 PushArc(s, arcs.back());
486 arcs.pop_back();
487 }
488 SetArcs(s);
489 }
490
491 private:
492 const Fst<A> *fst_;
493 float delta_;
494 vector<Weight> distance_;
495 FifoQueue<StateId> queue_;
496 RmEpsilonState<A, FifoQueue<StateId> > rmeps_state_;
497
498 void operator=(const RmEpsilonFstImpl<A> &); // disallow
499 };
500
501
502 // Removes epsilon-transitions (when both the input and output label
503 // are an epsilon) from a transducer. The result will be an equivalent
504 // FST that has no such epsilon transitions. This version is a
505 // delayed Fst.
506 //
507 // Complexity:
508 // - Time:
509 // - Unweighted: O(v^2 + v e)
510 // - General: exponential
511 // - Space: O(v e)
512 // where v = # of states visited, e = # of arcs visited. Constant time
513 // to visit an input state or arc is assumed and exclusive of caching.
514 //
515 // References:
516 // - Mehryar Mohri. Generic Epsilon-Removal and Input
517 // Epsilon-Normalization Algorithms for Weighted Transducers,
518 // "International Journal of Computer Science", 13(1):129-143 (2002).
519 //
520 // This class attaches interface to implementation and handles
521 // reference counting, delegating most methods to ImplToFst.
522 template <class A>
523 class RmEpsilonFst : public ImplToFst< RmEpsilonFstImpl<A> > {
524 public:
525 friend class ArcIterator< RmEpsilonFst<A> >;
526 friend class StateIterator< RmEpsilonFst<A> >;
527
528 typedef A Arc;
529 typedef typename A::StateId StateId;
530 typedef CacheState<A> State;
531 typedef RmEpsilonFstImpl<A> Impl;
532
RmEpsilonFst(const Fst<A> & fst)533 RmEpsilonFst(const Fst<A> &fst)
534 : ImplToFst<Impl>(new Impl(fst, RmEpsilonFstOptions())) {}
535
RmEpsilonFst(const Fst<A> & fst,const RmEpsilonFstOptions & opts)536 RmEpsilonFst(const Fst<A> &fst, const RmEpsilonFstOptions &opts)
537 : ImplToFst<Impl>(new Impl(fst, opts)) {}
538
539 // See Fst<>::Copy() for doc.
540 RmEpsilonFst(const RmEpsilonFst<A> &fst, bool safe = false)
541 : ImplToFst<Impl>(fst, safe) {}
542
543 // Get a copy of this RmEpsilonFst. See Fst<>::Copy() for further doc.
544 virtual RmEpsilonFst<A> *Copy(bool safe = false) const {
545 return new RmEpsilonFst<A>(*this, safe);
546 }
547
548 virtual inline void InitStateIterator(StateIteratorData<A> *data) const;
549
InitArcIterator(StateId s,ArcIteratorData<Arc> * data)550 virtual void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const {
551 GetImpl()->InitArcIterator(s, data);
552 }
553
554 private:
555 // Makes visible to friends.
GetImpl()556 Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); }
557
558 void operator=(const RmEpsilonFst<A> &fst); // disallow
559 };
560
561 // Specialization for RmEpsilonFst.
562 template<class A>
563 class StateIterator< RmEpsilonFst<A> >
564 : public CacheStateIterator< RmEpsilonFst<A> > {
565 public:
StateIterator(const RmEpsilonFst<A> & fst)566 explicit StateIterator(const RmEpsilonFst<A> &fst)
567 : CacheStateIterator< RmEpsilonFst<A> >(fst, fst.GetImpl()) {}
568 };
569
570
571 // Specialization for RmEpsilonFst.
572 template <class A>
573 class ArcIterator< RmEpsilonFst<A> >
574 : public CacheArcIterator< RmEpsilonFst<A> > {
575 public:
576 typedef typename A::StateId StateId;
577
ArcIterator(const RmEpsilonFst<A> & fst,StateId s)578 ArcIterator(const RmEpsilonFst<A> &fst, StateId s)
579 : CacheArcIterator< RmEpsilonFst<A> >(fst.GetImpl(), s) {
580 if (!fst.GetImpl()->HasArcs(s))
581 fst.GetImpl()->Expand(s);
582 }
583
584 private:
585 DISALLOW_COPY_AND_ASSIGN(ArcIterator);
586 };
587
588
589 template <class A> inline
InitStateIterator(StateIteratorData<A> * data)590 void RmEpsilonFst<A>::InitStateIterator(StateIteratorData<A> *data) const {
591 data->base = new StateIterator< RmEpsilonFst<A> >(*this);
592 }
593
594
595 // Useful alias when using StdArc.
596 typedef RmEpsilonFst<StdArc> StdRmEpsilonFst;
597
598 } // namespace fst
599
600 #endif // FST_LIB_RMEPSILON_H__
601