• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 //
14 // Copyright 2005-2010 Google, Inc.
15 // Author: sorenj@google.com (Jeffrey Sorensen)
16 //
17 #ifndef FST_EXTENSIONS_NGRAM_NGRAM_FST_H_
18 #define FST_EXTENSIONS_NGRAM_NGRAM_FST_H_
19 
20 #include <stddef.h>
21 #include <string.h>
22 #include <algorithm>
23 #include <string>
24 #include <vector>
25 using std::vector;
26 
27 #include <fst/compat.h>
28 #include <fst/fstlib.h>
29 #include <fst/mapped-file.h>
30 #include <fst/extensions/ngram/bitmap-index.h>
31 
32 // NgramFst implements a n-gram language model based upon the LOUDS data
33 // structure.  Please refer to "Unary Data Strucutres for Language Models"
34 // http://research.google.com/pubs/archive/37218.pdf
35 
36 namespace fst {
37 template <class A> class NGramFst;
38 template <class A> class NGramFstMatcher;
39 
40 // Instance data containing mutable state for bookkeeping repeated access to
41 // the same state.
42 template <class A>
43 struct NGramFstInst {
44   typedef typename A::Label Label;
45   typedef typename A::StateId StateId;
46   typedef typename A::Weight Weight;
47   StateId state_;
48   size_t num_futures_;
49   size_t offset_;
50   size_t node_;
51   StateId node_state_;
52   vector<Label> context_;
53   StateId context_state_;
NGramFstInstNGramFstInst54   NGramFstInst()
55       : state_(kNoStateId), node_state_(kNoStateId),
56         context_state_(kNoStateId) { }
57 };
58 
59 // Implementation class for LOUDS based NgramFst interface
60 template <class A>
61 class NGramFstImpl : public FstImpl<A> {
62   using FstImpl<A>::SetInputSymbols;
63   using FstImpl<A>::SetOutputSymbols;
64   using FstImpl<A>::SetType;
65   using FstImpl<A>::WriteHeader;
66 
67   friend class ArcIterator<NGramFst<A> >;
68   friend class NGramFstMatcher<A>;
69 
70  public:
71   using FstImpl<A>::InputSymbols;
72   using FstImpl<A>::SetProperties;
73   using FstImpl<A>::Properties;
74 
75   typedef A Arc;
76   typedef typename A::Label Label;
77   typedef typename A::StateId StateId;
78   typedef typename A::Weight Weight;
79 
NGramFstImpl()80   NGramFstImpl() : data_region_(0), data_(0), owned_(false) {
81     SetType("ngram");
82     SetInputSymbols(NULL);
83     SetOutputSymbols(NULL);
84     SetProperties(kStaticProperties);
85   }
86 
87   NGramFstImpl(const Fst<A> &fst, vector<StateId>* order_out);
88 
~NGramFstImpl()89   ~NGramFstImpl() {
90     if (owned_) {
91       delete [] data_;
92     }
93     delete data_region_;
94   }
95 
Read(istream & strm,const FstReadOptions & opts)96   static NGramFstImpl<A>* Read(istream &strm,  // NOLINT
97                                const FstReadOptions &opts) {
98     NGramFstImpl<A>* impl = new NGramFstImpl();
99     FstHeader hdr;
100     if (!impl->ReadHeader(strm, opts, kMinFileVersion, &hdr)) return 0;
101     uint64 num_states, num_futures, num_final;
102     const size_t offset = sizeof(num_states) + sizeof(num_futures) +
103         sizeof(num_final);
104     // Peek at num_states and num_futures to see how much more needs to be read.
105     strm.read(reinterpret_cast<char *>(&num_states), sizeof(num_states));
106     strm.read(reinterpret_cast<char *>(&num_futures), sizeof(num_futures));
107     strm.read(reinterpret_cast<char *>(&num_final), sizeof(num_final));
108     size_t size = Storage(num_states, num_futures, num_final);
109     MappedFile *data_region = MappedFile::Allocate(size);
110     char *data = reinterpret_cast<char *>(data_region->mutable_data());
111     // Copy num_states, num_futures and num_final back into data.
112     memcpy(data, reinterpret_cast<char *>(&num_states), sizeof(num_states));
113     memcpy(data + sizeof(num_states), reinterpret_cast<char *>(&num_futures),
114            sizeof(num_futures));
115     memcpy(data + sizeof(num_states) + sizeof(num_futures),
116            reinterpret_cast<char *>(&num_final), sizeof(num_final));
117     strm.read(data + offset, size - offset);
118     if (!strm) {
119       delete impl;
120       return NULL;
121     }
122     impl->Init(data, false, data_region);
123     return impl;
124   }
125 
Write(ostream & strm,const FstWriteOptions & opts)126   bool Write(ostream &strm,   // NOLINT
127              const FstWriteOptions &opts) const {
128     FstHeader hdr;
129     hdr.SetStart(Start());
130     hdr.SetNumStates(num_states_);
131     WriteHeader(strm, opts, kFileVersion, &hdr);
132     strm.write(data_, StorageSize());
133     return strm;
134   }
135 
Start()136   StateId Start() const {
137     return 1;
138   }
139 
Final(StateId state)140   Weight Final(StateId state) const {
141     if (final_index_.Get(state)) {
142       return final_probs_[final_index_.Rank1(state)];
143     } else {
144       return Weight::Zero();
145     }
146   }
147 
148   size_t NumArcs(StateId state, NGramFstInst<A> *inst = NULL) const {
149     if (inst == NULL) {
150       const size_t next_zero = future_index_.Select0(state + 1);
151       const size_t this_zero = future_index_.Select0(state);
152       return next_zero - this_zero - 1;
153     }
154     SetInstFuture(state, inst);
155     return inst->num_futures_ + ((state == 0) ? 0 : 1);
156   }
157 
NumInputEpsilons(StateId state)158   size_t NumInputEpsilons(StateId state) const {
159     // State 0 has no parent, thus no backoff.
160     if (state == 0) return 0;
161     return 1;
162   }
163 
NumOutputEpsilons(StateId state)164   size_t NumOutputEpsilons(StateId state) const {
165     return NumInputEpsilons(state);
166   }
167 
NumStates()168   StateId NumStates() const {
169     return num_states_;
170   }
171 
InitStateIterator(StateIteratorData<A> * data)172   void InitStateIterator(StateIteratorData<A>* data) const {
173     data->base = 0;
174     data->nstates = num_states_;
175   }
176 
Storage(uint64 num_states,uint64 num_futures,uint64 num_final)177   static size_t Storage(uint64 num_states, uint64 num_futures,
178                         uint64 num_final) {
179     uint64 b64;
180     Weight weight;
181     Label label;
182     size_t offset = sizeof(num_states) + sizeof(num_futures) +
183         sizeof(num_final);
184     offset += sizeof(b64) * (
185         BitmapIndex::StorageSize(num_states * 2 + 1) +
186         BitmapIndex::StorageSize(num_futures + num_states + 1) +
187         BitmapIndex::StorageSize(num_states));
188     offset += (num_states + 1) * sizeof(label) + num_futures * sizeof(label);
189     // Pad for alignemnt, see
190     // http://en.wikipedia.org/wiki/Data_structure_alignment#Computing_padding
191     offset = (offset + sizeof(weight) - 1) & ~(sizeof(weight) - 1);
192     offset += (num_states + 1) * sizeof(weight) + num_final * sizeof(weight) +
193         (num_futures + 1) * sizeof(weight);
194     return offset;
195   }
196 
SetInstFuture(StateId state,NGramFstInst<A> * inst)197   void SetInstFuture(StateId state, NGramFstInst<A> *inst) const {
198     if (inst->state_ != state) {
199       inst->state_ = state;
200       const size_t next_zero = future_index_.Select0(state + 1);
201       const size_t this_zero = future_index_.Select0(state);
202       inst->num_futures_ = next_zero - this_zero - 1;
203       inst->offset_ = future_index_.Rank1(future_index_.Select0(state) + 1);
204     }
205   }
206 
SetInstNode(NGramFstInst<A> * inst)207   void SetInstNode(NGramFstInst<A> *inst) const {
208     if (inst->node_state_ != inst->state_) {
209       inst->node_state_ = inst->state_;
210       inst->node_ = context_index_.Select1(inst->state_);
211     }
212   }
213 
SetInstContext(NGramFstInst<A> * inst)214   void SetInstContext(NGramFstInst<A> *inst) const {
215     SetInstNode(inst);
216     if (inst->context_state_ != inst->state_) {
217       inst->context_state_ = inst->state_;
218       inst->context_.clear();
219       size_t node = inst->node_;
220       while (node != 0) {
221         inst->context_.push_back(context_words_[context_index_.Rank1(node)]);
222         node = context_index_.Select1(context_index_.Rank0(node) - 1);
223       }
224     }
225   }
226 
227   // Access to the underlying representation
GetData(size_t * data_size)228   const char* GetData(size_t* data_size) const {
229     *data_size = StorageSize();
230     return data_;
231   }
232 
233   void Init(const char* data, bool owned, MappedFile *file = 0);
234 
GetContext(StateId s,NGramFstInst<A> * inst)235   const vector<Label> &GetContext(StateId s, NGramFstInst<A> *inst) const {
236     SetInstFuture(s, inst);
237     SetInstContext(inst);
238     return inst->context_;
239   }
240 
StorageSize()241   size_t StorageSize() const {
242     return Storage(num_states_, num_futures_, num_final_);
243   }
244 
245   void GetStates(const vector<Label>& context, vector<StateId> *states) const;
246 
247  private:
248   StateId Transition(const vector<Label> &context, Label future) const;
249 
250   // Properties always true for this Fst class.
251   static const uint64 kStaticProperties = kAcceptor | kIDeterministic |
252       kODeterministic | kEpsilons | kIEpsilons | kOEpsilons | kILabelSorted |
253       kOLabelSorted | kWeighted | kCyclic | kInitialAcyclic | kNotTopSorted |
254       kAccessible | kCoAccessible | kNotString | kExpanded;
255   // Current file format version.
256   static const int kFileVersion = 4;
257   // Minimum file format version supported.
258   static const int kMinFileVersion = 4;
259 
260   MappedFile *data_region_;
261   const char* data_;
262   bool owned_;  // True if we own data_
263   uint64 num_states_, num_futures_, num_final_;
264   size_t root_num_children_;
265   const Label *root_children_;
266   size_t root_first_child_;
267   // borrowed references
268   const uint64 *context_, *future_, *final_;
269   const Label *context_words_, *future_words_;
270   const Weight *backoff_, *final_probs_, *future_probs_;
271   BitmapIndex context_index_;
272   BitmapIndex future_index_;
273   BitmapIndex final_index_;
274 
275   void operator=(const NGramFstImpl<A> &);  // Disallow
276 };
277 
278 template<typename A>
NGramFstImpl(const Fst<A> & fst,vector<StateId> * order_out)279 NGramFstImpl<A>::NGramFstImpl(const Fst<A> &fst, vector<StateId>* order_out)
280     : data_region_(0), data_(0), owned_(false) {
281   typedef A Arc;
282   typedef typename Arc::Label Label;
283   typedef typename Arc::Weight Weight;
284   typedef typename Arc::StateId StateId;
285   SetType("ngram");
286   SetInputSymbols(fst.InputSymbols());
287   SetOutputSymbols(fst.OutputSymbols());
288   SetProperties(kStaticProperties);
289 
290   // Check basic requirements for an OpenGRM language model Fst.
291   int64 props = kAcceptor | kIDeterministic | kIEpsilons | kILabelSorted;
292   if (fst.Properties(props, true) != props) {
293     FSTERROR() << "NGramFst only accepts OpenGRM langauge models as input";
294     SetProperties(kError, kError);
295     return;
296   }
297 
298   int64 num_states = CountStates(fst);
299   Label* context = new Label[num_states];
300 
301   // Find the unigram state by starting from the start state, following
302   // epsilons.
303   StateId unigram = fst.Start();
304   while (1) {
305     if (unigram == kNoStateId) {
306       FSTERROR() << "Could not identify unigram state.";
307       SetProperties(kError, kError);
308       return;
309     }
310     ArcIterator<Fst<A> > aiter(fst, unigram);
311     if (aiter.Done()) {
312       LOG(WARNING) << "Unigram state " << unigram << " has no arcs.";
313       break;
314     }
315     if (aiter.Value().ilabel != 0) break;
316     unigram = aiter.Value().nextstate;
317   }
318 
319   // Each state's context is determined by the subtree it is under from the
320   // unigram state.
321   queue<pair<StateId, Label> > label_queue;
322   vector<bool> visited(num_states);
323   // Force an epsilon link to the start state.
324   label_queue.push(make_pair(fst.Start(), 0));
325   for (ArcIterator<Fst<A> > aiter(fst, unigram);
326        !aiter.Done(); aiter.Next()) {
327     label_queue.push(make_pair(aiter.Value().nextstate, aiter.Value().ilabel));
328   }
329   // investigate states in breadth first fashion to assign context words.
330   while (!label_queue.empty()) {
331     pair<StateId, Label> &now = label_queue.front();
332     if (!visited[now.first]) {
333       context[now.first] = now.second;
334       visited[now.first] = true;
335       for (ArcIterator<Fst<A> > aiter(fst, now.first);
336            !aiter.Done(); aiter.Next()) {
337         const Arc &arc = aiter.Value();
338         if (arc.ilabel != 0) {
339           label_queue.push(make_pair(arc.nextstate, now.second));
340         }
341       }
342     }
343     label_queue.pop();
344   }
345   visited.clear();
346 
347   // The arc from the start state should be assigned an epsilon to put it
348   // in front of the all other labels (which makes Start state 1 after
349   // unigram which is state 0).
350   context[fst.Start()] = 0;
351 
352   // Build the tree of contexts fst by reversing the epsilon arcs from fst.
353   VectorFst<Arc> context_fst;
354   uint64 num_final = 0;
355   for (int i = 0; i < num_states; ++i) {
356     if (fst.Final(i) != Weight::Zero()) {
357       ++num_final;
358     }
359     context_fst.SetFinal(context_fst.AddState(), fst.Final(i));
360   }
361   context_fst.SetStart(unigram);
362   context_fst.SetInputSymbols(fst.InputSymbols());
363   context_fst.SetOutputSymbols(fst.OutputSymbols());
364   int64 num_context_arcs = 0;
365   int64 num_futures = 0;
366   for (StateIterator<Fst<A> > siter(fst); !siter.Done(); siter.Next()) {
367     const StateId &state = siter.Value();
368     num_futures += fst.NumArcs(state) - fst.NumInputEpsilons(state);
369     ArcIterator<Fst<A> > aiter(fst, state);
370     if (!aiter.Done()) {
371       const Arc &arc = aiter.Value();
372       // this arc goes from state to arc.nextstate, so create an arc from
373       // arc.nextstate to state to reverse it.
374       if (arc.ilabel == 0) {
375         context_fst.AddArc(arc.nextstate, Arc(context[state], context[state],
376                                               arc.weight, state));
377         num_context_arcs++;
378       }
379     }
380   }
381   if (num_context_arcs != context_fst.NumStates() - 1) {
382     FSTERROR() << "Number of contexts arcs != number of states - 1";
383     SetProperties(kError, kError);
384     return;
385   }
386   if (context_fst.NumStates() != num_states) {
387     FSTERROR() << "Number of contexts != number of states";
388     SetProperties(kError, kError);
389     return;
390   }
391   int64 context_props = context_fst.Properties(kIDeterministic |
392                                                kILabelSorted, true);
393   if (!(context_props & kIDeterministic)) {
394     FSTERROR() << "Input fst is not structured properly";
395     SetProperties(kError, kError);
396     return;
397   }
398   if (!(context_props & kILabelSorted)) {
399      ArcSort(&context_fst, ILabelCompare<Arc>());
400   }
401 
402   delete [] context;
403 
404   uint64 b64;
405   Weight weight;
406   Label label = kNoLabel;
407   const size_t storage = Storage(num_states, num_futures, num_final);
408   MappedFile *data_region = MappedFile::Allocate(storage);
409   char *data = reinterpret_cast<char *>(data_region->mutable_data());
410   memset(data, 0, storage);
411   size_t offset = 0;
412   memcpy(data + offset, reinterpret_cast<char *>(&num_states),
413          sizeof(num_states));
414   offset += sizeof(num_states);
415   memcpy(data + offset, reinterpret_cast<char *>(&num_futures),
416          sizeof(num_futures));
417   offset += sizeof(num_futures);
418   memcpy(data + offset, reinterpret_cast<char *>(&num_final),
419          sizeof(num_final));
420   offset += sizeof(num_final);
421   uint64* context_bits = reinterpret_cast<uint64*>(data + offset);
422   offset += BitmapIndex::StorageSize(num_states * 2 + 1) * sizeof(b64);
423   uint64* future_bits = reinterpret_cast<uint64*>(data + offset);
424   offset +=
425       BitmapIndex::StorageSize(num_futures + num_states + 1) * sizeof(b64);
426   uint64* final_bits = reinterpret_cast<uint64*>(data + offset);
427   offset += BitmapIndex::StorageSize(num_states) * sizeof(b64);
428   Label* context_words = reinterpret_cast<Label*>(data + offset);
429   offset += (num_states + 1) * sizeof(label);
430   Label* future_words = reinterpret_cast<Label*>(data + offset);
431   offset += num_futures * sizeof(label);
432   offset = (offset + sizeof(weight) - 1) & ~(sizeof(weight) - 1);
433   Weight* backoff = reinterpret_cast<Weight*>(data + offset);
434   offset += (num_states + 1) * sizeof(weight);
435   Weight* final_probs = reinterpret_cast<Weight*>(data + offset);
436   offset += num_final * sizeof(weight);
437   Weight* future_probs = reinterpret_cast<Weight*>(data + offset);
438   int64 context_arc = 0, future_arc = 0, context_bit = 0, future_bit = 0,
439         final_bit = 0;
440 
441   // pseudo-root bits
442   BitmapIndex::Set(context_bits, context_bit++);
443   ++context_bit;
444   context_words[context_arc] = label;
445   backoff[context_arc] = Weight::Zero();
446   context_arc++;
447 
448   ++future_bit;
449   if (order_out) {
450     order_out->clear();
451     order_out->resize(num_states);
452   }
453 
454   queue<StateId> context_q;
455   context_q.push(context_fst.Start());
456   StateId state_number = 0;
457   while (!context_q.empty()) {
458     const StateId &state = context_q.front();
459     if (order_out) {
460       (*order_out)[state] = state_number;
461     }
462 
463     const Weight &final = context_fst.Final(state);
464     if (final != Weight::Zero()) {
465       BitmapIndex::Set(final_bits, state_number);
466       final_probs[final_bit] = final;
467       ++final_bit;
468     }
469 
470     for (ArcIterator<VectorFst<A> > aiter(context_fst, state);
471          !aiter.Done(); aiter.Next()) {
472       const Arc &arc = aiter.Value();
473       context_words[context_arc] = arc.ilabel;
474       backoff[context_arc] = arc.weight;
475       ++context_arc;
476       BitmapIndex::Set(context_bits, context_bit++);
477       context_q.push(arc.nextstate);
478     }
479     ++context_bit;
480 
481     for (ArcIterator<Fst<A> > aiter(fst, state); !aiter.Done(); aiter.Next()) {
482       const Arc &arc = aiter.Value();
483       if (arc.ilabel != 0) {
484         future_words[future_arc] = arc.ilabel;
485         future_probs[future_arc] = arc.weight;
486         ++future_arc;
487         BitmapIndex::Set(future_bits, future_bit++);
488       }
489     }
490     ++future_bit;
491     ++state_number;
492     context_q.pop();
493   }
494 
495   if ((state_number !=  num_states) ||
496       (context_bit != num_states * 2 + 1) ||
497       (context_arc != num_states) ||
498       (future_arc != num_futures) ||
499       (future_bit != num_futures + num_states + 1) ||
500       (final_bit != num_final)) {
501     FSTERROR() << "Structure problems detected during construction";
502     SetProperties(kError, kError);
503     return;
504   }
505 
506   Init(data, false, data_region);
507 }
508 
509 template<typename A>
Init(const char * data,bool owned,MappedFile * data_region)510 inline void NGramFstImpl<A>::Init(const char* data, bool owned,
511                                   MappedFile *data_region) {
512   if (owned_) {
513     delete [] data_;
514   }
515   delete data_region_;
516   data_region_ = data_region;
517   owned_ = owned;
518   data_ = data;
519   size_t offset = 0;
520   num_states_ = *(reinterpret_cast<const uint64*>(data_ + offset));
521   offset += sizeof(num_states_);
522   num_futures_ = *(reinterpret_cast<const uint64*>(data_ + offset));
523   offset += sizeof(num_futures_);
524   num_final_ = *(reinterpret_cast<const uint64*>(data_ + offset));
525   offset += sizeof(num_final_);
526   uint64 bits;
527   size_t context_bits = num_states_ * 2 + 1;
528   size_t future_bits = num_futures_ + num_states_ + 1;
529   context_ = reinterpret_cast<const uint64*>(data_ + offset);
530   offset += BitmapIndex::StorageSize(context_bits) * sizeof(bits);
531   future_ = reinterpret_cast<const uint64*>(data_ + offset);
532   offset += BitmapIndex::StorageSize(future_bits) * sizeof(bits);
533   final_ = reinterpret_cast<const uint64*>(data_ + offset);
534   offset += BitmapIndex::StorageSize(num_states_) * sizeof(bits);
535   context_words_ = reinterpret_cast<const Label*>(data_ + offset);
536   offset += (num_states_ + 1) * sizeof(*context_words_);
537   future_words_ = reinterpret_cast<const Label*>(data_ + offset);
538   offset += num_futures_ * sizeof(*future_words_);
539   offset = (offset + sizeof(*backoff_) - 1) & ~(sizeof(*backoff_) - 1);
540   backoff_ = reinterpret_cast<const Weight*>(data_ + offset);
541   offset += (num_states_ + 1) * sizeof(*backoff_);
542   final_probs_ = reinterpret_cast<const Weight*>(data_ + offset);
543   offset += num_final_ * sizeof(*final_probs_);
544   future_probs_ = reinterpret_cast<const Weight*>(data_ + offset);
545 
546   context_index_.BuildIndex(context_, context_bits);
547   future_index_.BuildIndex(future_, future_bits);
548   final_index_.BuildIndex(final_, num_states_);
549 
550   const size_t node_rank = context_index_.Rank1(0);
551   root_first_child_ = context_index_.Select0(node_rank) + 1;
552   if (context_index_.Get(root_first_child_) == false) {
553     FSTERROR() << "Missing unigrams";
554     SetProperties(kError, kError);
555     return;
556   }
557   const size_t last_child = context_index_.Select0(node_rank + 1) - 1;
558   root_num_children_ = last_child - root_first_child_ + 1;
559   root_children_ = context_words_ + context_index_.Rank1(root_first_child_);
560 }
561 
562 template<typename A>
Transition(const vector<Label> & context,Label future)563 inline typename A::StateId NGramFstImpl<A>::Transition(
564         const vector<Label> &context, Label future) const {
565   const Label *children = root_children_;
566   const Label *loc = lower_bound(children, children + root_num_children_,
567                                  future);
568   if (loc == children + root_num_children_ || *loc != future) {
569     return context_index_.Rank1(0);
570   }
571   size_t node = root_first_child_ + loc - children;
572   size_t node_rank = context_index_.Rank1(node);
573   size_t first_child = context_index_.Select0(node_rank) + 1;
574   if (context_index_.Get(first_child) == false) {
575     return context_index_.Rank1(node);
576   }
577   size_t last_child = context_index_.Select0(node_rank + 1) - 1;
578   for (int word = context.size() - 1; word >= 0; --word) {
579     children = context_words_ + context_index_.Rank1(first_child);
580     loc = lower_bound(children, children + last_child - first_child + 1,
581                       context[word]);
582     if (loc == children + last_child - first_child + 1 ||
583         *loc != context[word]) {
584       break;
585     }
586     node = first_child + loc - children;
587     node_rank = context_index_.Rank1(node);
588     first_child = context_index_.Select0(node_rank) + 1;
589     if (context_index_.Get(first_child) == false) break;
590     last_child = context_index_.Select0(node_rank + 1) - 1;
591   }
592   return context_index_.Rank1(node);
593 }
594 
595 template<typename A>
GetStates(const vector<Label> & context,vector<typename A::StateId> * states)596 inline void NGramFstImpl<A>::GetStates(
597     const vector<Label> &context,
598     vector<typename A::StateId>* states) const {
599   states->clear();
600   states->push_back(0);
601   typename vector<Label>::const_reverse_iterator cit = context.rbegin();
602   const Label *children = root_children_;
603   const Label *loc = lower_bound(children, children + root_num_children_, *cit);
604   if (loc == children + root_num_children_ || *loc != *cit) return;
605   size_t node = root_first_child_ + loc - children;
606   states->push_back(context_index_.Rank1(node));
607   if (context.size() == 1) return;
608   size_t node_rank = context_index_.Rank1(node);
609   size_t first_child = context_index_.Select0(node_rank) + 1;
610   ++cit;
611   if (context_index_.Get(first_child) != false) {
612     size_t last_child = context_index_.Select0(node_rank + 1) - 1;
613     while (cit != context.rend()) {
614       children = context_words_ + context_index_.Rank1(first_child);
615       loc = lower_bound(children, children + last_child - first_child + 1,
616                         *cit);
617       if (loc == children + last_child - first_child + 1 || *loc != *cit) {
618         break;
619       }
620       ++cit;
621       node = first_child + loc - children;
622       states->push_back(context_index_.Rank1(node));
623       node_rank = context_index_.Rank1(node);
624       first_child = context_index_.Select0(node_rank) + 1;
625       if (context_index_.Get(first_child) == false) break;
626       last_child = context_index_.Select0(node_rank + 1) - 1;
627     }
628   }
629 }
630 
631 /*****************************************************************************/
632 template<class A>
633 class NGramFst : public ImplToExpandedFst<NGramFstImpl<A> > {
634   friend class ArcIterator<NGramFst<A> >;
635   friend class NGramFstMatcher<A>;
636 
637  public:
638   typedef A Arc;
639   typedef typename A::StateId StateId;
640   typedef typename A::Label Label;
641   typedef typename A::Weight Weight;
642   typedef NGramFstImpl<A> Impl;
643 
NGramFst(const Fst<A> & dst)644   explicit NGramFst(const Fst<A> &dst)
645       : ImplToExpandedFst<Impl>(new Impl(dst, NULL)) {}
646 
NGramFst(const Fst<A> & fst,vector<StateId> * order_out)647   NGramFst(const Fst<A> &fst, vector<StateId>* order_out)
648       : ImplToExpandedFst<Impl>(new Impl(fst, order_out)) {}
649 
650   // Because the NGramFstImpl is a const stateless data structure, there
651   // is never a need to do anything beside copy the reference.
652   NGramFst(const NGramFst<A> &fst, bool safe = false)
653       : ImplToExpandedFst<Impl>(fst, false) {}
654 
NGramFst()655   NGramFst() : ImplToExpandedFst<Impl>(new Impl()) {}
656 
657   // Non-standard constructor to initialize NGramFst directly from data.
NGramFst(const char * data,bool owned)658   NGramFst(const char* data, bool owned) : ImplToExpandedFst<Impl>(new Impl()) {
659     GetImpl()->Init(data, owned, NULL);
660   }
661 
662   // Get method that gets the data associated with Init().
GetData(size_t * data_size)663   const char* GetData(size_t* data_size) const {
664     return GetImpl()->GetData(data_size);
665   }
666 
GetContext(StateId s)667   const vector<Label> GetContext(StateId s) const {
668     return GetImpl()->GetContext(s, &inst_);
669   }
670 
671   // Consumes as much as possible of context from right to left, returns the
672   // the states corresponding to the increasingly conditioned input sequence.
GetStates(const vector<Label> & context,vector<StateId> * state)673   void GetStates(const vector<Label>& context, vector<StateId> *state) const {
674     return GetImpl()->GetStates(context, state);
675   }
676 
NumArcs(StateId s)677   virtual size_t NumArcs(StateId s) const {
678     return GetImpl()->NumArcs(s, &inst_);
679   }
680 
681   virtual NGramFst<A>* Copy(bool safe = false) const {
682     return new NGramFst(*this, safe);
683   }
684 
Read(istream & strm,const FstReadOptions & opts)685   static NGramFst<A>* Read(istream &strm, const FstReadOptions &opts) {
686     Impl* impl = Impl::Read(strm, opts);
687     return impl ? new NGramFst<A>(impl) : 0;
688   }
689 
Read(const string & filename)690   static NGramFst<A>* Read(const string &filename) {
691     if (!filename.empty()) {
692       ifstream strm(filename.c_str(), ifstream::in | ifstream::binary);
693       if (!strm) {
694         LOG(ERROR) << "NGramFst::Read: Can't open file: " << filename;
695         return 0;
696       }
697       return Read(strm, FstReadOptions(filename));
698     } else {
699       return Read(cin, FstReadOptions("standard input"));
700     }
701   }
702 
Write(ostream & strm,const FstWriteOptions & opts)703   virtual bool Write(ostream &strm, const FstWriteOptions &opts) const {
704     return GetImpl()->Write(strm, opts);
705   }
706 
Write(const string & filename)707   virtual bool Write(const string &filename) const {
708     return Fst<A>::WriteFile(filename);
709   }
710 
InitStateIterator(StateIteratorData<A> * data)711   virtual inline void InitStateIterator(StateIteratorData<A>* data) const {
712     GetImpl()->InitStateIterator(data);
713   }
714 
715   virtual inline void InitArcIterator(
716       StateId s, ArcIteratorData<A>* data) const;
717 
InitMatcher(MatchType match_type)718   virtual MatcherBase<A>* InitMatcher(MatchType match_type) const {
719     return new NGramFstMatcher<A>(*this, match_type);
720   }
721 
StorageSize()722   size_t StorageSize() const {
723     return GetImpl()->StorageSize();
724   }
725 
726  private:
NGramFst(Impl * impl)727   explicit NGramFst(Impl* impl) : ImplToExpandedFst<Impl>(impl) {}
728 
GetImpl()729   Impl* GetImpl() const {
730     return
731         ImplToExpandedFst<Impl, ExpandedFst<A> >::GetImpl();
732   }
733 
734   void SetImpl(Impl* impl, bool own_impl = true) {
735     ImplToExpandedFst<Impl, Fst<A> >::SetImpl(impl, own_impl);
736   }
737 
738   mutable NGramFstInst<A> inst_;
739 };
740 
741 template <class A> inline void
InitArcIterator(StateId s,ArcIteratorData<A> * data)742 NGramFst<A>::InitArcIterator(StateId s, ArcIteratorData<A>* data) const {
743   GetImpl()->SetInstFuture(s, &inst_);
744   GetImpl()->SetInstNode(&inst_);
745   data->base = new ArcIterator<NGramFst<A> >(*this, s);
746 }
747 
748 /*****************************************************************************/
749 template <class A>
750 class NGramFstMatcher : public MatcherBase<A> {
751  public:
752   typedef A Arc;
753   typedef typename A::Label Label;
754   typedef typename A::StateId StateId;
755   typedef typename A::Weight Weight;
756 
NGramFstMatcher(const NGramFst<A> & fst,MatchType match_type)757   NGramFstMatcher(const NGramFst<A> &fst, MatchType match_type)
758       : fst_(fst), inst_(fst.inst_), match_type_(match_type),
759         current_loop_(false),
760         loop_(kNoLabel, 0, A::Weight::One(), kNoStateId) {
761     if (match_type_ == MATCH_OUTPUT) {
762       swap(loop_.ilabel, loop_.olabel);
763     }
764   }
765 
766   NGramFstMatcher(const NGramFstMatcher<A> &matcher, bool safe = false)
767       : fst_(matcher.fst_), inst_(matcher.inst_),
768         match_type_(matcher.match_type_), current_loop_(false),
769         loop_(kNoLabel, 0, A::Weight::One(), kNoStateId) {
770     if (match_type_ == MATCH_OUTPUT) {
771       swap(loop_.ilabel, loop_.olabel);
772     }
773   }
774 
775   virtual NGramFstMatcher<A>* Copy(bool safe = false) const {
776     return new NGramFstMatcher<A>(*this, safe);
777   }
778 
Type(bool test)779   virtual MatchType Type(bool test) const {
780     return match_type_;
781   }
782 
GetFst()783   virtual const Fst<A> &GetFst() const {
784     return fst_;
785   }
786 
Properties(uint64 props)787   virtual uint64 Properties(uint64 props) const {
788     return props;
789   }
790 
791  private:
SetState_(StateId s)792   virtual void SetState_(StateId s) {
793     fst_.GetImpl()->SetInstFuture(s, &inst_);
794     current_loop_ = false;
795   }
796 
Find_(Label label)797   virtual bool Find_(Label label) {
798     const Label nolabel = kNoLabel;
799     done_ = true;
800     if (label == 0 || label == nolabel) {
801       if (label == 0) {
802         current_loop_ = true;
803         loop_.nextstate = inst_.state_;
804       }
805       // The unigram state has no epsilon arc.
806       if (inst_.state_ != 0) {
807         arc_.ilabel = arc_.olabel = 0;
808         fst_.GetImpl()->SetInstNode(&inst_);
809         arc_.nextstate = fst_.GetImpl()->context_index_.Rank1(
810             fst_.GetImpl()->context_index_.Select1(
811                 fst_.GetImpl()->context_index_.Rank0(inst_.node_) - 1));
812         arc_.weight = fst_.GetImpl()->backoff_[inst_.state_];
813         done_ = false;
814       }
815     } else {
816       const Label *start = fst_.GetImpl()->future_words_ + inst_.offset_;
817       const Label *end = start + inst_.num_futures_;
818       const Label* search = lower_bound(start, end, label);
819       if (search != end && *search == label) {
820         size_t state = search - start;
821         arc_.ilabel = arc_.olabel = label;
822         arc_.weight = fst_.GetImpl()->future_probs_[inst_.offset_ + state];
823         fst_.GetImpl()->SetInstContext(&inst_);
824         arc_.nextstate = fst_.GetImpl()->Transition(inst_.context_, label);
825         done_ = false;
826       }
827     }
828     return !Done_();
829   }
830 
Done_()831   virtual bool Done_() const {
832     return !current_loop_ && done_;
833   }
834 
Value_()835   virtual const Arc& Value_() const {
836     return (current_loop_) ? loop_ : arc_;
837   }
838 
Next_()839   virtual void Next_() {
840     if (current_loop_) {
841       current_loop_ = false;
842     } else {
843       done_ = true;
844     }
845   }
846 
847   const NGramFst<A>& fst_;
848   NGramFstInst<A> inst_;
849   MatchType match_type_;             // Supplied by caller
850   bool done_;
851   Arc arc_;
852   bool current_loop_;                // Current arc is the implicit loop
853   Arc loop_;
854 };
855 
856 /*****************************************************************************/
857 template<class A>
858 class ArcIterator<NGramFst<A> > : public ArcIteratorBase<A> {
859  public:
860   typedef A Arc;
861   typedef typename A::Label Label;
862   typedef typename A::StateId StateId;
863   typedef typename A::Weight Weight;
864 
ArcIterator(const NGramFst<A> & fst,StateId state)865   ArcIterator(const NGramFst<A> &fst, StateId state)
866       : lazy_(~0), impl_(fst.GetImpl()), i_(0), flags_(kArcValueFlags) {
867     inst_ = fst.inst_;
868     impl_->SetInstFuture(state, &inst_);
869     impl_->SetInstNode(&inst_);
870   }
871 
Done()872   bool Done() const {
873     return i_ >= ((inst_.node_ == 0) ? inst_.num_futures_ :
874                   inst_.num_futures_ + 1);
875   }
876 
Value()877   const Arc &Value() const {
878     bool eps = (inst_.node_ != 0 && i_ == 0);
879     StateId state = (inst_.node_ == 0) ? i_ : i_ - 1;
880     if (flags_ & lazy_ & (kArcILabelValue | kArcOLabelValue)) {
881       arc_.ilabel =
882           arc_.olabel = eps ? 0 : impl_->future_words_[inst_.offset_ + state];
883       lazy_ &= ~(kArcILabelValue | kArcOLabelValue);
884     }
885     if (flags_ & lazy_ & kArcNextStateValue) {
886       if (eps) {
887         arc_.nextstate = impl_->context_index_.Rank1(
888             impl_->context_index_.Select1(
889                 impl_->context_index_.Rank0(inst_.node_) - 1));
890       } else {
891         if (lazy_ & kArcNextStateValue) {
892           impl_->SetInstContext(&inst_);  // first time only.
893         }
894         arc_.nextstate =
895             impl_->Transition(inst_.context_,
896                               impl_->future_words_[inst_.offset_ + state]);
897       }
898       lazy_ &= ~kArcNextStateValue;
899     }
900     if (flags_ & lazy_ & kArcWeightValue) {
901       arc_.weight = eps ?  impl_->backoff_[inst_.state_] :
902           impl_->future_probs_[inst_.offset_ + state];
903       lazy_ &= ~kArcWeightValue;
904     }
905     return arc_;
906   }
907 
Next()908   void Next() {
909     ++i_;
910     lazy_ = ~0;
911   }
912 
Position()913   size_t Position() const { return i_; }
914 
Reset()915   void Reset() {
916     i_ = 0;
917     lazy_ = ~0;
918   }
919 
Seek(size_t a)920   void Seek(size_t a) {
921     if (i_ != a) {
922       i_ = a;
923       lazy_ = ~0;
924     }
925   }
926 
Flags()927   uint32 Flags() const {
928     return flags_;
929   }
930 
SetFlags(uint32 f,uint32 m)931   void SetFlags(uint32 f, uint32 m) {
932     flags_ &= ~m;
933     flags_ |= (f & kArcValueFlags);
934   }
935 
936  private:
Done_()937   virtual bool Done_() const { return Done(); }
Value_()938   virtual const Arc& Value_() const { return Value(); }
Next_()939   virtual void Next_() { Next(); }
Position_()940   virtual size_t Position_() const { return Position(); }
Reset_()941   virtual void Reset_() { Reset(); }
Seek_(size_t a)942   virtual void Seek_(size_t a) { Seek(a); }
Flags_()943   uint32 Flags_() const { return Flags(); }
SetFlags_(uint32 f,uint32 m)944   void SetFlags_(uint32 f, uint32 m) { SetFlags(f, m); }
945 
946   mutable Arc arc_;
947   mutable uint32 lazy_;
948   const NGramFstImpl<A> *impl_;
949   mutable NGramFstInst<A> inst_;
950 
951   size_t i_;
952   uint32 flags_;
953 
954   DISALLOW_COPY_AND_ASSIGN(ArcIterator);
955 };
956 
957 /*****************************************************************************/
958 // Specialization for NGramFst; see generic version in fst.h
959 // for sample usage (but use the ProdLmFst type!). This version
960 // should inline.
961 template <class A>
962 class StateIterator<NGramFst<A> > : public StateIteratorBase<A> {
963   public:
964   typedef typename A::StateId StateId;
965 
StateIterator(const NGramFst<A> & fst)966   explicit StateIterator(const NGramFst<A> &fst)
967     : s_(0), num_states_(fst.NumStates()) { }
968 
Done()969   bool Done() const { return s_ >= num_states_; }
Value()970   StateId Value() const { return s_; }
Next()971   void Next() { ++s_; }
Reset()972   void Reset() { s_ = 0; }
973 
974  private:
Done_()975   virtual bool Done_() const { return Done(); }
Value_()976   virtual StateId Value_() const { return Value(); }
Next_()977   virtual void Next_() { Next(); }
Reset_()978   virtual void Reset_() { Reset(); }
979 
980   StateId s_, num_states_;
981 
982   DISALLOW_COPY_AND_ASSIGN(StateIterator);
983 };
984 }  // namespace fst
985 #endif  // FST_EXTENSIONS_NGRAM_NGRAM_FST_H_
986