• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // accumulator.h
2 
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Copyright 2005-2010 Google, Inc.
16 // Author: riley@google.com (Michael Riley)
17 //
18 // \file
19 // Classes to accumulate arc weights. Useful for weight lookahead.
20 
21 #ifndef FST_LIB_ACCUMULATOR_H__
22 #define FST_LIB_ACCUMULATOR_H__
23 
24 #include <algorithm>
25 #include <functional>
26 #include <tr1/unordered_map>
27 using std::tr1::unordered_map;
28 using std::tr1::unordered_multimap;
29 #include <vector>
30 using std::vector;
31 
32 #include <fst/arcfilter.h>
33 #include <fst/arcsort.h>
34 #include <fst/dfs-visit.h>
35 #include <fst/expanded-fst.h>
36 #include <fst/replace.h>
37 
38 namespace fst {
39 
40 // This class accumulates arc weights using the semiring Plus().
41 template <class A>
42 class DefaultAccumulator {
43  public:
44   typedef A Arc;
45   typedef typename A::StateId StateId;
46   typedef typename A::Weight Weight;
47 
DefaultAccumulator()48   DefaultAccumulator() {}
49 
DefaultAccumulator(const DefaultAccumulator<A> & acc)50   DefaultAccumulator(const DefaultAccumulator<A> &acc) {}
51 
52   void Init(const Fst<A>& fst, bool copy = false) {}
53 
SetState(StateId)54   void SetState(StateId) {}
55 
Sum(Weight w,Weight v)56   Weight Sum(Weight w, Weight v) {
57     return Plus(w, v);
58   }
59 
60   template <class ArcIterator>
Sum(Weight w,ArcIterator * aiter,ssize_t begin,ssize_t end)61   Weight Sum(Weight w, ArcIterator *aiter, ssize_t begin,
62              ssize_t end) {
63     Weight sum = w;
64     aiter->Seek(begin);
65     for (ssize_t pos = begin; pos < end; aiter->Next(), ++pos)
66       sum = Plus(sum, aiter->Value().weight);
67     return sum;
68   }
69 
Error()70   bool Error() const { return false; }
71 
72  private:
73   void operator=(const DefaultAccumulator<A> &);   // Disallow
74 };
75 
76 
77 // This class accumulates arc weights using the log semiring Plus()
78 // assuming an arc weight has a WeightConvert specialization to
79 // and from log64 weights.
80 template <class A>
81 class LogAccumulator {
82  public:
83   typedef A Arc;
84   typedef typename A::StateId StateId;
85   typedef typename A::Weight Weight;
86 
LogAccumulator()87   LogAccumulator() {}
88 
LogAccumulator(const LogAccumulator<A> & acc)89   LogAccumulator(const LogAccumulator<A> &acc) {}
90 
91   void Init(const Fst<A>& fst, bool copy = false) {}
92 
SetState(StateId)93   void SetState(StateId) {}
94 
Sum(Weight w,Weight v)95   Weight Sum(Weight w, Weight v) {
96     return LogPlus(w, v);
97   }
98 
99   template <class ArcIterator>
Sum(Weight w,ArcIterator * aiter,ssize_t begin,ssize_t end)100   Weight Sum(Weight w, ArcIterator *aiter, ssize_t begin,
101              ssize_t end) {
102     Weight sum = w;
103     aiter->Seek(begin);
104     for (ssize_t pos = begin; pos < end; aiter->Next(), ++pos)
105       sum = LogPlus(sum, aiter->Value().weight);
106     return sum;
107   }
108 
Error()109   bool Error() const { return false; }
110 
111  private:
LogPosExp(double x)112   double LogPosExp(double x) { return log(1.0F + exp(-x)); }
113 
LogPlus(Weight w,Weight v)114   Weight LogPlus(Weight w, Weight v) {
115     double f1 = to_log_weight_(w).Value();
116     double f2 = to_log_weight_(v).Value();
117     if (f1 > f2)
118       return to_weight_(f2 - LogPosExp(f1 - f2));
119     else
120       return to_weight_(f1 - LogPosExp(f2 - f1));
121   }
122 
123   WeightConvert<Weight, Log64Weight> to_log_weight_;
124   WeightConvert<Log64Weight, Weight> to_weight_;
125 
126   void operator=(const LogAccumulator<A> &);   // Disallow
127 };
128 
129 
130 // Stores shareable data for fast log accumulator copies.
131 class FastLogAccumulatorData {
132  public:
FastLogAccumulatorData()133   FastLogAccumulatorData() {}
134 
Weights()135   vector<double> *Weights() { return &weights_; }
WeightPositions()136   vector<ssize_t> *WeightPositions() { return &weight_positions_; }
WeightEnd()137   double *WeightEnd() { return &(weights_[weights_.size() - 1]); };
RefCount()138   int RefCount() const { return ref_count_.count(); }
IncrRefCount()139   int IncrRefCount() { return ref_count_.Incr(); }
DecrRefCount()140   int DecrRefCount() { return ref_count_.Decr(); }
141 
142  private:
143   // Cummulative weight per state for all states s.t. # of arcs >
144   // arc_limit_ with arcs in order. Special first element per state
145   // being Log64Weight::Zero();
146   vector<double> weights_;
147   // Maps from state to corresponding beginning weight position in
148   // weights_. Position -1 means no pre-computed weights for that
149   // state.
150   vector<ssize_t> weight_positions_;
151   RefCounter ref_count_;                  // Reference count.
152 
153   DISALLOW_COPY_AND_ASSIGN(FastLogAccumulatorData);
154 };
155 
156 
157 // This class accumulates arc weights using the log semiring Plus()
158 // assuming an arc weight has a WeightConvert specialization to and
159 // from log64 weights. The member function Init(fst) has to be called
160 // to setup pre-computed weight information.
161 template <class A>
162 class FastLogAccumulator {
163  public:
164   typedef A Arc;
165   typedef typename A::StateId StateId;
166   typedef typename A::Weight Weight;
167 
168   explicit FastLogAccumulator(ssize_t arc_limit = 20, ssize_t arc_period = 10)
arc_limit_(arc_limit)169       : arc_limit_(arc_limit),
170         arc_period_(arc_period),
171         data_(new FastLogAccumulatorData()),
172         error_(false) {}
173 
FastLogAccumulator(const FastLogAccumulator<A> & acc)174   FastLogAccumulator(const FastLogAccumulator<A> &acc)
175       : arc_limit_(acc.arc_limit_),
176         arc_period_(acc.arc_period_),
177         data_(acc.data_),
178         error_(acc.error_) {
179     data_->IncrRefCount();
180   }
181 
~FastLogAccumulator()182   ~FastLogAccumulator() {
183     if (!data_->DecrRefCount())
184       delete data_;
185   }
186 
SetState(StateId s)187   void SetState(StateId s) {
188     vector<double> &weights = *data_->Weights();
189     vector<ssize_t> &weight_positions = *data_->WeightPositions();
190 
191     if (weight_positions.size() <= s) {
192       FSTERROR() << "FastLogAccumulator::SetState: invalid state id.";
193       error_ = true;
194       return;
195     }
196 
197     ssize_t pos = weight_positions[s];
198     if (pos >= 0)
199       state_weights_ = &(weights[pos]);
200     else
201       state_weights_ = 0;
202   }
203 
Sum(Weight w,Weight v)204   Weight Sum(Weight w, Weight v) {
205     return LogPlus(w, v);
206   }
207 
208   template <class ArcIterator>
Sum(Weight w,ArcIterator * aiter,ssize_t begin,ssize_t end)209   Weight Sum(Weight w, ArcIterator *aiter, ssize_t begin,
210              ssize_t end) {
211     if (error_) return Weight::NoWeight();
212     Weight sum = w;
213     // Finds begin and end of pre-stored weights
214     ssize_t index_begin = -1, index_end = -1;
215     ssize_t stored_begin = end, stored_end = end;
216     if (state_weights_ != 0) {
217       index_begin = begin > 0 ? (begin - 1)/ arc_period_ + 1 : 0;
218       index_end = end / arc_period_;
219       stored_begin = index_begin * arc_period_;
220       stored_end = index_end * arc_period_;
221     }
222     // Computes sum before pre-stored weights
223     if (begin < stored_begin) {
224       ssize_t pos_end = min(stored_begin, end);
225       aiter->Seek(begin);
226       for (ssize_t pos = begin; pos < pos_end; aiter->Next(), ++pos)
227         sum = LogPlus(sum, aiter->Value().weight);
228     }
229     // Computes sum between pre-stored weights
230     if (stored_begin < stored_end) {
231       sum = LogPlus(sum, LogMinus(state_weights_[index_end],
232                                   state_weights_[index_begin]));
233     }
234     // Computes sum after pre-stored weights
235     if (stored_end < end) {
236       ssize_t pos_start = max(stored_begin, stored_end);
237       aiter->Seek(pos_start);
238       for (ssize_t pos = pos_start; pos < end; aiter->Next(), ++pos)
239         sum = LogPlus(sum, aiter->Value().weight);
240     }
241     return sum;
242   }
243 
244   template <class F>
245   void Init(const F &fst, bool copy = false) {
246     if (copy)
247       return;
248     vector<double> &weights = *data_->Weights();
249     vector<ssize_t> &weight_positions = *data_->WeightPositions();
250     if (!weights.empty() || arc_limit_ < arc_period_) {
251       FSTERROR() << "FastLogAccumulator: initialization error.";
252       error_ = true;
253       return;
254     }
255     weight_positions.reserve(CountStates(fst));
256 
257     ssize_t weight_position = 0;
258     for(StateIterator<F> siter(fst); !siter.Done(); siter.Next()) {
259       StateId s = siter.Value();
260       if (fst.NumArcs(s) >= arc_limit_) {
261         double sum = FloatLimits<double>::PosInfinity();
262         weight_positions.push_back(weight_position);
263         weights.push_back(sum);
264         ++weight_position;
265         ssize_t narcs = 0;
266         for(ArcIterator<F> aiter(fst, s); !aiter.Done(); aiter.Next()) {
267           const A &arc = aiter.Value();
268           sum = LogPlus(sum, arc.weight);
269           // Stores cumulative weight distribution per arc_period_.
270           if (++narcs % arc_period_ == 0) {
271             weights.push_back(sum);
272             ++weight_position;
273           }
274         }
275       } else {
276         weight_positions.push_back(-1);
277       }
278     }
279   }
280 
Error()281   bool Error() const { return error_; }
282 
283  private:
LogPosExp(double x)284   double LogPosExp(double x) {
285     return x == FloatLimits<double>::PosInfinity() ?
286         0.0 : log(1.0F + exp(-x));
287   }
288 
LogMinusExp(double x)289   double LogMinusExp(double x) {
290     return x == FloatLimits<double>::PosInfinity() ?
291         0.0 : log(1.0F - exp(-x));
292   }
293 
LogPlus(Weight w,Weight v)294   Weight LogPlus(Weight w, Weight v) {
295     double f1 = to_log_weight_(w).Value();
296     double f2 = to_log_weight_(v).Value();
297     if (f1 > f2)
298       return to_weight_(f2 - LogPosExp(f1 - f2));
299     else
300       return to_weight_(f1 - LogPosExp(f2 - f1));
301   }
302 
LogPlus(double f1,Weight v)303   double LogPlus(double f1, Weight v) {
304     double f2 = to_log_weight_(v).Value();
305     if (f1 == FloatLimits<double>::PosInfinity())
306       return f2;
307     else if (f1 > f2)
308       return f2 - LogPosExp(f1 - f2);
309     else
310       return f1 - LogPosExp(f2 - f1);
311   }
312 
LogMinus(double f1,double f2)313   Weight LogMinus(double f1, double f2) {
314     if (f1 >= f2) {
315       FSTERROR() << "FastLogAcumulator::LogMinus: f1 >= f2 with f1 = " << f1
316                  << " and f2 = " << f2;
317       error_ = true;
318       return Weight::NoWeight();
319     }
320     if (f2 == FloatLimits<double>::PosInfinity())
321       return to_weight_(f1);
322     else
323       return to_weight_(f1 - LogMinusExp(f2 - f1));
324   }
325 
326   WeightConvert<Weight, Log64Weight> to_log_weight_;
327   WeightConvert<Log64Weight, Weight> to_weight_;
328 
329   ssize_t arc_limit_;     // Minimum # of arcs to pre-compute state
330   ssize_t arc_period_;    // Save cumulative weights per 'arc_period_'.
331   bool init_;             // Cumulative weights initialized?
332   FastLogAccumulatorData *data_;
333   double *state_weights_;
334   bool error_;
335 
336   void operator=(const FastLogAccumulator<A> &);   // Disallow
337 };
338 
339 
340 // Stores shareable data for cache log accumulator copies.
341 // All copies share the same cache.
342 template <class A>
343 class CacheLogAccumulatorData {
344  public:
345   typedef A Arc;
346   typedef typename A::StateId StateId;
347   typedef typename A::Weight Weight;
348 
CacheLogAccumulatorData(bool gc,size_t gc_limit)349   CacheLogAccumulatorData(bool gc, size_t gc_limit)
350       : cache_gc_(gc), cache_limit_(gc_limit), cache_size_(0) {}
351 
~CacheLogAccumulatorData()352   ~CacheLogAccumulatorData() {
353     for(typename unordered_map<StateId, CacheState>::iterator it = cache_.begin();
354         it != cache_.end();
355         ++it)
356       delete it->second.weights;
357   }
358 
CacheDisabled()359   bool CacheDisabled() const { return cache_gc_ && cache_limit_ == 0; }
360 
GetWeights(StateId s)361   vector<double> *GetWeights(StateId s) {
362     typename unordered_map<StateId, CacheState>::iterator it = cache_.find(s);
363     if (it != cache_.end()) {
364       it->second.recent = true;
365       return it->second.weights;
366     } else {
367       return 0;
368     }
369   }
370 
AddWeights(StateId s,vector<double> * weights)371   void AddWeights(StateId s, vector<double> *weights) {
372     if (cache_gc_ && cache_size_ >= cache_limit_)
373       GC(false);
374     cache_.insert(make_pair(s, CacheState(weights, true)));
375     if (cache_gc_)
376       cache_size_ += weights->capacity() * sizeof(double);
377   }
378 
RefCount()379   int RefCount() const { return ref_count_.count(); }
IncrRefCount()380   int IncrRefCount() { return ref_count_.Incr(); }
DecrRefCount()381   int DecrRefCount() { return ref_count_.Decr(); }
382 
383  private:
384   // Cached information for a given state.
385   struct CacheState {
386     vector<double>* weights;  // Accumulated weights for this state.
387     bool recent;              // Has this state been accessed since last GC?
388 
CacheStateCacheState389     CacheState(vector<double> *w, bool r) : weights(w), recent(r) {}
390   };
391 
392   // Garbage collect: Delete from cache states that have not been
393   // accessed since the last GC ('free_recent = false') until
394   // 'cache_size_' is 2/3 of 'cache_limit_'. If it does not free enough
395   // memory, start deleting recently accessed states.
GC(bool free_recent)396   void GC(bool free_recent) {
397     size_t cache_target = (2 * cache_limit_)/3 + 1;
398     typename unordered_map<StateId, CacheState>::iterator it = cache_.begin();
399     while (it != cache_.end() && cache_size_ > cache_target) {
400       CacheState &cs = it->second;
401       if (free_recent || !cs.recent) {
402         cache_size_ -= cs.weights->capacity() * sizeof(double);
403         delete cs.weights;
404         cache_.erase(it++);
405       } else {
406         cs.recent = false;
407         ++it;
408       }
409     }
410     if (!free_recent && cache_size_ > cache_target)
411       GC(true);
412   }
413 
414   unordered_map<StateId, CacheState> cache_;  // Cache
415   bool cache_gc_;                        // Enable garbage collection
416   size_t cache_limit_;                   // # of bytes cached
417   size_t cache_size_;                    // # of bytes allowed before GC
418   RefCounter ref_count_;
419 
420   DISALLOW_COPY_AND_ASSIGN(CacheLogAccumulatorData);
421 };
422 
423 // This class accumulates arc weights using the log semiring Plus()
424 //  has a WeightConvert specialization to and from log64 weights.  It
425 //  is similar to the FastLogAccumator. However here, the accumulated
426 //  weights are pre-computed and stored only for the states that are
427 //  visited. The member function Init(fst) has to be called to setup
428 //  this accumulator.
429 template <class A>
430 class CacheLogAccumulator {
431  public:
432   typedef A Arc;
433   typedef typename A::StateId StateId;
434   typedef typename A::Weight Weight;
435 
436   explicit CacheLogAccumulator(ssize_t arc_limit = 10, bool gc = false,
437                                size_t gc_limit = 10 * 1024 * 1024)
arc_limit_(arc_limit)438       : arc_limit_(arc_limit), fst_(0), data_(
439           new CacheLogAccumulatorData<A>(gc, gc_limit)), s_(kNoStateId),
440         error_(false) {}
441 
CacheLogAccumulator(const CacheLogAccumulator<A> & acc)442   CacheLogAccumulator(const CacheLogAccumulator<A> &acc)
443       : arc_limit_(acc.arc_limit_), fst_(acc.fst_ ? acc.fst_->Copy() : 0),
444         data_(acc.data_), s_(kNoStateId), error_(acc.error_) {
445     data_->IncrRefCount();
446   }
447 
~CacheLogAccumulator()448   ~CacheLogAccumulator() {
449     if (fst_)
450       delete fst_;
451     if (!data_->DecrRefCount())
452       delete data_;
453   }
454 
455   // Arg 'arc_limit' specifies minimum # of arcs to pre-compute state.
456   void Init(const Fst<A> &fst, bool copy = false) {
457     if (copy) {
458       delete fst_;
459     } else if (fst_) {
460       FSTERROR() << "CacheLogAccumulator: initialization error.";
461       error_ = true;
462       return;
463     }
464     fst_ = fst.Copy();
465   }
466 
467   void SetState(StateId s, int depth = 0) {
468     if (s == s_)
469       return;
470     s_ = s;
471 
472     if (data_->CacheDisabled() || error_) {
473       weights_ = 0;
474       return;
475     }
476 
477     if (!fst_) {
478       FSTERROR() << "CacheLogAccumulator::SetState: incorrectly initialized.";
479       error_ = true;
480       weights_ = 0;
481       return;
482     }
483 
484     weights_ = data_->GetWeights(s);
485     if ((weights_ == 0) && (fst_->NumArcs(s) >= arc_limit_)) {
486       weights_ = new vector<double>;
487       weights_->reserve(fst_->NumArcs(s) + 1);
488       weights_->push_back(FloatLimits<double>::PosInfinity());
489       data_->AddWeights(s, weights_);
490     }
491   }
492 
Sum(Weight w,Weight v)493   Weight Sum(Weight w, Weight v) {
494     return LogPlus(w, v);
495   }
496 
497   template <class Iterator>
Sum(Weight w,Iterator * aiter,ssize_t begin,ssize_t end)498   Weight Sum(Weight w, Iterator *aiter, ssize_t begin,
499              ssize_t end) {
500     if (weights_ == 0) {
501       Weight sum = w;
502       aiter->Seek(begin);
503       for (ssize_t pos = begin; pos < end; aiter->Next(), ++pos)
504         sum = LogPlus(sum, aiter->Value().weight);
505       return sum;
506     } else {
507       if (weights_->size() <= end)
508         for (aiter->Seek(weights_->size() - 1);
509              weights_->size() <= end;
510              aiter->Next())
511           weights_->push_back(LogPlus(weights_->back(),
512                                       aiter->Value().weight));
513       return LogPlus(w, LogMinus((*weights_)[end], (*weights_)[begin]));
514     }
515   }
516 
517   template <class Iterator>
LowerBound(double w,Iterator * aiter)518   size_t LowerBound(double w, Iterator *aiter) {
519     if (weights_ != 0) {
520       return lower_bound(weights_->begin() + 1,
521                          weights_->end(),
522                          w,
523                          std::greater<double>())
524           - weights_->begin() - 1;
525     } else {
526       size_t n = 0;
527       double x =  FloatLimits<double>::PosInfinity();
528       for(aiter->Reset(); !aiter->Done(); aiter->Next(), ++n) {
529         x = LogPlus(x, aiter->Value().weight);
530         if (x < w) break;
531       }
532       return n;
533     }
534   }
535 
Error()536   bool Error() const { return error_; }
537 
538  private:
LogPosExp(double x)539   double LogPosExp(double x) {
540     return x == FloatLimits<double>::PosInfinity() ?
541         0.0 : log(1.0F + exp(-x));
542   }
543 
LogMinusExp(double x)544   double LogMinusExp(double x) {
545     return x == FloatLimits<double>::PosInfinity() ?
546         0.0 : log(1.0F - exp(-x));
547   }
548 
LogPlus(Weight w,Weight v)549   Weight LogPlus(Weight w, Weight v) {
550     double f1 = to_log_weight_(w).Value();
551     double f2 = to_log_weight_(v).Value();
552     if (f1 > f2)
553       return to_weight_(f2 - LogPosExp(f1 - f2));
554     else
555       return to_weight_(f1 - LogPosExp(f2 - f1));
556   }
557 
LogPlus(double f1,Weight v)558   double LogPlus(double f1, Weight v) {
559     double f2 = to_log_weight_(v).Value();
560     if (f1 == FloatLimits<double>::PosInfinity())
561       return f2;
562     else if (f1 > f2)
563       return f2 - LogPosExp(f1 - f2);
564     else
565       return f1 - LogPosExp(f2 - f1);
566   }
567 
LogMinus(double f1,double f2)568   Weight LogMinus(double f1, double f2) {
569     if (f1 >= f2) {
570       FSTERROR() << "CacheLogAcumulator::LogMinus: f1 >= f2 with f1 = " << f1
571                  << " and f2 = " << f2;
572       error_ = true;
573       return Weight::NoWeight();
574     }
575     if (f2 == FloatLimits<double>::PosInfinity())
576       return to_weight_(f1);
577     else
578       return to_weight_(f1 - LogMinusExp(f2 - f1));
579   }
580 
581   WeightConvert<Weight, Log64Weight> to_log_weight_;
582   WeightConvert<Log64Weight, Weight> to_weight_;
583 
584   ssize_t arc_limit_;                    // Minimum # of arcs to cache a state
585   vector<double> *weights_;              // Accumulated weights for cur. state
586   const Fst<A>* fst_;                    // Input fst
587   CacheLogAccumulatorData<A> *data_;     // Cache data
588   StateId s_;                            // Current state
589   bool error_;
590 
591   void operator=(const CacheLogAccumulator<A> &);   // Disallow
592 };
593 
594 
595 // Stores shareable data for replace accumulator copies.
596 template <class Accumulator, class T>
597 class ReplaceAccumulatorData {
598  public:
599   typedef typename Accumulator::Arc Arc;
600   typedef typename Arc::StateId StateId;
601   typedef typename Arc::Label Label;
602   typedef T StateTable;
603   typedef typename T::StateTuple StateTuple;
604 
ReplaceAccumulatorData()605   ReplaceAccumulatorData() : state_table_(0) {}
606 
ReplaceAccumulatorData(const vector<Accumulator * > & accumulators)607   ReplaceAccumulatorData(const vector<Accumulator*> &accumulators)
608       : state_table_(0), accumulators_(accumulators) {}
609 
~ReplaceAccumulatorData()610   ~ReplaceAccumulatorData() {
611     for (size_t i = 0; i < fst_array_.size(); ++i)
612       delete fst_array_[i];
613     for (size_t i = 0; i < accumulators_.size(); ++i)
614       delete accumulators_[i];
615   }
616 
Init(const vector<pair<Label,const Fst<Arc> * >> & fst_tuples,const StateTable * state_table)617   void Init(const vector<pair<Label, const Fst<Arc>*> > &fst_tuples,
618        const StateTable *state_table) {
619     state_table_ = state_table;
620     accumulators_.resize(fst_tuples.size());
621     for (size_t i = 0; i < accumulators_.size(); ++i) {
622       if (!accumulators_[i])
623         accumulators_[i] = new Accumulator;
624       accumulators_[i]->Init(*(fst_tuples[i].second));
625       fst_array_.push_back(fst_tuples[i].second->Copy());
626     }
627   }
628 
GetTuple(StateId s)629   const StateTuple &GetTuple(StateId s) const {
630     return state_table_->Tuple(s);
631   }
632 
GetAccumulator(size_t i)633   Accumulator *GetAccumulator(size_t i) { return accumulators_[i]; }
634 
GetFst(size_t i)635   const Fst<Arc> *GetFst(size_t i) const { return fst_array_[i]; }
636 
RefCount()637   int RefCount() const { return ref_count_.count(); }
IncrRefCount()638   int IncrRefCount() { return ref_count_.Incr(); }
DecrRefCount()639   int DecrRefCount() { return ref_count_.Decr(); }
640 
641  private:
642   const T * state_table_;
643   vector<Accumulator*> accumulators_;
644   vector<const Fst<Arc>*> fst_array_;
645   RefCounter ref_count_;
646 
647   DISALLOW_COPY_AND_ASSIGN(ReplaceAccumulatorData);
648 };
649 
650 // This class accumulates weights in a ReplaceFst.  The 'Init' method
651 // takes as input the argument used to build the ReplaceFst and the
652 // ReplaceFst state table. It uses accumulators of type 'Accumulator'
653 // in the underlying FSTs.
654 template <class Accumulator,
655           class T = DefaultReplaceStateTable<typename Accumulator::Arc> >
656 class ReplaceAccumulator {
657  public:
658   typedef typename Accumulator::Arc Arc;
659   typedef typename Arc::StateId StateId;
660   typedef typename Arc::Label Label;
661   typedef typename Arc::Weight Weight;
662   typedef T StateTable;
663   typedef typename T::StateTuple StateTuple;
664 
ReplaceAccumulator()665   ReplaceAccumulator()
666       : init_(false), data_(new ReplaceAccumulatorData<Accumulator, T>()),
667         error_(false) {}
668 
ReplaceAccumulator(const vector<Accumulator * > & accumulators)669   ReplaceAccumulator(const vector<Accumulator*> &accumulators)
670       : init_(false),
671         data_(new ReplaceAccumulatorData<Accumulator, T>(accumulators)),
672         error_(false) {}
673 
ReplaceAccumulator(const ReplaceAccumulator<Accumulator,T> & acc)674   ReplaceAccumulator(const ReplaceAccumulator<Accumulator, T> &acc)
675       : init_(acc.init_), data_(acc.data_), error_(acc.error_) {
676     if (!init_)
677       FSTERROR() << "ReplaceAccumulator: can't copy unintialized accumulator";
678     data_->IncrRefCount();
679   }
680 
~ReplaceAccumulator()681   ~ReplaceAccumulator() {
682      if (!data_->DecrRefCount())
683       delete data_;
684   }
685 
686   // Does not take ownership of the state table, the state table
687   // is own by the ReplaceFst
Init(const vector<pair<Label,const Fst<Arc> * >> & fst_tuples,const StateTable * state_table)688   void Init(const vector<pair<Label, const Fst<Arc>*> > &fst_tuples,
689             const StateTable *state_table) {
690     init_ = true;
691     data_->Init(fst_tuples, state_table);
692   }
693 
SetState(StateId s)694   void SetState(StateId s) {
695     if (!init_) {
696       FSTERROR() << "ReplaceAccumulator::SetState: incorrectly initialized.";
697       error_ = true;
698       return;
699     }
700     StateTuple tuple = data_->GetTuple(s);
701     fst_id_ = tuple.fst_id - 1;  // Replace FST ID is 1-based
702     data_->GetAccumulator(fst_id_)->SetState(tuple.fst_state);
703     if ((tuple.prefix_id != 0) &&
704         (data_->GetFst(fst_id_)->Final(tuple.fst_state) != Weight::Zero())) {
705       offset_ = 1;
706       offset_weight_ = data_->GetFst(fst_id_)->Final(tuple.fst_state);
707     } else {
708       offset_ = 0;
709       offset_weight_ = Weight::Zero();
710     }
711   }
712 
Sum(Weight w,Weight v)713   Weight Sum(Weight w, Weight v) {
714     if (error_) return Weight::NoWeight();
715     return data_->GetAccumulator(fst_id_)->Sum(w, v);
716   }
717 
718   template <class ArcIterator>
Sum(Weight w,ArcIterator * aiter,ssize_t begin,ssize_t end)719   Weight Sum(Weight w, ArcIterator *aiter, ssize_t begin,
720              ssize_t end) {
721     if (error_) return Weight::NoWeight();
722     Weight sum = begin == end ? Weight::Zero()
723         : data_->GetAccumulator(fst_id_)->Sum(
724             w, aiter, begin ? begin - offset_ : 0, end - offset_);
725     if (begin == 0 && end != 0 && offset_ > 0)
726       sum = Sum(offset_weight_, sum);
727     return sum;
728   }
729 
Error()730   bool Error() const { return error_; }
731 
732  private:
733   bool init_;
734   ReplaceAccumulatorData<Accumulator, T> *data_;
735   Label fst_id_;
736   size_t offset_;
737   Weight offset_weight_;
738   bool error_;
739 
740   void operator=(const ReplaceAccumulator<Accumulator, T> &);   // Disallow
741 };
742 
743 }  // namespace fst
744 
745 #endif  // FST_LIB_ACCUMULATOR_H__
746