• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // replace.h
2 
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Copyright 2005-2010 Google, Inc.
16 // Author: riley@google.com (Michael Riley)
17 //
18 // \file
19 // Recursively replace Fst arcs with other Fst(s) returning a PDT.
20 
21 #ifndef FST_EXTENSIONS_PDT_REPLACE_H__
22 #define FST_EXTENSIONS_PDT_REPLACE_H__
23 
24 #include <fst/replace.h>
25 
26 namespace fst {
27 
28 // Hash to paren IDs
29 template <typename S>
30 struct ReplaceParenHash {
operatorReplaceParenHash31   size_t operator()(const pair<size_t, S> &p) const {
32     return p.first + p.second * kPrime;
33   }
34  private:
35   static const size_t kPrime = 7853;
36 };
37 
38 template <typename S> const size_t ReplaceParenHash<S>::kPrime;
39 
40 // Builds a pushdown transducer (PDT) from an RTN specification
41 // identical to that in fst/lib/replace.h. The result is a PDT
42 // encoded as the FST 'ofst' where some transitions are labeled with
43 // open or close parentheses. To be interpreted as a PDT, the parens
44 // must balance on a path (see PdtExpand()). The open/close
45 // parenthesis label pairs are returned in 'parens'.
46 template <class Arc>
Replace(const vector<pair<typename Arc::Label,const Fst<Arc> * >> & ifst_array,MutableFst<Arc> * ofst,vector<pair<typename Arc::Label,typename Arc::Label>> * parens,typename Arc::Label root)47 void Replace(const vector<pair<typename Arc::Label,
48              const Fst<Arc>* > >& ifst_array,
49              MutableFst<Arc> *ofst,
50              vector<pair<typename Arc::Label,
51              typename Arc::Label> > *parens,
52              typename Arc::Label root) {
53   typedef typename Arc::Label Label;
54   typedef typename Arc::StateId StateId;
55   typedef typename Arc::Weight Weight;
56 
57   ofst->DeleteStates();
58   parens->clear();
59 
60   unordered_map<Label, size_t> label2id;
61   for (size_t i = 0; i < ifst_array.size(); ++i)
62     label2id[ifst_array[i].first] = i;
63 
64   Label max_label = kNoLabel;
65 
66   deque<size_t> non_term_queue;  // Queue of non-terminals to replace
67   unordered_set<Label> non_term_set;  // Set of non-terminals to replace
68   non_term_queue.push_back(root);
69   non_term_set.insert(root);
70 
71   // PDT state corr. to ith replace FST start state.
72   vector<StateId> fst_start(ifst_array.size(), kNoLabel);
73   // PDT state, weight pairs corr. to ith replace FST final state & weights.
74   vector< vector<pair<StateId, Weight> > > fst_final(ifst_array.size());
75 
76   // Builds single Fst combining all referenced input Fsts. Leaves in the
77   // non-termnals for now.  Tabulate the PDT states that correspond to
78   // the start and final states of the input Fsts.
79   for (StateId soff = 0; !non_term_queue.empty(); soff = ofst->NumStates()) {
80     Label label = non_term_queue.front();
81     non_term_queue.pop_front();
82     size_t fst_id = label2id[label];
83 
84     const Fst<Arc> *ifst = ifst_array[fst_id].second;
85     for (StateIterator< Fst<Arc> > siter(*ifst);
86          !siter.Done(); siter.Next()) {
87       StateId is = siter.Value();
88       StateId os = ofst->AddState();
89       if (is == ifst->Start()) {
90         fst_start[fst_id] = os;
91         if (label == root)
92           ofst->SetStart(os);
93       }
94       if (ifst->Final(is) != Weight::Zero()) {
95         if (label == root)
96           ofst->SetFinal(os, ifst->Final(is));
97         fst_final[fst_id].push_back(make_pair(os, ifst->Final(is)));
98       }
99       for (ArcIterator< Fst<Arc> > aiter(*ifst, is);
100            !aiter.Done(); aiter.Next()) {
101         Arc arc = aiter.Value();
102         if (max_label == kNoLabel || arc.olabel > max_label)
103           max_label = arc.olabel;
104         typename unordered_map<Label, size_t>::const_iterator it =
105             label2id.find(arc.olabel);
106         if (it != label2id.end()) {
107           size_t nfst_id = it->second;
108           if (ifst_array[nfst_id].second->Start() == -1)
109             continue;
110           if (non_term_set.count(arc.olabel) == 0) {
111             non_term_queue.push_back(arc.olabel);
112             non_term_set.insert(arc.olabel);
113           }
114         }
115         arc.nextstate += soff;
116         ofst->AddArc(os, arc);
117       }
118     }
119   }
120 
121   // Changes each non-terminal transition to an open parenthesis
122   // transition redirected to the PDT state that corresponds to the
123   // start state of the input FST for the non-terminal. Adds close parenthesis
124   // transitions from the PDT states corr. to the final states of the
125   // input FST for the non-terminal to the former destination state of the
126   // non-terminal transition.
127 
128   typedef MutableArcIterator< MutableFst<Arc> > MIter;
129   typedef unordered_map<pair<size_t, StateId >, size_t,
130                    ReplaceParenHash<StateId> > ParenMap;
131 
132   // Parenthesis pair ID per fst, state pair.
133   ParenMap paren_map;
134   // # of parenthesis pairs per fst.
135   vector<size_t> nparens(ifst_array.size(), 0);
136   // Initial open parenthesis label
137   Label first_paren = max_label + 1;
138 
139   for (StateIterator< Fst<Arc> > siter(*ofst);
140        !siter.Done(); siter.Next()) {
141     StateId os = siter.Value();
142     MIter *aiter = new MIter(ofst, os);
143     for (size_t n = 0; !aiter->Done(); aiter->Next(), ++n) {
144       Arc arc = aiter->Value();
145       typename unordered_map<Label, size_t>::const_iterator lit =
146           label2id.find(arc.olabel);
147       if (lit != label2id.end()) {
148         size_t nfst_id = lit->second;
149 
150         // Get parentheses. Ensures distinct parenthesis pair per
151         // non-terminal and destination state but otherwise reuses them.
152         Label open_paren = kNoLabel, close_paren = kNoLabel;
153         pair<size_t, StateId> paren_key(nfst_id, arc.nextstate);
154         typename ParenMap::const_iterator pit = paren_map.find(paren_key);
155         if (pit != paren_map.end()) {
156           size_t paren_id = pit->second;
157           open_paren = (*parens)[paren_id].first;
158           close_paren = (*parens)[paren_id].second;
159         } else {
160           size_t paren_id = nparens[nfst_id]++;
161           open_paren = first_paren + 2 * paren_id;
162           close_paren = open_paren + 1;
163           paren_map[paren_key] = paren_id;
164           if (paren_id >= parens->size())
165             parens->push_back(make_pair(open_paren, close_paren));
166         }
167 
168         // Sets open parenthesis.
169         Arc sarc(open_paren, open_paren, arc.weight, fst_start[nfst_id]);
170         aiter->SetValue(sarc);
171 
172         // Adds close parentheses.
173         for (size_t i = 0; i < fst_final[nfst_id].size(); ++i) {
174           pair<StateId, Weight> &p = fst_final[nfst_id][i];
175           Arc farc(close_paren, close_paren, p.second, arc.nextstate);
176 
177           ofst->AddArc(p.first, farc);
178           if (os == p.first) {  // Invalidated iterator
179             delete aiter;
180             aiter = new MIter(ofst, os);
181             aiter->Seek(n);
182           }
183         }
184       }
185     }
186     delete aiter;
187   }
188 }
189 
190 }  // namespace fst
191 
192 #endif  // FST_EXTENSIONS_PDT_REPLACE_H__
193