• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // shortest-path.h
2 
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Copyright 2005-2010 Google, Inc.
16 // Author: allauzen@google.com (Cyril Allauzen)
17 //
18 // \file
19 // Functions to find shortest paths in an FST.
20 
21 #ifndef FST_LIB_SHORTEST_PATH_H__
22 #define FST_LIB_SHORTEST_PATH_H__
23 
24 #include <functional>
25 #include <utility>
26 using std::pair; using std::make_pair;
27 #include <vector>
28 using std::vector;
29 
30 #include <fst/cache.h>
31 #include <fst/determinize.h>
32 #include <fst/queue.h>
33 #include <fst/shortest-distance.h>
34 #include <fst/test-properties.h>
35 
36 
37 namespace fst {
38 
39 template <class Arc, class Queue, class ArcFilter>
40 struct ShortestPathOptions
41     : public ShortestDistanceOptions<Arc, Queue, ArcFilter> {
42   typedef typename Arc::StateId StateId;
43   typedef typename Arc::Weight Weight;
44   size_t nshortest;   // return n-shortest paths
45   bool unique;        // only return paths with distinct input strings
46   bool has_distance;  // distance vector already contains the
47                       // shortest distance from the initial state
48   bool first_path;    // Single shortest path stops after finding the first
49                       // path to a final state. That path is the shortest path
50                       // only when using the ShortestFirstQueue and
51                       // only when all the weights in the FST are between
52                       // One() and Zero() according to NaturalLess.
53   Weight weight_threshold;   // pruning weight threshold.
54   StateId state_threshold;   // pruning state threshold.
55 
56   ShortestPathOptions(Queue *q, ArcFilter filt, size_t n = 1, bool u = false,
57                       bool hasdist = false, float d = kDelta,
58                       bool fp = false, Weight w = Weight::Zero(),
59                       StateId s = kNoStateId)
60       : ShortestDistanceOptions<Arc, Queue, ArcFilter>(q, filt, kNoStateId, d),
61         nshortest(n), unique(u), has_distance(hasdist), first_path(fp),
62         weight_threshold(w), state_threshold(s) {}
63 };
64 
65 
66 // Shortest-path algorithm: normally not called directly; prefer
67 // 'ShortestPath' below with n=1. 'ofst' contains the shortest path in
68 // 'ifst'. 'distance' returns the shortest distances from the source
69 // state to each state in 'ifst'. 'opts' is used to specify options
70 // such as the queue discipline, the arc filter and delta.
71 //
72 // The shortest path is the lowest weight path w.r.t. the natural
73 // semiring order.
74 //
75 // The weights need to be right distributive and have the path (kPath)
76 // property.
77 template<class Arc, class Queue, class ArcFilter>
SingleShortestPath(const Fst<Arc> & ifst,MutableFst<Arc> * ofst,vector<typename Arc::Weight> * distance,ShortestPathOptions<Arc,Queue,ArcFilter> & opts)78 void SingleShortestPath(const Fst<Arc> &ifst,
79                   MutableFst<Arc> *ofst,
80                   vector<typename Arc::Weight> *distance,
81                   ShortestPathOptions<Arc, Queue, ArcFilter> &opts) {
82   typedef typename Arc::StateId StateId;
83   typedef typename Arc::Weight Weight;
84 
85   ofst->DeleteStates();
86   ofst->SetInputSymbols(ifst.InputSymbols());
87   ofst->SetOutputSymbols(ifst.OutputSymbols());
88 
89   if (ifst.Start() == kNoStateId) {
90     if (ifst.Properties(kError, false)) ofst->SetProperties(kError, kError);
91     return;
92   }
93 
94   vector<bool> enqueued;
95   vector<StateId> parent;
96   vector<Arc> arc_parent;
97 
98   Queue *state_queue = opts.state_queue;
99   StateId source = opts.source == kNoStateId ? ifst.Start() : opts.source;
100   Weight f_distance = Weight::Zero();
101   StateId f_parent = kNoStateId;
102 
103   distance->clear();
104   state_queue->Clear();
105   if (opts.nshortest != 1) {
106     FSTERROR() << "SingleShortestPath: for nshortest > 1, use ShortestPath"
107                << " instead";
108     ofst->SetProperties(kError, kError);
109     return;
110   }
111   if (opts.weight_threshold != Weight::Zero() ||
112       opts.state_threshold != kNoStateId) {
113     FSTERROR() <<
114         "SingleShortestPath: weight and state thresholds not applicable";
115     ofst->SetProperties(kError, kError);
116     return;
117   }
118   if ((Weight::Properties() & (kPath | kRightSemiring))
119       != (kPath | kRightSemiring)) {
120     FSTERROR() << "SingleShortestPath: Weight needs to have the path"
121                << " property and be right distributive: " << Weight::Type();
122     ofst->SetProperties(kError, kError);
123     return;
124   }
125   while (distance->size() < source) {
126     distance->push_back(Weight::Zero());
127     enqueued.push_back(false);
128     parent.push_back(kNoStateId);
129     arc_parent.push_back(Arc(kNoLabel, kNoLabel, Weight::Zero(), kNoStateId));
130   }
131   distance->push_back(Weight::One());
132   parent.push_back(kNoStateId);
133   arc_parent.push_back(Arc(kNoLabel, kNoLabel, Weight::Zero(), kNoStateId));
134   state_queue->Enqueue(source);
135   enqueued.push_back(true);
136 
137   while (!state_queue->Empty()) {
138     StateId s = state_queue->Head();
139     state_queue->Dequeue();
140     enqueued[s] = false;
141     Weight sd = (*distance)[s];
142     if (ifst.Final(s) != Weight::Zero()) {
143       Weight w = Times(sd, ifst.Final(s));
144       if (f_distance != Plus(f_distance, w)) {
145         f_distance = Plus(f_distance, w);
146         f_parent = s;
147       }
148       if (!f_distance.Member()) {
149         ofst->SetProperties(kError, kError);
150         return;
151       }
152       if (opts.first_path)
153         break;
154     }
155     for (ArcIterator< Fst<Arc> > aiter(ifst, s);
156          !aiter.Done();
157          aiter.Next()) {
158       const Arc &arc = aiter.Value();
159       while (distance->size() <= arc.nextstate) {
160         distance->push_back(Weight::Zero());
161         enqueued.push_back(false);
162         parent.push_back(kNoStateId);
163         arc_parent.push_back(Arc(kNoLabel, kNoLabel, Weight::Zero(),
164                                  kNoStateId));
165       }
166       Weight &nd = (*distance)[arc.nextstate];
167       Weight w = Times(sd, arc.weight);
168       if (nd != Plus(nd, w)) {
169         nd = Plus(nd, w);
170         if (!nd.Member()) {
171           ofst->SetProperties(kError, kError);
172           return;
173         }
174         parent[arc.nextstate] = s;
175         arc_parent[arc.nextstate] = arc;
176         if (!enqueued[arc.nextstate]) {
177           state_queue->Enqueue(arc.nextstate);
178           enqueued[arc.nextstate] = true;
179         } else {
180           state_queue->Update(arc.nextstate);
181         }
182       }
183     }
184   }
185 
186   StateId s_p = kNoStateId, d_p = kNoStateId;
187   for (StateId s = f_parent, d = kNoStateId;
188        s != kNoStateId;
189        d = s, s = parent[s]) {
190     d_p = s_p;
191     s_p = ofst->AddState();
192     if (d == kNoStateId) {
193       ofst->SetFinal(s_p, ifst.Final(f_parent));
194     } else {
195       arc_parent[d].nextstate = d_p;
196       ofst->AddArc(s_p, arc_parent[d]);
197     }
198   }
199   ofst->SetStart(s_p);
200   if (ifst.Properties(kError, false)) ofst->SetProperties(kError, kError);
201   ofst->SetProperties(
202       ShortestPathProperties(ofst->Properties(kFstProperties, false)),
203       kFstProperties);
204 }
205 
206 
207 template <class S, class W>
208 class ShortestPathCompare {
209  public:
210   typedef S StateId;
211   typedef W Weight;
212   typedef pair<StateId, Weight> Pair;
213 
ShortestPathCompare(const vector<Pair> & pairs,const vector<Weight> & distance,StateId sfinal,float d)214   ShortestPathCompare(const vector<Pair>& pairs,
215                       const vector<Weight>& distance,
216                       StateId sfinal, float d)
217       : pairs_(pairs), distance_(distance), superfinal_(sfinal), delta_(d)  {}
218 
operator()219   bool operator()(const StateId x, const StateId y) const {
220     const Pair &px = pairs_[x];
221     const Pair &py = pairs_[y];
222     Weight dx = px.first == superfinal_ ? Weight::One() :
223         px.first < distance_.size() ? distance_[px.first] : Weight::Zero();
224     Weight dy = py.first == superfinal_ ? Weight::One() :
225         py.first < distance_.size() ? distance_[py.first] : Weight::Zero();
226     Weight wx = Times(dx, px.second);
227     Weight wy = Times(dy, py.second);
228     // Penalize complete paths to ensure correct results with inexact weights.
229     // This forms a strict weak order so long as ApproxEqual(a, b) =>
230     // ApproxEqual(a, c) for all c s.t. less_(a, c) && less_(c, b).
231     if (px.first == superfinal_ && py.first != superfinal_) {
232       return less_(wy, wx) || ApproxEqual(wx, wy, delta_);
233     } else if (py.first == superfinal_ && px.first != superfinal_) {
234       return less_(wy, wx) && !ApproxEqual(wx, wy, delta_);
235     } else {
236       return less_(wy, wx);
237     }
238   }
239 
240  private:
241   const vector<Pair> &pairs_;
242   const vector<Weight> &distance_;
243   StateId superfinal_;
244   float delta_;
245   NaturalLess<Weight> less_;
246 };
247 
248 
249 // N-Shortest-path algorithm: implements the core n-shortest path
250 // algorithm. The output is built REVERSED. See below for versions with
251 // more options and not reversed.
252 //
253 // 'ofst' contains the REVERSE of 'n'-shortest paths in 'ifst'.
254 // 'distance' must contain the shortest distance from each state to a final
255 // state in 'ifst'. 'delta' is the convergence delta.
256 //
257 // The n-shortest paths are the n-lowest weight paths w.r.t. the
258 // natural semiring order. The single path that can be read from the
259 // ith of at most n transitions leaving the initial state of 'ofst' is
260 // the ith shortest path. Disregarding the initial state and initial
261 // transitions, the n-shortest paths, in fact, form a tree rooted at
262 // the single final state.
263 //
264 // The weights need to be left and right distributive (kSemiring) and
265 // have the path (kPath) property.
266 //
267 // The algorithm is from Mohri and Riley, "An Efficient Algorithm for
268 // the n-best-strings problem", ICSLP 2002. The algorithm relies on
269 // the shortest-distance algorithm. There are some issues with the
270 // pseudo-code as written in the paper (viz., line 11).
271 //
272 // IMPLEMENTATION NOTE: The input fst 'ifst' can be a delayed fst and
273 // and at any state in its expansion the values of distance vector need only
274 // be defined at that time for the states that are known to exist.
275 template<class Arc, class RevArc>
276 void NShortestPath(const Fst<RevArc> &ifst,
277                    MutableFst<Arc> *ofst,
278                    const vector<typename Arc::Weight> &distance,
279                    size_t n,
280                    float delta = kDelta,
281                    typename Arc::Weight weight_threshold = Arc::Weight::Zero(),
282                    typename Arc::StateId state_threshold = kNoStateId) {
283   typedef typename Arc::StateId StateId;
284   typedef typename Arc::Weight Weight;
285   typedef pair<StateId, Weight> Pair;
286   typedef typename RevArc::Weight RevWeight;
287 
288   if (n <= 0) return;
289   if ((Weight::Properties() & (kPath | kSemiring)) != (kPath | kSemiring)) {
290     FSTERROR() << "NShortestPath: Weight needs to have the "
291                  << "path property and be distributive: "
292                  << Weight::Type();
293     ofst->SetProperties(kError, kError);
294     return;
295   }
296   ofst->DeleteStates();
297   ofst->SetInputSymbols(ifst.InputSymbols());
298   ofst->SetOutputSymbols(ifst.OutputSymbols());
299   // Each state in 'ofst' corresponds to a path with weight w from the
300   // initial state of 'ifst' to a state s in 'ifst', that can be
301   // characterized by a pair (s,w).  The vector 'pairs' maps each
302   // state in 'ofst' to the corresponding pair maps states in OFST to
303   // the corresponding pair (s,w).
304   vector<Pair> pairs;
305   // The supefinal state is denoted by -1, 'compare' knows that the
306   // distance from 'superfinal' to the final state is 'Weight::One()',
307   // hence 'distance[superfinal]' is not needed.
308   StateId superfinal = -1;
309   ShortestPathCompare<StateId, Weight>
310     compare(pairs, distance, superfinal, delta);
311   vector<StateId> heap;
312   // 'r[s + 1]', 's' state in 'fst', is the number of states in 'ofst'
313   // which corresponding pair contains 's' ,i.e. , it is number of
314   // paths computed so far to 's'. Valid for 's == -1' (superfinal).
315   vector<int> r;
316   NaturalLess<Weight> less;
317   if (ifst.Start() == kNoStateId ||
318       distance.size() <= ifst.Start() ||
319       distance[ifst.Start()] == Weight::Zero() ||
320       less(weight_threshold, Weight::One()) ||
321       state_threshold == 0) {
322     if (ifst.Properties(kError, false)) ofst->SetProperties(kError, kError);
323     return;
324   }
325   ofst->SetStart(ofst->AddState());
326   StateId final = ofst->AddState();
327   ofst->SetFinal(final, Weight::One());
328   while (pairs.size() <= final)
329     pairs.push_back(Pair(kNoStateId, Weight::Zero()));
330   pairs[final] = Pair(ifst.Start(), Weight::One());
331   heap.push_back(final);
332   Weight limit = Times(distance[ifst.Start()], weight_threshold);
333 
334   while (!heap.empty()) {
335     pop_heap(heap.begin(), heap.end(), compare);
336     StateId state = heap.back();
337     Pair p = pairs[state];
338     heap.pop_back();
339     Weight d = p.first == superfinal ? Weight::One() :
340         p.first < distance.size() ? distance[p.first] : Weight::Zero();
341 
342     if (less(limit, Times(d, p.second)) ||
343         (state_threshold != kNoStateId &&
344          ofst->NumStates() >= state_threshold))
345       continue;
346 
347     while (r.size() <= p.first + 1) r.push_back(0);
348     ++r[p.first + 1];
349     if (p.first == superfinal)
350       ofst->AddArc(ofst->Start(), Arc(0, 0, Weight::One(), state));
351     if ((p.first == superfinal) && (r[p.first + 1] == n)) break;
352     if (r[p.first + 1] > n) continue;
353     if (p.first == superfinal) continue;
354 
355     for (ArcIterator< Fst<RevArc> > aiter(ifst, p.first);
356          !aiter.Done();
357          aiter.Next()) {
358       const RevArc &rarc = aiter.Value();
359       Arc arc(rarc.ilabel, rarc.olabel, rarc.weight.Reverse(), rarc.nextstate);
360       Weight w = Times(p.second, arc.weight);
361       StateId next = ofst->AddState();
362       pairs.push_back(Pair(arc.nextstate, w));
363       arc.nextstate = state;
364       ofst->AddArc(next, arc);
365       heap.push_back(next);
366       push_heap(heap.begin(), heap.end(), compare);
367     }
368 
369     Weight finalw = ifst.Final(p.first).Reverse();
370     if (finalw != Weight::Zero()) {
371       Weight w = Times(p.second, finalw);
372       StateId next = ofst->AddState();
373       pairs.push_back(Pair(superfinal, w));
374       ofst->AddArc(next, Arc(0, 0, finalw, state));
375       heap.push_back(next);
376       push_heap(heap.begin(), heap.end(), compare);
377     }
378   }
379   Connect(ofst);
380   if (ifst.Properties(kError, false)) ofst->SetProperties(kError, kError);
381   ofst->SetProperties(
382       ShortestPathProperties(ofst->Properties(kFstProperties, false)),
383       kFstProperties);
384 }
385 
386 
387 // N-Shortest-path algorithm:  this version allow fine control
388 // via the options argument. See below for a simpler interface.
389 //
390 // 'ofst' contains the n-shortest paths in 'ifst'. 'distance' returns
391 // the shortest distances from the source state to each state in
392 // 'ifst'. 'opts' is used to specify options such as the number of
393 // paths to return, whether they need to have distinct input
394 // strings, the queue discipline, the arc filter and the convergence
395 // delta.
396 //
397 // The n-shortest paths are the n-lowest weight paths w.r.t. the
398 // natural semiring order. The single path that can be read from the
399 // ith of at most n transitions leaving the initial state of 'ofst' is
400 // the ith shortest path. Disregarding the initial state and initial
401 // transitions, The n-shortest paths, in fact, form a tree rooted at
402 // the single final state.
403 
404 // The weights need to be right distributive and have the path (kPath)
405 // property. They need to be left distributive as well for nshortest
406 // > 1.
407 //
408 // The algorithm is from Mohri and Riley, "An Efficient Algorithm for
409 // the n-best-strings problem", ICSLP 2002. The algorithm relies on
410 // the shortest-distance algorithm. There are some issues with the
411 // pseudo-code as written in the paper (viz., line 11).
412 template<class Arc, class Queue, class ArcFilter>
ShortestPath(const Fst<Arc> & ifst,MutableFst<Arc> * ofst,vector<typename Arc::Weight> * distance,ShortestPathOptions<Arc,Queue,ArcFilter> & opts)413 void ShortestPath(const Fst<Arc> &ifst, MutableFst<Arc> *ofst,
414                   vector<typename Arc::Weight> *distance,
415                   ShortestPathOptions<Arc, Queue, ArcFilter> &opts) {
416   typedef typename Arc::StateId StateId;
417   typedef typename Arc::Weight Weight;
418   typedef ReverseArc<Arc> ReverseArc;
419 
420   size_t n = opts.nshortest;
421   if (n == 1) {
422     SingleShortestPath(ifst, ofst, distance, opts);
423     return;
424   }
425   if (n <= 0) return;
426   if ((Weight::Properties() & (kPath | kSemiring)) != (kPath | kSemiring)) {
427     FSTERROR() << "ShortestPath: n-shortest: Weight needs to have the "
428                << "path property and be distributive: "
429                << Weight::Type();
430     ofst->SetProperties(kError, kError);
431     return;
432   }
433   if (!opts.has_distance) {
434     ShortestDistance(ifst, distance, opts);
435     if (distance->size() == 1 && !(*distance)[0].Member()) {
436       ofst->SetProperties(kError, kError);
437       return;
438     }
439   }
440   // Algorithm works on the reverse of 'fst' : 'rfst', 'distance' is
441   // the distance to the final state in 'rfst', 'ofst' is built as the
442   // reverse of the tree of n-shortest path in 'rfst'.
443   VectorFst<ReverseArc> rfst;
444   Reverse(ifst, &rfst);
445   Weight d = Weight::Zero();
446   for (ArcIterator< VectorFst<ReverseArc> > aiter(rfst, 0);
447        !aiter.Done(); aiter.Next()) {
448     const ReverseArc &arc = aiter.Value();
449     StateId s = arc.nextstate - 1;
450     if (s < distance->size())
451       d = Plus(d, Times(arc.weight.Reverse(), (*distance)[s]));
452   }
453   distance->insert(distance->begin(), d);
454 
455   if (!opts.unique) {
456     NShortestPath(rfst, ofst, *distance, n, opts.delta,
457                   opts.weight_threshold, opts.state_threshold);
458   } else {
459     vector<Weight> ddistance;
460     DeterminizeFstOptions<ReverseArc> dopts(opts.delta);
461     DeterminizeFst<ReverseArc> dfst(rfst, *distance, &ddistance, dopts);
462     NShortestPath(dfst, ofst, ddistance, n, opts.delta,
463                   opts.weight_threshold, opts.state_threshold);
464   }
465   distance->erase(distance->begin());
466 }
467 
468 
469 // Shortest-path algorithm: simplified interface. See above for a
470 // version that allows finer control.
471 //
472 // 'ofst' contains the 'n'-shortest paths in 'ifst'. The queue
473 // discipline is automatically selected. When 'unique' == true, only
474 // paths with distinct input labels are returned.
475 //
476 // The n-shortest paths are the n-lowest weight paths w.r.t. the
477 // natural semiring order. The single path that can be read from the
478 // ith of at most n transitions leaving the initial state of 'ofst' is
479 // the ith best path.
480 //
481 // The weights need to be right distributive and have the path
482 // (kPath) property.
483 template<class Arc>
484 void ShortestPath(const Fst<Arc> &ifst, MutableFst<Arc> *ofst,
485                   size_t n = 1, bool unique = false,
486                   bool first_path = false,
487                   typename Arc::Weight weight_threshold = Arc::Weight::Zero(),
488                   typename Arc::StateId state_threshold = kNoStateId) {
489   vector<typename Arc::Weight> distance;
490   AnyArcFilter<Arc> arc_filter;
491   AutoQueue<typename Arc::StateId> state_queue(ifst, &distance, arc_filter);
492   ShortestPathOptions< Arc, AutoQueue<typename Arc::StateId>,
493       AnyArcFilter<Arc> > opts(&state_queue, arc_filter, n, unique, false,
494                                kDelta, first_path, weight_threshold,
495                                state_threshold);
496   ShortestPath(ifst, ofst, &distance, opts);
497 }
498 
499 }  // namespace fst
500 
501 #endif  // FST_LIB_SHORTEST_PATH_H__
502