• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // paren.h
2 
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Copyright 2005-2010 Google, Inc.
16 // Author: riley@google.com (Michael Riley)
17 //
18 // Common classes for PDT parentheses
19 
20 // \file
21 
22 #ifndef FST_EXTENSIONS_PDT_PAREN_H_
23 #define FST_EXTENSIONS_PDT_PAREN_H_
24 
25 #include <algorithm>
26 #include <unordered_map>
27 using std::tr1::unordered_map;
28 using std::tr1::unordered_multimap;
29 #include <tr1/unordered_set>
30 using std::tr1::unordered_set;
31 using std::tr1::unordered_multiset;
32 #include <set>
33 
34 #include <fst/extensions/pdt/pdt.h>
35 #include <fst/extensions/pdt/collection.h>
36 #include <fst/fst.h>
37 #include <fst/dfs-visit.h>
38 
39 
40 namespace fst {
41 
42 //
43 // ParenState: Pair of an open (close) parenthesis and
44 // its destination (source) state.
45 //
46 
47 template <class A>
48 class ParenState {
49  public:
50   typedef typename A::Label Label;
51   typedef typename A::StateId StateId;
52 
53   struct Hash {
operatorHash54     size_t operator()(const ParenState<A> &p) const {
55       return p.paren_id + p.state_id * kPrime;
56     }
57   };
58 
59   Label paren_id;     // ID of open (close) paren
60   StateId state_id;   // destination (source) state of open (close) paren
61 
ParenState()62   ParenState() : paren_id(kNoLabel), state_id(kNoStateId) {}
63 
ParenState(Label p,StateId s)64   ParenState(Label p, StateId s) : paren_id(p), state_id(s) {}
65 
66   bool operator==(const ParenState<A> &p) const {
67     if (&p == this)
68       return true;
69     return p.paren_id == this->paren_id && p.state_id == this->state_id;
70   }
71 
72   bool operator!=(const ParenState<A> &p) const { return !(p == *this); }
73 
74   bool operator<(const ParenState<A> &p) const {
75     return paren_id < this->paren.id ||
76         (p.paren_id == this->paren.id && p.state_id < this->state_id);
77   }
78 
79  private:
80   static const size_t kPrime;
81 };
82 
83 template <class A>
84 const size_t ParenState<A>::kPrime = 7853;
85 
86 
87 // Creates an FST-style iterator from STL map and iterator.
88 template <class M>
89 class MapIterator {
90  public:
91   typedef typename M::const_iterator StlIterator;
92   typedef typename M::value_type PairType;
93   typedef typename PairType::second_type ValueType;
94 
MapIterator(const M & m,StlIterator iter)95   MapIterator(const M &m, StlIterator iter)
96       : map_(m), begin_(iter), iter_(iter) {}
97 
Done()98   bool Done() const {
99     return iter_ == map_.end() || iter_->first != begin_->first;
100   }
101 
Value()102   ValueType Value() const { return iter_->second; }
Next()103   void Next() { ++iter_; }
Reset()104   void Reset() { iter_ = begin_; }
105 
106  private:
107   const M &map_;
108   StlIterator begin_;
109   StlIterator iter_;
110 };
111 
112 //
113 // PdtParenReachable: Provides various parenthesis reachability information
114 // on a PDT.
115 //
116 
117 template <class A>
118 class PdtParenReachable {
119  public:
120   typedef typename A::StateId StateId;
121   typedef typename A::Label Label;
122  public:
123   // Maps from state ID to reachable paren IDs from (to) that state.
124   typedef unordered_multimap<StateId, Label> ParenMultiMap;
125 
126   // Maps from paren ID and state ID to reachable state set ID
127   typedef unordered_map<ParenState<A>, ssize_t,
128                    typename ParenState<A>::Hash> StateSetMap;
129 
130   // Maps from paren ID and state ID to arcs exiting that state with that
131   // Label.
132   typedef unordered_multimap<ParenState<A>, A,
133                         typename ParenState<A>::Hash> ParenArcMultiMap;
134 
135   typedef MapIterator<ParenMultiMap> ParenIterator;
136 
137   typedef MapIterator<ParenArcMultiMap> ParenArcIterator;
138 
139   typedef typename Collection<ssize_t, StateId>::SetIterator SetIterator;
140 
141   // Computes close (open) parenthesis reachabilty information for
142   // a PDT with bounded stack.
PdtParenReachable(const Fst<A> & fst,const vector<pair<Label,Label>> & parens,bool close)143   PdtParenReachable(const Fst<A> &fst,
144                     const vector<pair<Label, Label> > &parens, bool close)
145       : fst_(fst),
146         parens_(parens),
147         close_(close) {
148     for (Label i = 0; i < parens.size(); ++i) {
149       const pair<Label, Label>  &p = parens[i];
150       paren_id_map_[p.first] = i;
151       paren_id_map_[p.second] = i;
152     }
153 
154     if (close_) {
155       StateId start = fst.Start();
156       if (start == kNoStateId)
157         return;
158       DFSearch(start, start);
159     } else {
160       FSTERROR() << "PdtParenReachable: open paren info not implemented";
161     }
162   }
163 
164   // Given a state ID, returns an iterator over paren IDs
165   // for close (open) parens reachable from that state along balanced
166   // paths.
FindParens(StateId s)167   ParenIterator FindParens(StateId s) const {
168     return ParenIterator(paren_multimap_, paren_multimap_.find(s));
169   }
170 
171   // Given a paren ID and a state ID s, returns an iterator over
172   // states that can be reached along balanced paths from (to) s that
173   // have have close (open) parentheses matching the paren ID exiting
174   // (entering) those states.
FindStates(Label paren_id,StateId s)175   SetIterator FindStates(Label paren_id, StateId s) const {
176     ParenState<A> paren_state(paren_id, s);
177     typename StateSetMap::const_iterator id_it = set_map_.find(paren_state);
178     if (id_it == set_map_.end()) {
179       return state_sets_.FindSet(-1);
180     } else {
181       return state_sets_.FindSet(id_it->second);
182     }
183   }
184 
185   // Given a paren Id and a state ID s, return an iterator over
186   // arcs that exit (enter) s and are labeled with a close (open)
187   // parenthesis matching the paren ID.
FindParenArcs(Label paren_id,StateId s)188   ParenArcIterator FindParenArcs(Label paren_id, StateId s) const {
189     ParenState<A> paren_state(paren_id, s);
190     return ParenArcIterator(paren_arc_multimap_,
191                             paren_arc_multimap_.find(paren_state));
192   }
193 
194  private:
195   // DFS that gathers paren and state set information.
196   // Bool returns false when cycle detected.
197   bool DFSearch(StateId s, StateId start);
198 
199   // Unions state sets together gathered by the DFS.
200   void ComputeStateSet(StateId s);
201 
202   // Gather state set(s) from state 'nexts'.
203   void UpdateStateSet(StateId nexts, set<Label> *paren_set,
204                       vector< set<StateId> > *state_sets) const;
205 
206   const Fst<A> &fst_;
207   const vector<pair<Label, Label> > &parens_;         // Paren ID -> Labels
208   bool close_;                                        // Close/open paren info?
209   unordered_map<Label, Label> paren_id_map_;               // Paren labels -> ID
210   ParenMultiMap paren_multimap_;                      // Paren reachability
211   ParenArcMultiMap paren_arc_multimap_;               // Paren Arcs
212   vector<char> state_color_;                          // DFS state
213   mutable Collection<ssize_t, StateId> state_sets_;   // Reachable states -> ID
214   StateSetMap set_map_;                               // ID -> Reachable states
215   DISALLOW_COPY_AND_ASSIGN(PdtParenReachable);
216 };
217 
218 // DFS that gathers paren and state set information.
219 template <class A>
DFSearch(StateId s,StateId start)220 bool PdtParenReachable<A>::DFSearch(StateId s, StateId start) {
221   if (s >= state_color_.size())
222     state_color_.resize(s + 1, kDfsWhite);
223 
224   if (state_color_[s] == kDfsBlack)
225     return true;
226 
227   if (state_color_[s] == kDfsGrey)
228     return false;
229 
230   state_color_[s] = kDfsGrey;
231 
232   for (ArcIterator<Fst<A> > aiter(fst_, s);
233        !aiter.Done();
234        aiter.Next()) {
235     const A &arc = aiter.Value();
236 
237     typename unordered_map<Label, Label>::const_iterator pit
238         = paren_id_map_.find(arc.ilabel);
239     if (pit != paren_id_map_.end()) {               // paren?
240       Label paren_id = pit->second;
241       if (arc.ilabel == parens_[paren_id].first) {  // open paren
242         DFSearch(arc.nextstate, arc.nextstate);
243         for (SetIterator set_iter = FindStates(paren_id, arc.nextstate);
244              !set_iter.Done(); set_iter.Next()) {
245           for (ParenArcIterator paren_arc_iter =
246                    FindParenArcs(paren_id, set_iter.Element());
247                !paren_arc_iter.Done();
248                paren_arc_iter.Next()) {
249             const A &cparc = paren_arc_iter.Value();
250             DFSearch(cparc.nextstate, start);
251           }
252         }
253       }
254     } else {                                       // non-paren
255       if(!DFSearch(arc.nextstate, start)) {
256         FSTERROR() << "PdtReachable: Underlying cyclicity not supported";
257         return true;
258       }
259     }
260   }
261   ComputeStateSet(s);
262   state_color_[s] = kDfsBlack;
263   return true;
264 }
265 
266 // Unions state sets together gathered by the DFS.
267 template <class A>
ComputeStateSet(StateId s)268 void PdtParenReachable<A>::ComputeStateSet(StateId s) {
269   set<Label> paren_set;
270   vector< set<StateId> > state_sets(parens_.size());
271   for (ArcIterator< Fst<A> > aiter(fst_, s);
272        !aiter.Done();
273        aiter.Next()) {
274     const A &arc = aiter.Value();
275 
276     typename unordered_map<Label, Label>::const_iterator pit
277         = paren_id_map_.find(arc.ilabel);
278     if (pit != paren_id_map_.end()) {               // paren?
279       Label paren_id = pit->second;
280       if (arc.ilabel == parens_[paren_id].first) {  // open paren
281         for (SetIterator set_iter =
282                  FindStates(paren_id, arc.nextstate);
283              !set_iter.Done(); set_iter.Next()) {
284           for (ParenArcIterator paren_arc_iter =
285                    FindParenArcs(paren_id, set_iter.Element());
286                !paren_arc_iter.Done();
287                paren_arc_iter.Next()) {
288             const A &cparc = paren_arc_iter.Value();
289             UpdateStateSet(cparc.nextstate, &paren_set, &state_sets);
290           }
291         }
292       } else {                                      // close paren
293         paren_set.insert(paren_id);
294         state_sets[paren_id].insert(s);
295         ParenState<A> paren_state(paren_id, s);
296         paren_arc_multimap_.insert(make_pair(paren_state, arc));
297       }
298     } else {                                        // non-paren
299       UpdateStateSet(arc.nextstate, &paren_set, &state_sets);
300     }
301   }
302 
303   vector<StateId> state_set;
304   for (typename set<Label>::iterator paren_iter = paren_set.begin();
305        paren_iter != paren_set.end(); ++paren_iter) {
306     state_set.clear();
307     Label paren_id = *paren_iter;
308     paren_multimap_.insert(make_pair(s, paren_id));
309     for (typename set<StateId>::iterator state_iter
310              = state_sets[paren_id].begin();
311          state_iter != state_sets[paren_id].end();
312          ++state_iter) {
313       state_set.push_back(*state_iter);
314     }
315     ParenState<A> paren_state(paren_id, s);
316     set_map_[paren_state] = state_sets_.FindId(state_set);
317   }
318 }
319 
320 // Gather state set(s) from state 'nexts'.
321 template <class A>
UpdateStateSet(StateId nexts,set<Label> * paren_set,vector<set<StateId>> * state_sets)322 void PdtParenReachable<A>::UpdateStateSet(
323     StateId nexts, set<Label> *paren_set,
324     vector< set<StateId> > *state_sets) const {
325   for(ParenIterator paren_iter = FindParens(nexts);
326       !paren_iter.Done(); paren_iter.Next()) {
327     Label paren_id = paren_iter.Value();
328     paren_set->insert(paren_id);
329     for (SetIterator set_iter = FindStates(paren_id, nexts);
330          !set_iter.Done(); set_iter.Next()) {
331       (*state_sets)[paren_id].insert(set_iter.Element());
332     }
333   }
334 }
335 
336 
337 // Store balancing parenthesis data for a PDT. Allows on-the-fly
338 // construction (e.g. in PdtShortestPath) unlike PdtParenReachable above.
339 template <class A>
340 class PdtBalanceData {
341  public:
342   typedef typename A::StateId StateId;
343   typedef typename A::Label Label;
344 
345   // Hash set for open parens
346   typedef unordered_set<ParenState<A>, typename ParenState<A>::Hash> OpenParenSet;
347 
348   // Maps from open paren destination state to parenthesis ID.
349   typedef unordered_multimap<StateId, Label> OpenParenMap;
350 
351   // Maps from open paren state to source states of matching close parens
352   typedef unordered_multimap<ParenState<A>, StateId,
353                         typename ParenState<A>::Hash> CloseParenMap;
354 
355   // Maps from open paren state to close source set ID
356   typedef unordered_map<ParenState<A>, ssize_t,
357                    typename ParenState<A>::Hash> CloseSourceMap;
358 
359   typedef typename Collection<ssize_t, StateId>::SetIterator SetIterator;
360 
PdtBalanceData()361   PdtBalanceData() {}
362 
Clear()363   void Clear() {
364     open_paren_map_.clear();
365     close_paren_map_.clear();
366   }
367 
368   // Adds an open parenthesis with destination state 'open_dest'.
OpenInsert(Label paren_id,StateId open_dest)369   void OpenInsert(Label paren_id, StateId open_dest) {
370     ParenState<A> key(paren_id, open_dest);
371     if (!open_paren_set_.count(key)) {
372       open_paren_set_.insert(key);
373       open_paren_map_.insert(make_pair(open_dest, paren_id));
374     }
375   }
376 
377   // Adds a matching closing parenthesis with source state
378   // 'close_source' that balances an open_parenthesis with destination
379   // state 'open_dest' if OpenInsert() previously called
380   // (o.w. CloseInsert() does nothing).
CloseInsert(Label paren_id,StateId open_dest,StateId close_source)381   void CloseInsert(Label paren_id, StateId open_dest, StateId close_source) {
382     ParenState<A> key(paren_id, open_dest);
383     if (open_paren_set_.count(key))
384       close_paren_map_.insert(make_pair(key, close_source));
385   }
386 
387   // Find close paren source states matching an open parenthesis.
388   // Methods that follow, iterate through those matching states.
389   // Should be called only after FinishInsert(open_dest).
Find(Label paren_id,StateId open_dest)390   SetIterator Find(Label paren_id, StateId open_dest) {
391     ParenState<A> close_key(paren_id, open_dest);
392     typename CloseSourceMap::const_iterator id_it =
393         close_source_map_.find(close_key);
394     if (id_it == close_source_map_.end()) {
395       return close_source_sets_.FindSet(-1);
396     } else {
397       return close_source_sets_.FindSet(id_it->second);
398     }
399   }
400 
401   // Call when all open and close parenthesis insertions wrt open
402   // parentheses entering 'open_dest' are finished. Must be called
403   // before Find(open_dest). Stores close paren source state sets
404   // efficiently.
FinishInsert(StateId open_dest)405   void FinishInsert(StateId open_dest) {
406     vector<StateId> close_sources;
407     for (typename OpenParenMap::iterator oit = open_paren_map_.find(open_dest);
408          oit != open_paren_map_.end() && oit->first == open_dest;) {
409       Label paren_id = oit->second;
410       close_sources.clear();
411       ParenState<A> okey(paren_id, open_dest);
412       open_paren_set_.erase(open_paren_set_.find(okey));
413       for (typename CloseParenMap::iterator cit = close_paren_map_.find(okey);
414            cit != close_paren_map_.end() && cit->first == okey;) {
415         close_sources.push_back(cit->second);
416         close_paren_map_.erase(cit++);
417       }
418       sort(close_sources.begin(), close_sources.end());
419       typename vector<StateId>::iterator unique_end =
420           unique(close_sources.begin(), close_sources.end());
421       close_sources.resize(unique_end - close_sources.begin());
422 
423       if (!close_sources.empty())
424         close_source_map_[okey] = close_source_sets_.FindId(close_sources);
425       open_paren_map_.erase(oit++);
426     }
427   }
428 
429   // Return a new balance data object representing the reversed balance
430   // information.
431   PdtBalanceData<A> *Reverse(StateId num_states,
432                                StateId num_split,
433                                StateId state_id_shift) const;
434 
435  private:
436   OpenParenSet open_paren_set_;                      // open par. at dest?
437 
438   OpenParenMap open_paren_map_;                      // open parens per state
439   ParenState<A> open_dest_;                          // cur open dest. state
440   typename OpenParenMap::const_iterator open_iter_;  // cur open parens/state
441 
442   CloseParenMap close_paren_map_;                    // close states/open
443                                                      //  paren and state
444 
445   CloseSourceMap close_source_map_;                  // paren, state to set ID
446   mutable Collection<ssize_t, StateId> close_source_sets_;
447 };
448 
449 // Return a new balance data object representing the reversed balance
450 // information.
451 template <class A>
Reverse(StateId num_states,StateId num_split,StateId state_id_shift)452 PdtBalanceData<A> *PdtBalanceData<A>::Reverse(
453     StateId num_states,
454     StateId num_split,
455     StateId state_id_shift) const {
456   PdtBalanceData<A> *bd = new PdtBalanceData<A>;
457   unordered_set<StateId> close_sources;
458   StateId split_size = num_states / num_split;
459 
460   for (StateId i = 0; i < num_states; i+= split_size) {
461     close_sources.clear();
462 
463     for (typename CloseSourceMap::const_iterator
464              sit = close_source_map_.begin();
465          sit != close_source_map_.end();
466          ++sit) {
467       ParenState<A> okey = sit->first;
468       StateId open_dest = okey.state_id;
469       Label paren_id = okey.paren_id;
470       for (SetIterator set_iter = close_source_sets_.FindSet(sit->second);
471            !set_iter.Done(); set_iter.Next()) {
472         StateId close_source = set_iter.Element();
473         if ((close_source < i) || (close_source >= i + split_size))
474           continue;
475         close_sources.insert(close_source + state_id_shift);
476         bd->OpenInsert(paren_id, close_source + state_id_shift);
477         bd->CloseInsert(paren_id, close_source + state_id_shift,
478                         open_dest + state_id_shift);
479       }
480     }
481 
482     for (typename unordered_set<StateId>::const_iterator it
483              = close_sources.begin();
484          it != close_sources.end();
485          ++it) {
486       bd->FinishInsert(*it);
487     }
488 
489   }
490   return bd;
491 }
492 
493 
494 }  // namespace fst
495 
496 #endif  // FST_EXTENSIONS_PDT_PAREN_H_
497