1 // shortest-path.h
2
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Copyright 2005-2010 Google, Inc.
16 // Author: allauzen@google.com (Cyril Allauzen)
17 //
18 // \file
19 // Functions to find shortest paths in an FST.
20
21 #ifndef FST_LIB_SHORTEST_PATH_H__
22 #define FST_LIB_SHORTEST_PATH_H__
23
24 #include <functional>
25 #include <utility>
26 using std::pair; using std::make_pair;
27 #include <vector>
28 using std::vector;
29
30 #include <fst/cache.h>
31 #include <fst/determinize.h>
32 #include <fst/queue.h>
33 #include <fst/shortest-distance.h>
34 #include <fst/test-properties.h>
35
36
37 namespace fst {
38
39 template <class Arc, class Queue, class ArcFilter>
40 struct ShortestPathOptions
41 : public ShortestDistanceOptions<Arc, Queue, ArcFilter> {
42 typedef typename Arc::StateId StateId;
43 typedef typename Arc::Weight Weight;
44 size_t nshortest; // return n-shortest paths
45 bool unique; // only return paths with distinct input strings
46 bool has_distance; // distance vector already contains the
47 // shortest distance from the initial state
48 bool first_path; // Single shortest path stops after finding the first
49 // path to a final state. That path is the shortest path
50 // only when using the ShortestFirstQueue and
51 // only when all the weights in the FST are between
52 // One() and Zero() according to NaturalLess.
53 Weight weight_threshold; // pruning weight threshold.
54 StateId state_threshold; // pruning state threshold.
55
56 ShortestPathOptions(Queue *q, ArcFilter filt, size_t n = 1, bool u = false,
57 bool hasdist = false, float d = kDelta,
58 bool fp = false, Weight w = Weight::Zero(),
59 StateId s = kNoStateId)
60 : ShortestDistanceOptions<Arc, Queue, ArcFilter>(q, filt, kNoStateId, d),
61 nshortest(n), unique(u), has_distance(hasdist), first_path(fp),
62 weight_threshold(w), state_threshold(s) {}
63 };
64
65
66 // Shortest-path algorithm: normally not called directly; prefer
67 // 'ShortestPath' below with n=1. 'ofst' contains the shortest path in
68 // 'ifst'. 'distance' returns the shortest distances from the source
69 // state to each state in 'ifst'. 'opts' is used to specify options
70 // such as the queue discipline, the arc filter and delta.
71 //
72 // The shortest path is the lowest weight path w.r.t. the natural
73 // semiring order.
74 //
75 // The weights need to be right distributive and have the path (kPath)
76 // property.
77 template<class Arc, class Queue, class ArcFilter>
SingleShortestPath(const Fst<Arc> & ifst,MutableFst<Arc> * ofst,vector<typename Arc::Weight> * distance,ShortestPathOptions<Arc,Queue,ArcFilter> & opts)78 void SingleShortestPath(const Fst<Arc> &ifst,
79 MutableFst<Arc> *ofst,
80 vector<typename Arc::Weight> *distance,
81 ShortestPathOptions<Arc, Queue, ArcFilter> &opts) {
82 typedef typename Arc::StateId StateId;
83 typedef typename Arc::Weight Weight;
84
85 ofst->DeleteStates();
86 ofst->SetInputSymbols(ifst.InputSymbols());
87 ofst->SetOutputSymbols(ifst.OutputSymbols());
88
89 if (ifst.Start() == kNoStateId) {
90 if (ifst.Properties(kError, false)) ofst->SetProperties(kError, kError);
91 return;
92 }
93
94 vector<bool> enqueued;
95 vector<StateId> parent;
96 vector<Arc> arc_parent;
97
98 Queue *state_queue = opts.state_queue;
99 StateId source = opts.source == kNoStateId ? ifst.Start() : opts.source;
100 Weight f_distance = Weight::Zero();
101 StateId f_parent = kNoStateId;
102
103 distance->clear();
104 state_queue->Clear();
105 if (opts.nshortest != 1) {
106 FSTERROR() << "SingleShortestPath: for nshortest > 1, use ShortestPath"
107 << " instead";
108 ofst->SetProperties(kError, kError);
109 return;
110 }
111 if (opts.weight_threshold != Weight::Zero() ||
112 opts.state_threshold != kNoStateId) {
113 FSTERROR() <<
114 "SingleShortestPath: weight and state thresholds not applicable";
115 ofst->SetProperties(kError, kError);
116 return;
117 }
118 if ((Weight::Properties() & (kPath | kRightSemiring))
119 != (kPath | kRightSemiring)) {
120 FSTERROR() << "SingleShortestPath: Weight needs to have the path"
121 << " property and be right distributive: " << Weight::Type();
122 ofst->SetProperties(kError, kError);
123 return;
124 }
125 while (distance->size() < source) {
126 distance->push_back(Weight::Zero());
127 enqueued.push_back(false);
128 parent.push_back(kNoStateId);
129 arc_parent.push_back(Arc(kNoLabel, kNoLabel, Weight::Zero(), kNoStateId));
130 }
131 distance->push_back(Weight::One());
132 parent.push_back(kNoStateId);
133 arc_parent.push_back(Arc(kNoLabel, kNoLabel, Weight::Zero(), kNoStateId));
134 state_queue->Enqueue(source);
135 enqueued.push_back(true);
136
137 while (!state_queue->Empty()) {
138 StateId s = state_queue->Head();
139 state_queue->Dequeue();
140 enqueued[s] = false;
141 Weight sd = (*distance)[s];
142 if (ifst.Final(s) != Weight::Zero()) {
143 Weight w = Times(sd, ifst.Final(s));
144 if (f_distance != Plus(f_distance, w)) {
145 f_distance = Plus(f_distance, w);
146 f_parent = s;
147 }
148 if (!f_distance.Member()) {
149 ofst->SetProperties(kError, kError);
150 return;
151 }
152 if (opts.first_path)
153 break;
154 }
155 for (ArcIterator< Fst<Arc> > aiter(ifst, s);
156 !aiter.Done();
157 aiter.Next()) {
158 const Arc &arc = aiter.Value();
159 while (distance->size() <= arc.nextstate) {
160 distance->push_back(Weight::Zero());
161 enqueued.push_back(false);
162 parent.push_back(kNoStateId);
163 arc_parent.push_back(Arc(kNoLabel, kNoLabel, Weight::Zero(),
164 kNoStateId));
165 }
166 Weight &nd = (*distance)[arc.nextstate];
167 Weight w = Times(sd, arc.weight);
168 if (nd != Plus(nd, w)) {
169 nd = Plus(nd, w);
170 if (!nd.Member()) {
171 ofst->SetProperties(kError, kError);
172 return;
173 }
174 parent[arc.nextstate] = s;
175 arc_parent[arc.nextstate] = arc;
176 if (!enqueued[arc.nextstate]) {
177 state_queue->Enqueue(arc.nextstate);
178 enqueued[arc.nextstate] = true;
179 } else {
180 state_queue->Update(arc.nextstate);
181 }
182 }
183 }
184 }
185
186 StateId s_p = kNoStateId, d_p = kNoStateId;
187 for (StateId s = f_parent, d = kNoStateId;
188 s != kNoStateId;
189 d = s, s = parent[s]) {
190 d_p = s_p;
191 s_p = ofst->AddState();
192 if (d == kNoStateId) {
193 ofst->SetFinal(s_p, ifst.Final(f_parent));
194 } else {
195 arc_parent[d].nextstate = d_p;
196 ofst->AddArc(s_p, arc_parent[d]);
197 }
198 }
199 ofst->SetStart(s_p);
200 if (ifst.Properties(kError, false)) ofst->SetProperties(kError, kError);
201 ofst->SetProperties(
202 ShortestPathProperties(ofst->Properties(kFstProperties, false)),
203 kFstProperties);
204 }
205
206
207 template <class S, class W>
208 class ShortestPathCompare {
209 public:
210 typedef S StateId;
211 typedef W Weight;
212 typedef pair<StateId, Weight> Pair;
213
ShortestPathCompare(const vector<Pair> & pairs,const vector<Weight> & distance,StateId sfinal,float d)214 ShortestPathCompare(const vector<Pair>& pairs,
215 const vector<Weight>& distance,
216 StateId sfinal, float d)
217 : pairs_(pairs), distance_(distance), superfinal_(sfinal), delta_(d) {}
218
operator()219 bool operator()(const StateId x, const StateId y) const {
220 const Pair &px = pairs_[x];
221 const Pair &py = pairs_[y];
222 Weight dx = px.first == superfinal_ ? Weight::One() :
223 px.first < distance_.size() ? distance_[px.first] : Weight::Zero();
224 Weight dy = py.first == superfinal_ ? Weight::One() :
225 py.first < distance_.size() ? distance_[py.first] : Weight::Zero();
226 Weight wx = Times(dx, px.second);
227 Weight wy = Times(dy, py.second);
228 // Penalize complete paths to ensure correct results with inexact weights.
229 // This forms a strict weak order so long as ApproxEqual(a, b) =>
230 // ApproxEqual(a, c) for all c s.t. less_(a, c) && less_(c, b).
231 if (px.first == superfinal_ && py.first != superfinal_) {
232 return less_(wy, wx) || ApproxEqual(wx, wy, delta_);
233 } else if (py.first == superfinal_ && px.first != superfinal_) {
234 return less_(wy, wx) && !ApproxEqual(wx, wy, delta_);
235 } else {
236 return less_(wy, wx);
237 }
238 }
239
240 private:
241 const vector<Pair> &pairs_;
242 const vector<Weight> &distance_;
243 StateId superfinal_;
244 float delta_;
245 NaturalLess<Weight> less_;
246 };
247
248
249 // N-Shortest-path algorithm: implements the core n-shortest path
250 // algorithm. The output is built REVERSED. See below for versions with
251 // more options and not reversed.
252 //
253 // 'ofst' contains the REVERSE of 'n'-shortest paths in 'ifst'.
254 // 'distance' must contain the shortest distance from each state to a final
255 // state in 'ifst'. 'delta' is the convergence delta.
256 //
257 // The n-shortest paths are the n-lowest weight paths w.r.t. the
258 // natural semiring order. The single path that can be read from the
259 // ith of at most n transitions leaving the initial state of 'ofst' is
260 // the ith shortest path. Disregarding the initial state and initial
261 // transitions, the n-shortest paths, in fact, form a tree rooted at
262 // the single final state.
263 //
264 // The weights need to be left and right distributive (kSemiring) and
265 // have the path (kPath) property.
266 //
267 // The algorithm is from Mohri and Riley, "An Efficient Algorithm for
268 // the n-best-strings problem", ICSLP 2002. The algorithm relies on
269 // the shortest-distance algorithm. There are some issues with the
270 // pseudo-code as written in the paper (viz., line 11).
271 //
272 // IMPLEMENTATION NOTE: The input fst 'ifst' can be a delayed fst and
273 // and at any state in its expansion the values of distance vector need only
274 // be defined at that time for the states that are known to exist.
275 template<class Arc, class RevArc>
276 void NShortestPath(const Fst<RevArc> &ifst,
277 MutableFst<Arc> *ofst,
278 const vector<typename Arc::Weight> &distance,
279 size_t n,
280 float delta = kDelta,
281 typename Arc::Weight weight_threshold = Arc::Weight::Zero(),
282 typename Arc::StateId state_threshold = kNoStateId) {
283 typedef typename Arc::StateId StateId;
284 typedef typename Arc::Weight Weight;
285 typedef pair<StateId, Weight> Pair;
286 typedef typename RevArc::Weight RevWeight;
287
288 if (n <= 0) return;
289 if ((Weight::Properties() & (kPath | kSemiring)) != (kPath | kSemiring)) {
290 FSTERROR() << "NShortestPath: Weight needs to have the "
291 << "path property and be distributive: "
292 << Weight::Type();
293 ofst->SetProperties(kError, kError);
294 return;
295 }
296 ofst->DeleteStates();
297 ofst->SetInputSymbols(ifst.InputSymbols());
298 ofst->SetOutputSymbols(ifst.OutputSymbols());
299 // Each state in 'ofst' corresponds to a path with weight w from the
300 // initial state of 'ifst' to a state s in 'ifst', that can be
301 // characterized by a pair (s,w). The vector 'pairs' maps each
302 // state in 'ofst' to the corresponding pair maps states in OFST to
303 // the corresponding pair (s,w).
304 vector<Pair> pairs;
305 // The supefinal state is denoted by -1, 'compare' knows that the
306 // distance from 'superfinal' to the final state is 'Weight::One()',
307 // hence 'distance[superfinal]' is not needed.
308 StateId superfinal = -1;
309 ShortestPathCompare<StateId, Weight>
310 compare(pairs, distance, superfinal, delta);
311 vector<StateId> heap;
312 // 'r[s + 1]', 's' state in 'fst', is the number of states in 'ofst'
313 // which corresponding pair contains 's' ,i.e. , it is number of
314 // paths computed so far to 's'. Valid for 's == -1' (superfinal).
315 vector<int> r;
316 NaturalLess<Weight> less;
317 if (ifst.Start() == kNoStateId ||
318 distance.size() <= ifst.Start() ||
319 distance[ifst.Start()] == Weight::Zero() ||
320 less(weight_threshold, Weight::One()) ||
321 state_threshold == 0) {
322 if (ifst.Properties(kError, false)) ofst->SetProperties(kError, kError);
323 return;
324 }
325 ofst->SetStart(ofst->AddState());
326 StateId final = ofst->AddState();
327 ofst->SetFinal(final, Weight::One());
328 while (pairs.size() <= final)
329 pairs.push_back(Pair(kNoStateId, Weight::Zero()));
330 pairs[final] = Pair(ifst.Start(), Weight::One());
331 heap.push_back(final);
332 Weight limit = Times(distance[ifst.Start()], weight_threshold);
333
334 while (!heap.empty()) {
335 pop_heap(heap.begin(), heap.end(), compare);
336 StateId state = heap.back();
337 Pair p = pairs[state];
338 heap.pop_back();
339 Weight d = p.first == superfinal ? Weight::One() :
340 p.first < distance.size() ? distance[p.first] : Weight::Zero();
341
342 if (less(limit, Times(d, p.second)) ||
343 (state_threshold != kNoStateId &&
344 ofst->NumStates() >= state_threshold))
345 continue;
346
347 while (r.size() <= p.first + 1) r.push_back(0);
348 ++r[p.first + 1];
349 if (p.first == superfinal)
350 ofst->AddArc(ofst->Start(), Arc(0, 0, Weight::One(), state));
351 if ((p.first == superfinal) && (r[p.first + 1] == n)) break;
352 if (r[p.first + 1] > n) continue;
353 if (p.first == superfinal) continue;
354
355 for (ArcIterator< Fst<RevArc> > aiter(ifst, p.first);
356 !aiter.Done();
357 aiter.Next()) {
358 const RevArc &rarc = aiter.Value();
359 Arc arc(rarc.ilabel, rarc.olabel, rarc.weight.Reverse(), rarc.nextstate);
360 Weight w = Times(p.second, arc.weight);
361 StateId next = ofst->AddState();
362 pairs.push_back(Pair(arc.nextstate, w));
363 arc.nextstate = state;
364 ofst->AddArc(next, arc);
365 heap.push_back(next);
366 push_heap(heap.begin(), heap.end(), compare);
367 }
368
369 Weight finalw = ifst.Final(p.first).Reverse();
370 if (finalw != Weight::Zero()) {
371 Weight w = Times(p.second, finalw);
372 StateId next = ofst->AddState();
373 pairs.push_back(Pair(superfinal, w));
374 ofst->AddArc(next, Arc(0, 0, finalw, state));
375 heap.push_back(next);
376 push_heap(heap.begin(), heap.end(), compare);
377 }
378 }
379 Connect(ofst);
380 if (ifst.Properties(kError, false)) ofst->SetProperties(kError, kError);
381 ofst->SetProperties(
382 ShortestPathProperties(ofst->Properties(kFstProperties, false)),
383 kFstProperties);
384 }
385
386
387 // N-Shortest-path algorithm: this version allow fine control
388 // via the options argument. See below for a simpler interface.
389 //
390 // 'ofst' contains the n-shortest paths in 'ifst'. 'distance' returns
391 // the shortest distances from the source state to each state in
392 // 'ifst'. 'opts' is used to specify options such as the number of
393 // paths to return, whether they need to have distinct input
394 // strings, the queue discipline, the arc filter and the convergence
395 // delta.
396 //
397 // The n-shortest paths are the n-lowest weight paths w.r.t. the
398 // natural semiring order. The single path that can be read from the
399 // ith of at most n transitions leaving the initial state of 'ofst' is
400 // the ith shortest path. Disregarding the initial state and initial
401 // transitions, The n-shortest paths, in fact, form a tree rooted at
402 // the single final state.
403
404 // The weights need to be right distributive and have the path (kPath)
405 // property. They need to be left distributive as well for nshortest
406 // > 1.
407 //
408 // The algorithm is from Mohri and Riley, "An Efficient Algorithm for
409 // the n-best-strings problem", ICSLP 2002. The algorithm relies on
410 // the shortest-distance algorithm. There are some issues with the
411 // pseudo-code as written in the paper (viz., line 11).
412 template<class Arc, class Queue, class ArcFilter>
ShortestPath(const Fst<Arc> & ifst,MutableFst<Arc> * ofst,vector<typename Arc::Weight> * distance,ShortestPathOptions<Arc,Queue,ArcFilter> & opts)413 void ShortestPath(const Fst<Arc> &ifst, MutableFst<Arc> *ofst,
414 vector<typename Arc::Weight> *distance,
415 ShortestPathOptions<Arc, Queue, ArcFilter> &opts) {
416 typedef typename Arc::StateId StateId;
417 typedef typename Arc::Weight Weight;
418 typedef ReverseArc<Arc> ReverseArc;
419
420 size_t n = opts.nshortest;
421 if (n == 1) {
422 SingleShortestPath(ifst, ofst, distance, opts);
423 return;
424 }
425 if (n <= 0) return;
426 if ((Weight::Properties() & (kPath | kSemiring)) != (kPath | kSemiring)) {
427 FSTERROR() << "ShortestPath: n-shortest: Weight needs to have the "
428 << "path property and be distributive: "
429 << Weight::Type();
430 ofst->SetProperties(kError, kError);
431 return;
432 }
433 if (!opts.has_distance) {
434 ShortestDistance(ifst, distance, opts);
435 if (distance->size() == 1 && !(*distance)[0].Member()) {
436 ofst->SetProperties(kError, kError);
437 return;
438 }
439 }
440 // Algorithm works on the reverse of 'fst' : 'rfst', 'distance' is
441 // the distance to the final state in 'rfst', 'ofst' is built as the
442 // reverse of the tree of n-shortest path in 'rfst'.
443 VectorFst<ReverseArc> rfst;
444 Reverse(ifst, &rfst);
445 Weight d = Weight::Zero();
446 for (ArcIterator< VectorFst<ReverseArc> > aiter(rfst, 0);
447 !aiter.Done(); aiter.Next()) {
448 const ReverseArc &arc = aiter.Value();
449 StateId s = arc.nextstate - 1;
450 if (s < distance->size())
451 d = Plus(d, Times(arc.weight.Reverse(), (*distance)[s]));
452 }
453 distance->insert(distance->begin(), d);
454
455 if (!opts.unique) {
456 NShortestPath(rfst, ofst, *distance, n, opts.delta,
457 opts.weight_threshold, opts.state_threshold);
458 } else {
459 vector<Weight> ddistance;
460 DeterminizeFstOptions<ReverseArc> dopts(opts.delta);
461 DeterminizeFst<ReverseArc> dfst(rfst, *distance, &ddistance, dopts);
462 NShortestPath(dfst, ofst, ddistance, n, opts.delta,
463 opts.weight_threshold, opts.state_threshold);
464 }
465 distance->erase(distance->begin());
466 }
467
468
469 // Shortest-path algorithm: simplified interface. See above for a
470 // version that allows finer control.
471 //
472 // 'ofst' contains the 'n'-shortest paths in 'ifst'. The queue
473 // discipline is automatically selected. When 'unique' == true, only
474 // paths with distinct input labels are returned.
475 //
476 // The n-shortest paths are the n-lowest weight paths w.r.t. the
477 // natural semiring order. The single path that can be read from the
478 // ith of at most n transitions leaving the initial state of 'ofst' is
479 // the ith best path.
480 //
481 // The weights need to be right distributive and have the path
482 // (kPath) property.
483 template<class Arc>
484 void ShortestPath(const Fst<Arc> &ifst, MutableFst<Arc> *ofst,
485 size_t n = 1, bool unique = false,
486 bool first_path = false,
487 typename Arc::Weight weight_threshold = Arc::Weight::Zero(),
488 typename Arc::StateId state_threshold = kNoStateId) {
489 vector<typename Arc::Weight> distance;
490 AnyArcFilter<Arc> arc_filter;
491 AutoQueue<typename Arc::StateId> state_queue(ifst, &distance, arc_filter);
492 ShortestPathOptions< Arc, AutoQueue<typename Arc::StateId>,
493 AnyArcFilter<Arc> > opts(&state_queue, arc_filter, n, unique, false,
494 kDelta, first_path, weight_threshold,
495 state_threshold);
496 ShortestPath(ifst, ofst, &distance, opts);
497 }
498
499 } // namespace fst
500
501 #endif // FST_LIB_SHORTEST_PATH_H__
502