1 // relabel.h
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 //
16 // \file
17 // Functions and classes to relabel an Fst (either on input or output)
18 //
19 #ifndef FST_LIB_RELABEL_H__
20 #define FST_LIB_RELABEL_H__
21
22 #include <unordered_map>
23
24 #include "fst/lib/cache.h"
25 #include "fst/lib/test-properties.h"
26
27
28 namespace fst {
29
30 //
31 // Relabels either the input labels or output labels. The old to
32 // new labels are specified using a vector of pair<Label,Label>.
33 // Any label associations not specified are assumed to be identity
34 // mapping.
35 //
36 // \param fst input fst, must be mutable
37 // \param relabel_pairs vector of pairs indicating old to new mapping
38 // \param relabel_flags whether to relabel input or output
39 //
40 template <class A>
Relabel(MutableFst<A> * fst,const vector<pair<typename A::Label,typename A::Label>> & ipairs,const vector<pair<typename A::Label,typename A::Label>> & opairs)41 void Relabel(
42 MutableFst<A> *fst,
43 const vector<pair<typename A::Label, typename A::Label> >& ipairs,
44 const vector<pair<typename A::Label, typename A::Label> >& opairs) {
45 typedef typename A::StateId StateId;
46 typedef typename A::Label Label;
47
48 uint64 props = fst->Properties(kFstProperties, false);
49
50 // construct label to label hash. Could
51 std::unordered_map<Label, Label> input_map;
52 for (size_t i = 0; i < ipairs.size(); ++i) {
53 input_map[ipairs[i].first] = ipairs[i].second;
54 }
55
56 std::unordered_map<Label, Label> output_map;
57 for (size_t i = 0; i < opairs.size(); ++i) {
58 output_map[opairs[i].first] = opairs[i].second;
59 }
60
61 for (StateIterator<MutableFst<A> > siter(*fst);
62 !siter.Done(); siter.Next()) {
63 StateId s = siter.Value();
64 for (MutableArcIterator<MutableFst<A> > aiter(fst, s);
65 !aiter.Done(); aiter.Next()) {
66 A arc = aiter.Value();
67
68 // relabel input
69 // only relabel if relabel pair defined
70 typename std::unordered_map<Label, Label>::iterator it =
71 input_map.find(arc.ilabel);
72 if (it != input_map.end()) {arc.ilabel = it->second; }
73
74 // relabel output
75 it = output_map.find(arc.olabel);
76 if (it != output_map.end()) { arc.olabel = it->second; }
77
78 aiter.SetValue(arc);
79 }
80 }
81
82 fst->SetProperties(RelabelProperties(props), kFstProperties);
83 }
84
85
86
87 //
88 // Relabels either the input labels or output labels. The old to
89 // new labels mappings are specified using an input Symbol set.
90 // Any label associations not specified are assumed to be identity
91 // mapping.
92 //
93 // \param fst input fst, must be mutable
94 // \param new_symbols symbol set indicating new mapping
95 // \param relabel_flags whether to relabel input or output
96 //
97 template<class A>
Relabel(MutableFst<A> * fst,const SymbolTable * new_isymbols,const SymbolTable * new_osymbols)98 void Relabel(MutableFst<A> *fst,
99 const SymbolTable* new_isymbols,
100 const SymbolTable* new_osymbols) {
101 typedef typename A::StateId StateId;
102 typedef typename A::Label Label;
103
104 const SymbolTable* old_isymbols = fst->InputSymbols();
105 const SymbolTable* old_osymbols = fst->OutputSymbols();
106
107 vector<pair<Label, Label> > ipairs;
108 if (old_isymbols && new_isymbols) {
109 for (SymbolTableIterator syms_iter(*old_isymbols); !syms_iter.Done();
110 syms_iter.Next()) {
111 ipairs.push_back(make_pair(syms_iter.Value(),
112 new_isymbols->Find(syms_iter.Symbol())));
113 }
114 fst->SetInputSymbols(new_isymbols);
115 }
116
117 vector<pair<Label, Label> > opairs;
118 if (old_osymbols && new_osymbols) {
119 for (SymbolTableIterator syms_iter(*old_osymbols); !syms_iter.Done();
120 syms_iter.Next()) {
121 opairs.push_back(make_pair(syms_iter.Value(),
122 new_osymbols->Find(syms_iter.Symbol())));
123 }
124 fst->SetOutputSymbols(new_osymbols);
125 }
126
127 // call relabel using vector of relabel pairs.
128 Relabel(fst, ipairs, opairs);
129 }
130
131
132 typedef CacheOptions RelabelFstOptions;
133
134 template <class A> class RelabelFst;
135
136 //
137 // \class RelabelFstImpl
138 // \brief Implementation for delayed relabeling
139 //
140 // Relabels an FST from one symbol set to another. Relabeling
141 // can either be on input or output space. RelabelFst implements
142 // a delayed version of the relabel. Arcs are relabeled on the fly
143 // and not cached. I.e each request is recomputed.
144 //
145 template<class A>
146 class RelabelFstImpl : public CacheImpl<A> {
147 friend class StateIterator< RelabelFst<A> >;
148 public:
149 using FstImpl<A>::SetType;
150 using FstImpl<A>::SetProperties;
151 using FstImpl<A>::Properties;
152 using FstImpl<A>::SetInputSymbols;
153 using FstImpl<A>::SetOutputSymbols;
154
155 using CacheImpl<A>::HasStart;
156 using CacheImpl<A>::HasArcs;
157
158 typedef typename A::Label Label;
159 typedef typename A::Weight Weight;
160 typedef typename A::StateId StateId;
161 typedef CacheState<A> State;
162
RelabelFstImpl(const Fst<A> & fst,const vector<pair<Label,Label>> & ipairs,const vector<pair<Label,Label>> & opairs,const RelabelFstOptions & opts)163 RelabelFstImpl(const Fst<A>& fst,
164 const vector<pair<Label, Label> >& ipairs,
165 const vector<pair<Label, Label> >& opairs,
166 const RelabelFstOptions &opts)
167 : CacheImpl<A>(opts), fst_(fst.Copy()),
168 relabel_input_(false), relabel_output_(false) {
169 uint64 props = fst.Properties(kCopyProperties, false);
170 SetProperties(RelabelProperties(props));
171 SetType("relabel");
172
173 // create input label map
174 if (ipairs.size() > 0) {
175 for (size_t i = 0; i < ipairs.size(); ++i) {
176 input_map_[ipairs[i].first] = ipairs[i].second;
177 }
178 relabel_input_ = true;
179 }
180
181 // create output label map
182 if (opairs.size() > 0) {
183 for (size_t i = 0; i < opairs.size(); ++i) {
184 output_map_[opairs[i].first] = opairs[i].second;
185 }
186 relabel_output_ = true;
187 }
188 }
189
RelabelFstImpl(const Fst<A> & fst,const SymbolTable * new_isymbols,const SymbolTable * new_osymbols,const RelabelFstOptions & opts)190 RelabelFstImpl(const Fst<A>& fst,
191 const SymbolTable* new_isymbols,
192 const SymbolTable* new_osymbols,
193 const RelabelFstOptions &opts)
194 : CacheImpl<A>(opts), fst_(fst.Copy()),
195 relabel_input_(false), relabel_output_(false) {
196 SetType("relabel");
197
198 uint64 props = fst.Properties(kCopyProperties, false);
199 SetProperties(RelabelProperties(props));
200 SetInputSymbols(fst.InputSymbols());
201 SetOutputSymbols(fst.OutputSymbols());
202
203 const SymbolTable* old_isymbols = fst.InputSymbols();
204 const SymbolTable* old_osymbols = fst.OutputSymbols();
205
206 if (old_isymbols && new_isymbols &&
207 old_isymbols->CheckSum() != new_isymbols->CheckSum()) {
208 for (SymbolTableIterator syms_iter(*old_isymbols); !syms_iter.Done();
209 syms_iter.Next()) {
210 input_map_[syms_iter.Value()] = new_isymbols->Find(syms_iter.Symbol());
211 }
212 SetInputSymbols(new_isymbols);
213 relabel_input_ = true;
214 }
215
216 if (old_osymbols && new_osymbols &&
217 old_osymbols->CheckSum() != new_osymbols->CheckSum()) {
218 for (SymbolTableIterator syms_iter(*old_osymbols); !syms_iter.Done();
219 syms_iter.Next()) {
220 output_map_[syms_iter.Value()] =
221 new_osymbols->Find(syms_iter.Symbol());
222 }
223 SetOutputSymbols(new_osymbols);
224 relabel_output_ = true;
225 }
226 }
227
~RelabelFstImpl()228 ~RelabelFstImpl() { delete fst_; }
229
Start()230 StateId Start() {
231 if (!HasStart()) {
232 StateId s = fst_->Start();
233 SetStart(s);
234 }
235 return CacheImpl<A>::Start();
236 }
237
Final(StateId s)238 Weight Final(StateId s) {
239 if (!HasFinal(s)) {
240 SetFinal(s, fst_->Final(s));
241 }
242 return CacheImpl<A>::Final(s);
243 }
244
NumArcs(StateId s)245 size_t NumArcs(StateId s) {
246 if (!HasArcs(s)) {
247 Expand(s);
248 }
249 return CacheImpl<A>::NumArcs(s);
250 }
251
NumInputEpsilons(StateId s)252 size_t NumInputEpsilons(StateId s) {
253 if (!HasArcs(s)) {
254 Expand(s);
255 }
256 return CacheImpl<A>::NumInputEpsilons(s);
257 }
258
NumOutputEpsilons(StateId s)259 size_t NumOutputEpsilons(StateId s) {
260 if (!HasArcs(s)) {
261 Expand(s);
262 }
263 return CacheImpl<A>::NumOutputEpsilons(s);
264 }
265
InitArcIterator(StateId s,ArcIteratorData<A> * data)266 void InitArcIterator(StateId s, ArcIteratorData<A>* data) {
267 if (!HasArcs(s)) {
268 Expand(s);
269 }
270 CacheImpl<A>::InitArcIterator(s, data);
271 }
272
Expand(StateId s)273 void Expand(StateId s) {
274 for (ArcIterator<Fst<A> > aiter(*fst_, s); !aiter.Done(); aiter.Next()) {
275 A arc = aiter.Value();
276
277 // relabel input
278 if (relabel_input_) {
279 typename std::unordered_map<Label, Label>::iterator it =
280 input_map_.find(arc.ilabel);
281 if (it != input_map_.end()) { arc.ilabel = it->second; }
282 }
283
284 // relabel output
285 if (relabel_output_) {
286 typename std::unordered_map<Label, Label>::iterator it =
287 output_map_.find(arc.olabel);
288 if (it != output_map_.end()) { arc.olabel = it->second; }
289 }
290
291 AddArc(s, arc);
292 }
293 SetArcs(s);
294 }
295
296
297 private:
298 const Fst<A> *fst_;
299
300 std::unordered_map<Label, Label> input_map_;
301 std::unordered_map<Label, Label> output_map_;
302 bool relabel_input_;
303 bool relabel_output_;
304
305 DISALLOW_EVIL_CONSTRUCTORS(RelabelFstImpl);
306 };
307
308
309 //
310 // \class RelabelFst
311 // \brief Delayed implementation of arc relabeling
312 //
313 // This class attaches interface to implementation and handles
314 // reference counting.
315 template <class A>
316 class RelabelFst : public Fst<A> {
317 public:
318 friend class ArcIterator< RelabelFst<A> >;
319 friend class StateIterator< RelabelFst<A> >;
320 friend class CacheArcIterator< RelabelFst<A> >;
321
322 typedef A Arc;
323 typedef typename A::Label Label;
324 typedef typename A::Weight Weight;
325 typedef typename A::StateId StateId;
326 typedef CacheState<A> State;
327
RelabelFst(const Fst<A> & fst,const vector<pair<Label,Label>> & ipairs,const vector<pair<Label,Label>> & opairs)328 RelabelFst(const Fst<A>& fst,
329 const vector<pair<Label, Label> >& ipairs,
330 const vector<pair<Label, Label> >& opairs) :
331 impl_(new RelabelFstImpl<A>(fst, ipairs, opairs, RelabelFstOptions())) {}
332
RelabelFst(const Fst<A> & fst,const vector<pair<Label,Label>> & ipairs,const vector<pair<Label,Label>> & opairs,const RelabelFstOptions & opts)333 RelabelFst(const Fst<A>& fst,
334 const vector<pair<Label, Label> >& ipairs,
335 const vector<pair<Label, Label> >& opairs,
336 const RelabelFstOptions &opts)
337 : impl_(new RelabelFstImpl<A>(fst, ipairs, opairs, opts)) {}
338
RelabelFst(const Fst<A> & fst,const SymbolTable * new_isymbols,const SymbolTable * new_osymbols)339 RelabelFst(const Fst<A>& fst,
340 const SymbolTable* new_isymbols,
341 const SymbolTable* new_osymbols) :
342 impl_(new RelabelFstImpl<A>(fst, new_isymbols, new_osymbols,
343 RelabelFstOptions())) {}
344
RelabelFst(const Fst<A> & fst,const SymbolTable * new_isymbols,const SymbolTable * new_osymbols,const RelabelFstOptions & opts)345 RelabelFst(const Fst<A>& fst,
346 const SymbolTable* new_isymbols,
347 const SymbolTable* new_osymbols,
348 const RelabelFstOptions &opts)
349 : impl_(new RelabelFstImpl<A>(fst, new_isymbols, new_osymbols, opts)) {}
350
RelabelFst(const RelabelFst<A> & fst)351 RelabelFst(const RelabelFst<A> &fst) : impl_(fst.impl_) {
352 impl_->IncrRefCount();
353 }
354
~RelabelFst()355 virtual ~RelabelFst() { if (!impl_->DecrRefCount()) delete impl_; }
356
Start()357 virtual StateId Start() const { return impl_->Start(); }
358
Final(StateId s)359 virtual Weight Final(StateId s) const { return impl_->Final(s); }
360
NumArcs(StateId s)361 virtual size_t NumArcs(StateId s) const { return impl_->NumArcs(s); }
362
NumInputEpsilons(StateId s)363 virtual size_t NumInputEpsilons(StateId s) const {
364 return impl_->NumInputEpsilons(s);
365 }
366
NumOutputEpsilons(StateId s)367 virtual size_t NumOutputEpsilons(StateId s) const {
368 return impl_->NumOutputEpsilons(s);
369 }
370
Properties(uint64 mask,bool test)371 virtual uint64 Properties(uint64 mask, bool test) const {
372 if (test) {
373 uint64 known, test = TestProperties(*this, mask, &known);
374 impl_->SetProperties(test, known);
375 return test & mask;
376 } else {
377 return impl_->Properties(mask);
378 }
379 }
380
Type()381 virtual const string& Type() const { return impl_->Type(); }
382
Copy()383 virtual RelabelFst<A> *Copy() const {
384 return new RelabelFst<A>(*this);
385 }
386
InputSymbols()387 virtual const SymbolTable* InputSymbols() const {
388 return impl_->InputSymbols();
389 }
390
OutputSymbols()391 virtual const SymbolTable* OutputSymbols() const {
392 return impl_->OutputSymbols();
393 }
394
395 virtual void InitStateIterator(StateIteratorData<A> *data) const;
396
InitArcIterator(StateId s,ArcIteratorData<A> * data)397 virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
398 return impl_->InitArcIterator(s, data);
399 }
400
401 private:
402 RelabelFstImpl<A> *impl_;
403
404 void operator=(const RelabelFst<A> &fst); // disallow
405 };
406
407 // Specialization for RelabelFst.
408 template<class A>
409 class StateIterator< RelabelFst<A> > : public StateIteratorBase<A> {
410 public:
411 typedef typename A::StateId StateId;
412
StateIterator(const RelabelFst<A> & fst)413 explicit StateIterator(const RelabelFst<A> &fst)
414 : impl_(fst.impl_), siter_(*impl_->fst_), s_(0) {}
415
Done()416 bool Done() const { return siter_.Done(); }
417
Value()418 StateId Value() const { return s_; }
419
Next()420 void Next() {
421 if (!siter_.Done()) {
422 ++s_;
423 siter_.Next();
424 }
425 }
426
Reset()427 void Reset() {
428 s_ = 0;
429 siter_.Reset();
430 }
431
432 private:
433 const RelabelFstImpl<A> *impl_;
434 StateIterator< Fst<A> > siter_;
435 StateId s_;
436
437 DISALLOW_EVIL_CONSTRUCTORS(StateIterator);
438 };
439
440
441 // Specialization for RelabelFst.
442 template <class A>
443 class ArcIterator< RelabelFst<A> >
444 : public CacheArcIterator< RelabelFst<A> > {
445 public:
446 typedef typename A::StateId StateId;
447
ArcIterator(const RelabelFst<A> & fst,StateId s)448 ArcIterator(const RelabelFst<A> &fst, StateId s)
449 : CacheArcIterator< RelabelFst<A> >(fst, s) {
450 if (!fst.impl_->HasArcs(s))
451 fst.impl_->Expand(s);
452 }
453
454 private:
455 DISALLOW_EVIL_CONSTRUCTORS(ArcIterator);
456 };
457
458 template <class A> inline
InitStateIterator(StateIteratorData<A> * data)459 void RelabelFst<A>::InitStateIterator(StateIteratorData<A> *data) const {
460 data->base = new StateIterator< RelabelFst<A> >(*this);
461 }
462
463 // Useful alias when using StdArc.
464 typedef RelabelFst<StdArc> StdRelabelFst;
465
466 } // namespace fst
467
468 #endif // FST_LIB_RELABEL_H__
469