• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // relabel.h
2 
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Copyright 2005-2010 Google, Inc.
16 // Author: johans@google.com (Johan Schalkwyk)
17 //
18 // \file
19 // Functions and classes to relabel an Fst (either on input or output)
20 //
21 #ifndef FST_LIB_RELABEL_H__
22 #define FST_LIB_RELABEL_H__
23 
24 #include <tr1/unordered_map>
25 using std::tr1::unordered_map;
26 using std::tr1::unordered_multimap;
27 #include <string>
28 #include <utility>
29 using std::pair; using std::make_pair;
30 #include <vector>
31 using std::vector;
32 
33 #include <fst/cache.h>
34 #include <fst/test-properties.h>
35 
36 
37 #include <tr1/unordered_map>
38 using std::tr1::unordered_map;
39 using std::tr1::unordered_multimap;
40 
41 namespace fst {
42 
43 //
44 // Relabels either the input labels or output labels. The old to
45 // new labels are specified using a vector of pair<Label,Label>.
46 // Any label associations not specified are assumed to be identity
47 // mapping.
48 //
49 // \param fst input fst, must be mutable
50 // \param ipairs vector of input label pairs indicating old to new mapping
51 // \param opairs vector of output label pairs indicating old to new mapping
52 //
53 template <class A>
Relabel(MutableFst<A> * fst,const vector<pair<typename A::Label,typename A::Label>> & ipairs,const vector<pair<typename A::Label,typename A::Label>> & opairs)54 void Relabel(
55     MutableFst<A> *fst,
56     const vector<pair<typename A::Label, typename A::Label> >& ipairs,
57     const vector<pair<typename A::Label, typename A::Label> >& opairs) {
58   typedef typename A::StateId StateId;
59   typedef typename A::Label   Label;
60 
61   uint64 props = fst->Properties(kFstProperties, false);
62 
63   // construct label to label hash.
64   unordered_map<Label, Label> input_map;
65   for (size_t i = 0; i < ipairs.size(); ++i) {
66     input_map[ipairs[i].first] = ipairs[i].second;
67   }
68 
69   unordered_map<Label, Label> output_map;
70   for (size_t i = 0; i < opairs.size(); ++i) {
71     output_map[opairs[i].first] = opairs[i].second;
72   }
73 
74   for (StateIterator<MutableFst<A> > siter(*fst);
75        !siter.Done(); siter.Next()) {
76     StateId s = siter.Value();
77     for (MutableArcIterator<MutableFst<A> > aiter(fst, s);
78          !aiter.Done(); aiter.Next()) {
79       A arc = aiter.Value();
80 
81       // relabel input
82       // only relabel if relabel pair defined
83       typename unordered_map<Label, Label>::iterator it =
84         input_map.find(arc.ilabel);
85       if (it != input_map.end()) {
86         if (it->second == kNoLabel) {
87           FSTERROR() << "Input symbol id " << arc.ilabel
88                      << " missing from target vocabulary";
89           fst->SetProperties(kError, kError);
90           return;
91         }
92         arc.ilabel = it->second;
93       }
94 
95       // relabel output
96       it = output_map.find(arc.olabel);
97       if (it != output_map.end()) {
98         if (it->second == kNoLabel) {
99           FSTERROR() << "Output symbol id " << arc.olabel
100                      << " missing from target vocabulary";
101           fst->SetProperties(kError, kError);
102           return;
103         }
104         arc.olabel = it->second;
105       }
106 
107       aiter.SetValue(arc);
108     }
109   }
110 
111   fst->SetProperties(RelabelProperties(props), kFstProperties);
112 }
113 
114 //
115 // Relabels either the input labels or output labels. The old to
116 // new labels mappings are specified using an input Symbol set.
117 // Any label associations not specified are assumed to be identity
118 // mapping.
119 //
120 // \param fst input fst, must be mutable
121 // \param new_isymbols symbol set indicating new mapping of input symbols
122 // \param new_osymbols symbol set indicating new mapping of output symbols
123 //
124 template<class A>
Relabel(MutableFst<A> * fst,const SymbolTable * new_isymbols,const SymbolTable * new_osymbols)125 void Relabel(MutableFst<A> *fst,
126              const SymbolTable* new_isymbols,
127              const SymbolTable* new_osymbols) {
128   Relabel(fst,
129           fst->InputSymbols(), new_isymbols, true,
130           fst->OutputSymbols(), new_osymbols, true);
131 }
132 
133 template<class A>
Relabel(MutableFst<A> * fst,const SymbolTable * old_isymbols,const SymbolTable * new_isymbols,bool attach_new_isymbols,const SymbolTable * old_osymbols,const SymbolTable * new_osymbols,bool attach_new_osymbols)134 void Relabel(MutableFst<A> *fst,
135              const SymbolTable* old_isymbols,
136              const SymbolTable* new_isymbols,
137              bool attach_new_isymbols,
138              const SymbolTable* old_osymbols,
139              const SymbolTable* new_osymbols,
140              bool attach_new_osymbols) {
141   typedef typename A::StateId StateId;
142   typedef typename A::Label   Label;
143 
144   vector<pair<Label, Label> > ipairs;
145   if (old_isymbols && new_isymbols) {
146     for (SymbolTableIterator syms_iter(*old_isymbols); !syms_iter.Done();
147          syms_iter.Next()) {
148       string isymbol = syms_iter.Symbol();
149       int isymbol_val = syms_iter.Value();
150       int new_isymbol_val = new_isymbols->Find(isymbol);
151       ipairs.push_back(make_pair(isymbol_val, new_isymbol_val));
152     }
153     if (attach_new_isymbols)
154       fst->SetInputSymbols(new_isymbols);
155   }
156 
157   vector<pair<Label, Label> > opairs;
158   if (old_osymbols && new_osymbols) {
159     for (SymbolTableIterator syms_iter(*old_osymbols); !syms_iter.Done();
160          syms_iter.Next()) {
161       string osymbol = syms_iter.Symbol();
162       int osymbol_val = syms_iter.Value();
163       int new_osymbol_val = new_osymbols->Find(osymbol);
164       opairs.push_back(make_pair(osymbol_val, new_osymbol_val));
165     }
166     if (attach_new_osymbols)
167       fst->SetOutputSymbols(new_osymbols);
168   }
169 
170   // call relabel using vector of relabel pairs.
171   Relabel(fst, ipairs, opairs);
172 }
173 
174 
175 typedef CacheOptions RelabelFstOptions;
176 
177 template <class A> class RelabelFst;
178 
179 //
180 // \class RelabelFstImpl
181 // \brief Implementation for delayed relabeling
182 //
183 // Relabels an FST from one symbol set to another. Relabeling
184 // can either be on input or output space. RelabelFst implements
185 // a delayed version of the relabel. Arcs are relabeled on the fly
186 // and not cached. I.e each request is recomputed.
187 //
188 template<class A>
189 class RelabelFstImpl : public CacheImpl<A> {
190   friend class StateIterator< RelabelFst<A> >;
191  public:
192   using FstImpl<A>::SetType;
193   using FstImpl<A>::SetProperties;
194   using FstImpl<A>::WriteHeader;
195   using FstImpl<A>::SetInputSymbols;
196   using FstImpl<A>::SetOutputSymbols;
197 
198   using CacheImpl<A>::PushArc;
199   using CacheImpl<A>::HasArcs;
200   using CacheImpl<A>::HasFinal;
201   using CacheImpl<A>::HasStart;
202   using CacheImpl<A>::SetArcs;
203   using CacheImpl<A>::SetFinal;
204   using CacheImpl<A>::SetStart;
205 
206   typedef A Arc;
207   typedef typename A::Label   Label;
208   typedef typename A::Weight  Weight;
209   typedef typename A::StateId StateId;
210   typedef CacheState<A> State;
211 
RelabelFstImpl(const Fst<A> & fst,const vector<pair<Label,Label>> & ipairs,const vector<pair<Label,Label>> & opairs,const RelabelFstOptions & opts)212   RelabelFstImpl(const Fst<A>& fst,
213                  const vector<pair<Label, Label> >& ipairs,
214                  const vector<pair<Label, Label> >& opairs,
215                  const RelabelFstOptions &opts)
216       : CacheImpl<A>(opts), fst_(fst.Copy()),
217         relabel_input_(false), relabel_output_(false) {
218     uint64 props = fst.Properties(kCopyProperties, false);
219     SetProperties(RelabelProperties(props));
220     SetType("relabel");
221 
222     // create input label map
223     if (ipairs.size() > 0) {
224       for (size_t i = 0; i < ipairs.size(); ++i) {
225         input_map_[ipairs[i].first] = ipairs[i].second;
226       }
227       relabel_input_ = true;
228     }
229 
230     // create output label map
231     if (opairs.size() > 0) {
232       for (size_t i = 0; i < opairs.size(); ++i) {
233         output_map_[opairs[i].first] = opairs[i].second;
234       }
235       relabel_output_ = true;
236     }
237   }
238 
RelabelFstImpl(const Fst<A> & fst,const SymbolTable * old_isymbols,const SymbolTable * new_isymbols,const SymbolTable * old_osymbols,const SymbolTable * new_osymbols,const RelabelFstOptions & opts)239   RelabelFstImpl(const Fst<A>& fst,
240                  const SymbolTable* old_isymbols,
241                  const SymbolTable* new_isymbols,
242                  const SymbolTable* old_osymbols,
243                  const SymbolTable* new_osymbols,
244                  const RelabelFstOptions &opts)
245       : CacheImpl<A>(opts), fst_(fst.Copy()),
246         relabel_input_(false), relabel_output_(false) {
247     SetType("relabel");
248 
249     uint64 props = fst.Properties(kCopyProperties, false);
250     SetProperties(RelabelProperties(props));
251     SetInputSymbols(old_isymbols);
252     SetOutputSymbols(old_osymbols);
253 
254     if (old_isymbols && new_isymbols &&
255         old_isymbols->LabeledCheckSum() != new_isymbols->LabeledCheckSum()) {
256       for (SymbolTableIterator syms_iter(*old_isymbols); !syms_iter.Done();
257            syms_iter.Next()) {
258         input_map_[syms_iter.Value()] = new_isymbols->Find(syms_iter.Symbol());
259       }
260       SetInputSymbols(new_isymbols);
261       relabel_input_ = true;
262     }
263 
264     if (old_osymbols && new_osymbols &&
265         old_osymbols->LabeledCheckSum() != new_osymbols->LabeledCheckSum()) {
266       for (SymbolTableIterator syms_iter(*old_osymbols); !syms_iter.Done();
267            syms_iter.Next()) {
268         output_map_[syms_iter.Value()] =
269           new_osymbols->Find(syms_iter.Symbol());
270       }
271       SetOutputSymbols(new_osymbols);
272       relabel_output_ = true;
273     }
274   }
275 
RelabelFstImpl(const RelabelFstImpl<A> & impl)276   RelabelFstImpl(const RelabelFstImpl<A>& impl)
277       : CacheImpl<A>(impl),
278         fst_(impl.fst_->Copy(true)),
279         input_map_(impl.input_map_),
280         output_map_(impl.output_map_),
281         relabel_input_(impl.relabel_input_),
282         relabel_output_(impl.relabel_output_) {
283     SetType("relabel");
284     SetProperties(impl.Properties(), kCopyProperties);
285     SetInputSymbols(impl.InputSymbols());
286     SetOutputSymbols(impl.OutputSymbols());
287   }
288 
~RelabelFstImpl()289   ~RelabelFstImpl() { delete fst_; }
290 
Start()291   StateId Start() {
292     if (!HasStart()) {
293       StateId s = fst_->Start();
294       SetStart(s);
295     }
296     return CacheImpl<A>::Start();
297   }
298 
Final(StateId s)299   Weight Final(StateId s) {
300     if (!HasFinal(s)) {
301       SetFinal(s, fst_->Final(s));
302     }
303     return CacheImpl<A>::Final(s);
304   }
305 
NumArcs(StateId s)306   size_t NumArcs(StateId s) {
307     if (!HasArcs(s)) {
308       Expand(s);
309     }
310     return CacheImpl<A>::NumArcs(s);
311   }
312 
NumInputEpsilons(StateId s)313   size_t NumInputEpsilons(StateId s) {
314     if (!HasArcs(s)) {
315       Expand(s);
316     }
317     return CacheImpl<A>::NumInputEpsilons(s);
318   }
319 
NumOutputEpsilons(StateId s)320   size_t NumOutputEpsilons(StateId s) {
321     if (!HasArcs(s)) {
322       Expand(s);
323     }
324     return CacheImpl<A>::NumOutputEpsilons(s);
325   }
326 
Properties()327   uint64 Properties() const { return Properties(kFstProperties); }
328 
329   // Set error if found; return FST impl properties.
Properties(uint64 mask)330   uint64 Properties(uint64 mask) const {
331     if ((mask & kError) && fst_->Properties(kError, false))
332       SetProperties(kError, kError);
333     return FstImpl<Arc>::Properties(mask);
334   }
335 
InitArcIterator(StateId s,ArcIteratorData<A> * data)336   void InitArcIterator(StateId s, ArcIteratorData<A>* data) {
337     if (!HasArcs(s)) {
338       Expand(s);
339     }
340     CacheImpl<A>::InitArcIterator(s, data);
341   }
342 
Expand(StateId s)343   void Expand(StateId s) {
344     for (ArcIterator<Fst<A> > aiter(*fst_, s); !aiter.Done(); aiter.Next()) {
345       A arc = aiter.Value();
346 
347       // relabel input
348       if (relabel_input_) {
349         typename unordered_map<Label, Label>::iterator it =
350           input_map_.find(arc.ilabel);
351         if (it != input_map_.end()) { arc.ilabel = it->second; }
352       }
353 
354       // relabel output
355       if (relabel_output_) {
356         typename unordered_map<Label, Label>::iterator it =
357           output_map_.find(arc.olabel);
358         if (it != output_map_.end()) { arc.olabel = it->second; }
359       }
360 
361       PushArc(s, arc);
362     }
363     SetArcs(s);
364   }
365 
366 
367  private:
368   const Fst<A> *fst_;
369 
370   unordered_map<Label, Label> input_map_;
371   unordered_map<Label, Label> output_map_;
372   bool relabel_input_;
373   bool relabel_output_;
374 
375   void operator=(const RelabelFstImpl<A> &);  // disallow
376 };
377 
378 
379 //
380 // \class RelabelFst
381 // \brief Delayed implementation of arc relabeling
382 //
383 // This class attaches interface to implementation and handles
384 // reference counting, delegating most methods to ImplToFst.
385 template <class A>
386 class RelabelFst : public ImplToFst< RelabelFstImpl<A> > {
387  public:
388   friend class ArcIterator< RelabelFst<A> >;
389   friend class StateIterator< RelabelFst<A> >;
390 
391   typedef A Arc;
392   typedef typename A::Label   Label;
393   typedef typename A::Weight  Weight;
394   typedef typename A::StateId StateId;
395   typedef CacheState<A> State;
396   typedef RelabelFstImpl<A> Impl;
397 
RelabelFst(const Fst<A> & fst,const vector<pair<Label,Label>> & ipairs,const vector<pair<Label,Label>> & opairs)398   RelabelFst(const Fst<A>& fst,
399              const vector<pair<Label, Label> >& ipairs,
400              const vector<pair<Label, Label> >& opairs)
401       : ImplToFst<Impl>(new Impl(fst, ipairs, opairs, RelabelFstOptions())) {}
402 
RelabelFst(const Fst<A> & fst,const vector<pair<Label,Label>> & ipairs,const vector<pair<Label,Label>> & opairs,const RelabelFstOptions & opts)403   RelabelFst(const Fst<A>& fst,
404              const vector<pair<Label, Label> >& ipairs,
405              const vector<pair<Label, Label> >& opairs,
406              const RelabelFstOptions &opts)
407       : ImplToFst<Impl>(new Impl(fst, ipairs, opairs, opts)) {}
408 
RelabelFst(const Fst<A> & fst,const SymbolTable * new_isymbols,const SymbolTable * new_osymbols)409   RelabelFst(const Fst<A>& fst,
410              const SymbolTable* new_isymbols,
411              const SymbolTable* new_osymbols)
412       : ImplToFst<Impl>(new Impl(fst, fst.InputSymbols(), new_isymbols,
413                                  fst.OutputSymbols(), new_osymbols,
414                                  RelabelFstOptions())) {}
415 
RelabelFst(const Fst<A> & fst,const SymbolTable * new_isymbols,const SymbolTable * new_osymbols,const RelabelFstOptions & opts)416   RelabelFst(const Fst<A>& fst,
417              const SymbolTable* new_isymbols,
418              const SymbolTable* new_osymbols,
419              const RelabelFstOptions &opts)
420       : ImplToFst<Impl>(new Impl(fst, fst.InputSymbols(), new_isymbols,
421                                  fst.OutputSymbols(), new_osymbols, opts)) {}
422 
RelabelFst(const Fst<A> & fst,const SymbolTable * old_isymbols,const SymbolTable * new_isymbols,const SymbolTable * old_osymbols,const SymbolTable * new_osymbols)423   RelabelFst(const Fst<A>& fst,
424              const SymbolTable* old_isymbols,
425              const SymbolTable* new_isymbols,
426              const SymbolTable* old_osymbols,
427              const SymbolTable* new_osymbols)
428     : ImplToFst<Impl>(new Impl(fst, old_isymbols, new_isymbols, old_osymbols,
429                                new_osymbols, RelabelFstOptions())) {}
430 
RelabelFst(const Fst<A> & fst,const SymbolTable * old_isymbols,const SymbolTable * new_isymbols,const SymbolTable * old_osymbols,const SymbolTable * new_osymbols,const RelabelFstOptions & opts)431   RelabelFst(const Fst<A>& fst,
432              const SymbolTable* old_isymbols,
433              const SymbolTable* new_isymbols,
434              const SymbolTable* old_osymbols,
435              const SymbolTable* new_osymbols,
436              const RelabelFstOptions &opts)
437     : ImplToFst<Impl>(new Impl(fst, old_isymbols, new_isymbols, old_osymbols,
438                                new_osymbols, opts)) {}
439 
440   // See Fst<>::Copy() for doc.
441   RelabelFst(const RelabelFst<A> &fst, bool safe = false)
442     : ImplToFst<Impl>(fst, safe) {}
443 
444   // Get a copy of this RelabelFst. See Fst<>::Copy() for further doc.
445   virtual RelabelFst<A> *Copy(bool safe = false) const {
446     return new RelabelFst<A>(*this, safe);
447   }
448 
449   virtual void InitStateIterator(StateIteratorData<A> *data) const;
450 
InitArcIterator(StateId s,ArcIteratorData<A> * data)451   virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
452     return GetImpl()->InitArcIterator(s, data);
453   }
454 
455  private:
456   // Makes visible to friends.
GetImpl()457   Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); }
458 
459   void operator=(const RelabelFst<A> &fst);  // disallow
460 };
461 
462 // Specialization for RelabelFst.
463 template<class A>
464 class StateIterator< RelabelFst<A> > : public StateIteratorBase<A> {
465  public:
466   typedef typename A::StateId StateId;
467 
StateIterator(const RelabelFst<A> & fst)468   explicit StateIterator(const RelabelFst<A> &fst)
469       : impl_(fst.GetImpl()), siter_(*impl_->fst_), s_(0) {}
470 
Done()471   bool Done() const { return siter_.Done(); }
472 
Value()473   StateId Value() const { return s_; }
474 
Next()475   void Next() {
476     if (!siter_.Done()) {
477       ++s_;
478       siter_.Next();
479     }
480   }
481 
Reset()482   void Reset() {
483     s_ = 0;
484     siter_.Reset();
485   }
486 
487  private:
Done_()488   bool Done_() const { return Done(); }
Value_()489   StateId Value_() const { return Value(); }
Next_()490   void Next_() { Next(); }
Reset_()491   void Reset_() { Reset(); }
492 
493   const RelabelFstImpl<A> *impl_;
494   StateIterator< Fst<A> > siter_;
495   StateId s_;
496 
497   DISALLOW_COPY_AND_ASSIGN(StateIterator);
498 };
499 
500 
501 // Specialization for RelabelFst.
502 template <class A>
503 class ArcIterator< RelabelFst<A> >
504     : public CacheArcIterator< RelabelFst<A> > {
505  public:
506   typedef typename A::StateId StateId;
507 
ArcIterator(const RelabelFst<A> & fst,StateId s)508   ArcIterator(const RelabelFst<A> &fst, StateId s)
509       : CacheArcIterator< RelabelFst<A> >(fst.GetImpl(), s) {
510     if (!fst.GetImpl()->HasArcs(s))
511       fst.GetImpl()->Expand(s);
512   }
513 
514  private:
515   DISALLOW_COPY_AND_ASSIGN(ArcIterator);
516 };
517 
518 template <class A> inline
InitStateIterator(StateIteratorData<A> * data)519 void RelabelFst<A>::InitStateIterator(StateIteratorData<A> *data) const {
520   data->base = new StateIterator< RelabelFst<A> >(*this);
521 }
522 
523 // Useful alias when using StdArc.
524 typedef RelabelFst<StdArc> StdRelabelFst;
525 
526 }  // namespace fst
527 
528 #endif  // FST_LIB_RELABEL_H__
529