• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // prune.h
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Author: allauzen@cs.nyu.edu (Cyril Allauzen)
16 //
17 // \file
18 // Functions implementing pruning.
19 
20 #ifndef FST_LIB_PRUNE_H__
21 #define FST_LIB_PRUNE_H__
22 
23 #include "fst/lib/arcfilter.h"
24 #include "fst/lib/shortest-distance.h"
25 
26 namespace fst {
27 
28 template <class A, class ArcFilter>
29 class PruneOptions {
30  public:
31   typedef typename A::Weight Weight;
32 
33   // Pruning threshold.
34   Weight threshold;
35   // Arc filter.
36   ArcFilter filter;
37   // If non-zero, passes in pre-computed shortest distance from initial state
38   // (possibly resized).
39   vector<Weight> *idistance;
40   // If non-zero, passes in pre-computed shortest distance to final states
41   // (possibly resized).
42   vector<Weight> *fdistance;
43 
44   PruneOptions(const Weight& t, ArcFilter f, vector<Weight> *id = 0,
45                vector<Weight> *fd = 0)
threshold(t)46       : threshold(t), filter(f), idistance(id), fdistance(fd) {}
47 };
48 
49 
50 // Pruning algorithm: this version modifies its input and it takes an
51 // options class as an argment. Delete states and arcs in 'fst' that
52 // do not belong to a successful path whose weight is no more than
53 // 'opts.threshold' Times() the weight of the shortest path. Weights
54 // need to be commutative and have the path property.
55 template <class Arc, class ArcFilter>
Prune(MutableFst<Arc> * fst,const PruneOptions<Arc,ArcFilter> & opts)56 void Prune(MutableFst<Arc> *fst,
57            const PruneOptions<Arc, ArcFilter> &opts) {
58   typedef typename Arc::Weight Weight;
59   typedef typename Arc::StateId StateId;
60 
61   if ((Weight::Properties() & (kPath | kCommutative))
62       != (kPath | kCommutative))
63     LOG(FATAL) << "Prune: Weight needs to have the path property and"
64                << " be commutative: "
65                << Weight::Type();
66 
67   StateId ns = fst->NumStates();
68   if (ns == 0) return;
69 
70   vector<Weight> *idistance = opts.idistance;
71   vector<Weight> *fdistance = opts.fdistance;
72 
73   if (!idistance) {
74     idistance = new vector<Weight>(ns, Weight::Zero());
75     ShortestDistance(*fst, idistance, false);
76   } else {
77     idistance->resize(ns, Weight::Zero());
78   }
79 
80   if (!fdistance) {
81     fdistance = new vector<Weight>(ns, Weight::Zero());
82     ShortestDistance(*fst, fdistance, true);
83   } else {
84     fdistance->resize(ns, Weight::Zero());
85   }
86 
87   vector<StateId> dead;
88   dead.push_back(fst->AddState());
89   NaturalLess<Weight> less;
90   Weight ceiling = Times((*fdistance)[fst->Start()], opts.threshold);
91 
92   for (StateId state = 0; state < ns; ++state) {
93     if (less(ceiling, Times((*idistance)[state], (*fdistance)[state]))) {
94       dead.push_back(state);
95       continue;
96     }
97     for (MutableArcIterator< MutableFst<Arc> > it(fst, state);
98          !it.Done();
99          it.Next()) {
100       Arc arc = it.Value();
101       if (!opts.filter(arc)) continue;
102       Weight weight = Times(Times((*idistance)[state], arc.weight),
103                            (*fdistance)[arc.nextstate]);
104       if(less(ceiling, weight)) {
105         arc.nextstate = dead[0];
106         it.SetValue(arc);
107       }
108     }
109     if (less(ceiling, Times((*idistance)[state], fst->Final(state))))
110       fst->SetFinal(state, Weight::Zero());
111   }
112 
113   fst->DeleteStates(dead);
114 
115   if (!opts.idistance)
116     delete idistance;
117   if (!opts.fdistance)
118     delete fdistance;
119 }
120 
121 
122 // Pruning algorithm: this version modifies its input and simply takes
123 // the pruning threshold as an argument. Delete states and arcs in
124 // 'fst' that do not belong to a successful path whose weight is no
125 // more than 'opts.threshold' Times() the weight of the shortest
126 // path. Weights need to be commutative and have the path property.
127 template <class Arc>
Prune(MutableFst<Arc> * fst,typename Arc::Weight threshold)128 void Prune(MutableFst<Arc> *fst, typename Arc::Weight threshold) {
129   PruneOptions<Arc, AnyArcFilter<Arc> > opts(threshold, AnyArcFilter<Arc>());
130   Prune(fst, opts);
131 }
132 
133 
134 // Pruning algorithm: this version writes the pruned input Fst to an
135 // output MutableFst and it takes an options class as an argument.
136 // 'ofst' contains states and arcs that belong to a successful path in
137 // 'ifst' whose weight is no more than 'opts.threshold' Times() the
138 // weight of the shortest path. Weights need to be commutative and
139 // have the path property.
140 template <class Arc, class ArcFilter>
Prune(const Fst<Arc> & ifst,MutableFst<Arc> * ofst,const PruneOptions<Arc,ArcFilter> & opts)141 void Prune(const Fst<Arc> &ifst,
142            MutableFst<Arc> *ofst,
143            const PruneOptions<Arc, ArcFilter> &opts) {
144   typedef typename Arc::Weight Weight;
145   typedef typename Arc::StateId StateId;
146 
147   if ((Weight::Properties() & (kPath | kCommutative))
148       != (kPath | kCommutative))
149     LOG(FATAL) << "Prune: Weight needs to have the path property and"
150                << " be commutative: "
151                << Weight::Type();
152 
153   ofst->DeleteStates();
154 
155   if (ifst.Start() == kNoStateId)
156     return;
157 
158   vector<Weight> *idistance = opts.idistance;
159   vector<Weight> *fdistance = opts.fdistance;
160 
161   if (!idistance) {
162     idistance = new vector<Weight>;
163     ShortestDistance(ifst, idistance, false);
164   }
165 
166   if (!fdistance) {
167     fdistance = new vector<Weight>;
168     ShortestDistance(ifst, fdistance, true);
169   }
170 
171   vector<StateId> copy;
172   NaturalLess<Weight> less;
173   while (fdistance->size() <= ifst.Start())
174     fdistance->push_back(Weight::Zero());
175   Weight ceiling = Times((*fdistance)[ifst.Start()], opts.threshold);
176 
177   for (StateIterator< Fst<Arc> > sit(ifst);
178        !sit.Done();
179        sit.Next()) {
180     StateId state = sit.Value();
181     while (idistance->size() <= state)
182       idistance->push_back(Weight::Zero());
183     while (fdistance->size() <= state)
184       fdistance->push_back(Weight::Zero());
185     while (copy.size() <= state)
186       copy.push_back(kNoStateId);
187 
188     if (less(ceiling, Times((*idistance)[state], (*fdistance)[state])))
189       continue;
190 
191     if (copy[state] == kNoStateId)
192       copy[state] = ofst->AddState();
193     if (!less(ceiling, Times((*idistance)[state], ifst.Final(state))))
194       ofst->SetFinal(copy[state], ifst.Final(state));
195 
196     for (ArcIterator< Fst<Arc> > ait(ifst, state);
197          !ait.Done();
198          ait.Next()) {
199       Arc arc = ait.Value();
200 
201       if (!opts.filter(arc)) continue;
202 
203       while (idistance->size() <= arc.nextstate)
204         idistance->push_back(Weight::Zero());
205       while (fdistance->size() <= arc.nextstate)
206         fdistance->push_back(Weight::Zero());
207       while (copy.size() <= arc.nextstate)
208         copy.push_back(kNoStateId);
209 
210       Weight weight = Times(Times((*idistance)[state], arc.weight),
211                            (*fdistance)[arc.nextstate]);
212 
213       if (!less(ceiling, weight)) {
214         if (copy[arc.nextstate] == kNoStateId)
215           copy[arc.nextstate] = ofst->AddState();
216         arc.nextstate = copy[arc.nextstate];
217         ofst->AddArc(copy[state], arc);
218       }
219     }
220   }
221 
222   ofst->SetStart(copy[ifst.Start()]);
223 
224   if (!opts.idistance)
225     delete idistance;
226   if (!opts.fdistance)
227     delete fdistance;
228 }
229 
230 
231 // Pruning algorithm: this version writes the pruned input Fst to an
232 // output MutableFst and simply takes the pruning threshold as an
233 // argument.  'ofst' contains states and arcs that belong to a
234 // successful path in 'ifst' whose weight is no more than
235 // 'opts.threshold' Times() the weight of the shortest path. Weights
236 // need to be commutative and have the path property.
237 template <class Arc>
Prune(const Fst<Arc> & ifst,MutableFst<Arc> * ofst,typename Arc::Weight threshold)238 void Prune(const Fst<Arc> &ifst,
239            MutableFst<Arc> *ofst,
240            typename Arc::Weight threshold) {
241   PruneOptions<Arc, AnyArcFilter<Arc> > opts(threshold, AnyArcFilter<Arc>());
242   Prune(ifst, ofst, opts);
243 }
244 
245 } // namespace fst
246 
247 #endif // FST_LIB_PRUNE_H_
248