1 // shortest-path.h
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Author: allauzen@cs.nyu.edu (Cyril Allauzen)
16 //
17 // \file
18 // Functions to find shortest paths in an FST.
19
20 #ifndef FST_LIB_SHORTEST_PATH_H__
21 #define FST_LIB_SHORTEST_PATH_H__
22
23 #include <functional>
24
25 #include "fst/lib/cache.h"
26 #include "fst/lib/queue.h"
27 #include "fst/lib/shortest-distance.h"
28 #include "fst/lib/test-properties.h"
29
30 namespace fst {
31
32 template <class Arc, class Queue, class ArcFilter>
33 struct ShortestPathOptions
34 : public ShortestDistanceOptions<Arc, Queue, ArcFilter> {
35 typedef typename Arc::StateId StateId;
36
37 size_t nshortest; // return n-shortest paths
38 bool unique; // only return paths with distinct input strings
39 bool has_distance; // distance vector already contains the
40 // shortest distance from the initial state
41
42 ShortestPathOptions(Queue *q, ArcFilter filt, size_t n = 1, bool u = false,
43 bool hasdist = false, float d = kDelta)
44 : ShortestDistanceOptions<Arc, Queue, ArcFilter>(q, filt, kNoStateId, d),
45 nshortest(n), unique(u), has_distance(hasdist) {}
46 };
47
48
49 // Shortest-path algorithm: normally not called directly; prefer
50 // 'ShortestPath' below with n=1. 'ofst' contains the shortest path in
51 // 'ifst'. 'distance' returns the shortest distances from the source
52 // state to each state in 'ifst'. 'opts' is used to specify options
53 // such as the queue discipline, the arc filter and delta.
54 //
55 // The shortest path is the lowest weight path w.r.t. the natural
56 // semiring order.
57 //
58 // The weights need to be right distributive and have the path (kPath)
59 // property.
60 template<class Arc, class Queue, class ArcFilter>
SingleShortestPath(const Fst<Arc> & ifst,MutableFst<Arc> * ofst,vector<typename Arc::Weight> * distance,ShortestPathOptions<Arc,Queue,ArcFilter> & opts)61 void SingleShortestPath(const Fst<Arc> &ifst,
62 MutableFst<Arc> *ofst,
63 vector<typename Arc::Weight> *distance,
64 ShortestPathOptions<Arc, Queue, ArcFilter> &opts) {
65 typedef typename Arc::StateId StateId;
66 typedef typename Arc::Weight Weight;
67
68 ofst->DeleteStates();
69 ofst->SetInputSymbols(ifst.InputSymbols());
70 ofst->SetOutputSymbols(ifst.OutputSymbols());
71
72 if (ifst.Start() == kNoStateId)
73 return;
74
75 vector<Weight> rdistance;
76 vector<bool> enqueued;
77 vector<StateId> parent;
78 vector<Arc> arc_parent;
79
80 Queue *state_queue = opts.state_queue;
81 StateId source = opts.source == kNoStateId ? ifst.Start() : opts.source;
82 Weight f_distance = Weight::Zero();
83 StateId f_parent = kNoStateId;
84
85 distance->clear();
86 state_queue->Clear();
87 if (opts.nshortest != 1)
88 LOG(FATAL) << "SingleShortestPath: for nshortest > 1, use ShortestPath"
89 << " instead";
90 if ((Weight::Properties() & (kPath | kRightSemiring))
91 != (kPath | kRightSemiring))
92 LOG(FATAL) << "SingleShortestPath: Weight needs to have the path"
93 << " property and be right distributive: " << Weight::Type();
94
95 while (distance->size() < source) {
96 distance->push_back(Weight::Zero());
97 enqueued.push_back(false);
98 parent.push_back(kNoStateId);
99 arc_parent.push_back(Arc(kNoLabel, kNoLabel, Weight::Zero(), kNoStateId));
100 }
101 distance->push_back(Weight::One());
102 parent.push_back(kNoStateId);
103 arc_parent.push_back(Arc(kNoLabel, kNoLabel, Weight::Zero(), kNoStateId));
104 state_queue->Enqueue(source);
105 enqueued.push_back(true);
106
107 while (!state_queue->Empty()) {
108 StateId s = state_queue->Head();
109 state_queue->Dequeue();
110 enqueued[s] = false;
111 Weight sd = (*distance)[s];
112 for (ArcIterator< Fst<Arc> > aiter(ifst, s);
113 !aiter.Done();
114 aiter.Next()) {
115 const Arc &arc = aiter.Value();
116 while (distance->size() <= arc.nextstate) {
117 distance->push_back(Weight::Zero());
118 enqueued.push_back(false);
119 parent.push_back(kNoStateId);
120 arc_parent.push_back(Arc(kNoLabel, kNoLabel, Weight::Zero(),
121 kNoStateId));
122 }
123 Weight &nd = (*distance)[arc.nextstate];
124 Weight w = Times(sd, arc.weight);
125 if (nd != Plus(nd, w)) {
126 nd = Plus(nd, w);
127 parent[arc.nextstate] = s;
128 arc_parent[arc.nextstate] = arc;
129 if (!enqueued[arc.nextstate]) {
130 state_queue->Enqueue(arc.nextstate);
131 enqueued[arc.nextstate] = true;
132 } else {
133 state_queue->Update(arc.nextstate);
134 }
135 }
136 }
137 if (ifst.Final(s) != Weight::Zero()) {
138 Weight w = Times(sd, ifst.Final(s));
139 if (f_distance != Plus(f_distance, w)) {
140 f_distance = Plus(f_distance, w);
141 f_parent = s;
142 }
143 }
144 }
145 (*distance)[source] = Weight::One();
146 parent[source] = kNoStateId;
147
148 StateId s_p = kNoStateId, d_p = kNoStateId;
149 for (StateId s = f_parent, d = kNoStateId;
150 s != kNoStateId;
151 d = s, s = parent[s]) {
152 enqueued[s] = true;
153 d_p = s_p;
154 s_p = ofst->AddState();
155 if (d == kNoStateId) {
156 ofst->SetFinal(s_p, ifst.Final(f_parent));
157 } else {
158 arc_parent[d].nextstate = d_p;
159 ofst->AddArc(s_p, arc_parent[d]);
160 }
161 }
162 ofst->SetStart(s_p);
163 }
164
165
166 template <class S, class W>
167 class ShortestPathCompare {
168 public:
169 typedef S StateId;
170 typedef W Weight;
171 typedef pair<StateId, Weight> Pair;
172
ShortestPathCompare(const vector<Pair> & pairs,const vector<Weight> & distance,StateId sfinal,float d)173 ShortestPathCompare(const vector<Pair>& pairs,
174 const vector<Weight>& distance,
175 StateId sfinal, float d)
176 : pairs_(pairs), distance_(distance), superfinal_(sfinal), delta_(d) {}
177
operator()178 bool operator()(const StateId x, const StateId y) const {
179 const Pair &px = pairs_[x];
180 const Pair &py = pairs_[y];
181 Weight wx = Times(distance_[px.first], px.second);
182 Weight wy = Times(distance_[py.first], py.second);
183 // Penalize complete paths to ensure correct results with inexact weights.
184 // This forms a strict weak order so long as ApproxEqual(a, b) =>
185 // ApproxEqual(a, c) for all c s.t. less_(a, c) && less_(c, b).
186 if (px.first == superfinal_ && py.first != superfinal_) {
187 return less_(wy, wx) || ApproxEqual(wx, wy, delta_);
188 } else if (py.first == superfinal_ && px.first != superfinal_) {
189 return less_(wy, wx) && !ApproxEqual(wx, wy, delta_);
190 } else {
191 return less_(wy, wx);
192 }
193 }
194
195 private:
196 const vector<Pair> &pairs_;
197 const vector<Weight> &distance_;
198 StateId superfinal_;
199 float delta_;
200 NaturalLess<Weight> less_;
201 };
202
203
204 // N-Shortest-path algorithm: this version allow fine control
205 // via the otpions argument. See below for a simpler interface.
206 //
207 // 'ofst' contains the n-shortest paths in 'ifst'. 'distance' returns
208 // the shortest distances from the source state to each state in
209 // 'ifst'. 'opts' is used to specify options such as the number of
210 // paths to return, whether they need to have distinct input
211 // strings, the queue discipline, the arc filter and the convergence
212 // delta.
213 //
214 // The n-shortest paths are the n-lowest weight paths w.r.t. the
215 // natural semiring order. The single path that can be
216 // read from the ith of at most n transitions leaving the initial
217 // state of 'ofst' is the ith shortest path.
218
219 // The weights need to be right distributive and have the path (kPath)
220 // property. They need to be left distributive as well for nshortest
221 // > 1.
222 //
223 // The algorithm is from Mohri and Riley, "An Efficient Algorithm for
224 // the n-best-strings problem", ICSLP 2002. The algorithm relies on
225 // the shortest-distance algorithm. There are some issues with the
226 // pseudo-code as written in the paper (viz., line 11).
227 template<class Arc, class Queue, class ArcFilter>
ShortestPath(const Fst<Arc> & ifst,MutableFst<Arc> * ofst,vector<typename Arc::Weight> * distance,ShortestPathOptions<Arc,Queue,ArcFilter> & opts)228 void ShortestPath(const Fst<Arc> &ifst, MutableFst<Arc> *ofst,
229 vector<typename Arc::Weight> *distance,
230 ShortestPathOptions<Arc, Queue, ArcFilter> &opts) {
231 typedef typename Arc::StateId StateId;
232 typedef typename Arc::Weight Weight;
233 typedef pair<StateId, Weight> Pair;
234 typedef ReverseArc<Arc> ReverseArc;
235 typedef typename ReverseArc::Weight ReverseWeight;
236
237 size_t n = opts.nshortest;
238
239 if (n == 1) {
240 SingleShortestPath(ifst, ofst, distance, opts);
241 return;
242 }
243 ofst->DeleteStates();
244 ofst->SetInputSymbols(ifst.InputSymbols());
245 ofst->SetOutputSymbols(ifst.OutputSymbols());
246 if (n <= 0) return;
247 if ((Weight::Properties() & (kPath | kSemiring)) != (kPath | kSemiring))
248 LOG(FATAL) << "ShortestPath: n-shortest: Weight needs to have the "
249 << "path property and be distributive: "
250 << Weight::Type();
251 if (opts.unique)
252 LOG(FATAL) << "ShortestPath: n-shortest-string algorithm not "
253 << "currently implemented";
254
255 // Algorithm works on the reverse of 'fst' : 'rfst' 'distance' is
256 // the distance to the final state in 'rfst' 'ofst' is built as the
257 // reverse of the tree of n-shortest path in 'rfst'.
258
259 if (!opts.has_distance)
260 ShortestDistance(ifst, distance, opts);
261 VectorFst<ReverseArc> rfst;
262 Reverse(ifst, &rfst);
263 distance->insert(distance->begin(), Weight::One());
264 while (distance->size() < rfst.NumStates())
265 distance->push_back(Weight::Zero());
266
267
268 // Each state in 'ofst' corresponds to a path with weight w from the
269 // initial state of 'rfst' to a state s in 'rfst', that can be
270 // characterized by a pair (s,w). The vector 'pairs' maps each
271 // state in 'ofst' to the corresponding pair maps states in OFST to
272 // the corresponding pair (s,w).
273 vector<Pair> pairs;
274 // 'r[s]', 's' state in 'fst', is the number of states in 'ofst'
275 // which corresponding pair contains 's' ,i.e. , it is number of
276 // paths computed so far to 's'.
277 StateId superfinal = distance->size(); // superfinal must be handled
278 distance->push_back(Weight::One()); // differently when unique=true
279 ShortestPathCompare<StateId, Weight>
280 compare(pairs, *distance, superfinal, opts.delta);
281 vector<StateId> heap;
282 vector<int> r;
283 while (r.size() < distance->size())
284 r.push_back(0);
285 ofst->SetStart(ofst->AddState());
286 StateId final = ofst->AddState();
287 ofst->SetFinal(final, Weight::One());
288 while (pairs.size() <= final)
289 pairs.push_back(Pair(kNoStateId, Weight::Zero()));
290 pairs[final] = Pair(rfst.Start(), Weight::One());
291 heap.push_back(final);
292
293 while (!heap.empty()) {
294 pop_heap(heap.begin(), heap.end(), compare);
295 StateId state = heap.back();
296 Pair p = pairs[state];
297 heap.pop_back();
298
299 ++r[p.first];
300 if (p.first == superfinal)
301 ofst->AddArc(ofst->Start(), Arc(0, 0, Weight::One(), state));
302 if ((p.first == superfinal) && (r[p.first] == n)) break;
303 if (r[p.first] > n) continue;
304 if (p.first == superfinal)
305 continue;
306
307 for (ArcIterator< Fst<ReverseArc> > aiter(rfst, p.first);
308 !aiter.Done();
309 aiter.Next()) {
310 const ReverseArc &rarc = aiter.Value();
311 Arc arc(rarc.ilabel, rarc.olabel, rarc.weight.Reverse(), rarc.nextstate);
312 Weight w = Times(p.second, arc.weight);
313 StateId next = ofst->AddState();
314 pairs.push_back(Pair(arc.nextstate, w));
315 arc.nextstate = state;
316 ofst->AddArc(next, arc);
317 heap.push_back(next);
318 push_heap(heap.begin(), heap.end(), compare);
319 }
320
321 Weight finalw = rfst.Final(p.first).Reverse();
322 if (finalw != Weight::Zero()) {
323 Weight w = Times(p.second, finalw);
324 StateId next = ofst->AddState();
325 pairs.push_back(Pair(superfinal, w));
326 ofst->AddArc(next, Arc(0, 0, finalw, state));
327 heap.push_back(next);
328 push_heap(heap.begin(), heap.end(), compare);
329 }
330 }
331 Connect(ofst);
332 distance->erase(distance->begin());
333 distance->pop_back();
334 }
335
336 // Shortest-path algorithm: simplified interface. See above for a
337 // version that allows finer control.
338
339 // 'ofst' contains the 'n'-shortest paths in 'ifst'. The queue
340 // discipline is automatically selected. When 'unique' == true, only
341 // paths with distinct input labels are returned.
342 //
343 // The n-shortest paths are the n-lowest weight paths w.r.t. the
344 // natural semiring order. The single path that can be read from the
345 // ith of at most n transitions leaving the initial state of 'ofst' is
346 // the ith best path.
347 //
348 // The weights need to be right distributive and have the path
349 // (kPath) property.
350 template<class Arc>
351 void ShortestPath(const Fst<Arc> &ifst, MutableFst<Arc> *ofst,
352 size_t n = 1, bool unique = false) {
353 vector<typename Arc::Weight> distance;
354 AnyArcFilter<Arc> arc_filter;
355 AutoQueue<typename Arc::StateId> state_queue(ifst, &distance, arc_filter);
356 ShortestPathOptions< Arc, AutoQueue<typename Arc::StateId>,
357 AnyArcFilter<Arc> > opts(&state_queue, arc_filter, n, unique);
358 ShortestPath(ifst, ofst, &distance, opts);
359 }
360
361 } // namespace fst
362
363 #endif // FST_LIB_SHORTEST_PATH_H__
364