1 // prune.h
2
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Copyright 2005-2010 Google, Inc.
16 // Author: allauzen@google.com (Cyril Allauzen)
17 //
18 // \file
19 // Functions implementing pruning.
20
21 #ifndef FST_LIB_PRUNE_H__
22 #define FST_LIB_PRUNE_H__
23
24 #include <vector>
25 using std::vector;
26
27 #include <fst/arcfilter.h>
28 #include <fst/heap.h>
29 #include <fst/shortest-distance.h>
30
31
32 namespace fst {
33
34 template <class A, class ArcFilter>
35 class PruneOptions {
36 public:
37 typedef typename A::Weight Weight;
38 typedef typename A::StateId StateId;
39
40 // Pruning weight threshold.
41 Weight weight_threshold;
42 // Pruning state threshold.
43 StateId state_threshold;
44 // Arc filter.
45 ArcFilter filter;
46 // If non-zero, passes in pre-computed shortest distance to final states.
47 const vector<Weight> *distance;
48 // Determines the degree of convergence required when computing shortest
49 // distances.
50 float delta;
51
52 explicit PruneOptions(const Weight& w, StateId s, ArcFilter f,
53 vector<Weight> *d = 0, float e = kDelta)
weight_threshold(w)54 : weight_threshold(w),
55 state_threshold(s),
56 filter(f),
57 distance(d),
58 delta(e) {}
59 private:
60 PruneOptions(); // disallow
61 };
62
63
64 template <class S, class W>
65 class PruneCompare {
66 public:
67 typedef S StateId;
68 typedef W Weight;
69
PruneCompare(const vector<Weight> & idistance,const vector<Weight> & fdistance)70 PruneCompare(const vector<Weight> &idistance,
71 const vector<Weight> &fdistance)
72 : idistance_(idistance), fdistance_(fdistance) {}
73
operator()74 bool operator()(const StateId x, const StateId y) const {
75 Weight wx = Times(x < idistance_.size() ? idistance_[x] : Weight::Zero(),
76 x < fdistance_.size() ? fdistance_[x] : Weight::Zero());
77 Weight wy = Times(y < idistance_.size() ? idistance_[y] : Weight::Zero(),
78 y < fdistance_.size() ? fdistance_[y] : Weight::Zero());
79 return less_(wx, wy);
80 }
81
82 private:
83 const vector<Weight> &idistance_;
84 const vector<Weight> &fdistance_;
85 NaturalLess<Weight> less_;
86 };
87
88
89
90 // Pruning algorithm: this version modifies its input and it takes an
91 // options class as an argment. Delete states and arcs in 'fst' that
92 // do not belong to a successful path whose weight is no more than
93 // the weight of the shortest path Times() 'opts.weight_threshold'.
94 // When 'opts.state_threshold != kNoStateId', the resulting transducer
95 // will restricted further to have at most 'opts.state_threshold'
96 // states. Weights need to be commutative and have the path
97 // property. The weight 'w' of any cycle needs to be bounded, i.e.,
98 // 'Plus(w, W::One()) = One()'.
99 template <class Arc, class ArcFilter>
Prune(MutableFst<Arc> * fst,const PruneOptions<Arc,ArcFilter> & opts)100 void Prune(MutableFst<Arc> *fst,
101 const PruneOptions<Arc, ArcFilter> &opts) {
102 typedef typename Arc::Weight Weight;
103 typedef typename Arc::StateId StateId;
104
105 if ((Weight::Properties() & (kPath | kCommutative))
106 != (kPath | kCommutative)) {
107 FSTERROR() << "Prune: Weight needs to have the path property and"
108 << " be commutative: "
109 << Weight::Type();
110 fst->SetProperties(kError, kError);
111 return;
112 }
113 StateId ns = fst->NumStates();
114 if (ns == 0) return;
115 vector<Weight> idistance(ns, Weight::Zero());
116 vector<Weight> tmp;
117 if (!opts.distance) {
118 tmp.reserve(ns);
119 ShortestDistance(*fst, &tmp, true, opts.delta);
120 }
121 const vector<Weight> *fdistance = opts.distance ? opts.distance : &tmp;
122
123 if ((opts.state_threshold == 0) ||
124 (fdistance->size() <= fst->Start()) ||
125 ((*fdistance)[fst->Start()] == Weight::Zero())) {
126 fst->DeleteStates();
127 return;
128 }
129 PruneCompare<StateId, Weight> compare(idistance, *fdistance);
130 Heap< StateId, PruneCompare<StateId, Weight>, false> heap(compare);
131 vector<bool> visited(ns, false);
132 vector<size_t> enqueued(ns, kNoKey);
133 vector<StateId> dead;
134 dead.push_back(fst->AddState());
135 NaturalLess<Weight> less;
136 Weight limit = Times((*fdistance)[fst->Start()], opts.weight_threshold);
137
138 StateId num_visited = 0;
139 StateId s = fst->Start();
140 if (!less(limit, (*fdistance)[s])) {
141 idistance[s] = Weight::One();
142 enqueued[s] = heap.Insert(s);
143 ++num_visited;
144 }
145
146 while (!heap.Empty()) {
147 s = heap.Top();
148 heap.Pop();
149 enqueued[s] = kNoKey;
150 visited[s] = true;
151 if (less(limit, Times(idistance[s], fst->Final(s))))
152 fst->SetFinal(s, Weight::Zero());
153 for (MutableArcIterator< MutableFst<Arc> > ait(fst, s);
154 !ait.Done();
155 ait.Next()) {
156 Arc arc = ait.Value();
157 if (!opts.filter(arc)) continue;
158 Weight weight = Times(Times(idistance[s], arc.weight),
159 arc.nextstate < fdistance->size()
160 ? (*fdistance)[arc.nextstate]
161 : Weight::Zero());
162 if (less(limit, weight)) {
163 arc.nextstate = dead[0];
164 ait.SetValue(arc);
165 continue;
166 }
167 if (less(Times(idistance[s], arc.weight), idistance[arc.nextstate]))
168 idistance[arc.nextstate] = Times(idistance[s], arc.weight);
169 if (visited[arc.nextstate]) continue;
170 if ((opts.state_threshold != kNoStateId) &&
171 (num_visited >= opts.state_threshold))
172 continue;
173 if (enqueued[arc.nextstate] == kNoKey) {
174 enqueued[arc.nextstate] = heap.Insert(arc.nextstate);
175 ++num_visited;
176 } else {
177 heap.Update(enqueued[arc.nextstate], arc.nextstate);
178 }
179 }
180 }
181 for (size_t i = 0; i < visited.size(); ++i)
182 if (!visited[i]) dead.push_back(i);
183 fst->DeleteStates(dead);
184 }
185
186
187 // Pruning algorithm: this version modifies its input and simply takes
188 // the pruning threshold as an argument. Delete states and arcs in
189 // 'fst' that do not belong to a successful path whose weight is no
190 // more than the weight of the shortest path Times()
191 // 'weight_threshold'. When 'state_threshold != kNoStateId', the
192 // resulting transducer will be restricted further to have at most
193 // 'opts.state_threshold' states. Weights need to be commutative and
194 // have the path property. The weight 'w' of any cycle needs to be
195 // bounded, i.e., 'Plus(w, W::One()) = One()'.
196 template <class Arc>
197 void Prune(MutableFst<Arc> *fst,
198 typename Arc::Weight weight_threshold,
199 typename Arc::StateId state_threshold = kNoStateId,
200 double delta = kDelta) {
201 PruneOptions<Arc, AnyArcFilter<Arc> > opts(weight_threshold, state_threshold,
202 AnyArcFilter<Arc>(), 0, delta);
203 Prune(fst, opts);
204 }
205
206
207 // Pruning algorithm: this version writes the pruned input Fst to an
208 // output MutableFst and it takes an options class as an argument.
209 // 'ofst' contains states and arcs that belong to a successful path in
210 // 'ifst' whose weight is no more than the weight of the shortest path
211 // Times() 'opts.weight_threshold'. When 'opts.state_threshold !=
212 // kNoStateId', 'ofst' will be restricted further to have at most
213 // 'opts.state_threshold' states. Weights need to be commutative and
214 // have the path property. The weight 'w' of any cycle needs to be
215 // bounded, i.e., 'Plus(w, W::One()) = One()'.
216 template <class Arc, class ArcFilter>
Prune(const Fst<Arc> & ifst,MutableFst<Arc> * ofst,const PruneOptions<Arc,ArcFilter> & opts)217 void Prune(const Fst<Arc> &ifst,
218 MutableFst<Arc> *ofst,
219 const PruneOptions<Arc, ArcFilter> &opts) {
220 typedef typename Arc::Weight Weight;
221 typedef typename Arc::StateId StateId;
222
223 if ((Weight::Properties() & (kPath | kCommutative))
224 != (kPath | kCommutative)) {
225 FSTERROR() << "Prune: Weight needs to have the path property and"
226 << " be commutative: "
227 << Weight::Type();
228 ofst->SetProperties(kError, kError);
229 return;
230 }
231 ofst->DeleteStates();
232 ofst->SetInputSymbols(ifst.InputSymbols());
233 ofst->SetOutputSymbols(ifst.OutputSymbols());
234 if (ifst.Start() == kNoStateId)
235 return;
236 NaturalLess<Weight> less;
237 if (less(opts.weight_threshold, Weight::One()) ||
238 (opts.state_threshold == 0))
239 return;
240 vector<Weight> idistance;
241 vector<Weight> tmp;
242 if (!opts.distance)
243 ShortestDistance(ifst, &tmp, true, opts.delta);
244 const vector<Weight> *fdistance = opts.distance ? opts.distance : &tmp;
245
246 if ((fdistance->size() <= ifst.Start()) ||
247 ((*fdistance)[ifst.Start()] == Weight::Zero())) {
248 return;
249 }
250 PruneCompare<StateId, Weight> compare(idistance, *fdistance);
251 Heap< StateId, PruneCompare<StateId, Weight>, false> heap(compare);
252 vector<StateId> copy;
253 vector<size_t> enqueued;
254 vector<bool> visited;
255
256 StateId s = ifst.Start();
257 Weight limit = Times(s < fdistance->size() ? (*fdistance)[s] : Weight::Zero(),
258 opts.weight_threshold);
259 while (copy.size() <= s)
260 copy.push_back(kNoStateId);
261 copy[s] = ofst->AddState();
262 ofst->SetStart(copy[s]);
263 while (idistance.size() <= s)
264 idistance.push_back(Weight::Zero());
265 idistance[s] = Weight::One();
266 while (enqueued.size() <= s) {
267 enqueued.push_back(kNoKey);
268 visited.push_back(false);
269 }
270 enqueued[s] = heap.Insert(s);
271
272 while (!heap.Empty()) {
273 s = heap.Top();
274 heap.Pop();
275 enqueued[s] = kNoKey;
276 visited[s] = true;
277 if (!less(limit, Times(idistance[s], ifst.Final(s))))
278 ofst->SetFinal(copy[s], ifst.Final(s));
279 for (ArcIterator< Fst<Arc> > ait(ifst, s);
280 !ait.Done();
281 ait.Next()) {
282 const Arc &arc = ait.Value();
283 if (!opts.filter(arc)) continue;
284 Weight weight = Times(Times(idistance[s], arc.weight),
285 arc.nextstate < fdistance->size()
286 ? (*fdistance)[arc.nextstate]
287 : Weight::Zero());
288 if (less(limit, weight)) continue;
289 if ((opts.state_threshold != kNoStateId) &&
290 (ofst->NumStates() >= opts.state_threshold))
291 continue;
292 while (idistance.size() <= arc.nextstate)
293 idistance.push_back(Weight::Zero());
294 if (less(Times(idistance[s], arc.weight),
295 idistance[arc.nextstate]))
296 idistance[arc.nextstate] = Times(idistance[s], arc.weight);
297 while (copy.size() <= arc.nextstate)
298 copy.push_back(kNoStateId);
299 if (copy[arc.nextstate] == kNoStateId)
300 copy[arc.nextstate] = ofst->AddState();
301 ofst->AddArc(copy[s], Arc(arc.ilabel, arc.olabel, arc.weight,
302 copy[arc.nextstate]));
303 while (enqueued.size() <= arc.nextstate) {
304 enqueued.push_back(kNoKey);
305 visited.push_back(false);
306 }
307 if (visited[arc.nextstate]) continue;
308 if (enqueued[arc.nextstate] == kNoKey)
309 enqueued[arc.nextstate] = heap.Insert(arc.nextstate);
310 else
311 heap.Update(enqueued[arc.nextstate], arc.nextstate);
312 }
313 }
314 }
315
316
317 // Pruning algorithm: this version writes the pruned input Fst to an
318 // output MutableFst and simply takes the pruning threshold as an
319 // argument. 'ofst' contains states and arcs that belong to a
320 // successful path in 'ifst' whose weight is no more than
321 // the weight of the shortest path Times() 'weight_threshold'. When
322 // 'state_threshold != kNoStateId', 'ofst' will be restricted further
323 // to have at most 'opts.state_threshold' states. Weights need to be
324 // commutative and have the path property. The weight 'w' of any cycle
325 // needs to be bounded, i.e., 'Plus(w, W::One()) = W::One()'.
326 template <class Arc>
327 void Prune(const Fst<Arc> &ifst,
328 MutableFst<Arc> *ofst,
329 typename Arc::Weight weight_threshold,
330 typename Arc::StateId state_threshold = kNoStateId,
331 float delta = kDelta) {
332 PruneOptions<Arc, AnyArcFilter<Arc> > opts(weight_threshold, state_threshold,
333 AnyArcFilter<Arc>(), 0, delta);
334 Prune(ifst, ofst, opts);
335 }
336
337 } // namespace fst
338
339 #endif // FST_LIB_PRUNE_H_
340