1
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 // http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 //
14 // Copyright 2005-2010 Google, Inc.
15 // All Rights Reserved.
16 //
17 // Author : Johan Schalkwyk
18 //
19 // \file
20 // Classes to provide symbol-to-integer and integer-to-symbol mappings.
21
22 #ifndef FST_LIB_SYMBOL_TABLE_H__
23 #define FST_LIB_SYMBOL_TABLE_H__
24
25 #include <cstring>
26 #include <string>
27 #include <utility>
28 using std::pair; using std::make_pair;
29 #include <vector>
30 using std::vector;
31
32
33 #include <fst/compat.h>
34 #include <iostream>
35 #include <fstream>
36
37
38 #include <map>
39
40 DECLARE_bool(fst_compat_symbols);
41
42 namespace fst {
43
44 // WARNING: Reading via symbol table read options should
45 // not be used. This is a temporary work around for
46 // reading symbol ranges of previously stored symbol sets.
47 struct SymbolTableReadOptions {
SymbolTableReadOptionsSymbolTableReadOptions48 SymbolTableReadOptions() { }
49
SymbolTableReadOptionsSymbolTableReadOptions50 SymbolTableReadOptions(vector<pair<int64, int64> > string_hash_ranges_,
51 const string& source_)
52 : string_hash_ranges(string_hash_ranges_),
53 source(source_) { }
54
55 vector<pair<int64, int64> > string_hash_ranges;
56 string source;
57 };
58
59 class SymbolTableImpl {
60 public:
SymbolTableImpl(const string & name)61 SymbolTableImpl(const string &name)
62 : name_(name),
63 available_key_(0),
64 dense_key_limit_(0),
65 check_sum_finalized_(false) {}
66
SymbolTableImpl(const SymbolTableImpl & impl)67 explicit SymbolTableImpl(const SymbolTableImpl& impl)
68 : name_(impl.name_),
69 available_key_(0),
70 dense_key_limit_(0),
71 check_sum_finalized_(false) {
72 for (size_t i = 0; i < impl.symbols_.size(); ++i) {
73 AddSymbol(impl.symbols_[i], impl.Find(impl.symbols_[i]));
74 }
75 }
76
~SymbolTableImpl()77 ~SymbolTableImpl() {
78 for (size_t i = 0; i < symbols_.size(); ++i)
79 delete[] symbols_[i];
80 }
81
82 // TODO(johans): Add flag to specify whether the symbol
83 // should be indexed as string or int or both.
84 int64 AddSymbol(const string& symbol, int64 key);
85
AddSymbol(const string & symbol)86 int64 AddSymbol(const string& symbol) {
87 int64 key = Find(symbol);
88 return (key == -1) ? AddSymbol(symbol, available_key_++) : key;
89 }
90
91 static SymbolTableImpl* ReadText(istream &strm,
92 const string &name,
93 bool allow_negative = false);
94
95 static SymbolTableImpl* Read(istream &strm,
96 const SymbolTableReadOptions& opts);
97
98 bool Write(ostream &strm) const;
99
100 //
101 // Return the string associated with the key. If the key is out of
102 // range (<0, >max), return an empty string.
Find(int64 key)103 string Find(int64 key) const {
104 if (key >=0 && key < dense_key_limit_)
105 return string(symbols_[key]);
106
107 map<int64, const char*>::const_iterator it =
108 key_map_.find(key);
109 if (it == key_map_.end()) {
110 return "";
111 }
112 return string(it->second);
113 }
114
115 //
116 // Return the key associated with the symbol. If the symbol
117 // does not exists, return SymbolTable::kNoSymbol.
Find(const string & symbol)118 int64 Find(const string& symbol) const {
119 return Find(symbol.c_str());
120 }
121
122 //
123 // Return the key associated with the symbol. If the symbol
124 // does not exists, return SymbolTable::kNoSymbol.
Find(const char * symbol)125 int64 Find(const char* symbol) const {
126 map<const char *, int64, StrCmp>::const_iterator it =
127 symbol_map_.find(symbol);
128 if (it == symbol_map_.end()) {
129 return -1;
130 }
131 return it->second;
132 }
133
GetNthKey(ssize_t pos)134 int64 GetNthKey(ssize_t pos) const {
135 if ((pos < 0) || (pos >= symbols_.size())) return -1;
136 else return Find(symbols_[pos]);
137 }
138
Name()139 const string& Name() const { return name_; }
140
IncrRefCount()141 int IncrRefCount() const {
142 return ref_count_.Incr();
143 }
DecrRefCount()144 int DecrRefCount() const {
145 return ref_count_.Decr();
146 }
RefCount()147 int RefCount() const {
148 return ref_count_.count();
149 }
150
CheckSum()151 string CheckSum() const {
152 MutexLock check_sum_lock(&check_sum_mutex_);
153 MaybeRecomputeCheckSum();
154 return check_sum_string_;
155 }
156
LabeledCheckSum()157 string LabeledCheckSum() const {
158 MutexLock check_sum_lock(&check_sum_mutex_);
159 MaybeRecomputeCheckSum();
160 return labeled_check_sum_string_;
161 }
162
AvailableKey()163 int64 AvailableKey() const {
164 return available_key_;
165 }
166
NumSymbols()167 size_t NumSymbols() const {
168 return symbols_.size();
169 }
170
171 private:
172 // Recomputes the checksums (both of them) if we've had changes since the last
173 // computation (i.e., if check_sum_finalized_ is false).
174 void MaybeRecomputeCheckSum() const;
175
176 struct StrCmp {
operatorStrCmp177 bool operator()(const char *s1, const char *s2) const {
178 return strcmp(s1, s2) < 0;
179 }
180 };
181
182 string name_;
183 int64 available_key_;
184 int64 dense_key_limit_;
185 vector<const char *> symbols_;
186 map<int64, const char*> key_map_;
187 map<const char *, int64, StrCmp> symbol_map_;
188
189 mutable RefCounter ref_count_;
190 mutable bool check_sum_finalized_;
191 mutable CheckSummer check_sum_;
192 mutable CheckSummer labeled_check_sum_;
193 mutable string check_sum_string_;
194 mutable string labeled_check_sum_string_;
195 mutable Mutex check_sum_mutex_;
196 };
197
198 //
199 // \class SymbolTable
200 // \brief Symbol (string) to int and reverse mapping
201 //
202 // The SymbolTable implements the mappings of labels to strings and reverse.
203 // SymbolTables are used to describe the alphabet of the input and output
204 // labels for arcs in a Finite State Transducer.
205 //
206 // SymbolTables are reference counted and can therefore be shared across
207 // multiple machines. For example a language model grammar G, with a
208 // SymbolTable for the words in the language model can share this symbol
209 // table with the lexical representation L o G.
210 //
211 class SymbolTable {
212 public:
213 static const int64 kNoSymbol = -1;
214
215 // Construct symbol table with a unique name.
SymbolTable(const string & name)216 SymbolTable(const string& name) : impl_(new SymbolTableImpl(name)) {}
217
218 // Create a reference counted copy.
SymbolTable(const SymbolTable & table)219 SymbolTable(const SymbolTable& table) : impl_(table.impl_) {
220 impl_->IncrRefCount();
221 }
222
223 // Derefence implentation object. When reference count hits 0, delete
224 // implementation.
~SymbolTable()225 virtual ~SymbolTable() {
226 if (!impl_->DecrRefCount()) delete impl_;
227 }
228
229 // Read an ascii representation of the symbol table from an istream. Pass a
230 // name to give the resulting SymbolTable.
231 static SymbolTable* ReadText(istream &strm,
232 const string& name,
233 bool allow_negative = false) {
234 SymbolTableImpl* impl = SymbolTableImpl::ReadText(strm,
235 name,
236 allow_negative);
237 if (!impl)
238 return 0;
239 else
240 return new SymbolTable(impl);
241 }
242
243 // read an ascii representation of the symbol table
244 static SymbolTable* ReadText(const string& filename,
245 bool allow_negative = false) {
246 ifstream strm(filename.c_str(), ifstream::in);
247 if (!strm) {
248 LOG(ERROR) << "SymbolTable::ReadText: Can't open file " << filename;
249 return 0;
250 }
251 return ReadText(strm, filename, allow_negative);
252 }
253
254
255 // WARNING: Reading via symbol table read options should
256 // not be used. This is a temporary work around.
Read(istream & strm,const SymbolTableReadOptions & opts)257 static SymbolTable* Read(istream &strm,
258 const SymbolTableReadOptions& opts) {
259 SymbolTableImpl* impl = SymbolTableImpl::Read(strm, opts);
260 if (!impl)
261 return 0;
262 else
263 return new SymbolTable(impl);
264 }
265
266 // read a binary dump of the symbol table from a stream
Read(istream & strm,const string & source)267 static SymbolTable* Read(istream &strm, const string& source) {
268 SymbolTableReadOptions opts;
269 opts.source = source;
270 return Read(strm, opts);
271 }
272
273 // read a binary dump of the symbol table
Read(const string & filename)274 static SymbolTable* Read(const string& filename) {
275 ifstream strm(filename.c_str(), ifstream::in | ifstream::binary);
276 if (!strm) {
277 LOG(ERROR) << "SymbolTable::Read: Can't open file " << filename;
278 return 0;
279 }
280 return Read(strm, filename);
281 }
282
283 //--------------------------------------------------------
284 // Derivable Interface (final)
285 //--------------------------------------------------------
286 // create a reference counted copy
Copy()287 virtual SymbolTable* Copy() const {
288 return new SymbolTable(*this);
289 }
290
291 // Add a symbol with given key to table. A symbol table also
292 // keeps track of the last available key (highest key value in
293 // the symbol table).
AddSymbol(const string & symbol,int64 key)294 virtual int64 AddSymbol(const string& symbol, int64 key) {
295 MutateCheck();
296 return impl_->AddSymbol(symbol, key);
297 }
298
299 // Add a symbol to the table. The associated value key is automatically
300 // assigned by the symbol table.
AddSymbol(const string & symbol)301 virtual int64 AddSymbol(const string& symbol) {
302 MutateCheck();
303 return impl_->AddSymbol(symbol);
304 }
305
306 // Add another symbol table to this table. All key values will be offset
307 // by the current available key (highest key value in the symbol table).
308 // Note string symbols with the same key value with still have the same
309 // key value after the symbol table has been merged, but a different
310 // value. Adding symbol tables do not result in changes in the base table.
311 virtual void AddTable(const SymbolTable& table);
312
313 // return the name of the symbol table
Name()314 virtual const string& Name() const {
315 return impl_->Name();
316 }
317
318 // Return the label-agnostic MD5 check-sum for this table. All new symbols
319 // added to the table will result in an updated checksum.
320 // DEPRECATED.
CheckSum()321 virtual string CheckSum() const {
322 return impl_->CheckSum();
323 }
324
325 // Same as CheckSum(), but this returns an label-dependent version.
LabeledCheckSum()326 virtual string LabeledCheckSum() const {
327 return impl_->LabeledCheckSum();
328 }
329
Write(ostream & strm)330 virtual bool Write(ostream &strm) const {
331 return impl_->Write(strm);
332 }
333
Write(const string & filename)334 bool Write(const string& filename) const {
335 ofstream strm(filename.c_str(), ofstream::out | ofstream::binary);
336 if (!strm) {
337 LOG(ERROR) << "SymbolTable::Write: Can't open file " << filename;
338 return false;
339 }
340 return Write(strm);
341 }
342
343 // Dump an ascii text representation of the symbol table via a stream
344 virtual bool WriteText(ostream &strm) const;
345
346 // Dump an ascii text representation of the symbol table
WriteText(const string & filename)347 bool WriteText(const string& filename) const {
348 ofstream strm(filename.c_str());
349 if (!strm) {
350 LOG(ERROR) << "SymbolTable::WriteText: Can't open file " << filename;
351 return false;
352 }
353 return WriteText(strm);
354 }
355
356 // Return the string associated with the key. If the key is out of
357 // range (<0, >max), log error and return an empty string.
Find(int64 key)358 virtual string Find(int64 key) const {
359 return impl_->Find(key);
360 }
361
362 // Return the key associated with the symbol. If the symbol
363 // does not exists, log error and return SymbolTable::kNoSymbol
Find(const string & symbol)364 virtual int64 Find(const string& symbol) const {
365 return impl_->Find(symbol);
366 }
367
368 // Return the key associated with the symbol. If the symbol
369 // does not exists, log error and return SymbolTable::kNoSymbol
Find(const char * symbol)370 virtual int64 Find(const char* symbol) const {
371 return impl_->Find(symbol);
372 }
373
374 // Return the current available key (i.e highest key number+1) in
375 // the symbol table
AvailableKey(void)376 virtual int64 AvailableKey(void) const {
377 return impl_->AvailableKey();
378 }
379
380 // Return the current number of symbols in table (not necessarily
381 // equal to AvailableKey())
NumSymbols(void)382 virtual size_t NumSymbols(void) const {
383 return impl_->NumSymbols();
384 }
385
GetNthKey(ssize_t pos)386 virtual int64 GetNthKey(ssize_t pos) const {
387 return impl_->GetNthKey(pos);
388 }
389
390 private:
SymbolTable(SymbolTableImpl * impl)391 explicit SymbolTable(SymbolTableImpl* impl) : impl_(impl) {}
392
MutateCheck()393 void MutateCheck() {
394 // Copy on write
395 if (impl_->RefCount() > 1) {
396 impl_->DecrRefCount();
397 impl_ = new SymbolTableImpl(*impl_);
398 }
399 }
400
Impl()401 const SymbolTableImpl* Impl() const {
402 return impl_;
403 }
404
405 private:
406 SymbolTableImpl* impl_;
407
408 void operator=(const SymbolTable &table); // disallow
409 };
410
411
412 //
413 // \class SymbolTableIterator
414 // \brief Iterator class for symbols in a symbol table
415 class SymbolTableIterator {
416 public:
SymbolTableIterator(const SymbolTable & table)417 SymbolTableIterator(const SymbolTable& table)
418 : table_(table),
419 pos_(0),
420 nsymbols_(table.NumSymbols()),
421 key_(table.GetNthKey(0)) { }
422
~SymbolTableIterator()423 ~SymbolTableIterator() { }
424
425 // is iterator done
Done(void)426 bool Done(void) {
427 return (pos_ == nsymbols_);
428 }
429
430 // return the Value() of the current symbol (int64 key)
Value(void)431 int64 Value(void) {
432 return key_;
433 }
434
435 // return the string of the current symbol
Symbol(void)436 string Symbol(void) {
437 return table_.Find(key_);
438 }
439
440 // advance iterator forward
Next(void)441 void Next(void) {
442 ++pos_;
443 if (pos_ < nsymbols_) key_ = table_.GetNthKey(pos_);
444 }
445
446 // reset iterator
Reset(void)447 void Reset(void) {
448 pos_ = 0;
449 key_ = table_.GetNthKey(0);
450 }
451
452 private:
453 const SymbolTable& table_;
454 ssize_t pos_;
455 size_t nsymbols_;
456 int64 key_;
457 };
458
459
460 // Tests compatibilty between two sets of symbol tables
461 inline bool CompatSymbols(const SymbolTable *syms1, const SymbolTable *syms2,
462 bool warning = true) {
463 if (!FLAGS_fst_compat_symbols) {
464 return true;
465 } else if (!syms1 && !syms2) {
466 return true;
467 } else if (syms1 && !syms2) {
468 if (warning)
469 LOG(WARNING) <<
470 "CompatSymbols: first symbol table present but second missing";
471 return false;
472 } else if (!syms1 && syms2) {
473 if (warning)
474 LOG(WARNING) <<
475 "CompatSymbols: second symbol table present but first missing";
476 return false;
477 } else if (syms1->LabeledCheckSum() != syms2->LabeledCheckSum()) {
478 if (warning)
479 LOG(WARNING) << "CompatSymbols: Symbol table check sums do not match";
480 return false;
481 } else {
482 return true;
483 }
484 }
485
486
487 // Relabels a symbol table as specified by the input vector of pairs
488 // (old label, new label). The new symbol table only retains symbols
489 // for which a relabeling is *explicitely* specified.
490 // TODO(allauzen): consider adding options to allow for some form
491 // of implicit identity relabeling.
492 template <class Label>
RelabelSymbolTable(const SymbolTable * table,const vector<pair<Label,Label>> & pairs)493 SymbolTable *RelabelSymbolTable(const SymbolTable *table,
494 const vector<pair<Label, Label> > &pairs) {
495 SymbolTable *new_table = new SymbolTable(
496 table->Name().empty() ? string() :
497 (string("relabeled_") + table->Name()));
498
499 for (size_t i = 0; i < pairs.size(); ++i)
500 new_table->AddSymbol(table->Find(pairs[i].first), pairs[i].second);
501
502 return new_table;
503 }
504
505 } // namespace fst
506
507 #endif // FST_LIB_SYMBOL_TABLE_H__
508