• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // prune.h
2 
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Copyright 2005-2010 Google, Inc.
16 // Author: allauzen@google.com (Cyril Allauzen)
17 //
18 // \file
19 // Functions implementing pruning.
20 
21 #ifndef FST_LIB_PRUNE_H__
22 #define FST_LIB_PRUNE_H__
23 
24 #include <vector>
25 using std::vector;
26 
27 #include <fst/arcfilter.h>
28 #include <fst/heap.h>
29 #include <fst/shortest-distance.h>
30 
31 
32 namespace fst {
33 
34 template <class A, class ArcFilter>
35 class PruneOptions {
36  public:
37   typedef typename A::Weight Weight;
38   typedef typename A::StateId StateId;
39 
40   // Pruning weight threshold.
41   Weight weight_threshold;
42   // Pruning state threshold.
43   StateId state_threshold;
44   // Arc filter.
45   ArcFilter filter;
46   // If non-zero, passes in pre-computed shortest distance to final states.
47   const vector<Weight> *distance;
48   // Determines the degree of convergence required when computing shortest
49   // distances.
50   float delta;
51 
52   explicit PruneOptions(const Weight& w, StateId s, ArcFilter f,
53                         vector<Weight> *d = 0, float e = kDelta)
weight_threshold(w)54       : weight_threshold(w),
55         state_threshold(s),
56         filter(f),
57         distance(d),
58         delta(e) {}
59  private:
60   PruneOptions();  // disallow
61 };
62 
63 
64 template <class S, class W>
65 class PruneCompare {
66  public:
67   typedef S StateId;
68   typedef W Weight;
69 
PruneCompare(const vector<Weight> & idistance,const vector<Weight> & fdistance)70   PruneCompare(const vector<Weight> &idistance,
71                const vector<Weight> &fdistance)
72       : idistance_(idistance), fdistance_(fdistance) {}
73 
operator()74   bool operator()(const StateId x, const StateId y) const {
75     Weight wx = Times(x < idistance_.size() ? idistance_[x] : Weight::Zero(),
76                       x < fdistance_.size() ? fdistance_[x] : Weight::Zero());
77     Weight wy = Times(y < idistance_.size() ? idistance_[y] : Weight::Zero(),
78                       y < fdistance_.size() ? fdistance_[y] : Weight::Zero());
79     return less_(wx, wy);
80   }
81 
82  private:
83   const vector<Weight> &idistance_;
84   const vector<Weight> &fdistance_;
85   NaturalLess<Weight> less_;
86 };
87 
88 
89 
90 // Pruning algorithm: this version modifies its input and it takes an
91 // options class as an argment. Delete states and arcs in 'fst' that
92 // do not belong to a successful path whose weight is no more than
93 // the weight of the shortest path Times() 'opts.weight_threshold'.
94 // When 'opts.state_threshold != kNoStateId', the resulting transducer
95 // will restricted further to have at most 'opts.state_threshold'
96 // states. Weights need to be commutative and have the path
97 // property. The weight 'w' of any cycle needs to be bounded, i.e.,
98 // 'Plus(w, W::One()) = One()'.
99 template <class Arc, class ArcFilter>
Prune(MutableFst<Arc> * fst,const PruneOptions<Arc,ArcFilter> & opts)100 void Prune(MutableFst<Arc> *fst,
101            const PruneOptions<Arc, ArcFilter> &opts) {
102   typedef typename Arc::Weight Weight;
103   typedef typename Arc::StateId StateId;
104 
105   if ((Weight::Properties() & (kPath | kCommutative))
106       != (kPath | kCommutative)) {
107     FSTERROR() << "Prune: Weight needs to have the path property and"
108                << " be commutative: "
109                << Weight::Type();
110     fst->SetProperties(kError, kError);
111     return;
112   }
113   StateId ns = fst->NumStates();
114   if (ns == 0) return;
115   vector<Weight> idistance(ns, Weight::Zero());
116   vector<Weight> tmp;
117   if (!opts.distance) {
118     tmp.reserve(ns);
119     ShortestDistance(*fst, &tmp, true, opts.delta);
120   }
121   const vector<Weight> *fdistance = opts.distance ? opts.distance : &tmp;
122 
123   if ((opts.state_threshold == 0) ||
124       (fdistance->size() <= fst->Start()) ||
125       ((*fdistance)[fst->Start()] == Weight::Zero())) {
126     fst->DeleteStates();
127     return;
128   }
129   PruneCompare<StateId, Weight> compare(idistance, *fdistance);
130   Heap< StateId, PruneCompare<StateId, Weight>, false> heap(compare);
131   vector<bool> visited(ns, false);
132   vector<size_t> enqueued(ns, kNoKey);
133   vector<StateId> dead;
134   dead.push_back(fst->AddState());
135   NaturalLess<Weight> less;
136   Weight limit = Times((*fdistance)[fst->Start()], opts.weight_threshold);
137 
138   StateId num_visited = 0;
139   StateId s = fst->Start();
140   if (!less(limit, (*fdistance)[s])) {
141     idistance[s] = Weight::One();
142     enqueued[s] = heap.Insert(s);
143     ++num_visited;
144   }
145 
146   while (!heap.Empty()) {
147     s = heap.Top();
148     heap.Pop();
149     enqueued[s] = kNoKey;
150     visited[s] = true;
151     if (less(limit, Times(idistance[s], fst->Final(s))))
152       fst->SetFinal(s, Weight::Zero());
153     for (MutableArcIterator< MutableFst<Arc> > ait(fst, s);
154          !ait.Done();
155          ait.Next()) {
156       Arc arc = ait.Value();
157       if (!opts.filter(arc)) continue;
158       Weight weight = Times(Times(idistance[s], arc.weight),
159                             arc.nextstate < fdistance->size()
160                             ? (*fdistance)[arc.nextstate]
161                             : Weight::Zero());
162       if (less(limit, weight)) {
163         arc.nextstate = dead[0];
164         ait.SetValue(arc);
165         continue;
166       }
167       if (less(Times(idistance[s], arc.weight), idistance[arc.nextstate]))
168         idistance[arc.nextstate] = Times(idistance[s], arc.weight);
169       if (visited[arc.nextstate]) continue;
170       if ((opts.state_threshold != kNoStateId) &&
171           (num_visited >= opts.state_threshold))
172         continue;
173       if (enqueued[arc.nextstate] == kNoKey) {
174         enqueued[arc.nextstate] = heap.Insert(arc.nextstate);
175         ++num_visited;
176       } else {
177         heap.Update(enqueued[arc.nextstate], arc.nextstate);
178       }
179     }
180   }
181   for (size_t i = 0; i < visited.size(); ++i)
182     if (!visited[i]) dead.push_back(i);
183   fst->DeleteStates(dead);
184 }
185 
186 
187 // Pruning algorithm: this version modifies its input and simply takes
188 // the pruning threshold as an argument. Delete states and arcs in
189 // 'fst' that do not belong to a successful path whose weight is no
190 // more than the weight of the shortest path Times()
191 // 'weight_threshold'.  When 'state_threshold != kNoStateId', the
192 // resulting transducer will be restricted further to have at most
193 // 'opts.state_threshold' states. Weights need to be commutative and
194 // have the path property. The weight 'w' of any cycle needs to be
195 // bounded, i.e., 'Plus(w, W::One()) = One()'.
196 template <class Arc>
197 void Prune(MutableFst<Arc> *fst,
198            typename Arc::Weight weight_threshold,
199            typename Arc::StateId state_threshold = kNoStateId,
200            double delta = kDelta) {
201   PruneOptions<Arc, AnyArcFilter<Arc> > opts(weight_threshold, state_threshold,
202                                              AnyArcFilter<Arc>(), 0, delta);
203   Prune(fst, opts);
204 }
205 
206 
207 // Pruning algorithm: this version writes the pruned input Fst to an
208 // output MutableFst and it takes an options class as an argument.
209 // 'ofst' contains states and arcs that belong to a successful path in
210 // 'ifst' whose weight is no more than the weight of the shortest path
211 // Times() 'opts.weight_threshold'. When 'opts.state_threshold !=
212 // kNoStateId', 'ofst' will be restricted further to have at most
213 // 'opts.state_threshold' states. Weights need to be commutative and
214 // have the path property. The weight 'w' of any cycle needs to be
215 // bounded, i.e., 'Plus(w, W::One()) = One()'.
216 template <class Arc, class ArcFilter>
Prune(const Fst<Arc> & ifst,MutableFst<Arc> * ofst,const PruneOptions<Arc,ArcFilter> & opts)217 void Prune(const Fst<Arc> &ifst,
218            MutableFst<Arc> *ofst,
219            const PruneOptions<Arc, ArcFilter> &opts) {
220   typedef typename Arc::Weight Weight;
221   typedef typename Arc::StateId StateId;
222 
223   if ((Weight::Properties() & (kPath | kCommutative))
224       != (kPath | kCommutative)) {
225     FSTERROR() << "Prune: Weight needs to have the path property and"
226                << " be commutative: "
227                << Weight::Type();
228     ofst->SetProperties(kError, kError);
229     return;
230   }
231   ofst->DeleteStates();
232   ofst->SetInputSymbols(ifst.InputSymbols());
233   ofst->SetOutputSymbols(ifst.OutputSymbols());
234   if (ifst.Start() == kNoStateId)
235     return;
236   NaturalLess<Weight> less;
237   if (less(opts.weight_threshold, Weight::One()) ||
238       (opts.state_threshold == 0))
239     return;
240   vector<Weight> idistance;
241   vector<Weight> tmp;
242   if (!opts.distance)
243     ShortestDistance(ifst, &tmp, true, opts.delta);
244   const vector<Weight> *fdistance = opts.distance ? opts.distance : &tmp;
245 
246   if ((fdistance->size() <= ifst.Start()) ||
247       ((*fdistance)[ifst.Start()] == Weight::Zero())) {
248     return;
249   }
250   PruneCompare<StateId, Weight> compare(idistance, *fdistance);
251   Heap< StateId, PruneCompare<StateId, Weight>, false> heap(compare);
252   vector<StateId> copy;
253   vector<size_t> enqueued;
254   vector<bool> visited;
255 
256   StateId s = ifst.Start();
257   Weight limit = Times(s < fdistance->size() ? (*fdistance)[s] : Weight::Zero(),
258                          opts.weight_threshold);
259   while (copy.size() <= s)
260     copy.push_back(kNoStateId);
261   copy[s] = ofst->AddState();
262   ofst->SetStart(copy[s]);
263   while (idistance.size() <= s)
264     idistance.push_back(Weight::Zero());
265   idistance[s] = Weight::One();
266   while (enqueued.size() <= s) {
267     enqueued.push_back(kNoKey);
268     visited.push_back(false);
269   }
270   enqueued[s] = heap.Insert(s);
271 
272   while (!heap.Empty()) {
273     s = heap.Top();
274     heap.Pop();
275     enqueued[s] = kNoKey;
276     visited[s] = true;
277     if (!less(limit, Times(idistance[s], ifst.Final(s))))
278       ofst->SetFinal(copy[s], ifst.Final(s));
279     for (ArcIterator< Fst<Arc> > ait(ifst, s);
280          !ait.Done();
281          ait.Next()) {
282       const Arc &arc = ait.Value();
283       if (!opts.filter(arc)) continue;
284       Weight weight = Times(Times(idistance[s], arc.weight),
285                             arc.nextstate < fdistance->size()
286                             ? (*fdistance)[arc.nextstate]
287                             : Weight::Zero());
288       if (less(limit, weight)) continue;
289       if ((opts.state_threshold != kNoStateId) &&
290           (ofst->NumStates() >= opts.state_threshold))
291         continue;
292       while (idistance.size() <= arc.nextstate)
293         idistance.push_back(Weight::Zero());
294       if (less(Times(idistance[s], arc.weight),
295                idistance[arc.nextstate]))
296         idistance[arc.nextstate] = Times(idistance[s], arc.weight);
297       while (copy.size() <= arc.nextstate)
298         copy.push_back(kNoStateId);
299       if (copy[arc.nextstate] == kNoStateId)
300         copy[arc.nextstate] = ofst->AddState();
301       ofst->AddArc(copy[s], Arc(arc.ilabel, arc.olabel, arc.weight,
302                                 copy[arc.nextstate]));
303       while (enqueued.size() <= arc.nextstate) {
304         enqueued.push_back(kNoKey);
305         visited.push_back(false);
306       }
307       if (visited[arc.nextstate]) continue;
308       if (enqueued[arc.nextstate] == kNoKey)
309         enqueued[arc.nextstate] = heap.Insert(arc.nextstate);
310       else
311         heap.Update(enqueued[arc.nextstate], arc.nextstate);
312     }
313   }
314 }
315 
316 
317 // Pruning algorithm: this version writes the pruned input Fst to an
318 // output MutableFst and simply takes the pruning threshold as an
319 // argument.  'ofst' contains states and arcs that belong to a
320 // successful path in 'ifst' whose weight is no more than
321 // the weight of the shortest path Times() 'weight_threshold'. When
322 // 'state_threshold != kNoStateId', 'ofst' will be restricted further
323 // to have at most 'opts.state_threshold' states. Weights need to be
324 // commutative and have the path property. The weight 'w' of any cycle
325 // needs to be bounded, i.e., 'Plus(w, W::One()) = W::One()'.
326 template <class Arc>
327 void Prune(const Fst<Arc> &ifst,
328            MutableFst<Arc> *ofst,
329            typename Arc::Weight weight_threshold,
330            typename Arc::StateId state_threshold = kNoStateId,
331            float delta = kDelta) {
332   PruneOptions<Arc, AnyArcFilter<Arc> > opts(weight_threshold, state_threshold,
333                                              AnyArcFilter<Arc>(), 0, delta);
334   Prune(ifst, ofst, opts);
335 }
336 
337 }  // namespace fst
338 
339 #endif // FST_LIB_PRUNE_H_
340