1 // encode.h
2
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Copyright 2005-2010 Google, Inc.
16 // Author: johans@google.com (Johan Schalkwyk)
17 //
18 // \file
19 // Class to encode and decoder an fst.
20
21 #ifndef FST_LIB_ENCODE_H__
22 #define FST_LIB_ENCODE_H__
23
24 #include <climits>
25 #include <tr1/unordered_map>
26 using std::tr1::unordered_map;
27 using std::tr1::unordered_multimap;
28 #include <string>
29 #include <vector>
30 using std::vector;
31
32 #include <fst/arc-map.h>
33 #include <fst/rmfinalepsilon.h>
34
35
36 namespace fst {
37
38 static const uint32 kEncodeLabels = 0x0001;
39 static const uint32 kEncodeWeights = 0x0002;
40 static const uint32 kEncodeFlags = 0x0003; // All non-internal flags
41
42 static const uint32 kEncodeHasISymbols = 0x0004; // For internal use
43 static const uint32 kEncodeHasOSymbols = 0x0008; // For internal use
44
45 enum EncodeType { ENCODE = 1, DECODE = 2 };
46
47 // Identifies stream data as an encode table (and its endianity)
48 static const int32 kEncodeMagicNumber = 2129983209;
49
50
51 // The following class encapsulates implementation details for the
52 // encoding and decoding of label/weight tuples used for encoding
53 // and decoding of Fsts. The EncodeTable is bidirectional. I.E it
54 // stores both the Tuple of encode labels and weights to a unique
55 // label, and the reverse.
56 template <class A> class EncodeTable {
57 public:
58 typedef typename A::Label Label;
59 typedef typename A::Weight Weight;
60
61 // Encoded data consists of arc input/output labels and arc weight
62 struct Tuple {
TupleTuple63 Tuple() {}
TupleTuple64 Tuple(Label ilabel_, Label olabel_, Weight weight_)
65 : ilabel(ilabel_), olabel(olabel_), weight(weight_) {}
TupleTuple66 Tuple(const Tuple& tuple)
67 : ilabel(tuple.ilabel), olabel(tuple.olabel), weight(tuple.weight) {}
68
69 Label ilabel;
70 Label olabel;
71 Weight weight;
72 };
73
74 // Comparison object for hashing EncodeTable Tuple(s).
75 class TupleEqual {
76 public:
operator()77 bool operator()(const Tuple* x, const Tuple* y) const {
78 return (x->ilabel == y->ilabel &&
79 x->olabel == y->olabel &&
80 x->weight == y->weight);
81 }
82 };
83
84 // Hash function for EncodeTabe Tuples. Based on the encode flags
85 // we either hash the labels, weights or combination of them.
86 class TupleKey {
87 public:
TupleKey()88 TupleKey()
89 : encode_flags_(kEncodeLabels | kEncodeWeights) {}
90
TupleKey(const TupleKey & key)91 TupleKey(const TupleKey& key)
92 : encode_flags_(key.encode_flags_) {}
93
TupleKey(uint32 encode_flags)94 explicit TupleKey(uint32 encode_flags)
95 : encode_flags_(encode_flags) {}
96
operator()97 size_t operator()(const Tuple* x) const {
98 size_t hash = x->ilabel;
99 const int lshift = 5;
100 const int rshift = CHAR_BIT * sizeof(size_t) - 5;
101 if (encode_flags_ & kEncodeLabels)
102 hash = hash << lshift ^ hash >> rshift ^ x->olabel;
103 if (encode_flags_ & kEncodeWeights)
104 hash = hash << lshift ^ hash >> rshift ^ x->weight.Hash();
105 return hash;
106 }
107
108 private:
109 int32 encode_flags_;
110 };
111
112 typedef unordered_map<const Tuple*,
113 Label,
114 TupleKey,
115 TupleEqual> EncodeHash;
116
EncodeTable(uint32 encode_flags)117 explicit EncodeTable(uint32 encode_flags)
118 : flags_(encode_flags),
119 encode_hash_(1024, TupleKey(encode_flags)),
120 isymbols_(0), osymbols_(0) {}
121
~EncodeTable()122 ~EncodeTable() {
123 for (size_t i = 0; i < encode_tuples_.size(); ++i) {
124 delete encode_tuples_[i];
125 }
126 delete isymbols_;
127 delete osymbols_;
128 }
129
130 // Given an arc encode either input/ouptut labels or input/costs or both
Encode(const A & arc)131 Label Encode(const A &arc) {
132 const Tuple tuple(arc.ilabel,
133 flags_ & kEncodeLabels ? arc.olabel : 0,
134 flags_ & kEncodeWeights ? arc.weight : Weight::One());
135 typename EncodeHash::const_iterator it = encode_hash_.find(&tuple);
136 if (it == encode_hash_.end()) {
137 encode_tuples_.push_back(new Tuple(tuple));
138 encode_hash_[encode_tuples_.back()] = encode_tuples_.size();
139 return encode_tuples_.size();
140 } else {
141 return it->second;
142 }
143 }
144
145 // Given an arc, look up its encoded label. Returns kNoLabel if not found.
GetLabel(const A & arc)146 Label GetLabel(const A &arc) const {
147 const Tuple tuple(arc.ilabel,
148 flags_ & kEncodeLabels ? arc.olabel : 0,
149 flags_ & kEncodeWeights ? arc.weight : Weight::One());
150 typename EncodeHash::const_iterator it = encode_hash_.find(&tuple);
151 if (it == encode_hash_.end()) {
152 return kNoLabel;
153 } else {
154 return it->second;
155 }
156 }
157
158 // Given an encode arc Label decode back to input/output labels and costs
Decode(Label key)159 const Tuple* Decode(Label key) const {
160 if (key < 1 || key > encode_tuples_.size()) {
161 LOG(ERROR) << "EncodeTable::Decode: unknown decode key: " << key;
162 return 0;
163 }
164 return encode_tuples_[key - 1];
165 }
166
Size()167 size_t Size() const { return encode_tuples_.size(); }
168
169 bool Write(ostream &strm, const string &source) const;
170
171 static EncodeTable<A> *Read(istream &strm, const string &source);
172
flags()173 const uint32 flags() const { return flags_ & kEncodeFlags; }
174
RefCount()175 int RefCount() const { return ref_count_.count(); }
IncrRefCount()176 int IncrRefCount() { return ref_count_.Incr(); }
DecrRefCount()177 int DecrRefCount() { return ref_count_.Decr(); }
178
179
InputSymbols()180 SymbolTable *InputSymbols() const { return isymbols_; }
181
OutputSymbols()182 SymbolTable *OutputSymbols() const { return osymbols_; }
183
SetInputSymbols(const SymbolTable * syms)184 void SetInputSymbols(const SymbolTable* syms) {
185 if (isymbols_) delete isymbols_;
186 if (syms) {
187 isymbols_ = syms->Copy();
188 flags_ |= kEncodeHasISymbols;
189 } else {
190 isymbols_ = 0;
191 flags_ &= ~kEncodeHasISymbols;
192 }
193 }
194
SetOutputSymbols(const SymbolTable * syms)195 void SetOutputSymbols(const SymbolTable* syms) {
196 if (osymbols_) delete osymbols_;
197 if (syms) {
198 osymbols_ = syms->Copy();
199 flags_ |= kEncodeHasOSymbols;
200 } else {
201 osymbols_ = 0;
202 flags_ &= ~kEncodeHasOSymbols;
203 }
204 }
205
206 private:
207 uint32 flags_;
208 vector<Tuple*> encode_tuples_;
209 EncodeHash encode_hash_;
210 RefCounter ref_count_;
211 SymbolTable *isymbols_; // Pre-encoded ilabel symbol table
212 SymbolTable *osymbols_; // Pre-encoded olabel symbol table
213
214 DISALLOW_COPY_AND_ASSIGN(EncodeTable);
215 };
216
217 template <class A> inline
Write(ostream & strm,const string & source)218 bool EncodeTable<A>::Write(ostream &strm, const string &source) const {
219 WriteType(strm, kEncodeMagicNumber);
220 WriteType(strm, flags_);
221 int64 size = encode_tuples_.size();
222 WriteType(strm, size);
223 for (size_t i = 0; i < size; ++i) {
224 const Tuple* tuple = encode_tuples_[i];
225 WriteType(strm, tuple->ilabel);
226 WriteType(strm, tuple->olabel);
227 tuple->weight.Write(strm);
228 }
229
230 if (flags_ & kEncodeHasISymbols)
231 isymbols_->Write(strm);
232
233 if (flags_ & kEncodeHasOSymbols)
234 osymbols_->Write(strm);
235
236 strm.flush();
237 if (!strm) {
238 LOG(ERROR) << "EncodeTable::Write: write failed: " << source;
239 return false;
240 }
241 return true;
242 }
243
244 template <class A> inline
Read(istream & strm,const string & source)245 EncodeTable<A> *EncodeTable<A>::Read(istream &strm, const string &source) {
246 int32 magic_number = 0;
247 ReadType(strm, &magic_number);
248 if (magic_number != kEncodeMagicNumber) {
249 LOG(ERROR) << "EncodeTable::Read: Bad encode table header: " << source;
250 return 0;
251 }
252 uint32 flags;
253 ReadType(strm, &flags);
254 EncodeTable<A> *table = new EncodeTable<A>(flags);
255
256 int64 size;
257 ReadType(strm, &size);
258 if (!strm) {
259 LOG(ERROR) << "EncodeTable::Read: read failed: " << source;
260 return 0;
261 }
262
263 for (size_t i = 0; i < size; ++i) {
264 Tuple* tuple = new Tuple();
265 ReadType(strm, &tuple->ilabel);
266 ReadType(strm, &tuple->olabel);
267 tuple->weight.Read(strm);
268 if (!strm) {
269 LOG(ERROR) << "EncodeTable::Read: read failed: " << source;
270 return 0;
271 }
272 table->encode_tuples_.push_back(tuple);
273 table->encode_hash_[table->encode_tuples_.back()] =
274 table->encode_tuples_.size();
275 }
276
277 if (flags & kEncodeHasISymbols)
278 table->isymbols_ = SymbolTable::Read(strm, source);
279
280 if (flags & kEncodeHasOSymbols)
281 table->osymbols_ = SymbolTable::Read(strm, source);
282
283 return table;
284 }
285
286
287 // A mapper to encode/decode weighted transducers. Encoding of an
288 // Fst is useful for performing classical determinization or minimization
289 // on a weighted transducer by treating it as an unweighted acceptor over
290 // encoded labels.
291 //
292 // The Encode mapper stores the encoding in a local hash table (EncodeTable)
293 // This table is shared (and reference counted) between the encoder and
294 // decoder. A decoder has read only access to the EncodeTable.
295 //
296 // The EncodeMapper allows on the fly encoding of the machine. As the
297 // EncodeTable is generated the same table may by used to decode the machine
298 // on the fly. For example in the following sequence of operations
299 //
300 // Encode -> Determinize -> Decode
301 //
302 // we will use the encoding table generated during the encode step in the
303 // decode, even though the encoding is not complete.
304 //
305 template <class A> class EncodeMapper {
306 typedef typename A::Weight Weight;
307 typedef typename A::Label Label;
308 public:
EncodeMapper(uint32 flags,EncodeType type)309 EncodeMapper(uint32 flags, EncodeType type)
310 : flags_(flags),
311 type_(type),
312 table_(new EncodeTable<A>(flags)),
313 error_(false) {}
314
EncodeMapper(const EncodeMapper & mapper)315 EncodeMapper(const EncodeMapper& mapper)
316 : flags_(mapper.flags_),
317 type_(mapper.type_),
318 table_(mapper.table_),
319 error_(false) {
320 table_->IncrRefCount();
321 }
322
323 // Copy constructor but setting the type, typically to DECODE
EncodeMapper(const EncodeMapper & mapper,EncodeType type)324 EncodeMapper(const EncodeMapper& mapper, EncodeType type)
325 : flags_(mapper.flags_),
326 type_(type),
327 table_(mapper.table_),
328 error_(mapper.error_) {
329 table_->IncrRefCount();
330 }
331
~EncodeMapper()332 ~EncodeMapper() {
333 if (!table_->DecrRefCount()) delete table_;
334 }
335
336 A operator()(const A &arc);
337
FinalAction()338 MapFinalAction FinalAction() const {
339 return (type_ == ENCODE && (flags_ & kEncodeWeights)) ?
340 MAP_REQUIRE_SUPERFINAL : MAP_NO_SUPERFINAL;
341 }
342
InputSymbolsAction()343 MapSymbolsAction InputSymbolsAction() const { return MAP_CLEAR_SYMBOLS; }
344
OutputSymbolsAction()345 MapSymbolsAction OutputSymbolsAction() const { return MAP_CLEAR_SYMBOLS;}
346
Properties(uint64 inprops)347 uint64 Properties(uint64 inprops) {
348 uint64 outprops = inprops;
349 if (error_) outprops |= kError;
350
351 uint64 mask = kFstProperties;
352 if (flags_ & kEncodeLabels)
353 mask &= kILabelInvariantProperties & kOLabelInvariantProperties;
354 if (flags_ & kEncodeWeights)
355 mask &= kILabelInvariantProperties & kWeightInvariantProperties &
356 (type_ == ENCODE ? kAddSuperFinalProperties :
357 kRmSuperFinalProperties);
358
359 return outprops & mask;
360 }
361
flags()362 const uint32 flags() const { return flags_; }
type()363 const EncodeType type() const { return type_; }
table()364 const EncodeTable<A> &table() const { return *table_; }
365
Write(ostream & strm,const string & source)366 bool Write(ostream &strm, const string& source) {
367 return table_->Write(strm, source);
368 }
369
Write(const string & filename)370 bool Write(const string& filename) {
371 ofstream strm(filename.c_str(), ofstream::out | ofstream::binary);
372 if (!strm) {
373 LOG(ERROR) << "EncodeMap: Can't open file: " << filename;
374 return false;
375 }
376 return Write(strm, filename);
377 }
378
379 static EncodeMapper<A> *Read(istream &strm,
380 const string& source,
381 EncodeType type = ENCODE) {
382 EncodeTable<A> *table = EncodeTable<A>::Read(strm, source);
383 return table ? new EncodeMapper(table->flags(), type, table) : 0;
384 }
385
386 static EncodeMapper<A> *Read(const string& filename,
387 EncodeType type = ENCODE) {
388 ifstream strm(filename.c_str(), ifstream::in | ifstream::binary);
389 if (!strm) {
390 LOG(ERROR) << "EncodeMap: Can't open file: " << filename;
391 return NULL;
392 }
393 return Read(strm, filename, type);
394 }
395
InputSymbols()396 SymbolTable *InputSymbols() const { return table_->InputSymbols(); }
397
OutputSymbols()398 SymbolTable *OutputSymbols() const { return table_->OutputSymbols(); }
399
SetInputSymbols(const SymbolTable * syms)400 void SetInputSymbols(const SymbolTable* syms) {
401 table_->SetInputSymbols(syms);
402 }
403
SetOutputSymbols(const SymbolTable * syms)404 void SetOutputSymbols(const SymbolTable* syms) {
405 table_->SetOutputSymbols(syms);
406 }
407
408 private:
409 uint32 flags_;
410 EncodeType type_;
411 EncodeTable<A>* table_;
412 bool error_;
413
EncodeMapper(uint32 flags,EncodeType type,EncodeTable<A> * table)414 explicit EncodeMapper(uint32 flags, EncodeType type, EncodeTable<A> *table)
415 : flags_(flags), type_(type), table_(table) {}
416 void operator=(const EncodeMapper &); // Disallow.
417 };
418
419 template <class A> inline
operator()420 A EncodeMapper<A>::operator()(const A &arc) {
421 if (type_ == ENCODE) { // labels and/or weights to single label
422 if ((arc.nextstate == kNoStateId && !(flags_ & kEncodeWeights)) ||
423 (arc.nextstate == kNoStateId && (flags_ & kEncodeWeights) &&
424 arc.weight == Weight::Zero())) {
425 return arc;
426 } else {
427 Label label = table_->Encode(arc);
428 return A(label,
429 flags_ & kEncodeLabels ? label : arc.olabel,
430 flags_ & kEncodeWeights ? Weight::One() : arc.weight,
431 arc.nextstate);
432 }
433 } else { // type_ == DECODE
434 if (arc.nextstate == kNoStateId) {
435 return arc;
436 } else {
437 if (arc.ilabel == 0) return arc;
438 if (flags_ & kEncodeLabels && arc.ilabel != arc.olabel) {
439 FSTERROR() << "EncodeMapper: Label-encoded arc has different "
440 "input and output labels";
441 error_ = true;
442 }
443 if (flags_ & kEncodeWeights && arc.weight != Weight::One()) {
444 FSTERROR() <<
445 "EncodeMapper: Weight-encoded arc has non-trivial weight";
446 error_ = true;
447 }
448 const typename EncodeTable<A>::Tuple* tuple = table_->Decode(arc.ilabel);
449 if (!tuple) {
450 FSTERROR() << "EncodeMapper: decode failed";
451 error_ = true;
452 return A(kNoLabel, kNoLabel, Weight::NoWeight(), arc.nextstate);
453 } else {
454 return A(tuple->ilabel,
455 flags_ & kEncodeLabels ? tuple->olabel : arc.olabel,
456 flags_ & kEncodeWeights ? tuple->weight : arc.weight,
457 arc.nextstate);
458 }
459 }
460 }
461 }
462
463
464 // Complexity: O(nstates + narcs)
465 template<class A> inline
Encode(MutableFst<A> * fst,EncodeMapper<A> * mapper)466 void Encode(MutableFst<A> *fst, EncodeMapper<A>* mapper) {
467 mapper->SetInputSymbols(fst->InputSymbols());
468 mapper->SetOutputSymbols(fst->OutputSymbols());
469 ArcMap(fst, mapper);
470 }
471
472 template<class A> inline
Decode(MutableFst<A> * fst,const EncodeMapper<A> & mapper)473 void Decode(MutableFst<A>* fst, const EncodeMapper<A>& mapper) {
474 ArcMap(fst, EncodeMapper<A>(mapper, DECODE));
475 RmFinalEpsilon(fst);
476 fst->SetInputSymbols(mapper.InputSymbols());
477 fst->SetOutputSymbols(mapper.OutputSymbols());
478 }
479
480
481 // On the fly label and/or weight encoding of input Fst
482 //
483 // Complexity:
484 // - Constructor: O(1)
485 // - Traversal: O(nstates_visited + narcs_visited), assuming constant
486 // time to visit an input state or arc.
487 template <class A>
488 class EncodeFst : public ArcMapFst<A, A, EncodeMapper<A> > {
489 public:
490 typedef A Arc;
491 typedef EncodeMapper<A> C;
492 typedef ArcMapFstImpl< A, A, EncodeMapper<A> > Impl;
493 using ImplToFst<Impl>::GetImpl;
494
EncodeFst(const Fst<A> & fst,EncodeMapper<A> * encoder)495 EncodeFst(const Fst<A> &fst, EncodeMapper<A>* encoder)
496 : ArcMapFst<A, A, C>(fst, encoder, ArcMapFstOptions()) {
497 encoder->SetInputSymbols(fst.InputSymbols());
498 encoder->SetOutputSymbols(fst.OutputSymbols());
499 }
500
EncodeFst(const Fst<A> & fst,const EncodeMapper<A> & encoder)501 EncodeFst(const Fst<A> &fst, const EncodeMapper<A>& encoder)
502 : ArcMapFst<A, A, C>(fst, encoder, ArcMapFstOptions()) {}
503
504 // See Fst<>::Copy() for doc.
505 EncodeFst(const EncodeFst<A> &fst, bool copy = false)
506 : ArcMapFst<A, A, C>(fst, copy) {}
507
508 // Get a copy of this EncodeFst. See Fst<>::Copy() for further doc.
509 virtual EncodeFst<A> *Copy(bool safe = false) const {
510 if (safe) {
511 FSTERROR() << "EncodeFst::Copy(true): not allowed.";
512 GetImpl()->SetProperties(kError, kError);
513 }
514 return new EncodeFst(*this);
515 }
516 };
517
518
519 // On the fly label and/or weight encoding of input Fst
520 //
521 // Complexity:
522 // - Constructor: O(1)
523 // - Traversal: O(nstates_visited + narcs_visited), assuming constant
524 // time to visit an input state or arc.
525 template <class A>
526 class DecodeFst : public ArcMapFst<A, A, EncodeMapper<A> > {
527 public:
528 typedef A Arc;
529 typedef EncodeMapper<A> C;
530 typedef ArcMapFstImpl< A, A, EncodeMapper<A> > Impl;
531 using ImplToFst<Impl>::GetImpl;
532
DecodeFst(const Fst<A> & fst,const EncodeMapper<A> & encoder)533 DecodeFst(const Fst<A> &fst, const EncodeMapper<A>& encoder)
534 : ArcMapFst<A, A, C>(fst,
535 EncodeMapper<A>(encoder, DECODE),
536 ArcMapFstOptions()) {
537 GetImpl()->SetInputSymbols(encoder.InputSymbols());
538 GetImpl()->SetOutputSymbols(encoder.OutputSymbols());
539 }
540
541 // See Fst<>::Copy() for doc.
542 DecodeFst(const DecodeFst<A> &fst, bool safe = false)
543 : ArcMapFst<A, A, C>(fst, safe) {}
544
545 // Get a copy of this DecodeFst. See Fst<>::Copy() for further doc.
546 virtual DecodeFst<A> *Copy(bool safe = false) const {
547 return new DecodeFst(*this, safe);
548 }
549 };
550
551
552 // Specialization for EncodeFst.
553 template <class A>
554 class StateIterator< EncodeFst<A> >
555 : public StateIterator< ArcMapFst<A, A, EncodeMapper<A> > > {
556 public:
StateIterator(const EncodeFst<A> & fst)557 explicit StateIterator(const EncodeFst<A> &fst)
558 : StateIterator< ArcMapFst<A, A, EncodeMapper<A> > >(fst) {}
559 };
560
561
562 // Specialization for EncodeFst.
563 template <class A>
564 class ArcIterator< EncodeFst<A> >
565 : public ArcIterator< ArcMapFst<A, A, EncodeMapper<A> > > {
566 public:
ArcIterator(const EncodeFst<A> & fst,typename A::StateId s)567 ArcIterator(const EncodeFst<A> &fst, typename A::StateId s)
568 : ArcIterator< ArcMapFst<A, A, EncodeMapper<A> > >(fst, s) {}
569 };
570
571
572 // Specialization for DecodeFst.
573 template <class A>
574 class StateIterator< DecodeFst<A> >
575 : public StateIterator< ArcMapFst<A, A, EncodeMapper<A> > > {
576 public:
StateIterator(const DecodeFst<A> & fst)577 explicit StateIterator(const DecodeFst<A> &fst)
578 : StateIterator< ArcMapFst<A, A, EncodeMapper<A> > >(fst) {}
579 };
580
581
582 // Specialization for DecodeFst.
583 template <class A>
584 class ArcIterator< DecodeFst<A> >
585 : public ArcIterator< ArcMapFst<A, A, EncodeMapper<A> > > {
586 public:
ArcIterator(const DecodeFst<A> & fst,typename A::StateId s)587 ArcIterator(const DecodeFst<A> &fst, typename A::StateId s)
588 : ArcIterator< ArcMapFst<A, A, EncodeMapper<A> > >(fst, s) {}
589 };
590
591
592 // Useful aliases when using StdArc.
593 typedef EncodeFst<StdArc> StdEncodeFst;
594
595 typedef DecodeFst<StdArc> StdDecodeFst;
596
597 } // namespace fst
598
599 #endif // FST_LIB_ENCODE_H__
600