1 // replace.h
2
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Copyright 2005-2010 Google, Inc.
16 // Author: johans@google.com (Johan Schalkwyk)
17 //
18 // \file
19 // Functions and classes for the recursive replacement of Fsts.
20 //
21
22 #ifndef FST_LIB_REPLACE_H__
23 #define FST_LIB_REPLACE_H__
24
25 #include <unordered_map>
26 using std::tr1::unordered_map;
27 using std::tr1::unordered_multimap;
28 #include <set>
29 #include <string>
30 #include <utility>
31 using std::pair; using std::make_pair;
32 #include <vector>
33 using std::vector;
34
35 #include <fst/cache.h>
36 #include <fst/expanded-fst.h>
37 #include <fst/fst.h>
38 #include <fst/matcher.h>
39 #include <fst/replace-util.h>
40 #include <fst/state-table.h>
41 #include <fst/test-properties.h>
42
43 namespace fst {
44
45 //
46 // REPLACE STATE TUPLES AND TABLES
47 //
48 // The replace state table has the form
49 //
50 // template <class A, class P>
51 // class ReplaceStateTable {
52 // public:
53 // typedef A Arc;
54 // typedef P PrefixId;
55 // typedef typename A::StateId StateId;
56 // typedef ReplaceStateTuple<StateId, PrefixId> StateTuple;
57 // typedef typename A::Label Label;
58 //
59 // // Required constuctor
60 // ReplaceStateTable(const vector<pair<Label, const Fst<A>*> > &fst_tuples,
61 // Label root);
62 //
63 // // Required copy constructor that does not copy state
64 // ReplaceStateTable(const ReplaceStateTable<A,P> &table);
65 //
66 // // Lookup state ID by tuple. If it doesn't exist, then add it.
67 // StateId FindState(const StateTuple &tuple);
68 //
69 // // Lookup state tuple by ID.
70 // const StateTuple &Tuple(StateId id) const;
71 // };
72
73
74 // \struct ReplaceStateTuple
75 // \brief Tuple of information that uniquely defines a state in replace
76 template <class S, class P>
77 struct ReplaceStateTuple {
78 typedef S StateId;
79 typedef P PrefixId;
80
ReplaceStateTupleReplaceStateTuple81 ReplaceStateTuple()
82 : prefix_id(-1), fst_id(kNoStateId), fst_state(kNoStateId) {}
83
ReplaceStateTupleReplaceStateTuple84 ReplaceStateTuple(PrefixId p, StateId f, StateId s)
85 : prefix_id(p), fst_id(f), fst_state(s) {}
86
87 PrefixId prefix_id; // index in prefix table
88 StateId fst_id; // current fst being walked
89 StateId fst_state; // current state in fst being walked, not to be
90 // confused with the state_id of the combined fst
91 };
92
93
94 // Equality of replace state tuples.
95 template <class S, class P>
96 inline bool operator==(const ReplaceStateTuple<S, P>& x,
97 const ReplaceStateTuple<S, P>& y) {
98 return x.prefix_id == y.prefix_id &&
99 x.fst_id == y.fst_id &&
100 x.fst_state == y.fst_state;
101 }
102
103
104 // \class ReplaceRootSelector
105 // Functor returning true for tuples corresponding to states in the root FST
106 template <class S, class P>
107 class ReplaceRootSelector {
108 public:
operator()109 bool operator()(const ReplaceStateTuple<S, P> &tuple) const {
110 return tuple.prefix_id == 0;
111 }
112 };
113
114
115 // \class ReplaceFingerprint
116 // Fingerprint for general replace state tuples.
117 template <class S, class P>
118 class ReplaceFingerprint {
119 public:
ReplaceFingerprint(const vector<uint64> * size_array)120 ReplaceFingerprint(const vector<uint64> *size_array)
121 : cumulative_size_array_(size_array) {}
122
operator()123 uint64 operator()(const ReplaceStateTuple<S, P> &tuple) const {
124 return tuple.prefix_id * (cumulative_size_array_->back()) +
125 cumulative_size_array_->at(tuple.fst_id - 1) +
126 tuple.fst_state;
127 }
128
129 private:
130 const vector<uint64> *cumulative_size_array_;
131 };
132
133
134 // \class ReplaceFstStateFingerprint
135 // Useful when the fst_state uniquely define the tuple.
136 template <class S, class P>
137 class ReplaceFstStateFingerprint {
138 public:
operator()139 uint64 operator()(const ReplaceStateTuple<S, P>& tuple) const {
140 return tuple.fst_state;
141 }
142 };
143
144
145 // \class ReplaceHash
146 // A generic hash function for replace state tuples.
147 template <typename S, typename P>
148 class ReplaceHash {
149 public:
operator()150 size_t operator()(const ReplaceStateTuple<S, P>& t) const {
151 return t.prefix_id + t.fst_id * kPrime0 + t.fst_state * kPrime1;
152 }
153 private:
154 static const size_t kPrime0;
155 static const size_t kPrime1;
156 };
157
158 template <typename S, typename P>
159 const size_t ReplaceHash<S, P>::kPrime0 = 7853;
160
161 template <typename S, typename P>
162 const size_t ReplaceHash<S, P>::kPrime1 = 7867;
163
164 template <class A, class T> class ReplaceFstMatcher;
165
166
167 // \class VectorHashReplaceStateTable
168 // A two-level state table for replace.
169 // Warning: calls CountStates to compute the number of states of each
170 // component Fst.
171 template <class A, class P = ssize_t>
172 class VectorHashReplaceStateTable {
173 public:
174 typedef A Arc;
175 typedef typename A::StateId StateId;
176 typedef typename A::Label Label;
177 typedef P PrefixId;
178 typedef ReplaceStateTuple<StateId, P> StateTuple;
179 typedef VectorHashStateTable<ReplaceStateTuple<StateId, P>,
180 ReplaceRootSelector<StateId, P>,
181 ReplaceFstStateFingerprint<StateId, P>,
182 ReplaceFingerprint<StateId, P> > StateTable;
183
VectorHashReplaceStateTable(const vector<pair<Label,const Fst<A> * >> & fst_tuples,Label root)184 VectorHashReplaceStateTable(
185 const vector<pair<Label, const Fst<A>*> > &fst_tuples,
186 Label root) : root_size_(0) {
187 cumulative_size_array_.push_back(0);
188 for (size_t i = 0; i < fst_tuples.size(); ++i) {
189 if (fst_tuples[i].first == root) {
190 root_size_ = CountStates(*(fst_tuples[i].second));
191 cumulative_size_array_.push_back(cumulative_size_array_.back());
192 } else {
193 cumulative_size_array_.push_back(cumulative_size_array_.back() +
194 CountStates(*(fst_tuples[i].second)));
195 }
196 }
197 state_table_ = new StateTable(
198 new ReplaceRootSelector<StateId, P>,
199 new ReplaceFstStateFingerprint<StateId, P>,
200 new ReplaceFingerprint<StateId, P>(&cumulative_size_array_),
201 root_size_,
202 root_size_ + cumulative_size_array_.back());
203 }
204
VectorHashReplaceStateTable(const VectorHashReplaceStateTable<A,P> & table)205 VectorHashReplaceStateTable(const VectorHashReplaceStateTable<A, P> &table)
206 : root_size_(table.root_size_),
207 cumulative_size_array_(table.cumulative_size_array_) {
208 state_table_ = new StateTable(
209 new ReplaceRootSelector<StateId, P>,
210 new ReplaceFstStateFingerprint<StateId, P>,
211 new ReplaceFingerprint<StateId, P>(&cumulative_size_array_),
212 root_size_,
213 root_size_ + cumulative_size_array_.back());
214 }
215
~VectorHashReplaceStateTable()216 ~VectorHashReplaceStateTable() {
217 delete state_table_;
218 }
219
FindState(const StateTuple & tuple)220 StateId FindState(const StateTuple &tuple) {
221 return state_table_->FindState(tuple);
222 }
223
Tuple(StateId id)224 const StateTuple &Tuple(StateId id) const {
225 return state_table_->Tuple(id);
226 }
227
228 private:
229 StateId root_size_;
230 vector<uint64> cumulative_size_array_;
231 StateTable *state_table_;
232 };
233
234
235 // \class DefaultReplaceStateTable
236 // Default replace state table
237 template <class A, class P = ssize_t>
238 class DefaultReplaceStateTable : public CompactHashStateTable<
239 ReplaceStateTuple<typename A::StateId, P>,
240 ReplaceHash<typename A::StateId, P> > {
241 public:
242 typedef A Arc;
243 typedef typename A::StateId StateId;
244 typedef typename A::Label Label;
245 typedef P PrefixId;
246 typedef ReplaceStateTuple<StateId, P> StateTuple;
247 typedef CompactHashStateTable<StateTuple,
248 ReplaceHash<StateId, PrefixId> > StateTable;
249
250 using StateTable::FindState;
251 using StateTable::Tuple;
252
DefaultReplaceStateTable(const vector<pair<Label,const Fst<A> * >> & fst_tuples,Label root)253 DefaultReplaceStateTable(
254 const vector<pair<Label, const Fst<A>*> > &fst_tuples,
255 Label root) {}
256
DefaultReplaceStateTable(const DefaultReplaceStateTable<A,P> & table)257 DefaultReplaceStateTable(const DefaultReplaceStateTable<A, P> &table)
258 : StateTable() {}
259 };
260
261 //
262 // REPLACE FST CLASS
263 //
264
265 // By default ReplaceFst will copy the input label of the 'replace arc'.
266 // For acceptors we do not want this behaviour. Instead we need to
267 // create an epsilon arc when recursing into the appropriate Fst.
268 // The 'epsilon_on_replace' option can be used to toggle this behaviour.
269 template <class A, class T = DefaultReplaceStateTable<A> >
270 struct ReplaceFstOptions : CacheOptions {
271 int64 root; // root rule for expansion
272 bool epsilon_on_replace;
273 bool take_ownership; // take ownership of input Fst(s)
274 T* state_table;
275
ReplaceFstOptionsReplaceFstOptions276 ReplaceFstOptions(const CacheOptions &opts, int64 r)
277 : CacheOptions(opts),
278 root(r),
279 epsilon_on_replace(false),
280 take_ownership(false),
281 state_table(0) {}
ReplaceFstOptionsReplaceFstOptions282 explicit ReplaceFstOptions(int64 r)
283 : root(r),
284 epsilon_on_replace(false),
285 take_ownership(false),
286 state_table(0) {}
ReplaceFstOptionsReplaceFstOptions287 ReplaceFstOptions(int64 r, bool epsilon_replace_arc)
288 : root(r),
289 epsilon_on_replace(epsilon_replace_arc),
290 take_ownership(false),
291 state_table(0) {}
ReplaceFstOptionsReplaceFstOptions292 ReplaceFstOptions()
293 : root(kNoLabel),
294 epsilon_on_replace(false),
295 take_ownership(false),
296 state_table(0) {}
297 };
298
299
300 // \class ReplaceFstImpl
301 // \brief Implementation class for replace class Fst
302 //
303 // The replace implementation class supports a dynamic
304 // expansion of a recursive transition network represented as Fst
305 // with dynamic replacable arcs.
306 //
307 template <class A, class T>
308 class ReplaceFstImpl : public CacheImpl<A> {
309 friend class ReplaceFstMatcher<A, T>;
310
311 public:
312 using FstImpl<A>::SetType;
313 using FstImpl<A>::SetProperties;
314 using FstImpl<A>::WriteHeader;
315 using FstImpl<A>::SetInputSymbols;
316 using FstImpl<A>::SetOutputSymbols;
317 using FstImpl<A>::InputSymbols;
318 using FstImpl<A>::OutputSymbols;
319
320 using CacheImpl<A>::PushArc;
321 using CacheImpl<A>::HasArcs;
322 using CacheImpl<A>::HasFinal;
323 using CacheImpl<A>::HasStart;
324 using CacheImpl<A>::SetArcs;
325 using CacheImpl<A>::SetFinal;
326 using CacheImpl<A>::SetStart;
327
328 typedef typename A::Label Label;
329 typedef typename A::Weight Weight;
330 typedef typename A::StateId StateId;
331 typedef CacheState<A> State;
332 typedef A Arc;
333 typedef unordered_map<Label, Label> NonTerminalHash;
334
335 typedef T StateTable;
336 typedef typename T::PrefixId PrefixId;
337 typedef ReplaceStateTuple<StateId, PrefixId> StateTuple;
338
339 // constructor for replace class implementation.
340 // \param fst_tuples array of label/fst tuples, one for each non-terminal
ReplaceFstImpl(const vector<pair<Label,const Fst<A> * >> & fst_tuples,const ReplaceFstOptions<A,T> & opts)341 ReplaceFstImpl(const vector< pair<Label, const Fst<A>* > >& fst_tuples,
342 const ReplaceFstOptions<A, T> &opts)
343 : CacheImpl<A>(opts),
344 epsilon_on_replace_(opts.epsilon_on_replace),
345 state_table_(opts.state_table ? opts.state_table :
346 new StateTable(fst_tuples, opts.root)) {
347
348 SetType("replace");
349
350 if (fst_tuples.size() > 0) {
351 SetInputSymbols(fst_tuples[0].second->InputSymbols());
352 SetOutputSymbols(fst_tuples[0].second->OutputSymbols());
353 }
354
355 bool all_negative = true; // all nonterminals are negative?
356 bool dense_range = true; // all nonterminals are positive
357 // and form a dense range containing 1?
358 for (size_t i = 0; i < fst_tuples.size(); ++i) {
359 Label nonterminal = fst_tuples[i].first;
360 if (nonterminal >= 0)
361 all_negative = false;
362 if (nonterminal > fst_tuples.size() || nonterminal <= 0)
363 dense_range = false;
364 }
365
366 vector<uint64> inprops;
367 bool all_ilabel_sorted = true;
368 bool all_olabel_sorted = true;
369 bool all_non_empty = true;
370 fst_array_.push_back(0);
371 for (size_t i = 0; i < fst_tuples.size(); ++i) {
372 Label label = fst_tuples[i].first;
373 const Fst<A> *fst = fst_tuples[i].second;
374 nonterminal_hash_[label] = fst_array_.size();
375 nonterminal_set_.insert(label);
376 fst_array_.push_back(opts.take_ownership ? fst : fst->Copy());
377 if (fst->Start() == kNoStateId)
378 all_non_empty = false;
379 if(!fst->Properties(kILabelSorted, false))
380 all_ilabel_sorted = false;
381 if(!fst->Properties(kOLabelSorted, false))
382 all_olabel_sorted = false;
383 inprops.push_back(fst->Properties(kCopyProperties, false));
384 if (i) {
385 if (!CompatSymbols(InputSymbols(), fst->InputSymbols())) {
386 FSTERROR() << "ReplaceFstImpl: input symbols of Fst " << i
387 << " does not match input symbols of base Fst (0'th fst)";
388 SetProperties(kError, kError);
389 }
390 if (!CompatSymbols(OutputSymbols(), fst->OutputSymbols())) {
391 FSTERROR() << "ReplaceFstImpl: output symbols of Fst " << i
392 << " does not match output symbols of base Fst "
393 << "(0'th fst)";
394 SetProperties(kError, kError);
395 }
396 }
397 }
398 Label nonterminal = nonterminal_hash_[opts.root];
399 if ((nonterminal == 0) && (fst_array_.size() > 1)) {
400 FSTERROR() << "ReplaceFstImpl: no Fst corresponding to root label '"
401 << opts.root << "' in the input tuple vector";
402 SetProperties(kError, kError);
403 }
404 root_ = (nonterminal > 0) ? nonterminal : 1;
405
406 SetProperties(ReplaceProperties(inprops, root_ - 1, epsilon_on_replace_,
407 all_non_empty));
408 // We assume that all terminals are positive. The resulting
409 // ReplaceFst is known to be kILabelSorted when all sub-FSTs are
410 // kILabelSorted and one of the 3 following conditions is satisfied:
411 // 1. 'epsilon_on_replace' is false, or
412 // 2. all non-terminals are negative, or
413 // 3. all non-terninals are positive and form a dense range containing 1.
414 if (all_ilabel_sorted &&
415 (!epsilon_on_replace_ || all_negative || dense_range))
416 SetProperties(kILabelSorted, kILabelSorted);
417 // Similarly, the resulting ReplaceFst is known to be
418 // kOLabelSorted when all sub-FSTs are kOLabelSorted and one of
419 // the 2 following conditions is satisfied:
420 // 1. all non-terminals are negative, or
421 // 2. all non-terninals are positive and form a dense range containing 1.
422 if (all_olabel_sorted && (all_negative || dense_range))
423 SetProperties(kOLabelSorted, kOLabelSorted);
424
425 // Enable optional caching as long as sorted and all non empty.
426 if (Properties(kILabelSorted | kOLabelSorted) && all_non_empty)
427 always_cache_ = false;
428 else
429 always_cache_ = true;
430 VLOG(2) << "ReplaceFstImpl::ReplaceFstImpl: always_cache = "
431 << (always_cache_ ? "true" : "false");
432 }
433
ReplaceFstImpl(const ReplaceFstImpl & impl)434 ReplaceFstImpl(const ReplaceFstImpl& impl)
435 : CacheImpl<A>(impl),
436 epsilon_on_replace_(impl.epsilon_on_replace_),
437 always_cache_(impl.always_cache_),
438 state_table_(new StateTable(*(impl.state_table_))),
439 nonterminal_set_(impl.nonterminal_set_),
440 nonterminal_hash_(impl.nonterminal_hash_),
441 root_(impl.root_) {
442 SetType("replace");
443 SetProperties(impl.Properties(), kCopyProperties);
444 SetInputSymbols(impl.InputSymbols());
445 SetOutputSymbols(impl.OutputSymbols());
446 fst_array_.reserve(impl.fst_array_.size());
447 fst_array_.push_back(0);
448 for (size_t i = 1; i < impl.fst_array_.size(); ++i) {
449 fst_array_.push_back(impl.fst_array_[i]->Copy(true));
450 }
451 }
452
~ReplaceFstImpl()453 ~ReplaceFstImpl() {
454 VLOG(2) << "~ReplaceFstImpl: gc = "
455 << (CacheImpl<A>::GetCacheGc() ? "true" : "false")
456 << ", gc_size = " << CacheImpl<A>::GetCacheSize()
457 << ", gc_limit = " << CacheImpl<A>::GetCacheLimit();
458
459 delete state_table_;
460 for (size_t i = 1; i < fst_array_.size(); ++i) {
461 delete fst_array_[i];
462 }
463 }
464
465 // Computes the dependency graph of the replace class and returns
466 // true if the dependencies are cyclic. Cyclic dependencies will result
467 // in an un-expandable replace fst.
CyclicDependencies()468 bool CyclicDependencies() const {
469 ReplaceUtil<A> replace_util(fst_array_, nonterminal_hash_, root_);
470 return replace_util.CyclicDependencies();
471 }
472
473 // Return or compute start state of replace fst
Start()474 StateId Start() {
475 if (!HasStart()) {
476 if (fst_array_.size() == 1) { // no fsts defined for replace
477 SetStart(kNoStateId);
478 return kNoStateId;
479 } else {
480 const Fst<A>* fst = fst_array_[root_];
481 StateId fst_start = fst->Start();
482 if (fst_start == kNoStateId) // root Fst is empty
483 return kNoStateId;
484
485 PrefixId prefix = GetPrefixId(StackPrefix());
486 StateId start = state_table_->FindState(
487 StateTuple(prefix, root_, fst_start));
488 SetStart(start);
489 return start;
490 }
491 } else {
492 return CacheImpl<A>::Start();
493 }
494 }
495
496 // return final weight of state (kInfWeight means state is not final)
Final(StateId s)497 Weight Final(StateId s) {
498 if (!HasFinal(s)) {
499 const StateTuple& tuple = state_table_->Tuple(s);
500 const StackPrefix& stack = stackprefix_array_[tuple.prefix_id];
501 const Fst<A>* fst = fst_array_[tuple.fst_id];
502 StateId fst_state = tuple.fst_state;
503
504 if (fst->Final(fst_state) != Weight::Zero() && stack.Depth() == 0)
505 SetFinal(s, fst->Final(fst_state));
506 else
507 SetFinal(s, Weight::Zero());
508 }
509 return CacheImpl<A>::Final(s);
510 }
511
NumArcs(StateId s)512 size_t NumArcs(StateId s) {
513 if (HasArcs(s)) { // If state cached, use the cached value.
514 return CacheImpl<A>::NumArcs(s);
515 } else if (always_cache_) { // If always caching, expand and cache state.
516 Expand(s);
517 return CacheImpl<A>::NumArcs(s);
518 } else { // Otherwise compute the number of arcs without expanding.
519 StateTuple tuple = state_table_->Tuple(s);
520 if (tuple.fst_state == kNoStateId)
521 return 0;
522
523 const Fst<A>* fst = fst_array_[tuple.fst_id];
524 size_t num_arcs = fst->NumArcs(tuple.fst_state);
525 if (ComputeFinalArc(tuple, 0))
526 num_arcs++;
527
528 return num_arcs;
529 }
530 }
531
532 // Returns whether a given label is a non terminal
IsNonTerminal(Label l)533 bool IsNonTerminal(Label l) const {
534 // TODO(allauzen): be smarter and take advantage of
535 // all_dense or all_negative.
536 // Use also in ComputeArc, this would require changes to replace
537 // so that recursing into an empty fst lead to a non co-accessible
538 // state instead of deleting the arc as done currently.
539 // Current use correct, since i/olabel sorted iff all_non_empty.
540 typename NonTerminalHash::const_iterator it =
541 nonterminal_hash_.find(l);
542 return it != nonterminal_hash_.end();
543 }
544
NumInputEpsilons(StateId s)545 size_t NumInputEpsilons(StateId s) {
546 if (HasArcs(s)) {
547 // If state cached, use the cached value.
548 return CacheImpl<A>::NumInputEpsilons(s);
549 } else if (always_cache_ || !Properties(kILabelSorted)) {
550 // If always caching or if the number of input epsilons is too expensive
551 // to compute without caching (i.e. not ilabel sorted),
552 // then expand and cache state.
553 Expand(s);
554 return CacheImpl<A>::NumInputEpsilons(s);
555 } else {
556 // Otherwise, compute the number of input epsilons without caching.
557 StateTuple tuple = state_table_->Tuple(s);
558 if (tuple.fst_state == kNoStateId)
559 return 0;
560 const Fst<A>* fst = fst_array_[tuple.fst_id];
561 size_t num = 0;
562 if (!epsilon_on_replace_) {
563 // If epsilon_on_replace is false, all input epsilon arcs
564 // are also input epsilons arcs in the underlying machine.
565 fst->NumInputEpsilons(tuple.fst_state);
566 } else {
567 // Otherwise, one need to consider that all non-terminal arcs
568 // in the underlying machine also become input epsilon arc.
569 ArcIterator<Fst<A> > aiter(*fst, tuple.fst_state);
570 for (; !aiter.Done() &&
571 ((aiter.Value().ilabel == 0) ||
572 IsNonTerminal(aiter.Value().olabel));
573 aiter.Next())
574 ++num;
575 }
576 if (ComputeFinalArc(tuple, 0))
577 num++;
578 return num;
579 }
580 }
581
NumOutputEpsilons(StateId s)582 size_t NumOutputEpsilons(StateId s) {
583 if (HasArcs(s)) {
584 // If state cached, use the cached value.
585 return CacheImpl<A>::NumOutputEpsilons(s);
586 } else if(always_cache_ || !Properties(kOLabelSorted)) {
587 // If always caching or if the number of output epsilons is too expensive
588 // to compute without caching (i.e. not olabel sorted),
589 // then expand and cache state.
590 Expand(s);
591 return CacheImpl<A>::NumOutputEpsilons(s);
592 } else {
593 // Otherwise, compute the number of output epsilons without caching.
594 StateTuple tuple = state_table_->Tuple(s);
595 if (tuple.fst_state == kNoStateId)
596 return 0;
597 const Fst<A>* fst = fst_array_[tuple.fst_id];
598 size_t num = 0;
599 ArcIterator<Fst<A> > aiter(*fst, tuple.fst_state);
600 for (; !aiter.Done() &&
601 ((aiter.Value().olabel == 0) ||
602 IsNonTerminal(aiter.Value().olabel));
603 aiter.Next())
604 ++num;
605 if (ComputeFinalArc(tuple, 0))
606 num++;
607 return num;
608 }
609 }
610
Properties()611 uint64 Properties() const { return Properties(kFstProperties); }
612
613 // Set error if found; return FST impl properties.
Properties(uint64 mask)614 uint64 Properties(uint64 mask) const {
615 if (mask & kError) {
616 for (size_t i = 1; i < fst_array_.size(); ++i) {
617 if (fst_array_[i]->Properties(kError, false))
618 SetProperties(kError, kError);
619 }
620 }
621 return FstImpl<Arc>::Properties(mask);
622 }
623
624 // return the base arc iterator, if arcs have not been computed yet,
625 // extend/recurse for new arcs.
InitArcIterator(StateId s,ArcIteratorData<A> * data)626 void InitArcIterator(StateId s, ArcIteratorData<A> *data) {
627 if (!HasArcs(s))
628 Expand(s);
629 CacheImpl<A>::InitArcIterator(s, data);
630 // TODO(allauzen): Set behaviour of generic iterator
631 // Warning: ArcIterator<ReplaceFst<A> >::InitCache()
632 // relies on current behaviour.
633 }
634
635
636 // Extend current state (walk arcs one level deep)
Expand(StateId s)637 void Expand(StateId s) {
638 StateTuple tuple = state_table_->Tuple(s);
639
640 // If local fst is empty
641 if (tuple.fst_state == kNoStateId) {
642 SetArcs(s);
643 return;
644 }
645
646 ArcIterator< Fst<A> > aiter(
647 *(fst_array_[tuple.fst_id]), tuple.fst_state);
648 Arc arc;
649
650 // Create a final arc when needed
651 if (ComputeFinalArc(tuple, &arc))
652 PushArc(s, arc);
653
654 // Expand all arcs leaving the state
655 for (;!aiter.Done(); aiter.Next()) {
656 if (ComputeArc(tuple, aiter.Value(), &arc))
657 PushArc(s, arc);
658 }
659
660 SetArcs(s);
661 }
662
Expand(StateId s,const StateTuple & tuple,const ArcIteratorData<A> & data)663 void Expand(StateId s, const StateTuple &tuple,
664 const ArcIteratorData<A> &data) {
665 // If local fst is empty
666 if (tuple.fst_state == kNoStateId) {
667 SetArcs(s);
668 return;
669 }
670
671 ArcIterator< Fst<A> > aiter(data);
672 Arc arc;
673
674 // Create a final arc when needed
675 if (ComputeFinalArc(tuple, &arc))
676 AddArc(s, arc);
677
678 // Expand all arcs leaving the state
679 for (; !aiter.Done(); aiter.Next()) {
680 if (ComputeArc(tuple, aiter.Value(), &arc))
681 AddArc(s, arc);
682 }
683
684 SetArcs(s);
685 }
686
687 // If arcp == 0, only returns if a final arc is required, does not
688 // actually compute it.
689 bool ComputeFinalArc(const StateTuple &tuple, A* arcp,
690 uint32 flags = kArcValueFlags) {
691 const Fst<A>* fst = fst_array_[tuple.fst_id];
692 StateId fst_state = tuple.fst_state;
693 if (fst_state == kNoStateId)
694 return false;
695
696 // if state is final, pop up stack
697 const StackPrefix& stack = stackprefix_array_[tuple.prefix_id];
698 if (fst->Final(fst_state) != Weight::Zero() && stack.Depth()) {
699 if (arcp) {
700 arcp->ilabel = 0;
701 arcp->olabel = 0;
702 if (flags & kArcNextStateValue) {
703 PrefixId prefix_id = PopPrefix(stack);
704 const PrefixTuple& top = stack.Top();
705 arcp->nextstate = state_table_->FindState(
706 StateTuple(prefix_id, top.fst_id, top.nextstate));
707 }
708 if (flags & kArcWeightValue)
709 arcp->weight = fst->Final(fst_state);
710 }
711 return true;
712 } else {
713 return false;
714 }
715 }
716
717 // Compute the arc in the replace fst corresponding to a given
718 // in the underlying machine. Returns false if the underlying arc
719 // corresponds to no arc in the replace.
720 bool ComputeArc(const StateTuple &tuple, const A &arc, A* arcp,
721 uint32 flags = kArcValueFlags) {
722 if (!epsilon_on_replace_ &&
723 (flags == (flags & (kArcILabelValue | kArcWeightValue)))) {
724 *arcp = arc;
725 return true;
726 }
727
728 if (arc.olabel == 0) { // expand local fst
729 StateId nextstate = flags & kArcNextStateValue
730 ? state_table_->FindState(
731 StateTuple(tuple.prefix_id, tuple.fst_id, arc.nextstate))
732 : kNoStateId;
733 *arcp = A(arc.ilabel, arc.olabel, arc.weight, nextstate);
734 } else {
735 // check for non terminal
736 typename NonTerminalHash::const_iterator it =
737 nonterminal_hash_.find(arc.olabel);
738 if (it != nonterminal_hash_.end()) { // recurse into non terminal
739 Label nonterminal = it->second;
740 const Fst<A>* nt_fst = fst_array_[nonterminal];
741 PrefixId nt_prefix = PushPrefix(stackprefix_array_[tuple.prefix_id],
742 tuple.fst_id, arc.nextstate);
743
744 // if start state is valid replace, else arc is implicitly
745 // deleted
746 StateId nt_start = nt_fst->Start();
747 if (nt_start != kNoStateId) {
748 StateId nt_nextstate = flags & kArcNextStateValue
749 ? state_table_->FindState(
750 StateTuple(nt_prefix, nonterminal, nt_start))
751 : kNoStateId;
752 Label ilabel = (epsilon_on_replace_) ? 0 : arc.ilabel;
753 *arcp = A(ilabel, 0, arc.weight, nt_nextstate);
754 } else {
755 return false;
756 }
757 } else {
758 StateId nextstate = flags & kArcNextStateValue
759 ? state_table_->FindState(
760 StateTuple(tuple.prefix_id, tuple.fst_id, arc.nextstate))
761 : kNoStateId;
762 *arcp = A(arc.ilabel, arc.olabel, arc.weight, nextstate);
763 }
764 }
765 return true;
766 }
767
768 // Returns the arc iterator flags supported by this Fst.
ArcIteratorFlags()769 uint32 ArcIteratorFlags() const {
770 uint32 flags = kArcValueFlags;
771 if (!always_cache_)
772 flags |= kArcNoCache;
773 return flags;
774 }
775
GetStateTable()776 T* GetStateTable() const {
777 return state_table_;
778 }
779
GetFst(Label fst_id)780 const Fst<A>* GetFst(Label fst_id) const {
781 return fst_array_[fst_id];
782 }
783
EpsilonOnReplace()784 bool EpsilonOnReplace() const { return epsilon_on_replace_; }
785
786 // private helper classes
787 private:
788 static const size_t kPrime0;
789
790 // \class PrefixTuple
791 // \brief Tuple of fst_id and destination state (entry in stack prefix)
792 struct PrefixTuple {
PrefixTuplePrefixTuple793 PrefixTuple(Label f, StateId s) : fst_id(f), nextstate(s) {}
794
795 Label fst_id;
796 StateId nextstate;
797 };
798
799 // \class StackPrefix
800 // \brief Container for stack prefix.
801 class StackPrefix {
802 public:
StackPrefix()803 StackPrefix() {}
804
805 // copy constructor
StackPrefix(const StackPrefix & x)806 StackPrefix(const StackPrefix& x) :
807 prefix_(x.prefix_) {
808 }
809
Push(StateId fst_id,StateId nextstate)810 void Push(StateId fst_id, StateId nextstate) {
811 prefix_.push_back(PrefixTuple(fst_id, nextstate));
812 }
813
Pop()814 void Pop() {
815 prefix_.pop_back();
816 }
817
Top()818 const PrefixTuple& Top() const {
819 return prefix_[prefix_.size()-1];
820 }
821
Depth()822 size_t Depth() const {
823 return prefix_.size();
824 }
825
826 public:
827 vector<PrefixTuple> prefix_;
828 };
829
830
831 // \class StackPrefixEqual
832 // \brief Compare two stack prefix classes for equality
833 class StackPrefixEqual {
834 public:
operator()835 bool operator()(const StackPrefix& x, const StackPrefix& y) const {
836 if (x.prefix_.size() != y.prefix_.size()) return false;
837 for (size_t i = 0; i < x.prefix_.size(); ++i) {
838 if (x.prefix_[i].fst_id != y.prefix_[i].fst_id ||
839 x.prefix_[i].nextstate != y.prefix_[i].nextstate) return false;
840 }
841 return true;
842 }
843 };
844
845 //
846 // \class StackPrefixKey
847 // \brief Hash function for stack prefix to prefix id
848 class StackPrefixKey {
849 public:
operator()850 size_t operator()(const StackPrefix& x) const {
851 size_t sum = 0;
852 for (size_t i = 0; i < x.prefix_.size(); ++i) {
853 sum += x.prefix_[i].fst_id + x.prefix_[i].nextstate*kPrime0;
854 }
855 return sum;
856 }
857 };
858
859 typedef unordered_map<StackPrefix, PrefixId, StackPrefixKey, StackPrefixEqual>
860 StackPrefixHash;
861
862 // private methods
863 private:
864 // hash stack prefix (return unique index into stackprefix array)
GetPrefixId(const StackPrefix & prefix)865 PrefixId GetPrefixId(const StackPrefix& prefix) {
866 typename StackPrefixHash::iterator it = prefix_hash_.find(prefix);
867 if (it == prefix_hash_.end()) {
868 PrefixId prefix_id = stackprefix_array_.size();
869 stackprefix_array_.push_back(prefix);
870 prefix_hash_[prefix] = prefix_id;
871 return prefix_id;
872 } else {
873 return it->second;
874 }
875 }
876
877 // prefix id after a stack pop
PopPrefix(StackPrefix prefix)878 PrefixId PopPrefix(StackPrefix prefix) {
879 prefix.Pop();
880 return GetPrefixId(prefix);
881 }
882
883 // prefix id after a stack push
PushPrefix(StackPrefix prefix,Label fst_id,StateId nextstate)884 PrefixId PushPrefix(StackPrefix prefix, Label fst_id, StateId nextstate) {
885 prefix.Push(fst_id, nextstate);
886 return GetPrefixId(prefix);
887 }
888
889
890 // private data
891 private:
892 // runtime options
893 bool epsilon_on_replace_;
894 bool always_cache_; // Optionally caching arc iterator disabled when true
895
896 // state table
897 StateTable *state_table_;
898
899 // cross index of unique stack prefix
900 // could potentially have one copy of prefix array
901 StackPrefixHash prefix_hash_;
902 vector<StackPrefix> stackprefix_array_;
903
904 set<Label> nonterminal_set_;
905 NonTerminalHash nonterminal_hash_;
906 vector<const Fst<A>*> fst_array_;
907 Label root_;
908
909 void operator=(const ReplaceFstImpl<A, T> &); // disallow
910 };
911
912
913 template <class A, class T>
914 const size_t ReplaceFstImpl<A, T>::kPrime0 = 7853;
915
916 //
917 // \class ReplaceFst
918 // \brief Recursivively replaces arcs in the root Fst with other Fsts.
919 // This version is a delayed Fst.
920 //
921 // ReplaceFst supports dynamic replacement of arcs in one Fst with
922 // another Fst. This replacement is recursive. ReplaceFst can be used
923 // to support a variety of delayed constructions such as recursive
924 // transition networks, union, or closure. It is constructed with an
925 // array of Fst(s). One Fst represents the root (or topology)
926 // machine. The root Fst refers to other Fsts by recursively replacing
927 // arcs labeled as non-terminals with the matching non-terminal
928 // Fst. Currently the ReplaceFst uses the output symbols of the arcs
929 // to determine whether the arc is a non-terminal arc or not. A
930 // non-terminal can be any label that is not a non-zero terminal label
931 // in the output alphabet.
932 //
933 // Note that the constructor uses a vector of pair<>. These correspond
934 // to the tuple of non-terminal Label and corresponding Fst. For example
935 // to implement the closure operation we need 2 Fsts. The first root
936 // Fst is a single Arc on the start State that self loops, it references
937 // the particular machine for which we are performing the closure operation.
938 //
939 // The ReplaceFst class supports an optionally caching arc iterator:
940 // ArcIterator< ReplaceFst<A> >
941 // The ReplaceFst need to be built such that it is known to be ilabel
942 // or olabel sorted (see usage below).
943 //
944 // Observe that Matcher<Fst<A> > will use the optionally caching arc
945 // iterator when available (Fst is ilabel sorted and matching on the
946 // input, or Fst is olabel sorted and matching on the output).
947 // In order to obtain the most efficient behaviour, it is recommended
948 // to set 'epsilon_on_replace' to false (this means constructing acceptors
949 // as transducers with epsilons on the input side of nonterminal arcs)
950 // and matching on the input side.
951 //
952 // This class attaches interface to implementation and handles
953 // reference counting, delegating most methods to ImplToFst.
954 template <class A, class T = DefaultReplaceStateTable<A> >
955 class ReplaceFst : public ImplToFst< ReplaceFstImpl<A, T> > {
956 public:
957 friend class ArcIterator< ReplaceFst<A, T> >;
958 friend class StateIterator< ReplaceFst<A, T> >;
959 friend class ReplaceFstMatcher<A, T>;
960
961 typedef A Arc;
962 typedef typename A::Label Label;
963 typedef typename A::Weight Weight;
964 typedef typename A::StateId StateId;
965 typedef CacheState<A> State;
966 typedef ReplaceFstImpl<A, T> Impl;
967
968 using ImplToFst<Impl>::Properties;
969
ReplaceFst(const vector<pair<Label,const Fst<A> * >> & fst_array,Label root)970 ReplaceFst(const vector<pair<Label, const Fst<A>* > >& fst_array,
971 Label root)
972 : ImplToFst<Impl>(new Impl(fst_array, ReplaceFstOptions<A, T>(root))) {}
973
ReplaceFst(const vector<pair<Label,const Fst<A> * >> & fst_array,const ReplaceFstOptions<A,T> & opts)974 ReplaceFst(const vector<pair<Label, const Fst<A>* > >& fst_array,
975 const ReplaceFstOptions<A, T> &opts)
976 : ImplToFst<Impl>(new Impl(fst_array, opts)) {}
977
978 // See Fst<>::Copy() for doc.
979 ReplaceFst(const ReplaceFst<A, T>& fst, bool safe = false)
980 : ImplToFst<Impl>(fst, safe) {}
981
982 // Get a copy of this ReplaceFst. See Fst<>::Copy() for further doc.
983 virtual ReplaceFst<A, T> *Copy(bool safe = false) const {
984 return new ReplaceFst<A, T>(*this, safe);
985 }
986
987 virtual inline void InitStateIterator(StateIteratorData<A> *data) const;
988
InitArcIterator(StateId s,ArcIteratorData<A> * data)989 virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
990 GetImpl()->InitArcIterator(s, data);
991 }
992
InitMatcher(MatchType match_type)993 virtual MatcherBase<A> *InitMatcher(MatchType match_type) const {
994 if ((GetImpl()->ArcIteratorFlags() & kArcNoCache) &&
995 ((match_type == MATCH_INPUT && Properties(kILabelSorted, false)) ||
996 (match_type == MATCH_OUTPUT && Properties(kOLabelSorted, false)))) {
997 return new ReplaceFstMatcher<A, T>(*this, match_type);
998 }
999 else {
1000 VLOG(2) << "Not using replace matcher";
1001 return 0;
1002 }
1003 }
1004
CyclicDependencies()1005 bool CyclicDependencies() const {
1006 return GetImpl()->CyclicDependencies();
1007 }
1008
1009 private:
1010 // Makes visible to friends.
GetImpl()1011 Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); }
1012
1013 void operator=(const ReplaceFst<A> &fst); // disallow
1014 };
1015
1016
1017 // Specialization for ReplaceFst.
1018 template<class A, class T>
1019 class StateIterator< ReplaceFst<A, T> >
1020 : public CacheStateIterator< ReplaceFst<A, T> > {
1021 public:
StateIterator(const ReplaceFst<A,T> & fst)1022 explicit StateIterator(const ReplaceFst<A, T> &fst)
1023 : CacheStateIterator< ReplaceFst<A, T> >(fst, fst.GetImpl()) {}
1024
1025 private:
1026 DISALLOW_COPY_AND_ASSIGN(StateIterator);
1027 };
1028
1029
1030 // Specialization for ReplaceFst.
1031 // Implements optional caching. It can be used as follows:
1032 //
1033 // ReplaceFst<A> replace;
1034 // ArcIterator< ReplaceFst<A> > aiter(replace, s);
1035 // // Note: ArcIterator< Fst<A> > is always a caching arc iterator.
1036 // aiter.SetFlags(kArcNoCache, kArcNoCache);
1037 // // Use the arc iterator, no arc will be cached, no state will be expanded.
1038 // // The varied 'kArcValueFlags' can be used to decide which part
1039 // // of arc values needs to be computed.
1040 // aiter.SetFlags(kArcILabelValue, kArcValueFlags);
1041 // // Only want the ilabel for this arc
1042 // aiter.Value(); // Does not compute the destination state.
1043 // aiter.Next();
1044 // aiter.SetFlags(kArcNextStateValue, kArcNextStateValue);
1045 // // Want both ilabel and nextstate for that arc
1046 // aiter.Value(); // Does compute the destination state and inserts it
1047 // // in the replace state table.
1048 // // No Arc has been cached at that point.
1049 //
1050 template <class A, class T>
1051 class ArcIterator< ReplaceFst<A, T> > {
1052 public:
1053 typedef A Arc;
1054 typedef typename A::StateId StateId;
1055
ArcIterator(const ReplaceFst<A,T> & fst,StateId s)1056 ArcIterator(const ReplaceFst<A, T> &fst, StateId s)
1057 : fst_(fst), state_(s), pos_(0), offset_(0), flags_(0), arcs_(0),
1058 data_flags_(0), final_flags_(0) {
1059 cache_data_.ref_count = 0;
1060 local_data_.ref_count = 0;
1061
1062 // If FST does not support optional caching, force caching.
1063 if(!(fst_.GetImpl()->ArcIteratorFlags() & kArcNoCache) &&
1064 !(fst_.GetImpl()->HasArcs(state_)))
1065 fst_.GetImpl()->Expand(state_);
1066
1067 // If state is already cached, use cached arcs array.
1068 if (fst_.GetImpl()->HasArcs(state_)) {
1069 (fst_.GetImpl())->template CacheImpl<A>::InitArcIterator(state_,
1070 &cache_data_);
1071 num_arcs_ = cache_data_.narcs;
1072 arcs_ = cache_data_.arcs; // 'arcs_' is a ptr to the cached arcs.
1073 data_flags_ = kArcValueFlags; // All the arc member values are valid.
1074 } else { // Otherwise delay decision until Value() is called.
1075 tuple_ = fst_.GetImpl()->GetStateTable()->Tuple(state_);
1076 if (tuple_.fst_state == kNoStateId) {
1077 num_arcs_ = 0;
1078 } else {
1079 // The decision to cache or not to cache has been defered
1080 // until Value() or SetFlags() is called. However, the arc
1081 // iterator is set up now to be ready for non-caching in order
1082 // to keep the Value() method simple and efficient.
1083 const Fst<A>* fst = fst_.GetImpl()->GetFst(tuple_.fst_id);
1084 fst->InitArcIterator(tuple_.fst_state, &local_data_);
1085 // 'arcs_' is a pointer to the arcs in the underlying machine.
1086 arcs_ = local_data_.arcs;
1087 // Compute the final arc (but not its destination state)
1088 // if a final arc is required.
1089 bool has_final_arc = fst_.GetImpl()->ComputeFinalArc(
1090 tuple_,
1091 &final_arc_,
1092 kArcValueFlags & ~kArcNextStateValue);
1093 // Set the arc value flags that hold for 'final_arc_'.
1094 final_flags_ = kArcValueFlags & ~kArcNextStateValue;
1095 // Compute the number of arcs.
1096 num_arcs_ = local_data_.narcs;
1097 if (has_final_arc)
1098 ++num_arcs_;
1099 // Set the offset between the underlying arc positions and
1100 // the positions in the arc iterator.
1101 offset_ = num_arcs_ - local_data_.narcs;
1102 // Defers the decision to cache or not until Value() or
1103 // SetFlags() is called.
1104 data_flags_ = 0;
1105 }
1106 }
1107 }
1108
~ArcIterator()1109 ~ArcIterator() {
1110 if (cache_data_.ref_count)
1111 --(*cache_data_.ref_count);
1112 if (local_data_.ref_count)
1113 --(*local_data_.ref_count);
1114 }
1115
ExpandAndCache()1116 void ExpandAndCache() const {
1117 // TODO(allauzen): revisit this
1118 // fst_.GetImpl()->Expand(state_, tuple_, local_data_);
1119 // (fst_.GetImpl())->CacheImpl<A>*>::InitArcIterator(state_,
1120 // &cache_data_);
1121 //
1122 fst_.InitArcIterator(state_, &cache_data_); // Expand and cache state.
1123 arcs_ = cache_data_.arcs; // 'arcs_' is a pointer to the cached arcs.
1124 data_flags_ = kArcValueFlags; // All the arc member values are valid.
1125 offset_ = 0; // No offset
1126
1127 }
1128
Init()1129 void Init() {
1130 if (flags_ & kArcNoCache) { // If caching is disabled
1131 // 'arcs_' is a pointer to the arcs in the underlying machine.
1132 arcs_ = local_data_.arcs;
1133 // Set the arcs value flags that hold for 'arcs_'.
1134 data_flags_ = kArcWeightValue;
1135 if (!fst_.GetImpl()->EpsilonOnReplace())
1136 data_flags_ |= kArcILabelValue;
1137 // Set the offset between the underlying arc positions and
1138 // the positions in the arc iterator.
1139 offset_ = num_arcs_ - local_data_.narcs;
1140 } else { // Otherwise, expand and cache
1141 ExpandAndCache();
1142 }
1143 }
1144
Done()1145 bool Done() const { return pos_ >= num_arcs_; }
1146
Value()1147 const A& Value() const {
1148 // If 'data_flags_' was set to 0, non-caching was not requested
1149 if (!data_flags_) {
1150 // TODO(allauzen): revisit this.
1151 if (flags_ & kArcNoCache) {
1152 // Should never happen.
1153 FSTERROR() << "ReplaceFst: inconsistent arc iterator flags";
1154 }
1155 ExpandAndCache(); // Expand and cache.
1156 }
1157
1158 if (pos_ - offset_ >= 0) { // The requested arc is not the 'final' arc.
1159 const A& arc = arcs_[pos_ - offset_];
1160 if ((data_flags_ & flags_) == (flags_ & kArcValueFlags)) {
1161 // If the value flags for 'arc' match the recquired value flags
1162 // then return 'arc'.
1163 return arc;
1164 } else {
1165 // Otherwise, compute the corresponding arc on-the-fly.
1166 fst_.GetImpl()->ComputeArc(tuple_, arc, &arc_, flags_ & kArcValueFlags);
1167 return arc_;
1168 }
1169 } else { // The requested arc is the 'final' arc.
1170 if ((final_flags_ & flags_) != (flags_ & kArcValueFlags)) {
1171 // If the arc value flags that hold for the final arc
1172 // do not match the requested value flags, then
1173 // 'final_arc_' needs to be updated.
1174 fst_.GetImpl()->ComputeFinalArc(tuple_, &final_arc_,
1175 flags_ & kArcValueFlags);
1176 final_flags_ = flags_ & kArcValueFlags;
1177 }
1178 return final_arc_;
1179 }
1180 }
1181
Next()1182 void Next() { ++pos_; }
1183
Position()1184 size_t Position() const { return pos_; }
1185
Reset()1186 void Reset() { pos_ = 0; }
1187
Seek(size_t pos)1188 void Seek(size_t pos) { pos_ = pos; }
1189
Flags()1190 uint32 Flags() const { return flags_; }
1191
SetFlags(uint32 f,uint32 mask)1192 void SetFlags(uint32 f, uint32 mask) {
1193 // Update the flags taking into account what flags are supported
1194 // by the Fst.
1195 flags_ &= ~mask;
1196 flags_ |= (f & fst_.GetImpl()->ArcIteratorFlags());
1197 // If non-caching is not requested (and caching has not already
1198 // been performed), then flush 'data_flags_' to request caching
1199 // during the next call to Value().
1200 if (!(flags_ & kArcNoCache) && data_flags_ != kArcValueFlags) {
1201 if (!fst_.GetImpl()->HasArcs(state_))
1202 data_flags_ = 0;
1203 }
1204 // If 'data_flags_' has been flushed but non-caching is requested
1205 // before calling Value(), then set up the iterator for non-caching.
1206 if ((f & kArcNoCache) && (!data_flags_))
1207 Init();
1208 }
1209
1210 private:
1211 const ReplaceFst<A, T> &fst_; // Reference to the FST
1212 StateId state_; // State in the FST
1213 mutable typename T::StateTuple tuple_; // Tuple corresponding to state_
1214
1215 ssize_t pos_; // Current position
1216 mutable ssize_t offset_; // Offset between position in iterator and in arcs_
1217 ssize_t num_arcs_; // Number of arcs at state_
1218 uint32 flags_; // Behavorial flags for the arc iterator
1219 mutable Arc arc_; // Memory to temporarily store computed arcs
1220
1221 mutable ArcIteratorData<Arc> cache_data_; // Arc iterator data in cache
1222 mutable ArcIteratorData<Arc> local_data_; // Arc iterator data in local fst
1223
1224 mutable const A* arcs_; // Array of arcs
1225 mutable uint32 data_flags_; // Arc value flags valid for data in arcs_
1226 mutable Arc final_arc_; // Final arc (when required)
1227 mutable uint32 final_flags_; // Arc value flags valid for final_arc_
1228
1229 DISALLOW_COPY_AND_ASSIGN(ArcIterator);
1230 };
1231
1232
1233 template <class A, class T>
1234 class ReplaceFstMatcher : public MatcherBase<A> {
1235 public:
1236 typedef A Arc;
1237 typedef typename A::StateId StateId;
1238 typedef typename A::Label Label;
1239 typedef MultiEpsMatcher<Matcher<Fst<A> > > LocalMatcher;
1240
ReplaceFstMatcher(const ReplaceFst<A,T> & fst,fst::MatchType match_type)1241 ReplaceFstMatcher(const ReplaceFst<A, T> &fst, fst::MatchType match_type)
1242 : fst_(fst),
1243 impl_(fst_.GetImpl()),
1244 s_(fst::kNoStateId),
1245 match_type_(match_type),
1246 current_loop_(false),
1247 final_arc_(false),
1248 loop_(fst::kNoLabel, 0, A::Weight::One(), fst::kNoStateId) {
1249 if (match_type_ == fst::MATCH_OUTPUT)
1250 swap(loop_.ilabel, loop_.olabel);
1251 InitMatchers();
1252 }
1253
1254 ReplaceFstMatcher(const ReplaceFstMatcher<A, T> &matcher, bool safe = false)
1255 : fst_(matcher.fst_),
1256 impl_(fst_.GetImpl()),
1257 s_(fst::kNoStateId),
1258 match_type_(matcher.match_type_),
1259 current_loop_(false),
1260 loop_(fst::kNoLabel, 0, A::Weight::One(), fst::kNoStateId) {
1261 if (match_type_ == fst::MATCH_OUTPUT)
1262 swap(loop_.ilabel, loop_.olabel);
1263 InitMatchers();
1264 }
1265
1266 // Create a local matcher for each component Fst of replace.
1267 // LocalMatcher is a multi epsilon wrapper matcher. MultiEpsilonMatcher
1268 // is used to match each non-terminal arc, since these non-terminal
1269 // turn into epsilons on recursion.
InitMatchers()1270 void InitMatchers() {
1271 const vector<const Fst<A>*>& fst_array = impl_->fst_array_;
1272 matcher_.resize(fst_array.size(), 0);
1273 for (size_t i = 0; i < fst_array.size(); ++i) {
1274 if (fst_array[i]) {
1275 matcher_[i] =
1276 new LocalMatcher(*fst_array[i], match_type_, kMultiEpsList);
1277
1278 typename set<Label>::iterator it = impl_->nonterminal_set_.begin();
1279 for (; it != impl_->nonterminal_set_.end(); ++it) {
1280 matcher_[i]->AddMultiEpsLabel(*it);
1281 }
1282 }
1283 }
1284 }
1285
1286 virtual ReplaceFstMatcher<A, T> *Copy(bool safe = false) const {
1287 return new ReplaceFstMatcher<A, T>(*this, safe);
1288 }
1289
~ReplaceFstMatcher()1290 virtual ~ReplaceFstMatcher() {
1291 for (size_t i = 0; i < matcher_.size(); ++i)
1292 delete matcher_[i];
1293 }
1294
Type(bool test)1295 virtual MatchType Type(bool test) const {
1296 if (match_type_ == MATCH_NONE)
1297 return match_type_;
1298
1299 uint64 true_prop = match_type_ == MATCH_INPUT ?
1300 kILabelSorted : kOLabelSorted;
1301 uint64 false_prop = match_type_ == MATCH_INPUT ?
1302 kNotILabelSorted : kNotOLabelSorted;
1303 uint64 props = fst_.Properties(true_prop | false_prop, test);
1304
1305 if (props & true_prop)
1306 return match_type_;
1307 else if (props & false_prop)
1308 return MATCH_NONE;
1309 else
1310 return MATCH_UNKNOWN;
1311 }
1312
GetFst()1313 virtual const Fst<A> &GetFst() const {
1314 return fst_;
1315 }
1316
Properties(uint64 props)1317 virtual uint64 Properties(uint64 props) const {
1318 return props;
1319 }
1320
1321 private:
1322 // Set the sate from which our matching happens.
SetState_(StateId s)1323 virtual void SetState_(StateId s) {
1324 if (s_ == s) return;
1325
1326 s_ = s;
1327 tuple_ = impl_->GetStateTable()->Tuple(s_);
1328 if (tuple_.fst_state == kNoStateId) {
1329 done_ = true;
1330 return;
1331 }
1332 // Get current matcher. Used for non epsilon matching
1333 current_matcher_ = matcher_[tuple_.fst_id];
1334 current_matcher_->SetState(tuple_.fst_state);
1335 loop_.nextstate = s_;
1336
1337 final_arc_ = false;
1338 }
1339
1340 // Search for label, from previous set state. If label == 0, first
1341 // hallucinate and epsilon loop, else use the underlying matcher to
1342 // search for the label or epsilons.
1343 // - Note since the ReplaceFST recursion on non-terminal arcs causes
1344 // epsilon transitions to be created we use the MultiEpsilonMatcher
1345 // to search for possible matches of non terminals.
1346 // - If the component Fst reaches a final state we also need to add
1347 // the exiting final arc.
Find_(Label label)1348 virtual bool Find_(Label label) {
1349 bool found = false;
1350 label_ = label;
1351 if (label_ == 0 || label_ == kNoLabel) {
1352 // Compute loop directly, saving Replace::ComputeArc
1353 if (label_ == 0) {
1354 current_loop_ = true;
1355 found = true;
1356 }
1357 // Search for matching multi epsilons
1358 final_arc_ = impl_->ComputeFinalArc(tuple_, 0);
1359 found = current_matcher_->Find(kNoLabel) || final_arc_ || found;
1360 } else {
1361 // Search on sub machine directly using sub machine matcher.
1362 found = current_matcher_->Find(label_);
1363 }
1364 return found;
1365 }
1366
Done_()1367 virtual bool Done_() const {
1368 return !current_loop_ && !final_arc_ && current_matcher_->Done();
1369 }
1370
Value_()1371 virtual const Arc& Value_() const {
1372 if (current_loop_) {
1373 return loop_;
1374 }
1375 if (final_arc_) {
1376 impl_->ComputeFinalArc(tuple_, &arc_);
1377 return arc_;
1378 }
1379 const Arc& component_arc = current_matcher_->Value();
1380 impl_->ComputeArc(tuple_, component_arc, &arc_);
1381 return arc_;
1382 }
1383
Next_()1384 virtual void Next_() {
1385 if (current_loop_) {
1386 current_loop_ = false;
1387 return;
1388 }
1389 if (final_arc_) {
1390 final_arc_ = false;
1391 return;
1392 }
1393 current_matcher_->Next();
1394 }
1395
1396 const ReplaceFst<A, T>& fst_;
1397 ReplaceFstImpl<A, T> *impl_;
1398 LocalMatcher* current_matcher_;
1399 vector<LocalMatcher*> matcher_;
1400
1401 StateId s_; // Current state
1402 Label label_; // Current label
1403
1404 MatchType match_type_; // Supplied by caller
1405 mutable bool done_;
1406 mutable bool current_loop_; // Current arc is the implicit loop
1407 mutable bool final_arc_; // Current arc for exiting recursion
1408 mutable typename T::StateTuple tuple_; // Tuple corresponding to state_
1409 mutable Arc arc_;
1410 Arc loop_;
1411 };
1412
1413 template <class A, class T> inline
InitStateIterator(StateIteratorData<A> * data)1414 void ReplaceFst<A, T>::InitStateIterator(StateIteratorData<A> *data) const {
1415 data->base = new StateIterator< ReplaceFst<A, T> >(*this);
1416 }
1417
1418 typedef ReplaceFst<StdArc> StdReplaceFst;
1419
1420
1421 // // Recursivively replaces arcs in the root Fst with other Fsts.
1422 // This version writes the result of replacement to an output MutableFst.
1423 //
1424 // Replace supports replacement of arcs in one Fst with another
1425 // Fst. This replacement is recursive. Replace takes an array of
1426 // Fst(s). One Fst represents the root (or topology) machine. The root
1427 // Fst refers to other Fsts by recursively replacing arcs labeled as
1428 // non-terminals with the matching non-terminal Fst. Currently Replace
1429 // uses the output symbols of the arcs to determine whether the arc is
1430 // a non-terminal arc or not. A non-terminal can be any label that is
1431 // not a non-zero terminal label in the output alphabet. Note that
1432 // input argument is a vector of pair<>. These correspond to the tuple
1433 // of non-terminal Label and corresponding Fst.
1434 template<class Arc>
Replace(const vector<pair<typename Arc::Label,const Fst<Arc> * >> & ifst_array,MutableFst<Arc> * ofst,typename Arc::Label root,bool epsilon_on_replace)1435 void Replace(const vector<pair<typename Arc::Label,
1436 const Fst<Arc>* > >& ifst_array,
1437 MutableFst<Arc> *ofst, typename Arc::Label root,
1438 bool epsilon_on_replace) {
1439 ReplaceFstOptions<Arc> opts(root, epsilon_on_replace);
1440 opts.gc_limit = 0; // Cache only the last state for fastest copy.
1441 *ofst = ReplaceFst<Arc>(ifst_array, opts);
1442 }
1443
1444 template<class Arc>
Replace(const vector<pair<typename Arc::Label,const Fst<Arc> * >> & ifst_array,MutableFst<Arc> * ofst,typename Arc::Label root)1445 void Replace(const vector<pair<typename Arc::Label,
1446 const Fst<Arc>* > >& ifst_array,
1447 MutableFst<Arc> *ofst, typename Arc::Label root) {
1448 Replace(ifst_array, ofst, root, false);
1449 }
1450
1451 } // namespace fst
1452
1453 #endif // FST_LIB_REPLACE_H__
1454