• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // relabel.h
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 //
16 // \file
17 // Functions and classes to relabel an Fst (either on input or output)
18 //
19 #ifndef FST_LIB_RELABEL_H__
20 #define FST_LIB_RELABEL_H__
21 
22 #include <unordered_map>
23 
24 #include "fst/lib/cache.h"
25 #include "fst/lib/test-properties.h"
26 
27 
28 namespace fst {
29 
30 //
31 // Relabels either the input labels or output labels. The old to
32 // new labels are specified using a vector of pair<Label,Label>.
33 // Any label associations not specified are assumed to be identity
34 // mapping.
35 //
36 // \param fst input fst, must be mutable
37 // \param relabel_pairs vector of pairs indicating old to new mapping
38 // \param relabel_flags whether to relabel input or output
39 //
40 template <class A>
Relabel(MutableFst<A> * fst,const vector<pair<typename A::Label,typename A::Label>> & ipairs,const vector<pair<typename A::Label,typename A::Label>> & opairs)41 void Relabel(
42     MutableFst<A> *fst,
43     const vector<pair<typename A::Label, typename A::Label> >& ipairs,
44     const vector<pair<typename A::Label, typename A::Label> >& opairs) {
45   typedef typename A::StateId StateId;
46   typedef typename A::Label   Label;
47 
48   uint64 props = fst->Properties(kFstProperties, false);
49 
50   // construct label to label hash. Could
51   std::unordered_map<Label, Label> input_map;
52   for (size_t i = 0; i < ipairs.size(); ++i) {
53     input_map[ipairs[i].first] = ipairs[i].second;
54   }
55 
56   std::unordered_map<Label, Label> output_map;
57   for (size_t i = 0; i < opairs.size(); ++i) {
58     output_map[opairs[i].first] = opairs[i].second;
59   }
60 
61   for (StateIterator<MutableFst<A> > siter(*fst);
62        !siter.Done(); siter.Next()) {
63     StateId s = siter.Value();
64     for (MutableArcIterator<MutableFst<A> > aiter(fst, s);
65          !aiter.Done(); aiter.Next()) {
66       A arc = aiter.Value();
67 
68       // relabel input
69       // only relabel if relabel pair defined
70       typename std::unordered_map<Label, Label>::iterator it =
71         input_map.find(arc.ilabel);
72       if (it != input_map.end()) {arc.ilabel = it->second; }
73 
74       // relabel output
75       it = output_map.find(arc.olabel);
76       if (it != output_map.end()) { arc.olabel = it->second; }
77 
78       aiter.SetValue(arc);
79     }
80   }
81 
82   fst->SetProperties(RelabelProperties(props), kFstProperties);
83 }
84 
85 
86 
87 //
88 // Relabels either the input labels or output labels. The old to
89 // new labels mappings are specified using an input Symbol set.
90 // Any label associations not specified are assumed to be identity
91 // mapping.
92 //
93 // \param fst input fst, must be mutable
94 // \param new_symbols symbol set indicating new mapping
95 // \param relabel_flags whether to relabel input or output
96 //
97 template<class A>
Relabel(MutableFst<A> * fst,const SymbolTable * new_isymbols,const SymbolTable * new_osymbols)98 void Relabel(MutableFst<A> *fst,
99              const SymbolTable* new_isymbols,
100              const SymbolTable* new_osymbols) {
101   typedef typename A::StateId StateId;
102   typedef typename A::Label   Label;
103 
104   const SymbolTable* old_isymbols = fst->InputSymbols();
105   const SymbolTable* old_osymbols = fst->OutputSymbols();
106 
107   vector<pair<Label, Label> > ipairs;
108   if (old_isymbols && new_isymbols) {
109     for (SymbolTableIterator syms_iter(*old_isymbols); !syms_iter.Done();
110          syms_iter.Next()) {
111       ipairs.push_back(make_pair(syms_iter.Value(),
112                                  new_isymbols->Find(syms_iter.Symbol())));
113     }
114     fst->SetInputSymbols(new_isymbols);
115   }
116 
117   vector<pair<Label, Label> > opairs;
118   if (old_osymbols && new_osymbols) {
119     for (SymbolTableIterator syms_iter(*old_osymbols); !syms_iter.Done();
120          syms_iter.Next()) {
121       opairs.push_back(make_pair(syms_iter.Value(),
122                                  new_osymbols->Find(syms_iter.Symbol())));
123     }
124     fst->SetOutputSymbols(new_osymbols);
125   }
126 
127   // call relabel using vector of relabel pairs.
128   Relabel(fst, ipairs, opairs);
129 }
130 
131 
132 typedef CacheOptions RelabelFstOptions;
133 
134 template <class A> class RelabelFst;
135 
136 //
137 // \class RelabelFstImpl
138 // \brief Implementation for delayed relabeling
139 //
140 // Relabels an FST from one symbol set to another. Relabeling
141 // can either be on input or output space. RelabelFst implements
142 // a delayed version of the relabel. Arcs are relabeled on the fly
143 // and not cached. I.e each request is recomputed.
144 //
145 template<class A>
146 class RelabelFstImpl : public CacheImpl<A> {
147   friend class StateIterator< RelabelFst<A> >;
148  public:
149   using FstImpl<A>::SetType;
150   using FstImpl<A>::SetProperties;
151   using FstImpl<A>::Properties;
152   using FstImpl<A>::SetInputSymbols;
153   using FstImpl<A>::SetOutputSymbols;
154 
155   using CacheImpl<A>::HasStart;
156   using CacheImpl<A>::HasArcs;
157 
158   typedef typename A::Label   Label;
159   typedef typename A::Weight  Weight;
160   typedef typename A::StateId StateId;
161   typedef CacheState<A> State;
162 
RelabelFstImpl(const Fst<A> & fst,const vector<pair<Label,Label>> & ipairs,const vector<pair<Label,Label>> & opairs,const RelabelFstOptions & opts)163   RelabelFstImpl(const Fst<A>& fst,
164                  const vector<pair<Label, Label> >& ipairs,
165                  const vector<pair<Label, Label> >& opairs,
166                  const RelabelFstOptions &opts)
167       : CacheImpl<A>(opts), fst_(fst.Copy()),
168         relabel_input_(false), relabel_output_(false) {
169     uint64 props = fst.Properties(kCopyProperties, false);
170     SetProperties(RelabelProperties(props));
171     SetType("relabel");
172 
173     // create input label map
174     if (ipairs.size() > 0) {
175       for (size_t i = 0; i < ipairs.size(); ++i) {
176         input_map_[ipairs[i].first] = ipairs[i].second;
177       }
178       relabel_input_ = true;
179     }
180 
181     // create output label map
182     if (opairs.size() > 0) {
183       for (size_t i = 0; i < opairs.size(); ++i) {
184         output_map_[opairs[i].first] = opairs[i].second;
185       }
186       relabel_output_ = true;
187     }
188   }
189 
RelabelFstImpl(const Fst<A> & fst,const SymbolTable * new_isymbols,const SymbolTable * new_osymbols,const RelabelFstOptions & opts)190   RelabelFstImpl(const Fst<A>& fst,
191                  const SymbolTable* new_isymbols,
192                  const SymbolTable* new_osymbols,
193                  const RelabelFstOptions &opts)
194       : CacheImpl<A>(opts), fst_(fst.Copy()),
195         relabel_input_(false), relabel_output_(false) {
196     SetType("relabel");
197 
198     uint64 props = fst.Properties(kCopyProperties, false);
199     SetProperties(RelabelProperties(props));
200     SetInputSymbols(fst.InputSymbols());
201     SetOutputSymbols(fst.OutputSymbols());
202 
203     const SymbolTable* old_isymbols = fst.InputSymbols();
204     const SymbolTable* old_osymbols = fst.OutputSymbols();
205 
206     if (old_isymbols && new_isymbols &&
207         old_isymbols->CheckSum() != new_isymbols->CheckSum()) {
208       for (SymbolTableIterator syms_iter(*old_isymbols); !syms_iter.Done();
209            syms_iter.Next()) {
210         input_map_[syms_iter.Value()] = new_isymbols->Find(syms_iter.Symbol());
211       }
212       SetInputSymbols(new_isymbols);
213       relabel_input_ = true;
214     }
215 
216     if (old_osymbols && new_osymbols &&
217         old_osymbols->CheckSum() != new_osymbols->CheckSum()) {
218       for (SymbolTableIterator syms_iter(*old_osymbols); !syms_iter.Done();
219            syms_iter.Next()) {
220         output_map_[syms_iter.Value()] =
221           new_osymbols->Find(syms_iter.Symbol());
222       }
223       SetOutputSymbols(new_osymbols);
224       relabel_output_ = true;
225     }
226   }
227 
~RelabelFstImpl()228   ~RelabelFstImpl() { delete fst_; }
229 
Start()230   StateId Start() {
231     if (!HasStart()) {
232       StateId s = fst_->Start();
233       SetStart(s);
234     }
235     return CacheImpl<A>::Start();
236   }
237 
Final(StateId s)238   Weight Final(StateId s) {
239     if (!HasFinal(s)) {
240       SetFinal(s, fst_->Final(s));
241     }
242     return CacheImpl<A>::Final(s);
243   }
244 
NumArcs(StateId s)245   size_t NumArcs(StateId s) {
246     if (!HasArcs(s)) {
247       Expand(s);
248     }
249     return CacheImpl<A>::NumArcs(s);
250   }
251 
NumInputEpsilons(StateId s)252   size_t NumInputEpsilons(StateId s) {
253     if (!HasArcs(s)) {
254       Expand(s);
255     }
256     return CacheImpl<A>::NumInputEpsilons(s);
257   }
258 
NumOutputEpsilons(StateId s)259   size_t NumOutputEpsilons(StateId s) {
260     if (!HasArcs(s)) {
261       Expand(s);
262     }
263     return CacheImpl<A>::NumOutputEpsilons(s);
264   }
265 
InitArcIterator(StateId s,ArcIteratorData<A> * data)266   void InitArcIterator(StateId s, ArcIteratorData<A>* data) {
267     if (!HasArcs(s)) {
268       Expand(s);
269     }
270     CacheImpl<A>::InitArcIterator(s, data);
271   }
272 
Expand(StateId s)273   void Expand(StateId s) {
274     for (ArcIterator<Fst<A> > aiter(*fst_, s); !aiter.Done(); aiter.Next()) {
275       A arc = aiter.Value();
276 
277       // relabel input
278       if (relabel_input_) {
279         typename std::unordered_map<Label, Label>::iterator it =
280           input_map_.find(arc.ilabel);
281         if (it != input_map_.end()) { arc.ilabel = it->second; }
282       }
283 
284       // relabel output
285       if (relabel_output_) {
286         typename std::unordered_map<Label, Label>::iterator it =
287           output_map_.find(arc.olabel);
288         if (it != output_map_.end()) { arc.olabel = it->second; }
289       }
290 
291       AddArc(s, arc);
292     }
293     SetArcs(s);
294   }
295 
296 
297  private:
298   const Fst<A> *fst_;
299 
300   std::unordered_map<Label, Label> input_map_;
301   std::unordered_map<Label, Label> output_map_;
302   bool relabel_input_;
303   bool relabel_output_;
304 
305   DISALLOW_EVIL_CONSTRUCTORS(RelabelFstImpl);
306 };
307 
308 
309 //
310 // \class RelabelFst
311 // \brief Delayed implementation of arc relabeling
312 //
313 // This class attaches interface to implementation and handles
314 // reference counting.
315 template <class A>
316 class RelabelFst : public Fst<A> {
317  public:
318   friend class ArcIterator< RelabelFst<A> >;
319   friend class StateIterator< RelabelFst<A> >;
320   friend class CacheArcIterator< RelabelFst<A> >;
321 
322   typedef A Arc;
323   typedef typename A::Label   Label;
324   typedef typename A::Weight  Weight;
325   typedef typename A::StateId StateId;
326   typedef CacheState<A> State;
327 
RelabelFst(const Fst<A> & fst,const vector<pair<Label,Label>> & ipairs,const vector<pair<Label,Label>> & opairs)328   RelabelFst(const Fst<A>& fst,
329              const vector<pair<Label, Label> >& ipairs,
330              const vector<pair<Label, Label> >& opairs) :
331       impl_(new RelabelFstImpl<A>(fst, ipairs, opairs, RelabelFstOptions())) {}
332 
RelabelFst(const Fst<A> & fst,const vector<pair<Label,Label>> & ipairs,const vector<pair<Label,Label>> & opairs,const RelabelFstOptions & opts)333   RelabelFst(const Fst<A>& fst,
334              const vector<pair<Label, Label> >& ipairs,
335              const vector<pair<Label, Label> >& opairs,
336              const RelabelFstOptions &opts)
337       : impl_(new RelabelFstImpl<A>(fst, ipairs, opairs, opts)) {}
338 
RelabelFst(const Fst<A> & fst,const SymbolTable * new_isymbols,const SymbolTable * new_osymbols)339   RelabelFst(const Fst<A>& fst,
340              const SymbolTable* new_isymbols,
341              const SymbolTable* new_osymbols) :
342       impl_(new RelabelFstImpl<A>(fst, new_isymbols, new_osymbols,
343                                   RelabelFstOptions())) {}
344 
RelabelFst(const Fst<A> & fst,const SymbolTable * new_isymbols,const SymbolTable * new_osymbols,const RelabelFstOptions & opts)345   RelabelFst(const Fst<A>& fst,
346              const SymbolTable* new_isymbols,
347              const SymbolTable* new_osymbols,
348              const RelabelFstOptions &opts)
349     : impl_(new RelabelFstImpl<A>(fst, new_isymbols, new_osymbols, opts)) {}
350 
RelabelFst(const RelabelFst<A> & fst)351   RelabelFst(const RelabelFst<A> &fst) : impl_(fst.impl_) {
352     impl_->IncrRefCount();
353   }
354 
~RelabelFst()355   virtual ~RelabelFst() { if (!impl_->DecrRefCount()) delete impl_;  }
356 
Start()357   virtual StateId Start() const { return impl_->Start(); }
358 
Final(StateId s)359   virtual Weight Final(StateId s) const { return impl_->Final(s); }
360 
NumArcs(StateId s)361   virtual size_t NumArcs(StateId s) const { return impl_->NumArcs(s); }
362 
NumInputEpsilons(StateId s)363   virtual size_t NumInputEpsilons(StateId s) const {
364     return impl_->NumInputEpsilons(s);
365   }
366 
NumOutputEpsilons(StateId s)367   virtual size_t NumOutputEpsilons(StateId s) const {
368     return impl_->NumOutputEpsilons(s);
369   }
370 
Properties(uint64 mask,bool test)371   virtual uint64 Properties(uint64 mask, bool test) const {
372     if (test) {
373       uint64 known, test = TestProperties(*this, mask, &known);
374       impl_->SetProperties(test, known);
375       return test & mask;
376     } else {
377       return impl_->Properties(mask);
378     }
379   }
380 
Type()381   virtual const string& Type() const { return impl_->Type(); }
382 
Copy()383   virtual RelabelFst<A> *Copy() const {
384     return new RelabelFst<A>(*this);
385   }
386 
InputSymbols()387   virtual const SymbolTable* InputSymbols() const {
388     return impl_->InputSymbols();
389   }
390 
OutputSymbols()391   virtual const SymbolTable* OutputSymbols() const {
392     return impl_->OutputSymbols();
393   }
394 
395   virtual void InitStateIterator(StateIteratorData<A> *data) const;
396 
InitArcIterator(StateId s,ArcIteratorData<A> * data)397   virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
398     return impl_->InitArcIterator(s, data);
399   }
400 
401  private:
402   RelabelFstImpl<A> *impl_;
403 
404   void operator=(const RelabelFst<A> &fst);  // disallow
405 };
406 
407 // Specialization for RelabelFst.
408 template<class A>
409 class StateIterator< RelabelFst<A> > : public StateIteratorBase<A> {
410  public:
411   typedef typename A::StateId StateId;
412 
StateIterator(const RelabelFst<A> & fst)413   explicit StateIterator(const RelabelFst<A> &fst)
414       : impl_(fst.impl_), siter_(*impl_->fst_), s_(0) {}
415 
Done()416   bool Done() const { return siter_.Done(); }
417 
Value()418   StateId Value() const { return s_; }
419 
Next()420   void Next() {
421     if (!siter_.Done()) {
422       ++s_;
423       siter_.Next();
424     }
425   }
426 
Reset()427   void Reset() {
428     s_ = 0;
429     siter_.Reset();
430   }
431 
432  private:
433   const RelabelFstImpl<A> *impl_;
434   StateIterator< Fst<A> > siter_;
435   StateId s_;
436 
437   DISALLOW_EVIL_CONSTRUCTORS(StateIterator);
438 };
439 
440 
441 // Specialization for RelabelFst.
442 template <class A>
443 class ArcIterator< RelabelFst<A> >
444     : public CacheArcIterator< RelabelFst<A> > {
445  public:
446   typedef typename A::StateId StateId;
447 
ArcIterator(const RelabelFst<A> & fst,StateId s)448   ArcIterator(const RelabelFst<A> &fst, StateId s)
449       : CacheArcIterator< RelabelFst<A> >(fst, s) {
450     if (!fst.impl_->HasArcs(s))
451       fst.impl_->Expand(s);
452   }
453 
454  private:
455   DISALLOW_EVIL_CONSTRUCTORS(ArcIterator);
456 };
457 
458 template <class A> inline
InitStateIterator(StateIteratorData<A> * data)459 void RelabelFst<A>::InitStateIterator(StateIteratorData<A> *data) const {
460   data->base = new StateIterator< RelabelFst<A> >(*this);
461 }
462 
463 // Useful alias when using StdArc.
464 typedef RelabelFst<StdArc> StdRelabelFst;
465 
466 }  // namespace fst
467 
468 #endif  // FST_LIB_RELABEL_H__
469