1 // prune.h
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Author: allauzen@cs.nyu.edu (Cyril Allauzen)
16 //
17 // \file
18 // Functions implementing pruning.
19
20 #ifndef FST_LIB_PRUNE_H__
21 #define FST_LIB_PRUNE_H__
22
23 #include "fst/lib/arcfilter.h"
24 #include "fst/lib/shortest-distance.h"
25
26 namespace fst {
27
28 template <class A, class ArcFilter>
29 class PruneOptions {
30 public:
31 typedef typename A::Weight Weight;
32
33 // Pruning threshold.
34 Weight threshold;
35 // Arc filter.
36 ArcFilter filter;
37 // If non-zero, passes in pre-computed shortest distance from initial state
38 // (possibly resized).
39 vector<Weight> *idistance;
40 // If non-zero, passes in pre-computed shortest distance to final states
41 // (possibly resized).
42 vector<Weight> *fdistance;
43
44 PruneOptions(const Weight& t, ArcFilter f, vector<Weight> *id = 0,
45 vector<Weight> *fd = 0)
threshold(t)46 : threshold(t), filter(f), idistance(id), fdistance(fd) {}
47 };
48
49
50 // Pruning algorithm: this version modifies its input and it takes an
51 // options class as an argment. Delete states and arcs in 'fst' that
52 // do not belong to a successful path whose weight is no more than
53 // 'opts.threshold' Times() the weight of the shortest path. Weights
54 // need to be commutative and have the path property.
55 template <class Arc, class ArcFilter>
Prune(MutableFst<Arc> * fst,const PruneOptions<Arc,ArcFilter> & opts)56 void Prune(MutableFst<Arc> *fst,
57 const PruneOptions<Arc, ArcFilter> &opts) {
58 typedef typename Arc::Weight Weight;
59 typedef typename Arc::StateId StateId;
60
61 if ((Weight::Properties() & (kPath | kCommutative))
62 != (kPath | kCommutative))
63 LOG(FATAL) << "Prune: Weight needs to have the path property and"
64 << " be commutative: "
65 << Weight::Type();
66
67 StateId ns = fst->NumStates();
68 if (ns == 0) return;
69
70 vector<Weight> *idistance = opts.idistance;
71 vector<Weight> *fdistance = opts.fdistance;
72
73 if (!idistance) {
74 idistance = new vector<Weight>(ns, Weight::Zero());
75 ShortestDistance(*fst, idistance, false);
76 } else {
77 idistance->resize(ns, Weight::Zero());
78 }
79
80 if (!fdistance) {
81 fdistance = new vector<Weight>(ns, Weight::Zero());
82 ShortestDistance(*fst, fdistance, true);
83 } else {
84 fdistance->resize(ns, Weight::Zero());
85 }
86
87 vector<StateId> dead;
88 dead.push_back(fst->AddState());
89 NaturalLess<Weight> less;
90 Weight ceiling = Times((*fdistance)[fst->Start()], opts.threshold);
91
92 for (StateId state = 0; state < ns; ++state) {
93 if (less(ceiling, Times((*idistance)[state], (*fdistance)[state]))) {
94 dead.push_back(state);
95 continue;
96 }
97 for (MutableArcIterator< MutableFst<Arc> > it(fst, state);
98 !it.Done();
99 it.Next()) {
100 Arc arc = it.Value();
101 if (!opts.filter(arc)) continue;
102 Weight weight = Times(Times((*idistance)[state], arc.weight),
103 (*fdistance)[arc.nextstate]);
104 if(less(ceiling, weight)) {
105 arc.nextstate = dead[0];
106 it.SetValue(arc);
107 }
108 }
109 if (less(ceiling, Times((*idistance)[state], fst->Final(state))))
110 fst->SetFinal(state, Weight::Zero());
111 }
112
113 fst->DeleteStates(dead);
114
115 if (!opts.idistance)
116 delete idistance;
117 if (!opts.fdistance)
118 delete fdistance;
119 }
120
121
122 // Pruning algorithm: this version modifies its input and simply takes
123 // the pruning threshold as an argument. Delete states and arcs in
124 // 'fst' that do not belong to a successful path whose weight is no
125 // more than 'opts.threshold' Times() the weight of the shortest
126 // path. Weights need to be commutative and have the path property.
127 template <class Arc>
Prune(MutableFst<Arc> * fst,typename Arc::Weight threshold)128 void Prune(MutableFst<Arc> *fst, typename Arc::Weight threshold) {
129 PruneOptions<Arc, AnyArcFilter<Arc> > opts(threshold, AnyArcFilter<Arc>());
130 Prune(fst, opts);
131 }
132
133
134 // Pruning algorithm: this version writes the pruned input Fst to an
135 // output MutableFst and it takes an options class as an argument.
136 // 'ofst' contains states and arcs that belong to a successful path in
137 // 'ifst' whose weight is no more than 'opts.threshold' Times() the
138 // weight of the shortest path. Weights need to be commutative and
139 // have the path property.
140 template <class Arc, class ArcFilter>
Prune(const Fst<Arc> & ifst,MutableFst<Arc> * ofst,const PruneOptions<Arc,ArcFilter> & opts)141 void Prune(const Fst<Arc> &ifst,
142 MutableFst<Arc> *ofst,
143 const PruneOptions<Arc, ArcFilter> &opts) {
144 typedef typename Arc::Weight Weight;
145 typedef typename Arc::StateId StateId;
146
147 if ((Weight::Properties() & (kPath | kCommutative))
148 != (kPath | kCommutative))
149 LOG(FATAL) << "Prune: Weight needs to have the path property and"
150 << " be commutative: "
151 << Weight::Type();
152
153 ofst->DeleteStates();
154
155 if (ifst.Start() == kNoStateId)
156 return;
157
158 vector<Weight> *idistance = opts.idistance;
159 vector<Weight> *fdistance = opts.fdistance;
160
161 if (!idistance) {
162 idistance = new vector<Weight>;
163 ShortestDistance(ifst, idistance, false);
164 }
165
166 if (!fdistance) {
167 fdistance = new vector<Weight>;
168 ShortestDistance(ifst, fdistance, true);
169 }
170
171 vector<StateId> copy;
172 NaturalLess<Weight> less;
173 while (fdistance->size() <= ifst.Start())
174 fdistance->push_back(Weight::Zero());
175 Weight ceiling = Times((*fdistance)[ifst.Start()], opts.threshold);
176
177 for (StateIterator< Fst<Arc> > sit(ifst);
178 !sit.Done();
179 sit.Next()) {
180 StateId state = sit.Value();
181 while (idistance->size() <= state)
182 idistance->push_back(Weight::Zero());
183 while (fdistance->size() <= state)
184 fdistance->push_back(Weight::Zero());
185 while (copy.size() <= state)
186 copy.push_back(kNoStateId);
187
188 if (less(ceiling, Times((*idistance)[state], (*fdistance)[state])))
189 continue;
190
191 if (copy[state] == kNoStateId)
192 copy[state] = ofst->AddState();
193 if (!less(ceiling, Times((*idistance)[state], ifst.Final(state))))
194 ofst->SetFinal(copy[state], ifst.Final(state));
195
196 for (ArcIterator< Fst<Arc> > ait(ifst, state);
197 !ait.Done();
198 ait.Next()) {
199 Arc arc = ait.Value();
200
201 if (!opts.filter(arc)) continue;
202
203 while (idistance->size() <= arc.nextstate)
204 idistance->push_back(Weight::Zero());
205 while (fdistance->size() <= arc.nextstate)
206 fdistance->push_back(Weight::Zero());
207 while (copy.size() <= arc.nextstate)
208 copy.push_back(kNoStateId);
209
210 Weight weight = Times(Times((*idistance)[state], arc.weight),
211 (*fdistance)[arc.nextstate]);
212
213 if (!less(ceiling, weight)) {
214 if (copy[arc.nextstate] == kNoStateId)
215 copy[arc.nextstate] = ofst->AddState();
216 arc.nextstate = copy[arc.nextstate];
217 ofst->AddArc(copy[state], arc);
218 }
219 }
220 }
221
222 ofst->SetStart(copy[ifst.Start()]);
223
224 if (!opts.idistance)
225 delete idistance;
226 if (!opts.fdistance)
227 delete fdistance;
228 }
229
230
231 // Pruning algorithm: this version writes the pruned input Fst to an
232 // output MutableFst and simply takes the pruning threshold as an
233 // argument. 'ofst' contains states and arcs that belong to a
234 // successful path in 'ifst' whose weight is no more than
235 // 'opts.threshold' Times() the weight of the shortest path. Weights
236 // need to be commutative and have the path property.
237 template <class Arc>
Prune(const Fst<Arc> & ifst,MutableFst<Arc> * ofst,typename Arc::Weight threshold)238 void Prune(const Fst<Arc> &ifst,
239 MutableFst<Arc> *ofst,
240 typename Arc::Weight threshold) {
241 PruneOptions<Arc, AnyArcFilter<Arc> > opts(threshold, AnyArcFilter<Arc>());
242 Prune(ifst, ofst, opts);
243 }
244
245 } // namespace fst
246
247 #endif // FST_LIB_PRUNE_H_
248