1 // replace-util.h
2
3
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16 // Copyright 2005-2010 Google, Inc.
17 // Author: riley@google.com (Michael Riley)
18 //
19
20 // \file
21 // Utility classes for the recursive replacement of Fsts (RTNs).
22
23 #ifndef FST_LIB_REPLACE_UTIL_H__
24 #define FST_LIB_REPLACE_UTIL_H__
25
26 #include <vector>
27 using std::vector;
28 #include <tr1/unordered_map>
29 using std::tr1::unordered_map;
30 using std::tr1::unordered_multimap;
31 #include <tr1/unordered_set>
32 using std::tr1::unordered_set;
33 using std::tr1::unordered_multiset;
34 #include <map>
35
36 #include <fst/connect.h>
37 #include <fst/mutable-fst.h>
38 #include <fst/topsort.h>
39
40
41 namespace fst {
42
43 template <class Arc>
44 void Replace(const vector<pair<typename Arc::Label, const Fst<Arc>* > >&,
45 MutableFst<Arc> *, typename Arc::Label, bool);
46
47
48 // Utility class for the recursive replacement of Fsts (RTNs). The
49 // user provides a set of Label, Fst pairs at construction. These are
50 // used by methods for testing cyclic dependencies and connectedness
51 // and doing RTN connection and specific Fst replacement by label or
52 // for various optimization properties. The modified results can be
53 // obtained with the GetFstPairs() or GetMutableFstPairs() methods.
54 template <class Arc>
55 class ReplaceUtil {
56 public:
57 typedef typename Arc::Label Label;
58 typedef typename Arc::Weight Weight;
59 typedef typename Arc::StateId StateId;
60
61 typedef pair<Label, const Fst<Arc>*> FstPair;
62 typedef pair<Label, MutableFst<Arc>*> MutableFstPair;
63 typedef unordered_map<Label, Label> NonTerminalHash;
64
65 // Constructs from mutable Fsts; Fst ownership given to ReplaceUtil.
66 ReplaceUtil(const vector<MutableFstPair> &fst_pairs,
67 Label root_label, bool epsilon_on_replace = false);
68
69 // Constructs from Fsts; Fst ownership retained by caller.
70 ReplaceUtil(const vector<FstPair> &fst_pairs,
71 Label root_label, bool epsilon_on_replace = false);
72
73 // Constructs from ReplaceFst internals; ownership retained by caller.
74 ReplaceUtil(const vector<const Fst<Arc> *> &fst_array,
75 const NonTerminalHash &nonterminal_hash, Label root_fst,
76 bool epsilon_on_replace = false);
77
~ReplaceUtil()78 ~ReplaceUtil() {
79 for (Label i = 0; i < fst_array_.size(); ++i)
80 delete fst_array_[i];
81 }
82
83 // True if the non-terminal dependencies are cyclic. Cyclic
84 // dependencies will result in an unexpandable replace fst.
CyclicDependencies()85 bool CyclicDependencies() const {
86 GetDependencies(false);
87 return depprops_ & kCyclic;
88 }
89
90 // Returns true if no useless Fsts, states or transitions.
Connected()91 bool Connected() const {
92 GetDependencies(false);
93 uint64 props = kAccessible | kCoAccessible;
94 for (Label i = 0; i < fst_array_.size(); ++i) {
95 if (!fst_array_[i])
96 continue;
97 if (fst_array_[i]->Properties(props, true) != props || !depaccess_[i])
98 return false;
99 }
100 return true;
101 }
102
103 // Removes useless Fsts, states and transitions.
104 void Connect();
105
106 // Replaces Fsts specified by labels.
107 // Does nothing if there are cyclic dependencies.
108 void ReplaceLabels(const vector<Label> &labels);
109
110 // Replaces Fsts that have at most 'nstates' states, 'narcs' arcs and
111 // 'nnonterm' non-terminals (updating in reverse dependency order).
112 // Does nothing if there are cyclic dependencies.
113 void ReplaceBySize(size_t nstates, size_t narcs, size_t nnonterms);
114
115 // Replaces singleton Fsts.
116 // Does nothing if there are cyclic dependencies.
ReplaceTrivial()117 void ReplaceTrivial() { ReplaceBySize(2, 1, 1); }
118
119 // Replaces non-terminals that have at most 'ninstances' instances
120 // (updating in dependency order).
121 // Does nothing if there are cyclic dependencies.
122 void ReplaceByInstances(size_t ninstances);
123
124 // Replaces non-terminals that have only one instance.
125 // Does nothing if there are cyclic dependencies.
ReplaceUnique()126 void ReplaceUnique() { ReplaceByInstances(1); }
127
128 // Returns Label, Fst pairs; Fst ownership retained by ReplaceUtil.
129 void GetFstPairs(vector<FstPair> *fst_pairs);
130
131 // Returns Label, MutableFst pairs; Fst ownership given to caller.
132 void GetMutableFstPairs(vector<MutableFstPair> *mutable_fst_pairs);
133
134 private:
135 // Per Fst statistics
136 struct ReplaceStats {
137 StateId nstates; // # of states
138 StateId nfinal; // # of final states
139 size_t narcs; // # of arcs
140 Label nnonterms; // # of non-terminals in Fst
141 size_t nref; // # of non-terminal instances referring to this Fst
142
143 // # of times that ith Fst references this Fst
144 map<Label, size_t> inref;
145 // # of times that this Fst references the ith Fst
146 map<Label, size_t> outref;
147
ReplaceStatsReplaceStats148 ReplaceStats()
149 : nstates(0),
150 nfinal(0),
151 narcs(0),
152 nnonterms(0),
153 nref(0) {}
154 };
155
156 // Check Mutable Fsts exist o.w. create them.
157 void CheckMutableFsts();
158
159 // Computes the dependency graph of the replace Fsts.
160 // If 'stats' is true, dependency statistics computed as well.
161 void GetDependencies(bool stats) const;
162
ClearDependencies()163 void ClearDependencies() const {
164 depfst_.DeleteStates();
165 stats_.clear();
166 depprops_ = 0;
167 have_stats_ = false;
168 }
169
170 // Get topological order of dependencies. Returns false with cyclic input.
171 bool GetTopOrder(const Fst<Arc> &fst, vector<Label> *toporder) const;
172
173 // Update statistics assuming that jth Fst will be replaced.
174 void UpdateStats(Label j);
175
176 Label root_label_; // root non-terminal
177 Label root_fst_; // root Fst ID
178 bool epsilon_on_replace_; // see Replace()
179 vector<const Fst<Arc> *> fst_array_; // Fst per ID
180 vector<MutableFst<Arc> *> mutable_fst_array_; // MutableFst per ID
181 vector<Label> nonterminal_array_; // Fst ID to non-terminal
182 NonTerminalHash nonterminal_hash_; // non-terminal to Fst ID
183 mutable VectorFst<Arc> depfst_; // Fst ID dependencies
184 mutable vector<bool> depaccess_; // Fst ID accessibility
185 mutable uint64 depprops_; // dependency Fst props
186 mutable bool have_stats_; // have dependency statistics
187 mutable vector<ReplaceStats> stats_; // Per Fst statistics
188 DISALLOW_COPY_AND_ASSIGN(ReplaceUtil);
189 };
190
191 template <class Arc>
ReplaceUtil(const vector<MutableFstPair> & fst_pairs,Label root_label,bool epsilon_on_replace)192 ReplaceUtil<Arc>::ReplaceUtil(
193 const vector<MutableFstPair> &fst_pairs,
194 Label root_label, bool epsilon_on_replace)
195 : root_label_(root_label),
196 epsilon_on_replace_(epsilon_on_replace),
197 depprops_(0),
198 have_stats_(false) {
199 fst_array_.push_back(0);
200 mutable_fst_array_.push_back(0);
201 nonterminal_array_.push_back(kNoLabel);
202 for (Label i = 0; i < fst_pairs.size(); ++i) {
203 Label label = fst_pairs[i].first;
204 MutableFst<Arc> *fst = fst_pairs[i].second;
205 nonterminal_hash_[label] = fst_array_.size();
206 nonterminal_array_.push_back(label);
207 fst_array_.push_back(fst);
208 mutable_fst_array_.push_back(fst);
209 }
210 root_fst_ = nonterminal_hash_[root_label_];
211 if (!root_fst_)
212 FSTERROR() << "ReplaceUtil: no root FST for label: " << root_label_;
213 }
214
215 template <class Arc>
ReplaceUtil(const vector<FstPair> & fst_pairs,Label root_label,bool epsilon_on_replace)216 ReplaceUtil<Arc>::ReplaceUtil(
217 const vector<FstPair> &fst_pairs,
218 Label root_label, bool epsilon_on_replace)
219 : root_label_(root_label),
220 epsilon_on_replace_(epsilon_on_replace),
221 depprops_(0),
222 have_stats_(false) {
223 fst_array_.push_back(0);
224 nonterminal_array_.push_back(kNoLabel);
225 for (Label i = 0; i < fst_pairs.size(); ++i) {
226 Label label = fst_pairs[i].first;
227 const Fst<Arc> *fst = fst_pairs[i].second;
228 nonterminal_hash_[label] = fst_array_.size();
229 nonterminal_array_.push_back(label);
230 fst_array_.push_back(fst->Copy());
231 }
232 root_fst_ = nonterminal_hash_[root_label];
233 if (!root_fst_)
234 FSTERROR() << "ReplaceUtil: no root FST for label: " << root_label_;
235 }
236
237 template <class Arc>
ReplaceUtil(const vector<const Fst<Arc> * > & fst_array,const NonTerminalHash & nonterminal_hash,Label root_fst,bool epsilon_on_replace)238 ReplaceUtil<Arc>::ReplaceUtil(
239 const vector<const Fst<Arc> *> &fst_array,
240 const NonTerminalHash &nonterminal_hash, Label root_fst,
241 bool epsilon_on_replace)
242 : root_fst_(root_fst),
243 epsilon_on_replace_(epsilon_on_replace),
244 nonterminal_array_(fst_array.size()),
245 nonterminal_hash_(nonterminal_hash),
246 depprops_(0),
247 have_stats_(false) {
248 fst_array_.push_back(0);
249 for (Label i = 1; i < fst_array.size(); ++i)
250 fst_array_.push_back(fst_array[i]->Copy());
251 for (typename NonTerminalHash::const_iterator it =
252 nonterminal_hash.begin(); it != nonterminal_hash.end(); ++it)
253 nonterminal_array_[it->second] = it->first;
254 root_label_ = nonterminal_array_[root_fst_];
255 }
256
257 template <class Arc>
GetDependencies(bool stats)258 void ReplaceUtil<Arc>::GetDependencies(bool stats) const {
259 if (depfst_.NumStates() > 0) {
260 if (stats && !have_stats_)
261 ClearDependencies();
262 else
263 return;
264 }
265
266 have_stats_ = stats;
267 if (have_stats_)
268 stats_.reserve(fst_array_.size());
269
270 for (Label i = 0; i < fst_array_.size(); ++i) {
271 depfst_.AddState();
272 depfst_.SetFinal(i, Weight::One());
273 if (have_stats_)
274 stats_.push_back(ReplaceStats());
275 }
276 depfst_.SetStart(root_fst_);
277
278 // An arc from each state (representing the fst) to the
279 // state representing the fst being replaced
280 for (Label i = 0; i < fst_array_.size(); ++i) {
281 const Fst<Arc> *ifst = fst_array_[i];
282 if (!ifst)
283 continue;
284 for (StateIterator<Fst<Arc> > siter(*ifst); !siter.Done(); siter.Next()) {
285 StateId s = siter.Value();
286 if (have_stats_) {
287 ++stats_[i].nstates;
288 if (ifst->Final(s) != Weight::Zero())
289 ++stats_[i].nfinal;
290 }
291 for (ArcIterator<Fst<Arc> > aiter(*ifst, s);
292 !aiter.Done(); aiter.Next()) {
293 if (have_stats_)
294 ++stats_[i].narcs;
295 const Arc& arc = aiter.Value();
296
297 typename NonTerminalHash::const_iterator it =
298 nonterminal_hash_.find(arc.olabel);
299 if (it != nonterminal_hash_.end()) {
300 Label j = it->second;
301 depfst_.AddArc(i, Arc(arc.olabel, arc.olabel, Weight::One(), j));
302 if (have_stats_) {
303 ++stats_[i].nnonterms;
304 ++stats_[j].nref;
305 ++stats_[j].inref[i];
306 ++stats_[i].outref[j];
307 }
308 }
309 }
310 }
311 }
312
313 // Gets accessibility info
314 SccVisitor<Arc> scc_visitor(0, &depaccess_, 0, &depprops_);
315 DfsVisit(depfst_, &scc_visitor);
316 }
317
318 template <class Arc>
UpdateStats(Label j)319 void ReplaceUtil<Arc>::UpdateStats(Label j) {
320 if (!have_stats_) {
321 FSTERROR() << "ReplaceUtil::UpdateStats: stats not available";
322 return;
323 }
324
325 if (j == root_fst_) // can't replace root
326 return;
327
328 typedef typename map<Label, size_t>::iterator Iter;
329 for (Iter in = stats_[j].inref.begin();
330 in != stats_[j].inref.end();
331 ++in) {
332 Label i = in->first;
333 size_t ni = in->second;
334 stats_[i].nstates += stats_[j].nstates * ni;
335 stats_[i].narcs += (stats_[j].narcs + 1) * ni; // narcs - 1 + 2 (eps)
336 stats_[i].nnonterms += (stats_[j].nnonterms - 1) * ni;
337 stats_[i].outref.erase(stats_[i].outref.find(j));
338 for (Iter out = stats_[j].outref.begin();
339 out != stats_[j].outref.end();
340 ++out) {
341 Label k = out->first;
342 size_t nk = out->second;
343 stats_[i].outref[k] += ni * nk;
344 }
345 }
346
347 for (Iter out = stats_[j].outref.begin();
348 out != stats_[j].outref.end();
349 ++out) {
350 Label k = out->first;
351 size_t nk = out->second;
352 stats_[k].nref -= nk;
353 stats_[k].inref.erase(stats_[k].inref.find(j));
354 for (Iter in = stats_[j].inref.begin();
355 in != stats_[j].inref.end();
356 ++in) {
357 Label i = in->first;
358 size_t ni = in->second;
359 stats_[k].inref[i] += ni * nk;
360 stats_[k].nref += ni * nk;
361 }
362 }
363 }
364
365 template <class Arc>
CheckMutableFsts()366 void ReplaceUtil<Arc>::CheckMutableFsts() {
367 if (mutable_fst_array_.size() == 0) {
368 for (Label i = 0; i < fst_array_.size(); ++i) {
369 if (!fst_array_[i]) {
370 mutable_fst_array_.push_back(0);
371 } else {
372 mutable_fst_array_.push_back(new VectorFst<Arc>(*fst_array_[i]));
373 delete fst_array_[i];
374 fst_array_[i] = mutable_fst_array_[i];
375 }
376 }
377 }
378 }
379
380 template <class Arc>
Connect()381 void ReplaceUtil<Arc>::Connect() {
382 CheckMutableFsts();
383 uint64 props = kAccessible | kCoAccessible;
384 for (Label i = 0; i < mutable_fst_array_.size(); ++i) {
385 if (!mutable_fst_array_[i])
386 continue;
387 if (mutable_fst_array_[i]->Properties(props, false) != props)
388 fst::Connect(mutable_fst_array_[i]);
389 }
390 GetDependencies(false);
391 for (Label i = 0; i < mutable_fst_array_.size(); ++i) {
392 MutableFst<Arc> *fst = mutable_fst_array_[i];
393 if (fst && !depaccess_[i]) {
394 delete fst;
395 fst_array_[i] = 0;
396 mutable_fst_array_[i] = 0;
397 }
398 }
399 ClearDependencies();
400 }
401
402 template <class Arc>
GetTopOrder(const Fst<Arc> & fst,vector<Label> * toporder)403 bool ReplaceUtil<Arc>::GetTopOrder(const Fst<Arc> &fst,
404 vector<Label> *toporder) const {
405 // Finds topological order of dependencies.
406 vector<StateId> order;
407 bool acyclic = false;
408
409 TopOrderVisitor<Arc> top_order_visitor(&order, &acyclic);
410 DfsVisit(fst, &top_order_visitor);
411 if (!acyclic) {
412 LOG(WARNING) << "ReplaceUtil::GetTopOrder: Cyclical label dependencies";
413 return false;
414 }
415
416 toporder->resize(order.size());
417 for (Label i = 0; i < order.size(); ++i)
418 (*toporder)[order[i]] = i;
419
420 return true;
421 }
422
423 template <class Arc>
ReplaceLabels(const vector<Label> & labels)424 void ReplaceUtil<Arc>::ReplaceLabels(const vector<Label> &labels) {
425 CheckMutableFsts();
426 unordered_set<Label> label_set;
427 for (Label i = 0; i < labels.size(); ++i)
428 if (labels[i] != root_label_) // can't replace root
429 label_set.insert(labels[i]);
430
431 // Finds Fst dependencies restricted to the labels requested.
432 GetDependencies(false);
433 VectorFst<Arc> pfst(depfst_);
434 for (StateId i = 0; i < pfst.NumStates(); ++i) {
435 vector<Arc> arcs;
436 for (ArcIterator< VectorFst<Arc> > aiter(pfst, i);
437 !aiter.Done(); aiter.Next()) {
438 const Arc &arc = aiter.Value();
439 Label label = nonterminal_array_[arc.nextstate];
440 if (label_set.count(label) > 0)
441 arcs.push_back(arc);
442 }
443 pfst.DeleteArcs(i);
444 for (size_t j = 0; j < arcs.size(); ++j)
445 pfst.AddArc(i, arcs[j]);
446 }
447
448 vector<Label> toporder;
449 if (!GetTopOrder(pfst, &toporder)) {
450 ClearDependencies();
451 return;
452 }
453
454 // Visits Fsts in reverse topological order of dependencies and
455 // performs replacements.
456 for (Label o = toporder.size() - 1; o >= 0; --o) {
457 vector<FstPair> fst_pairs;
458 StateId s = toporder[o];
459 for (ArcIterator< VectorFst<Arc> > aiter(pfst, s);
460 !aiter.Done(); aiter.Next()) {
461 const Arc &arc = aiter.Value();
462 Label label = nonterminal_array_[arc.nextstate];
463 const Fst<Arc> *fst = fst_array_[arc.nextstate];
464 fst_pairs.push_back(make_pair(label, fst));
465 }
466 if (fst_pairs.empty())
467 continue;
468 Label label = nonterminal_array_[s];
469 const Fst<Arc> *fst = fst_array_[s];
470 fst_pairs.push_back(make_pair(label, fst));
471
472 Replace(fst_pairs, mutable_fst_array_[s], label, epsilon_on_replace_);
473 }
474 ClearDependencies();
475 }
476
477 template <class Arc>
ReplaceBySize(size_t nstates,size_t narcs,size_t nnonterms)478 void ReplaceUtil<Arc>::ReplaceBySize(size_t nstates, size_t narcs,
479 size_t nnonterms) {
480 vector<Label> labels;
481 GetDependencies(true);
482
483 vector<Label> toporder;
484 if (!GetTopOrder(depfst_, &toporder)) {
485 ClearDependencies();
486 return;
487 }
488
489 for (Label o = toporder.size() - 1; o >= 0; --o) {
490 Label j = toporder[o];
491 if (stats_[j].nstates <= nstates &&
492 stats_[j].narcs <= narcs &&
493 stats_[j].nnonterms <= nnonterms) {
494 labels.push_back(nonterminal_array_[j]);
495 UpdateStats(j);
496 }
497 }
498 ReplaceLabels(labels);
499 }
500
501 template <class Arc>
ReplaceByInstances(size_t ninstances)502 void ReplaceUtil<Arc>::ReplaceByInstances(size_t ninstances) {
503 vector<Label> labels;
504 GetDependencies(true);
505
506 vector<Label> toporder;
507 if (!GetTopOrder(depfst_, &toporder)) {
508 ClearDependencies();
509 return;
510 }
511 for (Label o = 0; o < toporder.size(); ++o) {
512 Label j = toporder[o];
513 if (stats_[j].nref <= ninstances) {
514 labels.push_back(nonterminal_array_[j]);
515 UpdateStats(j);
516 }
517 }
518 ReplaceLabels(labels);
519 }
520
521 template <class Arc>
GetFstPairs(vector<FstPair> * fst_pairs)522 void ReplaceUtil<Arc>::GetFstPairs(vector<FstPair> *fst_pairs) {
523 CheckMutableFsts();
524 fst_pairs->clear();
525 for (Label i = 0; i < fst_array_.size(); ++i) {
526 Label label = nonterminal_array_[i];
527 const Fst<Arc> *fst = fst_array_[i];
528 if (!fst)
529 continue;
530 fst_pairs->push_back(make_pair(label, fst));
531 }
532 }
533
534 template <class Arc>
GetMutableFstPairs(vector<MutableFstPair> * mutable_fst_pairs)535 void ReplaceUtil<Arc>::GetMutableFstPairs(
536 vector<MutableFstPair> *mutable_fst_pairs) {
537 CheckMutableFsts();
538 mutable_fst_pairs->clear();
539 for (Label i = 0; i < mutable_fst_array_.size(); ++i) {
540 Label label = nonterminal_array_[i];
541 MutableFst<Arc> *fst = mutable_fst_array_[i];
542 if (!fst)
543 continue;
544 mutable_fst_pairs->push_back(make_pair(label, fst->Copy()));
545 }
546 }
547
548 } // namespace fst
549
550 #endif // FST_LIB_REPLACE_UTIL_H__
551