1 // relabel.h
2
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Copyright 2005-2010 Google, Inc.
16 // Author: johans@google.com (Johan Schalkwyk)
17 //
18 // \file
19 // Functions and classes to relabel an Fst (either on input or output)
20 //
21 #ifndef FST_LIB_RELABEL_H__
22 #define FST_LIB_RELABEL_H__
23
24 #include <tr1/unordered_map>
25 using std::tr1::unordered_map;
26 using std::tr1::unordered_multimap;
27 #include <string>
28 #include <utility>
29 using std::pair; using std::make_pair;
30 #include <vector>
31 using std::vector;
32
33 #include <fst/cache.h>
34 #include <fst/test-properties.h>
35
36
37 #include <tr1/unordered_map>
38 using std::tr1::unordered_map;
39 using std::tr1::unordered_multimap;
40
41 namespace fst {
42
43 //
44 // Relabels either the input labels or output labels. The old to
45 // new labels are specified using a vector of pair<Label,Label>.
46 // Any label associations not specified are assumed to be identity
47 // mapping.
48 //
49 // \param fst input fst, must be mutable
50 // \param ipairs vector of input label pairs indicating old to new mapping
51 // \param opairs vector of output label pairs indicating old to new mapping
52 //
53 template <class A>
Relabel(MutableFst<A> * fst,const vector<pair<typename A::Label,typename A::Label>> & ipairs,const vector<pair<typename A::Label,typename A::Label>> & opairs)54 void Relabel(
55 MutableFst<A> *fst,
56 const vector<pair<typename A::Label, typename A::Label> >& ipairs,
57 const vector<pair<typename A::Label, typename A::Label> >& opairs) {
58 typedef typename A::StateId StateId;
59 typedef typename A::Label Label;
60
61 uint64 props = fst->Properties(kFstProperties, false);
62
63 // construct label to label hash.
64 unordered_map<Label, Label> input_map;
65 for (size_t i = 0; i < ipairs.size(); ++i) {
66 input_map[ipairs[i].first] = ipairs[i].second;
67 }
68
69 unordered_map<Label, Label> output_map;
70 for (size_t i = 0; i < opairs.size(); ++i) {
71 output_map[opairs[i].first] = opairs[i].second;
72 }
73
74 for (StateIterator<MutableFst<A> > siter(*fst);
75 !siter.Done(); siter.Next()) {
76 StateId s = siter.Value();
77 for (MutableArcIterator<MutableFst<A> > aiter(fst, s);
78 !aiter.Done(); aiter.Next()) {
79 A arc = aiter.Value();
80
81 // relabel input
82 // only relabel if relabel pair defined
83 typename unordered_map<Label, Label>::iterator it =
84 input_map.find(arc.ilabel);
85 if (it != input_map.end()) {
86 if (it->second == kNoLabel) {
87 FSTERROR() << "Input symbol id " << arc.ilabel
88 << " missing from target vocabulary";
89 fst->SetProperties(kError, kError);
90 return;
91 }
92 arc.ilabel = it->second;
93 }
94
95 // relabel output
96 it = output_map.find(arc.olabel);
97 if (it != output_map.end()) {
98 if (it->second == kNoLabel) {
99 FSTERROR() << "Output symbol id " << arc.olabel
100 << " missing from target vocabulary";
101 fst->SetProperties(kError, kError);
102 return;
103 }
104 arc.olabel = it->second;
105 }
106
107 aiter.SetValue(arc);
108 }
109 }
110
111 fst->SetProperties(RelabelProperties(props), kFstProperties);
112 }
113
114 //
115 // Relabels either the input labels or output labels. The old to
116 // new labels mappings are specified using an input Symbol set.
117 // Any label associations not specified are assumed to be identity
118 // mapping.
119 //
120 // \param fst input fst, must be mutable
121 // \param new_isymbols symbol set indicating new mapping of input symbols
122 // \param new_osymbols symbol set indicating new mapping of output symbols
123 //
124 template<class A>
Relabel(MutableFst<A> * fst,const SymbolTable * new_isymbols,const SymbolTable * new_osymbols)125 void Relabel(MutableFst<A> *fst,
126 const SymbolTable* new_isymbols,
127 const SymbolTable* new_osymbols) {
128 Relabel(fst,
129 fst->InputSymbols(), new_isymbols, true,
130 fst->OutputSymbols(), new_osymbols, true);
131 }
132
133 template<class A>
Relabel(MutableFst<A> * fst,const SymbolTable * old_isymbols,const SymbolTable * new_isymbols,bool attach_new_isymbols,const SymbolTable * old_osymbols,const SymbolTable * new_osymbols,bool attach_new_osymbols)134 void Relabel(MutableFst<A> *fst,
135 const SymbolTable* old_isymbols,
136 const SymbolTable* new_isymbols,
137 bool attach_new_isymbols,
138 const SymbolTable* old_osymbols,
139 const SymbolTable* new_osymbols,
140 bool attach_new_osymbols) {
141 typedef typename A::StateId StateId;
142 typedef typename A::Label Label;
143
144 vector<pair<Label, Label> > ipairs;
145 if (old_isymbols && new_isymbols) {
146 for (SymbolTableIterator syms_iter(*old_isymbols); !syms_iter.Done();
147 syms_iter.Next()) {
148 string isymbol = syms_iter.Symbol();
149 int isymbol_val = syms_iter.Value();
150 int new_isymbol_val = new_isymbols->Find(isymbol);
151 ipairs.push_back(make_pair(isymbol_val, new_isymbol_val));
152 }
153 if (attach_new_isymbols)
154 fst->SetInputSymbols(new_isymbols);
155 }
156
157 vector<pair<Label, Label> > opairs;
158 if (old_osymbols && new_osymbols) {
159 for (SymbolTableIterator syms_iter(*old_osymbols); !syms_iter.Done();
160 syms_iter.Next()) {
161 string osymbol = syms_iter.Symbol();
162 int osymbol_val = syms_iter.Value();
163 int new_osymbol_val = new_osymbols->Find(osymbol);
164 opairs.push_back(make_pair(osymbol_val, new_osymbol_val));
165 }
166 if (attach_new_osymbols)
167 fst->SetOutputSymbols(new_osymbols);
168 }
169
170 // call relabel using vector of relabel pairs.
171 Relabel(fst, ipairs, opairs);
172 }
173
174
175 typedef CacheOptions RelabelFstOptions;
176
177 template <class A> class RelabelFst;
178
179 //
180 // \class RelabelFstImpl
181 // \brief Implementation for delayed relabeling
182 //
183 // Relabels an FST from one symbol set to another. Relabeling
184 // can either be on input or output space. RelabelFst implements
185 // a delayed version of the relabel. Arcs are relabeled on the fly
186 // and not cached. I.e each request is recomputed.
187 //
188 template<class A>
189 class RelabelFstImpl : public CacheImpl<A> {
190 friend class StateIterator< RelabelFst<A> >;
191 public:
192 using FstImpl<A>::SetType;
193 using FstImpl<A>::SetProperties;
194 using FstImpl<A>::WriteHeader;
195 using FstImpl<A>::SetInputSymbols;
196 using FstImpl<A>::SetOutputSymbols;
197
198 using CacheImpl<A>::PushArc;
199 using CacheImpl<A>::HasArcs;
200 using CacheImpl<A>::HasFinal;
201 using CacheImpl<A>::HasStart;
202 using CacheImpl<A>::SetArcs;
203 using CacheImpl<A>::SetFinal;
204 using CacheImpl<A>::SetStart;
205
206 typedef A Arc;
207 typedef typename A::Label Label;
208 typedef typename A::Weight Weight;
209 typedef typename A::StateId StateId;
210 typedef CacheState<A> State;
211
RelabelFstImpl(const Fst<A> & fst,const vector<pair<Label,Label>> & ipairs,const vector<pair<Label,Label>> & opairs,const RelabelFstOptions & opts)212 RelabelFstImpl(const Fst<A>& fst,
213 const vector<pair<Label, Label> >& ipairs,
214 const vector<pair<Label, Label> >& opairs,
215 const RelabelFstOptions &opts)
216 : CacheImpl<A>(opts), fst_(fst.Copy()),
217 relabel_input_(false), relabel_output_(false) {
218 uint64 props = fst.Properties(kCopyProperties, false);
219 SetProperties(RelabelProperties(props));
220 SetType("relabel");
221
222 // create input label map
223 if (ipairs.size() > 0) {
224 for (size_t i = 0; i < ipairs.size(); ++i) {
225 input_map_[ipairs[i].first] = ipairs[i].second;
226 }
227 relabel_input_ = true;
228 }
229
230 // create output label map
231 if (opairs.size() > 0) {
232 for (size_t i = 0; i < opairs.size(); ++i) {
233 output_map_[opairs[i].first] = opairs[i].second;
234 }
235 relabel_output_ = true;
236 }
237 }
238
RelabelFstImpl(const Fst<A> & fst,const SymbolTable * old_isymbols,const SymbolTable * new_isymbols,const SymbolTable * old_osymbols,const SymbolTable * new_osymbols,const RelabelFstOptions & opts)239 RelabelFstImpl(const Fst<A>& fst,
240 const SymbolTable* old_isymbols,
241 const SymbolTable* new_isymbols,
242 const SymbolTable* old_osymbols,
243 const SymbolTable* new_osymbols,
244 const RelabelFstOptions &opts)
245 : CacheImpl<A>(opts), fst_(fst.Copy()),
246 relabel_input_(false), relabel_output_(false) {
247 SetType("relabel");
248
249 uint64 props = fst.Properties(kCopyProperties, false);
250 SetProperties(RelabelProperties(props));
251 SetInputSymbols(old_isymbols);
252 SetOutputSymbols(old_osymbols);
253
254 if (old_isymbols && new_isymbols &&
255 old_isymbols->LabeledCheckSum() != new_isymbols->LabeledCheckSum()) {
256 for (SymbolTableIterator syms_iter(*old_isymbols); !syms_iter.Done();
257 syms_iter.Next()) {
258 input_map_[syms_iter.Value()] = new_isymbols->Find(syms_iter.Symbol());
259 }
260 SetInputSymbols(new_isymbols);
261 relabel_input_ = true;
262 }
263
264 if (old_osymbols && new_osymbols &&
265 old_osymbols->LabeledCheckSum() != new_osymbols->LabeledCheckSum()) {
266 for (SymbolTableIterator syms_iter(*old_osymbols); !syms_iter.Done();
267 syms_iter.Next()) {
268 output_map_[syms_iter.Value()] =
269 new_osymbols->Find(syms_iter.Symbol());
270 }
271 SetOutputSymbols(new_osymbols);
272 relabel_output_ = true;
273 }
274 }
275
RelabelFstImpl(const RelabelFstImpl<A> & impl)276 RelabelFstImpl(const RelabelFstImpl<A>& impl)
277 : CacheImpl<A>(impl),
278 fst_(impl.fst_->Copy(true)),
279 input_map_(impl.input_map_),
280 output_map_(impl.output_map_),
281 relabel_input_(impl.relabel_input_),
282 relabel_output_(impl.relabel_output_) {
283 SetType("relabel");
284 SetProperties(impl.Properties(), kCopyProperties);
285 SetInputSymbols(impl.InputSymbols());
286 SetOutputSymbols(impl.OutputSymbols());
287 }
288
~RelabelFstImpl()289 ~RelabelFstImpl() { delete fst_; }
290
Start()291 StateId Start() {
292 if (!HasStart()) {
293 StateId s = fst_->Start();
294 SetStart(s);
295 }
296 return CacheImpl<A>::Start();
297 }
298
Final(StateId s)299 Weight Final(StateId s) {
300 if (!HasFinal(s)) {
301 SetFinal(s, fst_->Final(s));
302 }
303 return CacheImpl<A>::Final(s);
304 }
305
NumArcs(StateId s)306 size_t NumArcs(StateId s) {
307 if (!HasArcs(s)) {
308 Expand(s);
309 }
310 return CacheImpl<A>::NumArcs(s);
311 }
312
NumInputEpsilons(StateId s)313 size_t NumInputEpsilons(StateId s) {
314 if (!HasArcs(s)) {
315 Expand(s);
316 }
317 return CacheImpl<A>::NumInputEpsilons(s);
318 }
319
NumOutputEpsilons(StateId s)320 size_t NumOutputEpsilons(StateId s) {
321 if (!HasArcs(s)) {
322 Expand(s);
323 }
324 return CacheImpl<A>::NumOutputEpsilons(s);
325 }
326
Properties()327 uint64 Properties() const { return Properties(kFstProperties); }
328
329 // Set error if found; return FST impl properties.
Properties(uint64 mask)330 uint64 Properties(uint64 mask) const {
331 if ((mask & kError) && fst_->Properties(kError, false))
332 SetProperties(kError, kError);
333 return FstImpl<Arc>::Properties(mask);
334 }
335
InitArcIterator(StateId s,ArcIteratorData<A> * data)336 void InitArcIterator(StateId s, ArcIteratorData<A>* data) {
337 if (!HasArcs(s)) {
338 Expand(s);
339 }
340 CacheImpl<A>::InitArcIterator(s, data);
341 }
342
Expand(StateId s)343 void Expand(StateId s) {
344 for (ArcIterator<Fst<A> > aiter(*fst_, s); !aiter.Done(); aiter.Next()) {
345 A arc = aiter.Value();
346
347 // relabel input
348 if (relabel_input_) {
349 typename unordered_map<Label, Label>::iterator it =
350 input_map_.find(arc.ilabel);
351 if (it != input_map_.end()) { arc.ilabel = it->second; }
352 }
353
354 // relabel output
355 if (relabel_output_) {
356 typename unordered_map<Label, Label>::iterator it =
357 output_map_.find(arc.olabel);
358 if (it != output_map_.end()) { arc.olabel = it->second; }
359 }
360
361 PushArc(s, arc);
362 }
363 SetArcs(s);
364 }
365
366
367 private:
368 const Fst<A> *fst_;
369
370 unordered_map<Label, Label> input_map_;
371 unordered_map<Label, Label> output_map_;
372 bool relabel_input_;
373 bool relabel_output_;
374
375 void operator=(const RelabelFstImpl<A> &); // disallow
376 };
377
378
379 //
380 // \class RelabelFst
381 // \brief Delayed implementation of arc relabeling
382 //
383 // This class attaches interface to implementation and handles
384 // reference counting, delegating most methods to ImplToFst.
385 template <class A>
386 class RelabelFst : public ImplToFst< RelabelFstImpl<A> > {
387 public:
388 friend class ArcIterator< RelabelFst<A> >;
389 friend class StateIterator< RelabelFst<A> >;
390
391 typedef A Arc;
392 typedef typename A::Label Label;
393 typedef typename A::Weight Weight;
394 typedef typename A::StateId StateId;
395 typedef CacheState<A> State;
396 typedef RelabelFstImpl<A> Impl;
397
RelabelFst(const Fst<A> & fst,const vector<pair<Label,Label>> & ipairs,const vector<pair<Label,Label>> & opairs)398 RelabelFst(const Fst<A>& fst,
399 const vector<pair<Label, Label> >& ipairs,
400 const vector<pair<Label, Label> >& opairs)
401 : ImplToFst<Impl>(new Impl(fst, ipairs, opairs, RelabelFstOptions())) {}
402
RelabelFst(const Fst<A> & fst,const vector<pair<Label,Label>> & ipairs,const vector<pair<Label,Label>> & opairs,const RelabelFstOptions & opts)403 RelabelFst(const Fst<A>& fst,
404 const vector<pair<Label, Label> >& ipairs,
405 const vector<pair<Label, Label> >& opairs,
406 const RelabelFstOptions &opts)
407 : ImplToFst<Impl>(new Impl(fst, ipairs, opairs, opts)) {}
408
RelabelFst(const Fst<A> & fst,const SymbolTable * new_isymbols,const SymbolTable * new_osymbols)409 RelabelFst(const Fst<A>& fst,
410 const SymbolTable* new_isymbols,
411 const SymbolTable* new_osymbols)
412 : ImplToFst<Impl>(new Impl(fst, fst.InputSymbols(), new_isymbols,
413 fst.OutputSymbols(), new_osymbols,
414 RelabelFstOptions())) {}
415
RelabelFst(const Fst<A> & fst,const SymbolTable * new_isymbols,const SymbolTable * new_osymbols,const RelabelFstOptions & opts)416 RelabelFst(const Fst<A>& fst,
417 const SymbolTable* new_isymbols,
418 const SymbolTable* new_osymbols,
419 const RelabelFstOptions &opts)
420 : ImplToFst<Impl>(new Impl(fst, fst.InputSymbols(), new_isymbols,
421 fst.OutputSymbols(), new_osymbols, opts)) {}
422
RelabelFst(const Fst<A> & fst,const SymbolTable * old_isymbols,const SymbolTable * new_isymbols,const SymbolTable * old_osymbols,const SymbolTable * new_osymbols)423 RelabelFst(const Fst<A>& fst,
424 const SymbolTable* old_isymbols,
425 const SymbolTable* new_isymbols,
426 const SymbolTable* old_osymbols,
427 const SymbolTable* new_osymbols)
428 : ImplToFst<Impl>(new Impl(fst, old_isymbols, new_isymbols, old_osymbols,
429 new_osymbols, RelabelFstOptions())) {}
430
RelabelFst(const Fst<A> & fst,const SymbolTable * old_isymbols,const SymbolTable * new_isymbols,const SymbolTable * old_osymbols,const SymbolTable * new_osymbols,const RelabelFstOptions & opts)431 RelabelFst(const Fst<A>& fst,
432 const SymbolTable* old_isymbols,
433 const SymbolTable* new_isymbols,
434 const SymbolTable* old_osymbols,
435 const SymbolTable* new_osymbols,
436 const RelabelFstOptions &opts)
437 : ImplToFst<Impl>(new Impl(fst, old_isymbols, new_isymbols, old_osymbols,
438 new_osymbols, opts)) {}
439
440 // See Fst<>::Copy() for doc.
441 RelabelFst(const RelabelFst<A> &fst, bool safe = false)
442 : ImplToFst<Impl>(fst, safe) {}
443
444 // Get a copy of this RelabelFst. See Fst<>::Copy() for further doc.
445 virtual RelabelFst<A> *Copy(bool safe = false) const {
446 return new RelabelFst<A>(*this, safe);
447 }
448
449 virtual void InitStateIterator(StateIteratorData<A> *data) const;
450
InitArcIterator(StateId s,ArcIteratorData<A> * data)451 virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
452 return GetImpl()->InitArcIterator(s, data);
453 }
454
455 private:
456 // Makes visible to friends.
GetImpl()457 Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); }
458
459 void operator=(const RelabelFst<A> &fst); // disallow
460 };
461
462 // Specialization for RelabelFst.
463 template<class A>
464 class StateIterator< RelabelFst<A> > : public StateIteratorBase<A> {
465 public:
466 typedef typename A::StateId StateId;
467
StateIterator(const RelabelFst<A> & fst)468 explicit StateIterator(const RelabelFst<A> &fst)
469 : impl_(fst.GetImpl()), siter_(*impl_->fst_), s_(0) {}
470
Done()471 bool Done() const { return siter_.Done(); }
472
Value()473 StateId Value() const { return s_; }
474
Next()475 void Next() {
476 if (!siter_.Done()) {
477 ++s_;
478 siter_.Next();
479 }
480 }
481
Reset()482 void Reset() {
483 s_ = 0;
484 siter_.Reset();
485 }
486
487 private:
Done_()488 bool Done_() const { return Done(); }
Value_()489 StateId Value_() const { return Value(); }
Next_()490 void Next_() { Next(); }
Reset_()491 void Reset_() { Reset(); }
492
493 const RelabelFstImpl<A> *impl_;
494 StateIterator< Fst<A> > siter_;
495 StateId s_;
496
497 DISALLOW_COPY_AND_ASSIGN(StateIterator);
498 };
499
500
501 // Specialization for RelabelFst.
502 template <class A>
503 class ArcIterator< RelabelFst<A> >
504 : public CacheArcIterator< RelabelFst<A> > {
505 public:
506 typedef typename A::StateId StateId;
507
ArcIterator(const RelabelFst<A> & fst,StateId s)508 ArcIterator(const RelabelFst<A> &fst, StateId s)
509 : CacheArcIterator< RelabelFst<A> >(fst.GetImpl(), s) {
510 if (!fst.GetImpl()->HasArcs(s))
511 fst.GetImpl()->Expand(s);
512 }
513
514 private:
515 DISALLOW_COPY_AND_ASSIGN(ArcIterator);
516 };
517
518 template <class A> inline
InitStateIterator(StateIteratorData<A> * data)519 void RelabelFst<A>::InitStateIterator(StateIteratorData<A> *data) const {
520 data->base = new StateIterator< RelabelFst<A> >(*this);
521 }
522
523 // Useful alias when using StdArc.
524 typedef RelabelFst<StdArc> StdRelabelFst;
525
526 } // namespace fst
527
528 #endif // FST_LIB_RELABEL_H__
529