1
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 // http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 //
14 // Copyright 2005-2010 Google, Inc.
15 // Author: jpr@google.com (Jake Ratkiewicz)
16 // Convenience file for including all PDT operations at once, and/or
17 // registering them for new arc types.
18
19 #ifndef FST_EXTENSIONS_PDT_PDTSCRIPT_H_
20 #define FST_EXTENSIONS_PDT_PDTSCRIPT_H_
21
22 #include <utility>
23 using std::pair; using std::make_pair;
24 #include <vector>
25 using std::vector;
26
27 #include <fst/compose.h> // for ComposeOptions
28 #include <fst/util.h>
29
30 #include <fst/script/fst-class.h>
31 #include <fst/script/arg-packs.h>
32 #include <fst/script/shortest-path.h>
33
34 #include <fst/extensions/pdt/compose.h>
35 #include <fst/extensions/pdt/expand.h>
36 #include <fst/extensions/pdt/info.h>
37 #include <fst/extensions/pdt/replace.h>
38 #include <fst/extensions/pdt/reverse.h>
39 #include <fst/extensions/pdt/shortest-path.h>
40
41
42 namespace fst {
43 namespace script {
44
45 // PDT COMPOSE
46
47 typedef args::Package<const FstClass &,
48 const FstClass &,
49 const vector<pair<int64, int64> >&,
50 MutableFstClass *,
51 const PdtComposeOptions &,
52 bool> PdtComposeArgs;
53
54 template<class Arc>
PdtCompose(PdtComposeArgs * args)55 void PdtCompose(PdtComposeArgs *args) {
56 const Fst<Arc> &ifst1 = *(args->arg1.GetFst<Arc>());
57 const Fst<Arc> &ifst2 = *(args->arg2.GetFst<Arc>());
58 MutableFst<Arc> *ofst = args->arg4->GetMutableFst<Arc>();
59
60 vector<pair<typename Arc::Label, typename Arc::Label> > parens(
61 args->arg3.size());
62
63 for (size_t i = 0; i < parens.size(); ++i) {
64 parens[i].first = args->arg3[i].first;
65 parens[i].second = args->arg3[i].second;
66 }
67
68 if (args->arg6) {
69 Compose(ifst1, parens, ifst2, ofst, args->arg5);
70 } else {
71 Compose(ifst1, ifst2, parens, ofst, args->arg5);
72 }
73 }
74
75 void PdtCompose(const FstClass & ifst1,
76 const FstClass & ifst2,
77 const vector<pair<int64, int64> > &parens,
78 MutableFstClass *ofst,
79 const PdtComposeOptions &copts,
80 bool left_pdt);
81
82 // PDT EXPAND
83
84 struct PdtExpandOptions {
85 bool connect;
86 bool keep_parentheses;
87 WeightClass weight_threshold;
88
89 PdtExpandOptions(bool c = true, bool k = false,
90 WeightClass w = WeightClass::Zero())
connectPdtExpandOptions91 : connect(c), keep_parentheses(k), weight_threshold(w) {}
92 };
93
94 typedef args::Package<const FstClass &,
95 const vector<pair<int64, int64> >&,
96 MutableFstClass *, PdtExpandOptions> PdtExpandArgs;
97
98 template<class Arc>
PdtExpand(PdtExpandArgs * args)99 void PdtExpand(PdtExpandArgs *args) {
100 const Fst<Arc> &fst = *(args->arg1.GetFst<Arc>());
101 MutableFst<Arc> *ofst = args->arg3->GetMutableFst<Arc>();
102
103 vector<pair<typename Arc::Label, typename Arc::Label> > parens(
104 args->arg2.size());
105 for (size_t i = 0; i < parens.size(); ++i) {
106 parens[i].first = args->arg2[i].first;
107 parens[i].second = args->arg2[i].second;
108 }
109 Expand(fst, parens, ofst,
110 ExpandOptions<Arc>(
111 args->arg4.connect, args->arg4.keep_parentheses,
112 *(args->arg4.weight_threshold.GetWeight<typename Arc::Weight>())));
113 }
114
115 void PdtExpand(const FstClass &ifst,
116 const vector<pair<int64, int64> > &parens,
117 MutableFstClass *ofst, const PdtExpandOptions &opts);
118
119 void PdtExpand(const FstClass &ifst,
120 const vector<pair<int64, int64> > &parens,
121 MutableFstClass *ofst, bool connect);
122
123 // PDT REPLACE
124
125 typedef args::Package<const vector<pair<int64, const FstClass*> > &,
126 MutableFstClass *,
127 vector<pair<int64, int64> > *,
128 const int64 &> PdtReplaceArgs;
129 template<class Arc>
PdtReplace(PdtReplaceArgs * args)130 void PdtReplace(PdtReplaceArgs *args) {
131 vector<pair<typename Arc::Label, const Fst<Arc> *> > tuples(
132 args->arg1.size());
133 for (size_t i = 0; i < tuples.size(); ++i) {
134 tuples[i].first = args->arg1[i].first;
135 tuples[i].second = (args->arg1[i].second)->GetFst<Arc>();
136 }
137 MutableFst<Arc> *ofst = args->arg2->GetMutableFst<Arc>();
138 vector<pair<typename Arc::Label, typename Arc::Label> > parens(
139 args->arg3->size());
140
141 for (size_t i = 0; i < parens.size(); ++i) {
142 parens[i].first = args->arg3->at(i).first;
143 parens[i].second = args->arg3->at(i).second;
144 }
145 Replace(tuples, ofst, &parens, args->arg4);
146
147 // now copy parens back
148 args->arg3->resize(parens.size());
149 for (size_t i = 0; i < parens.size(); ++i) {
150 (*args->arg3)[i].first = parens[i].first;
151 (*args->arg3)[i].second = parens[i].second;
152 }
153 }
154
155 void PdtReplace(const vector<pair<int64, const FstClass*> > &fst_tuples,
156 MutableFstClass *ofst,
157 vector<pair<int64, int64> > *parens,
158 const int64 &root);
159
160 // PDT REVERSE
161
162 typedef args::Package<const FstClass &,
163 const vector<pair<int64, int64> >&,
164 MutableFstClass *> PdtReverseArgs;
165
166 template<class Arc>
PdtReverse(PdtReverseArgs * args)167 void PdtReverse(PdtReverseArgs *args) {
168 const Fst<Arc> &fst = *(args->arg1.GetFst<Arc>());
169 MutableFst<Arc> *ofst = args->arg3->GetMutableFst<Arc>();
170
171 vector<pair<typename Arc::Label, typename Arc::Label> > parens(
172 args->arg2.size());
173 for (size_t i = 0; i < parens.size(); ++i) {
174 parens[i].first = args->arg2[i].first;
175 parens[i].second = args->arg2[i].second;
176 }
177 Reverse(fst, parens, ofst);
178 }
179
180 void PdtReverse(const FstClass &ifst,
181 const vector<pair<int64, int64> > &parens,
182 MutableFstClass *ofst);
183
184
185 // PDT SHORTESTPATH
186
187 struct PdtShortestPathOptions {
188 QueueType queue_type;
189 bool keep_parentheses;
190 bool path_gc;
191
192 PdtShortestPathOptions(QueueType qt = FIFO_QUEUE,
193 bool kp = false, bool gc = true)
queue_typePdtShortestPathOptions194 : queue_type(qt), keep_parentheses(kp), path_gc(gc) {}
195 };
196
197 typedef args::Package<const FstClass &,
198 const vector<pair<int64, int64> >&,
199 MutableFstClass *,
200 const PdtShortestPathOptions &> PdtShortestPathArgs;
201
202 template<class Arc>
PdtShortestPath(PdtShortestPathArgs * args)203 void PdtShortestPath(PdtShortestPathArgs *args) {
204 typedef typename Arc::StateId StateId;
205 typedef typename Arc::Label Label;
206 typedef typename Arc::Weight Weight;
207
208 const Fst<Arc> &fst = *(args->arg1.GetFst<Arc>());
209 MutableFst<Arc> *ofst = args->arg3->GetMutableFst<Arc>();
210 const PdtShortestPathOptions &opts = args->arg4;
211
212
213 vector<pair<Label, Label> > parens(args->arg2.size());
214 for (size_t i = 0; i < parens.size(); ++i) {
215 parens[i].first = args->arg2[i].first;
216 parens[i].second = args->arg2[i].second;
217 }
218
219 switch (opts.queue_type) {
220 default:
221 FSTERROR() << "Unknown queue type: " << opts.queue_type;
222 case FIFO_QUEUE: {
223 typedef FifoQueue<StateId> Queue;
224 fst::PdtShortestPathOptions<Arc, Queue> spopts(opts.keep_parentheses,
225 opts.path_gc);
226 ShortestPath(fst, parens, ofst, spopts);
227 return;
228 }
229 case LIFO_QUEUE: {
230 typedef LifoQueue<StateId> Queue;
231 fst::PdtShortestPathOptions<Arc, Queue> spopts(opts.keep_parentheses,
232 opts.path_gc);
233 ShortestPath(fst, parens, ofst, spopts);
234 return;
235 }
236 case STATE_ORDER_QUEUE: {
237 typedef StateOrderQueue<StateId> Queue;
238 fst::PdtShortestPathOptions<Arc, Queue> spopts(opts.keep_parentheses,
239 opts.path_gc);
240 ShortestPath(fst, parens, ofst, spopts);
241 return;
242 }
243 }
244 }
245
246 void PdtShortestPath(const FstClass &ifst,
247 const vector<pair<int64, int64> > &parens,
248 MutableFstClass *ofst,
249 const PdtShortestPathOptions &opts =
250 PdtShortestPathOptions());
251
252 // PRINT INFO
253
254 typedef args::Package<const FstClass &,
255 const vector<pair<int64, int64> > &> PrintPdtInfoArgs;
256
257 template<class Arc>
PrintPdtInfo(PrintPdtInfoArgs * args)258 void PrintPdtInfo(PrintPdtInfoArgs *args) {
259 const Fst<Arc> &fst = *(args->arg1.GetFst<Arc>());
260 vector<pair<typename Arc::Label, typename Arc::Label> > parens(
261 args->arg2.size());
262 for (size_t i = 0; i < parens.size(); ++i) {
263 parens[i].first = args->arg2[i].first;
264 parens[i].second = args->arg2[i].second;
265 }
266 PdtInfo<Arc> pdtinfo(fst, parens);
267 PrintPdtInfo(pdtinfo);
268 }
269
270 void PrintPdtInfo(const FstClass &ifst,
271 const vector<pair<int64, int64> > &parens);
272
273 } // namespace script
274 } // namespace fst
275
276
277 #define REGISTER_FST_PDT_OPERATIONS(ArcType) \
278 REGISTER_FST_OPERATION(PdtCompose, ArcType, PdtComposeArgs); \
279 REGISTER_FST_OPERATION(PdtExpand, ArcType, PdtExpandArgs); \
280 REGISTER_FST_OPERATION(PdtReplace, ArcType, PdtReplaceArgs); \
281 REGISTER_FST_OPERATION(PdtReverse, ArcType, PdtReverseArgs); \
282 REGISTER_FST_OPERATION(PdtShortestPath, ArcType, PdtShortestPathArgs); \
283 REGISTER_FST_OPERATION(PrintPdtInfo, ArcType, PrintPdtInfoArgs)
284 #endif // FST_EXTENSIONS_PDT_PDTSCRIPT_H_
285