1
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 // http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 //
14 // Copyright 2005-2010 Google, Inc.
15 // Author: sorenj@google.com (Jeffrey Sorensen)
16 //
17 #ifndef FST_EXTENSIONS_NGRAM_NGRAM_FST_H_
18 #define FST_EXTENSIONS_NGRAM_NGRAM_FST_H_
19
20 #include <stddef.h>
21 #include <string.h>
22 #include <algorithm>
23 #include <string>
24 #include <vector>
25 using std::vector;
26
27 #include <fst/compat.h>
28 #include <fst/fstlib.h>
29 #include <fst/mapped-file.h>
30 #include <fst/extensions/ngram/bitmap-index.h>
31
32 // NgramFst implements a n-gram language model based upon the LOUDS data
33 // structure. Please refer to "Unary Data Strucutres for Language Models"
34 // http://research.google.com/pubs/archive/37218.pdf
35
36 namespace fst {
37 template <class A> class NGramFst;
38 template <class A> class NGramFstMatcher;
39
40 // Instance data containing mutable state for bookkeeping repeated access to
41 // the same state.
42 template <class A>
43 struct NGramFstInst {
44 typedef typename A::Label Label;
45 typedef typename A::StateId StateId;
46 typedef typename A::Weight Weight;
47 StateId state_;
48 size_t num_futures_;
49 size_t offset_;
50 size_t node_;
51 StateId node_state_;
52 vector<Label> context_;
53 StateId context_state_;
NGramFstInstNGramFstInst54 NGramFstInst()
55 : state_(kNoStateId), node_state_(kNoStateId),
56 context_state_(kNoStateId) { }
57 };
58
59 // Implementation class for LOUDS based NgramFst interface
60 template <class A>
61 class NGramFstImpl : public FstImpl<A> {
62 using FstImpl<A>::SetInputSymbols;
63 using FstImpl<A>::SetOutputSymbols;
64 using FstImpl<A>::SetType;
65 using FstImpl<A>::WriteHeader;
66
67 friend class ArcIterator<NGramFst<A> >;
68 friend class NGramFstMatcher<A>;
69
70 public:
71 using FstImpl<A>::InputSymbols;
72 using FstImpl<A>::SetProperties;
73 using FstImpl<A>::Properties;
74
75 typedef A Arc;
76 typedef typename A::Label Label;
77 typedef typename A::StateId StateId;
78 typedef typename A::Weight Weight;
79
NGramFstImpl()80 NGramFstImpl() : data_region_(0), data_(0), owned_(false) {
81 SetType("ngram");
82 SetInputSymbols(NULL);
83 SetOutputSymbols(NULL);
84 SetProperties(kStaticProperties);
85 }
86
87 NGramFstImpl(const Fst<A> &fst, vector<StateId>* order_out);
88
~NGramFstImpl()89 ~NGramFstImpl() {
90 if (owned_) {
91 delete [] data_;
92 }
93 delete data_region_;
94 }
95
Read(istream & strm,const FstReadOptions & opts)96 static NGramFstImpl<A>* Read(istream &strm, // NOLINT
97 const FstReadOptions &opts) {
98 NGramFstImpl<A>* impl = new NGramFstImpl();
99 FstHeader hdr;
100 if (!impl->ReadHeader(strm, opts, kMinFileVersion, &hdr)) return 0;
101 uint64 num_states, num_futures, num_final;
102 const size_t offset = sizeof(num_states) + sizeof(num_futures) +
103 sizeof(num_final);
104 // Peek at num_states and num_futures to see how much more needs to be read.
105 strm.read(reinterpret_cast<char *>(&num_states), sizeof(num_states));
106 strm.read(reinterpret_cast<char *>(&num_futures), sizeof(num_futures));
107 strm.read(reinterpret_cast<char *>(&num_final), sizeof(num_final));
108 size_t size = Storage(num_states, num_futures, num_final);
109 MappedFile *data_region = MappedFile::Allocate(size);
110 char *data = reinterpret_cast<char *>(data_region->mutable_data());
111 // Copy num_states, num_futures and num_final back into data.
112 memcpy(data, reinterpret_cast<char *>(&num_states), sizeof(num_states));
113 memcpy(data + sizeof(num_states), reinterpret_cast<char *>(&num_futures),
114 sizeof(num_futures));
115 memcpy(data + sizeof(num_states) + sizeof(num_futures),
116 reinterpret_cast<char *>(&num_final), sizeof(num_final));
117 strm.read(data + offset, size - offset);
118 if (!strm) {
119 delete impl;
120 return NULL;
121 }
122 impl->Init(data, false, data_region);
123 return impl;
124 }
125
Write(ostream & strm,const FstWriteOptions & opts)126 bool Write(ostream &strm, // NOLINT
127 const FstWriteOptions &opts) const {
128 FstHeader hdr;
129 hdr.SetStart(Start());
130 hdr.SetNumStates(num_states_);
131 WriteHeader(strm, opts, kFileVersion, &hdr);
132 strm.write(data_, StorageSize());
133 return strm;
134 }
135
Start()136 StateId Start() const {
137 return 1;
138 }
139
Final(StateId state)140 Weight Final(StateId state) const {
141 if (final_index_.Get(state)) {
142 return final_probs_[final_index_.Rank1(state)];
143 } else {
144 return Weight::Zero();
145 }
146 }
147
148 size_t NumArcs(StateId state, NGramFstInst<A> *inst = NULL) const {
149 if (inst == NULL) {
150 const size_t next_zero = future_index_.Select0(state + 1);
151 const size_t this_zero = future_index_.Select0(state);
152 return next_zero - this_zero - 1;
153 }
154 SetInstFuture(state, inst);
155 return inst->num_futures_ + ((state == 0) ? 0 : 1);
156 }
157
NumInputEpsilons(StateId state)158 size_t NumInputEpsilons(StateId state) const {
159 // State 0 has no parent, thus no backoff.
160 if (state == 0) return 0;
161 return 1;
162 }
163
NumOutputEpsilons(StateId state)164 size_t NumOutputEpsilons(StateId state) const {
165 return NumInputEpsilons(state);
166 }
167
NumStates()168 StateId NumStates() const {
169 return num_states_;
170 }
171
InitStateIterator(StateIteratorData<A> * data)172 void InitStateIterator(StateIteratorData<A>* data) const {
173 data->base = 0;
174 data->nstates = num_states_;
175 }
176
Storage(uint64 num_states,uint64 num_futures,uint64 num_final)177 static size_t Storage(uint64 num_states, uint64 num_futures,
178 uint64 num_final) {
179 uint64 b64;
180 Weight weight;
181 Label label;
182 size_t offset = sizeof(num_states) + sizeof(num_futures) +
183 sizeof(num_final);
184 offset += sizeof(b64) * (
185 BitmapIndex::StorageSize(num_states * 2 + 1) +
186 BitmapIndex::StorageSize(num_futures + num_states + 1) +
187 BitmapIndex::StorageSize(num_states));
188 offset += (num_states + 1) * sizeof(label) + num_futures * sizeof(label);
189 // Pad for alignemnt, see
190 // http://en.wikipedia.org/wiki/Data_structure_alignment#Computing_padding
191 offset = (offset + sizeof(weight) - 1) & ~(sizeof(weight) - 1);
192 offset += (num_states + 1) * sizeof(weight) + num_final * sizeof(weight) +
193 (num_futures + 1) * sizeof(weight);
194 return offset;
195 }
196
SetInstFuture(StateId state,NGramFstInst<A> * inst)197 void SetInstFuture(StateId state, NGramFstInst<A> *inst) const {
198 if (inst->state_ != state) {
199 inst->state_ = state;
200 const size_t next_zero = future_index_.Select0(state + 1);
201 const size_t this_zero = future_index_.Select0(state);
202 inst->num_futures_ = next_zero - this_zero - 1;
203 inst->offset_ = future_index_.Rank1(future_index_.Select0(state) + 1);
204 }
205 }
206
SetInstNode(NGramFstInst<A> * inst)207 void SetInstNode(NGramFstInst<A> *inst) const {
208 if (inst->node_state_ != inst->state_) {
209 inst->node_state_ = inst->state_;
210 inst->node_ = context_index_.Select1(inst->state_);
211 }
212 }
213
SetInstContext(NGramFstInst<A> * inst)214 void SetInstContext(NGramFstInst<A> *inst) const {
215 SetInstNode(inst);
216 if (inst->context_state_ != inst->state_) {
217 inst->context_state_ = inst->state_;
218 inst->context_.clear();
219 size_t node = inst->node_;
220 while (node != 0) {
221 inst->context_.push_back(context_words_[context_index_.Rank1(node)]);
222 node = context_index_.Select1(context_index_.Rank0(node) - 1);
223 }
224 }
225 }
226
227 // Access to the underlying representation
GetData(size_t * data_size)228 const char* GetData(size_t* data_size) const {
229 *data_size = StorageSize();
230 return data_;
231 }
232
233 void Init(const char* data, bool owned, MappedFile *file = 0);
234
GetContext(StateId s,NGramFstInst<A> * inst)235 const vector<Label> &GetContext(StateId s, NGramFstInst<A> *inst) const {
236 SetInstFuture(s, inst);
237 SetInstContext(inst);
238 return inst->context_;
239 }
240
StorageSize()241 size_t StorageSize() const {
242 return Storage(num_states_, num_futures_, num_final_);
243 }
244
245 void GetStates(const vector<Label>& context, vector<StateId> *states) const;
246
247 private:
248 StateId Transition(const vector<Label> &context, Label future) const;
249
250 // Properties always true for this Fst class.
251 static const uint64 kStaticProperties = kAcceptor | kIDeterministic |
252 kODeterministic | kEpsilons | kIEpsilons | kOEpsilons | kILabelSorted |
253 kOLabelSorted | kWeighted | kCyclic | kInitialAcyclic | kNotTopSorted |
254 kAccessible | kCoAccessible | kNotString | kExpanded;
255 // Current file format version.
256 static const int kFileVersion = 4;
257 // Minimum file format version supported.
258 static const int kMinFileVersion = 4;
259
260 MappedFile *data_region_;
261 const char* data_;
262 bool owned_; // True if we own data_
263 uint64 num_states_, num_futures_, num_final_;
264 size_t root_num_children_;
265 const Label *root_children_;
266 size_t root_first_child_;
267 // borrowed references
268 const uint64 *context_, *future_, *final_;
269 const Label *context_words_, *future_words_;
270 const Weight *backoff_, *final_probs_, *future_probs_;
271 BitmapIndex context_index_;
272 BitmapIndex future_index_;
273 BitmapIndex final_index_;
274
275 void operator=(const NGramFstImpl<A> &); // Disallow
276 };
277
278 template<typename A>
NGramFstImpl(const Fst<A> & fst,vector<StateId> * order_out)279 NGramFstImpl<A>::NGramFstImpl(const Fst<A> &fst, vector<StateId>* order_out)
280 : data_region_(0), data_(0), owned_(false) {
281 typedef A Arc;
282 typedef typename Arc::Label Label;
283 typedef typename Arc::Weight Weight;
284 typedef typename Arc::StateId StateId;
285 SetType("ngram");
286 SetInputSymbols(fst.InputSymbols());
287 SetOutputSymbols(fst.OutputSymbols());
288 SetProperties(kStaticProperties);
289
290 // Check basic requirements for an OpenGRM language model Fst.
291 int64 props = kAcceptor | kIDeterministic | kIEpsilons | kILabelSorted;
292 if (fst.Properties(props, true) != props) {
293 FSTERROR() << "NGramFst only accepts OpenGRM langauge models as input";
294 SetProperties(kError, kError);
295 return;
296 }
297
298 int64 num_states = CountStates(fst);
299 Label* context = new Label[num_states];
300
301 // Find the unigram state by starting from the start state, following
302 // epsilons.
303 StateId unigram = fst.Start();
304 while (1) {
305 if (unigram == kNoStateId) {
306 FSTERROR() << "Could not identify unigram state.";
307 SetProperties(kError, kError);
308 return;
309 }
310 ArcIterator<Fst<A> > aiter(fst, unigram);
311 if (aiter.Done()) {
312 LOG(WARNING) << "Unigram state " << unigram << " has no arcs.";
313 break;
314 }
315 if (aiter.Value().ilabel != 0) break;
316 unigram = aiter.Value().nextstate;
317 }
318
319 // Each state's context is determined by the subtree it is under from the
320 // unigram state.
321 queue<pair<StateId, Label> > label_queue;
322 vector<bool> visited(num_states);
323 // Force an epsilon link to the start state.
324 label_queue.push(make_pair(fst.Start(), 0));
325 for (ArcIterator<Fst<A> > aiter(fst, unigram);
326 !aiter.Done(); aiter.Next()) {
327 label_queue.push(make_pair(aiter.Value().nextstate, aiter.Value().ilabel));
328 }
329 // investigate states in breadth first fashion to assign context words.
330 while (!label_queue.empty()) {
331 pair<StateId, Label> &now = label_queue.front();
332 if (!visited[now.first]) {
333 context[now.first] = now.second;
334 visited[now.first] = true;
335 for (ArcIterator<Fst<A> > aiter(fst, now.first);
336 !aiter.Done(); aiter.Next()) {
337 const Arc &arc = aiter.Value();
338 if (arc.ilabel != 0) {
339 label_queue.push(make_pair(arc.nextstate, now.second));
340 }
341 }
342 }
343 label_queue.pop();
344 }
345 visited.clear();
346
347 // The arc from the start state should be assigned an epsilon to put it
348 // in front of the all other labels (which makes Start state 1 after
349 // unigram which is state 0).
350 context[fst.Start()] = 0;
351
352 // Build the tree of contexts fst by reversing the epsilon arcs from fst.
353 VectorFst<Arc> context_fst;
354 uint64 num_final = 0;
355 for (int i = 0; i < num_states; ++i) {
356 if (fst.Final(i) != Weight::Zero()) {
357 ++num_final;
358 }
359 context_fst.SetFinal(context_fst.AddState(), fst.Final(i));
360 }
361 context_fst.SetStart(unigram);
362 context_fst.SetInputSymbols(fst.InputSymbols());
363 context_fst.SetOutputSymbols(fst.OutputSymbols());
364 int64 num_context_arcs = 0;
365 int64 num_futures = 0;
366 for (StateIterator<Fst<A> > siter(fst); !siter.Done(); siter.Next()) {
367 const StateId &state = siter.Value();
368 num_futures += fst.NumArcs(state) - fst.NumInputEpsilons(state);
369 ArcIterator<Fst<A> > aiter(fst, state);
370 if (!aiter.Done()) {
371 const Arc &arc = aiter.Value();
372 // this arc goes from state to arc.nextstate, so create an arc from
373 // arc.nextstate to state to reverse it.
374 if (arc.ilabel == 0) {
375 context_fst.AddArc(arc.nextstate, Arc(context[state], context[state],
376 arc.weight, state));
377 num_context_arcs++;
378 }
379 }
380 }
381 if (num_context_arcs != context_fst.NumStates() - 1) {
382 FSTERROR() << "Number of contexts arcs != number of states - 1";
383 SetProperties(kError, kError);
384 return;
385 }
386 if (context_fst.NumStates() != num_states) {
387 FSTERROR() << "Number of contexts != number of states";
388 SetProperties(kError, kError);
389 return;
390 }
391 int64 context_props = context_fst.Properties(kIDeterministic |
392 kILabelSorted, true);
393 if (!(context_props & kIDeterministic)) {
394 FSTERROR() << "Input fst is not structured properly";
395 SetProperties(kError, kError);
396 return;
397 }
398 if (!(context_props & kILabelSorted)) {
399 ArcSort(&context_fst, ILabelCompare<Arc>());
400 }
401
402 delete [] context;
403
404 uint64 b64;
405 Weight weight;
406 Label label = kNoLabel;
407 const size_t storage = Storage(num_states, num_futures, num_final);
408 MappedFile *data_region = MappedFile::Allocate(storage);
409 char *data = reinterpret_cast<char *>(data_region->mutable_data());
410 memset(data, 0, storage);
411 size_t offset = 0;
412 memcpy(data + offset, reinterpret_cast<char *>(&num_states),
413 sizeof(num_states));
414 offset += sizeof(num_states);
415 memcpy(data + offset, reinterpret_cast<char *>(&num_futures),
416 sizeof(num_futures));
417 offset += sizeof(num_futures);
418 memcpy(data + offset, reinterpret_cast<char *>(&num_final),
419 sizeof(num_final));
420 offset += sizeof(num_final);
421 uint64* context_bits = reinterpret_cast<uint64*>(data + offset);
422 offset += BitmapIndex::StorageSize(num_states * 2 + 1) * sizeof(b64);
423 uint64* future_bits = reinterpret_cast<uint64*>(data + offset);
424 offset +=
425 BitmapIndex::StorageSize(num_futures + num_states + 1) * sizeof(b64);
426 uint64* final_bits = reinterpret_cast<uint64*>(data + offset);
427 offset += BitmapIndex::StorageSize(num_states) * sizeof(b64);
428 Label* context_words = reinterpret_cast<Label*>(data + offset);
429 offset += (num_states + 1) * sizeof(label);
430 Label* future_words = reinterpret_cast<Label*>(data + offset);
431 offset += num_futures * sizeof(label);
432 offset = (offset + sizeof(weight) - 1) & ~(sizeof(weight) - 1);
433 Weight* backoff = reinterpret_cast<Weight*>(data + offset);
434 offset += (num_states + 1) * sizeof(weight);
435 Weight* final_probs = reinterpret_cast<Weight*>(data + offset);
436 offset += num_final * sizeof(weight);
437 Weight* future_probs = reinterpret_cast<Weight*>(data + offset);
438 int64 context_arc = 0, future_arc = 0, context_bit = 0, future_bit = 0,
439 final_bit = 0;
440
441 // pseudo-root bits
442 BitmapIndex::Set(context_bits, context_bit++);
443 ++context_bit;
444 context_words[context_arc] = label;
445 backoff[context_arc] = Weight::Zero();
446 context_arc++;
447
448 ++future_bit;
449 if (order_out) {
450 order_out->clear();
451 order_out->resize(num_states);
452 }
453
454 queue<StateId> context_q;
455 context_q.push(context_fst.Start());
456 StateId state_number = 0;
457 while (!context_q.empty()) {
458 const StateId &state = context_q.front();
459 if (order_out) {
460 (*order_out)[state] = state_number;
461 }
462
463 const Weight &final = context_fst.Final(state);
464 if (final != Weight::Zero()) {
465 BitmapIndex::Set(final_bits, state_number);
466 final_probs[final_bit] = final;
467 ++final_bit;
468 }
469
470 for (ArcIterator<VectorFst<A> > aiter(context_fst, state);
471 !aiter.Done(); aiter.Next()) {
472 const Arc &arc = aiter.Value();
473 context_words[context_arc] = arc.ilabel;
474 backoff[context_arc] = arc.weight;
475 ++context_arc;
476 BitmapIndex::Set(context_bits, context_bit++);
477 context_q.push(arc.nextstate);
478 }
479 ++context_bit;
480
481 for (ArcIterator<Fst<A> > aiter(fst, state); !aiter.Done(); aiter.Next()) {
482 const Arc &arc = aiter.Value();
483 if (arc.ilabel != 0) {
484 future_words[future_arc] = arc.ilabel;
485 future_probs[future_arc] = arc.weight;
486 ++future_arc;
487 BitmapIndex::Set(future_bits, future_bit++);
488 }
489 }
490 ++future_bit;
491 ++state_number;
492 context_q.pop();
493 }
494
495 if ((state_number != num_states) ||
496 (context_bit != num_states * 2 + 1) ||
497 (context_arc != num_states) ||
498 (future_arc != num_futures) ||
499 (future_bit != num_futures + num_states + 1) ||
500 (final_bit != num_final)) {
501 FSTERROR() << "Structure problems detected during construction";
502 SetProperties(kError, kError);
503 return;
504 }
505
506 Init(data, false, data_region);
507 }
508
509 template<typename A>
Init(const char * data,bool owned,MappedFile * data_region)510 inline void NGramFstImpl<A>::Init(const char* data, bool owned,
511 MappedFile *data_region) {
512 if (owned_) {
513 delete [] data_;
514 }
515 delete data_region_;
516 data_region_ = data_region;
517 owned_ = owned;
518 data_ = data;
519 size_t offset = 0;
520 num_states_ = *(reinterpret_cast<const uint64*>(data_ + offset));
521 offset += sizeof(num_states_);
522 num_futures_ = *(reinterpret_cast<const uint64*>(data_ + offset));
523 offset += sizeof(num_futures_);
524 num_final_ = *(reinterpret_cast<const uint64*>(data_ + offset));
525 offset += sizeof(num_final_);
526 uint64 bits;
527 size_t context_bits = num_states_ * 2 + 1;
528 size_t future_bits = num_futures_ + num_states_ + 1;
529 context_ = reinterpret_cast<const uint64*>(data_ + offset);
530 offset += BitmapIndex::StorageSize(context_bits) * sizeof(bits);
531 future_ = reinterpret_cast<const uint64*>(data_ + offset);
532 offset += BitmapIndex::StorageSize(future_bits) * sizeof(bits);
533 final_ = reinterpret_cast<const uint64*>(data_ + offset);
534 offset += BitmapIndex::StorageSize(num_states_) * sizeof(bits);
535 context_words_ = reinterpret_cast<const Label*>(data_ + offset);
536 offset += (num_states_ + 1) * sizeof(*context_words_);
537 future_words_ = reinterpret_cast<const Label*>(data_ + offset);
538 offset += num_futures_ * sizeof(*future_words_);
539 offset = (offset + sizeof(*backoff_) - 1) & ~(sizeof(*backoff_) - 1);
540 backoff_ = reinterpret_cast<const Weight*>(data_ + offset);
541 offset += (num_states_ + 1) * sizeof(*backoff_);
542 final_probs_ = reinterpret_cast<const Weight*>(data_ + offset);
543 offset += num_final_ * sizeof(*final_probs_);
544 future_probs_ = reinterpret_cast<const Weight*>(data_ + offset);
545
546 context_index_.BuildIndex(context_, context_bits);
547 future_index_.BuildIndex(future_, future_bits);
548 final_index_.BuildIndex(final_, num_states_);
549
550 const size_t node_rank = context_index_.Rank1(0);
551 root_first_child_ = context_index_.Select0(node_rank) + 1;
552 if (context_index_.Get(root_first_child_) == false) {
553 FSTERROR() << "Missing unigrams";
554 SetProperties(kError, kError);
555 return;
556 }
557 const size_t last_child = context_index_.Select0(node_rank + 1) - 1;
558 root_num_children_ = last_child - root_first_child_ + 1;
559 root_children_ = context_words_ + context_index_.Rank1(root_first_child_);
560 }
561
562 template<typename A>
Transition(const vector<Label> & context,Label future)563 inline typename A::StateId NGramFstImpl<A>::Transition(
564 const vector<Label> &context, Label future) const {
565 const Label *children = root_children_;
566 const Label *loc = lower_bound(children, children + root_num_children_,
567 future);
568 if (loc == children + root_num_children_ || *loc != future) {
569 return context_index_.Rank1(0);
570 }
571 size_t node = root_first_child_ + loc - children;
572 size_t node_rank = context_index_.Rank1(node);
573 size_t first_child = context_index_.Select0(node_rank) + 1;
574 if (context_index_.Get(first_child) == false) {
575 return context_index_.Rank1(node);
576 }
577 size_t last_child = context_index_.Select0(node_rank + 1) - 1;
578 for (int word = context.size() - 1; word >= 0; --word) {
579 children = context_words_ + context_index_.Rank1(first_child);
580 loc = lower_bound(children, children + last_child - first_child + 1,
581 context[word]);
582 if (loc == children + last_child - first_child + 1 ||
583 *loc != context[word]) {
584 break;
585 }
586 node = first_child + loc - children;
587 node_rank = context_index_.Rank1(node);
588 first_child = context_index_.Select0(node_rank) + 1;
589 if (context_index_.Get(first_child) == false) break;
590 last_child = context_index_.Select0(node_rank + 1) - 1;
591 }
592 return context_index_.Rank1(node);
593 }
594
595 template<typename A>
GetStates(const vector<Label> & context,vector<typename A::StateId> * states)596 inline void NGramFstImpl<A>::GetStates(
597 const vector<Label> &context,
598 vector<typename A::StateId>* states) const {
599 states->clear();
600 states->push_back(0);
601 typename vector<Label>::const_reverse_iterator cit = context.rbegin();
602 const Label *children = root_children_;
603 const Label *loc = lower_bound(children, children + root_num_children_, *cit);
604 if (loc == children + root_num_children_ || *loc != *cit) return;
605 size_t node = root_first_child_ + loc - children;
606 states->push_back(context_index_.Rank1(node));
607 if (context.size() == 1) return;
608 size_t node_rank = context_index_.Rank1(node);
609 size_t first_child = context_index_.Select0(node_rank) + 1;
610 ++cit;
611 if (context_index_.Get(first_child) != false) {
612 size_t last_child = context_index_.Select0(node_rank + 1) - 1;
613 while (cit != context.rend()) {
614 children = context_words_ + context_index_.Rank1(first_child);
615 loc = lower_bound(children, children + last_child - first_child + 1,
616 *cit);
617 if (loc == children + last_child - first_child + 1 || *loc != *cit) {
618 break;
619 }
620 ++cit;
621 node = first_child + loc - children;
622 states->push_back(context_index_.Rank1(node));
623 node_rank = context_index_.Rank1(node);
624 first_child = context_index_.Select0(node_rank) + 1;
625 if (context_index_.Get(first_child) == false) break;
626 last_child = context_index_.Select0(node_rank + 1) - 1;
627 }
628 }
629 }
630
631 /*****************************************************************************/
632 template<class A>
633 class NGramFst : public ImplToExpandedFst<NGramFstImpl<A> > {
634 friend class ArcIterator<NGramFst<A> >;
635 friend class NGramFstMatcher<A>;
636
637 public:
638 typedef A Arc;
639 typedef typename A::StateId StateId;
640 typedef typename A::Label Label;
641 typedef typename A::Weight Weight;
642 typedef NGramFstImpl<A> Impl;
643
NGramFst(const Fst<A> & dst)644 explicit NGramFst(const Fst<A> &dst)
645 : ImplToExpandedFst<Impl>(new Impl(dst, NULL)) {}
646
NGramFst(const Fst<A> & fst,vector<StateId> * order_out)647 NGramFst(const Fst<A> &fst, vector<StateId>* order_out)
648 : ImplToExpandedFst<Impl>(new Impl(fst, order_out)) {}
649
650 // Because the NGramFstImpl is a const stateless data structure, there
651 // is never a need to do anything beside copy the reference.
652 NGramFst(const NGramFst<A> &fst, bool safe = false)
653 : ImplToExpandedFst<Impl>(fst, false) {}
654
NGramFst()655 NGramFst() : ImplToExpandedFst<Impl>(new Impl()) {}
656
657 // Non-standard constructor to initialize NGramFst directly from data.
NGramFst(const char * data,bool owned)658 NGramFst(const char* data, bool owned) : ImplToExpandedFst<Impl>(new Impl()) {
659 GetImpl()->Init(data, owned, NULL);
660 }
661
662 // Get method that gets the data associated with Init().
GetData(size_t * data_size)663 const char* GetData(size_t* data_size) const {
664 return GetImpl()->GetData(data_size);
665 }
666
GetContext(StateId s)667 const vector<Label> GetContext(StateId s) const {
668 return GetImpl()->GetContext(s, &inst_);
669 }
670
671 // Consumes as much as possible of context from right to left, returns the
672 // the states corresponding to the increasingly conditioned input sequence.
GetStates(const vector<Label> & context,vector<StateId> * state)673 void GetStates(const vector<Label>& context, vector<StateId> *state) const {
674 return GetImpl()->GetStates(context, state);
675 }
676
NumArcs(StateId s)677 virtual size_t NumArcs(StateId s) const {
678 return GetImpl()->NumArcs(s, &inst_);
679 }
680
681 virtual NGramFst<A>* Copy(bool safe = false) const {
682 return new NGramFst(*this, safe);
683 }
684
Read(istream & strm,const FstReadOptions & opts)685 static NGramFst<A>* Read(istream &strm, const FstReadOptions &opts) {
686 Impl* impl = Impl::Read(strm, opts);
687 return impl ? new NGramFst<A>(impl) : 0;
688 }
689
Read(const string & filename)690 static NGramFst<A>* Read(const string &filename) {
691 if (!filename.empty()) {
692 ifstream strm(filename.c_str(), ifstream::in | ifstream::binary);
693 if (!strm) {
694 LOG(ERROR) << "NGramFst::Read: Can't open file: " << filename;
695 return 0;
696 }
697 return Read(strm, FstReadOptions(filename));
698 } else {
699 return Read(cin, FstReadOptions("standard input"));
700 }
701 }
702
Write(ostream & strm,const FstWriteOptions & opts)703 virtual bool Write(ostream &strm, const FstWriteOptions &opts) const {
704 return GetImpl()->Write(strm, opts);
705 }
706
Write(const string & filename)707 virtual bool Write(const string &filename) const {
708 return Fst<A>::WriteFile(filename);
709 }
710
InitStateIterator(StateIteratorData<A> * data)711 virtual inline void InitStateIterator(StateIteratorData<A>* data) const {
712 GetImpl()->InitStateIterator(data);
713 }
714
715 virtual inline void InitArcIterator(
716 StateId s, ArcIteratorData<A>* data) const;
717
InitMatcher(MatchType match_type)718 virtual MatcherBase<A>* InitMatcher(MatchType match_type) const {
719 return new NGramFstMatcher<A>(*this, match_type);
720 }
721
StorageSize()722 size_t StorageSize() const {
723 return GetImpl()->StorageSize();
724 }
725
726 private:
NGramFst(Impl * impl)727 explicit NGramFst(Impl* impl) : ImplToExpandedFst<Impl>(impl) {}
728
GetImpl()729 Impl* GetImpl() const {
730 return
731 ImplToExpandedFst<Impl, ExpandedFst<A> >::GetImpl();
732 }
733
734 void SetImpl(Impl* impl, bool own_impl = true) {
735 ImplToExpandedFst<Impl, Fst<A> >::SetImpl(impl, own_impl);
736 }
737
738 mutable NGramFstInst<A> inst_;
739 };
740
741 template <class A> inline void
InitArcIterator(StateId s,ArcIteratorData<A> * data)742 NGramFst<A>::InitArcIterator(StateId s, ArcIteratorData<A>* data) const {
743 GetImpl()->SetInstFuture(s, &inst_);
744 GetImpl()->SetInstNode(&inst_);
745 data->base = new ArcIterator<NGramFst<A> >(*this, s);
746 }
747
748 /*****************************************************************************/
749 template <class A>
750 class NGramFstMatcher : public MatcherBase<A> {
751 public:
752 typedef A Arc;
753 typedef typename A::Label Label;
754 typedef typename A::StateId StateId;
755 typedef typename A::Weight Weight;
756
NGramFstMatcher(const NGramFst<A> & fst,MatchType match_type)757 NGramFstMatcher(const NGramFst<A> &fst, MatchType match_type)
758 : fst_(fst), inst_(fst.inst_), match_type_(match_type),
759 current_loop_(false),
760 loop_(kNoLabel, 0, A::Weight::One(), kNoStateId) {
761 if (match_type_ == MATCH_OUTPUT) {
762 swap(loop_.ilabel, loop_.olabel);
763 }
764 }
765
766 NGramFstMatcher(const NGramFstMatcher<A> &matcher, bool safe = false)
767 : fst_(matcher.fst_), inst_(matcher.inst_),
768 match_type_(matcher.match_type_), current_loop_(false),
769 loop_(kNoLabel, 0, A::Weight::One(), kNoStateId) {
770 if (match_type_ == MATCH_OUTPUT) {
771 swap(loop_.ilabel, loop_.olabel);
772 }
773 }
774
775 virtual NGramFstMatcher<A>* Copy(bool safe = false) const {
776 return new NGramFstMatcher<A>(*this, safe);
777 }
778
Type(bool test)779 virtual MatchType Type(bool test) const {
780 return match_type_;
781 }
782
GetFst()783 virtual const Fst<A> &GetFst() const {
784 return fst_;
785 }
786
Properties(uint64 props)787 virtual uint64 Properties(uint64 props) const {
788 return props;
789 }
790
791 private:
SetState_(StateId s)792 virtual void SetState_(StateId s) {
793 fst_.GetImpl()->SetInstFuture(s, &inst_);
794 current_loop_ = false;
795 }
796
Find_(Label label)797 virtual bool Find_(Label label) {
798 const Label nolabel = kNoLabel;
799 done_ = true;
800 if (label == 0 || label == nolabel) {
801 if (label == 0) {
802 current_loop_ = true;
803 loop_.nextstate = inst_.state_;
804 }
805 // The unigram state has no epsilon arc.
806 if (inst_.state_ != 0) {
807 arc_.ilabel = arc_.olabel = 0;
808 fst_.GetImpl()->SetInstNode(&inst_);
809 arc_.nextstate = fst_.GetImpl()->context_index_.Rank1(
810 fst_.GetImpl()->context_index_.Select1(
811 fst_.GetImpl()->context_index_.Rank0(inst_.node_) - 1));
812 arc_.weight = fst_.GetImpl()->backoff_[inst_.state_];
813 done_ = false;
814 }
815 } else {
816 const Label *start = fst_.GetImpl()->future_words_ + inst_.offset_;
817 const Label *end = start + inst_.num_futures_;
818 const Label* search = lower_bound(start, end, label);
819 if (search != end && *search == label) {
820 size_t state = search - start;
821 arc_.ilabel = arc_.olabel = label;
822 arc_.weight = fst_.GetImpl()->future_probs_[inst_.offset_ + state];
823 fst_.GetImpl()->SetInstContext(&inst_);
824 arc_.nextstate = fst_.GetImpl()->Transition(inst_.context_, label);
825 done_ = false;
826 }
827 }
828 return !Done_();
829 }
830
Done_()831 virtual bool Done_() const {
832 return !current_loop_ && done_;
833 }
834
Value_()835 virtual const Arc& Value_() const {
836 return (current_loop_) ? loop_ : arc_;
837 }
838
Next_()839 virtual void Next_() {
840 if (current_loop_) {
841 current_loop_ = false;
842 } else {
843 done_ = true;
844 }
845 }
846
847 const NGramFst<A>& fst_;
848 NGramFstInst<A> inst_;
849 MatchType match_type_; // Supplied by caller
850 bool done_;
851 Arc arc_;
852 bool current_loop_; // Current arc is the implicit loop
853 Arc loop_;
854 };
855
856 /*****************************************************************************/
857 template<class A>
858 class ArcIterator<NGramFst<A> > : public ArcIteratorBase<A> {
859 public:
860 typedef A Arc;
861 typedef typename A::Label Label;
862 typedef typename A::StateId StateId;
863 typedef typename A::Weight Weight;
864
ArcIterator(const NGramFst<A> & fst,StateId state)865 ArcIterator(const NGramFst<A> &fst, StateId state)
866 : lazy_(~0), impl_(fst.GetImpl()), i_(0), flags_(kArcValueFlags) {
867 inst_ = fst.inst_;
868 impl_->SetInstFuture(state, &inst_);
869 impl_->SetInstNode(&inst_);
870 }
871
Done()872 bool Done() const {
873 return i_ >= ((inst_.node_ == 0) ? inst_.num_futures_ :
874 inst_.num_futures_ + 1);
875 }
876
Value()877 const Arc &Value() const {
878 bool eps = (inst_.node_ != 0 && i_ == 0);
879 StateId state = (inst_.node_ == 0) ? i_ : i_ - 1;
880 if (flags_ & lazy_ & (kArcILabelValue | kArcOLabelValue)) {
881 arc_.ilabel =
882 arc_.olabel = eps ? 0 : impl_->future_words_[inst_.offset_ + state];
883 lazy_ &= ~(kArcILabelValue | kArcOLabelValue);
884 }
885 if (flags_ & lazy_ & kArcNextStateValue) {
886 if (eps) {
887 arc_.nextstate = impl_->context_index_.Rank1(
888 impl_->context_index_.Select1(
889 impl_->context_index_.Rank0(inst_.node_) - 1));
890 } else {
891 if (lazy_ & kArcNextStateValue) {
892 impl_->SetInstContext(&inst_); // first time only.
893 }
894 arc_.nextstate =
895 impl_->Transition(inst_.context_,
896 impl_->future_words_[inst_.offset_ + state]);
897 }
898 lazy_ &= ~kArcNextStateValue;
899 }
900 if (flags_ & lazy_ & kArcWeightValue) {
901 arc_.weight = eps ? impl_->backoff_[inst_.state_] :
902 impl_->future_probs_[inst_.offset_ + state];
903 lazy_ &= ~kArcWeightValue;
904 }
905 return arc_;
906 }
907
Next()908 void Next() {
909 ++i_;
910 lazy_ = ~0;
911 }
912
Position()913 size_t Position() const { return i_; }
914
Reset()915 void Reset() {
916 i_ = 0;
917 lazy_ = ~0;
918 }
919
Seek(size_t a)920 void Seek(size_t a) {
921 if (i_ != a) {
922 i_ = a;
923 lazy_ = ~0;
924 }
925 }
926
Flags()927 uint32 Flags() const {
928 return flags_;
929 }
930
SetFlags(uint32 f,uint32 m)931 void SetFlags(uint32 f, uint32 m) {
932 flags_ &= ~m;
933 flags_ |= (f & kArcValueFlags);
934 }
935
936 private:
Done_()937 virtual bool Done_() const { return Done(); }
Value_()938 virtual const Arc& Value_() const { return Value(); }
Next_()939 virtual void Next_() { Next(); }
Position_()940 virtual size_t Position_() const { return Position(); }
Reset_()941 virtual void Reset_() { Reset(); }
Seek_(size_t a)942 virtual void Seek_(size_t a) { Seek(a); }
Flags_()943 uint32 Flags_() const { return Flags(); }
SetFlags_(uint32 f,uint32 m)944 void SetFlags_(uint32 f, uint32 m) { SetFlags(f, m); }
945
946 mutable Arc arc_;
947 mutable uint32 lazy_;
948 const NGramFstImpl<A> *impl_;
949 mutable NGramFstInst<A> inst_;
950
951 size_t i_;
952 uint32 flags_;
953
954 DISALLOW_COPY_AND_ASSIGN(ArcIterator);
955 };
956
957 /*****************************************************************************/
958 // Specialization for NGramFst; see generic version in fst.h
959 // for sample usage (but use the ProdLmFst type!). This version
960 // should inline.
961 template <class A>
962 class StateIterator<NGramFst<A> > : public StateIteratorBase<A> {
963 public:
964 typedef typename A::StateId StateId;
965
StateIterator(const NGramFst<A> & fst)966 explicit StateIterator(const NGramFst<A> &fst)
967 : s_(0), num_states_(fst.NumStates()) { }
968
Done()969 bool Done() const { return s_ >= num_states_; }
Value()970 StateId Value() const { return s_; }
Next()971 void Next() { ++s_; }
Reset()972 void Reset() { s_ = 0; }
973
974 private:
Done_()975 virtual bool Done_() const { return Done(); }
Value_()976 virtual StateId Value_() const { return Value(); }
Next_()977 virtual void Next_() { Next(); }
Reset_()978 virtual void Reset_() { Reset(); }
979
980 StateId s_, num_states_;
981
982 DISALLOW_COPY_AND_ASSIGN(StateIterator);
983 };
984 } // namespace fst
985 #endif // FST_EXTENSIONS_NGRAM_NGRAM_FST_H_
986