1 // symbol-table.h
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 //
16 // \file
17 // Classes to provide symbol-to-integer and integer-to-symbol mappings.
18
19 #ifndef FST_LIB_SYMBOL_TABLE_H__
20 #define FST_LIB_SYMBOL_TABLE_H__
21
22 #include <fstream>
23 #include <iostream>
24 #include <string>
25 #include <unordered_map>
26 #include <vector>
27
28 #include "fst/lib/compat.h"
29
30
31
32 DECLARE_bool(fst_compat_symbols);
33
34 namespace fst {
35
36 class SymbolTableImpl {
37 friend class SymbolTableIterator;
38 public:
SymbolTableImpl(const string & name)39 SymbolTableImpl(const string &name)
40 : name_(name), available_key_(0), ref_count_(1),
41 check_sum_finalized_(false) {}
~SymbolTableImpl()42 ~SymbolTableImpl() {
43 for (size_t i = 0; i < symbols_.size(); ++i)
44 delete[] symbols_[i];
45 }
46
47 int64 AddSymbol(const string& symbol, int64 key);
48
AddSymbol(const string & symbol)49 int64 AddSymbol(const string& symbol) {
50 int64 key = Find(symbol);
51 return (key == -1) ? AddSymbol(symbol, available_key_++) : key;
52 }
53
AddTable(SymbolTableImpl * table)54 void AddTable(SymbolTableImpl* table) {
55 for (size_t i = 0; i < table->symbols_.size(); ++i) {
56 AddSymbol(table->symbols_[i]);
57 }
58 }
59
60 static SymbolTableImpl* ReadText(const string& filename);
61
62 static SymbolTableImpl* Read(istream &strm, const string& source);
63
64 bool Write(ostream &strm) const;
65
66 bool WriteText(ostream &strm) const;
67
68 //
69 // Return the string associated with the key. If the key is out of
70 // range (<0, >max), return an empty string.
Find(int64 key)71 string Find(int64 key) const {
72 std::unordered_map<int64, string>::const_iterator it =
73 key_map_.find(key);
74 if (it == key_map_.end()) {
75 return "";
76 }
77 return it->second;
78 }
79
80 //
81 // Return the key associated with the symbol. If the symbol
82 // does not exists, return -1.
Find(const string & symbol)83 int64 Find(const string& symbol) const {
84 return Find(symbol.c_str());
85 }
86
87 //
88 // Return the key associated with the symbol. If the symbol
89 // does not exists, return -1.
Find(const char * symbol)90 int64 Find(const char* symbol) const {
91 unordered_map<string, int64>::const_iterator it =
92 symbol_map_.find(symbol);
93 if (it == symbol_map_.end()) {
94 return -1;
95 }
96 return it->second;
97 }
98
Name()99 const string& Name() const { return name_; }
100
IncrRefCount()101 int IncrRefCount() const {
102 return ++ref_count_;
103 }
DecrRefCount()104 int DecrRefCount() const {
105 return --ref_count_;
106 }
107
CheckSum()108 string CheckSum() const {
109 if (!check_sum_finalized_) {
110 RecomputeCheckSum();
111 check_sum_string_ = check_sum_.Digest();
112 }
113 return check_sum_string_;
114 }
115
AvailableKey()116 int64 AvailableKey() const {
117 return available_key_;
118 }
119
120 // private support methods
121 private:
122 void RecomputeCheckSum() const;
123 static SymbolTableImpl* Read1(istream &, const string &);
124
125 string name_;
126 int64 available_key_;
127 vector<const char *> symbols_;
128 std::unordered_map<int64, string> key_map_;
129 std::unordered_map<string, int64> symbol_map_;
130
131 mutable int ref_count_;
132 mutable bool check_sum_finalized_;
133 mutable MD5 check_sum_;
134 mutable string check_sum_string_;
135
136 DISALLOW_EVIL_CONSTRUCTORS(SymbolTableImpl);
137 };
138
139
140 class SymbolTableIterator;
141
142 //
143 // \class SymbolTable
144 // \brief Symbol (string) to int and reverse mapping
145 //
146 // The SymbolTable implements the mappings of labels to strings and reverse.
147 // SymbolTables are used to describe the alphabet of the input and output
148 // labels for arcs in a Finite State Transducer.
149 //
150 // SymbolTables are reference counted and can therefore be shared across
151 // multiple machines. For example a language model grammar G, with a
152 // SymbolTable for the words in the language model can share this symbol
153 // table with the lexical representation L o G.
154 //
155 class SymbolTable {
156 friend class SymbolTableIterator;
157 public:
158 static const int64 kNoSymbol = -1;
159
160 // Construct symbol table with a unique name.
SymbolTable(const string & name)161 SymbolTable(const string& name) : impl_(new SymbolTableImpl(name)) {}
162
163 // Create a reference counted copy.
SymbolTable(const SymbolTable & table)164 SymbolTable(const SymbolTable& table) : impl_(table.impl_) {
165 impl_->IncrRefCount();
166 }
167
168 // Derefence implentation object. When reference count hits 0, delete
169 // implementation.
~SymbolTable()170 ~SymbolTable() {
171 if (!impl_->DecrRefCount()) delete impl_;
172 }
173
174 // create a reference counted copy
Copy()175 SymbolTable* Copy() const {
176 return new SymbolTable(*this);
177 }
178
179 // Add a symbol with given key to table. A symbol table also
180 // keeps track of the last available key (highest key value in
181 // the symbol table).
182 //
183 // \param symbol string symbol to add
184 // \param key associated key for string symbol
185 // \return the key created by the symbol table. Symbols allready added to
186 // the symbol table will not get a different key.
AddSymbol(const string & symbol,int64 key)187 int64 AddSymbol(const string& symbol, int64 key) {
188 return impl_->AddSymbol(symbol, key);
189 }
190
191 // Add a symbol to the table. The associated value key is automatically
192 // assigned by the symbol table.
193 //
194 // \param symbol string to add to the table
195 // \return the value key assigned to the associated string symbol
AddSymbol(const string & symbol)196 int64 AddSymbol(const string& symbol) {
197 return impl_->AddSymbol(symbol);
198 }
199
200 // Add another symbol table to this table. All key values will be offset
201 // by the current available key (highest key value in the symbol table).
202 // Note string symbols with the same key value with still have the same
203 // key value after the symbol table has been merged, but a different
204 // value. Adding symbol tables do not result in changes in the base table.
205 //
206 // Merging N symbol tables is often useful when combining the various
207 // name spaces of transducers to a unified representation.
208 //
209 // \param table the symbol table to add to this table
AddTable(const SymbolTable & table)210 void AddTable(const SymbolTable& table) {
211 return impl_->AddTable(table.impl_);
212 }
213
214 // return the name of the symbol table
Name()215 const string& Name() const {
216 return impl_->Name();
217 }
218
219 // return the MD5 check-sum for this table. All new symbols added to
220 // the table will result in an updated checksum.
CheckSum()221 string CheckSum() const {
222 return impl_->CheckSum();
223 }
224
225 // read an ascii representation of the symbol table
ReadText(const string & filename)226 static SymbolTable* ReadText(const string& filename) {
227 SymbolTableImpl* impl = SymbolTableImpl::ReadText(filename);
228 if (!impl)
229 return 0;
230 else
231 return new SymbolTable(impl);
232 }
233
234 // read a binary dump of the symbol table
Read(istream & strm,const string & source)235 static SymbolTable* Read(istream &strm, const string& source) {
236 SymbolTableImpl* impl = SymbolTableImpl::Read(strm, source);
237 if (!impl)
238 return 0;
239 else
240 return new SymbolTable(impl);
241 }
242
243 // read a binary dump of the symbol table
Read(const string & filename)244 static SymbolTable* Read(const string& filename) {
245 ifstream strm(filename.c_str());
246 if (!strm) {
247 LOG(ERROR) << "SymbolTable::Read: Can't open file " << filename;
248 return 0;
249 }
250 return Read(strm, filename);
251 }
252
Write(ostream & strm)253 bool Write(ostream &strm) const {
254 return impl_->Write(strm);
255 }
256
Write(const string & filename)257 bool Write(const string& filename) const {
258 ofstream strm(filename.c_str());
259 if (!strm) {
260 LOG(ERROR) << "SymbolTable::Write: Can't open file " << filename;
261 return false;
262 }
263 return Write(strm);
264 }
265
266 // Dump an ascii text representation of the symbol table
WriteText(ostream & strm)267 bool WriteText(ostream &strm) const {
268 return impl_->WriteText(strm);
269 }
270
271 // Dump an ascii text representation of the symbol table
WriteText(const string & filename)272 bool WriteText(const string& filename) const {
273 ofstream strm(filename.c_str());
274 if (!strm) {
275 LOG(ERROR) << "SymbolTable::WriteText: Can't open file " << filename;
276 return false;
277 }
278 return WriteText(strm);
279 }
280
281 // Return the string associated with the key. If the key is out of
282 // range (<0, >max), log error and return an empty string.
Find(int64 key)283 string Find(int64 key) const {
284 return impl_->Find(key);
285 }
286
287 // Return the key associated with the symbol. If the symbol
288 // does not exists, log error and return -1
Find(const string & symbol)289 int64 Find(const string& symbol) const {
290 return impl_->Find(symbol);
291 }
292
293 // Return the key associated with the symbol. If the symbol
294 // does not exists, log error and return -1
Find(const char * symbol)295 int64 Find(const char* symbol) const {
296 return impl_->Find(symbol);
297 }
298
299 // return the current available key (i.e highest key number) in
300 // the symbol table
AvailableKey(void)301 int64 AvailableKey(void) const {
302 return impl_->AvailableKey();
303 }
304
305 protected:
SymbolTable(SymbolTableImpl * impl)306 explicit SymbolTable(SymbolTableImpl* impl) : impl_(impl) {}
307
Impl()308 const SymbolTableImpl* Impl() const {
309 return impl_;
310 }
311
312 private:
313 SymbolTableImpl* impl_;
314
315
316 void operator=(const SymbolTable &table); // disallow
317 };
318
319
320 //
321 // \class SymbolTableIterator
322 // \brief Iterator class for symbols in a symbol table
323 class SymbolTableIterator {
324 public:
325 // Constructor creates a refcounted copy of underlying implementation
SymbolTableIterator(const SymbolTable & symbol_table)326 SymbolTableIterator(const SymbolTable& symbol_table) {
327 impl_ = symbol_table.Impl();
328 impl_->IncrRefCount();
329 pos_ = 0;
330 size_ = impl_->symbols_.size();
331 }
332
333 // decrement implementation refcount, and delete if 0
~SymbolTableIterator()334 ~SymbolTableIterator() {
335 if (!impl_->DecrRefCount()) delete impl_;
336 }
337
338 // is iterator done
Done(void)339 bool Done(void) {
340 return (pos_ == size_);
341 }
342
343 // return the Value() of the current symbol (in64 key)
Value(void)344 int64 Value(void) {
345 return impl_->Find(impl_->symbols_[pos_]);
346 }
347
348 // return the string of the current symbol
Symbol(void)349 const char* Symbol(void) {
350 return impl_->symbols_[pos_];
351 }
352
353 // advance iterator forward
Next(void)354 void Next(void) {
355 if (Done()) return;
356 ++pos_;
357 }
358
359 // reset iterator
Reset(void)360 void Reset(void) {
361 pos_ = 0;
362 }
363
364 private:
365 const SymbolTableImpl* impl_;
366 size_t pos_;
367 size_t size_;
368 };
369
370
371 // Tests compatibilty between two sets of symbol tables
CompatSymbols(const SymbolTable * syms1,const SymbolTable * syms2)372 inline bool CompatSymbols(const SymbolTable *syms1,
373 const SymbolTable *syms2) {
374 if (!FLAGS_fst_compat_symbols)
375 return true;
376 else if (!syms1 && !syms2)
377 return true;
378 else if ((syms1 && !syms2) || (!syms1 && syms2))
379 return false;
380 else
381 return syms1->CheckSum() == syms2->CheckSum();
382 }
383
384 } // namespace fst
385
386 #endif // FST_LIB_SYMBOL_TABLE_H__
387