• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // minimize.h
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 //
16 // \file Functions and classes to minimize a finite state acceptor
17 
18 #ifndef FST_LIB_MINIMIZE_H__
19 #define FST_LIB_MINIMIZE_H__
20 
21 #include <algorithm>
22 #include <map>
23 #include <queue>
24 
25 #include "fst/lib/arcsort.h"
26 #include "fst/lib/arcsum.h"
27 #include "fst/lib/connect.h"
28 #include "fst/lib/dfs-visit.h"
29 #include "fst/lib/encode.h"
30 #include "fst/lib/factor-weight.h"
31 #include "fst/lib/fst.h"
32 #include "fst/lib/mutable-fst.h"
33 #include "fst/lib/partition.h"
34 #include "fst/lib/push.h"
35 #include "fst/lib/queue.h"
36 #include "fst/lib/reverse.h"
37 
38 namespace fst {
39 
40 // comparator for creating partition based on sorting on
41 // - states
42 // - final weight
43 // - out degree,
44 // -  (input label, output label, weight, destination_block)
45 template <class A>
46 class StateComparator {
47  public:
48   typedef typename A::StateId StateId;
49   typedef typename A::Weight Weight;
50 
51   static const int32 kCompareFinal     = 0x0000001;
52   static const int32 kCompareOutDegree = 0x0000002;
53   static const int32 kCompareArcs      = 0x0000004;
54   static const int32 kCompareAll       = (kCompareFinal |
55                                           kCompareOutDegree |
56                                           kCompareArcs);
57 
58   StateComparator(const Fst<A>& fst,
59                   const Partition<typename A::StateId>& partition,
60                   int32 flags = kCompareAll)
fst_(fst)61       : fst_(fst), partition_(partition), flags_(flags) {}
62 
63   // compare state x with state y based on sort criteria
operator()64   bool operator()(const StateId x, const StateId y) const {
65     // check for final state equivalence
66     if (flags_ & kCompareFinal) {
67       const ssize_t xfinal = fst_.Final(x).Hash();
68       const ssize_t yfinal = fst_.Final(y).Hash();
69       if      (xfinal < yfinal) return true;
70       else if (xfinal > yfinal) return false;
71     }
72 
73     if (flags_ & kCompareOutDegree) {
74       // check for # arcs
75       if (fst_.NumArcs(x) < fst_.NumArcs(y)) return true;
76       if (fst_.NumArcs(x) > fst_.NumArcs(y)) return false;
77 
78       if (flags_ & kCompareArcs) {
79         // # arcs are equal, check for arc match
80         for (ArcIterator<Fst<A> > aiter1(fst_, x), aiter2(fst_, y);
81              !aiter1.Done() && !aiter2.Done(); aiter1.Next(), aiter2.Next()) {
82           const A& arc1 = aiter1.Value();
83           const A& arc2 = aiter2.Value();
84           if (arc1.ilabel < arc2.ilabel) return true;
85           if (arc1.ilabel > arc2.ilabel) return false;
86 
87           if (partition_.class_id(arc1.nextstate) <
88               partition_.class_id(arc2.nextstate)) return true;
89           if (partition_.class_id(arc1.nextstate) >
90               partition_.class_id(arc2.nextstate)) return false;
91         }
92       }
93     }
94 
95     return false;
96   }
97 
98  private:
99   const Fst<A>& fst_;
100   const Partition<typename A::StateId>& partition_;
101   const int32 flags_;
102 };
103 
104 // Computes equivalence classes for cyclic Fsts. For cyclic minimization
105 // we use the classic HopCroft minimization algorithm, which is of
106 //
107 //   O(E)log(N),
108 //
109 // where E is the number of edges in the machine and N is number of states.
110 //
111 // The following paper describes the original algorithm
112 //  An N Log N algorithm for minimizing states in a finite automaton
113 //  by John HopCroft, January 1971
114 //
115 template <class A, class Queue>
116 class CyclicMinimizer {
117  public:
118   typedef typename A::Label Label;
119   typedef typename A::StateId StateId;
120   typedef typename A::StateId ClassId;
121   typedef typename A::Weight Weight;
122   typedef ReverseArc<A> RevA;
123 
CyclicMinimizer(const ExpandedFst<A> & fst)124   CyclicMinimizer(const ExpandedFst<A>& fst) {
125     Initialize(fst);
126     Compute(fst);
127   }
128 
~CyclicMinimizer()129   ~CyclicMinimizer() {
130     delete aiter_queue_;
131   }
132 
partition()133   const Partition<StateId>& partition() const {
134     return P_;
135   }
136 
137   // helper classes
138  private:
139   typedef ArcIterator<Fst<RevA> > ArcIter;
140   class ArcIterCompare {
141    public:
ArcIterCompare(const Partition<StateId> & partition)142     ArcIterCompare(const Partition<StateId>& partition)
143         : partition_(partition) {}
144 
ArcIterCompare(const ArcIterCompare & comp)145     ArcIterCompare(const ArcIterCompare& comp)
146         : partition_(comp.partition_) {}
147 
148     // compare two iterators based on there input labels, and proto state
149     // (partition class Ids)
operator()150     bool operator()(const ArcIter* x, const ArcIter* y) const {
151       const RevA& xarc = x->Value();
152       const RevA& yarc = y->Value();
153       return (xarc.ilabel > yarc.ilabel);
154     }
155 
156    private:
157     const Partition<StateId>& partition_;
158   };
159 
160   typedef priority_queue<ArcIter*, vector<ArcIter*>, ArcIterCompare>
161   ArcIterQueue;
162 
163   // helper methods
164  private:
165   // prepartitions the space into equivalence classes with
166   //   same final weight
167   //   same # arcs per state
168   //   same outgoing arcs
PrePartition(const Fst<A> & fst)169   void PrePartition(const Fst<A>& fst) {
170     VLOG(5) << "PrePartition";
171 
172     typedef map<StateId, StateId, StateComparator<A> > EquivalenceMap;
173     StateComparator<A> comp(fst, P_, StateComparator<A>::kCompareFinal);
174     EquivalenceMap equiv_map(comp);
175 
176     StateIterator<Fst<A> > siter(fst);
177     StateId class_id = P_.AddClass();
178     P_.Add(siter.Value(), class_id);
179     equiv_map[siter.Value()] = class_id;
180     L_.Enqueue(class_id);
181     for (siter.Next(); !siter.Done(); siter.Next()) {
182       StateId  s = siter.Value();
183       typename EquivalenceMap::const_iterator it = equiv_map.find(s);
184       if (it == equiv_map.end()) {
185         class_id = P_.AddClass();
186         P_.Add(s, class_id);
187         equiv_map[s] = class_id;
188         L_.Enqueue(class_id);
189       } else {
190         P_.Add(s, it->second);
191         equiv_map[s] = it->second;
192       }
193     }
194 
195     VLOG(5) << "Initial Partition: " << P_.num_classes();
196   }
197 
198   // - Create inverse transition Tr_ = rev(fst)
199   // - loop over states in fst and split on final, creating two blocks
200   //   in the partition corresponding to final, non-final
Initialize(const Fst<A> & fst)201   void Initialize(const Fst<A>& fst) {
202     // construct Tr
203     Reverse(fst, &Tr_);
204     ILabelCompare<RevA> ilabel_comp;
205     ArcSort(&Tr_, ilabel_comp);
206 
207     // initial split (F, S - F)
208     P_.Initialize(Tr_.NumStates() - 1);
209 
210     // prep partition
211     PrePartition(fst);
212 
213     // allocate arc iterator queue
214     ArcIterCompare comp(P_);
215     aiter_queue_ = new ArcIterQueue(comp);
216   }
217 
218   // partition all classes with destination C
Split(ClassId C)219   void Split(ClassId C) {
220     // Prep priority queue. Open arc iterator for each state in C, and
221     // insert into priority queue.
222     for (PartitionIterator<StateId> siter(P_, C);
223          !siter.Done(); siter.Next()) {
224       StateId s = siter.Value();
225       if (Tr_.NumArcs(s + 1))
226         aiter_queue_->push(new ArcIterator<Fst<RevA> >(Tr_, s + 1));
227     }
228 
229     // Now pop arc iterator from queue, split entering equivalence class
230     // re-insert updated iterator into queue.
231     Label prev_label = -1;
232     while (!aiter_queue_->empty()) {
233       ArcIterator<Fst<RevA> >* aiter = aiter_queue_->top();
234       aiter_queue_->pop();
235       if (aiter->Done()) {
236         delete aiter;
237         continue;
238      }
239 
240       const RevA& arc = aiter->Value();
241       StateId from_state = aiter->Value().nextstate - 1;
242       Label   from_label = arc.ilabel;
243       if (prev_label != from_label)
244         P_.FinalizeSplit(&L_);
245 
246       StateId from_class = P_.class_id(from_state);
247       if (P_.class_size(from_class) > 1)
248         P_.SplitOn(from_state);
249 
250       prev_label = from_label;
251       aiter->Next();
252       if (aiter->Done())
253         delete aiter;
254       else
255         aiter_queue_->push(aiter);
256     }
257     P_.FinalizeSplit(&L_);
258   }
259 
260   // Main loop for hopcroft minimization.
Compute(const Fst<A> & fst)261   void Compute(const Fst<A>& fst) {
262     // process active classes (FIFO, or FILO)
263     while (!L_.Empty()) {
264       ClassId C = L_.Head();
265       L_.Dequeue();
266 
267       // split on C, all labels in C
268       Split(C);
269     }
270   }
271 
272   // helper data
273  private:
274   // Partioning of states into equivalence classes
275   Partition<StateId> P_;
276 
277   // L = set of active classes to be processed in partition P
278   Queue L_;
279 
280   // reverse transition function
281   VectorFst<RevA> Tr_;
282 
283   // Priority queue of open arc iterators for all states in the 'splitter'
284   // equivalence class
285   ArcIterQueue* aiter_queue_;
286 };
287 
288 
289 // Computes equivalence classes for acyclic Fsts. The implementation details
290 // for this algorithms is documented by the following paper.
291 //
292 // Minimization of acyclic deterministic automata in linear time
293 //  Dominque Revuz
294 //
295 // Complexity O(|E|)
296 //
297 template <class A>
298 class AcyclicMinimizer {
299  public:
300   typedef typename A::Label Label;
301   typedef typename A::StateId StateId;
302   typedef typename A::StateId ClassId;
303   typedef typename A::Weight Weight;
304 
AcyclicMinimizer(const ExpandedFst<A> & fst)305   AcyclicMinimizer(const ExpandedFst<A>& fst) {
306     Initialize(fst);
307     Refine(fst);
308   }
309 
partition()310   const Partition<StateId>& partition() {
311     return partition_;
312   }
313 
314   // helper classes
315  private:
316   // DFS visitor to compute the height (distance) to final state.
317   class HeightVisitor {
318    public:
HeightVisitor()319     HeightVisitor() : max_height_(0), num_states_(0) { }
320 
321     // invoked before dfs visit
InitVisit(const Fst<A> & fst)322     void InitVisit(const Fst<A>& fst) {}
323 
324     // invoked when state is discovered (2nd arg is DFS tree root)
InitState(StateId s,StateId root)325     bool InitState(StateId s, StateId root) {
326       // extend height array and initialize height (distance) to 0
327       for (size_t i = height_.size(); i <= (size_t)s; ++i)
328         height_.push_back(-1);
329 
330       if (s >= (StateId)num_states_) num_states_ = s + 1;
331       return true;
332     }
333 
334     // invoked when tree arc examined (to undiscoverted state)
TreeArc(StateId s,const A & arc)335     bool TreeArc(StateId s, const A& arc) {
336       return true;
337     }
338 
339     // invoked when back arc examined (to unfinished state)
BackArc(StateId s,const A & arc)340     bool BackArc(StateId s, const A& arc) {
341       return true;
342     }
343 
344     // invoked when forward or cross arc examined (to finished state)
ForwardOrCrossArc(StateId s,const A & arc)345     bool ForwardOrCrossArc(StateId s, const A& arc) {
346       if (height_[arc.nextstate] + 1 > height_[s])
347         height_[s] = height_[arc.nextstate] + 1;
348       return true;
349     }
350 
351     // invoked when state finished (parent is kNoStateId for tree root)
FinishState(StateId s,StateId parent,const A * parent_arc)352     void FinishState(StateId s, StateId parent, const A* parent_arc) {
353       if (height_[s] == -1) height_[s] = 0;
354       StateId h = height_[s] +  1;
355       if (parent >= 0) {
356         if (h > height_[parent]) height_[parent] = h;
357         if (h > (StateId)max_height_)     max_height_ = h;
358       }
359     }
360 
361     // invoked after DFS visit
FinishVisit()362     void FinishVisit() {}
363 
max_height()364     size_t max_height() const { return max_height_; }
365 
height()366     const vector<StateId>& height() const { return height_; }
367 
num_states()368     const size_t num_states() const { return num_states_; }
369 
370    private:
371     vector<StateId> height_;
372     size_t max_height_;
373     size_t num_states_;
374   };
375 
376   // helper methods
377  private:
378   // cluster states according to height (distance to final state)
Initialize(const Fst<A> & fst)379   void Initialize(const Fst<A>& fst) {
380     // compute height (distance to final state)
381     HeightVisitor hvisitor;
382     DfsVisit(fst, &hvisitor);
383 
384     // create initial partition based on height
385     partition_.Initialize(hvisitor.num_states());
386     partition_.AllocateClasses(hvisitor.max_height() + 1);
387     const vector<StateId>& hstates = hvisitor.height();
388     for (size_t s = 0; s < hstates.size(); ++s)
389       partition_.Add(s, hstates[s]);
390   }
391 
392   // refine states based on arc sort (out degree, arc equivalence)
Refine(const Fst<A> & fst)393   void Refine(const Fst<A>& fst) {
394     typedef map<StateId, StateId, StateComparator<A> > EquivalenceMap;
395     StateComparator<A> comp(fst, partition_);
396 
397     // start with tail (height = 0)
398     size_t height = partition_.num_classes();
399     for (size_t h = 0; h < height; ++h) {
400       EquivalenceMap equiv_classes(comp);
401 
402       // sort states within equivalence class
403       PartitionIterator<StateId> siter(partition_, h);
404       equiv_classes[siter.Value()] = h;
405       for (siter.Next(); !siter.Done(); siter.Next()) {
406         const StateId s = siter.Value();
407         typename EquivalenceMap::const_iterator it = equiv_classes.find(s);
408         if (it == equiv_classes.end())
409           equiv_classes[s] = partition_.AddClass();
410         else
411           equiv_classes[s] = it->second;
412       }
413 
414       // create refined partition
415       for (siter.Reset(); !siter.Done();) {
416         const StateId s = siter.Value();
417         const StateId old_class = partition_.class_id(s);
418         const StateId new_class = equiv_classes[s];
419 
420         // a move operation can invalidate the iterator, so
421         // we first update the iterator to the next element
422         // before we move the current element out of the list
423         siter.Next();
424         if (old_class != new_class)
425           partition_.Move(s, new_class);
426       }
427     }
428   }
429 
430  private:
431   Partition<StateId> partition_;
432 };
433 
434 
435 // Given a partition and a mutable fst, merge states of Fst inplace
436 // (i.e. destructively). Merging works by taking the first state in
437 // a class of the partition to be the representative state for the class.
438 // Each arc is then reconnected to this state. All states in the class
439 // are merged by adding there arcs to the representative state.
440 template <class A>
MergeStates(const Partition<typename A::StateId> & partition,MutableFst<A> * fst)441 void MergeStates(
442     const Partition<typename A::StateId>& partition, MutableFst<A>* fst) {
443   typedef typename A::StateId StateId;
444 
445   vector<StateId> state_map(partition.num_classes());
446   for (size_t i = 0; i < (size_t)partition.num_classes(); ++i) {
447     PartitionIterator<StateId> siter(partition, i);
448     state_map[i] = siter.Value();  // first state in partition;
449   }
450 
451   // relabel destination states
452   for (size_t c = 0; c < (size_t)partition.num_classes(); ++c) {
453     for (PartitionIterator<StateId> siter(partition, c);
454          !siter.Done(); siter.Next()) {
455       StateId s = siter.Value();
456       for (MutableArcIterator<MutableFst<A> > aiter(fst, s);
457            !aiter.Done(); aiter.Next()) {
458         A arc = aiter.Value();
459         arc.nextstate = state_map[partition.class_id(arc.nextstate)];
460 
461         if (s == state_map[c])  // first state just set destination
462           aiter.SetValue(arc);
463         else
464           fst->AddArc(state_map[c], arc);
465       }
466     }
467   }
468   fst->SetStart(state_map[partition.class_id(fst->Start())]);
469 
470   Connect(fst);
471 }
472 
473 template <class A>
AcceptorMinimize(MutableFst<A> * fst)474 void AcceptorMinimize(MutableFst<A>* fst) {
475   typedef typename A::StateId StateId;
476   if (!(fst->Properties(kAcceptor | kUnweighted, true)))
477     LOG(FATAL) << "Input Fst is not an unweighted acceptor";
478 
479   // connect fst before minimization, handles disconnected states
480   Connect(fst);
481   if (fst->NumStates() == 0) return;
482 
483   if (fst->Properties(kAcyclic, true)) {
484     // Acyclic minimization (revuz)
485     VLOG(2) << "Acyclic Minimization";
486     AcyclicMinimizer<A> minimizer(*fst);
487     MergeStates(minimizer.partition(), fst);
488 
489   } else {
490     // Cyclic minimizaton (hopcroft)
491     VLOG(2) << "Cyclic Minimization";
492     CyclicMinimizer<A, LifoQueue<StateId> > minimizer(*fst);
493     MergeStates(minimizer.partition(), fst);
494   }
495 
496   // sort arcs before summing
497   ArcSort(fst, ILabelCompare<A>());
498 
499   // sum in appropriate semiring
500   ArcSum(fst);
501 }
502 
503 
504 // In place minimization of unweighted, deterministic acceptors
505 //
506 // For acyclic automata we use an algorithm from Dominique Revuz that is
507 // linear in the number of arcs (edges) in the machine.
508 //  Complexity = O(E)
509 //
510 // For cyclic automata we use the classical hopcroft minimization.
511 //  Complexity = O(|E|log(|N|)
512 //
513 template <class A>
514 void Minimize(MutableFst<A>* fst, MutableFst<A>* sfst = 0) {
515   uint64 props = fst->Properties(kAcceptor | kIDeterministic|
516                                  kWeighted | kUnweighted, true);
517   if (!(props & kIDeterministic))
518     LOG(FATAL) << "Input Fst is not deterministic";
519 
520   if (!(props & kAcceptor)) {  // weighted transducer
521     VectorFst< GallicArc<A, STRING_LEFT> > gfst;
522     Map(*fst, &gfst, ToGallicMapper<A, STRING_LEFT>());
523     fst->DeleteStates();
524     gfst.SetProperties(kAcceptor, kAcceptor);
525     Push(&gfst, REWEIGHT_TO_INITIAL);
526     Map(&gfst, QuantizeMapper< GallicArc<A, STRING_LEFT> >());
527     EncodeMapper< GallicArc<A, STRING_LEFT> >
528       encoder(kEncodeLabels | kEncodeWeights, ENCODE);
529     Encode(&gfst, &encoder);
530     AcceptorMinimize(&gfst);
531     Decode(&gfst, encoder);
532 
533     if (sfst == 0) {
534       FactorWeightFst< GallicArc<A, STRING_LEFT>,
535         GallicFactor<typename A::Label,
536         typename A::Weight, STRING_LEFT> > fwfst(gfst);
537       Map(fwfst, fst, FromGallicMapper<A, STRING_LEFT>());
538     } else {
539       sfst->SetOutputSymbols(fst->OutputSymbols());
540       GallicToNewSymbolsMapper<A, STRING_LEFT> mapper(sfst);
541       Map(gfst, fst, mapper);
542       fst->SetOutputSymbols(sfst->InputSymbols());
543     }
544   } else if (props & kWeighted) {  // weighted acceptor
545     Push(fst, REWEIGHT_TO_INITIAL);
546     Map(fst, QuantizeMapper<A>());
547     EncodeMapper<A> encoder(kEncodeLabels | kEncodeWeights, ENCODE);
548     Encode(fst, &encoder);
549     AcceptorMinimize(fst);
550     Decode(fst, encoder);
551   } else {  // unweighted acceptor
552     AcceptorMinimize(fst);
553   }
554 }
555 
556 }  // namespace fst
557 
558 #endif  // FST_LIB_MINIMIZE_H__
559