• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // encode.h
2 
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Copyright 2005-2010 Google, Inc.
16 // Author: johans@google.com (Johan Schalkwyk)
17 //
18 // \file
19 // Class to encode and decoder an fst.
20 
21 #ifndef FST_LIB_ENCODE_H__
22 #define FST_LIB_ENCODE_H__
23 
24 #include <climits>
25 #include <tr1/unordered_map>
26 using std::tr1::unordered_map;
27 using std::tr1::unordered_multimap;
28 #include <string>
29 #include <vector>
30 using std::vector;
31 
32 #include <fst/arc-map.h>
33 #include <fst/rmfinalepsilon.h>
34 
35 
36 namespace fst {
37 
38 static const uint32 kEncodeLabels      = 0x0001;
39 static const uint32 kEncodeWeights     = 0x0002;
40 static const uint32 kEncodeFlags       = 0x0003;  // All non-internal flags
41 
42 static const uint32 kEncodeHasISymbols = 0x0004;  // For internal use
43 static const uint32 kEncodeHasOSymbols = 0x0008;  // For internal use
44 
45 enum EncodeType { ENCODE = 1, DECODE = 2 };
46 
47 // Identifies stream data as an encode table (and its endianity)
48 static const int32 kEncodeMagicNumber = 2129983209;
49 
50 
51 // The following class encapsulates implementation details for the
52 // encoding and decoding of label/weight tuples used for encoding
53 // and decoding of Fsts. The EncodeTable is bidirectional. I.E it
54 // stores both the Tuple of encode labels and weights to a unique
55 // label, and the reverse.
56 template <class A>  class EncodeTable {
57  public:
58   typedef typename A::Label Label;
59   typedef typename A::Weight Weight;
60 
61   // Encoded data consists of arc input/output labels and arc weight
62   struct Tuple {
TupleTuple63     Tuple() {}
TupleTuple64     Tuple(Label ilabel_, Label olabel_, Weight weight_)
65         : ilabel(ilabel_), olabel(olabel_), weight(weight_) {}
TupleTuple66     Tuple(const Tuple& tuple)
67         : ilabel(tuple.ilabel), olabel(tuple.olabel), weight(tuple.weight) {}
68 
69     Label ilabel;
70     Label olabel;
71     Weight weight;
72   };
73 
74   // Comparison object for hashing EncodeTable Tuple(s).
75   class TupleEqual {
76    public:
operator()77     bool operator()(const Tuple* x, const Tuple* y) const {
78       return (x->ilabel == y->ilabel &&
79               x->olabel == y->olabel &&
80               x->weight == y->weight);
81     }
82   };
83 
84   // Hash function for EncodeTabe Tuples. Based on the encode flags
85   // we either hash the labels, weights or combination of them.
86   class TupleKey {
87    public:
TupleKey()88     TupleKey()
89         : encode_flags_(kEncodeLabels | kEncodeWeights) {}
90 
TupleKey(const TupleKey & key)91     TupleKey(const TupleKey& key)
92         : encode_flags_(key.encode_flags_) {}
93 
TupleKey(uint32 encode_flags)94     explicit TupleKey(uint32 encode_flags)
95         : encode_flags_(encode_flags) {}
96 
operator()97     size_t operator()(const Tuple* x) const {
98       size_t hash = x->ilabel;
99       const int lshift = 5;
100       const int rshift = CHAR_BIT * sizeof(size_t) - 5;
101       if (encode_flags_ & kEncodeLabels)
102         hash = hash << lshift ^ hash >> rshift ^ x->olabel;
103       if (encode_flags_ & kEncodeWeights)
104         hash = hash << lshift ^ hash >> rshift ^ x->weight.Hash();
105       return hash;
106     }
107 
108    private:
109     int32 encode_flags_;
110   };
111 
112   typedef unordered_map<const Tuple*,
113                    Label,
114                    TupleKey,
115                    TupleEqual> EncodeHash;
116 
EncodeTable(uint32 encode_flags)117   explicit EncodeTable(uint32 encode_flags)
118       : flags_(encode_flags),
119         encode_hash_(1024, TupleKey(encode_flags)),
120         isymbols_(0), osymbols_(0) {}
121 
~EncodeTable()122   ~EncodeTable() {
123     for (size_t i = 0; i < encode_tuples_.size(); ++i) {
124       delete encode_tuples_[i];
125     }
126     delete isymbols_;
127     delete osymbols_;
128   }
129 
130   // Given an arc encode either input/ouptut labels or input/costs or both
Encode(const A & arc)131   Label Encode(const A &arc) {
132     const Tuple tuple(arc.ilabel,
133                       flags_ & kEncodeLabels ? arc.olabel : 0,
134                       flags_ & kEncodeWeights ? arc.weight : Weight::One());
135     typename EncodeHash::const_iterator it = encode_hash_.find(&tuple);
136     if (it == encode_hash_.end()) {
137       encode_tuples_.push_back(new Tuple(tuple));
138       encode_hash_[encode_tuples_.back()] = encode_tuples_.size();
139       return encode_tuples_.size();
140     } else {
141       return it->second;
142     }
143   }
144 
145   // Given an arc, look up its encoded label. Returns kNoLabel if not found.
GetLabel(const A & arc)146   Label GetLabel(const A &arc) const {
147     const Tuple tuple(arc.ilabel,
148                       flags_ & kEncodeLabels ? arc.olabel : 0,
149                       flags_ & kEncodeWeights ? arc.weight : Weight::One());
150     typename EncodeHash::const_iterator it = encode_hash_.find(&tuple);
151     if (it == encode_hash_.end()) {
152       return kNoLabel;
153     } else {
154       return it->second;
155     }
156   }
157 
158   // Given an encode arc Label decode back to input/output labels and costs
Decode(Label key)159   const Tuple* Decode(Label key) const {
160     if (key < 1 || key > encode_tuples_.size()) {
161       LOG(ERROR) << "EncodeTable::Decode: unknown decode key: " << key;
162       return 0;
163     }
164     return encode_tuples_[key - 1];
165   }
166 
Size()167   size_t Size() const { return encode_tuples_.size(); }
168 
169   bool Write(ostream &strm, const string &source) const;
170 
171   static EncodeTable<A> *Read(istream &strm, const string &source);
172 
flags()173   const uint32 flags() const { return flags_ & kEncodeFlags; }
174 
RefCount()175   int RefCount() const { return ref_count_.count(); }
IncrRefCount()176   int IncrRefCount() { return ref_count_.Incr(); }
DecrRefCount()177   int DecrRefCount() { return ref_count_.Decr(); }
178 
179 
InputSymbols()180   SymbolTable *InputSymbols() const { return isymbols_; }
181 
OutputSymbols()182   SymbolTable *OutputSymbols() const { return osymbols_; }
183 
SetInputSymbols(const SymbolTable * syms)184   void SetInputSymbols(const SymbolTable* syms) {
185     if (isymbols_) delete isymbols_;
186     if (syms) {
187       isymbols_ = syms->Copy();
188       flags_ |= kEncodeHasISymbols;
189     } else {
190       isymbols_ = 0;
191       flags_ &= ~kEncodeHasISymbols;
192     }
193   }
194 
SetOutputSymbols(const SymbolTable * syms)195   void SetOutputSymbols(const SymbolTable* syms) {
196     if (osymbols_) delete osymbols_;
197     if (syms) {
198       osymbols_ = syms->Copy();
199       flags_ |= kEncodeHasOSymbols;
200     } else {
201       osymbols_ = 0;
202       flags_ &= ~kEncodeHasOSymbols;
203     }
204   }
205 
206  private:
207   uint32 flags_;
208   vector<Tuple*> encode_tuples_;
209   EncodeHash encode_hash_;
210   RefCounter ref_count_;
211   SymbolTable *isymbols_;       // Pre-encoded ilabel symbol table
212   SymbolTable *osymbols_;       // Pre-encoded olabel symbol table
213 
214   DISALLOW_COPY_AND_ASSIGN(EncodeTable);
215 };
216 
217 template <class A> inline
Write(ostream & strm,const string & source)218 bool EncodeTable<A>::Write(ostream &strm, const string &source) const {
219   WriteType(strm, kEncodeMagicNumber);
220   WriteType(strm, flags_);
221   int64 size = encode_tuples_.size();
222   WriteType(strm, size);
223   for (size_t i = 0;  i < size; ++i) {
224     const Tuple* tuple = encode_tuples_[i];
225     WriteType(strm, tuple->ilabel);
226     WriteType(strm, tuple->olabel);
227     tuple->weight.Write(strm);
228   }
229 
230   if (flags_ & kEncodeHasISymbols)
231     isymbols_->Write(strm);
232 
233   if (flags_ & kEncodeHasOSymbols)
234     osymbols_->Write(strm);
235 
236   strm.flush();
237   if (!strm) {
238     LOG(ERROR) << "EncodeTable::Write: write failed: " << source;
239     return false;
240   }
241   return true;
242 }
243 
244 template <class A> inline
Read(istream & strm,const string & source)245 EncodeTable<A> *EncodeTable<A>::Read(istream &strm, const string &source) {
246   int32 magic_number = 0;
247   ReadType(strm, &magic_number);
248   if (magic_number != kEncodeMagicNumber) {
249     LOG(ERROR) << "EncodeTable::Read: Bad encode table header: " << source;
250     return 0;
251   }
252   uint32 flags;
253   ReadType(strm, &flags);
254   EncodeTable<A> *table = new EncodeTable<A>(flags);
255 
256   int64 size;
257   ReadType(strm, &size);
258   if (!strm) {
259     LOG(ERROR) << "EncodeTable::Read: read failed: " << source;
260     return 0;
261   }
262 
263   for (size_t i = 0; i < size; ++i) {
264     Tuple* tuple = new Tuple();
265     ReadType(strm, &tuple->ilabel);
266     ReadType(strm, &tuple->olabel);
267     tuple->weight.Read(strm);
268     if (!strm) {
269       LOG(ERROR) << "EncodeTable::Read: read failed: " << source;
270       return 0;
271     }
272     table->encode_tuples_.push_back(tuple);
273     table->encode_hash_[table->encode_tuples_.back()] =
274         table->encode_tuples_.size();
275   }
276 
277   if (flags & kEncodeHasISymbols)
278     table->isymbols_ = SymbolTable::Read(strm, source);
279 
280   if (flags & kEncodeHasOSymbols)
281     table->osymbols_ = SymbolTable::Read(strm, source);
282 
283   return table;
284 }
285 
286 
287 // A mapper to encode/decode weighted transducers. Encoding of an
288 // Fst is useful for performing classical determinization or minimization
289 // on a weighted transducer by treating it as an unweighted acceptor over
290 // encoded labels.
291 //
292 // The Encode mapper stores the encoding in a local hash table (EncodeTable)
293 // This table is shared (and reference counted) between the encoder and
294 // decoder. A decoder has read only access to the EncodeTable.
295 //
296 // The EncodeMapper allows on the fly encoding of the machine. As the
297 // EncodeTable is generated the same table may by used to decode the machine
298 // on the fly. For example in the following sequence of operations
299 //
300 //  Encode -> Determinize -> Decode
301 //
302 // we will use the encoding table generated during the encode step in the
303 // decode, even though the encoding is not complete.
304 //
305 template <class A> class EncodeMapper {
306   typedef typename A::Weight Weight;
307   typedef typename A::Label  Label;
308  public:
EncodeMapper(uint32 flags,EncodeType type)309   EncodeMapper(uint32 flags, EncodeType type)
310     : flags_(flags),
311       type_(type),
312       table_(new EncodeTable<A>(flags)),
313       error_(false) {}
314 
EncodeMapper(const EncodeMapper & mapper)315   EncodeMapper(const EncodeMapper& mapper)
316       : flags_(mapper.flags_),
317         type_(mapper.type_),
318         table_(mapper.table_),
319         error_(false) {
320     table_->IncrRefCount();
321   }
322 
323   // Copy constructor but setting the type, typically to DECODE
EncodeMapper(const EncodeMapper & mapper,EncodeType type)324   EncodeMapper(const EncodeMapper& mapper, EncodeType type)
325       : flags_(mapper.flags_),
326         type_(type),
327         table_(mapper.table_),
328         error_(mapper.error_) {
329     table_->IncrRefCount();
330   }
331 
~EncodeMapper()332   ~EncodeMapper() {
333     if (!table_->DecrRefCount()) delete table_;
334   }
335 
336   A operator()(const A &arc);
337 
FinalAction()338   MapFinalAction FinalAction() const {
339     return (type_ == ENCODE && (flags_ & kEncodeWeights)) ?
340                    MAP_REQUIRE_SUPERFINAL : MAP_NO_SUPERFINAL;
341   }
342 
InputSymbolsAction()343   MapSymbolsAction InputSymbolsAction() const { return MAP_CLEAR_SYMBOLS; }
344 
OutputSymbolsAction()345   MapSymbolsAction OutputSymbolsAction() const { return MAP_CLEAR_SYMBOLS;}
346 
Properties(uint64 inprops)347   uint64 Properties(uint64 inprops) {
348     uint64 outprops = inprops;
349     if (error_) outprops |= kError;
350 
351     uint64 mask = kFstProperties;
352     if (flags_ & kEncodeLabels)
353       mask &= kILabelInvariantProperties & kOLabelInvariantProperties;
354     if (flags_ & kEncodeWeights)
355       mask &= kILabelInvariantProperties & kWeightInvariantProperties &
356           (type_ == ENCODE ? kAddSuperFinalProperties :
357            kRmSuperFinalProperties);
358 
359     return outprops & mask;
360   }
361 
flags()362   const uint32 flags() const { return flags_; }
type()363   const EncodeType type() const { return type_; }
table()364   const EncodeTable<A> &table() const { return *table_; }
365 
Write(ostream & strm,const string & source)366   bool Write(ostream &strm, const string& source) {
367     return table_->Write(strm, source);
368   }
369 
Write(const string & filename)370   bool Write(const string& filename) {
371     ofstream strm(filename.c_str(), ofstream::out | ofstream::binary);
372     if (!strm) {
373       LOG(ERROR) << "EncodeMap: Can't open file: " << filename;
374       return false;
375     }
376     return Write(strm, filename);
377   }
378 
379   static EncodeMapper<A> *Read(istream &strm,
380                                const string& source,
381                                EncodeType type = ENCODE) {
382     EncodeTable<A> *table = EncodeTable<A>::Read(strm, source);
383     return table ? new EncodeMapper(table->flags(), type, table) : 0;
384   }
385 
386   static EncodeMapper<A> *Read(const string& filename,
387                                EncodeType type = ENCODE) {
388     ifstream strm(filename.c_str(), ifstream::in | ifstream::binary);
389     if (!strm) {
390       LOG(ERROR) << "EncodeMap: Can't open file: " << filename;
391       return NULL;
392     }
393     return Read(strm, filename, type);
394   }
395 
InputSymbols()396   SymbolTable *InputSymbols() const { return table_->InputSymbols(); }
397 
OutputSymbols()398   SymbolTable *OutputSymbols() const { return table_->OutputSymbols(); }
399 
SetInputSymbols(const SymbolTable * syms)400   void SetInputSymbols(const SymbolTable* syms) {
401     table_->SetInputSymbols(syms);
402   }
403 
SetOutputSymbols(const SymbolTable * syms)404   void SetOutputSymbols(const SymbolTable* syms) {
405     table_->SetOutputSymbols(syms);
406   }
407 
408  private:
409   uint32 flags_;
410   EncodeType type_;
411   EncodeTable<A>* table_;
412   bool error_;
413 
EncodeMapper(uint32 flags,EncodeType type,EncodeTable<A> * table)414   explicit EncodeMapper(uint32 flags, EncodeType type, EncodeTable<A> *table)
415       : flags_(flags), type_(type), table_(table) {}
416   void operator=(const EncodeMapper &);  // Disallow.
417 };
418 
419 template <class A> inline
operator()420 A EncodeMapper<A>::operator()(const A &arc) {
421   if (type_ == ENCODE) {  // labels and/or weights to single label
422     if ((arc.nextstate == kNoStateId && !(flags_ & kEncodeWeights)) ||
423         (arc.nextstate == kNoStateId && (flags_ & kEncodeWeights) &&
424          arc.weight == Weight::Zero())) {
425       return arc;
426     } else {
427       Label label = table_->Encode(arc);
428       return A(label,
429                flags_ & kEncodeLabels ? label : arc.olabel,
430                flags_ & kEncodeWeights ? Weight::One() : arc.weight,
431                arc.nextstate);
432     }
433   } else {  // type_ == DECODE
434     if (arc.nextstate == kNoStateId) {
435       return arc;
436     } else {
437       if (arc.ilabel == 0) return arc;
438       if (flags_ & kEncodeLabels && arc.ilabel != arc.olabel) {
439         FSTERROR() << "EncodeMapper: Label-encoded arc has different "
440             "input and output labels";
441         error_ = true;
442       }
443       if (flags_ & kEncodeWeights && arc.weight != Weight::One()) {
444         FSTERROR() <<
445             "EncodeMapper: Weight-encoded arc has non-trivial weight";
446         error_ = true;
447       }
448       const typename EncodeTable<A>::Tuple* tuple = table_->Decode(arc.ilabel);
449       if (!tuple) {
450         FSTERROR() << "EncodeMapper: decode failed";
451         error_ = true;
452         return A(kNoLabel, kNoLabel, Weight::NoWeight(), arc.nextstate);
453       } else {
454         return A(tuple->ilabel,
455                  flags_ & kEncodeLabels ? tuple->olabel : arc.olabel,
456                  flags_ & kEncodeWeights ? tuple->weight : arc.weight,
457                  arc.nextstate);
458       }
459     }
460   }
461 }
462 
463 
464 // Complexity: O(nstates + narcs)
465 template<class A> inline
Encode(MutableFst<A> * fst,EncodeMapper<A> * mapper)466 void Encode(MutableFst<A> *fst, EncodeMapper<A>* mapper) {
467   mapper->SetInputSymbols(fst->InputSymbols());
468   mapper->SetOutputSymbols(fst->OutputSymbols());
469   ArcMap(fst, mapper);
470 }
471 
472 template<class A> inline
Decode(MutableFst<A> * fst,const EncodeMapper<A> & mapper)473 void Decode(MutableFst<A>* fst, const EncodeMapper<A>& mapper) {
474   ArcMap(fst, EncodeMapper<A>(mapper, DECODE));
475   RmFinalEpsilon(fst);
476   fst->SetInputSymbols(mapper.InputSymbols());
477   fst->SetOutputSymbols(mapper.OutputSymbols());
478 }
479 
480 
481 // On the fly label and/or weight encoding of input Fst
482 //
483 // Complexity:
484 // - Constructor: O(1)
485 // - Traversal: O(nstates_visited + narcs_visited), assuming constant
486 //   time to visit an input state or arc.
487 template <class A>
488 class EncodeFst : public ArcMapFst<A, A, EncodeMapper<A> > {
489  public:
490   typedef A Arc;
491   typedef EncodeMapper<A> C;
492   typedef ArcMapFstImpl< A, A, EncodeMapper<A> > Impl;
493   using ImplToFst<Impl>::GetImpl;
494 
EncodeFst(const Fst<A> & fst,EncodeMapper<A> * encoder)495   EncodeFst(const Fst<A> &fst, EncodeMapper<A>* encoder)
496       : ArcMapFst<A, A, C>(fst, encoder, ArcMapFstOptions()) {
497     encoder->SetInputSymbols(fst.InputSymbols());
498     encoder->SetOutputSymbols(fst.OutputSymbols());
499   }
500 
EncodeFst(const Fst<A> & fst,const EncodeMapper<A> & encoder)501   EncodeFst(const Fst<A> &fst, const EncodeMapper<A>& encoder)
502       : ArcMapFst<A, A, C>(fst, encoder, ArcMapFstOptions()) {}
503 
504   // See Fst<>::Copy() for doc.
505   EncodeFst(const EncodeFst<A> &fst, bool copy = false)
506       : ArcMapFst<A, A, C>(fst, copy) {}
507 
508   // Get a copy of this EncodeFst. See Fst<>::Copy() for further doc.
509   virtual EncodeFst<A> *Copy(bool safe = false) const {
510     if (safe) {
511       FSTERROR() << "EncodeFst::Copy(true): not allowed.";
512       GetImpl()->SetProperties(kError, kError);
513     }
514     return new EncodeFst(*this);
515   }
516 };
517 
518 
519 // On the fly label and/or weight encoding of input Fst
520 //
521 // Complexity:
522 // - Constructor: O(1)
523 // - Traversal: O(nstates_visited + narcs_visited), assuming constant
524 //   time to visit an input state or arc.
525 template <class A>
526 class DecodeFst : public ArcMapFst<A, A, EncodeMapper<A> > {
527  public:
528   typedef A Arc;
529   typedef EncodeMapper<A> C;
530   typedef ArcMapFstImpl< A, A, EncodeMapper<A> > Impl;
531   using ImplToFst<Impl>::GetImpl;
532 
DecodeFst(const Fst<A> & fst,const EncodeMapper<A> & encoder)533   DecodeFst(const Fst<A> &fst, const EncodeMapper<A>& encoder)
534       : ArcMapFst<A, A, C>(fst,
535                             EncodeMapper<A>(encoder, DECODE),
536                             ArcMapFstOptions()) {
537     GetImpl()->SetInputSymbols(encoder.InputSymbols());
538     GetImpl()->SetOutputSymbols(encoder.OutputSymbols());
539   }
540 
541   // See Fst<>::Copy() for doc.
542   DecodeFst(const DecodeFst<A> &fst, bool safe = false)
543       : ArcMapFst<A, A, C>(fst, safe) {}
544 
545   // Get a copy of this DecodeFst. See Fst<>::Copy() for further doc.
546   virtual DecodeFst<A> *Copy(bool safe = false) const {
547     return new DecodeFst(*this, safe);
548   }
549 };
550 
551 
552 // Specialization for EncodeFst.
553 template <class A>
554 class StateIterator< EncodeFst<A> >
555     : public StateIterator< ArcMapFst<A, A, EncodeMapper<A> > > {
556  public:
StateIterator(const EncodeFst<A> & fst)557   explicit StateIterator(const EncodeFst<A> &fst)
558       : StateIterator< ArcMapFst<A, A, EncodeMapper<A> > >(fst) {}
559 };
560 
561 
562 // Specialization for EncodeFst.
563 template <class A>
564 class ArcIterator< EncodeFst<A> >
565     : public ArcIterator< ArcMapFst<A, A, EncodeMapper<A> > > {
566  public:
ArcIterator(const EncodeFst<A> & fst,typename A::StateId s)567   ArcIterator(const EncodeFst<A> &fst, typename A::StateId s)
568       : ArcIterator< ArcMapFst<A, A, EncodeMapper<A> > >(fst, s) {}
569 };
570 
571 
572 // Specialization for DecodeFst.
573 template <class A>
574 class StateIterator< DecodeFst<A> >
575     : public StateIterator< ArcMapFst<A, A, EncodeMapper<A> > > {
576  public:
StateIterator(const DecodeFst<A> & fst)577   explicit StateIterator(const DecodeFst<A> &fst)
578       : StateIterator< ArcMapFst<A, A, EncodeMapper<A> > >(fst) {}
579 };
580 
581 
582 // Specialization for DecodeFst.
583 template <class A>
584 class ArcIterator< DecodeFst<A> >
585     : public ArcIterator< ArcMapFst<A, A, EncodeMapper<A> > > {
586  public:
ArcIterator(const DecodeFst<A> & fst,typename A::StateId s)587   ArcIterator(const DecodeFst<A> &fst, typename A::StateId s)
588       : ArcIterator< ArcMapFst<A, A, EncodeMapper<A> > >(fst, s) {}
589 };
590 
591 
592 // Useful aliases when using StdArc.
593 typedef EncodeFst<StdArc> StdEncodeFst;
594 
595 typedef DecodeFst<StdArc> StdDecodeFst;
596 
597 }  // namespace fst
598 
599 #endif  // FST_LIB_ENCODE_H__
600