1 // relabel.h
2
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Copyright 2005-2010 Google, Inc.
16 // Author: johans@google.com (Johan Schalkwyk)
17 //
18 // \file
19 // Functions and classes to relabel an Fst (either on input or output)
20 //
21 #ifndef FST_LIB_RELABEL_H__
22 #define FST_LIB_RELABEL_H__
23
24 #include <unordered_map>
25 using std::tr1::unordered_map;
26 using std::tr1::unordered_multimap;
27 #include <string>
28 #include <utility>
29 using std::pair; using std::make_pair;
30 #include <vector>
31 using std::vector;
32
33 #include <fst/cache.h>
34 #include <fst/test-properties.h>
35
36
37 namespace fst {
38
39 //
40 // Relabels either the input labels or output labels. The old to
41 // new labels are specified using a vector of pair<Label,Label>.
42 // Any label associations not specified are assumed to be identity
43 // mapping.
44 //
45 // \param fst input fst, must be mutable
46 // \param ipairs vector of input label pairs indicating old to new mapping
47 // \param opairs vector of output label pairs indicating old to new mapping
48 //
49 template <class A>
Relabel(MutableFst<A> * fst,const vector<pair<typename A::Label,typename A::Label>> & ipairs,const vector<pair<typename A::Label,typename A::Label>> & opairs)50 void Relabel(
51 MutableFst<A> *fst,
52 const vector<pair<typename A::Label, typename A::Label> >& ipairs,
53 const vector<pair<typename A::Label, typename A::Label> >& opairs) {
54 typedef typename A::StateId StateId;
55 typedef typename A::Label Label;
56
57 uint64 props = fst->Properties(kFstProperties, false);
58
59 // construct label to label hash.
60 unordered_map<Label, Label> input_map;
61 for (size_t i = 0; i < ipairs.size(); ++i) {
62 input_map[ipairs[i].first] = ipairs[i].second;
63 }
64
65 unordered_map<Label, Label> output_map;
66 for (size_t i = 0; i < opairs.size(); ++i) {
67 output_map[opairs[i].first] = opairs[i].second;
68 }
69
70 for (StateIterator<MutableFst<A> > siter(*fst);
71 !siter.Done(); siter.Next()) {
72 StateId s = siter.Value();
73 for (MutableArcIterator<MutableFst<A> > aiter(fst, s);
74 !aiter.Done(); aiter.Next()) {
75 A arc = aiter.Value();
76
77 // relabel input
78 // only relabel if relabel pair defined
79 typename unordered_map<Label, Label>::iterator it =
80 input_map.find(arc.ilabel);
81 if (it != input_map.end()) {
82 if (it->second == kNoLabel) {
83 FSTERROR() << "Input symbol id " << arc.ilabel
84 << " missing from target vocabulary";
85 fst->SetProperties(kError, kError);
86 return;
87 }
88 arc.ilabel = it->second;
89 }
90
91 // relabel output
92 it = output_map.find(arc.olabel);
93 if (it != output_map.end()) {
94 if (it->second == kNoLabel) {
95 FSTERROR() << "Output symbol id " << arc.olabel
96 << " missing from target vocabulary";
97 fst->SetProperties(kError, kError);
98 return;
99 }
100 arc.olabel = it->second;
101 }
102
103 aiter.SetValue(arc);
104 }
105 }
106
107 fst->SetProperties(RelabelProperties(props), kFstProperties);
108 }
109
110 //
111 // Relabels either the input labels or output labels. The old to
112 // new labels mappings are specified using an input Symbol set.
113 // Any label associations not specified are assumed to be identity
114 // mapping.
115 //
116 // \param fst input fst, must be mutable
117 // \param new_isymbols symbol set indicating new mapping of input symbols
118 // \param new_osymbols symbol set indicating new mapping of output symbols
119 //
120 template<class A>
Relabel(MutableFst<A> * fst,const SymbolTable * new_isymbols,const SymbolTable * new_osymbols)121 void Relabel(MutableFst<A> *fst,
122 const SymbolTable* new_isymbols,
123 const SymbolTable* new_osymbols) {
124 Relabel(fst,
125 fst->InputSymbols(), new_isymbols, true,
126 fst->OutputSymbols(), new_osymbols, true);
127 }
128
129 template<class A>
Relabel(MutableFst<A> * fst,const SymbolTable * old_isymbols,const SymbolTable * new_isymbols,bool attach_new_isymbols,const SymbolTable * old_osymbols,const SymbolTable * new_osymbols,bool attach_new_osymbols)130 void Relabel(MutableFst<A> *fst,
131 const SymbolTable* old_isymbols,
132 const SymbolTable* new_isymbols,
133 bool attach_new_isymbols,
134 const SymbolTable* old_osymbols,
135 const SymbolTable* new_osymbols,
136 bool attach_new_osymbols) {
137 typedef typename A::StateId StateId;
138 typedef typename A::Label Label;
139
140 vector<pair<Label, Label> > ipairs;
141 if (old_isymbols && new_isymbols) {
142 for (SymbolTableIterator syms_iter(*old_isymbols); !syms_iter.Done();
143 syms_iter.Next()) {
144 string isymbol = syms_iter.Symbol();
145 int isymbol_val = syms_iter.Value();
146 int new_isymbol_val = new_isymbols->Find(isymbol);
147 ipairs.push_back(make_pair(isymbol_val, new_isymbol_val));
148 }
149 if (attach_new_isymbols)
150 fst->SetInputSymbols(new_isymbols);
151 }
152
153 vector<pair<Label, Label> > opairs;
154 if (old_osymbols && new_osymbols) {
155 for (SymbolTableIterator syms_iter(*old_osymbols); !syms_iter.Done();
156 syms_iter.Next()) {
157 string osymbol = syms_iter.Symbol();
158 int osymbol_val = syms_iter.Value();
159 int new_osymbol_val = new_osymbols->Find(osymbol);
160 opairs.push_back(make_pair(osymbol_val, new_osymbol_val));
161 }
162 if (attach_new_osymbols)
163 fst->SetOutputSymbols(new_osymbols);
164 }
165
166 // call relabel using vector of relabel pairs.
167 Relabel(fst, ipairs, opairs);
168 }
169
170
171 typedef CacheOptions RelabelFstOptions;
172
173 template <class A> class RelabelFst;
174
175 //
176 // \class RelabelFstImpl
177 // \brief Implementation for delayed relabeling
178 //
179 // Relabels an FST from one symbol set to another. Relabeling
180 // can either be on input or output space. RelabelFst implements
181 // a delayed version of the relabel. Arcs are relabeled on the fly
182 // and not cached. I.e each request is recomputed.
183 //
184 template<class A>
185 class RelabelFstImpl : public CacheImpl<A> {
186 friend class StateIterator< RelabelFst<A> >;
187 public:
188 using FstImpl<A>::SetType;
189 using FstImpl<A>::SetProperties;
190 using FstImpl<A>::WriteHeader;
191 using FstImpl<A>::SetInputSymbols;
192 using FstImpl<A>::SetOutputSymbols;
193
194 using CacheImpl<A>::PushArc;
195 using CacheImpl<A>::HasArcs;
196 using CacheImpl<A>::HasFinal;
197 using CacheImpl<A>::HasStart;
198 using CacheImpl<A>::SetArcs;
199 using CacheImpl<A>::SetFinal;
200 using CacheImpl<A>::SetStart;
201
202 typedef A Arc;
203 typedef typename A::Label Label;
204 typedef typename A::Weight Weight;
205 typedef typename A::StateId StateId;
206 typedef CacheState<A> State;
207
RelabelFstImpl(const Fst<A> & fst,const vector<pair<Label,Label>> & ipairs,const vector<pair<Label,Label>> & opairs,const RelabelFstOptions & opts)208 RelabelFstImpl(const Fst<A>& fst,
209 const vector<pair<Label, Label> >& ipairs,
210 const vector<pair<Label, Label> >& opairs,
211 const RelabelFstOptions &opts)
212 : CacheImpl<A>(opts), fst_(fst.Copy()),
213 relabel_input_(false), relabel_output_(false) {
214 uint64 props = fst.Properties(kCopyProperties, false);
215 SetProperties(RelabelProperties(props));
216 SetType("relabel");
217
218 // create input label map
219 if (ipairs.size() > 0) {
220 for (size_t i = 0; i < ipairs.size(); ++i) {
221 input_map_[ipairs[i].first] = ipairs[i].second;
222 }
223 relabel_input_ = true;
224 }
225
226 // create output label map
227 if (opairs.size() > 0) {
228 for (size_t i = 0; i < opairs.size(); ++i) {
229 output_map_[opairs[i].first] = opairs[i].second;
230 }
231 relabel_output_ = true;
232 }
233 }
234
RelabelFstImpl(const Fst<A> & fst,const SymbolTable * old_isymbols,const SymbolTable * new_isymbols,const SymbolTable * old_osymbols,const SymbolTable * new_osymbols,const RelabelFstOptions & opts)235 RelabelFstImpl(const Fst<A>& fst,
236 const SymbolTable* old_isymbols,
237 const SymbolTable* new_isymbols,
238 const SymbolTable* old_osymbols,
239 const SymbolTable* new_osymbols,
240 const RelabelFstOptions &opts)
241 : CacheImpl<A>(opts), fst_(fst.Copy()),
242 relabel_input_(false), relabel_output_(false) {
243 SetType("relabel");
244
245 uint64 props = fst.Properties(kCopyProperties, false);
246 SetProperties(RelabelProperties(props));
247 SetInputSymbols(old_isymbols);
248 SetOutputSymbols(old_osymbols);
249
250 if (old_isymbols && new_isymbols &&
251 old_isymbols->LabeledCheckSum() != new_isymbols->LabeledCheckSum()) {
252 for (SymbolTableIterator syms_iter(*old_isymbols); !syms_iter.Done();
253 syms_iter.Next()) {
254 input_map_[syms_iter.Value()] = new_isymbols->Find(syms_iter.Symbol());
255 }
256 SetInputSymbols(new_isymbols);
257 relabel_input_ = true;
258 }
259
260 if (old_osymbols && new_osymbols &&
261 old_osymbols->LabeledCheckSum() != new_osymbols->LabeledCheckSum()) {
262 for (SymbolTableIterator syms_iter(*old_osymbols); !syms_iter.Done();
263 syms_iter.Next()) {
264 output_map_[syms_iter.Value()] =
265 new_osymbols->Find(syms_iter.Symbol());
266 }
267 SetOutputSymbols(new_osymbols);
268 relabel_output_ = true;
269 }
270 }
271
RelabelFstImpl(const RelabelFstImpl<A> & impl)272 RelabelFstImpl(const RelabelFstImpl<A>& impl)
273 : CacheImpl<A>(impl),
274 fst_(impl.fst_->Copy(true)),
275 input_map_(impl.input_map_),
276 output_map_(impl.output_map_),
277 relabel_input_(impl.relabel_input_),
278 relabel_output_(impl.relabel_output_) {
279 SetType("relabel");
280 SetProperties(impl.Properties(), kCopyProperties);
281 SetInputSymbols(impl.InputSymbols());
282 SetOutputSymbols(impl.OutputSymbols());
283 }
284
~RelabelFstImpl()285 ~RelabelFstImpl() { delete fst_; }
286
Start()287 StateId Start() {
288 if (!HasStart()) {
289 StateId s = fst_->Start();
290 SetStart(s);
291 }
292 return CacheImpl<A>::Start();
293 }
294
Final(StateId s)295 Weight Final(StateId s) {
296 if (!HasFinal(s)) {
297 SetFinal(s, fst_->Final(s));
298 }
299 return CacheImpl<A>::Final(s);
300 }
301
NumArcs(StateId s)302 size_t NumArcs(StateId s) {
303 if (!HasArcs(s)) {
304 Expand(s);
305 }
306 return CacheImpl<A>::NumArcs(s);
307 }
308
NumInputEpsilons(StateId s)309 size_t NumInputEpsilons(StateId s) {
310 if (!HasArcs(s)) {
311 Expand(s);
312 }
313 return CacheImpl<A>::NumInputEpsilons(s);
314 }
315
NumOutputEpsilons(StateId s)316 size_t NumOutputEpsilons(StateId s) {
317 if (!HasArcs(s)) {
318 Expand(s);
319 }
320 return CacheImpl<A>::NumOutputEpsilons(s);
321 }
322
Properties()323 uint64 Properties() const { return Properties(kFstProperties); }
324
325 // Set error if found; return FST impl properties.
Properties(uint64 mask)326 uint64 Properties(uint64 mask) const {
327 if ((mask & kError) && fst_->Properties(kError, false))
328 SetProperties(kError, kError);
329 return FstImpl<Arc>::Properties(mask);
330 }
331
InitArcIterator(StateId s,ArcIteratorData<A> * data)332 void InitArcIterator(StateId s, ArcIteratorData<A>* data) {
333 if (!HasArcs(s)) {
334 Expand(s);
335 }
336 CacheImpl<A>::InitArcIterator(s, data);
337 }
338
Expand(StateId s)339 void Expand(StateId s) {
340 for (ArcIterator<Fst<A> > aiter(*fst_, s); !aiter.Done(); aiter.Next()) {
341 A arc = aiter.Value();
342
343 // relabel input
344 if (relabel_input_) {
345 typename unordered_map<Label, Label>::iterator it =
346 input_map_.find(arc.ilabel);
347 if (it != input_map_.end()) { arc.ilabel = it->second; }
348 }
349
350 // relabel output
351 if (relabel_output_) {
352 typename unordered_map<Label, Label>::iterator it =
353 output_map_.find(arc.olabel);
354 if (it != output_map_.end()) { arc.olabel = it->second; }
355 }
356
357 PushArc(s, arc);
358 }
359 SetArcs(s);
360 }
361
362
363 private:
364 const Fst<A> *fst_;
365
366 unordered_map<Label, Label> input_map_;
367 unordered_map<Label, Label> output_map_;
368 bool relabel_input_;
369 bool relabel_output_;
370
371 void operator=(const RelabelFstImpl<A> &); // disallow
372 };
373
374
375 //
376 // \class RelabelFst
377 // \brief Delayed implementation of arc relabeling
378 //
379 // This class attaches interface to implementation and handles
380 // reference counting, delegating most methods to ImplToFst.
381 template <class A>
382 class RelabelFst : public ImplToFst< RelabelFstImpl<A> > {
383 public:
384 friend class ArcIterator< RelabelFst<A> >;
385 friend class StateIterator< RelabelFst<A> >;
386
387 typedef A Arc;
388 typedef typename A::Label Label;
389 typedef typename A::Weight Weight;
390 typedef typename A::StateId StateId;
391 typedef CacheState<A> State;
392 typedef RelabelFstImpl<A> Impl;
393
RelabelFst(const Fst<A> & fst,const vector<pair<Label,Label>> & ipairs,const vector<pair<Label,Label>> & opairs)394 RelabelFst(const Fst<A>& fst,
395 const vector<pair<Label, Label> >& ipairs,
396 const vector<pair<Label, Label> >& opairs)
397 : ImplToFst<Impl>(new Impl(fst, ipairs, opairs, RelabelFstOptions())) {}
398
RelabelFst(const Fst<A> & fst,const vector<pair<Label,Label>> & ipairs,const vector<pair<Label,Label>> & opairs,const RelabelFstOptions & opts)399 RelabelFst(const Fst<A>& fst,
400 const vector<pair<Label, Label> >& ipairs,
401 const vector<pair<Label, Label> >& opairs,
402 const RelabelFstOptions &opts)
403 : ImplToFst<Impl>(new Impl(fst, ipairs, opairs, opts)) {}
404
RelabelFst(const Fst<A> & fst,const SymbolTable * new_isymbols,const SymbolTable * new_osymbols)405 RelabelFst(const Fst<A>& fst,
406 const SymbolTable* new_isymbols,
407 const SymbolTable* new_osymbols)
408 : ImplToFst<Impl>(new Impl(fst, fst.InputSymbols(), new_isymbols,
409 fst.OutputSymbols(), new_osymbols,
410 RelabelFstOptions())) {}
411
RelabelFst(const Fst<A> & fst,const SymbolTable * new_isymbols,const SymbolTable * new_osymbols,const RelabelFstOptions & opts)412 RelabelFst(const Fst<A>& fst,
413 const SymbolTable* new_isymbols,
414 const SymbolTable* new_osymbols,
415 const RelabelFstOptions &opts)
416 : ImplToFst<Impl>(new Impl(fst, fst.InputSymbols(), new_isymbols,
417 fst.OutputSymbols(), new_osymbols, opts)) {}
418
RelabelFst(const Fst<A> & fst,const SymbolTable * old_isymbols,const SymbolTable * new_isymbols,const SymbolTable * old_osymbols,const SymbolTable * new_osymbols)419 RelabelFst(const Fst<A>& fst,
420 const SymbolTable* old_isymbols,
421 const SymbolTable* new_isymbols,
422 const SymbolTable* old_osymbols,
423 const SymbolTable* new_osymbols)
424 : ImplToFst<Impl>(new Impl(fst, old_isymbols, new_isymbols, old_osymbols,
425 new_osymbols, RelabelFstOptions())) {}
426
RelabelFst(const Fst<A> & fst,const SymbolTable * old_isymbols,const SymbolTable * new_isymbols,const SymbolTable * old_osymbols,const SymbolTable * new_osymbols,const RelabelFstOptions & opts)427 RelabelFst(const Fst<A>& fst,
428 const SymbolTable* old_isymbols,
429 const SymbolTable* new_isymbols,
430 const SymbolTable* old_osymbols,
431 const SymbolTable* new_osymbols,
432 const RelabelFstOptions &opts)
433 : ImplToFst<Impl>(new Impl(fst, old_isymbols, new_isymbols, old_osymbols,
434 new_osymbols, opts)) {}
435
436 // See Fst<>::Copy() for doc.
437 RelabelFst(const RelabelFst<A> &fst, bool safe = false)
438 : ImplToFst<Impl>(fst, safe) {}
439
440 // Get a copy of this RelabelFst. See Fst<>::Copy() for further doc.
441 virtual RelabelFst<A> *Copy(bool safe = false) const {
442 return new RelabelFst<A>(*this, safe);
443 }
444
445 virtual void InitStateIterator(StateIteratorData<A> *data) const;
446
InitArcIterator(StateId s,ArcIteratorData<A> * data)447 virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
448 return GetImpl()->InitArcIterator(s, data);
449 }
450
451 private:
452 // Makes visible to friends.
GetImpl()453 Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); }
454
455 void operator=(const RelabelFst<A> &fst); // disallow
456 };
457
458 // Specialization for RelabelFst.
459 template<class A>
460 class StateIterator< RelabelFst<A> > : public StateIteratorBase<A> {
461 public:
462 typedef typename A::StateId StateId;
463
StateIterator(const RelabelFst<A> & fst)464 explicit StateIterator(const RelabelFst<A> &fst)
465 : impl_(fst.GetImpl()), siter_(*impl_->fst_), s_(0) {}
466
Done()467 bool Done() const { return siter_.Done(); }
468
Value()469 StateId Value() const { return s_; }
470
Next()471 void Next() {
472 if (!siter_.Done()) {
473 ++s_;
474 siter_.Next();
475 }
476 }
477
Reset()478 void Reset() {
479 s_ = 0;
480 siter_.Reset();
481 }
482
483 private:
Done_()484 bool Done_() const { return Done(); }
Value_()485 StateId Value_() const { return Value(); }
Next_()486 void Next_() { Next(); }
Reset_()487 void Reset_() { Reset(); }
488
489 const RelabelFstImpl<A> *impl_;
490 StateIterator< Fst<A> > siter_;
491 StateId s_;
492
493 DISALLOW_COPY_AND_ASSIGN(StateIterator);
494 };
495
496
497 // Specialization for RelabelFst.
498 template <class A>
499 class ArcIterator< RelabelFst<A> >
500 : public CacheArcIterator< RelabelFst<A> > {
501 public:
502 typedef typename A::StateId StateId;
503
ArcIterator(const RelabelFst<A> & fst,StateId s)504 ArcIterator(const RelabelFst<A> &fst, StateId s)
505 : CacheArcIterator< RelabelFst<A> >(fst.GetImpl(), s) {
506 if (!fst.GetImpl()->HasArcs(s))
507 fst.GetImpl()->Expand(s);
508 }
509
510 private:
511 DISALLOW_COPY_AND_ASSIGN(ArcIterator);
512 };
513
514 template <class A> inline
InitStateIterator(StateIteratorData<A> * data)515 void RelabelFst<A>::InitStateIterator(StateIteratorData<A> *data) const {
516 data->base = new StateIterator< RelabelFst<A> >(*this);
517 }
518
519 // Useful alias when using StdArc.
520 typedef RelabelFst<StdArc> StdRelabelFst;
521
522 } // namespace fst
523
524 #endif // FST_LIB_RELABEL_H__
525