1 // minimize.h
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 //
16 // \file Functions and classes to minimize a finite state acceptor
17
18 #ifndef FST_LIB_MINIMIZE_H__
19 #define FST_LIB_MINIMIZE_H__
20
21 #include <algorithm>
22 #include <map>
23 #include <queue>
24
25 #include "fst/lib/arcsort.h"
26 #include "fst/lib/arcsum.h"
27 #include "fst/lib/connect.h"
28 #include "fst/lib/dfs-visit.h"
29 #include "fst/lib/encode.h"
30 #include "fst/lib/factor-weight.h"
31 #include "fst/lib/fst.h"
32 #include "fst/lib/mutable-fst.h"
33 #include "fst/lib/partition.h"
34 #include "fst/lib/push.h"
35 #include "fst/lib/queue.h"
36 #include "fst/lib/reverse.h"
37
38 namespace fst {
39
40 // comparator for creating partition based on sorting on
41 // - states
42 // - final weight
43 // - out degree,
44 // - (input label, output label, weight, destination_block)
45 template <class A>
46 class StateComparator {
47 public:
48 typedef typename A::StateId StateId;
49 typedef typename A::Weight Weight;
50
51 static const int32 kCompareFinal = 0x0000001;
52 static const int32 kCompareOutDegree = 0x0000002;
53 static const int32 kCompareArcs = 0x0000004;
54 static const int32 kCompareAll = (kCompareFinal |
55 kCompareOutDegree |
56 kCompareArcs);
57
58 StateComparator(const Fst<A>& fst,
59 const Partition<typename A::StateId>& partition,
60 int32 flags = kCompareAll)
fst_(fst)61 : fst_(fst), partition_(partition), flags_(flags) {}
62
63 // compare state x with state y based on sort criteria
operator()64 bool operator()(const StateId x, const StateId y) const {
65 // check for final state equivalence
66 if (flags_ & kCompareFinal) {
67 const ssize_t xfinal = fst_.Final(x).Hash();
68 const ssize_t yfinal = fst_.Final(y).Hash();
69 if (xfinal < yfinal) return true;
70 else if (xfinal > yfinal) return false;
71 }
72
73 if (flags_ & kCompareOutDegree) {
74 // check for # arcs
75 if (fst_.NumArcs(x) < fst_.NumArcs(y)) return true;
76 if (fst_.NumArcs(x) > fst_.NumArcs(y)) return false;
77
78 if (flags_ & kCompareArcs) {
79 // # arcs are equal, check for arc match
80 for (ArcIterator<Fst<A> > aiter1(fst_, x), aiter2(fst_, y);
81 !aiter1.Done() && !aiter2.Done(); aiter1.Next(), aiter2.Next()) {
82 const A& arc1 = aiter1.Value();
83 const A& arc2 = aiter2.Value();
84 if (arc1.ilabel < arc2.ilabel) return true;
85 if (arc1.ilabel > arc2.ilabel) return false;
86
87 if (partition_.class_id(arc1.nextstate) <
88 partition_.class_id(arc2.nextstate)) return true;
89 if (partition_.class_id(arc1.nextstate) >
90 partition_.class_id(arc2.nextstate)) return false;
91 }
92 }
93 }
94
95 return false;
96 }
97
98 private:
99 const Fst<A>& fst_;
100 const Partition<typename A::StateId>& partition_;
101 const int32 flags_;
102 };
103
104 // Computes equivalence classes for cyclic Fsts. For cyclic minimization
105 // we use the classic HopCroft minimization algorithm, which is of
106 //
107 // O(E)log(N),
108 //
109 // where E is the number of edges in the machine and N is number of states.
110 //
111 // The following paper describes the original algorithm
112 // An N Log N algorithm for minimizing states in a finite automaton
113 // by John HopCroft, January 1971
114 //
115 template <class A, class Queue>
116 class CyclicMinimizer {
117 public:
118 typedef typename A::Label Label;
119 typedef typename A::StateId StateId;
120 typedef typename A::StateId ClassId;
121 typedef typename A::Weight Weight;
122 typedef ReverseArc<A> RevA;
123
CyclicMinimizer(const ExpandedFst<A> & fst)124 CyclicMinimizer(const ExpandedFst<A>& fst) {
125 Initialize(fst);
126 Compute(fst);
127 }
128
~CyclicMinimizer()129 ~CyclicMinimizer() {
130 delete aiter_queue_;
131 }
132
partition()133 const Partition<StateId>& partition() const {
134 return P_;
135 }
136
137 // helper classes
138 private:
139 typedef ArcIterator<Fst<RevA> > ArcIter;
140 class ArcIterCompare {
141 public:
ArcIterCompare(const Partition<StateId> & partition)142 ArcIterCompare(const Partition<StateId>& partition)
143 : partition_(partition) {}
144
ArcIterCompare(const ArcIterCompare & comp)145 ArcIterCompare(const ArcIterCompare& comp)
146 : partition_(comp.partition_) {}
147
148 // compare two iterators based on there input labels, and proto state
149 // (partition class Ids)
operator()150 bool operator()(const ArcIter* x, const ArcIter* y) const {
151 const RevA& xarc = x->Value();
152 const RevA& yarc = y->Value();
153 return (xarc.ilabel > yarc.ilabel);
154 }
155
156 private:
157 const Partition<StateId>& partition_;
158 };
159
160 typedef priority_queue<ArcIter*, vector<ArcIter*>, ArcIterCompare>
161 ArcIterQueue;
162
163 // helper methods
164 private:
165 // prepartitions the space into equivalence classes with
166 // same final weight
167 // same # arcs per state
168 // same outgoing arcs
PrePartition(const Fst<A> & fst)169 void PrePartition(const Fst<A>& fst) {
170 VLOG(5) << "PrePartition";
171
172 typedef map<StateId, StateId, StateComparator<A> > EquivalenceMap;
173 StateComparator<A> comp(fst, P_, StateComparator<A>::kCompareFinal);
174 EquivalenceMap equiv_map(comp);
175
176 StateIterator<Fst<A> > siter(fst);
177 StateId class_id = P_.AddClass();
178 P_.Add(siter.Value(), class_id);
179 equiv_map[siter.Value()] = class_id;
180 L_.Enqueue(class_id);
181 for (siter.Next(); !siter.Done(); siter.Next()) {
182 StateId s = siter.Value();
183 typename EquivalenceMap::const_iterator it = equiv_map.find(s);
184 if (it == equiv_map.end()) {
185 class_id = P_.AddClass();
186 P_.Add(s, class_id);
187 equiv_map[s] = class_id;
188 L_.Enqueue(class_id);
189 } else {
190 P_.Add(s, it->second);
191 equiv_map[s] = it->second;
192 }
193 }
194
195 VLOG(5) << "Initial Partition: " << P_.num_classes();
196 }
197
198 // - Create inverse transition Tr_ = rev(fst)
199 // - loop over states in fst and split on final, creating two blocks
200 // in the partition corresponding to final, non-final
Initialize(const Fst<A> & fst)201 void Initialize(const Fst<A>& fst) {
202 // construct Tr
203 Reverse(fst, &Tr_);
204 ILabelCompare<RevA> ilabel_comp;
205 ArcSort(&Tr_, ilabel_comp);
206
207 // initial split (F, S - F)
208 P_.Initialize(Tr_.NumStates() - 1);
209
210 // prep partition
211 PrePartition(fst);
212
213 // allocate arc iterator queue
214 ArcIterCompare comp(P_);
215 aiter_queue_ = new ArcIterQueue(comp);
216 }
217
218 // partition all classes with destination C
Split(ClassId C)219 void Split(ClassId C) {
220 // Prep priority queue. Open arc iterator for each state in C, and
221 // insert into priority queue.
222 for (PartitionIterator<StateId> siter(P_, C);
223 !siter.Done(); siter.Next()) {
224 StateId s = siter.Value();
225 if (Tr_.NumArcs(s + 1))
226 aiter_queue_->push(new ArcIterator<Fst<RevA> >(Tr_, s + 1));
227 }
228
229 // Now pop arc iterator from queue, split entering equivalence class
230 // re-insert updated iterator into queue.
231 Label prev_label = -1;
232 while (!aiter_queue_->empty()) {
233 ArcIterator<Fst<RevA> >* aiter = aiter_queue_->top();
234 aiter_queue_->pop();
235 if (aiter->Done()) {
236 delete aiter;
237 continue;
238 }
239
240 const RevA& arc = aiter->Value();
241 StateId from_state = aiter->Value().nextstate - 1;
242 Label from_label = arc.ilabel;
243 if (prev_label != from_label)
244 P_.FinalizeSplit(&L_);
245
246 StateId from_class = P_.class_id(from_state);
247 if (P_.class_size(from_class) > 1)
248 P_.SplitOn(from_state);
249
250 prev_label = from_label;
251 aiter->Next();
252 if (aiter->Done())
253 delete aiter;
254 else
255 aiter_queue_->push(aiter);
256 }
257 P_.FinalizeSplit(&L_);
258 }
259
260 // Main loop for hopcroft minimization.
Compute(const Fst<A> & fst)261 void Compute(const Fst<A>& fst) {
262 // process active classes (FIFO, or FILO)
263 while (!L_.Empty()) {
264 ClassId C = L_.Head();
265 L_.Dequeue();
266
267 // split on C, all labels in C
268 Split(C);
269 }
270 }
271
272 // helper data
273 private:
274 // Partioning of states into equivalence classes
275 Partition<StateId> P_;
276
277 // L = set of active classes to be processed in partition P
278 Queue L_;
279
280 // reverse transition function
281 VectorFst<RevA> Tr_;
282
283 // Priority queue of open arc iterators for all states in the 'splitter'
284 // equivalence class
285 ArcIterQueue* aiter_queue_;
286 };
287
288
289 // Computes equivalence classes for acyclic Fsts. The implementation details
290 // for this algorithms is documented by the following paper.
291 //
292 // Minimization of acyclic deterministic automata in linear time
293 // Dominque Revuz
294 //
295 // Complexity O(|E|)
296 //
297 template <class A>
298 class AcyclicMinimizer {
299 public:
300 typedef typename A::Label Label;
301 typedef typename A::StateId StateId;
302 typedef typename A::StateId ClassId;
303 typedef typename A::Weight Weight;
304
AcyclicMinimizer(const ExpandedFst<A> & fst)305 AcyclicMinimizer(const ExpandedFst<A>& fst) {
306 Initialize(fst);
307 Refine(fst);
308 }
309
partition()310 const Partition<StateId>& partition() {
311 return partition_;
312 }
313
314 // helper classes
315 private:
316 // DFS visitor to compute the height (distance) to final state.
317 class HeightVisitor {
318 public:
HeightVisitor()319 HeightVisitor() : max_height_(0), num_states_(0) { }
320
321 // invoked before dfs visit
InitVisit(const Fst<A> & fst)322 void InitVisit(const Fst<A>& fst) {}
323
324 // invoked when state is discovered (2nd arg is DFS tree root)
InitState(StateId s,StateId root)325 bool InitState(StateId s, StateId root) {
326 // extend height array and initialize height (distance) to 0
327 for (size_t i = height_.size(); i <= (size_t)s; ++i)
328 height_.push_back(-1);
329
330 if (s >= (StateId)num_states_) num_states_ = s + 1;
331 return true;
332 }
333
334 // invoked when tree arc examined (to undiscoverted state)
TreeArc(StateId s,const A & arc)335 bool TreeArc(StateId s, const A& arc) {
336 return true;
337 }
338
339 // invoked when back arc examined (to unfinished state)
BackArc(StateId s,const A & arc)340 bool BackArc(StateId s, const A& arc) {
341 return true;
342 }
343
344 // invoked when forward or cross arc examined (to finished state)
ForwardOrCrossArc(StateId s,const A & arc)345 bool ForwardOrCrossArc(StateId s, const A& arc) {
346 if (height_[arc.nextstate] + 1 > height_[s])
347 height_[s] = height_[arc.nextstate] + 1;
348 return true;
349 }
350
351 // invoked when state finished (parent is kNoStateId for tree root)
FinishState(StateId s,StateId parent,const A * parent_arc)352 void FinishState(StateId s, StateId parent, const A* parent_arc) {
353 if (height_[s] == -1) height_[s] = 0;
354 StateId h = height_[s] + 1;
355 if (parent >= 0) {
356 if (h > height_[parent]) height_[parent] = h;
357 if (h > (StateId)max_height_) max_height_ = h;
358 }
359 }
360
361 // invoked after DFS visit
FinishVisit()362 void FinishVisit() {}
363
max_height()364 size_t max_height() const { return max_height_; }
365
height()366 const vector<StateId>& height() const { return height_; }
367
num_states()368 size_t num_states() const { return num_states_; }
369
370 private:
371 vector<StateId> height_;
372 size_t max_height_;
373 size_t num_states_;
374 };
375
376 // helper methods
377 private:
378 // cluster states according to height (distance to final state)
Initialize(const Fst<A> & fst)379 void Initialize(const Fst<A>& fst) {
380 // compute height (distance to final state)
381 HeightVisitor hvisitor;
382 DfsVisit(fst, &hvisitor);
383
384 // create initial partition based on height
385 partition_.Initialize(hvisitor.num_states());
386 partition_.AllocateClasses(hvisitor.max_height() + 1);
387 const vector<StateId>& hstates = hvisitor.height();
388 for (size_t s = 0; s < hstates.size(); ++s)
389 partition_.Add(s, hstates[s]);
390 }
391
392 // refine states based on arc sort (out degree, arc equivalence)
Refine(const Fst<A> & fst)393 void Refine(const Fst<A>& fst) {
394 typedef map<StateId, StateId, StateComparator<A> > EquivalenceMap;
395 StateComparator<A> comp(fst, partition_);
396
397 // start with tail (height = 0)
398 size_t height = partition_.num_classes();
399 for (size_t h = 0; h < height; ++h) {
400 EquivalenceMap equiv_classes(comp);
401
402 // sort states within equivalence class
403 PartitionIterator<StateId> siter(partition_, h);
404 equiv_classes[siter.Value()] = h;
405 for (siter.Next(); !siter.Done(); siter.Next()) {
406 const StateId s = siter.Value();
407 typename EquivalenceMap::const_iterator it = equiv_classes.find(s);
408 if (it == equiv_classes.end())
409 equiv_classes[s] = partition_.AddClass();
410 else
411 equiv_classes[s] = it->second;
412 }
413
414 // create refined partition
415 for (siter.Reset(); !siter.Done();) {
416 const StateId s = siter.Value();
417 const StateId old_class = partition_.class_id(s);
418 const StateId new_class = equiv_classes[s];
419
420 // a move operation can invalidate the iterator, so
421 // we first update the iterator to the next element
422 // before we move the current element out of the list
423 siter.Next();
424 if (old_class != new_class)
425 partition_.Move(s, new_class);
426 }
427 }
428 }
429
430 private:
431 Partition<StateId> partition_;
432 };
433
434
435 // Given a partition and a mutable fst, merge states of Fst inplace
436 // (i.e. destructively). Merging works by taking the first state in
437 // a class of the partition to be the representative state for the class.
438 // Each arc is then reconnected to this state. All states in the class
439 // are merged by adding there arcs to the representative state.
440 template <class A>
MergeStates(const Partition<typename A::StateId> & partition,MutableFst<A> * fst)441 void MergeStates(
442 const Partition<typename A::StateId>& partition, MutableFst<A>* fst) {
443 typedef typename A::StateId StateId;
444
445 vector<StateId> state_map(partition.num_classes());
446 for (size_t i = 0; i < (size_t)partition.num_classes(); ++i) {
447 PartitionIterator<StateId> siter(partition, i);
448 state_map[i] = siter.Value(); // first state in partition;
449 }
450
451 // relabel destination states
452 for (size_t c = 0; c < (size_t)partition.num_classes(); ++c) {
453 for (PartitionIterator<StateId> siter(partition, c);
454 !siter.Done(); siter.Next()) {
455 StateId s = siter.Value();
456 for (MutableArcIterator<MutableFst<A> > aiter(fst, s);
457 !aiter.Done(); aiter.Next()) {
458 A arc = aiter.Value();
459 arc.nextstate = state_map[partition.class_id(arc.nextstate)];
460
461 if (s == state_map[c]) // first state just set destination
462 aiter.SetValue(arc);
463 else
464 fst->AddArc(state_map[c], arc);
465 }
466 }
467 }
468 fst->SetStart(state_map[partition.class_id(fst->Start())]);
469
470 Connect(fst);
471 }
472
473 template <class A>
AcceptorMinimize(MutableFst<A> * fst)474 void AcceptorMinimize(MutableFst<A>* fst) {
475 typedef typename A::StateId StateId;
476 if (!(fst->Properties(kAcceptor | kUnweighted, true)))
477 LOG(FATAL) << "Input Fst is not an unweighted acceptor";
478
479 // connect fst before minimization, handles disconnected states
480 Connect(fst);
481 if (fst->NumStates() == 0) return;
482
483 if (fst->Properties(kAcyclic, true)) {
484 // Acyclic minimization (revuz)
485 VLOG(2) << "Acyclic Minimization";
486 AcyclicMinimizer<A> minimizer(*fst);
487 MergeStates(minimizer.partition(), fst);
488
489 } else {
490 // Cyclic minimizaton (hopcroft)
491 VLOG(2) << "Cyclic Minimization";
492 CyclicMinimizer<A, LifoQueue<StateId> > minimizer(*fst);
493 MergeStates(minimizer.partition(), fst);
494 }
495
496 // sort arcs before summing
497 ArcSort(fst, ILabelCompare<A>());
498
499 // sum in appropriate semiring
500 ArcSum(fst);
501 }
502
503
504 // In place minimization of unweighted, deterministic acceptors
505 //
506 // For acyclic automata we use an algorithm from Dominique Revuz that is
507 // linear in the number of arcs (edges) in the machine.
508 // Complexity = O(E)
509 //
510 // For cyclic automata we use the classical hopcroft minimization.
511 // Complexity = O(|E|log(|N|)
512 //
513 template <class A>
514 void Minimize(MutableFst<A>* fst, MutableFst<A>* sfst = 0) {
515 uint64 props = fst->Properties(kAcceptor | kIDeterministic|
516 kWeighted | kUnweighted, true);
517 if (!(props & kIDeterministic))
518 LOG(FATAL) << "Input Fst is not deterministic";
519
520 if (!(props & kAcceptor)) { // weighted transducer
521 VectorFst< GallicArc<A, STRING_LEFT> > gfst;
522 Map(*fst, &gfst, ToGallicMapper<A, STRING_LEFT>());
523 fst->DeleteStates();
524 gfst.SetProperties(kAcceptor, kAcceptor);
525 Push(&gfst, REWEIGHT_TO_INITIAL);
526 Map(&gfst, QuantizeMapper< GallicArc<A, STRING_LEFT> >());
527 EncodeMapper< GallicArc<A, STRING_LEFT> >
528 encoder(kEncodeLabels | kEncodeWeights, ENCODE);
529 Encode(&gfst, &encoder);
530 AcceptorMinimize(&gfst);
531 Decode(&gfst, encoder);
532
533 if (sfst == 0) {
534 FactorWeightFst< GallicArc<A, STRING_LEFT>,
535 GallicFactor<typename A::Label,
536 typename A::Weight, STRING_LEFT> > fwfst(gfst);
537 Map(fwfst, fst, FromGallicMapper<A, STRING_LEFT>());
538 } else {
539 sfst->SetOutputSymbols(fst->OutputSymbols());
540 GallicToNewSymbolsMapper<A, STRING_LEFT> mapper(sfst);
541 Map(gfst, fst, mapper);
542 fst->SetOutputSymbols(sfst->InputSymbols());
543 }
544 } else if (props & kWeighted) { // weighted acceptor
545 Push(fst, REWEIGHT_TO_INITIAL);
546 Map(fst, QuantizeMapper<A>());
547 EncodeMapper<A> encoder(kEncodeLabels | kEncodeWeights, ENCODE);
548 Encode(fst, &encoder);
549 AcceptorMinimize(fst);
550 Decode(fst, encoder);
551 } else { // unweighted acceptor
552 AcceptorMinimize(fst);
553 }
554 }
555
556 } // namespace fst
557
558 #endif // FST_LIB_MINIMIZE_H__
559