1 // encode.h
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 //
16 // \file
17 // Class to encode and decoder an fst.
18
19 #ifndef FST_LIB_ENCODE_H__
20 #define FST_LIB_ENCODE_H__
21
22 #include "fst/lib/map.h"
23 #include "fst/lib/rmfinalepsilon.h"
24
25 namespace fst {
26
27 static const uint32 kEncodeLabels = 0x00001;
28 static const uint32 kEncodeWeights = 0x00002;
29
30 enum EncodeType { ENCODE = 1, DECODE = 2 };
31
32 // Identifies stream data as an encode table (and its endianity)
33 static const int32 kEncodeMagicNumber = 2129983209;
34
35
36 // The following class encapsulates implementation details for the
37 // encoding and decoding of label/weight tuples used for encoding
38 // and decoding of Fsts. The EncodeTable is bidirectional. I.E it
39 // stores both the Tuple of encode labels and weights to a unique
40 // label, and the reverse.
41 template <class A> class EncodeTable {
42 public:
43 typedef typename A::Label Label;
44 typedef typename A::Weight Weight;
45
46 // Encoded data consists of arc input/output labels and arc weight
47 struct Tuple {
TupleTuple48 Tuple() {}
TupleTuple49 Tuple(Label ilabel_, Label olabel_, Weight weight_)
50 : ilabel(ilabel_), olabel(olabel_), weight(weight_) {}
TupleTuple51 Tuple(const Tuple& tuple)
52 : ilabel(tuple.ilabel), olabel(tuple.olabel), weight(tuple.weight) {}
53
54 Label ilabel;
55 Label olabel;
56 Weight weight;
57 };
58
59 // Comparison object for hashing EncodeTable Tuple(s).
60 class TupleEqual {
61 public:
operator()62 bool operator()(const Tuple* x, const Tuple* y) const {
63 return (x->ilabel == y->ilabel &&
64 x->olabel == y->olabel &&
65 x->weight == y->weight);
66 }
67 };
68
69 // Hash function for EncodeTabe Tuples. Based on the encode flags
70 // we either hash the labels, weights or compbination of them.
71 class TupleKey {
72 static const int kPrime = 7853;
73 public:
TupleKey()74 TupleKey()
75 : encode_flags_(kEncodeLabels | kEncodeWeights) {}
76
TupleKey(const TupleKey & key)77 TupleKey(const TupleKey& key)
78 : encode_flags_(key.encode_flags_) {}
79
TupleKey(uint32 encode_flags)80 explicit TupleKey(uint32 encode_flags)
81 : encode_flags_(encode_flags) {}
82
operator()83 size_t operator()(const Tuple* x) const {
84 int lshift = x->ilabel % kPrime;
85 int rshift = sizeof(size_t) - lshift;
86 size_t hash = x->ilabel << lshift;
87 if (encode_flags_ & kEncodeLabels) hash ^= x->olabel >> rshift;
88 if (encode_flags_ & kEncodeWeights) hash ^= x->weight.Hash();
89 return hash;
90 }
91
92 private:
93 int32 encode_flags_;
94 };
95
96 typedef hash_map<const Tuple*,
97 Label,
98 TupleKey,
99 TupleEqual> EncodeHash;
100
EncodeTable(uint32 encode_flags)101 explicit EncodeTable(uint32 encode_flags)
102 : flags_(encode_flags),
103 encode_hash_(1024, TupleKey(encode_flags)) {}
104
~EncodeTable()105 ~EncodeTable() {
106 for (size_t i = 0; i < encode_tuples_.size(); ++i) {
107 delete encode_tuples_[i];
108 }
109 }
110
111 // Given an arc encode either input/ouptut labels or input/costs or both
Encode(const A & arc)112 Label Encode(const A &arc) {
113 const Tuple tuple(arc.ilabel,
114 flags_ & kEncodeLabels ? arc.olabel : 0,
115 flags_ & kEncodeWeights ? arc.weight : Weight::One());
116 typename EncodeHash::const_iterator it = encode_hash_.find(&tuple);
117 if (it == encode_hash_.end()) {
118 encode_tuples_.push_back(new Tuple(tuple));
119 encode_hash_[encode_tuples_.back()] = encode_tuples_.size();
120 return encode_tuples_.size();
121 } else {
122 return it->second;
123 }
124 }
125
126 // Given an encode arc Label decode back to input/output labels and costs
Decode(Label key)127 const Tuple* Decode(Label key) {
128 return key <= (Label)encode_tuples_.size() ? encode_tuples_[key - 1] : 0;
129 }
130
Write(ostream & strm,const string & source)131 bool Write(ostream &strm, const string &source) const {
132 WriteType(strm, kEncodeMagicNumber);
133 WriteType(strm, flags_);
134 int64 size = encode_tuples_.size();
135 WriteType(strm, size);
136 for (size_t i = 0; i < size; ++i) {
137 const Tuple* tuple = encode_tuples_[i];
138 WriteType(strm, tuple->ilabel);
139 WriteType(strm, tuple->olabel);
140 tuple->weight.Write(strm);
141 }
142 strm.flush();
143 if (!strm)
144 LOG(ERROR) << "EncodeTable::Write: write failed: " << source;
145 return strm;
146 }
147
Read(istream & strm,const string & source)148 bool Read(istream &strm, const string &source) {
149 encode_tuples_.clear();
150 encode_hash_.clear();
151 int32 magic_number = 0;
152 ReadType(strm, &magic_number);
153 if (magic_number != kEncodeMagicNumber) {
154 LOG(ERROR) << "EncodeTable::Read: Bad encode table header: " << source;
155 return false;
156 }
157 ReadType(strm, &flags_);
158 int64 size;
159 ReadType(strm, &size);
160 if (!strm) {
161 LOG(ERROR) << "EncodeTable::Read: read failed: " << source;
162 return false;
163 }
164 for (size_t i = 0; i < size; ++i) {
165 Tuple* tuple = new Tuple();
166 ReadType(strm, &tuple->ilabel);
167 ReadType(strm, &tuple->olabel);
168 tuple->weight.Read(strm);
169 encode_tuples_.push_back(tuple);
170 encode_hash_[encode_tuples_.back()] = encode_tuples_.size();
171 }
172 if (!strm)
173 LOG(ERROR) << "EncodeTable::Read: read failed: " << source;
174 return strm;
175 }
176
flags()177 const uint32 flags() const { return flags_; }
178 private:
179 uint32 flags_;
180 vector<Tuple*> encode_tuples_;
181 EncodeHash encode_hash_;
182
183 DISALLOW_EVIL_CONSTRUCTORS(EncodeTable);
184 };
185
186
187 // A mapper to encode/decode weighted transducers. Encoding of an
188 // Fst is useful for performing classical determinization or minimization
189 // on a weighted transducer by treating it as an unweighted acceptor over
190 // encoded labels.
191 //
192 // The Encode mapper stores the encoding in a local hash table (EncodeTable)
193 // This table is shared (and reference counted) between the encoder and
194 // decoder. A decoder has read only access to the EncodeTable.
195 //
196 // The EncodeMapper allows on the fly encoding of the machine. As the
197 // EncodeTable is generated the same table may by used to decode the machine
198 // on the fly. For example in the following sequence of operations
199 //
200 // Encode -> Determinize -> Decode
201 //
202 // we will use the encoding table generated during the encode step in the
203 // decode, even though the encoding is not complete.
204 //
205 template <class A> class EncodeMapper {
206 typedef typename A::Weight Weight;
207 typedef typename A::Label Label;
208 public:
EncodeMapper(uint32 flags,EncodeType type)209 EncodeMapper(uint32 flags, EncodeType type)
210 : ref_count_(1), flags_(flags), type_(type),
211 table_(new EncodeTable<A>(flags)) {}
212
EncodeMapper(const EncodeMapper & mapper)213 EncodeMapper(const EncodeMapper& mapper)
214 : ref_count_(mapper.ref_count_ + 1),
215 flags_(mapper.flags_),
216 type_(mapper.type_),
217 table_(mapper.table_) { }
218
219 // Copy constructor but setting the type, typically to DECODE
EncodeMapper(const EncodeMapper & mapper,EncodeType type)220 EncodeMapper(const EncodeMapper& mapper, EncodeType type)
221 : ref_count_(mapper.ref_count_ + 1),
222 flags_(mapper.flags_),
223 type_(type),
224 table_(mapper.table_) { }
225
~EncodeMapper()226 ~EncodeMapper() {
227 if (--ref_count_ == 0) delete table_;
228 }
229
operator()230 A operator()(const A &arc) {
231 if (type_ == ENCODE) { // labels and/or weights to single label
232 if ((arc.nextstate == kNoStateId && !(flags_ & kEncodeWeights)) ||
233 (arc.nextstate == kNoStateId && (flags_ & kEncodeWeights) &&
234 arc.weight == Weight::Zero())) {
235 return arc;
236 } else {
237 Label label = table_->Encode(arc);
238 return A(label,
239 flags_ & kEncodeLabels ? label : arc.olabel,
240 flags_ & kEncodeWeights ? Weight::One() : arc.weight,
241 arc.nextstate);
242 }
243 } else {
244 if (arc.nextstate == kNoStateId) {
245 return arc;
246 } else {
247 const typename EncodeTable<A>::Tuple* tuple =
248 table_->Decode(arc.ilabel);
249 return A(tuple->ilabel,
250 flags_ & kEncodeLabels ? tuple->olabel : arc.olabel,
251 flags_ & kEncodeWeights ? tuple->weight : arc.weight,
252 arc.nextstate);;
253 }
254 }
255 }
256
Properties(uint64 props)257 uint64 Properties(uint64 props) {
258 uint64 mask = kFstProperties;
259 if (flags_ & kEncodeLabels)
260 mask &= kILabelInvariantProperties & kOLabelInvariantProperties;
261 if (flags_ & kEncodeWeights)
262 mask &= kILabelInvariantProperties & kWeightInvariantProperties &
263 (type_ == ENCODE ? kAddSuperFinalProperties :
264 kRmSuperFinalProperties);
265 return props & mask;
266 }
267
268
FinalAction()269 MapFinalAction FinalAction() const {
270 return (type_ == ENCODE && (flags_ & kEncodeWeights)) ?
271 MAP_REQUIRE_SUPERFINAL : MAP_NO_SUPERFINAL;
272 }
273
flags()274 const uint32 flags() const { return flags_; }
type()275 const EncodeType type() const { return type_; }
276
Write(ostream & strm,const string & source)277 bool Write(ostream &strm, const string& source) {
278 return table_->Write(strm, source);
279 }
280
Write(const string & filename)281 bool Write(const string& filename) {
282 ofstream strm(filename.c_str());
283 if (!strm) {
284 LOG(ERROR) << "EncodeMap: Can't open file: " << filename;
285 return false;
286 }
287 return Write(strm, filename);
288 }
289
Read(istream & strm,const string & source,EncodeType type)290 static EncodeMapper<A> *Read(istream &strm,
291 const string& source, EncodeType type) {
292 EncodeTable<A> *table = new EncodeTable<A>(0);
293 bool r = table->Read(strm, source);
294 return r ? new EncodeMapper(table->flags(), type, table) : 0;
295 }
296
Read(const string & filename,EncodeType type)297 static EncodeMapper<A> *Read(const string& filename, EncodeType type) {
298 ifstream strm(filename.c_str());
299 if (!strm) {
300 LOG(ERROR) << "EncodeMap: Can't open file: " << filename;
301 return false;
302 }
303 return Read(strm, filename, type);
304 }
305
306 private:
307 uint32 ref_count_;
308 uint32 flags_;
309 EncodeType type_;
310 EncodeTable<A>* table_;
311
EncodeMapper(uint32 flags,EncodeType type,EncodeTable<A> * table)312 explicit EncodeMapper(uint32 flags, EncodeType type, EncodeTable<A> *table)
313 : ref_count_(1), flags_(flags), type_(type), table_(table) {}
314 void operator=(const EncodeMapper &); // Disallow.
315 };
316
317
318 // Complexity: O(nstates + narcs)
319 template<class A> inline
Encode(MutableFst<A> * fst,EncodeMapper<A> * mapper)320 void Encode(MutableFst<A> *fst, EncodeMapper<A>* mapper) {
321 Map(fst, mapper);
322 }
323
324
325 template<class A> inline
Decode(MutableFst<A> * fst,const EncodeMapper<A> & mapper)326 void Decode(MutableFst<A>* fst, const EncodeMapper<A>& mapper) {
327 Map(fst, EncodeMapper<A>(mapper, DECODE));
328 RmFinalEpsilon(fst);
329 }
330
331
332 // On the fly label and/or weight encoding of input Fst
333 //
334 // Complexity:
335 // - Constructor: O(1)
336 // - Traversal: O(nstates_visited + narcs_visited), assuming constant
337 // time to visit an input state or arc.
338 template <class A>
339 class EncodeFst : public MapFst<A, A, EncodeMapper<A> > {
340 public:
341 typedef A Arc;
342 typedef EncodeMapper<A> C;
343
EncodeFst(const Fst<A> & fst,EncodeMapper<A> * encoder)344 EncodeFst(const Fst<A> &fst, EncodeMapper<A>* encoder)
345 : MapFst<A, A, C>(fst, encoder, MapFstOptions()) {}
346
EncodeFst(const Fst<A> & fst,const EncodeMapper<A> & encoder)347 EncodeFst(const Fst<A> &fst, const EncodeMapper<A>& encoder)
348 : MapFst<A, A, C>(fst, encoder, MapFstOptions()) {}
349
EncodeFst(const EncodeFst<A> & fst)350 EncodeFst(const EncodeFst<A> &fst)
351 : MapFst<A, A, C>(fst) {}
352
Copy()353 virtual EncodeFst<A> *Copy() const { return new EncodeFst(*this); }
354 };
355
356
357 // On the fly label and/or weight encoding of input Fst
358 //
359 // Complexity:
360 // - Constructor: O(1)
361 // - Traversal: O(nstates_visited + narcs_visited), assuming constant
362 // time to visit an input state or arc.
363 template <class A>
364 class DecodeFst : public MapFst<A, A, EncodeMapper<A> > {
365 public:
366 typedef A Arc;
367 typedef EncodeMapper<A> C;
368
DecodeFst(const Fst<A> & fst,const EncodeMapper<A> & encoder)369 DecodeFst(const Fst<A> &fst, const EncodeMapper<A>& encoder)
370 : MapFst<A, A, C>(fst,
371 EncodeMapper<A>(encoder, DECODE),
372 MapFstOptions()) {}
373
DecodeFst(const EncodeFst<A> & fst)374 DecodeFst(const EncodeFst<A> &fst)
375 : MapFst<A, A, C>(fst) {}
376
Copy()377 virtual DecodeFst<A> *Copy() const { return new DecodeFst(*this); }
378 };
379
380
381 // Specialization for EncodeFst.
382 template <class A>
383 class StateIterator< EncodeFst<A> >
384 : public StateIterator< MapFst<A, A, EncodeMapper<A> > > {
385 public:
StateIterator(const EncodeFst<A> & fst)386 explicit StateIterator(const EncodeFst<A> &fst)
387 : StateIterator< MapFst<A, A, EncodeMapper<A> > >(fst) {}
388 };
389
390
391 // Specialization for EncodeFst.
392 template <class A>
393 class ArcIterator< EncodeFst<A> >
394 : public ArcIterator< MapFst<A, A, EncodeMapper<A> > > {
395 public:
ArcIterator(const EncodeFst<A> & fst,typename A::StateId s)396 ArcIterator(const EncodeFst<A> &fst, typename A::StateId s)
397 : ArcIterator< MapFst<A, A, EncodeMapper<A> > >(fst, s) {}
398 };
399
400
401 // Specialization for DecodeFst.
402 template <class A>
403 class StateIterator< DecodeFst<A> >
404 : public StateIterator< MapFst<A, A, EncodeMapper<A> > > {
405 public:
StateIterator(const DecodeFst<A> & fst)406 explicit StateIterator(const DecodeFst<A> &fst)
407 : StateIterator< MapFst<A, A, EncodeMapper<A> > >(fst) {}
408 };
409
410
411 // Specialization for DecodeFst.
412 template <class A>
413 class ArcIterator< DecodeFst<A> >
414 : public ArcIterator< MapFst<A, A, EncodeMapper<A> > > {
415 public:
ArcIterator(const DecodeFst<A> & fst,typename A::StateId s)416 ArcIterator(const DecodeFst<A> &fst, typename A::StateId s)
417 : ArcIterator< MapFst<A, A, EncodeMapper<A> > >(fst, s) {}
418 };
419
420
421 // Useful aliases when using StdArc.
422 typedef EncodeFst<StdArc> StdEncodeFst;
423
424 typedef DecodeFst<StdArc> StdDecodeFst;
425
426 }
427
428 #endif // FST_LIB_ENCODE_H__
429