• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // shortest-distance.h
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Author: allauzen@cs.nyu.edu (Cyril Allauzen)
16 //
17 // \file
18 // Functions and classes to find shortest distance in an FST.
19 
20 #ifndef FST_LIB_SHORTEST_DISTANCE_H__
21 #define FST_LIB_SHORTEST_DISTANCE_H__
22 
23 #include <deque>
24 
25 #include "fst/lib/arcfilter.h"
26 #include "fst/lib/cache.h"
27 #include "fst/lib/queue.h"
28 #include "fst/lib/reverse.h"
29 #include "fst/lib/test-properties.h"
30 
31 namespace fst {
32 
33 template <class Arc, class Queue, class ArcFilter>
34 struct ShortestDistanceOptions {
35   typedef typename Arc::StateId StateId;
36 
37   Queue *state_queue;    // Queue discipline used; owned by caller
38   ArcFilter arc_filter;  // Arc filter (e.g., limit to only epsilon graph)
39   StateId source;        // If kNoStateId, use the Fst's initial state
40   float delta;           // Determines the degree of convergence required
41 
42   ShortestDistanceOptions(Queue *q, ArcFilter filt, StateId src = kNoStateId,
43                           float d = kDelta)
state_queueShortestDistanceOptions44       : state_queue(q), arc_filter(filt), source(src), delta(d) {}
45 };
46 
47 
48 // Computation state of the shortest-distance algorithm. Reusable
49 // information is maintained across calls to member function
50 // ShortestDistance(source) when 'retain' is true for improved
51 // efficiency when calling multiple times from different source states
52 // (e.g., in epsilon removal). Vector 'distance' should not be
53 // modified by the user between these calls.
54 template<class Arc, class Queue, class ArcFilter>
55 class ShortestDistanceState {
56  public:
57   typedef typename Arc::StateId StateId;
58   typedef typename Arc::Weight Weight;
59 
ShortestDistanceState(const Fst<Arc> & fst,vector<Weight> * distance,const ShortestDistanceOptions<Arc,Queue,ArcFilter> & opts,bool retain)60   ShortestDistanceState(
61       const Fst<Arc> &fst,
62       vector<Weight> *distance,
63       const ShortestDistanceOptions<Arc, Queue, ArcFilter> &opts,
64       bool retain)
65       : fst_(fst.Copy()), distance_(distance), state_queue_(opts.state_queue),
66         arc_filter_(opts.arc_filter),
67         delta_(opts.delta), retain_(retain) {
68     distance_->clear();
69   }
70 
~ShortestDistanceState()71   ~ShortestDistanceState() {
72     delete fst_;
73   }
74 
75   void ShortestDistance(StateId source);
76 
77  private:
78   const Fst<Arc> *fst_;
79   vector<Weight> *distance_;
80   Queue *state_queue_;
81   ArcFilter arc_filter_;
82   float delta_;
83   bool retain_;                  // Retain and reuse information across calls
84 
85   vector<Weight> rdistance_;    // Relaxation distance.
86   vector<bool> enqueued_;       // Is state enqueued?
87   vector<StateId> sources_;     // Source state for ith state in 'distance_',
88                                 //  'rdistance_', and 'enqueued_' if retained.
89 };
90 
91 // Compute the shortest distance. If 'source' is kNoStateId, use
92 // the initial state of the Fst.
93 template <class Arc, class Queue, class ArcFilter>
ShortestDistance(StateId source)94 void ShortestDistanceState<Arc, Queue, ArcFilter>::ShortestDistance(
95     StateId source) {
96   if (fst_->Start() == kNoStateId)
97     return;
98 
99   if (!(Weight::Properties() & kRightSemiring))
100     LOG(FATAL) << "ShortestDistance: Weight needs to be right distributive: "
101                << Weight::Type();
102 
103   state_queue_->Clear();
104 
105   if (!retain_) {
106     distance_->clear();
107     rdistance_.clear();
108     enqueued_.clear();
109   }
110 
111   if (source == kNoStateId)
112     source = fst_->Start();
113 
114   while ((StateId)distance_->size() <= source) {
115     distance_->push_back(Weight::Zero());
116     rdistance_.push_back(Weight::Zero());
117     enqueued_.push_back(false);
118   }
119   if (retain_) {
120     while ((StateId)sources_.size() <= source)
121       sources_.push_back(kNoStateId);
122     sources_[source] = source;
123   }
124   (*distance_)[source] = Weight::One();
125   rdistance_[source] = Weight::One();
126   enqueued_[source] = true;
127 
128   state_queue_->Enqueue(source);
129 
130   while (!state_queue_->Empty()) {
131     StateId s = state_queue_->Head();
132     state_queue_->Dequeue();
133     while ((StateId)distance_->size() <= s) {
134       distance_->push_back(Weight::Zero());
135       rdistance_.push_back(Weight::Zero());
136       enqueued_.push_back(false);
137     }
138     enqueued_[s] = false;
139     Weight r = rdistance_[s];
140     rdistance_[s] = Weight::Zero();
141     for (ArcIterator< Fst<Arc> > aiter(*fst_, s);
142          !aiter.Done();
143          aiter.Next()) {
144       const Arc &arc = aiter.Value();
145       if (!arc_filter_(arc) || arc.weight == Weight::Zero())
146         continue;
147       while ((StateId)distance_->size() <= arc.nextstate) {
148         distance_->push_back(Weight::Zero());
149         rdistance_.push_back(Weight::Zero());
150         enqueued_.push_back(false);
151       }
152       if (retain_) {
153         while ((StateId)sources_.size() <= arc.nextstate)
154           sources_.push_back(kNoStateId);
155         if (sources_[arc.nextstate] != source) {
156           (*distance_)[arc.nextstate] = Weight::Zero();
157           rdistance_[arc.nextstate] = Weight::Zero();
158           enqueued_[arc.nextstate] = false;
159           sources_[arc.nextstate] = source;
160         }
161       }
162       Weight &nd = (*distance_)[arc.nextstate];
163       Weight &nr = rdistance_[arc.nextstate];
164       Weight w = Times(r, arc.weight);
165       if (!ApproxEqual(nd, Plus(nd, w), delta_)) {
166         nd = Plus(nd, w);
167         nr = Plus(nr, w);
168         if (!enqueued_[arc.nextstate]) {
169           state_queue_->Enqueue(arc.nextstate);
170           enqueued_[arc.nextstate] = true;
171         } else {
172           state_queue_->Update(arc.nextstate);
173         }
174       }
175     }
176   }
177 }
178 
179 
180 // Shortest-distance algorithm: this version allows fine control
181 // via the options argument. See below for a simpler interface.
182 //
183 // This computes the shortest distance from the 'opts.source' state to
184 // each visited state S and stores the value in the 'distance' vector.
185 // An unvisited state S has distance Zero(), which will be stored in
186 // the 'distance' vector if S is less than the maximum visited state.
187 // The state queue discipline, arc filter, and convergence delta are
188 // taken in the options argument.
189 
190 // The weights must must be right distributive and k-closed (i.e., 1 +
191 // x + x^2 + ... + x^(k +1) = 1 + x + x^2 + ... + x^k).
192 //
193 // The algorithm is from Mohri, "Semiring Framweork and Algorithms for
194 // Shortest-Distance Problems", Journal of Automata, Languages and
195 // Combinatorics 7(3):321-350, 2002. The complexity of algorithm
196 // depends on the properties of the semiring and the queue discipline
197 // used. Refer to the paper for more details.
198 template<class Arc, class Queue, class ArcFilter>
ShortestDistance(const Fst<Arc> & fst,vector<typename Arc::Weight> * distance,const ShortestDistanceOptions<Arc,Queue,ArcFilter> & opts)199 void ShortestDistance(
200     const Fst<Arc> &fst,
201     vector<typename Arc::Weight> *distance,
202     const ShortestDistanceOptions<Arc, Queue, ArcFilter> &opts) {
203 
204   ShortestDistanceState<Arc, Queue, ArcFilter>
205     sd_state(fst, distance, opts, false);
206   sd_state.ShortestDistance(opts.source);
207 }
208 
209 // Shortest-distance algorithm: simplified interface. See above for a
210 // version that allows finer control.
211 //
212 // If 'reverse' is false, this computes the shortest distance from the
213 // initial state to each state S and stores the value in the
214 // 'distance' vector. If 'reverse' is true, this computes the shortest
215 // distance from each state to the final states.  An unvisited state S
216 // has distance Zero(), which will be stored in the 'distance' vector
217 // if S is less than the maximum visited state.  The state queue
218 // discipline is automatically-selected.
219 //
220 // The weights must must be right (left) distributive if reverse is
221 // false (true) and k-closed (i.e., 1 + x + x^2 + ... + x^(k +1) = 1 +
222 // x + x^2 + ... + x^k).
223 //
224 // The algorithm is from Mohri, "Semiring Framweork and Algorithms for
225 // Shortest-Distance Problems", Journal of Automata, Languages and
226 // Combinatorics 7(3):321-350, 2002. The complexity of algorithm
227 // depends on the properties of the semiring and the queue discipline
228 // used. Refer to the paper for more details.
229 template<class Arc>
230 void ShortestDistance(const Fst<Arc> &fst,
231                       vector<typename Arc::Weight> *distance,
232                       bool reverse = false) {
233   typedef typename Arc::StateId StateId;
234   typedef typename Arc::Weight Weight;
235 
236   if (!reverse) {
237     AnyArcFilter<Arc> arc_filter;
238     AutoQueue<StateId> state_queue(fst, distance, arc_filter);
239     ShortestDistanceOptions< Arc, AutoQueue<StateId>, AnyArcFilter<Arc> >
240       opts(&state_queue, arc_filter);
241     ShortestDistance(fst, distance, opts);
242   } else {
243     typedef ReverseArc<Arc> ReverseArc;
244     typedef typename ReverseArc::Weight ReverseWeight;
245     AnyArcFilter<ReverseArc> rarc_filter;
246     VectorFst<ReverseArc> rfst;
247     Reverse(fst, &rfst);
248     vector<ReverseWeight> rdistance;
249     AutoQueue<StateId> state_queue(rfst, &rdistance, rarc_filter);
250     ShortestDistanceOptions< ReverseArc, AutoQueue<StateId>,
251       AnyArcFilter<ReverseArc> >
252       ropts(&state_queue, rarc_filter);
253     ShortestDistance(rfst, &rdistance, ropts);
254     distance->clear();
255     while (distance->size() < rdistance.size() - 1)
256       distance->push_back(rdistance[distance->size() + 1].Reverse());
257   }
258 }
259 
260 }  // namespace fst
261 
262 #endif  // FST_LIB_SHORTEST_DISTANCE_H__
263