• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // relabel.h
2 
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Copyright 2005-2010 Google, Inc.
16 // Author: johans@google.com (Johan Schalkwyk)
17 //
18 // \file
19 // Functions and classes to relabel an Fst (either on input or output)
20 //
21 #ifndef FST_LIB_RELABEL_H__
22 #define FST_LIB_RELABEL_H__
23 
24 #include <unordered_map>
25 using std::tr1::unordered_map;
26 using std::tr1::unordered_multimap;
27 #include <string>
28 #include <utility>
29 using std::pair; using std::make_pair;
30 #include <vector>
31 using std::vector;
32 
33 #include <fst/cache.h>
34 #include <fst/test-properties.h>
35 
36 
37 namespace fst {
38 
39 //
40 // Relabels either the input labels or output labels. The old to
41 // new labels are specified using a vector of pair<Label,Label>.
42 // Any label associations not specified are assumed to be identity
43 // mapping.
44 //
45 // \param fst input fst, must be mutable
46 // \param ipairs vector of input label pairs indicating old to new mapping
47 // \param opairs vector of output label pairs indicating old to new mapping
48 //
49 template <class A>
Relabel(MutableFst<A> * fst,const vector<pair<typename A::Label,typename A::Label>> & ipairs,const vector<pair<typename A::Label,typename A::Label>> & opairs)50 void Relabel(
51     MutableFst<A> *fst,
52     const vector<pair<typename A::Label, typename A::Label> >& ipairs,
53     const vector<pair<typename A::Label, typename A::Label> >& opairs) {
54   typedef typename A::StateId StateId;
55   typedef typename A::Label   Label;
56 
57   uint64 props = fst->Properties(kFstProperties, false);
58 
59   // construct label to label hash.
60   unordered_map<Label, Label> input_map;
61   for (size_t i = 0; i < ipairs.size(); ++i) {
62     input_map[ipairs[i].first] = ipairs[i].second;
63   }
64 
65   unordered_map<Label, Label> output_map;
66   for (size_t i = 0; i < opairs.size(); ++i) {
67     output_map[opairs[i].first] = opairs[i].second;
68   }
69 
70   for (StateIterator<MutableFst<A> > siter(*fst);
71        !siter.Done(); siter.Next()) {
72     StateId s = siter.Value();
73     for (MutableArcIterator<MutableFst<A> > aiter(fst, s);
74          !aiter.Done(); aiter.Next()) {
75       A arc = aiter.Value();
76 
77       // relabel input
78       // only relabel if relabel pair defined
79       typename unordered_map<Label, Label>::iterator it =
80         input_map.find(arc.ilabel);
81       if (it != input_map.end()) {
82         if (it->second == kNoLabel) {
83           FSTERROR() << "Input symbol id " << arc.ilabel
84                      << " missing from target vocabulary";
85           fst->SetProperties(kError, kError);
86           return;
87         }
88         arc.ilabel = it->second;
89       }
90 
91       // relabel output
92       it = output_map.find(arc.olabel);
93       if (it != output_map.end()) {
94         if (it->second == kNoLabel) {
95           FSTERROR() << "Output symbol id " << arc.olabel
96                      << " missing from target vocabulary";
97           fst->SetProperties(kError, kError);
98           return;
99         }
100         arc.olabel = it->second;
101       }
102 
103       aiter.SetValue(arc);
104     }
105   }
106 
107   fst->SetProperties(RelabelProperties(props), kFstProperties);
108 }
109 
110 //
111 // Relabels either the input labels or output labels. The old to
112 // new labels mappings are specified using an input Symbol set.
113 // Any label associations not specified are assumed to be identity
114 // mapping.
115 //
116 // \param fst input fst, must be mutable
117 // \param new_isymbols symbol set indicating new mapping of input symbols
118 // \param new_osymbols symbol set indicating new mapping of output symbols
119 //
120 template<class A>
Relabel(MutableFst<A> * fst,const SymbolTable * new_isymbols,const SymbolTable * new_osymbols)121 void Relabel(MutableFst<A> *fst,
122              const SymbolTable* new_isymbols,
123              const SymbolTable* new_osymbols) {
124   Relabel(fst,
125           fst->InputSymbols(), new_isymbols, true,
126           fst->OutputSymbols(), new_osymbols, true);
127 }
128 
129 template<class A>
Relabel(MutableFst<A> * fst,const SymbolTable * old_isymbols,const SymbolTable * new_isymbols,bool attach_new_isymbols,const SymbolTable * old_osymbols,const SymbolTable * new_osymbols,bool attach_new_osymbols)130 void Relabel(MutableFst<A> *fst,
131              const SymbolTable* old_isymbols,
132              const SymbolTable* new_isymbols,
133              bool attach_new_isymbols,
134              const SymbolTable* old_osymbols,
135              const SymbolTable* new_osymbols,
136              bool attach_new_osymbols) {
137   typedef typename A::StateId StateId;
138   typedef typename A::Label   Label;
139 
140   vector<pair<Label, Label> > ipairs;
141   if (old_isymbols && new_isymbols) {
142     for (SymbolTableIterator syms_iter(*old_isymbols); !syms_iter.Done();
143          syms_iter.Next()) {
144       string isymbol = syms_iter.Symbol();
145       int isymbol_val = syms_iter.Value();
146       int new_isymbol_val = new_isymbols->Find(isymbol);
147       ipairs.push_back(make_pair(isymbol_val, new_isymbol_val));
148     }
149     if (attach_new_isymbols)
150       fst->SetInputSymbols(new_isymbols);
151   }
152 
153   vector<pair<Label, Label> > opairs;
154   if (old_osymbols && new_osymbols) {
155     for (SymbolTableIterator syms_iter(*old_osymbols); !syms_iter.Done();
156          syms_iter.Next()) {
157       string osymbol = syms_iter.Symbol();
158       int osymbol_val = syms_iter.Value();
159       int new_osymbol_val = new_osymbols->Find(osymbol);
160       opairs.push_back(make_pair(osymbol_val, new_osymbol_val));
161     }
162     if (attach_new_osymbols)
163       fst->SetOutputSymbols(new_osymbols);
164   }
165 
166   // call relabel using vector of relabel pairs.
167   Relabel(fst, ipairs, opairs);
168 }
169 
170 
171 typedef CacheOptions RelabelFstOptions;
172 
173 template <class A> class RelabelFst;
174 
175 //
176 // \class RelabelFstImpl
177 // \brief Implementation for delayed relabeling
178 //
179 // Relabels an FST from one symbol set to another. Relabeling
180 // can either be on input or output space. RelabelFst implements
181 // a delayed version of the relabel. Arcs are relabeled on the fly
182 // and not cached. I.e each request is recomputed.
183 //
184 template<class A>
185 class RelabelFstImpl : public CacheImpl<A> {
186   friend class StateIterator< RelabelFst<A> >;
187  public:
188   using FstImpl<A>::SetType;
189   using FstImpl<A>::SetProperties;
190   using FstImpl<A>::WriteHeader;
191   using FstImpl<A>::SetInputSymbols;
192   using FstImpl<A>::SetOutputSymbols;
193 
194   using CacheImpl<A>::PushArc;
195   using CacheImpl<A>::HasArcs;
196   using CacheImpl<A>::HasFinal;
197   using CacheImpl<A>::HasStart;
198   using CacheImpl<A>::SetArcs;
199   using CacheImpl<A>::SetFinal;
200   using CacheImpl<A>::SetStart;
201 
202   typedef A Arc;
203   typedef typename A::Label   Label;
204   typedef typename A::Weight  Weight;
205   typedef typename A::StateId StateId;
206   typedef CacheState<A> State;
207 
RelabelFstImpl(const Fst<A> & fst,const vector<pair<Label,Label>> & ipairs,const vector<pair<Label,Label>> & opairs,const RelabelFstOptions & opts)208   RelabelFstImpl(const Fst<A>& fst,
209                  const vector<pair<Label, Label> >& ipairs,
210                  const vector<pair<Label, Label> >& opairs,
211                  const RelabelFstOptions &opts)
212       : CacheImpl<A>(opts), fst_(fst.Copy()),
213         relabel_input_(false), relabel_output_(false) {
214     uint64 props = fst.Properties(kCopyProperties, false);
215     SetProperties(RelabelProperties(props));
216     SetType("relabel");
217 
218     // create input label map
219     if (ipairs.size() > 0) {
220       for (size_t i = 0; i < ipairs.size(); ++i) {
221         input_map_[ipairs[i].first] = ipairs[i].second;
222       }
223       relabel_input_ = true;
224     }
225 
226     // create output label map
227     if (opairs.size() > 0) {
228       for (size_t i = 0; i < opairs.size(); ++i) {
229         output_map_[opairs[i].first] = opairs[i].second;
230       }
231       relabel_output_ = true;
232     }
233   }
234 
RelabelFstImpl(const Fst<A> & fst,const SymbolTable * old_isymbols,const SymbolTable * new_isymbols,const SymbolTable * old_osymbols,const SymbolTable * new_osymbols,const RelabelFstOptions & opts)235   RelabelFstImpl(const Fst<A>& fst,
236                  const SymbolTable* old_isymbols,
237                  const SymbolTable* new_isymbols,
238                  const SymbolTable* old_osymbols,
239                  const SymbolTable* new_osymbols,
240                  const RelabelFstOptions &opts)
241       : CacheImpl<A>(opts), fst_(fst.Copy()),
242         relabel_input_(false), relabel_output_(false) {
243     SetType("relabel");
244 
245     uint64 props = fst.Properties(kCopyProperties, false);
246     SetProperties(RelabelProperties(props));
247     SetInputSymbols(old_isymbols);
248     SetOutputSymbols(old_osymbols);
249 
250     if (old_isymbols && new_isymbols &&
251         old_isymbols->LabeledCheckSum() != new_isymbols->LabeledCheckSum()) {
252       for (SymbolTableIterator syms_iter(*old_isymbols); !syms_iter.Done();
253            syms_iter.Next()) {
254         input_map_[syms_iter.Value()] = new_isymbols->Find(syms_iter.Symbol());
255       }
256       SetInputSymbols(new_isymbols);
257       relabel_input_ = true;
258     }
259 
260     if (old_osymbols && new_osymbols &&
261         old_osymbols->LabeledCheckSum() != new_osymbols->LabeledCheckSum()) {
262       for (SymbolTableIterator syms_iter(*old_osymbols); !syms_iter.Done();
263            syms_iter.Next()) {
264         output_map_[syms_iter.Value()] =
265           new_osymbols->Find(syms_iter.Symbol());
266       }
267       SetOutputSymbols(new_osymbols);
268       relabel_output_ = true;
269     }
270   }
271 
RelabelFstImpl(const RelabelFstImpl<A> & impl)272   RelabelFstImpl(const RelabelFstImpl<A>& impl)
273       : CacheImpl<A>(impl),
274         fst_(impl.fst_->Copy(true)),
275         input_map_(impl.input_map_),
276         output_map_(impl.output_map_),
277         relabel_input_(impl.relabel_input_),
278         relabel_output_(impl.relabel_output_) {
279     SetType("relabel");
280     SetProperties(impl.Properties(), kCopyProperties);
281     SetInputSymbols(impl.InputSymbols());
282     SetOutputSymbols(impl.OutputSymbols());
283   }
284 
~RelabelFstImpl()285   ~RelabelFstImpl() { delete fst_; }
286 
Start()287   StateId Start() {
288     if (!HasStart()) {
289       StateId s = fst_->Start();
290       SetStart(s);
291     }
292     return CacheImpl<A>::Start();
293   }
294 
Final(StateId s)295   Weight Final(StateId s) {
296     if (!HasFinal(s)) {
297       SetFinal(s, fst_->Final(s));
298     }
299     return CacheImpl<A>::Final(s);
300   }
301 
NumArcs(StateId s)302   size_t NumArcs(StateId s) {
303     if (!HasArcs(s)) {
304       Expand(s);
305     }
306     return CacheImpl<A>::NumArcs(s);
307   }
308 
NumInputEpsilons(StateId s)309   size_t NumInputEpsilons(StateId s) {
310     if (!HasArcs(s)) {
311       Expand(s);
312     }
313     return CacheImpl<A>::NumInputEpsilons(s);
314   }
315 
NumOutputEpsilons(StateId s)316   size_t NumOutputEpsilons(StateId s) {
317     if (!HasArcs(s)) {
318       Expand(s);
319     }
320     return CacheImpl<A>::NumOutputEpsilons(s);
321   }
322 
Properties()323   uint64 Properties() const { return Properties(kFstProperties); }
324 
325   // Set error if found; return FST impl properties.
Properties(uint64 mask)326   uint64 Properties(uint64 mask) const {
327     if ((mask & kError) && fst_->Properties(kError, false))
328       SetProperties(kError, kError);
329     return FstImpl<Arc>::Properties(mask);
330   }
331 
InitArcIterator(StateId s,ArcIteratorData<A> * data)332   void InitArcIterator(StateId s, ArcIteratorData<A>* data) {
333     if (!HasArcs(s)) {
334       Expand(s);
335     }
336     CacheImpl<A>::InitArcIterator(s, data);
337   }
338 
Expand(StateId s)339   void Expand(StateId s) {
340     for (ArcIterator<Fst<A> > aiter(*fst_, s); !aiter.Done(); aiter.Next()) {
341       A arc = aiter.Value();
342 
343       // relabel input
344       if (relabel_input_) {
345         typename unordered_map<Label, Label>::iterator it =
346           input_map_.find(arc.ilabel);
347         if (it != input_map_.end()) { arc.ilabel = it->second; }
348       }
349 
350       // relabel output
351       if (relabel_output_) {
352         typename unordered_map<Label, Label>::iterator it =
353           output_map_.find(arc.olabel);
354         if (it != output_map_.end()) { arc.olabel = it->second; }
355       }
356 
357       PushArc(s, arc);
358     }
359     SetArcs(s);
360   }
361 
362 
363  private:
364   const Fst<A> *fst_;
365 
366   unordered_map<Label, Label> input_map_;
367   unordered_map<Label, Label> output_map_;
368   bool relabel_input_;
369   bool relabel_output_;
370 
371   void operator=(const RelabelFstImpl<A> &);  // disallow
372 };
373 
374 
375 //
376 // \class RelabelFst
377 // \brief Delayed implementation of arc relabeling
378 //
379 // This class attaches interface to implementation and handles
380 // reference counting, delegating most methods to ImplToFst.
381 template <class A>
382 class RelabelFst : public ImplToFst< RelabelFstImpl<A> > {
383  public:
384   friend class ArcIterator< RelabelFst<A> >;
385   friend class StateIterator< RelabelFst<A> >;
386 
387   typedef A Arc;
388   typedef typename A::Label   Label;
389   typedef typename A::Weight  Weight;
390   typedef typename A::StateId StateId;
391   typedef CacheState<A> State;
392   typedef RelabelFstImpl<A> Impl;
393 
RelabelFst(const Fst<A> & fst,const vector<pair<Label,Label>> & ipairs,const vector<pair<Label,Label>> & opairs)394   RelabelFst(const Fst<A>& fst,
395              const vector<pair<Label, Label> >& ipairs,
396              const vector<pair<Label, Label> >& opairs)
397       : ImplToFst<Impl>(new Impl(fst, ipairs, opairs, RelabelFstOptions())) {}
398 
RelabelFst(const Fst<A> & fst,const vector<pair<Label,Label>> & ipairs,const vector<pair<Label,Label>> & opairs,const RelabelFstOptions & opts)399   RelabelFst(const Fst<A>& fst,
400              const vector<pair<Label, Label> >& ipairs,
401              const vector<pair<Label, Label> >& opairs,
402              const RelabelFstOptions &opts)
403       : ImplToFst<Impl>(new Impl(fst, ipairs, opairs, opts)) {}
404 
RelabelFst(const Fst<A> & fst,const SymbolTable * new_isymbols,const SymbolTable * new_osymbols)405   RelabelFst(const Fst<A>& fst,
406              const SymbolTable* new_isymbols,
407              const SymbolTable* new_osymbols)
408       : ImplToFst<Impl>(new Impl(fst, fst.InputSymbols(), new_isymbols,
409                                  fst.OutputSymbols(), new_osymbols,
410                                  RelabelFstOptions())) {}
411 
RelabelFst(const Fst<A> & fst,const SymbolTable * new_isymbols,const SymbolTable * new_osymbols,const RelabelFstOptions & opts)412   RelabelFst(const Fst<A>& fst,
413              const SymbolTable* new_isymbols,
414              const SymbolTable* new_osymbols,
415              const RelabelFstOptions &opts)
416       : ImplToFst<Impl>(new Impl(fst, fst.InputSymbols(), new_isymbols,
417                                  fst.OutputSymbols(), new_osymbols, opts)) {}
418 
RelabelFst(const Fst<A> & fst,const SymbolTable * old_isymbols,const SymbolTable * new_isymbols,const SymbolTable * old_osymbols,const SymbolTable * new_osymbols)419   RelabelFst(const Fst<A>& fst,
420              const SymbolTable* old_isymbols,
421              const SymbolTable* new_isymbols,
422              const SymbolTable* old_osymbols,
423              const SymbolTable* new_osymbols)
424     : ImplToFst<Impl>(new Impl(fst, old_isymbols, new_isymbols, old_osymbols,
425                                new_osymbols, RelabelFstOptions())) {}
426 
RelabelFst(const Fst<A> & fst,const SymbolTable * old_isymbols,const SymbolTable * new_isymbols,const SymbolTable * old_osymbols,const SymbolTable * new_osymbols,const RelabelFstOptions & opts)427   RelabelFst(const Fst<A>& fst,
428              const SymbolTable* old_isymbols,
429              const SymbolTable* new_isymbols,
430              const SymbolTable* old_osymbols,
431              const SymbolTable* new_osymbols,
432              const RelabelFstOptions &opts)
433     : ImplToFst<Impl>(new Impl(fst, old_isymbols, new_isymbols, old_osymbols,
434                                new_osymbols, opts)) {}
435 
436   // See Fst<>::Copy() for doc.
437   RelabelFst(const RelabelFst<A> &fst, bool safe = false)
438     : ImplToFst<Impl>(fst, safe) {}
439 
440   // Get a copy of this RelabelFst. See Fst<>::Copy() for further doc.
441   virtual RelabelFst<A> *Copy(bool safe = false) const {
442     return new RelabelFst<A>(*this, safe);
443   }
444 
445   virtual void InitStateIterator(StateIteratorData<A> *data) const;
446 
InitArcIterator(StateId s,ArcIteratorData<A> * data)447   virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
448     return GetImpl()->InitArcIterator(s, data);
449   }
450 
451  private:
452   // Makes visible to friends.
GetImpl()453   Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); }
454 
455   void operator=(const RelabelFst<A> &fst);  // disallow
456 };
457 
458 // Specialization for RelabelFst.
459 template<class A>
460 class StateIterator< RelabelFst<A> > : public StateIteratorBase<A> {
461  public:
462   typedef typename A::StateId StateId;
463 
StateIterator(const RelabelFst<A> & fst)464   explicit StateIterator(const RelabelFst<A> &fst)
465       : impl_(fst.GetImpl()), siter_(*impl_->fst_), s_(0) {}
466 
Done()467   bool Done() const { return siter_.Done(); }
468 
Value()469   StateId Value() const { return s_; }
470 
Next()471   void Next() {
472     if (!siter_.Done()) {
473       ++s_;
474       siter_.Next();
475     }
476   }
477 
Reset()478   void Reset() {
479     s_ = 0;
480     siter_.Reset();
481   }
482 
483  private:
Done_()484   bool Done_() const { return Done(); }
Value_()485   StateId Value_() const { return Value(); }
Next_()486   void Next_() { Next(); }
Reset_()487   void Reset_() { Reset(); }
488 
489   const RelabelFstImpl<A> *impl_;
490   StateIterator< Fst<A> > siter_;
491   StateId s_;
492 
493   DISALLOW_COPY_AND_ASSIGN(StateIterator);
494 };
495 
496 
497 // Specialization for RelabelFst.
498 template <class A>
499 class ArcIterator< RelabelFst<A> >
500     : public CacheArcIterator< RelabelFst<A> > {
501  public:
502   typedef typename A::StateId StateId;
503 
ArcIterator(const RelabelFst<A> & fst,StateId s)504   ArcIterator(const RelabelFst<A> &fst, StateId s)
505       : CacheArcIterator< RelabelFst<A> >(fst.GetImpl(), s) {
506     if (!fst.GetImpl()->HasArcs(s))
507       fst.GetImpl()->Expand(s);
508   }
509 
510  private:
511   DISALLOW_COPY_AND_ASSIGN(ArcIterator);
512 };
513 
514 template <class A> inline
InitStateIterator(StateIteratorData<A> * data)515 void RelabelFst<A>::InitStateIterator(StateIteratorData<A> *data) const {
516   data->base = new StateIterator< RelabelFst<A> >(*this);
517 }
518 
519 // Useful alias when using StdArc.
520 typedef RelabelFst<StdArc> StdRelabelFst;
521 
522 }  // namespace fst
523 
524 #endif  // FST_LIB_RELABEL_H__
525