• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // replace-util.h
2 
3 
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16 // Copyright 2005-2010 Google, Inc.
17 // Author: riley@google.com (Michael Riley)
18 //
19 
20 // \file
21 // Utility classes for the recursive replacement of Fsts (RTNs).
22 
23 #ifndef FST_LIB_REPLACE_UTIL_H__
24 #define FST_LIB_REPLACE_UTIL_H__
25 
26 #include <vector>
27 using std::vector;
28 #include <tr1/unordered_map>
29 using std::tr1::unordered_map;
30 using std::tr1::unordered_multimap;
31 #include <tr1/unordered_set>
32 using std::tr1::unordered_set;
33 using std::tr1::unordered_multiset;
34 #include <map>
35 
36 #include <fst/connect.h>
37 #include <fst/mutable-fst.h>
38 #include <fst/topsort.h>
39 
40 
41 namespace fst {
42 
43 template <class Arc>
44 void Replace(const vector<pair<typename Arc::Label, const Fst<Arc>* > >&,
45              MutableFst<Arc> *, typename Arc::Label, bool);
46 
47 
48 // Utility class for the recursive replacement of Fsts (RTNs). The
49 // user provides a set of Label, Fst pairs at construction. These are
50 // used by methods for testing cyclic dependencies and connectedness
51 // and doing RTN connection and specific Fst replacement by label or
52 // for various optimization properties. The modified results can be
53 // obtained with the GetFstPairs() or GetMutableFstPairs() methods.
54 template <class Arc>
55 class ReplaceUtil {
56  public:
57   typedef typename Arc::Label Label;
58   typedef typename Arc::Weight Weight;
59   typedef typename Arc::StateId StateId;
60 
61   typedef pair<Label, const Fst<Arc>*> FstPair;
62   typedef pair<Label, MutableFst<Arc>*> MutableFstPair;
63   typedef unordered_map<Label, Label> NonTerminalHash;
64 
65   // Constructs from mutable Fsts; Fst ownership given to ReplaceUtil.
66   ReplaceUtil(const vector<MutableFstPair> &fst_pairs,
67               Label root_label, bool epsilon_on_replace = false);
68 
69   // Constructs from Fsts; Fst ownership retained by caller.
70   ReplaceUtil(const vector<FstPair> &fst_pairs,
71               Label root_label, bool epsilon_on_replace = false);
72 
73   // Constructs from ReplaceFst internals; ownership retained by caller.
74   ReplaceUtil(const vector<const Fst<Arc> *> &fst_array,
75               const NonTerminalHash &nonterminal_hash, Label root_fst,
76               bool epsilon_on_replace = false);
77 
~ReplaceUtil()78   ~ReplaceUtil() {
79     for (Label i = 0; i < fst_array_.size(); ++i)
80       delete fst_array_[i];
81   }
82 
83   // True if the non-terminal dependencies are cyclic. Cyclic
84   // dependencies will result in an unexpandable replace fst.
CyclicDependencies()85   bool CyclicDependencies() const {
86     GetDependencies(false);
87     return depprops_ & kCyclic;
88   }
89 
90   // Returns true if no useless Fsts, states or transitions.
Connected()91   bool Connected() const {
92     GetDependencies(false);
93     uint64 props = kAccessible | kCoAccessible;
94     for (Label i = 0; i < fst_array_.size(); ++i) {
95       if (!fst_array_[i])
96         continue;
97       if (fst_array_[i]->Properties(props, true) != props || !depaccess_[i])
98         return false;
99     }
100     return true;
101   }
102 
103   // Removes useless Fsts, states and transitions.
104   void Connect();
105 
106   // Replaces Fsts specified by labels.
107   // Does nothing if there are cyclic dependencies.
108   void ReplaceLabels(const vector<Label> &labels);
109 
110   // Replaces Fsts that have at most 'nstates' states, 'narcs' arcs and
111   // 'nnonterm' non-terminals (updating in reverse dependency order).
112   // Does nothing if there are cyclic dependencies.
113   void ReplaceBySize(size_t nstates, size_t narcs, size_t nnonterms);
114 
115   // Replaces singleton Fsts.
116   // Does nothing if there are cyclic dependencies.
ReplaceTrivial()117   void ReplaceTrivial() { ReplaceBySize(2, 1, 1); }
118 
119   // Replaces non-terminals that have at most 'ninstances' instances
120   // (updating in dependency order).
121   // Does nothing if there are cyclic dependencies.
122   void ReplaceByInstances(size_t ninstances);
123 
124   // Replaces non-terminals that have only one instance.
125   // Does nothing if there are cyclic dependencies.
ReplaceUnique()126   void ReplaceUnique() { ReplaceByInstances(1); }
127 
128   // Returns Label, Fst pairs; Fst ownership retained by ReplaceUtil.
129   void GetFstPairs(vector<FstPair> *fst_pairs);
130 
131   // Returns Label, MutableFst pairs; Fst ownership given to caller.
132   void GetMutableFstPairs(vector<MutableFstPair> *mutable_fst_pairs);
133 
134  private:
135   // Per Fst statistics
136   struct ReplaceStats {
137     StateId nstates;    // # of states
138     StateId nfinal;     // # of final states
139     size_t narcs;       // # of arcs
140     Label nnonterms;    // # of non-terminals in Fst
141     size_t nref;        // # of non-terminal instances referring to this Fst
142 
143     // # of times that ith Fst references this Fst
144     map<Label, size_t> inref;
145     // # of times that this Fst references the ith Fst
146     map<Label, size_t> outref;
147 
ReplaceStatsReplaceStats148     ReplaceStats()
149         : nstates(0),
150           nfinal(0),
151           narcs(0),
152           nnonterms(0),
153           nref(0) {}
154   };
155 
156   // Check Mutable Fsts exist o.w. create them.
157   void CheckMutableFsts();
158 
159   // Computes the dependency graph of the replace Fsts.
160   // If 'stats' is true, dependency statistics computed as well.
161   void GetDependencies(bool stats) const;
162 
ClearDependencies()163   void ClearDependencies() const {
164     depfst_.DeleteStates();
165     stats_.clear();
166     depprops_ = 0;
167     have_stats_ = false;
168   }
169 
170   // Get topological order of dependencies. Returns false with cyclic input.
171   bool GetTopOrder(const Fst<Arc> &fst, vector<Label> *toporder) const;
172 
173   // Update statistics assuming that jth Fst will be replaced.
174   void UpdateStats(Label j);
175 
176   Label root_label_;                              // root non-terminal
177   Label root_fst_;                                // root Fst ID
178   bool epsilon_on_replace_;                       // see Replace()
179   vector<const Fst<Arc> *> fst_array_;            // Fst per ID
180   vector<MutableFst<Arc> *> mutable_fst_array_;   // MutableFst per ID
181   vector<Label> nonterminal_array_;               // Fst ID to non-terminal
182   NonTerminalHash nonterminal_hash_;              // non-terminal to Fst ID
183   mutable VectorFst<Arc> depfst_;                 // Fst ID dependencies
184   mutable vector<bool> depaccess_;                // Fst ID accessibility
185   mutable uint64 depprops_;                       // dependency Fst props
186   mutable bool have_stats_;                       // have dependency statistics
187   mutable vector<ReplaceStats> stats_;            // Per Fst statistics
188   DISALLOW_COPY_AND_ASSIGN(ReplaceUtil);
189 };
190 
191 template <class Arc>
ReplaceUtil(const vector<MutableFstPair> & fst_pairs,Label root_label,bool epsilon_on_replace)192 ReplaceUtil<Arc>::ReplaceUtil(
193     const vector<MutableFstPair> &fst_pairs,
194     Label root_label, bool epsilon_on_replace)
195     : root_label_(root_label),
196       epsilon_on_replace_(epsilon_on_replace),
197       depprops_(0),
198       have_stats_(false) {
199   fst_array_.push_back(0);
200   mutable_fst_array_.push_back(0);
201   nonterminal_array_.push_back(kNoLabel);
202   for (Label i = 0; i < fst_pairs.size(); ++i) {
203     Label label = fst_pairs[i].first;
204     MutableFst<Arc> *fst = fst_pairs[i].second;
205     nonterminal_hash_[label] = fst_array_.size();
206     nonterminal_array_.push_back(label);
207     fst_array_.push_back(fst);
208     mutable_fst_array_.push_back(fst);
209   }
210   root_fst_ = nonterminal_hash_[root_label_];
211   if (!root_fst_)
212     FSTERROR() << "ReplaceUtil: no root FST for label: " << root_label_;
213 }
214 
215 template <class Arc>
ReplaceUtil(const vector<FstPair> & fst_pairs,Label root_label,bool epsilon_on_replace)216 ReplaceUtil<Arc>::ReplaceUtil(
217     const vector<FstPair> &fst_pairs,
218     Label root_label, bool epsilon_on_replace)
219     : root_label_(root_label),
220       epsilon_on_replace_(epsilon_on_replace),
221       depprops_(0),
222       have_stats_(false) {
223   fst_array_.push_back(0);
224   nonterminal_array_.push_back(kNoLabel);
225   for (Label i = 0; i < fst_pairs.size(); ++i) {
226     Label label = fst_pairs[i].first;
227     const Fst<Arc> *fst = fst_pairs[i].second;
228     nonterminal_hash_[label] = fst_array_.size();
229     nonterminal_array_.push_back(label);
230     fst_array_.push_back(fst->Copy());
231   }
232   root_fst_ = nonterminal_hash_[root_label];
233   if (!root_fst_)
234     FSTERROR() << "ReplaceUtil: no root FST for label: " << root_label_;
235 }
236 
237 template <class Arc>
ReplaceUtil(const vector<const Fst<Arc> * > & fst_array,const NonTerminalHash & nonterminal_hash,Label root_fst,bool epsilon_on_replace)238 ReplaceUtil<Arc>::ReplaceUtil(
239     const vector<const Fst<Arc> *> &fst_array,
240     const NonTerminalHash &nonterminal_hash, Label root_fst,
241     bool epsilon_on_replace)
242     : root_fst_(root_fst),
243       epsilon_on_replace_(epsilon_on_replace),
244       nonterminal_array_(fst_array.size()),
245       nonterminal_hash_(nonterminal_hash),
246       depprops_(0),
247       have_stats_(false) {
248   fst_array_.push_back(0);
249   for (Label i = 1; i < fst_array.size(); ++i)
250     fst_array_.push_back(fst_array[i]->Copy());
251   for (typename NonTerminalHash::const_iterator it =
252            nonterminal_hash.begin(); it != nonterminal_hash.end(); ++it)
253     nonterminal_array_[it->second] = it->first;
254   root_label_ = nonterminal_array_[root_fst_];
255 }
256 
257 template <class Arc>
GetDependencies(bool stats)258 void ReplaceUtil<Arc>::GetDependencies(bool stats) const {
259   if (depfst_.NumStates() > 0) {
260     if (stats && !have_stats_)
261       ClearDependencies();
262     else
263       return;
264   }
265 
266   have_stats_ = stats;
267   if (have_stats_)
268     stats_.reserve(fst_array_.size());
269 
270   for (Label i = 0; i < fst_array_.size(); ++i) {
271     depfst_.AddState();
272     depfst_.SetFinal(i, Weight::One());
273     if (have_stats_)
274       stats_.push_back(ReplaceStats());
275   }
276   depfst_.SetStart(root_fst_);
277 
278   // An arc from each state (representing the fst) to the
279   // state representing the fst being replaced
280   for (Label i = 0; i < fst_array_.size(); ++i) {
281     const Fst<Arc> *ifst = fst_array_[i];
282     if (!ifst)
283       continue;
284     for (StateIterator<Fst<Arc> > siter(*ifst); !siter.Done(); siter.Next()) {
285       StateId s = siter.Value();
286       if (have_stats_) {
287         ++stats_[i].nstates;
288         if (ifst->Final(s) != Weight::Zero())
289           ++stats_[i].nfinal;
290       }
291       for (ArcIterator<Fst<Arc> > aiter(*ifst, s);
292            !aiter.Done(); aiter.Next()) {
293         if (have_stats_)
294           ++stats_[i].narcs;
295         const Arc& arc = aiter.Value();
296 
297         typename NonTerminalHash::const_iterator it =
298             nonterminal_hash_.find(arc.olabel);
299         if (it != nonterminal_hash_.end()) {
300           Label j = it->second;
301           depfst_.AddArc(i, Arc(arc.olabel, arc.olabel, Weight::One(), j));
302           if (have_stats_) {
303             ++stats_[i].nnonterms;
304             ++stats_[j].nref;
305             ++stats_[j].inref[i];
306             ++stats_[i].outref[j];
307           }
308         }
309       }
310     }
311   }
312 
313   // Gets accessibility info
314   SccVisitor<Arc> scc_visitor(0, &depaccess_, 0, &depprops_);
315   DfsVisit(depfst_, &scc_visitor);
316 }
317 
318 template <class Arc>
UpdateStats(Label j)319 void ReplaceUtil<Arc>::UpdateStats(Label j) {
320   if (!have_stats_) {
321     FSTERROR() << "ReplaceUtil::UpdateStats: stats not available";
322     return;
323   }
324 
325   if (j == root_fst_)  // can't replace root
326     return;
327 
328   typedef typename map<Label, size_t>::iterator Iter;
329   for (Iter in = stats_[j].inref.begin();
330        in != stats_[j].inref.end();
331        ++in) {
332     Label i = in->first;
333     size_t ni = in->second;
334     stats_[i].nstates += stats_[j].nstates * ni;
335     stats_[i].narcs += (stats_[j].narcs + 1) * ni;  // narcs - 1 + 2 (eps)
336     stats_[i].nnonterms += (stats_[j].nnonterms - 1) * ni;
337     stats_[i].outref.erase(stats_[i].outref.find(j));
338     for (Iter out = stats_[j].outref.begin();
339          out != stats_[j].outref.end();
340          ++out) {
341       Label k = out->first;
342       size_t nk = out->second;
343       stats_[i].outref[k] += ni * nk;
344     }
345   }
346 
347   for (Iter out = stats_[j].outref.begin();
348        out != stats_[j].outref.end();
349        ++out) {
350     Label k = out->first;
351     size_t nk = out->second;
352     stats_[k].nref -= nk;
353     stats_[k].inref.erase(stats_[k].inref.find(j));
354     for (Iter in = stats_[j].inref.begin();
355          in != stats_[j].inref.end();
356          ++in) {
357       Label i = in->first;
358       size_t ni = in->second;
359       stats_[k].inref[i] += ni * nk;
360       stats_[k].nref += ni * nk;
361     }
362   }
363 }
364 
365 template <class Arc>
CheckMutableFsts()366 void ReplaceUtil<Arc>::CheckMutableFsts() {
367   if (mutable_fst_array_.size() == 0) {
368     for (Label i = 0; i < fst_array_.size(); ++i) {
369       if (!fst_array_[i]) {
370         mutable_fst_array_.push_back(0);
371       } else {
372         mutable_fst_array_.push_back(new VectorFst<Arc>(*fst_array_[i]));
373         delete fst_array_[i];
374         fst_array_[i] = mutable_fst_array_[i];
375       }
376     }
377   }
378 }
379 
380 template <class Arc>
Connect()381 void ReplaceUtil<Arc>::Connect() {
382   CheckMutableFsts();
383   uint64 props = kAccessible | kCoAccessible;
384   for (Label i = 0; i < mutable_fst_array_.size(); ++i) {
385     if (!mutable_fst_array_[i])
386       continue;
387     if (mutable_fst_array_[i]->Properties(props, false) != props)
388       fst::Connect(mutable_fst_array_[i]);
389   }
390   GetDependencies(false);
391   for (Label i = 0; i < mutable_fst_array_.size(); ++i) {
392     MutableFst<Arc> *fst = mutable_fst_array_[i];
393     if (fst && !depaccess_[i]) {
394       delete fst;
395       fst_array_[i] = 0;
396       mutable_fst_array_[i] = 0;
397     }
398   }
399   ClearDependencies();
400 }
401 
402 template <class Arc>
GetTopOrder(const Fst<Arc> & fst,vector<Label> * toporder)403 bool ReplaceUtil<Arc>::GetTopOrder(const Fst<Arc> &fst,
404                                    vector<Label> *toporder) const {
405   // Finds topological order of dependencies.
406   vector<StateId> order;
407   bool acyclic = false;
408 
409   TopOrderVisitor<Arc> top_order_visitor(&order, &acyclic);
410   DfsVisit(fst, &top_order_visitor);
411   if (!acyclic) {
412     LOG(WARNING) << "ReplaceUtil::GetTopOrder: Cyclical label dependencies";
413     return false;
414   }
415 
416   toporder->resize(order.size());
417   for (Label i = 0; i < order.size(); ++i)
418     (*toporder)[order[i]] = i;
419 
420   return true;
421 }
422 
423 template <class Arc>
ReplaceLabels(const vector<Label> & labels)424 void ReplaceUtil<Arc>::ReplaceLabels(const vector<Label> &labels) {
425   CheckMutableFsts();
426   unordered_set<Label> label_set;
427   for (Label i = 0; i < labels.size(); ++i)
428     if (labels[i] != root_label_)  // can't replace root
429       label_set.insert(labels[i]);
430 
431   // Finds Fst dependencies restricted to the labels requested.
432   GetDependencies(false);
433   VectorFst<Arc> pfst(depfst_);
434   for (StateId i = 0; i < pfst.NumStates(); ++i) {
435     vector<Arc> arcs;
436     for (ArcIterator< VectorFst<Arc> > aiter(pfst, i);
437          !aiter.Done(); aiter.Next()) {
438       const Arc &arc = aiter.Value();
439       Label label = nonterminal_array_[arc.nextstate];
440       if (label_set.count(label) > 0)
441         arcs.push_back(arc);
442     }
443     pfst.DeleteArcs(i);
444     for (size_t j = 0; j < arcs.size(); ++j)
445       pfst.AddArc(i, arcs[j]);
446   }
447 
448   vector<Label> toporder;
449   if (!GetTopOrder(pfst, &toporder)) {
450     ClearDependencies();
451     return;
452   }
453 
454   // Visits Fsts in reverse topological order of dependencies and
455   // performs replacements.
456   for (Label o = toporder.size() - 1; o >= 0;  --o) {
457     vector<FstPair> fst_pairs;
458     StateId s = toporder[o];
459     for (ArcIterator< VectorFst<Arc> > aiter(pfst, s);
460          !aiter.Done(); aiter.Next()) {
461       const Arc &arc = aiter.Value();
462       Label label = nonterminal_array_[arc.nextstate];
463       const Fst<Arc> *fst = fst_array_[arc.nextstate];
464       fst_pairs.push_back(make_pair(label, fst));
465     }
466     if (fst_pairs.empty())
467         continue;
468     Label label = nonterminal_array_[s];
469     const Fst<Arc> *fst = fst_array_[s];
470     fst_pairs.push_back(make_pair(label, fst));
471 
472     Replace(fst_pairs, mutable_fst_array_[s], label, epsilon_on_replace_);
473   }
474   ClearDependencies();
475 }
476 
477 template <class Arc>
ReplaceBySize(size_t nstates,size_t narcs,size_t nnonterms)478 void ReplaceUtil<Arc>::ReplaceBySize(size_t nstates, size_t narcs,
479                                      size_t nnonterms) {
480   vector<Label> labels;
481   GetDependencies(true);
482 
483   vector<Label> toporder;
484   if (!GetTopOrder(depfst_, &toporder)) {
485     ClearDependencies();
486     return;
487   }
488 
489   for (Label o = toporder.size() - 1; o >= 0; --o) {
490     Label j = toporder[o];
491     if (stats_[j].nstates <= nstates &&
492         stats_[j].narcs <= narcs &&
493         stats_[j].nnonterms <= nnonterms) {
494       labels.push_back(nonterminal_array_[j]);
495       UpdateStats(j);
496     }
497   }
498   ReplaceLabels(labels);
499 }
500 
501 template <class Arc>
ReplaceByInstances(size_t ninstances)502 void ReplaceUtil<Arc>::ReplaceByInstances(size_t ninstances) {
503   vector<Label> labels;
504   GetDependencies(true);
505 
506   vector<Label> toporder;
507   if (!GetTopOrder(depfst_, &toporder)) {
508     ClearDependencies();
509     return;
510   }
511   for (Label o = 0; o < toporder.size(); ++o) {
512     Label j = toporder[o];
513     if (stats_[j].nref <= ninstances) {
514       labels.push_back(nonterminal_array_[j]);
515       UpdateStats(j);
516     }
517   }
518   ReplaceLabels(labels);
519 }
520 
521 template <class Arc>
GetFstPairs(vector<FstPair> * fst_pairs)522 void ReplaceUtil<Arc>::GetFstPairs(vector<FstPair> *fst_pairs) {
523   CheckMutableFsts();
524   fst_pairs->clear();
525   for (Label i = 0; i < fst_array_.size(); ++i) {
526     Label label = nonterminal_array_[i];
527     const Fst<Arc> *fst = fst_array_[i];
528     if (!fst)
529       continue;
530     fst_pairs->push_back(make_pair(label, fst));
531   }
532 }
533 
534 template <class Arc>
GetMutableFstPairs(vector<MutableFstPair> * mutable_fst_pairs)535 void ReplaceUtil<Arc>::GetMutableFstPairs(
536     vector<MutableFstPair> *mutable_fst_pairs) {
537   CheckMutableFsts();
538   mutable_fst_pairs->clear();
539   for (Label i = 0; i < mutable_fst_array_.size(); ++i) {
540     Label label = nonterminal_array_[i];
541     MutableFst<Arc> *fst = mutable_fst_array_[i];
542     if (!fst)
543       continue;
544     mutable_fst_pairs->push_back(make_pair(label, fst->Copy()));
545   }
546 }
547 
548 }  // namespace fst
549 
550 #endif  // FST_LIB_REPLACE_UTIL_H__
551