1 // rmepsilon.h
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Author: allauzen@cs.nyu.edu (Cyril Allauzen)
16 //
17 // \file
18 // Functions and classes that implemement epsilon-removal.
19
20 #ifndef FST_LIB_RMEPSILON_H__
21 #define FST_LIB_RMEPSILON_H__
22
23 #include <ext/hash_map>
24 using __gnu_cxx::hash_map;
25 #include <ext/slist>
26 using __gnu_cxx::slist;
27
28 #include "fst/lib/arcfilter.h"
29 #include "fst/lib/cache.h"
30 #include "fst/lib/connect.h"
31 #include "fst/lib/factor-weight.h"
32 #include "fst/lib/invert.h"
33 #include "fst/lib/map.h"
34 #include "fst/lib/queue.h"
35 #include "fst/lib/shortest-distance.h"
36 #include "fst/lib/topsort.h"
37
38 namespace fst {
39
40 template <class Arc, class Queue>
41 struct RmEpsilonOptions
42 : public ShortestDistanceOptions<Arc, Queue, EpsilonArcFilter<Arc> > {
43 typedef typename Arc::StateId StateId;
44
45 bool connect; // Connect output
46
47 RmEpsilonOptions(Queue *q, float d = kDelta, bool c = true)
48 : ShortestDistanceOptions<Arc, Queue, EpsilonArcFilter<Arc> >(
49 q, EpsilonArcFilter<Arc>(), kNoStateId, d), connect(c) {}
50
51 };
52
53
54 // Computation state of the epsilon-removal algorithm.
55 template <class Arc, class Queue>
56 class RmEpsilonState {
57 public:
58 typedef typename Arc::Label Label;
59 typedef typename Arc::StateId StateId;
60 typedef typename Arc::Weight Weight;
61
RmEpsilonState(const Fst<Arc> & fst,vector<Weight> * distance,const RmEpsilonOptions<Arc,Queue> & opts)62 RmEpsilonState(const Fst<Arc> &fst,
63 vector<Weight> *distance,
64 const RmEpsilonOptions<Arc, Queue> &opts)
65 : fst_(fst), distance_(distance), sd_state_(fst_, distance, opts, true) {
66 }
67
68 // Compute arcs and final weight for state 's'
69 void Expand(StateId s);
70
71 // Returns arcs of expanded state.
Arcs()72 vector<Arc> &Arcs() { return arcs_; }
73
74 // Returns final weight of expanded state.
Final()75 const Weight &Final() const { return final_; }
76
77 private:
78 struct Element {
79 Label ilabel;
80 Label olabel;
81 StateId nextstate;
82
ElementElement83 Element() {}
84
ElementElement85 Element(Label i, Label o, StateId s)
86 : ilabel(i), olabel(o), nextstate(s) {}
87 };
88
89 class ElementKey {
90 public:
operator()91 size_t operator()(const Element& e) const {
92 return static_cast<size_t>(e.nextstate);
93 return static_cast<size_t>(e.nextstate +
94 e.ilabel * kPrime0 +
95 e.olabel * kPrime1);
96 }
97
98 private:
99 static const int kPrime0 = 7853;
100 static const int kPrime1 = 7867;
101 };
102
103 class ElementEqual {
104 public:
operator()105 bool operator()(const Element &e1, const Element &e2) const {
106 return (e1.ilabel == e2.ilabel) && (e1.olabel == e2.olabel)
107 && (e1.nextstate == e2.nextstate);
108 }
109 };
110
111 private:
112 typedef hash_map<Element, pair<StateId, ssize_t>,
113 ElementKey, ElementEqual> ElementMap;
114
115 const Fst<Arc> &fst_;
116 // Distance from state being expanded in epsilon-closure.
117 vector<Weight> *distance_;
118 // Shortest distance algorithm computation state.
119 ShortestDistanceState<Arc, Queue, EpsilonArcFilter<Arc> > sd_state_;
120 // Maps an element 'e' to a pair 'p' corresponding to a position
121 // in the arcs vector of the state being expanded. 'e' corresponds
122 // to the position 'p.second' in the 'arcs_' vector if 'p.first' is
123 // equal to the state being expanded.
124 ElementMap element_map_;
125 EpsilonArcFilter<Arc> eps_filter_;
126 stack<StateId> eps_queue_; // Queue used to visit the epsilon-closure
127 vector<bool> visited_; // '[i] = true' if state 'i' has been visited
128 slist<StateId> visited_states_; // List of visited states
129 vector<Arc> arcs_; // Arcs of state being expanded
130 Weight final_; // Final weight of state being expanded
131
132 void operator=(const RmEpsilonState); // Disallow
133 };
134
135
136 template <class Arc, class Queue>
Expand(typename Arc::StateId source)137 void RmEpsilonState<Arc,Queue>::Expand(typename Arc::StateId source) {
138 sd_state_.ShortestDistance(source);
139 eps_queue_.push(source);
140 final_ = Weight::Zero();
141 arcs_.clear();
142
143 while (!eps_queue_.empty()) {
144 StateId state = eps_queue_.top();
145 eps_queue_.pop();
146
147 while ((StateId)visited_.size() <= state) visited_.push_back(false);
148 visited_[state] = true;
149 visited_states_.push_front(state);
150
151 for (ArcIterator< Fst<Arc> > ait(fst_, state);
152 !ait.Done();
153 ait.Next()) {
154 Arc arc = ait.Value();
155 arc.weight = Times((*distance_)[state], arc.weight);
156
157 if (eps_filter_(arc)) {
158 while ((StateId)visited_.size() <= arc.nextstate)
159 visited_.push_back(false);
160 if (!visited_[arc.nextstate])
161 eps_queue_.push(arc.nextstate);
162 } else {
163 Element element(arc.ilabel, arc.olabel, arc.nextstate);
164 typename ElementMap::iterator it = element_map_.find(element);
165 if (it == element_map_.end()) {
166 element_map_.insert(
167 pair<Element, pair<StateId, ssize_t> >
168 (element, pair<StateId, ssize_t>(source, arcs_.size())));
169 arcs_.push_back(arc);
170 } else {
171 if (((*it).second).first == source) {
172 Weight &w = arcs_[((*it).second).second].weight;
173 w = Plus(w, arc.weight);
174 } else {
175 ((*it).second).first = source;
176 ((*it).second).second = arcs_.size();
177 arcs_.push_back(arc);
178 }
179 }
180 }
181 }
182 final_ = Plus(final_, Times((*distance_)[state], fst_.Final(state)));
183 }
184
185 while (!visited_states_.empty()) {
186 visited_[visited_states_.front()] = false;
187 visited_states_.pop_front();
188 }
189 }
190
191
192 // Removes epsilon-transitions (when both the input and output label
193 // are an epsilon) from a transducer. The result will be an equivalent
194 // FST that has no such epsilon transitions. This version modifies
195 // its input. It allows fine control via the options argument; see
196 // below for a simpler interface.
197 //
198 // The vector 'distance' will be used to hold the shortest distances
199 // during the epsilon-closure computation. The state queue discipline
200 // and convergence delta are taken in the options argument.
201 template <class Arc, class Queue>
RmEpsilon(MutableFst<Arc> * fst,vector<typename Arc::Weight> * distance,const RmEpsilonOptions<Arc,Queue> & opts)202 void RmEpsilon(MutableFst<Arc> *fst,
203 vector<typename Arc::Weight> *distance,
204 const RmEpsilonOptions<Arc, Queue> &opts) {
205 typedef typename Arc::StateId StateId;
206 typedef typename Arc::Weight Weight;
207 typedef typename Arc::Label Label;
208
209 // States sorted in topological order when (acyclic) or generic
210 // topological order (cyclic).
211 vector<StateId> states;
212
213 if (fst->Properties(kTopSorted, false) & kTopSorted) {
214 for (StateId i = 0; i < (StateId)fst->NumStates(); i++)
215 states.push_back(i);
216 } else if (fst->Properties(kAcyclic, false) & kAcyclic) {
217 vector<StateId> order;
218 bool acyclic;
219 TopOrderVisitor<Arc> top_order_visitor(&order, &acyclic);
220 DfsVisit(*fst, &top_order_visitor, EpsilonArcFilter<Arc>());
221 if (!acyclic)
222 LOG(FATAL) << "RmEpsilon: not acyclic though property bit is set";
223 states.resize(order.size());
224 for (StateId i = 0; i < (StateId)order.size(); i++)
225 states[order[i]] = i;
226 } else {
227 uint64 props;
228 vector<StateId> scc;
229 SccVisitor<Arc> scc_visitor(&scc, 0, 0, &props);
230 DfsVisit(*fst, &scc_visitor, EpsilonArcFilter<Arc>());
231 vector<StateId> first(scc.size(), kNoStateId);
232 vector<StateId> next(scc.size(), kNoStateId);
233 for (StateId i = 0; i < (StateId)scc.size(); i++) {
234 if (first[scc[i]] != kNoStateId)
235 next[i] = first[scc[i]];
236 first[scc[i]] = i;
237 }
238 for (StateId i = 0; i < (StateId)first.size(); i++)
239 for (StateId j = first[i]; j != kNoStateId; j = next[j])
240 states.push_back(j);
241 }
242
243 RmEpsilonState<Arc, Queue>
244 rmeps_state(*fst, distance, opts);
245
246 while (!states.empty()) {
247 StateId state = states.back();
248 states.pop_back();
249 rmeps_state.Expand(state);
250 fst->SetFinal(state, rmeps_state.Final());
251 fst->DeleteArcs(state);
252 vector<Arc> &arcs = rmeps_state.Arcs();
253 while (!arcs.empty()) {
254 fst->AddArc(state, arcs.back());
255 arcs.pop_back();
256 }
257 }
258
259 fst->SetProperties(RmEpsilonProperties(
260 fst->Properties(kFstProperties, false)),
261 kFstProperties);
262
263 if (opts.connect)
264 Connect(fst);
265 }
266
267
268 // Removes epsilon-transitions (when both the input and output label
269 // are an epsilon) from a transducer. The result will be an equivalent
270 // FST that has no such epsilon transitions. This version modifies its
271 // input. It has a simplified interface; see above for a version that
272 // allows finer control.
273 //
274 // Complexity:
275 // - Time:
276 // - Unweighted: O(V2 + V E)
277 // - Acyclic: O(V2 + V E)
278 // - Tropical semiring: O(V2 log V + V E)
279 // - General: exponential
280 // - Space: O(V E)
281 // where V = # of states visited, E = # of arcs.
282 //
283 // References:
284 // - Mehryar Mohri. Generic Epsilon-Removal and Input
285 // Epsilon-Normalization Algorithms for Weighted Transducers,
286 // "International Journal of Computer Science", 13(1):129-143 (2002).
287 template <class Arc>
288 void RmEpsilon(MutableFst<Arc> *fst, bool connect = true) {
289 typedef typename Arc::StateId StateId;
290 typedef typename Arc::Weight Weight;
291 typedef typename Arc::Label Label;
292
293 vector<Weight> distance;
294 AutoQueue<StateId> state_queue(*fst, &distance, EpsilonArcFilter<Arc>());
295 RmEpsilonOptions<Arc, AutoQueue<StateId> >
296 opts(&state_queue, kDelta, connect);
297
298 RmEpsilon(fst, &distance, opts);
299 }
300
301
302 struct RmEpsilonFstOptions : CacheOptions {
303 float delta;
304
305 RmEpsilonFstOptions(const CacheOptions &opts, float delta = kDelta)
CacheOptionsRmEpsilonFstOptions306 : CacheOptions(opts), delta(delta) {}
307
deltaRmEpsilonFstOptions308 explicit RmEpsilonFstOptions(float delta = kDelta) : delta(delta) {}
309 };
310
311
312 // Implementation of delayed RmEpsilonFst.
313 template <class A>
314 class RmEpsilonFstImpl : public CacheImpl<A> {
315 public:
316 using FstImpl<A>::SetType;
317 using FstImpl<A>::SetProperties;
318 using FstImpl<A>::Properties;
319 using FstImpl<A>::SetInputSymbols;
320 using FstImpl<A>::SetOutputSymbols;
321
322 using CacheBaseImpl< CacheState<A> >::HasStart;
323 using CacheBaseImpl< CacheState<A> >::HasFinal;
324 using CacheBaseImpl< CacheState<A> >::HasArcs;
325
326 typedef typename A::Label Label;
327 typedef typename A::Weight Weight;
328 typedef typename A::StateId StateId;
329 typedef CacheState<A> State;
330
RmEpsilonFstImpl(const Fst<A> & fst,const RmEpsilonFstOptions & opts)331 RmEpsilonFstImpl(const Fst<A>& fst, const RmEpsilonFstOptions &opts)
332 : CacheImpl<A>(opts),
333 fst_(fst.Copy()),
334 rmeps_state_(
335 *fst_,
336 &distance_,
337 RmEpsilonOptions<A, FifoQueue<StateId> >(&queue_, opts.delta,
338 false)
339 ) {
340 SetType("rmepsilon");
341 uint64 props = fst.Properties(kFstProperties, false);
342 SetProperties(RmEpsilonProperties(props, true), kCopyProperties);
343 }
344
~RmEpsilonFstImpl()345 ~RmEpsilonFstImpl() {
346 delete fst_;
347 }
348
Start()349 StateId Start() {
350 if (!HasStart()) {
351 SetStart(fst_->Start());
352 }
353 return CacheImpl<A>::Start();
354 }
355
Final(StateId s)356 Weight Final(StateId s) {
357 if (!HasFinal(s)) {
358 Expand(s);
359 }
360 return CacheImpl<A>::Final(s);
361 }
362
NumArcs(StateId s)363 size_t NumArcs(StateId s) {
364 if (!HasArcs(s))
365 Expand(s);
366 return CacheImpl<A>::NumArcs(s);
367 }
368
NumInputEpsilons(StateId s)369 size_t NumInputEpsilons(StateId s) {
370 if (!HasArcs(s))
371 Expand(s);
372 return CacheImpl<A>::NumInputEpsilons(s);
373 }
374
NumOutputEpsilons(StateId s)375 size_t NumOutputEpsilons(StateId s) {
376 if (!HasArcs(s))
377 Expand(s);
378 return CacheImpl<A>::NumOutputEpsilons(s);
379 }
380
InitArcIterator(StateId s,ArcIteratorData<A> * data)381 void InitArcIterator(StateId s, ArcIteratorData<A> *data) {
382 if (!HasArcs(s))
383 Expand(s);
384 CacheImpl<A>::InitArcIterator(s, data);
385 }
386
Expand(StateId s)387 void Expand(StateId s) {
388 rmeps_state_.Expand(s);
389 SetFinal(s, rmeps_state_.Final());
390 vector<A> &arcs = rmeps_state_.Arcs();
391 while (!arcs.empty()) {
392 AddArc(s, arcs.back());
393 arcs.pop_back();
394 }
395 SetArcs(s);
396 }
397
398 private:
399 const Fst<A> *fst_;
400 vector<Weight> distance_;
401 FifoQueue<StateId> queue_;
402 RmEpsilonState<A, FifoQueue<StateId> > rmeps_state_;
403
404 DISALLOW_EVIL_CONSTRUCTORS(RmEpsilonFstImpl);
405 };
406
407
408 // Removes epsilon-transitions (when both the input and output label
409 // are an epsilon) from a transducer. The result will be an equivalent
410 // FST that has no such epsilon transitions. This version is a
411 // delayed Fst.
412 //
413 // Complexity:
414 // - Time:
415 // - Unweighted: O(v^2 + v e)
416 // - General: exponential
417 // - Space: O(v e)
418 // where v = # of states visited, e = # of arcs visited. Constant time
419 // to visit an input state or arc is assumed and exclusive of caching.
420 //
421 // References:
422 // - Mehryar Mohri. Generic Epsilon-Removal and Input
423 // Epsilon-Normalization Algorithms for Weighted Transducers,
424 // "International Journal of Computer Science", 13(1):129-143 (2002).
425 template <class A>
426 class RmEpsilonFst : public Fst<A> {
427 public:
428 friend class ArcIterator< RmEpsilonFst<A> >;
429 friend class CacheStateIterator< RmEpsilonFst<A> >;
430 friend class CacheArcIterator< RmEpsilonFst<A> >;
431
432 typedef A Arc;
433 typedef typename A::Weight Weight;
434 typedef typename A::StateId StateId;
435 typedef CacheState<A> State;
436
RmEpsilonFst(const Fst<A> & fst)437 RmEpsilonFst(const Fst<A> &fst)
438 : impl_(new RmEpsilonFstImpl<A>(fst, RmEpsilonFstOptions())) {}
439
RmEpsilonFst(const Fst<A> & fst,const RmEpsilonFstOptions & opts)440 RmEpsilonFst(const Fst<A> &fst, const RmEpsilonFstOptions &opts)
441 : impl_(new RmEpsilonFstImpl<A>(fst, opts)) {}
442
RmEpsilonFst(const RmEpsilonFst<A> & fst)443 explicit RmEpsilonFst(const RmEpsilonFst<A> &fst) : impl_(fst.impl_) {
444 impl_->IncrRefCount();
445 }
446
~RmEpsilonFst()447 virtual ~RmEpsilonFst() { if (!impl_->DecrRefCount()) delete impl_; }
448
Start()449 virtual StateId Start() const { return impl_->Start(); }
450
Final(StateId s)451 virtual Weight Final(StateId s) const { return impl_->Final(s); }
452
NumArcs(StateId s)453 virtual size_t NumArcs(StateId s) const { return impl_->NumArcs(s); }
454
NumInputEpsilons(StateId s)455 virtual size_t NumInputEpsilons(StateId s) const {
456 return impl_->NumInputEpsilons(s);
457 }
458
NumOutputEpsilons(StateId s)459 virtual size_t NumOutputEpsilons(StateId s) const {
460 return impl_->NumOutputEpsilons(s);
461 }
462
Properties(uint64 mask,bool test)463 virtual uint64 Properties(uint64 mask, bool test) const {
464 if (test) {
465 uint64 known, test = TestProperties(*this, mask, &known);
466 impl_->SetProperties(test, known);
467 return test & mask;
468 } else {
469 return impl_->Properties(mask);
470 }
471 }
472
Type()473 virtual const string& Type() const { return impl_->Type(); }
474
Copy()475 virtual RmEpsilonFst<A> *Copy() const {
476 return new RmEpsilonFst<A>(*this);
477 }
478
InputSymbols()479 virtual const SymbolTable* InputSymbols() const {
480 return impl_->InputSymbols();
481 }
482
OutputSymbols()483 virtual const SymbolTable* OutputSymbols() const {
484 return impl_->OutputSymbols();
485 }
486
487 virtual inline void InitStateIterator(StateIteratorData<A> *data) const;
488
InitArcIterator(StateId s,ArcIteratorData<A> * data)489 virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
490 impl_->InitArcIterator(s, data);
491 }
492
493 protected:
Impl()494 RmEpsilonFstImpl<A> *Impl() { return impl_; }
495
496 private:
497 RmEpsilonFstImpl<A> *impl_;
498
499 void operator=(const RmEpsilonFst<A> &fst); // disallow
500 };
501
502
503 // Specialization for RmEpsilonFst.
504 template<class A>
505 class StateIterator< RmEpsilonFst<A> >
506 : public CacheStateIterator< RmEpsilonFst<A> > {
507 public:
StateIterator(const RmEpsilonFst<A> & fst)508 explicit StateIterator(const RmEpsilonFst<A> &fst)
509 : CacheStateIterator< RmEpsilonFst<A> >(fst) {}
510 };
511
512
513 // Specialization for RmEpsilonFst.
514 template <class A>
515 class ArcIterator< RmEpsilonFst<A> >
516 : public CacheArcIterator< RmEpsilonFst<A> > {
517 public:
518 typedef typename A::StateId StateId;
519
ArcIterator(const RmEpsilonFst<A> & fst,StateId s)520 ArcIterator(const RmEpsilonFst<A> &fst, StateId s)
521 : CacheArcIterator< RmEpsilonFst<A> >(fst, s) {
522 if (!fst.impl_->HasArcs(s))
523 fst.impl_->Expand(s);
524 }
525
526 private:
527 DISALLOW_EVIL_CONSTRUCTORS(ArcIterator);
528 };
529
530
531 template <class A> inline
InitStateIterator(StateIteratorData<A> * data)532 void RmEpsilonFst<A>::InitStateIterator(StateIteratorData<A> *data) const {
533 data->base = new StateIterator< RmEpsilonFst<A> >(*this);
534 }
535
536
537 // Useful alias when using StdArc.
538 typedef RmEpsilonFst<StdArc> StdRmEpsilonFst;
539
540 } // namespace fst
541
542 #endif // FST_LIB_RMEPSILON_H__
543