• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // encode.h
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 //
16 // \file
17 // Class to encode and decoder an fst.
18 
19 #ifndef FST_LIB_ENCODE_H__
20 #define FST_LIB_ENCODE_H__
21 
22 #include "fst/lib/map.h"
23 #include "fst/lib/rmfinalepsilon.h"
24 
25 namespace fst {
26 
27 static const uint32 kEncodeLabels = 0x00001;
28 static const uint32 kEncodeWeights  = 0x00002;
29 
30 enum EncodeType { ENCODE = 1, DECODE = 2 };
31 
32 // Identifies stream data as an encode table (and its endianity)
33 static const int32 kEncodeMagicNumber = 2129983209;
34 
35 
36 // The following class encapsulates implementation details for the
37 // encoding and decoding of label/weight tuples used for encoding
38 // and decoding of Fsts. The EncodeTable is bidirectional. I.E it
39 // stores both the Tuple of encode labels and weights to a unique
40 // label, and the reverse.
41 template <class A>  class EncodeTable {
42  public:
43   typedef typename A::Label Label;
44   typedef typename A::Weight Weight;
45 
46   // Encoded data consists of arc input/output labels and arc weight
47   struct Tuple {
TupleTuple48     Tuple() {}
TupleTuple49     Tuple(Label ilabel_, Label olabel_, Weight weight_)
50         : ilabel(ilabel_), olabel(olabel_), weight(weight_) {}
TupleTuple51     Tuple(const Tuple& tuple)
52         : ilabel(tuple.ilabel), olabel(tuple.olabel), weight(tuple.weight) {}
53 
54     Label ilabel;
55     Label olabel;
56     Weight weight;
57   };
58 
59   // Comparison object for hashing EncodeTable Tuple(s).
60   class TupleEqual {
61    public:
operator()62     bool operator()(const Tuple* x, const Tuple* y) const {
63       return (x->ilabel == y->ilabel &&
64               x->olabel == y->olabel &&
65               x->weight == y->weight);
66     }
67   };
68 
69   // Hash function for EncodeTabe Tuples. Based on the encode flags
70   // we either hash the labels, weights or compbination of them.
71   class TupleKey {
72     static const int kPrime = 7853;
73    public:
TupleKey()74     TupleKey()
75         : encode_flags_(kEncodeLabels | kEncodeWeights) {}
76 
TupleKey(const TupleKey & key)77     TupleKey(const TupleKey& key)
78         : encode_flags_(key.encode_flags_) {}
79 
TupleKey(uint32 encode_flags)80     explicit TupleKey(uint32 encode_flags)
81         : encode_flags_(encode_flags) {}
82 
operator()83     size_t operator()(const Tuple* x) const {
84       int lshift = x->ilabel % kPrime;
85       int rshift = sizeof(size_t) - lshift;
86       size_t hash = x->ilabel << lshift;
87       if (encode_flags_ & kEncodeLabels) hash ^= x->olabel >> rshift;
88       if (encode_flags_ & kEncodeWeights)  hash ^= x->weight.Hash();
89       return hash;
90     }
91 
92    private:
93     int32 encode_flags_;
94   };
95 
96   typedef hash_map<const Tuple*,
97                    Label,
98                    TupleKey,
99                    TupleEqual> EncodeHash;
100 
EncodeTable(uint32 encode_flags)101   explicit EncodeTable(uint32 encode_flags)
102       : flags_(encode_flags),
103         encode_hash_(1024, TupleKey(encode_flags)) {}
104 
~EncodeTable()105   ~EncodeTable() {
106     for (size_t i = 0; i < encode_tuples_.size(); ++i) {
107       delete encode_tuples_[i];
108     }
109   }
110 
111   // Given an arc encode either input/ouptut labels or input/costs or both
Encode(const A & arc)112   Label Encode(const A &arc) {
113     const Tuple tuple(arc.ilabel,
114                       flags_ & kEncodeLabels ? arc.olabel : 0,
115                       flags_ & kEncodeWeights ? arc.weight : Weight::One());
116     typename EncodeHash::const_iterator it = encode_hash_.find(&tuple);
117     if (it == encode_hash_.end()) {
118       encode_tuples_.push_back(new Tuple(tuple));
119       encode_hash_[encode_tuples_.back()] = encode_tuples_.size();
120       return encode_tuples_.size();
121     } else {
122       return it->second;
123     }
124   }
125 
126   // Given an encode arc Label decode back to input/output labels and costs
Decode(Label key)127   const Tuple* Decode(Label key) {
128     return key <= (Label)encode_tuples_.size() ? encode_tuples_[key - 1] : 0;
129   }
130 
Write(ostream & strm,const string & source)131   bool Write(ostream &strm, const string &source) const {
132     WriteType(strm, kEncodeMagicNumber);
133     WriteType(strm, flags_);
134     int64 size = encode_tuples_.size();
135     WriteType(strm, size);
136     for (size_t i = 0;  i < size; ++i) {
137       const Tuple* tuple = encode_tuples_[i];
138       WriteType(strm, tuple->ilabel);
139       WriteType(strm, tuple->olabel);
140       tuple->weight.Write(strm);
141     }
142     strm.flush();
143     if (!strm)
144       LOG(ERROR) << "EncodeTable::Write: write failed: " << source;
145     return strm;
146   }
147 
Read(istream & strm,const string & source)148   bool Read(istream &strm, const string &source) {
149     encode_tuples_.clear();
150     encode_hash_.clear();
151     int32 magic_number = 0;
152     ReadType(strm, &magic_number);
153     if (magic_number != kEncodeMagicNumber) {
154       LOG(ERROR) << "EncodeTable::Read: Bad encode table header: " << source;
155       return false;
156     }
157     ReadType(strm, &flags_);
158     int64 size;
159     ReadType(strm, &size);
160     if (!strm) {
161       LOG(ERROR) << "EncodeTable::Read: read failed: " << source;
162       return false;
163     }
164     for (size_t i = 0; i < size; ++i) {
165       Tuple* tuple = new Tuple();
166       ReadType(strm, &tuple->ilabel);
167       ReadType(strm, &tuple->olabel);
168       tuple->weight.Read(strm);
169       encode_tuples_.push_back(tuple);
170       encode_hash_[encode_tuples_.back()] = encode_tuples_.size();
171     }
172     if (!strm)
173       LOG(ERROR) << "EncodeTable::Read: read failed: " << source;
174     return strm;
175   }
176 
flags()177   const uint32 flags() const { return flags_; }
178  private:
179   uint32 flags_;
180   vector<Tuple*> encode_tuples_;
181   EncodeHash encode_hash_;
182 
183   DISALLOW_EVIL_CONSTRUCTORS(EncodeTable);
184 };
185 
186 
187 // A mapper to encode/decode weighted transducers. Encoding of an
188 // Fst is useful for performing classical determinization or minimization
189 // on a weighted transducer by treating it as an unweighted acceptor over
190 // encoded labels.
191 //
192 // The Encode mapper stores the encoding in a local hash table (EncodeTable)
193 // This table is shared (and reference counted) between the encoder and
194 // decoder. A decoder has read only access to the EncodeTable.
195 //
196 // The EncodeMapper allows on the fly encoding of the machine. As the
197 // EncodeTable is generated the same table may by used to decode the machine
198 // on the fly. For example in the following sequence of operations
199 //
200 //  Encode -> Determinize -> Decode
201 //
202 // we will use the encoding table generated during the encode step in the
203 // decode, even though the encoding is not complete.
204 //
205 template <class A> class EncodeMapper {
206   typedef typename A::Weight Weight;
207   typedef typename A::Label  Label;
208  public:
EncodeMapper(uint32 flags,EncodeType type)209   EncodeMapper(uint32 flags, EncodeType type)
210     : ref_count_(1), flags_(flags), type_(type),
211       table_(new EncodeTable<A>(flags)) {}
212 
EncodeMapper(const EncodeMapper & mapper)213   EncodeMapper(const EncodeMapper& mapper)
214       : ref_count_(mapper.ref_count_ + 1),
215         flags_(mapper.flags_),
216         type_(mapper.type_),
217         table_(mapper.table_) { }
218 
219   // Copy constructor but setting the type, typically to DECODE
EncodeMapper(const EncodeMapper & mapper,EncodeType type)220   EncodeMapper(const EncodeMapper& mapper, EncodeType type)
221       : ref_count_(mapper.ref_count_ + 1),
222         flags_(mapper.flags_),
223         type_(type),
224         table_(mapper.table_) { }
225 
~EncodeMapper()226   ~EncodeMapper() {
227     if (--ref_count_ == 0) delete table_;
228   }
229 
operator()230   A operator()(const A &arc) {
231     if (type_ == ENCODE) {  // labels and/or weights to single label
232       if ((arc.nextstate == kNoStateId && !(flags_ & kEncodeWeights)) ||
233           (arc.nextstate == kNoStateId && (flags_ & kEncodeWeights) &&
234            arc.weight == Weight::Zero())) {
235         return arc;
236       } else {
237         Label label = table_->Encode(arc);
238         return A(label,
239                  flags_ & kEncodeLabels ? label : arc.olabel,
240                  flags_ & kEncodeWeights ? Weight::One() : arc.weight,
241                  arc.nextstate);
242       }
243     } else {
244       if (arc.nextstate == kNoStateId) {
245         return arc;
246       } else {
247         const typename EncodeTable<A>::Tuple* tuple =
248           table_->Decode(arc.ilabel);
249         return A(tuple->ilabel,
250                  flags_ & kEncodeLabels ? tuple->olabel : arc.olabel,
251                  flags_ & kEncodeWeights ? tuple->weight : arc.weight,
252                  arc.nextstate);;
253       }
254     }
255   }
256 
Properties(uint64 props)257   uint64 Properties(uint64 props) {
258     uint64 mask = kFstProperties;
259     if (flags_ & kEncodeLabels)
260       mask &= kILabelInvariantProperties & kOLabelInvariantProperties;
261     if (flags_ & kEncodeWeights)
262       mask &= kILabelInvariantProperties & kWeightInvariantProperties &
263           (type_ == ENCODE ? kAddSuperFinalProperties :
264            kRmSuperFinalProperties);
265     return props & mask;
266   }
267 
268 
FinalAction()269   MapFinalAction FinalAction() const {
270     return (type_ == ENCODE && (flags_ & kEncodeWeights)) ?
271                    MAP_REQUIRE_SUPERFINAL : MAP_NO_SUPERFINAL;
272   }
273 
flags()274   const uint32 flags() const { return flags_; }
type()275   const EncodeType type() const { return type_; }
276 
Write(ostream & strm,const string & source)277   bool Write(ostream &strm, const string& source) {
278     return table_->Write(strm, source);
279   }
280 
Write(const string & filename)281   bool Write(const string& filename) {
282     ofstream strm(filename.c_str());
283     if (!strm) {
284       LOG(ERROR) << "EncodeMap: Can't open file: " << filename;
285       return false;
286     }
287     return Write(strm, filename);
288   }
289 
Read(istream & strm,const string & source,EncodeType type)290   static EncodeMapper<A> *Read(istream &strm,
291                                const string& source, EncodeType type) {
292     EncodeTable<A> *table = new EncodeTable<A>(0);
293     bool r = table->Read(strm, source);
294     return r ? new EncodeMapper(table->flags(), type, table) : 0;
295   }
296 
Read(const string & filename,EncodeType type)297   static EncodeMapper<A> *Read(const string& filename, EncodeType type) {
298     ifstream strm(filename.c_str());
299     if (!strm) {
300       LOG(ERROR) << "EncodeMap: Can't open file: " << filename;
301       return false;
302     }
303     return Read(strm, filename, type);
304   }
305 
306  private:
307   uint32  ref_count_;
308   uint32  flags_;
309   EncodeType type_;
310   EncodeTable<A>* table_;
311 
EncodeMapper(uint32 flags,EncodeType type,EncodeTable<A> * table)312   explicit EncodeMapper(uint32 flags, EncodeType type, EncodeTable<A> *table)
313       : ref_count_(1), flags_(flags), type_(type), table_(table) {}
314   void operator=(const EncodeMapper &);  // Disallow.
315 };
316 
317 
318 // Complexity: O(nstates + narcs)
319 template<class A> inline
Encode(MutableFst<A> * fst,EncodeMapper<A> * mapper)320 void Encode(MutableFst<A> *fst, EncodeMapper<A>* mapper) {
321   Map(fst, mapper);
322 }
323 
324 
325 template<class A> inline
Decode(MutableFst<A> * fst,const EncodeMapper<A> & mapper)326 void Decode(MutableFst<A>* fst, const EncodeMapper<A>& mapper) {
327   Map(fst, EncodeMapper<A>(mapper, DECODE));
328   RmFinalEpsilon(fst);
329 }
330 
331 
332 // On the fly label and/or weight encoding of input Fst
333 //
334 // Complexity:
335 // - Constructor: O(1)
336 // - Traversal: O(nstates_visited + narcs_visited), assuming constant
337 //   time to visit an input state or arc.
338 template <class A>
339 class EncodeFst : public MapFst<A, A, EncodeMapper<A> > {
340  public:
341   typedef A Arc;
342   typedef EncodeMapper<A> C;
343 
EncodeFst(const Fst<A> & fst,EncodeMapper<A> * encoder)344   EncodeFst(const Fst<A> &fst, EncodeMapper<A>* encoder)
345       : MapFst<A, A, C>(fst, encoder, MapFstOptions()) {}
346 
EncodeFst(const Fst<A> & fst,const EncodeMapper<A> & encoder)347   EncodeFst(const Fst<A> &fst, const EncodeMapper<A>& encoder)
348       : MapFst<A, A, C>(fst, encoder, MapFstOptions()) {}
349 
EncodeFst(const EncodeFst<A> & fst)350   EncodeFst(const EncodeFst<A> &fst)
351       : MapFst<A, A, C>(fst) {}
352 
Copy()353   virtual EncodeFst<A> *Copy() const { return new EncodeFst(*this); }
354 };
355 
356 
357 // On the fly label and/or weight encoding of input Fst
358 //
359 // Complexity:
360 // - Constructor: O(1)
361 // - Traversal: O(nstates_visited + narcs_visited), assuming constant
362 //   time to visit an input state or arc.
363 template <class A>
364 class DecodeFst : public MapFst<A, A, EncodeMapper<A> > {
365  public:
366   typedef A Arc;
367   typedef EncodeMapper<A> C;
368 
DecodeFst(const Fst<A> & fst,const EncodeMapper<A> & encoder)369   DecodeFst(const Fst<A> &fst, const EncodeMapper<A>& encoder)
370       : MapFst<A, A, C>(fst,
371                             EncodeMapper<A>(encoder, DECODE),
372                             MapFstOptions()) {}
373 
DecodeFst(const EncodeFst<A> & fst)374   DecodeFst(const EncodeFst<A> &fst)
375       : MapFst<A, A, C>(fst) {}
376 
Copy()377   virtual DecodeFst<A> *Copy() const { return new DecodeFst(*this); }
378 };
379 
380 
381 // Specialization for EncodeFst.
382 template <class A>
383 class StateIterator< EncodeFst<A> >
384     : public StateIterator< MapFst<A, A, EncodeMapper<A> > > {
385  public:
StateIterator(const EncodeFst<A> & fst)386   explicit StateIterator(const EncodeFst<A> &fst)
387       : StateIterator< MapFst<A, A, EncodeMapper<A> > >(fst) {}
388 };
389 
390 
391 // Specialization for EncodeFst.
392 template <class A>
393 class ArcIterator< EncodeFst<A> >
394     : public ArcIterator< MapFst<A, A, EncodeMapper<A> > > {
395  public:
ArcIterator(const EncodeFst<A> & fst,typename A::StateId s)396   ArcIterator(const EncodeFst<A> &fst, typename A::StateId s)
397       : ArcIterator< MapFst<A, A, EncodeMapper<A> > >(fst, s) {}
398 };
399 
400 
401 // Specialization for DecodeFst.
402 template <class A>
403 class StateIterator< DecodeFst<A> >
404     : public StateIterator< MapFst<A, A, EncodeMapper<A> > > {
405  public:
StateIterator(const DecodeFst<A> & fst)406   explicit StateIterator(const DecodeFst<A> &fst)
407       : StateIterator< MapFst<A, A, EncodeMapper<A> > >(fst) {}
408 };
409 
410 
411 // Specialization for DecodeFst.
412 template <class A>
413 class ArcIterator< DecodeFst<A> >
414     : public ArcIterator< MapFst<A, A, EncodeMapper<A> > > {
415  public:
ArcIterator(const DecodeFst<A> & fst,typename A::StateId s)416   ArcIterator(const DecodeFst<A> &fst, typename A::StateId s)
417       : ArcIterator< MapFst<A, A, EncodeMapper<A> > >(fst, s) {}
418 };
419 
420 
421 // Useful aliases when using StdArc.
422 typedef EncodeFst<StdArc> StdEncodeFst;
423 
424 typedef DecodeFst<StdArc> StdDecodeFst;
425 
426 }
427 
428 #endif  // FST_LIB_ENCODE_H__
429