1 // map.h
2
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Copyright 2005-2010 Google, Inc.
16 // Author: riley@google.com (Michael Riley)
17 //
18 // \file
19 // Class to map over/transform states e.g., sort transitions
20 // Consider using when operation does not change the number of states.
21
22 #ifndef FST_LIB_STATE_MAP_H__
23 #define FST_LIB_STATE_MAP_H__
24
25 #include <algorithm>
26 #include <unordered_map>
27 using std::tr1::unordered_map;
28 using std::tr1::unordered_multimap;
29 #include <string>
30 #include <utility>
31 using std::pair; using std::make_pair;
32
33 #include <fst/cache.h>
34 #include <fst/arc-map.h>
35 #include <fst/mutable-fst.h>
36
37
38 namespace fst {
39
40 // StateMapper Interface - class determinies how states are mapped.
41 // Useful for implementing operations that do not change the number of states.
42 //
43 // class StateMapper {
44 // public:
45 // typedef A FromArc;
46 // typedef B ToArc;
47 //
48 // // Typical constructor
49 // StateMapper(const Fst<A> &fst);
50 // // Required copy constructor that allows updating Fst argument;
51 // // pass only if relevant and changed.
52 // StateMapper(const StateMapper &mapper, const Fst<A> *fst = 0);
53 //
54 // // Specifies initial state of result
55 // B::StateId Start() const;
56 // // Specifies state's final weight in result
57 // B::Weight Final(B::StateId s) const;
58 //
59 // // These methods iterate through a state's arcs in result
60 // // Specifies state to iterate over
61 // void SetState(B::StateId s);
62 // // End of arcs?
63 // bool Done() const;
64 // // Current arc
65
66 // const B &Value() const;
67 // // Advance to next arc (when !Done)
68 // void Next();
69 //
70 // // Specifies input symbol table action the mapper requires (see above).
71 // MapSymbolsAction InputSymbolsAction() const;
72 // // Specifies output symbol table action the mapper requires (see above).
73 // MapSymbolsAction OutputSymbolsAction() const;
74 // // This specifies the known properties of an Fst mapped by this
75 // // mapper. It takes as argument the input Fst's known properties.
76 // uint64 Properties(uint64 props) const;
77 // };
78 //
79 // We include a various state map versions below. One dimension of
80 // variation is whether the mapping mutates its input, writes to a
81 // new result Fst, or is an on-the-fly Fst. Another dimension is how
82 // we pass the mapper. We allow passing the mapper by pointer
83 // for cases that we need to change the state of the user's mapper.
84 // We also include map versions that pass the mapper
85 // by value or const reference when this suffices.
86
87 // Maps an arc type A using a mapper function object C, passed
88 // by pointer. This version modifies its Fst input.
89 template<class A, class C>
StateMap(MutableFst<A> * fst,C * mapper)90 void StateMap(MutableFst<A> *fst, C* mapper) {
91 typedef typename A::StateId StateId;
92 typedef typename A::Weight Weight;
93
94 if (mapper->InputSymbolsAction() == MAP_CLEAR_SYMBOLS)
95 fst->SetInputSymbols(0);
96
97 if (mapper->OutputSymbolsAction() == MAP_CLEAR_SYMBOLS)
98 fst->SetOutputSymbols(0);
99
100 if (fst->Start() == kNoStateId)
101 return;
102
103 uint64 props = fst->Properties(kFstProperties, false);
104
105 fst->SetStart(mapper->Start());
106
107 for (StateId s = 0; s < fst->NumStates(); ++s) {
108 mapper->SetState(s);
109 fst->DeleteArcs(s);
110 for (; !mapper->Done(); mapper->Next())
111 fst->AddArc(s, mapper->Value());
112 fst->SetFinal(s, mapper->Final(s));
113 }
114
115 fst->SetProperties(mapper->Properties(props), kFstProperties);
116 }
117
118 // Maps an arc type A using a mapper function object C, passed
119 // by value. This version modifies its Fst input.
120 template<class A, class C>
StateMap(MutableFst<A> * fst,C mapper)121 void StateMap(MutableFst<A> *fst, C mapper) {
122 StateMap(fst, &mapper);
123 }
124
125
126 // Maps an arc type A to an arc type B using mapper function
127 // object C, passed by pointer. This version writes the mapped
128 // input Fst to an output MutableFst.
129 template<class A, class B, class C>
StateMap(const Fst<A> & ifst,MutableFst<B> * ofst,C * mapper)130 void StateMap(const Fst<A> &ifst, MutableFst<B> *ofst, C* mapper) {
131 typedef typename A::StateId StateId;
132 typedef typename A::Weight Weight;
133
134 ofst->DeleteStates();
135
136 if (mapper->InputSymbolsAction() == MAP_COPY_SYMBOLS)
137 ofst->SetInputSymbols(ifst.InputSymbols());
138 else if (mapper->InputSymbolsAction() == MAP_CLEAR_SYMBOLS)
139 ofst->SetInputSymbols(0);
140
141 if (mapper->OutputSymbolsAction() == MAP_COPY_SYMBOLS)
142 ofst->SetOutputSymbols(ifst.OutputSymbols());
143 else if (mapper->OutputSymbolsAction() == MAP_CLEAR_SYMBOLS)
144 ofst->SetOutputSymbols(0);
145
146 uint64 iprops = ifst.Properties(kCopyProperties, false);
147
148 if (ifst.Start() == kNoStateId) {
149 if (iprops & kError) ofst->SetProperties(kError, kError);
150 return;
151 }
152
153 // Add all states.
154 if (ifst.Properties(kExpanded, false))
155 ofst->ReserveStates(CountStates(ifst));
156 for (StateIterator< Fst<A> > siter(ifst); !siter.Done(); siter.Next())
157 ofst->AddState();
158
159 ofst->SetStart(mapper->Start());
160
161 for (StateIterator< Fst<A> > siter(ifst); !siter.Done(); siter.Next()) {
162 StateId s = siter.Value();
163 mapper->SetState(s);
164 for (; !mapper->Done(); mapper->Next())
165 ofst->AddArc(s, mapper->Value());
166 ofst->SetFinal(s, mapper->Final(s));
167 }
168
169 uint64 oprops = ofst->Properties(kFstProperties, false);
170 ofst->SetProperties(mapper->Properties(iprops) | oprops, kFstProperties);
171 }
172
173 // Maps an arc type A to an arc type B using mapper function
174 // object C, passed by value. This version writes the mapped input
175 // Fst to an output MutableFst.
176 template<class A, class B, class C>
StateMap(const Fst<A> & ifst,MutableFst<B> * ofst,C mapper)177 void StateMap(const Fst<A> &ifst, MutableFst<B> *ofst, C mapper) {
178 StateMap(ifst, ofst, &mapper);
179 }
180
181 typedef CacheOptions StateMapFstOptions;
182
183 template <class A, class B, class C> class StateMapFst;
184
185 // Implementation of delayed StateMapFst.
186 template <class A, class B, class C>
187 class StateMapFstImpl : public CacheImpl<B> {
188 public:
189 using FstImpl<B>::SetType;
190 using FstImpl<B>::SetProperties;
191 using FstImpl<B>::SetInputSymbols;
192 using FstImpl<B>::SetOutputSymbols;
193
194 using VectorFstBaseImpl<typename CacheImpl<B>::State>::NumStates;
195
196 using CacheImpl<B>::PushArc;
197 using CacheImpl<B>::HasArcs;
198 using CacheImpl<B>::HasFinal;
199 using CacheImpl<B>::HasStart;
200 using CacheImpl<B>::SetArcs;
201 using CacheImpl<B>::SetFinal;
202 using CacheImpl<B>::SetStart;
203
204 friend class StateIterator< StateMapFst<A, B, C> >;
205
206 typedef B Arc;
207 typedef typename B::Weight Weight;
208 typedef typename B::StateId StateId;
209
StateMapFstImpl(const Fst<A> & fst,const C & mapper,const StateMapFstOptions & opts)210 StateMapFstImpl(const Fst<A> &fst, const C &mapper,
211 const StateMapFstOptions& opts)
212 : CacheImpl<B>(opts),
213 fst_(fst.Copy()),
214 mapper_(new C(mapper, fst_)),
215 own_mapper_(true) {
216 Init();
217 }
218
StateMapFstImpl(const Fst<A> & fst,C * mapper,const StateMapFstOptions & opts)219 StateMapFstImpl(const Fst<A> &fst, C *mapper,
220 const StateMapFstOptions& opts)
221 : CacheImpl<B>(opts),
222 fst_(fst.Copy()),
223 mapper_(mapper),
224 own_mapper_(false) {
225 Init();
226 }
227
StateMapFstImpl(const StateMapFstImpl<A,B,C> & impl)228 StateMapFstImpl(const StateMapFstImpl<A, B, C> &impl)
229 : CacheImpl<B>(impl),
230 fst_(impl.fst_->Copy(true)),
231 mapper_(new C(*impl.mapper_, fst_)),
232 own_mapper_(true) {
233 Init();
234 }
235
~StateMapFstImpl()236 ~StateMapFstImpl() {
237 delete fst_;
238 if (own_mapper_) delete mapper_;
239 }
240
Start()241 StateId Start() {
242 if (!HasStart())
243 SetStart(mapper_->Start());
244 return CacheImpl<B>::Start();
245 }
246
Final(StateId s)247 Weight Final(StateId s) {
248 if (!HasFinal(s))
249 SetFinal(s, mapper_->Final(s));
250 return CacheImpl<B>::Final(s);
251 }
252
NumArcs(StateId s)253 size_t NumArcs(StateId s) {
254 if (!HasArcs(s))
255 Expand(s);
256 return CacheImpl<B>::NumArcs(s);
257 }
258
NumInputEpsilons(StateId s)259 size_t NumInputEpsilons(StateId s) {
260 if (!HasArcs(s))
261 Expand(s);
262 return CacheImpl<B>::NumInputEpsilons(s);
263 }
264
NumOutputEpsilons(StateId s)265 size_t NumOutputEpsilons(StateId s) {
266 if (!HasArcs(s))
267 Expand(s);
268 return CacheImpl<B>::NumOutputEpsilons(s);
269 }
270
InitStateIterator(StateIteratorData<A> * data)271 void InitStateIterator(StateIteratorData<A> *data) const {
272 fst_->InitStateIterator(data);
273 }
274
InitArcIterator(StateId s,ArcIteratorData<B> * data)275 void InitArcIterator(StateId s, ArcIteratorData<B> *data) {
276 if (!HasArcs(s))
277 Expand(s);
278 CacheImpl<B>::InitArcIterator(s, data);
279 }
280
Properties()281 uint64 Properties() const { return Properties(kFstProperties); }
282
283 // Set error if found; return FST impl properties.
Properties(uint64 mask)284 uint64 Properties(uint64 mask) const {
285 if ((mask & kError) && (fst_->Properties(kError, false) ||
286 (mapper_->Properties(0) & kError)))
287 SetProperties(kError, kError);
288 return FstImpl<Arc>::Properties(mask);
289 }
290
Expand(StateId s)291 void Expand(StateId s) {
292 // Add exiting arcs.
293 for (mapper_->SetState(s); !mapper_->Done(); mapper_->Next())
294 PushArc(s, mapper_->Value());
295 SetArcs(s);
296 }
297
298 private:
Init()299 void Init() {
300 SetType("statemap");
301
302 if (mapper_->InputSymbolsAction() == MAP_COPY_SYMBOLS)
303 SetInputSymbols(fst_->InputSymbols());
304 else if (mapper_->InputSymbolsAction() == MAP_CLEAR_SYMBOLS)
305 SetInputSymbols(0);
306
307 if (mapper_->OutputSymbolsAction() == MAP_COPY_SYMBOLS)
308 SetOutputSymbols(fst_->OutputSymbols());
309 else if (mapper_->OutputSymbolsAction() == MAP_CLEAR_SYMBOLS)
310 SetOutputSymbols(0);
311
312 uint64 props = fst_->Properties(kCopyProperties, false);
313 SetProperties(mapper_->Properties(props));
314 }
315
316 const Fst<A> *fst_;
317 C* mapper_;
318 bool own_mapper_;
319
320 void operator=(const StateMapFstImpl<A, B, C> &); // disallow
321 };
322
323
324 // Maps an arc type A to an arc type B using Mapper function object
325 // C. This version is a delayed Fst.
326 template <class A, class B, class C>
327 class StateMapFst : public ImplToFst< StateMapFstImpl<A, B, C> > {
328 public:
329 friend class ArcIterator< StateMapFst<A, B, C> >;
330
331 typedef B Arc;
332 typedef typename B::Weight Weight;
333 typedef typename B::StateId StateId;
334 typedef CacheState<B> State;
335 typedef StateMapFstImpl<A, B, C> Impl;
336
StateMapFst(const Fst<A> & fst,const C & mapper,const StateMapFstOptions & opts)337 StateMapFst(const Fst<A> &fst, const C &mapper,
338 const StateMapFstOptions& opts)
339 : ImplToFst<Impl>(new Impl(fst, mapper, opts)) {}
340
StateMapFst(const Fst<A> & fst,C * mapper,const StateMapFstOptions & opts)341 StateMapFst(const Fst<A> &fst, C* mapper, const StateMapFstOptions& opts)
342 : ImplToFst<Impl>(new Impl(fst, mapper, opts)) {}
343
StateMapFst(const Fst<A> & fst,const C & mapper)344 StateMapFst(const Fst<A> &fst, const C &mapper)
345 : ImplToFst<Impl>(new Impl(fst, mapper, StateMapFstOptions())) {}
346
StateMapFst(const Fst<A> & fst,C * mapper)347 StateMapFst(const Fst<A> &fst, C* mapper)
348 : ImplToFst<Impl>(new Impl(fst, mapper, StateMapFstOptions())) {}
349
350 // See Fst<>::Copy() for doc.
351 StateMapFst(const StateMapFst<A, B, C> &fst, bool safe = false)
352 : ImplToFst<Impl>(fst, safe) {}
353
354 // Get a copy of this StateMapFst. See Fst<>::Copy() for further doc.
355 virtual StateMapFst<A, B, C> *Copy(bool safe = false) const {
356 return new StateMapFst<A, B, C>(*this, safe);
357 }
358
InitStateIterator(StateIteratorData<A> * data)359 virtual void InitStateIterator(StateIteratorData<A> *data) const {
360 GetImpl()->InitStateIterator(data);
361 }
362
InitArcIterator(StateId s,ArcIteratorData<B> * data)363 virtual void InitArcIterator(StateId s, ArcIteratorData<B> *data) const {
364 GetImpl()->InitArcIterator(s, data);
365 }
366
367 private:
368 // Makes visible to friends.
GetImpl()369 Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); }
370
371 void operator=(const StateMapFst<A, B, C> &fst); // disallow
372 };
373
374
375 // Specialization for StateMapFst.
376 template <class A, class B, class C>
377 class ArcIterator< StateMapFst<A, B, C> >
378 : public CacheArcIterator< StateMapFst<A, B, C> > {
379 public:
380 typedef typename A::StateId StateId;
381
ArcIterator(const StateMapFst<A,B,C> & fst,StateId s)382 ArcIterator(const StateMapFst<A, B, C> &fst, StateId s)
383 : CacheArcIterator< StateMapFst<A, B, C> >(fst.GetImpl(), s) {
384 if (!fst.GetImpl()->HasArcs(s))
385 fst.GetImpl()->Expand(s);
386 }
387
388 private:
389 DISALLOW_COPY_AND_ASSIGN(ArcIterator);
390 };
391
392 //
393 // Utility Mappers
394 //
395
396 // Mapper that returns its input.
397 template <class A>
398 class IdentityStateMapper {
399 public:
400 typedef A FromArc;
401 typedef A ToArc;
402
403 typedef typename A::StateId StateId;
404 typedef typename A::Weight Weight;
405
IdentityStateMapper(const Fst<A> & fst)406 explicit IdentityStateMapper(const Fst<A> &fst) : fst_(fst), aiter_(0) {}
407
408 // Allows updating Fst argument; pass only if changed.
409 IdentityStateMapper(const IdentityStateMapper<A> &mapper,
410 const Fst<A> *fst = 0)
411 : fst_(fst ? *fst : mapper.fst_), aiter_(0) {}
412
~IdentityStateMapper()413 ~IdentityStateMapper() { delete aiter_; }
414
Start()415 StateId Start() const { return fst_.Start(); }
416
Final(StateId s)417 Weight Final(StateId s) const { return fst_.Final(s); }
418
SetState(StateId s)419 void SetState(StateId s) {
420 if (aiter_) delete aiter_;
421 aiter_ = new ArcIterator< Fst<A> >(fst_, s);
422 }
423
Done()424 bool Done() const { return aiter_->Done(); }
Value()425 const A &Value() const { return aiter_->Value(); }
Next()426 void Next() { aiter_->Next(); }
427
InputSymbolsAction()428 MapSymbolsAction InputSymbolsAction() const { return MAP_COPY_SYMBOLS; }
OutputSymbolsAction()429 MapSymbolsAction OutputSymbolsAction() const { return MAP_COPY_SYMBOLS;}
430
Properties(uint64 props)431 uint64 Properties(uint64 props) const { return props; }
432
433 private:
434 const Fst<A> &fst_;
435 ArcIterator< Fst<A> > *aiter_;
436 };
437
438 template <class A>
439 class ArcSumMapper {
440 public:
441 typedef A FromArc;
442 typedef A ToArc;
443
444 typedef typename A::StateId StateId;
445 typedef typename A::Weight Weight;
446
ArcSumMapper(const Fst<A> & fst)447 explicit ArcSumMapper(const Fst<A> &fst) : fst_(fst), i_(0) {}
448
449 // Allows updating Fst argument; pass only if changed.
450 ArcSumMapper(const ArcSumMapper<A> &mapper,
451 const Fst<A> *fst = 0)
452 : fst_(fst ? *fst : mapper.fst_), i_(0) {}
453
Start()454 StateId Start() const { return fst_.Start(); }
Final(StateId s)455 Weight Final(StateId s) const { return fst_.Final(s); }
456
SetState(StateId s)457 void SetState(StateId s) {
458 i_ = 0;
459 arcs_.clear();
460 arcs_.reserve(fst_.NumArcs(s));
461 for (ArcIterator<Fst<A> > aiter(fst_, s); !aiter.Done(); aiter.Next())
462 arcs_.push_back(aiter.Value());
463
464 // First sorts the exiting arcs by input label, output label
465 // and destination state and then sums weights of arcs with
466 // the same input label, output label, and destination state.
467 sort(arcs_.begin(), arcs_.end(), comp_);
468 size_t narcs = 0;
469 for (size_t i = 0; i < arcs_.size(); ++i) {
470 if (narcs > 0 && equal_(arcs_[i], arcs_[narcs - 1])) {
471 arcs_[narcs - 1].weight = Plus(arcs_[narcs - 1].weight,
472 arcs_[i].weight);
473 } else {
474 arcs_[narcs++] = arcs_[i];
475 }
476 }
477 arcs_.resize(narcs);
478 }
479
Done()480 bool Done() const { return i_ >= arcs_.size(); }
Value()481 const A &Value() const { return arcs_[i_]; }
Next()482 void Next() { ++i_; }
483
InputSymbolsAction()484 MapSymbolsAction InputSymbolsAction() const { return MAP_COPY_SYMBOLS; }
OutputSymbolsAction()485 MapSymbolsAction OutputSymbolsAction() const { return MAP_COPY_SYMBOLS; }
486
Properties(uint64 props)487 uint64 Properties(uint64 props) const {
488 return props & kArcSortProperties &
489 kDeleteArcsProperties & kWeightInvariantProperties;
490 }
491
492 private:
493 struct Compare {
operatorCompare494 bool operator()(const A& x, const A& y) {
495 if (x.ilabel < y.ilabel) return true;
496 if (x.ilabel > y.ilabel) return false;
497 if (x.olabel < y.olabel) return true;
498 if (x.olabel > y.olabel) return false;
499 if (x.nextstate < y.nextstate) return true;
500 if (x.nextstate > y.nextstate) return false;
501 return false;
502 }
503 };
504
505 struct Equal {
operatorEqual506 bool operator()(const A& x, const A& y) {
507 return (x.ilabel == y.ilabel &&
508 x.olabel == y.olabel &&
509 x.nextstate == y.nextstate);
510 }
511 };
512
513 const Fst<A> &fst_;
514 Compare comp_;
515 Equal equal_;
516 vector<A> arcs_;
517 ssize_t i_; // current arc position
518
519 void operator=(const ArcSumMapper<A> &); // disallow
520 };
521
522 template <class A>
523 class ArcUniqueMapper {
524 public:
525 typedef A FromArc;
526 typedef A ToArc;
527
528 typedef typename A::StateId StateId;
529 typedef typename A::Weight Weight;
530
ArcUniqueMapper(const Fst<A> & fst)531 explicit ArcUniqueMapper(const Fst<A> &fst) : fst_(fst), i_(0) {}
532
533 // Allows updating Fst argument; pass only if changed.
534 ArcUniqueMapper(const ArcSumMapper<A> &mapper,
535 const Fst<A> *fst = 0)
536 : fst_(fst ? *fst : mapper.fst_), i_(0) {}
537
Start()538 StateId Start() const { return fst_.Start(); }
Final(StateId s)539 Weight Final(StateId s) const { return fst_.Final(s); }
540
SetState(StateId s)541 void SetState(StateId s) {
542 i_ = 0;
543 arcs_.clear();
544 arcs_.reserve(fst_.NumArcs(s));
545 for (ArcIterator<Fst<A> > aiter(fst_, s); !aiter.Done(); aiter.Next())
546 arcs_.push_back(aiter.Value());
547
548 // First sorts the exiting arcs by input label, output label
549 // and destination state and then uniques identical arcs
550 sort(arcs_.begin(), arcs_.end(), comp_);
551 typename vector<A>::iterator unique_end =
552 unique(arcs_.begin(), arcs_.end(), equal_);
553 arcs_.resize(unique_end - arcs_.begin());
554 }
555
Done()556 bool Done() const { return i_ >= arcs_.size(); }
Value()557 const A &Value() const { return arcs_[i_]; }
Next()558 void Next() { ++i_; }
559
InputSymbolsAction()560 MapSymbolsAction InputSymbolsAction() const { return MAP_COPY_SYMBOLS; }
OutputSymbolsAction()561 MapSymbolsAction OutputSymbolsAction() const { return MAP_COPY_SYMBOLS; }
562
Properties(uint64 props)563 uint64 Properties(uint64 props) const {
564 return props & kArcSortProperties & kDeleteArcsProperties;
565 }
566
567 private:
568 struct Compare {
operatorCompare569 bool operator()(const A& x, const A& y) {
570 if (x.ilabel < y.ilabel) return true;
571 if (x.ilabel > y.ilabel) return false;
572 if (x.olabel < y.olabel) return true;
573 if (x.olabel > y.olabel) return false;
574 if (x.nextstate < y.nextstate) return true;
575 if (x.nextstate > y.nextstate) return false;
576 return false;
577 }
578 };
579
580 struct Equal {
operatorEqual581 bool operator()(const A& x, const A& y) {
582 return (x.ilabel == y.ilabel &&
583 x.olabel == y.olabel &&
584 x.nextstate == y.nextstate &&
585 x.weight == y.weight);
586 }
587 };
588
589 const Fst<A> &fst_;
590 Compare comp_;
591 Equal equal_;
592 vector<A> arcs_;
593 ssize_t i_; // current arc position
594
595 void operator=(const ArcUniqueMapper<A> &); // disallow
596 };
597
598
599 } // namespace fst
600
601 #endif // FST_LIB_STATE_MAP_H__
602