1 // shortest-path.h
2
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Copyright 2005-2010 Google, Inc.
16 // Author: riley@google.com (Michael Riley)
17 //
18 // \file
19 // Functions to find shortest paths in a PDT.
20
21 #ifndef FST_EXTENSIONS_PDT_SHORTEST_PATH_H__
22 #define FST_EXTENSIONS_PDT_SHORTEST_PATH_H__
23
24 #include <fst/shortest-path.h>
25 #include <fst/extensions/pdt/paren.h>
26 #include <fst/extensions/pdt/pdt.h>
27
28 #include <unordered_map>
29 using std::tr1::unordered_map;
30 using std::tr1::unordered_multimap;
31 #include <tr1/unordered_set>
32 using std::tr1::unordered_set;
33 using std::tr1::unordered_multiset;
34 #include <stack>
35 #include <vector>
36 using std::vector;
37
38 namespace fst {
39
40 template <class Arc, class Queue>
41 struct PdtShortestPathOptions {
42 bool keep_parentheses;
43 bool path_gc;
44
45 PdtShortestPathOptions(bool kp = false, bool gc = true)
keep_parenthesesPdtShortestPathOptions46 : keep_parentheses(kp), path_gc(gc) {}
47 };
48
49
50 // Class to store PDT shortest path results. Stores shortest path
51 // tree info 'Distance()', Parent(), and ArcParent() information keyed
52 // on two types:
53 // (1) By SearchState: This is a usual node in a shortest path tree but:
54 // (a) is w.r.t a PDT search state - a pair of a PDT state and
55 // a 'start' state, which is either the PDT start state or
56 // the destination state of an open parenthesis.
57 // (b) the Distance() is from this 'start' state to the search state.
58 // (c) Parent().state is kNoLabel for the 'start' state.
59 //
60 // (2) By ParenSpec: This connects shortest path trees depending on the
61 // the parenthesis taken. Given the parenthesis spec:
62 // (a) the Distance() is from the Parent() 'start' state to the
63 // parenthesis destination state.
64 // (b) the ArcParent() is the parenthesis arc.
65 template <class Arc>
66 class PdtShortestPathData {
67 public:
68 static const uint8 kFinal;
69
70 typedef typename Arc::StateId StateId;
71 typedef typename Arc::Weight Weight;
72 typedef typename Arc::Label Label;
73
74 struct SearchState {
SearchStateSearchState75 SearchState() : state(kNoStateId), start(kNoStateId) {}
76
SearchStateSearchState77 SearchState(StateId s, StateId t) : state(s), start(t) {}
78
79 bool operator==(const SearchState &s) const {
80 if (&s == this)
81 return true;
82 return s.state == this->state && s.start == this->start;
83 }
84
85 StateId state; // PDT state
86 StateId start; // PDT paren 'source' state
87 };
88
89
90 // Specifies paren id, source and dest 'start' states of a paren.
91 // These are the 'start' states of the respective sub-graphs.
92 struct ParenSpec {
ParenSpecParenSpec93 ParenSpec()
94 : paren_id(kNoLabel), src_start(kNoStateId), dest_start(kNoStateId) {}
95
ParenSpecParenSpec96 ParenSpec(Label id, StateId s, StateId d)
97 : paren_id(id), src_start(s), dest_start(d) {}
98
99 Label paren_id; // Id of parenthesis
100 StateId src_start; // sub-graph 'start' state for paren source.
101 StateId dest_start; // sub-graph 'start' state for paren dest.
102
103 bool operator==(const ParenSpec &x) const {
104 if (&x == this)
105 return true;
106 return x.paren_id == this->paren_id &&
107 x.src_start == this->src_start &&
108 x.dest_start == this->dest_start;
109 }
110 };
111
112 struct SearchData {
SearchDataSearchData113 SearchData() : distance(Weight::Zero()),
114 parent(kNoStateId, kNoStateId),
115 paren_id(kNoLabel),
116 flags(0) {}
117
118 Weight distance; // Distance to this state from PDT 'start' state
119 SearchState parent; // Parent state in shortest path tree
120 int16 paren_id; // If parent arc has paren, paren ID, o.w. kNoLabel
121 uint8 flags; // First byte reserved for PdtShortestPathData use
122 };
123
PdtShortestPathData(bool gc)124 PdtShortestPathData(bool gc)
125 : state_(kNoStateId, kNoStateId),
126 paren_(kNoLabel, kNoStateId, kNoStateId),
127 gc_(gc),
128 nstates_(0),
129 ngc_(0),
130 finished_(false) {}
131
~PdtShortestPathData()132 ~PdtShortestPathData() {
133 VLOG(1) << "opm size: " << paren_map_.size();
134 VLOG(1) << "# of search states: " << nstates_;
135 if (gc_)
136 VLOG(1) << "# of GC'd search states: " << ngc_;
137 }
138
Clear()139 void Clear() {
140 search_map_.clear();
141 search_multimap_.clear();
142 paren_map_.clear();
143 state_ = SearchState(kNoStateId, kNoStateId);
144 nstates_ = 0;
145 ngc_ = 0;
146 }
147
Distance(SearchState s)148 Weight Distance(SearchState s) const {
149 SearchData *data = GetSearchData(s);
150 return data->distance;
151 }
152
Distance(const ParenSpec & paren)153 Weight Distance(const ParenSpec &paren) const {
154 SearchData *data = GetSearchData(paren);
155 return data->distance;
156 }
157
Parent(SearchState s)158 SearchState Parent(SearchState s) const {
159 SearchData *data = GetSearchData(s);
160 return data->parent;
161 }
162
Parent(const ParenSpec & paren)163 SearchState Parent(const ParenSpec &paren) const {
164 SearchData *data = GetSearchData(paren);
165 return data->parent;
166 }
167
ParenId(SearchState s)168 Label ParenId(SearchState s) const {
169 SearchData *data = GetSearchData(s);
170 return data->paren_id;
171 }
172
Flags(SearchState s)173 uint8 Flags(SearchState s) const {
174 SearchData *data = GetSearchData(s);
175 return data->flags;
176 }
177
SetDistance(SearchState s,Weight w)178 void SetDistance(SearchState s, Weight w) {
179 SearchData *data = GetSearchData(s);
180 data->distance = w;
181 }
182
SetDistance(const ParenSpec & paren,Weight w)183 void SetDistance(const ParenSpec &paren, Weight w) {
184 SearchData *data = GetSearchData(paren);
185 data->distance = w;
186 }
187
SetParent(SearchState s,SearchState p)188 void SetParent(SearchState s, SearchState p) {
189 SearchData *data = GetSearchData(s);
190 data->parent = p;
191 }
192
SetParent(const ParenSpec & paren,SearchState p)193 void SetParent(const ParenSpec &paren, SearchState p) {
194 SearchData *data = GetSearchData(paren);
195 data->parent = p;
196 }
197
SetParenId(SearchState s,Label p)198 void SetParenId(SearchState s, Label p) {
199 if (p >= 32768)
200 FSTERROR() << "PdtShortestPathData: Paren ID does not fits in an int16";
201 SearchData *data = GetSearchData(s);
202 data->paren_id = p;
203 }
204
SetFlags(SearchState s,uint8 f,uint8 mask)205 void SetFlags(SearchState s, uint8 f, uint8 mask) {
206 SearchData *data = GetSearchData(s);
207 data->flags &= ~mask;
208 data->flags |= f & mask;
209 }
210
211 void GC(StateId s);
212
Finish()213 void Finish() { finished_ = true; }
214
215 private:
216 static const Arc kNoArc;
217 static const size_t kPrime0;
218 static const size_t kPrime1;
219 static const uint8 kInited;
220 static const uint8 kMarked;
221
222 // Hash for search state
223 struct SearchStateHash {
operatorSearchStateHash224 size_t operator()(const SearchState &s) const {
225 return s.state + s.start * kPrime0;
226 }
227 };
228
229 // Hash for paren map
230 struct ParenHash {
operatorParenHash231 size_t operator()(const ParenSpec &paren) const {
232 return paren.paren_id + paren.src_start * kPrime0 +
233 paren.dest_start * kPrime1;
234 }
235 };
236
237 typedef unordered_map<SearchState, SearchData, SearchStateHash> SearchMap;
238
239 typedef unordered_multimap<StateId, StateId> SearchMultimap;
240
241 // Hash map from paren spec to open paren data
242 typedef unordered_map<ParenSpec, SearchData, ParenHash> ParenMap;
243
GetSearchData(SearchState s)244 SearchData *GetSearchData(SearchState s) const {
245 if (s == state_)
246 return state_data_;
247 if (finished_) {
248 typename SearchMap::iterator it = search_map_.find(s);
249 if (it == search_map_.end())
250 return &null_search_data_;
251 state_ = s;
252 return state_data_ = &(it->second);
253 } else {
254 state_ = s;
255 state_data_ = &search_map_[s];
256 if (!(state_data_->flags & kInited)) {
257 ++nstates_;
258 if (gc_)
259 search_multimap_.insert(make_pair(s.start, s.state));
260 state_data_->flags = kInited;
261 }
262 return state_data_;
263 }
264 }
265
GetSearchData(ParenSpec paren)266 SearchData *GetSearchData(ParenSpec paren) const {
267 if (paren == paren_)
268 return paren_data_;
269 if (finished_) {
270 typename ParenMap::iterator it = paren_map_.find(paren);
271 if (it == paren_map_.end())
272 return &null_search_data_;
273 paren_ = paren;
274 return state_data_ = &(it->second);
275 } else {
276 paren_ = paren;
277 return paren_data_ = &paren_map_[paren];
278 }
279 }
280
281 mutable SearchMap search_map_; // Maps from search state to data
282 mutable SearchMultimap search_multimap_; // Maps from 'start' to subgraph
283 mutable ParenMap paren_map_; // Maps paren spec to search data
284 mutable SearchState state_; // Last state accessed
285 mutable SearchData *state_data_; // Last state data accessed
286 mutable ParenSpec paren_; // Last paren spec accessed
287 mutable SearchData *paren_data_; // Last paren data accessed
288 bool gc_; // Allow GC?
289 mutable size_t nstates_; // Total number of search states
290 size_t ngc_; // Number of GC'd search states
291 mutable SearchData null_search_data_; // Null search data
292 bool finished_; // Read-only access when true
293
294 DISALLOW_COPY_AND_ASSIGN(PdtShortestPathData);
295 };
296
297 // Deletes inaccessible search data from a given 'start' (open paren dest)
298 // state. Assumes 'final' (close paren source or PDT final) states have
299 // been flagged 'kFinal'.
300 template<class Arc>
GC(StateId start)301 void PdtShortestPathData<Arc>::GC(StateId start) {
302 if (!gc_)
303 return;
304 vector<StateId> final;
305 for (typename SearchMultimap::iterator mmit = search_multimap_.find(start);
306 mmit != search_multimap_.end() && mmit->first == start;
307 ++mmit) {
308 SearchState s(mmit->second, start);
309 const SearchData &data = search_map_[s];
310 if (data.flags & kFinal)
311 final.push_back(s.state);
312 }
313
314 // Mark phase
315 for (size_t i = 0; i < final.size(); ++i) {
316 SearchState s(final[i], start);
317 while (s.state != kNoLabel) {
318 SearchData *sdata = &search_map_[s];
319 if (sdata->flags & kMarked)
320 break;
321 sdata->flags |= kMarked;
322 SearchState p = sdata->parent;
323 if (p.start != start && p.start != kNoLabel) { // entering sub-subgraph
324 ParenSpec paren(sdata->paren_id, s.start, p.start);
325 SearchData *pdata = &paren_map_[paren];
326 s = pdata->parent;
327 } else {
328 s = p;
329 }
330 }
331 }
332
333 // Sweep phase
334 typename SearchMultimap::iterator mmit = search_multimap_.find(start);
335 while (mmit != search_multimap_.end() && mmit->first == start) {
336 SearchState s(mmit->second, start);
337 typename SearchMap::iterator mit = search_map_.find(s);
338 const SearchData &data = mit->second;
339 if (!(data.flags & kMarked)) {
340 search_map_.erase(mit);
341 ++ngc_;
342 }
343 search_multimap_.erase(mmit++);
344 }
345 }
346
347 template<class Arc> const Arc PdtShortestPathData<Arc>::kNoArc
348 = Arc(kNoLabel, kNoLabel, Weight::Zero(), kNoStateId);
349
350 template<class Arc> const size_t PdtShortestPathData<Arc>::kPrime0 = 7853;
351
352 template<class Arc> const size_t PdtShortestPathData<Arc>::kPrime1 = 7867;
353
354 template<class Arc> const uint8 PdtShortestPathData<Arc>::kInited = 0x01;
355
356 template<class Arc> const uint8 PdtShortestPathData<Arc>::kFinal = 0x02;
357
358 template<class Arc> const uint8 PdtShortestPathData<Arc>::kMarked = 0x04;
359
360
361 // This computes the single source shortest (balanced) path (SSSP)
362 // through a weighted PDT that has a bounded stack (i.e. is expandable
363 // as an FST). It is a generalization of the classic SSSP graph
364 // algorithm that removes a state s from a queue (defined by a
365 // user-provided queue type) and relaxes the destination states of
366 // transitions leaving s. In this PDT version, states that have
367 // entering open parentheses are treated as source states for a
368 // sub-graph SSSP problem with the shortest path up to the open
369 // parenthesis being first saved. When a close parenthesis is then
370 // encountered any balancing open parenthesis is examined for this
371 // saved information and multiplied back. In this way, each sub-graph
372 // is entered only once rather than repeatedly. If every state in the
373 // input PDT has the property that there is a unique 'start' state for
374 // it with entering open parentheses, then this algorithm is quite
375 // straight-forward. In general, this will not be the case, so the
376 // algorithm (implicitly) creates a new graph where each state is a
377 // pair of an original state and a possible parenthesis 'start' state
378 // for that state.
379 template<class Arc, class Queue>
380 class PdtShortestPath {
381 public:
382 typedef typename Arc::StateId StateId;
383 typedef typename Arc::Weight Weight;
384 typedef typename Arc::Label Label;
385
386 typedef PdtShortestPathData<Arc> SpData;
387 typedef typename SpData::SearchState SearchState;
388 typedef typename SpData::ParenSpec ParenSpec;
389
390 typedef typename PdtParenReachable<Arc>::SetIterator StateSetIterator;
391 typedef typename PdtBalanceData<Arc>::SetIterator CloseSourceIterator;
392
PdtShortestPath(const Fst<Arc> & ifst,const vector<pair<Label,Label>> & parens,const PdtShortestPathOptions<Arc,Queue> & opts)393 PdtShortestPath(const Fst<Arc> &ifst,
394 const vector<pair<Label, Label> > &parens,
395 const PdtShortestPathOptions<Arc, Queue> &opts)
396 : kFinal(SpData::kFinal),
397 ifst_(ifst.Copy()),
398 parens_(parens),
399 keep_parens_(opts.keep_parentheses),
400 start_(ifst.Start()),
401 sp_data_(opts.path_gc),
402 error_(false) {
403
404 if ((Weight::Properties() & (kPath | kRightSemiring))
405 != (kPath | kRightSemiring)) {
406 FSTERROR() << "SingleShortestPath: Weight needs to have the path"
407 << " property and be right distributive: " << Weight::Type();
408 error_ = true;
409 }
410
411 for (Label i = 0; i < parens.size(); ++i) {
412 const pair<Label, Label> &p = parens[i];
413 paren_id_map_[p.first] = i;
414 paren_id_map_[p.second] = i;
415 }
416 };
417
~PdtShortestPath()418 ~PdtShortestPath() {
419 VLOG(1) << "# of input states: " << CountStates(*ifst_);
420 VLOG(1) << "# of enqueued: " << nenqueued_;
421 VLOG(1) << "cpmm size: " << close_paren_multimap_.size();
422 delete ifst_;
423 }
424
ShortestPath(MutableFst<Arc> * ofst)425 void ShortestPath(MutableFst<Arc> *ofst) {
426 Init(ofst);
427 GetDistance(start_);
428 GetPath();
429 sp_data_.Finish();
430 if (error_) ofst->SetProperties(kError, kError);
431 }
432
GetShortestPathData()433 const PdtShortestPathData<Arc> &GetShortestPathData() const {
434 return sp_data_;
435 }
436
GetBalanceData()437 PdtBalanceData<Arc> *GetBalanceData() { return &balance_data_; }
438
439 private:
440 static const Arc kNoArc;
441 static const uint8 kEnqueued;
442 static const uint8 kExpanded;
443 const uint8 kFinal;
444
445 public:
446 // Hash multimap from close paren label to an paren arc.
447 typedef unordered_multimap<ParenState<Arc>, Arc,
448 typename ParenState<Arc>::Hash> CloseParenMultimap;
449
GetCloseParenMultimap()450 const CloseParenMultimap &GetCloseParenMultimap() const {
451 return close_paren_multimap_;
452 }
453
454 private:
455 void Init(MutableFst<Arc> *ofst);
456 void GetDistance(StateId start);
457 void ProcFinal(SearchState s);
458 void ProcArcs(SearchState s);
459 void ProcOpenParen(Label paren_id, SearchState s, Arc arc, Weight w);
460 void ProcCloseParen(Label paren_id, SearchState s, const Arc &arc, Weight w);
461 void ProcNonParen(SearchState s, const Arc &arc, Weight w);
462 void Relax(SearchState s, SearchState t, Arc arc, Weight w, Label paren_id);
463 void Enqueue(SearchState d);
464 void GetPath();
465 Arc GetPathArc(SearchState s, SearchState p, Label paren_id, bool open);
466
467 Fst<Arc> *ifst_;
468 MutableFst<Arc> *ofst_;
469 const vector<pair<Label, Label> > &parens_;
470 bool keep_parens_;
471 Queue *state_queue_; // current state queue
472 StateId start_;
473 Weight f_distance_;
474 SearchState f_parent_;
475 SpData sp_data_;
476 unordered_map<Label, Label> paren_id_map_;
477 CloseParenMultimap close_paren_multimap_;
478 PdtBalanceData<Arc> balance_data_;
479 ssize_t nenqueued_;
480 bool error_;
481
482 DISALLOW_COPY_AND_ASSIGN(PdtShortestPath);
483 };
484
485 template<class Arc, class Queue>
Init(MutableFst<Arc> * ofst)486 void PdtShortestPath<Arc, Queue>::Init(MutableFst<Arc> *ofst) {
487 ofst_ = ofst;
488 ofst->DeleteStates();
489 ofst->SetInputSymbols(ifst_->InputSymbols());
490 ofst->SetOutputSymbols(ifst_->OutputSymbols());
491
492 if (ifst_->Start() == kNoStateId)
493 return;
494
495 f_distance_ = Weight::Zero();
496 f_parent_ = SearchState(kNoStateId, kNoStateId);
497
498 sp_data_.Clear();
499 close_paren_multimap_.clear();
500 balance_data_.Clear();
501 nenqueued_ = 0;
502
503 // Find open parens per destination state and close parens per source state.
504 for (StateIterator<Fst<Arc> > siter(*ifst_); !siter.Done(); siter.Next()) {
505 StateId s = siter.Value();
506 for (ArcIterator<Fst<Arc> > aiter(*ifst_, s);
507 !aiter.Done(); aiter.Next()) {
508 const Arc &arc = aiter.Value();
509 typename unordered_map<Label, Label>::const_iterator pit
510 = paren_id_map_.find(arc.ilabel);
511 if (pit != paren_id_map_.end()) { // Is a paren?
512 Label paren_id = pit->second;
513 if (arc.ilabel == parens_[paren_id].first) { // Open paren
514 balance_data_.OpenInsert(paren_id, arc.nextstate);
515 } else { // Close paren
516 ParenState<Arc> paren_state(paren_id, s);
517 close_paren_multimap_.insert(make_pair(paren_state, arc));
518 }
519 }
520 }
521 }
522 }
523
524 // Computes the shortest distance stored in a recursive way. Each
525 // sub-graph (i.e. different paren 'start' state) begins with weight One().
526 template<class Arc, class Queue>
GetDistance(StateId start)527 void PdtShortestPath<Arc, Queue>::GetDistance(StateId start) {
528 if (start == kNoStateId)
529 return;
530
531 Queue state_queue;
532 state_queue_ = &state_queue;
533 SearchState q(start, start);
534 Enqueue(q);
535 sp_data_.SetDistance(q, Weight::One());
536
537 while (!state_queue_->Empty()) {
538 StateId state = state_queue_->Head();
539 state_queue_->Dequeue();
540 SearchState s(state, start);
541 sp_data_.SetFlags(s, 0, kEnqueued);
542 ProcFinal(s);
543 ProcArcs(s);
544 sp_data_.SetFlags(s, kExpanded, kExpanded);
545 }
546 balance_data_.FinishInsert(start);
547 sp_data_.GC(start);
548 }
549
550 // Updates best complete path.
551 template<class Arc, class Queue>
ProcFinal(SearchState s)552 void PdtShortestPath<Arc, Queue>::ProcFinal(SearchState s) {
553 if (ifst_->Final(s.state) != Weight::Zero() && s.start == start_) {
554 Weight w = Times(sp_data_.Distance(s),
555 ifst_->Final(s.state));
556 if (f_distance_ != Plus(f_distance_, w)) {
557 if (f_parent_.state != kNoStateId)
558 sp_data_.SetFlags(f_parent_, 0, kFinal);
559 sp_data_.SetFlags(s, kFinal, kFinal);
560
561 f_distance_ = Plus(f_distance_, w);
562 f_parent_ = s;
563 }
564 }
565 }
566
567 // Processes all arcs leaving the state s.
568 template<class Arc, class Queue>
ProcArcs(SearchState s)569 void PdtShortestPath<Arc, Queue>::ProcArcs(SearchState s) {
570 for (ArcIterator< Fst<Arc> > aiter(*ifst_, s.state);
571 !aiter.Done();
572 aiter.Next()) {
573 Arc arc = aiter.Value();
574 Weight w = Times(sp_data_.Distance(s), arc.weight);
575
576 typename unordered_map<Label, Label>::const_iterator pit
577 = paren_id_map_.find(arc.ilabel);
578 if (pit != paren_id_map_.end()) { // Is a paren?
579 Label paren_id = pit->second;
580 if (arc.ilabel == parens_[paren_id].first)
581 ProcOpenParen(paren_id, s, arc, w);
582 else
583 ProcCloseParen(paren_id, s, arc, w);
584 } else {
585 ProcNonParen(s, arc, w);
586 }
587 }
588 }
589
590 // Saves the shortest path info for reaching this parenthesis
591 // and starts a new SSSP in the sub-graph pointed to by the parenthesis
592 // if previously unvisited. Otherwise it finds any previously encountered
593 // closing parentheses and relaxes them using the recursively stored
594 // shortest distance to them.
595 template<class Arc, class Queue> inline
ProcOpenParen(Label paren_id,SearchState s,Arc arc,Weight w)596 void PdtShortestPath<Arc, Queue>::ProcOpenParen(
597 Label paren_id, SearchState s, Arc arc, Weight w) {
598
599 SearchState d(arc.nextstate, arc.nextstate);
600 ParenSpec paren(paren_id, s.start, d.start);
601 Weight pdist = sp_data_.Distance(paren);
602 if (pdist != Plus(pdist, w)) {
603 sp_data_.SetDistance(paren, w);
604 sp_data_.SetParent(paren, s);
605 Weight dist = sp_data_.Distance(d);
606 if (dist == Weight::Zero()) {
607 Queue *state_queue = state_queue_;
608 GetDistance(d.start);
609 state_queue_ = state_queue;
610 }
611 for (CloseSourceIterator set_iter =
612 balance_data_.Find(paren_id, arc.nextstate);
613 !set_iter.Done(); set_iter.Next()) {
614 SearchState cpstate(set_iter.Element(), d.start);
615 ParenState<Arc> paren_state(paren_id, cpstate.state);
616 for (typename CloseParenMultimap::const_iterator cpit =
617 close_paren_multimap_.find(paren_state);
618 cpit != close_paren_multimap_.end() && paren_state == cpit->first;
619 ++cpit) {
620 const Arc &cparc = cpit->second;
621 Weight cpw = Times(w, Times(sp_data_.Distance(cpstate),
622 cparc.weight));
623 Relax(cpstate, s, cparc, cpw, paren_id);
624 }
625 }
626 }
627 }
628
629 // Saves the correspondence between each closing parenthesis and its
630 // balancing open parenthesis info. Relaxes any close parenthesis
631 // destination state that has a balancing previously encountered open
632 // parenthesis.
633 template<class Arc, class Queue> inline
ProcCloseParen(Label paren_id,SearchState s,const Arc & arc,Weight w)634 void PdtShortestPath<Arc, Queue>::ProcCloseParen(
635 Label paren_id, SearchState s, const Arc &arc, Weight w) {
636 ParenState<Arc> paren_state(paren_id, s.start);
637 if (!(sp_data_.Flags(s) & kExpanded)) {
638 balance_data_.CloseInsert(paren_id, s.start, s.state);
639 sp_data_.SetFlags(s, kFinal, kFinal);
640 }
641 }
642
643 // For non-parentheses, classical relaxation.
644 template<class Arc, class Queue> inline
ProcNonParen(SearchState s,const Arc & arc,Weight w)645 void PdtShortestPath<Arc, Queue>::ProcNonParen(
646 SearchState s, const Arc &arc, Weight w) {
647 Relax(s, s, arc, w, kNoLabel);
648 }
649
650 // Classical relaxation on the search graph for 'arc' from state 's'.
651 // State 't' is in the same sub-graph as the nextstate should be (i.e.
652 // has the same paren 'start'.
653 template<class Arc, class Queue> inline
Relax(SearchState s,SearchState t,Arc arc,Weight w,Label paren_id)654 void PdtShortestPath<Arc, Queue>::Relax(
655 SearchState s, SearchState t, Arc arc, Weight w, Label paren_id) {
656 SearchState d(arc.nextstate, t.start);
657 Weight dist = sp_data_.Distance(d);
658 if (dist != Plus(dist, w)) {
659 sp_data_.SetParent(d, s);
660 sp_data_.SetParenId(d, paren_id);
661 sp_data_.SetDistance(d, Plus(dist, w));
662 Enqueue(d);
663 }
664 }
665
666 template<class Arc, class Queue> inline
Enqueue(SearchState s)667 void PdtShortestPath<Arc, Queue>::Enqueue(SearchState s) {
668 if (!(sp_data_.Flags(s) & kEnqueued)) {
669 state_queue_->Enqueue(s.state);
670 sp_data_.SetFlags(s, kEnqueued, kEnqueued);
671 ++nenqueued_;
672 } else {
673 state_queue_->Update(s.state);
674 }
675 }
676
677 // Follows parent pointers to find the shortest path. Uses a stack
678 // since the shortest distance is stored recursively.
679 template<class Arc, class Queue>
GetPath()680 void PdtShortestPath<Arc, Queue>::GetPath() {
681 SearchState s = f_parent_, d = SearchState(kNoStateId, kNoStateId);
682 StateId s_p = kNoStateId, d_p = kNoStateId;
683 Arc arc(kNoArc);
684 Label paren_id = kNoLabel;
685 stack<ParenSpec> paren_stack;
686 while (s.state != kNoStateId) {
687 d_p = s_p;
688 s_p = ofst_->AddState();
689 if (d.state == kNoStateId) {
690 ofst_->SetFinal(s_p, ifst_->Final(f_parent_.state));
691 } else {
692 if (paren_id != kNoLabel) { // paren?
693 if (arc.ilabel == parens_[paren_id].first) { // open paren
694 paren_stack.pop();
695 } else { // close paren
696 ParenSpec paren(paren_id, d.start, s.start);
697 paren_stack.push(paren);
698 }
699 if (!keep_parens_)
700 arc.ilabel = arc.olabel = 0;
701 }
702 arc.nextstate = d_p;
703 ofst_->AddArc(s_p, arc);
704 }
705 d = s;
706 s = sp_data_.Parent(d);
707 paren_id = sp_data_.ParenId(d);
708 if (s.state != kNoStateId) {
709 arc = GetPathArc(s, d, paren_id, false);
710 } else if (!paren_stack.empty()) {
711 ParenSpec paren = paren_stack.top();
712 s = sp_data_.Parent(paren);
713 paren_id = paren.paren_id;
714 arc = GetPathArc(s, d, paren_id, true);
715 }
716 }
717 ofst_->SetStart(s_p);
718 ofst_->SetProperties(
719 ShortestPathProperties(ofst_->Properties(kFstProperties, false)),
720 kFstProperties);
721 }
722
723
724 // Finds transition with least weight between two states with label matching
725 // paren_id and open/close paren type or a non-paren if kNoLabel.
726 template<class Arc, class Queue>
GetPathArc(SearchState s,SearchState d,Label paren_id,bool open_paren)727 Arc PdtShortestPath<Arc, Queue>::GetPathArc(
728 SearchState s, SearchState d, Label paren_id, bool open_paren) {
729 Arc path_arc = kNoArc;
730 for (ArcIterator< Fst<Arc> > aiter(*ifst_, s.state);
731 !aiter.Done();
732 aiter.Next()) {
733 const Arc &arc = aiter.Value();
734 if (arc.nextstate != d.state)
735 continue;
736 Label arc_paren_id = kNoLabel;
737 typename unordered_map<Label, Label>::const_iterator pit
738 = paren_id_map_.find(arc.ilabel);
739 if (pit != paren_id_map_.end()) {
740 arc_paren_id = pit->second;
741 bool arc_open_paren = arc.ilabel == parens_[arc_paren_id].first;
742 if (arc_open_paren != open_paren)
743 continue;
744 }
745 if (arc_paren_id != paren_id)
746 continue;
747 if (arc.weight == Plus(arc.weight, path_arc.weight))
748 path_arc = arc;
749 }
750 if (path_arc.nextstate == kNoStateId) {
751 FSTERROR() << "PdtShortestPath::GetPathArc failed to find arc";
752 error_ = true;
753 }
754 return path_arc;
755 }
756
757 template<class Arc, class Queue>
758 const Arc PdtShortestPath<Arc, Queue>::kNoArc
759 = Arc(kNoLabel, kNoLabel, Weight::Zero(), kNoStateId);
760
761 template<class Arc, class Queue>
762 const uint8 PdtShortestPath<Arc, Queue>::kEnqueued = 0x10;
763
764 template<class Arc, class Queue>
765 const uint8 PdtShortestPath<Arc, Queue>::kExpanded = 0x20;
766
767 template<class Arc, class Queue>
ShortestPath(const Fst<Arc> & ifst,const vector<pair<typename Arc::Label,typename Arc::Label>> & parens,MutableFst<Arc> * ofst,const PdtShortestPathOptions<Arc,Queue> & opts)768 void ShortestPath(const Fst<Arc> &ifst,
769 const vector<pair<typename Arc::Label,
770 typename Arc::Label> > &parens,
771 MutableFst<Arc> *ofst,
772 const PdtShortestPathOptions<Arc, Queue> &opts) {
773 PdtShortestPath<Arc, Queue> psp(ifst, parens, opts);
774 psp.ShortestPath(ofst);
775 }
776
777 template<class Arc>
ShortestPath(const Fst<Arc> & ifst,const vector<pair<typename Arc::Label,typename Arc::Label>> & parens,MutableFst<Arc> * ofst)778 void ShortestPath(const Fst<Arc> &ifst,
779 const vector<pair<typename Arc::Label,
780 typename Arc::Label> > &parens,
781 MutableFst<Arc> *ofst) {
782 typedef FifoQueue<typename Arc::StateId> Queue;
783 PdtShortestPathOptions<Arc, Queue> opts;
784 PdtShortestPath<Arc, Queue> psp(ifst, parens, opts);
785 psp.ShortestPath(ofst);
786 }
787
788 } // namespace fst
789
790 #endif // FST_EXTENSIONS_PDT_SHORTEST_PATH_H__
791