1 // replace.h
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 //
16 // \file
17 // Functions and classes for the recursive replacement of Fsts.
18 //
19
20 #ifndef FST_LIB_REPLACE_H__
21 #define FST_LIB_REPLACE_H__
22
23 #include <ext/hash_map>
24 using __gnu_cxx::hash_map;
25
26 #include "fst/lib/fst.h"
27 #include "fst/lib/cache.h"
28 #include "fst/lib/test-properties.h"
29
30 namespace fst {
31
32 // By default ReplaceFst will copy the input label of the 'replace arc'.
33 // For acceptors we do not want this behaviour. Instead we need to
34 // create an epsilon arc when recursing into the appropriate Fst.
35 // The epsilon_on_replace option can be used to toggle this behaviour.
36 struct ReplaceFstOptions : CacheOptions {
37 int64 root; // root rule for expansion
38 bool epsilon_on_replace;
39
ReplaceFstOptionsReplaceFstOptions40 ReplaceFstOptions(const CacheOptions &opts, int64 r)
41 : CacheOptions(opts), root(r), epsilon_on_replace(false) {}
ReplaceFstOptionsReplaceFstOptions42 explicit ReplaceFstOptions(int64 r)
43 : root(r), epsilon_on_replace(false) {}
ReplaceFstOptionsReplaceFstOptions44 ReplaceFstOptions(int64 r, bool epsilon_replace_arc)
45 : root(r), epsilon_on_replace(epsilon_replace_arc) {}
ReplaceFstOptionsReplaceFstOptions46 ReplaceFstOptions()
47 : root(kNoLabel), epsilon_on_replace(false) {}
48 };
49
50 //
51 // \class ReplaceFstImpl
52 // \brief Implementation class for replace class Fst
53 //
54 // The replace implementation class supports a dynamic
55 // expansion of a recursive transition network represented as Fst
56 // with dynamic replacable arcs.
57 //
58 template <class A>
59 class ReplaceFstImpl : public CacheImpl<A> {
60 public:
61 using FstImpl<A>::SetType;
62 using FstImpl<A>::SetProperties;
63 using FstImpl<A>::Properties;
64 using FstImpl<A>::SetInputSymbols;
65 using FstImpl<A>::SetOutputSymbols;
66 using FstImpl<A>::InputSymbols;
67 using FstImpl<A>::OutputSymbols;
68
69 using CacheImpl<A>::HasStart;
70 using CacheImpl<A>::HasArcs;
71 using CacheImpl<A>::SetStart;
72
73 typedef typename A::Label Label;
74 typedef typename A::Weight Weight;
75 typedef typename A::StateId StateId;
76 typedef CacheState<A> State;
77 typedef A Arc;
78 typedef hash_map<Label, Label> NonTerminalHash;
79
80
81 // \struct StateTuple
82 // \brief Tuple of information that uniquely defines a state
83 struct StateTuple {
84 typedef int PrefixId;
85
StateTupleStateTuple86 StateTuple() {}
StateTupleStateTuple87 StateTuple(PrefixId p, StateId f, StateId s) :
88 prefix_id(p), fst_id(f), fst_state(s) {}
89
90 PrefixId prefix_id; // index in prefix table
91 StateId fst_id; // current fst being walked
92 StateId fst_state; // current state in fst being walked, not to be
93 // confused with the state_id of the combined fst
94 };
95
96 // constructor for replace class implementation.
97 // \param fst_tuples array of label/fst tuples, one for each non-terminal
ReplaceFstImpl(const vector<pair<Label,const Fst<A> * >> & fst_tuples,const ReplaceFstOptions & opts)98 ReplaceFstImpl(const vector< pair<Label, const Fst<A>* > >& fst_tuples,
99 const ReplaceFstOptions &opts)
100 : CacheImpl<A>(opts), opts_(opts) {
101 SetType("replace");
102 if (fst_tuples.size() > 0) {
103 SetInputSymbols(fst_tuples[0].second->InputSymbols());
104 SetOutputSymbols(fst_tuples[0].second->OutputSymbols());
105 }
106
107 fst_array_.push_back(0);
108 for (size_t i = 0; i < fst_tuples.size(); ++i)
109 AddFst(fst_tuples[i].first, fst_tuples[i].second);
110
111 SetRoot(opts.root);
112 }
113
ReplaceFstImpl(const ReplaceFstOptions & opts)114 explicit ReplaceFstImpl(const ReplaceFstOptions &opts)
115 : CacheImpl<A>(opts), opts_(opts), root_(kNoLabel) {
116 fst_array_.push_back(0);
117 }
118
ReplaceFstImpl(const ReplaceFstImpl & impl)119 ReplaceFstImpl(const ReplaceFstImpl& impl)
120 : opts_(impl.opts_), state_tuples_(impl.state_tuples_),
121 state_hash_(impl.state_hash_),
122 prefix_hash_(impl.prefix_hash_),
123 stackprefix_array_(impl.stackprefix_array_),
124 nonterminal_hash_(impl.nonterminal_hash_),
125 root_(impl.root_) {
126 SetType("replace");
127 SetProperties(impl.Properties(), kCopyProperties);
128 SetInputSymbols(InputSymbols());
129 SetOutputSymbols(OutputSymbols());
130 fst_array_.reserve(impl.fst_array_.size());
131 fst_array_.push_back(0);
132 for (size_t i = 1; i < impl.fst_array_.size(); ++i)
133 fst_array_.push_back(impl.fst_array_[i]->Copy());
134 }
135
~ReplaceFstImpl()136 ~ReplaceFstImpl() {
137 for (size_t i = 1; i < fst_array_.size(); ++i) {
138 delete fst_array_[i];
139 }
140 }
141
142 // Add to Fst array
AddFst(Label label,const Fst<A> * fst)143 void AddFst(Label label, const Fst<A>* fst) {
144 nonterminal_hash_[label] = fst_array_.size();
145 fst_array_.push_back(fst->Copy());
146 if (fst_array_.size() > 1) {
147 vector<uint64> inprops(fst_array_.size());
148
149 for (size_t i = 1; i < fst_array_.size(); ++i) {
150 inprops[i] = fst_array_[i]->Properties(kCopyProperties, false);
151 }
152 SetProperties(ReplaceProperties(inprops));
153
154 const SymbolTable* isymbols = fst_array_[1]->InputSymbols();
155 const SymbolTable* osymbols = fst_array_[1]->OutputSymbols();
156 for (size_t i = 2; i < fst_array_.size(); ++i) {
157 if (!CompatSymbols(isymbols, fst_array_[i]->InputSymbols())) {
158 LOG(FATAL) << "ReplaceFst::AddFst input symbols of Fst " << i-1
159 << " does not match input symbols of base Fst (0'th fst)";
160 }
161 if (!CompatSymbols(osymbols, fst_array_[i]->OutputSymbols())) {
162 LOG(FATAL) << "ReplaceFst::AddFst output symbols of Fst " << i-1
163 << " does not match output symbols of base Fst "
164 << "(0'th fst)";
165 }
166 }
167 }
168 }
169
170 // Computes the dependency graph of the replace class and returns
171 // true if the dependencies are cyclic. Cyclic dependencies will result
172 // in an un-expandable replace fst.
CyclicDependencies()173 bool CyclicDependencies() const {
174 StdVectorFst depfst;
175
176 // one state for each fst
177 for (size_t i = 1; i < fst_array_.size(); ++i)
178 depfst.AddState();
179
180 // an arc from each state (representing the fst) to the
181 // state representing the fst being replaced
182 for (size_t i = 1; i < fst_array_.size(); ++i) {
183 for (StateIterator<Fst<A> > siter(*(fst_array_[i]));
184 !siter.Done(); siter.Next()) {
185 for (ArcIterator<Fst<A> > aiter(*(fst_array_[i]), siter.Value());
186 !aiter.Done(); aiter.Next()) {
187 const A& arc = aiter.Value();
188
189 typename NonTerminalHash::const_iterator it =
190 nonterminal_hash_.find(arc.olabel);
191 if (it != nonterminal_hash_.end()) {
192 Label j = it->second - 1;
193 depfst.AddArc(i - 1, A(arc.olabel, arc.olabel, Weight::One(), j));
194 }
195 }
196 }
197 }
198
199 depfst.SetStart(root_ - 1);
200 depfst.SetFinal(root_ - 1, Weight::One());
201 return depfst.Properties(kCyclic, true);
202 }
203
204 // set root rule for expansion
SetRoot(Label root)205 void SetRoot(Label root) {
206 Label nonterminal = nonterminal_hash_[root];
207 root_ = (nonterminal > 0) ? nonterminal : 1;
208 }
209
210 // Change Fst array
SetFst(Label label,const Fst<A> * fst)211 void SetFst(Label label, const Fst<A>* fst) {
212 Label nonterminal = nonterminal_hash_[label];
213 delete fst_array_[nonterminal];
214 fst_array_[nonterminal] = fst->Copy();
215 }
216
217 // Return or compute start state of replace fst
Start()218 StateId Start() {
219 if (!HasStart()) {
220 if (fst_array_.size() == 1) { // no fsts defined for replace
221 SetStart(kNoStateId);
222 return kNoStateId;
223 } else {
224 const Fst<A>* fst = fst_array_[root_];
225 StateId fst_start = fst->Start();
226 if (fst_start == kNoStateId) // root Fst is empty
227 return kNoStateId;
228
229 int prefix = PrefixId(StackPrefix());
230 StateId start = FindState(StateTuple(prefix, root_, fst_start));
231 SetStart(start);
232 return start;
233 }
234 } else {
235 return CacheImpl<A>::Start();
236 }
237 }
238
239 // return final weight of state (kInfWeight means state is not final)
Final(StateId s)240 Weight Final(StateId s) {
241 if (!HasFinal(s)) {
242 const StateTuple& tuple = state_tuples_[s];
243 const StackPrefix& stack = stackprefix_array_[tuple.prefix_id];
244 const Fst<A>* fst = fst_array_[tuple.fst_id];
245 StateId fst_state = tuple.fst_state;
246
247 if (fst->Final(fst_state) != Weight::Zero() && stack.Depth() == 0)
248 SetFinal(s, fst->Final(fst_state));
249 else
250 SetFinal(s, Weight::Zero());
251 }
252 return CacheImpl<A>::Final(s);
253 }
254
NumArcs(StateId s)255 size_t NumArcs(StateId s) {
256 if (!HasArcs(s))
257 Expand(s);
258 return CacheImpl<A>::NumArcs(s);
259 }
260
NumInputEpsilons(StateId s)261 size_t NumInputEpsilons(StateId s) {
262 if (!HasArcs(s))
263 Expand(s);
264 return CacheImpl<A>::NumInputEpsilons(s);
265 }
266
NumOutputEpsilons(StateId s)267 size_t NumOutputEpsilons(StateId s) {
268 if (!HasArcs(s))
269 Expand(s);
270 return CacheImpl<A>::NumOutputEpsilons(s);
271 }
272
273 // return the base arc iterator, if arcs have not been computed yet,
274 // extend/recurse for new arcs.
InitArcIterator(StateId s,ArcIteratorData<A> * data)275 void InitArcIterator(StateId s, ArcIteratorData<A> *data) {
276 if (!HasArcs(s))
277 Expand(s);
278 CacheImpl<A>::InitArcIterator(s, data);
279 }
280
281 // Find/create an Fst state given a StateTuple. Only create a new
282 // state if StateTuple is not found in the state hash.
FindState(const StateTuple & tuple)283 StateId FindState(const StateTuple& tuple) {
284 typename StateTupleHash::iterator it = state_hash_.find(tuple);
285 if (it == state_hash_.end()) {
286 StateId new_state_id = state_tuples_.size();
287 state_tuples_.push_back(tuple);
288 state_hash_[tuple] = new_state_id;
289 return new_state_id;
290 } else {
291 return it->second;
292 }
293 }
294
295 // extend current state (walk arcs one level deep)
Expand(StateId s)296 void Expand(StateId s) {
297 StateTuple tuple = state_tuples_[s];
298 const Fst<A>* fst = fst_array_[tuple.fst_id];
299 StateId fst_state = tuple.fst_state;
300 if (fst_state == kNoStateId) {
301 SetArcs(s);
302 return;
303 }
304
305 // if state is final, pop up stack
306 const StackPrefix& stack = stackprefix_array_[tuple.prefix_id];
307 if (fst->Final(fst_state) != Weight::Zero() && stack.Depth()) {
308 int prefix_id = PopPrefix(stack);
309 const PrefixTuple& top = stack.Top();
310
311 StateId nextstate =
312 FindState(StateTuple(prefix_id, top.fst_id, top.nextstate));
313 AddArc(s, A(0, 0, fst->Final(fst_state), nextstate));
314 }
315
316 // extend arcs leaving the state
317 for (ArcIterator< Fst<A> > aiter(*fst, fst_state);
318 !aiter.Done(); aiter.Next()) {
319 const Arc& arc = aiter.Value();
320 if (arc.olabel == 0) { // expand local fst
321 StateId nextstate =
322 FindState(StateTuple(tuple.prefix_id, tuple.fst_id, arc.nextstate));
323 AddArc(s, A(arc.ilabel, arc.olabel, arc.weight, nextstate));
324 } else {
325 // check for non terminal
326 typename NonTerminalHash::const_iterator it =
327 nonterminal_hash_.find(arc.olabel);
328 if (it != nonterminal_hash_.end()) { // recurse into non terminal
329 Label nonterminal = it->second;
330 const Fst<A>* nt_fst = fst_array_[nonterminal];
331 int nt_prefix = PushPrefix(stackprefix_array_[tuple.prefix_id],
332 tuple.fst_id, arc.nextstate);
333
334 // if start state is valid replace, else arc is implicitly
335 // deleted
336 StateId nt_start = nt_fst->Start();
337 if (nt_start != kNoStateId) {
338 StateId nt_nextstate = FindState(
339 StateTuple(nt_prefix, nonterminal, nt_start));
340 Label ilabel = (opts_.epsilon_on_replace) ? 0 : arc.ilabel;
341 AddArc(s, A(ilabel, 0, arc.weight, nt_nextstate));
342 }
343 } else {
344 StateId nextstate =
345 FindState(
346 StateTuple(tuple.prefix_id, tuple.fst_id, arc.nextstate));
347 AddArc(s, A(arc.ilabel, arc.olabel, arc.weight, nextstate));
348 }
349 }
350 }
351
352 SetArcs(s);
353 }
354
355
356 // private helper classes
357 private:
358 static const int kPrime0 = 7853;
359 static const int kPrime1 = 7867;
360
361 // \class StateTupleEqual
362 // \brief Compare two StateTuples for equality
363 class StateTupleEqual {
364 public:
operator()365 bool operator()(const StateTuple& x, const StateTuple& y) const {
366 return ((x.prefix_id == y.prefix_id) && (x.fst_id == y.fst_id) &&
367 (x.fst_state == y.fst_state));
368 }
369 };
370
371 // \class StateTupleKey
372 // \brief Hash function for StateTuple to Fst states
373 class StateTupleKey {
374 public:
operator()375 size_t operator()(const StateTuple& x) const {
376 return static_cast<size_t>(x.prefix_id +
377 x.fst_id * kPrime0 +
378 x.fst_state * kPrime1);
379 }
380 };
381
382 typedef hash_map<StateTuple, StateId, StateTupleKey, StateTupleEqual>
383 StateTupleHash;
384
385 // \class PrefixTuple
386 // \brief Tuple of fst_id and destination state (entry in stack prefix)
387 struct PrefixTuple {
PrefixTuplePrefixTuple388 PrefixTuple(Label f, StateId s) : fst_id(f), nextstate(s) {}
389
390 Label fst_id;
391 StateId nextstate;
392 };
393
394 // \class StackPrefix
395 // \brief Container for stack prefix.
396 class StackPrefix {
397 public:
StackPrefix()398 StackPrefix() {}
399
400 // copy constructor
StackPrefix(const StackPrefix & x)401 StackPrefix(const StackPrefix& x) :
402 prefix_(x.prefix_) {
403 }
404
Push(int fst_id,StateId nextstate)405 void Push(int fst_id, StateId nextstate) {
406 prefix_.push_back(PrefixTuple(fst_id, nextstate));
407 }
408
Pop()409 void Pop() {
410 prefix_.pop_back();
411 }
412
Top()413 const PrefixTuple& Top() const {
414 return prefix_[prefix_.size()-1];
415 }
416
Depth()417 size_t Depth() const {
418 return prefix_.size();
419 }
420
421 public:
422 vector<PrefixTuple> prefix_;
423 };
424
425
426 // \class StackPrefixEqual
427 // \brief Compare two stack prefix classes for equality
428 class StackPrefixEqual {
429 public:
operator()430 bool operator()(const StackPrefix& x, const StackPrefix& y) const {
431 if (x.prefix_.size() != y.prefix_.size()) return false;
432 for (size_t i = 0; i < x.prefix_.size(); ++i) {
433 if (x.prefix_[i].fst_id != y.prefix_[i].fst_id ||
434 x.prefix_[i].nextstate != y.prefix_[i].nextstate) return false;
435 }
436 return true;
437 }
438 };
439
440 //
441 // \class StackPrefixKey
442 // \brief Hash function for stack prefix to prefix id
443 class StackPrefixKey {
444 public:
operator()445 size_t operator()(const StackPrefix& x) const {
446 int sum = 0;
447 for (size_t i = 0; i < x.prefix_.size(); ++i) {
448 sum += x.prefix_[i].fst_id + x.prefix_[i].nextstate*kPrime0;
449 }
450 return (size_t) sum;
451 }
452 };
453
454 typedef hash_map<StackPrefix, int, StackPrefixKey, StackPrefixEqual>
455 StackPrefixHash;
456
457 // private methods
458 private:
459 // hash stack prefix (return unique index into stackprefix array)
PrefixId(const StackPrefix & prefix)460 int PrefixId(const StackPrefix& prefix) {
461 typename StackPrefixHash::iterator it = prefix_hash_.find(prefix);
462 if (it == prefix_hash_.end()) {
463 int prefix_id = stackprefix_array_.size();
464 stackprefix_array_.push_back(prefix);
465 prefix_hash_[prefix] = prefix_id;
466 return prefix_id;
467 } else {
468 return it->second;
469 }
470 }
471
472 // prefix id after a stack pop
PopPrefix(StackPrefix prefix)473 int PopPrefix(StackPrefix prefix) {
474 prefix.Pop();
475 return PrefixId(prefix);
476 }
477
478 // prefix id after a stack push
PushPrefix(StackPrefix prefix,Label fst_id,StateId nextstate)479 int PushPrefix(StackPrefix prefix, Label fst_id, StateId nextstate) {
480 prefix.Push(fst_id, nextstate);
481 return PrefixId(prefix);
482 }
483
484
485 // private data
486 private:
487 // runtime options
488 ReplaceFstOptions opts_;
489
490 // maps from StateId to StateTuple
491 vector<StateTuple> state_tuples_;
492
493 // hashes from StateTuple to StateId
494 StateTupleHash state_hash_;
495
496 // cross index of unique stack prefix
497 // could potentially have one copy of prefix array
498 StackPrefixHash prefix_hash_;
499 vector<StackPrefix> stackprefix_array_;
500
501 NonTerminalHash nonterminal_hash_;
502 vector<const Fst<A>*> fst_array_;
503
504 Label root_;
505
506 void operator=(const ReplaceFstImpl<A> &); // disallow
507 };
508
509
510 //
511 // \class ReplaceFst
512 // \brief Recursivively replaces arcs in the root Fst with other Fsts.
513 // This version is a delayed Fst.
514 //
515 // ReplaceFst supports dynamic replacement of arcs in one Fst with
516 // another Fst. This replacement is recursive. ReplaceFst can be used
517 // to support a variety of delayed constructions such as recursive
518 // transition networks, union, or closure. It is constructed with an
519 // array of Fst(s). One Fst represents the root (or topology)
520 // machine. The root Fst refers to other Fsts by recursively replacing
521 // arcs labeled as non-terminals with the matching non-terminal
522 // Fst. Currently the ReplaceFst uses the output symbols of the arcs
523 // to determine whether the arc is a non-terminal arc or not. A
524 // non-terminal can be any label that is not a non-zero terminal label
525 // in the output alphabet.
526 //
527 // Note that the constructor uses a vector of pair<>. These correspond
528 // to the tuple of non-terminal Label and corresponding Fst. For example
529 // to implement the closure operation we need 2 Fsts. The first root
530 // Fst is a single Arc on the start State that self loops, it references
531 // the particular machine for which we are performing the closure operation.
532 //
533 template <class A>
534 class ReplaceFst : public Fst<A> {
535 public:
536 friend class ArcIterator< ReplaceFst<A> >;
537 friend class CacheStateIterator< ReplaceFst<A> >;
538 friend class CacheArcIterator< ReplaceFst<A> >;
539
540 typedef A Arc;
541 typedef typename A::Label Label;
542 typedef typename A::Weight Weight;
543 typedef typename A::StateId StateId;
544 typedef CacheState<A> State;
545
ReplaceFst(const vector<pair<Label,const Fst<A> * >> & fst_array,Label root)546 ReplaceFst(const vector<pair<Label, const Fst<A>* > >& fst_array,
547 Label root)
548 : impl_(new ReplaceFstImpl<A>(fst_array, ReplaceFstOptions(root))) {}
549
ReplaceFst(const vector<pair<Label,const Fst<A> * >> & fst_array,const ReplaceFstOptions & opts)550 ReplaceFst(const vector<pair<Label, const Fst<A>* > >& fst_array,
551 const ReplaceFstOptions &opts)
552 : impl_(new ReplaceFstImpl<A>(fst_array, opts)) {}
553
ReplaceFst(const ReplaceFst<A> & fst)554 ReplaceFst(const ReplaceFst<A>& fst) :
555 impl_(new ReplaceFstImpl<A>(*(fst.impl_))) {}
556
~ReplaceFst()557 virtual ~ReplaceFst() {
558 delete impl_;
559 }
560
Start()561 virtual StateId Start() const {
562 return impl_->Start();
563 }
564
Final(StateId s)565 virtual Weight Final(StateId s) const {
566 return impl_->Final(s);
567 }
568
NumArcs(StateId s)569 virtual size_t NumArcs(StateId s) const {
570 return impl_->NumArcs(s);
571 }
572
NumInputEpsilons(StateId s)573 virtual size_t NumInputEpsilons(StateId s) const {
574 return impl_->NumInputEpsilons(s);
575 }
576
NumOutputEpsilons(StateId s)577 virtual size_t NumOutputEpsilons(StateId s) const {
578 return impl_->NumOutputEpsilons(s);
579 }
580
Properties(uint64 mask,bool test)581 virtual uint64 Properties(uint64 mask, bool test) const {
582 if (test) {
583 uint64 known, test = TestProperties(*this, mask, &known);
584 impl_->SetProperties(test, known);
585 return test & mask;
586 } else {
587 return impl_->Properties(mask);
588 }
589 }
590
Type()591 virtual const string& Type() const {
592 return impl_->Type();
593 }
594
Copy()595 virtual ReplaceFst<A>* Copy() const {
596 return new ReplaceFst<A>(*this);
597 }
598
InputSymbols()599 virtual const SymbolTable* InputSymbols() const {
600 return impl_->InputSymbols();
601 }
602
OutputSymbols()603 virtual const SymbolTable* OutputSymbols() const {
604 return impl_->OutputSymbols();
605 }
606
607 virtual inline void InitStateIterator(StateIteratorData<A> *data) const;
608
InitArcIterator(StateId s,ArcIteratorData<A> * data)609 virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
610 impl_->InitArcIterator(s, data);
611 }
612
CyclicDependencies()613 bool CyclicDependencies() const {
614 return impl_->CyclicDependencies();
615 }
616
617 private:
618 ReplaceFstImpl<A>* impl_;
619 };
620
621
622 // Specialization for ReplaceFst.
623 template<class A>
624 class StateIterator< ReplaceFst<A> >
625 : public CacheStateIterator< ReplaceFst<A> > {
626 public:
StateIterator(const ReplaceFst<A> & fst)627 explicit StateIterator(const ReplaceFst<A> &fst)
628 : CacheStateIterator< ReplaceFst<A> >(fst) {}
629
630 private:
631 DISALLOW_EVIL_CONSTRUCTORS(StateIterator);
632 };
633
634 // Specialization for ReplaceFst.
635 template <class A>
636 class ArcIterator< ReplaceFst<A> >
637 : public CacheArcIterator< ReplaceFst<A> > {
638 public:
639 typedef typename A::StateId StateId;
640
ArcIterator(const ReplaceFst<A> & fst,StateId s)641 ArcIterator(const ReplaceFst<A> &fst, StateId s)
642 : CacheArcIterator< ReplaceFst<A> >(fst, s) {
643 if (!fst.impl_->HasArcs(s))
644 fst.impl_->Expand(s);
645 }
646
647 private:
648 DISALLOW_EVIL_CONSTRUCTORS(ArcIterator);
649 };
650
651 template <class A> inline
InitStateIterator(StateIteratorData<A> * data)652 void ReplaceFst<A>::InitStateIterator(StateIteratorData<A> *data) const {
653 data->base = new StateIterator< ReplaceFst<A> >(*this);
654 }
655
656 typedef ReplaceFst<StdArc> StdReplaceFst;
657
658
659 // // Recursivively replaces arcs in the root Fst with other Fsts.
660 // This version writes the result of replacement to an output MutableFst.
661 //
662 // Replace supports replacement of arcs in one Fst with another
663 // Fst. This replacement is recursive. Replace takes an array of
664 // Fst(s). One Fst represents the root (or topology) machine. The root
665 // Fst refers to other Fsts by recursively replacing arcs labeled as
666 // non-terminals with the matching non-terminal Fst. Currently Replace
667 // uses the output symbols of the arcs to determine whether the arc is
668 // a non-terminal arc or not. A non-terminal can be any label that is
669 // not a non-zero terminal label in the output alphabet. Note that
670 // input argument is a vector of pair<>. These correspond to the tuple
671 // of non-terminal Label and corresponding Fst.
672 template<class Arc>
Replace(const vector<pair<typename Arc::Label,const Fst<Arc> * >> & ifst_array,MutableFst<Arc> * ofst,typename Arc::Label root)673 void Replace(const vector<pair<typename Arc::Label,
674 const Fst<Arc>* > >& ifst_array,
675 MutableFst<Arc> *ofst, typename Arc::Label root) {
676 ReplaceFstOptions opts(root);
677 opts.gc_limit = 0; // Cache only the last state for fastest copy.
678 *ofst = ReplaceFst<Arc>(ifst_array, opts);
679 }
680
681 }
682
683 #endif // FST_LIB_REPLACE_H__
684