1 // paren.h
2
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Copyright 2005-2010 Google, Inc.
16 // Author: riley@google.com (Michael Riley)
17 //
18 // Common classes for PDT parentheses
19
20 // \file
21
22 #ifndef FST_EXTENSIONS_PDT_PAREN_H_
23 #define FST_EXTENSIONS_PDT_PAREN_H_
24
25 #include <algorithm>
26 #include <unordered_map>
27 using std::tr1::unordered_map;
28 using std::tr1::unordered_multimap;
29 #include <tr1/unordered_set>
30 using std::tr1::unordered_set;
31 using std::tr1::unordered_multiset;
32 #include <set>
33
34 #include <fst/extensions/pdt/pdt.h>
35 #include <fst/extensions/pdt/collection.h>
36 #include <fst/fst.h>
37 #include <fst/dfs-visit.h>
38
39
40 namespace fst {
41
42 //
43 // ParenState: Pair of an open (close) parenthesis and
44 // its destination (source) state.
45 //
46
47 template <class A>
48 class ParenState {
49 public:
50 typedef typename A::Label Label;
51 typedef typename A::StateId StateId;
52
53 struct Hash {
operatorHash54 size_t operator()(const ParenState<A> &p) const {
55 return p.paren_id + p.state_id * kPrime;
56 }
57 };
58
59 Label paren_id; // ID of open (close) paren
60 StateId state_id; // destination (source) state of open (close) paren
61
ParenState()62 ParenState() : paren_id(kNoLabel), state_id(kNoStateId) {}
63
ParenState(Label p,StateId s)64 ParenState(Label p, StateId s) : paren_id(p), state_id(s) {}
65
66 bool operator==(const ParenState<A> &p) const {
67 if (&p == this)
68 return true;
69 return p.paren_id == this->paren_id && p.state_id == this->state_id;
70 }
71
72 bool operator!=(const ParenState<A> &p) const { return !(p == *this); }
73
74 bool operator<(const ParenState<A> &p) const {
75 return paren_id < this->paren.id ||
76 (p.paren_id == this->paren.id && p.state_id < this->state_id);
77 }
78
79 private:
80 static const size_t kPrime;
81 };
82
83 template <class A>
84 const size_t ParenState<A>::kPrime = 7853;
85
86
87 // Creates an FST-style iterator from STL map and iterator.
88 template <class M>
89 class MapIterator {
90 public:
91 typedef typename M::const_iterator StlIterator;
92 typedef typename M::value_type PairType;
93 typedef typename PairType::second_type ValueType;
94
MapIterator(const M & m,StlIterator iter)95 MapIterator(const M &m, StlIterator iter)
96 : map_(m), begin_(iter), iter_(iter) {}
97
Done()98 bool Done() const {
99 return iter_ == map_.end() || iter_->first != begin_->first;
100 }
101
Value()102 ValueType Value() const { return iter_->second; }
Next()103 void Next() { ++iter_; }
Reset()104 void Reset() { iter_ = begin_; }
105
106 private:
107 const M &map_;
108 StlIterator begin_;
109 StlIterator iter_;
110 };
111
112 //
113 // PdtParenReachable: Provides various parenthesis reachability information
114 // on a PDT.
115 //
116
117 template <class A>
118 class PdtParenReachable {
119 public:
120 typedef typename A::StateId StateId;
121 typedef typename A::Label Label;
122 public:
123 // Maps from state ID to reachable paren IDs from (to) that state.
124 typedef unordered_multimap<StateId, Label> ParenMultiMap;
125
126 // Maps from paren ID and state ID to reachable state set ID
127 typedef unordered_map<ParenState<A>, ssize_t,
128 typename ParenState<A>::Hash> StateSetMap;
129
130 // Maps from paren ID and state ID to arcs exiting that state with that
131 // Label.
132 typedef unordered_multimap<ParenState<A>, A,
133 typename ParenState<A>::Hash> ParenArcMultiMap;
134
135 typedef MapIterator<ParenMultiMap> ParenIterator;
136
137 typedef MapIterator<ParenArcMultiMap> ParenArcIterator;
138
139 typedef typename Collection<ssize_t, StateId>::SetIterator SetIterator;
140
141 // Computes close (open) parenthesis reachabilty information for
142 // a PDT with bounded stack.
PdtParenReachable(const Fst<A> & fst,const vector<pair<Label,Label>> & parens,bool close)143 PdtParenReachable(const Fst<A> &fst,
144 const vector<pair<Label, Label> > &parens, bool close)
145 : fst_(fst),
146 parens_(parens),
147 close_(close) {
148 for (Label i = 0; i < parens.size(); ++i) {
149 const pair<Label, Label> &p = parens[i];
150 paren_id_map_[p.first] = i;
151 paren_id_map_[p.second] = i;
152 }
153
154 if (close_) {
155 StateId start = fst.Start();
156 if (start == kNoStateId)
157 return;
158 DFSearch(start, start);
159 } else {
160 FSTERROR() << "PdtParenReachable: open paren info not implemented";
161 }
162 }
163
164 // Given a state ID, returns an iterator over paren IDs
165 // for close (open) parens reachable from that state along balanced
166 // paths.
FindParens(StateId s)167 ParenIterator FindParens(StateId s) const {
168 return ParenIterator(paren_multimap_, paren_multimap_.find(s));
169 }
170
171 // Given a paren ID and a state ID s, returns an iterator over
172 // states that can be reached along balanced paths from (to) s that
173 // have have close (open) parentheses matching the paren ID exiting
174 // (entering) those states.
FindStates(Label paren_id,StateId s)175 SetIterator FindStates(Label paren_id, StateId s) const {
176 ParenState<A> paren_state(paren_id, s);
177 typename StateSetMap::const_iterator id_it = set_map_.find(paren_state);
178 if (id_it == set_map_.end()) {
179 return state_sets_.FindSet(-1);
180 } else {
181 return state_sets_.FindSet(id_it->second);
182 }
183 }
184
185 // Given a paren Id and a state ID s, return an iterator over
186 // arcs that exit (enter) s and are labeled with a close (open)
187 // parenthesis matching the paren ID.
FindParenArcs(Label paren_id,StateId s)188 ParenArcIterator FindParenArcs(Label paren_id, StateId s) const {
189 ParenState<A> paren_state(paren_id, s);
190 return ParenArcIterator(paren_arc_multimap_,
191 paren_arc_multimap_.find(paren_state));
192 }
193
194 private:
195 // DFS that gathers paren and state set information.
196 // Bool returns false when cycle detected.
197 bool DFSearch(StateId s, StateId start);
198
199 // Unions state sets together gathered by the DFS.
200 void ComputeStateSet(StateId s);
201
202 // Gather state set(s) from state 'nexts'.
203 void UpdateStateSet(StateId nexts, set<Label> *paren_set,
204 vector< set<StateId> > *state_sets) const;
205
206 const Fst<A> &fst_;
207 const vector<pair<Label, Label> > &parens_; // Paren ID -> Labels
208 bool close_; // Close/open paren info?
209 unordered_map<Label, Label> paren_id_map_; // Paren labels -> ID
210 ParenMultiMap paren_multimap_; // Paren reachability
211 ParenArcMultiMap paren_arc_multimap_; // Paren Arcs
212 vector<char> state_color_; // DFS state
213 mutable Collection<ssize_t, StateId> state_sets_; // Reachable states -> ID
214 StateSetMap set_map_; // ID -> Reachable states
215 DISALLOW_COPY_AND_ASSIGN(PdtParenReachable);
216 };
217
218 // DFS that gathers paren and state set information.
219 template <class A>
DFSearch(StateId s,StateId start)220 bool PdtParenReachable<A>::DFSearch(StateId s, StateId start) {
221 if (s >= state_color_.size())
222 state_color_.resize(s + 1, kDfsWhite);
223
224 if (state_color_[s] == kDfsBlack)
225 return true;
226
227 if (state_color_[s] == kDfsGrey)
228 return false;
229
230 state_color_[s] = kDfsGrey;
231
232 for (ArcIterator<Fst<A> > aiter(fst_, s);
233 !aiter.Done();
234 aiter.Next()) {
235 const A &arc = aiter.Value();
236
237 typename unordered_map<Label, Label>::const_iterator pit
238 = paren_id_map_.find(arc.ilabel);
239 if (pit != paren_id_map_.end()) { // paren?
240 Label paren_id = pit->second;
241 if (arc.ilabel == parens_[paren_id].first) { // open paren
242 DFSearch(arc.nextstate, arc.nextstate);
243 for (SetIterator set_iter = FindStates(paren_id, arc.nextstate);
244 !set_iter.Done(); set_iter.Next()) {
245 for (ParenArcIterator paren_arc_iter =
246 FindParenArcs(paren_id, set_iter.Element());
247 !paren_arc_iter.Done();
248 paren_arc_iter.Next()) {
249 const A &cparc = paren_arc_iter.Value();
250 DFSearch(cparc.nextstate, start);
251 }
252 }
253 }
254 } else { // non-paren
255 if(!DFSearch(arc.nextstate, start)) {
256 FSTERROR() << "PdtReachable: Underlying cyclicity not supported";
257 return true;
258 }
259 }
260 }
261 ComputeStateSet(s);
262 state_color_[s] = kDfsBlack;
263 return true;
264 }
265
266 // Unions state sets together gathered by the DFS.
267 template <class A>
ComputeStateSet(StateId s)268 void PdtParenReachable<A>::ComputeStateSet(StateId s) {
269 set<Label> paren_set;
270 vector< set<StateId> > state_sets(parens_.size());
271 for (ArcIterator< Fst<A> > aiter(fst_, s);
272 !aiter.Done();
273 aiter.Next()) {
274 const A &arc = aiter.Value();
275
276 typename unordered_map<Label, Label>::const_iterator pit
277 = paren_id_map_.find(arc.ilabel);
278 if (pit != paren_id_map_.end()) { // paren?
279 Label paren_id = pit->second;
280 if (arc.ilabel == parens_[paren_id].first) { // open paren
281 for (SetIterator set_iter =
282 FindStates(paren_id, arc.nextstate);
283 !set_iter.Done(); set_iter.Next()) {
284 for (ParenArcIterator paren_arc_iter =
285 FindParenArcs(paren_id, set_iter.Element());
286 !paren_arc_iter.Done();
287 paren_arc_iter.Next()) {
288 const A &cparc = paren_arc_iter.Value();
289 UpdateStateSet(cparc.nextstate, &paren_set, &state_sets);
290 }
291 }
292 } else { // close paren
293 paren_set.insert(paren_id);
294 state_sets[paren_id].insert(s);
295 ParenState<A> paren_state(paren_id, s);
296 paren_arc_multimap_.insert(make_pair(paren_state, arc));
297 }
298 } else { // non-paren
299 UpdateStateSet(arc.nextstate, &paren_set, &state_sets);
300 }
301 }
302
303 vector<StateId> state_set;
304 for (typename set<Label>::iterator paren_iter = paren_set.begin();
305 paren_iter != paren_set.end(); ++paren_iter) {
306 state_set.clear();
307 Label paren_id = *paren_iter;
308 paren_multimap_.insert(make_pair(s, paren_id));
309 for (typename set<StateId>::iterator state_iter
310 = state_sets[paren_id].begin();
311 state_iter != state_sets[paren_id].end();
312 ++state_iter) {
313 state_set.push_back(*state_iter);
314 }
315 ParenState<A> paren_state(paren_id, s);
316 set_map_[paren_state] = state_sets_.FindId(state_set);
317 }
318 }
319
320 // Gather state set(s) from state 'nexts'.
321 template <class A>
UpdateStateSet(StateId nexts,set<Label> * paren_set,vector<set<StateId>> * state_sets)322 void PdtParenReachable<A>::UpdateStateSet(
323 StateId nexts, set<Label> *paren_set,
324 vector< set<StateId> > *state_sets) const {
325 for(ParenIterator paren_iter = FindParens(nexts);
326 !paren_iter.Done(); paren_iter.Next()) {
327 Label paren_id = paren_iter.Value();
328 paren_set->insert(paren_id);
329 for (SetIterator set_iter = FindStates(paren_id, nexts);
330 !set_iter.Done(); set_iter.Next()) {
331 (*state_sets)[paren_id].insert(set_iter.Element());
332 }
333 }
334 }
335
336
337 // Store balancing parenthesis data for a PDT. Allows on-the-fly
338 // construction (e.g. in PdtShortestPath) unlike PdtParenReachable above.
339 template <class A>
340 class PdtBalanceData {
341 public:
342 typedef typename A::StateId StateId;
343 typedef typename A::Label Label;
344
345 // Hash set for open parens
346 typedef unordered_set<ParenState<A>, typename ParenState<A>::Hash> OpenParenSet;
347
348 // Maps from open paren destination state to parenthesis ID.
349 typedef unordered_multimap<StateId, Label> OpenParenMap;
350
351 // Maps from open paren state to source states of matching close parens
352 typedef unordered_multimap<ParenState<A>, StateId,
353 typename ParenState<A>::Hash> CloseParenMap;
354
355 // Maps from open paren state to close source set ID
356 typedef unordered_map<ParenState<A>, ssize_t,
357 typename ParenState<A>::Hash> CloseSourceMap;
358
359 typedef typename Collection<ssize_t, StateId>::SetIterator SetIterator;
360
PdtBalanceData()361 PdtBalanceData() {}
362
Clear()363 void Clear() {
364 open_paren_map_.clear();
365 close_paren_map_.clear();
366 }
367
368 // Adds an open parenthesis with destination state 'open_dest'.
OpenInsert(Label paren_id,StateId open_dest)369 void OpenInsert(Label paren_id, StateId open_dest) {
370 ParenState<A> key(paren_id, open_dest);
371 if (!open_paren_set_.count(key)) {
372 open_paren_set_.insert(key);
373 open_paren_map_.insert(make_pair(open_dest, paren_id));
374 }
375 }
376
377 // Adds a matching closing parenthesis with source state
378 // 'close_source' that balances an open_parenthesis with destination
379 // state 'open_dest' if OpenInsert() previously called
380 // (o.w. CloseInsert() does nothing).
CloseInsert(Label paren_id,StateId open_dest,StateId close_source)381 void CloseInsert(Label paren_id, StateId open_dest, StateId close_source) {
382 ParenState<A> key(paren_id, open_dest);
383 if (open_paren_set_.count(key))
384 close_paren_map_.insert(make_pair(key, close_source));
385 }
386
387 // Find close paren source states matching an open parenthesis.
388 // Methods that follow, iterate through those matching states.
389 // Should be called only after FinishInsert(open_dest).
Find(Label paren_id,StateId open_dest)390 SetIterator Find(Label paren_id, StateId open_dest) {
391 ParenState<A> close_key(paren_id, open_dest);
392 typename CloseSourceMap::const_iterator id_it =
393 close_source_map_.find(close_key);
394 if (id_it == close_source_map_.end()) {
395 return close_source_sets_.FindSet(-1);
396 } else {
397 return close_source_sets_.FindSet(id_it->second);
398 }
399 }
400
401 // Call when all open and close parenthesis insertions wrt open
402 // parentheses entering 'open_dest' are finished. Must be called
403 // before Find(open_dest). Stores close paren source state sets
404 // efficiently.
FinishInsert(StateId open_dest)405 void FinishInsert(StateId open_dest) {
406 vector<StateId> close_sources;
407 for (typename OpenParenMap::iterator oit = open_paren_map_.find(open_dest);
408 oit != open_paren_map_.end() && oit->first == open_dest;) {
409 Label paren_id = oit->second;
410 close_sources.clear();
411 ParenState<A> okey(paren_id, open_dest);
412 open_paren_set_.erase(open_paren_set_.find(okey));
413 for (typename CloseParenMap::iterator cit = close_paren_map_.find(okey);
414 cit != close_paren_map_.end() && cit->first == okey;) {
415 close_sources.push_back(cit->second);
416 close_paren_map_.erase(cit++);
417 }
418 sort(close_sources.begin(), close_sources.end());
419 typename vector<StateId>::iterator unique_end =
420 unique(close_sources.begin(), close_sources.end());
421 close_sources.resize(unique_end - close_sources.begin());
422
423 if (!close_sources.empty())
424 close_source_map_[okey] = close_source_sets_.FindId(close_sources);
425 open_paren_map_.erase(oit++);
426 }
427 }
428
429 // Return a new balance data object representing the reversed balance
430 // information.
431 PdtBalanceData<A> *Reverse(StateId num_states,
432 StateId num_split,
433 StateId state_id_shift) const;
434
435 private:
436 OpenParenSet open_paren_set_; // open par. at dest?
437
438 OpenParenMap open_paren_map_; // open parens per state
439 ParenState<A> open_dest_; // cur open dest. state
440 typename OpenParenMap::const_iterator open_iter_; // cur open parens/state
441
442 CloseParenMap close_paren_map_; // close states/open
443 // paren and state
444
445 CloseSourceMap close_source_map_; // paren, state to set ID
446 mutable Collection<ssize_t, StateId> close_source_sets_;
447 };
448
449 // Return a new balance data object representing the reversed balance
450 // information.
451 template <class A>
Reverse(StateId num_states,StateId num_split,StateId state_id_shift)452 PdtBalanceData<A> *PdtBalanceData<A>::Reverse(
453 StateId num_states,
454 StateId num_split,
455 StateId state_id_shift) const {
456 PdtBalanceData<A> *bd = new PdtBalanceData<A>;
457 unordered_set<StateId> close_sources;
458 StateId split_size = num_states / num_split;
459
460 for (StateId i = 0; i < num_states; i+= split_size) {
461 close_sources.clear();
462
463 for (typename CloseSourceMap::const_iterator
464 sit = close_source_map_.begin();
465 sit != close_source_map_.end();
466 ++sit) {
467 ParenState<A> okey = sit->first;
468 StateId open_dest = okey.state_id;
469 Label paren_id = okey.paren_id;
470 for (SetIterator set_iter = close_source_sets_.FindSet(sit->second);
471 !set_iter.Done(); set_iter.Next()) {
472 StateId close_source = set_iter.Element();
473 if ((close_source < i) || (close_source >= i + split_size))
474 continue;
475 close_sources.insert(close_source + state_id_shift);
476 bd->OpenInsert(paren_id, close_source + state_id_shift);
477 bd->CloseInsert(paren_id, close_source + state_id_shift,
478 open_dest + state_id_shift);
479 }
480 }
481
482 for (typename unordered_set<StateId>::const_iterator it
483 = close_sources.begin();
484 it != close_sources.end();
485 ++it) {
486 bd->FinishInsert(*it);
487 }
488
489 }
490 return bd;
491 }
492
493
494 } // namespace fst
495
496 #endif // FST_EXTENSIONS_PDT_PAREN_H_
497