• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 //
14 // Copyright 2005-2010 Google, Inc.
15 // All Rights Reserved.
16 //
17 // Author : Johan Schalkwyk
18 //
19 // \file
20 // Classes to provide symbol-to-integer and integer-to-symbol mappings.
21 
22 #ifndef FST_LIB_SYMBOL_TABLE_H__
23 #define FST_LIB_SYMBOL_TABLE_H__
24 
25 #include <cstring>
26 #include <string>
27 #include <utility>
28 using std::pair; using std::make_pair;
29 #include <vector>
30 using std::vector;
31 
32 
33 #include <fst/compat.h>
34 #include <iostream>
35 #include <fstream>
36 
37 
38 #include <map>
39 
40 DECLARE_bool(fst_compat_symbols);
41 
42 namespace fst {
43 
44 // WARNING: Reading via symbol table read options should
45 //          not be used. This is a temporary work around for
46 //          reading symbol ranges of previously stored symbol sets.
47 struct SymbolTableReadOptions {
SymbolTableReadOptionsSymbolTableReadOptions48   SymbolTableReadOptions() { }
49 
SymbolTableReadOptionsSymbolTableReadOptions50   SymbolTableReadOptions(vector<pair<int64, int64> > string_hash_ranges_,
51                          const string& source_)
52       : string_hash_ranges(string_hash_ranges_),
53         source(source_) { }
54 
55   vector<pair<int64, int64> > string_hash_ranges;
56   string source;
57 };
58 
59 class SymbolTableImpl {
60  public:
SymbolTableImpl(const string & name)61   SymbolTableImpl(const string &name)
62       : name_(name),
63         available_key_(0),
64         dense_key_limit_(0),
65         check_sum_finalized_(false) {}
66 
SymbolTableImpl(const SymbolTableImpl & impl)67   explicit SymbolTableImpl(const SymbolTableImpl& impl)
68       : name_(impl.name_),
69         available_key_(0),
70         dense_key_limit_(0),
71         check_sum_finalized_(false) {
72     for (size_t i = 0; i < impl.symbols_.size(); ++i) {
73       AddSymbol(impl.symbols_[i], impl.Find(impl.symbols_[i]));
74     }
75   }
76 
~SymbolTableImpl()77   ~SymbolTableImpl() {
78     for (size_t i = 0; i < symbols_.size(); ++i)
79       delete[] symbols_[i];
80   }
81 
82   // TODO(johans): Add flag to specify whether the symbol
83   //               should be indexed as string or int or both.
84   int64 AddSymbol(const string& symbol, int64 key);
85 
AddSymbol(const string & symbol)86   int64 AddSymbol(const string& symbol) {
87     int64 key = Find(symbol);
88     return (key == -1) ? AddSymbol(symbol, available_key_++) : key;
89   }
90 
91   static SymbolTableImpl* ReadText(istream &strm,
92                                    const string &name,
93                                    bool allow_negative = false);
94 
95   static SymbolTableImpl* Read(istream &strm,
96                                const SymbolTableReadOptions& opts);
97 
98   bool Write(ostream &strm) const;
99 
100   //
101   // Return the string associated with the key. If the key is out of
102   // range (<0, >max), return an empty string.
Find(int64 key)103   string Find(int64 key) const {
104     if (key >=0 && key < dense_key_limit_)
105       return string(symbols_[key]);
106 
107     map<int64, const char*>::const_iterator it =
108         key_map_.find(key);
109     if (it == key_map_.end()) {
110       return "";
111     }
112     return string(it->second);
113   }
114 
115   //
116   // Return the key associated with the symbol. If the symbol
117   // does not exists, return SymbolTable::kNoSymbol.
Find(const string & symbol)118   int64 Find(const string& symbol) const {
119     return Find(symbol.c_str());
120   }
121 
122   //
123   // Return the key associated with the symbol. If the symbol
124   // does not exists, return SymbolTable::kNoSymbol.
Find(const char * symbol)125   int64 Find(const char* symbol) const {
126     map<const char *, int64, StrCmp>::const_iterator it =
127         symbol_map_.find(symbol);
128     if (it == symbol_map_.end()) {
129       return -1;
130     }
131     return it->second;
132   }
133 
GetNthKey(ssize_t pos)134   int64 GetNthKey(ssize_t pos) const {
135     if ((pos < 0) || (pos >= symbols_.size())) return -1;
136     else return Find(symbols_[pos]);
137   }
138 
Name()139   const string& Name() const { return name_; }
140 
IncrRefCount()141   int IncrRefCount() const {
142     return ref_count_.Incr();
143   }
DecrRefCount()144   int DecrRefCount() const {
145     return ref_count_.Decr();
146   }
RefCount()147   int RefCount() const {
148     return ref_count_.count();
149   }
150 
CheckSum()151   string CheckSum() const {
152     MutexLock check_sum_lock(&check_sum_mutex_);
153     MaybeRecomputeCheckSum();
154     return check_sum_string_;
155   }
156 
LabeledCheckSum()157   string LabeledCheckSum() const {
158     MutexLock check_sum_lock(&check_sum_mutex_);
159     MaybeRecomputeCheckSum();
160     return labeled_check_sum_string_;
161   }
162 
AvailableKey()163   int64 AvailableKey() const {
164     return available_key_;
165   }
166 
NumSymbols()167   size_t NumSymbols() const {
168     return symbols_.size();
169   }
170 
171  private:
172   // Recomputes the checksums (both of them) if we've had changes since the last
173   // computation (i.e., if check_sum_finalized_ is false).
174   void MaybeRecomputeCheckSum() const;
175 
176   struct StrCmp {
operatorStrCmp177     bool operator()(const char *s1, const char *s2) const {
178       return strcmp(s1, s2) < 0;
179     }
180   };
181 
182   string name_;
183   int64 available_key_;
184   int64 dense_key_limit_;
185   vector<const char *> symbols_;
186   map<int64, const char*> key_map_;
187   map<const char *, int64, StrCmp> symbol_map_;
188 
189   mutable RefCounter ref_count_;
190   mutable bool check_sum_finalized_;
191   mutable CheckSummer check_sum_;
192   mutable CheckSummer labeled_check_sum_;
193   mutable string check_sum_string_;
194   mutable string labeled_check_sum_string_;
195   mutable Mutex check_sum_mutex_;
196 };
197 
198 //
199 // \class SymbolTable
200 // \brief Symbol (string) to int and reverse mapping
201 //
202 // The SymbolTable implements the mappings of labels to strings and reverse.
203 // SymbolTables are used to describe the alphabet of the input and output
204 // labels for arcs in a Finite State Transducer.
205 //
206 // SymbolTables are reference counted and can therefore be shared across
207 // multiple machines. For example a language model grammar G, with a
208 // SymbolTable for the words in the language model can share this symbol
209 // table with the lexical representation L o G.
210 //
211 class SymbolTable {
212  public:
213   static const int64 kNoSymbol = -1;
214 
215   // Construct symbol table with a unique name.
SymbolTable(const string & name)216   SymbolTable(const string& name) : impl_(new SymbolTableImpl(name)) {}
217 
218   // Create a reference counted copy.
SymbolTable(const SymbolTable & table)219   SymbolTable(const SymbolTable& table) : impl_(table.impl_) {
220     impl_->IncrRefCount();
221   }
222 
223   // Derefence implentation object. When reference count hits 0, delete
224   // implementation.
~SymbolTable()225   virtual ~SymbolTable() {
226     if (!impl_->DecrRefCount()) delete impl_;
227   }
228 
229   // Read an ascii representation of the symbol table from an istream. Pass a
230   // name to give the resulting SymbolTable.
231   static SymbolTable* ReadText(istream &strm,
232                                const string& name,
233                                bool allow_negative = false) {
234     SymbolTableImpl* impl = SymbolTableImpl::ReadText(strm,
235                                                       name,
236                                                       allow_negative);
237     if (!impl)
238       return 0;
239     else
240       return new SymbolTable(impl);
241   }
242 
243   // read an ascii representation of the symbol table
244   static SymbolTable* ReadText(const string& filename,
245                                bool allow_negative = false) {
246     ifstream strm(filename.c_str(), ifstream::in);
247     if (!strm) {
248       LOG(ERROR) << "SymbolTable::ReadText: Can't open file " << filename;
249       return 0;
250     }
251     return ReadText(strm, filename, allow_negative);
252   }
253 
254 
255   // WARNING: Reading via symbol table read options should
256   //          not be used. This is a temporary work around.
Read(istream & strm,const SymbolTableReadOptions & opts)257   static SymbolTable* Read(istream &strm,
258                            const SymbolTableReadOptions& opts) {
259     SymbolTableImpl* impl = SymbolTableImpl::Read(strm, opts);
260     if (!impl)
261       return 0;
262     else
263       return new SymbolTable(impl);
264   }
265 
266   // read a binary dump of the symbol table from a stream
Read(istream & strm,const string & source)267   static SymbolTable* Read(istream &strm, const string& source) {
268     SymbolTableReadOptions opts;
269     opts.source = source;
270     return Read(strm, opts);
271   }
272 
273   // read a binary dump of the symbol table
Read(const string & filename)274   static SymbolTable* Read(const string& filename) {
275     ifstream strm(filename.c_str(), ifstream::in | ifstream::binary);
276     if (!strm) {
277       LOG(ERROR) << "SymbolTable::Read: Can't open file " << filename;
278       return 0;
279     }
280     return Read(strm, filename);
281   }
282 
283   //--------------------------------------------------------
284   // Derivable Interface (final)
285   //--------------------------------------------------------
286   // create a reference counted copy
Copy()287   virtual SymbolTable* Copy() const {
288     return new SymbolTable(*this);
289   }
290 
291   // Add a symbol with given key to table. A symbol table also
292   // keeps track of the last available key (highest key value in
293   // the symbol table).
AddSymbol(const string & symbol,int64 key)294   virtual int64 AddSymbol(const string& symbol, int64 key) {
295     MutateCheck();
296     return impl_->AddSymbol(symbol, key);
297   }
298 
299   // Add a symbol to the table. The associated value key is automatically
300   // assigned by the symbol table.
AddSymbol(const string & symbol)301   virtual int64 AddSymbol(const string& symbol) {
302     MutateCheck();
303     return impl_->AddSymbol(symbol);
304   }
305 
306   // Add another symbol table to this table. All key values will be offset
307   // by the current available key (highest key value in the symbol table).
308   // Note string symbols with the same key value with still have the same
309   // key value after the symbol table has been merged, but a different
310   // value. Adding symbol tables do not result in changes in the base table.
311   virtual void AddTable(const SymbolTable& table);
312 
313   // return the name of the symbol table
Name()314   virtual const string& Name() const {
315     return impl_->Name();
316   }
317 
318   // Return the label-agnostic MD5 check-sum for this table.  All new symbols
319   // added to the table will result in an updated checksum.
320   // DEPRECATED.
CheckSum()321   virtual string CheckSum() const {
322     return impl_->CheckSum();
323   }
324 
325   // Same as CheckSum(), but this returns an label-dependent version.
LabeledCheckSum()326   virtual string LabeledCheckSum() const {
327     return impl_->LabeledCheckSum();
328   }
329 
Write(ostream & strm)330   virtual bool Write(ostream &strm) const {
331     return impl_->Write(strm);
332   }
333 
Write(const string & filename)334   bool Write(const string& filename) const {
335     ofstream strm(filename.c_str(), ofstream::out | ofstream::binary);
336     if (!strm) {
337       LOG(ERROR) << "SymbolTable::Write: Can't open file " << filename;
338       return false;
339     }
340     return Write(strm);
341   }
342 
343   // Dump an ascii text representation of the symbol table via a stream
344   virtual bool WriteText(ostream &strm) const;
345 
346   // Dump an ascii text representation of the symbol table
WriteText(const string & filename)347   bool WriteText(const string& filename) const {
348     ofstream strm(filename.c_str());
349     if (!strm) {
350       LOG(ERROR) << "SymbolTable::WriteText: Can't open file " << filename;
351       return false;
352     }
353     return WriteText(strm);
354   }
355 
356   // Return the string associated with the key. If the key is out of
357   // range (<0, >max), log error and return an empty string.
Find(int64 key)358   virtual string Find(int64 key) const {
359     return impl_->Find(key);
360   }
361 
362   // Return the key associated with the symbol. If the symbol
363   // does not exists, log error and  return SymbolTable::kNoSymbol
Find(const string & symbol)364   virtual int64 Find(const string& symbol) const {
365     return impl_->Find(symbol);
366   }
367 
368   // Return the key associated with the symbol. If the symbol
369   // does not exists, log error and  return SymbolTable::kNoSymbol
Find(const char * symbol)370   virtual int64 Find(const char* symbol) const {
371     return impl_->Find(symbol);
372   }
373 
374   // Return the current available key (i.e highest key number+1) in
375   // the symbol table
AvailableKey(void)376   virtual int64 AvailableKey(void) const {
377     return impl_->AvailableKey();
378   }
379 
380   // Return the current number of symbols in table (not necessarily
381   // equal to AvailableKey())
NumSymbols(void)382   virtual size_t NumSymbols(void) const {
383     return impl_->NumSymbols();
384   }
385 
GetNthKey(ssize_t pos)386   virtual int64 GetNthKey(ssize_t pos) const {
387     return impl_->GetNthKey(pos);
388   }
389 
390  private:
SymbolTable(SymbolTableImpl * impl)391   explicit SymbolTable(SymbolTableImpl* impl) : impl_(impl) {}
392 
MutateCheck()393   void MutateCheck() {
394     // Copy on write
395     if (impl_->RefCount() > 1) {
396       impl_->DecrRefCount();
397       impl_ = new SymbolTableImpl(*impl_);
398     }
399   }
400 
Impl()401   const SymbolTableImpl* Impl() const {
402     return impl_;
403   }
404 
405  private:
406   SymbolTableImpl* impl_;
407 
408   void operator=(const SymbolTable &table);  // disallow
409 };
410 
411 
412 //
413 // \class SymbolTableIterator
414 // \brief Iterator class for symbols in a symbol table
415 class SymbolTableIterator {
416  public:
SymbolTableIterator(const SymbolTable & table)417   SymbolTableIterator(const SymbolTable& table)
418       : table_(table),
419         pos_(0),
420         nsymbols_(table.NumSymbols()),
421         key_(table.GetNthKey(0)) { }
422 
~SymbolTableIterator()423   ~SymbolTableIterator() { }
424 
425   // is iterator done
Done(void)426   bool Done(void) {
427     return (pos_ == nsymbols_);
428   }
429 
430   // return the Value() of the current symbol (int64 key)
Value(void)431   int64 Value(void) {
432     return key_;
433   }
434 
435   // return the string of the current symbol
Symbol(void)436   string Symbol(void) {
437     return table_.Find(key_);
438   }
439 
440   // advance iterator forward
Next(void)441   void Next(void) {
442     ++pos_;
443     if (pos_ < nsymbols_) key_ = table_.GetNthKey(pos_);
444   }
445 
446   // reset iterator
Reset(void)447   void Reset(void) {
448     pos_ = 0;
449     key_ = table_.GetNthKey(0);
450   }
451 
452  private:
453   const SymbolTable& table_;
454   ssize_t pos_;
455   size_t nsymbols_;
456   int64 key_;
457 };
458 
459 
460 // Tests compatibilty between two sets of symbol tables
461 inline bool CompatSymbols(const SymbolTable *syms1, const SymbolTable *syms2,
462                           bool warning = true) {
463   if (!FLAGS_fst_compat_symbols) {
464     return true;
465   } else if (!syms1 && !syms2) {
466     return true;
467   } else if (syms1 && !syms2) {
468     if (warning)
469       LOG(WARNING) <<
470           "CompatSymbols: first symbol table present but second missing";
471     return false;
472   } else if (!syms1 && syms2) {
473     if (warning)
474       LOG(WARNING) <<
475           "CompatSymbols: second symbol table present but first missing";
476     return false;
477   } else if (syms1->LabeledCheckSum() != syms2->LabeledCheckSum()) {
478     if (warning)
479       LOG(WARNING) << "CompatSymbols: Symbol table check sums do not match";
480     return false;
481   } else {
482     return true;
483   }
484 }
485 
486 
487 // Relabels a symbol table as specified by the input vector of pairs
488 // (old label, new label). The new symbol table only retains symbols
489 // for which a relabeling is *explicitely* specified.
490 // TODO(allauzen): consider adding options to allow for some form
491 // of implicit identity relabeling.
492 template <class Label>
RelabelSymbolTable(const SymbolTable * table,const vector<pair<Label,Label>> & pairs)493 SymbolTable *RelabelSymbolTable(const SymbolTable *table,
494                                 const vector<pair<Label, Label> > &pairs) {
495   SymbolTable *new_table = new SymbolTable(
496       table->Name().empty() ? string() :
497       (string("relabeled_") + table->Name()));
498 
499   for (size_t i = 0; i < pairs.size(); ++i)
500     new_table->AddSymbol(table->Find(pairs[i].first), pairs[i].second);
501 
502   return new_table;
503 }
504 
505 }  // namespace fst
506 
507 #endif  // FST_LIB_SYMBOL_TABLE_H__
508