• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // cache.h
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 //
16 // \file
17 // An Fst implementation that caches FST elements of a delayed
18 // computation.
19 
20 #ifndef FST_LIB_CACHE_H__
21 #define FST_LIB_CACHE_H__
22 
23 #include <list>
24 
25 #include "fst/lib/vector-fst.h"
26 
27 DECLARE_bool(fst_default_cache_gc);
28 DECLARE_int64(fst_default_cache_gc_limit);
29 
30 namespace fst {
31 
32 struct CacheOptions {
33   bool gc;          // enable GC
34   size_t gc_limit;  // # of bytes allowed before GC
35 
36 
CacheOptionsCacheOptions37   CacheOptions(bool g, size_t l) : gc(g), gc_limit(l) {}
CacheOptionsCacheOptions38   CacheOptions()
39       : gc(FLAGS_fst_default_cache_gc),
40         gc_limit(FLAGS_fst_default_cache_gc_limit) {}
41 };
42 
43 
44 // This is a VectorFstBaseImpl container that holds a State similar to
45 // VectorState but additionally has a flags data member (see
46 // CacheState below). This class is used to cache FST elements with
47 // the flags used to indicate what has been cached. Use HasStart()
48 // HasFinal(), and HasArcs() to determine if cached and SetStart(),
49 // SetFinal(), AddArc(), and SetArcs() to cache. Note you must set the
50 // final weight even if the state is non-final to mark it as
51 // cached. If the 'gc' option is 'false', cached items have the extent
52 // of the FST - minimizing computation. If the 'gc' option is 'true',
53 // garbage collection of states (not in use in an arc iterator) is
54 // performed, in a rough approximation of LRU order, when 'gc_limit'
55 // bytes is reached - controlling memory use. When 'gc_limit' is 0,
56 // special optimizations apply - minimizing memory use.
57 
58 template <class S>
59 class CacheBaseImpl : public VectorFstBaseImpl<S> {
60  public:
61   using FstImpl<typename S::Arc>::Type;
62   using VectorFstBaseImpl<S>::NumStates;
63   using VectorFstBaseImpl<S>::AddState;
64 
65   typedef S State;
66   typedef typename S::Arc Arc;
67   typedef typename Arc::Weight Weight;
68   typedef typename Arc::StateId StateId;
69 
CacheBaseImpl()70   CacheBaseImpl()
71       : cache_start_(false), nknown_states_(0), min_unexpanded_state_id_(0),
72         cache_first_state_id_(kNoStateId), cache_first_state_(0),
73         cache_gc_(FLAGS_fst_default_cache_gc),  cache_size_(0),
74         cache_limit_(FLAGS_fst_default_cache_gc_limit > kMinCacheLimit ||
75                      FLAGS_fst_default_cache_gc_limit == 0 ?
76                      FLAGS_fst_default_cache_gc_limit : kMinCacheLimit) {}
77 
CacheBaseImpl(const CacheOptions & opts)78   explicit CacheBaseImpl(const CacheOptions &opts)
79       : cache_start_(false), nknown_states_(0),
80         min_unexpanded_state_id_(0), cache_first_state_id_(kNoStateId),
81         cache_first_state_(0), cache_gc_(opts.gc), cache_size_(0),
82         cache_limit_(opts.gc_limit > kMinCacheLimit || opts.gc_limit == 0 ?
83                      opts.gc_limit : kMinCacheLimit) {}
84 
~CacheBaseImpl()85   ~CacheBaseImpl() {
86     delete cache_first_state_;
87   }
88 
89   // Gets a state from its ID; state must exist.
GetState(StateId s)90   const S *GetState(StateId s) const {
91     if (s == cache_first_state_id_)
92       return cache_first_state_;
93     else
94       return VectorFstBaseImpl<S>::GetState(s);
95   }
96 
97   // Gets a state from its ID; state must exist.
GetState(StateId s)98   S *GetState(StateId s) {
99     if (s == cache_first_state_id_)
100       return cache_first_state_;
101     else
102       return VectorFstBaseImpl<S>::GetState(s);
103   }
104 
105   // Gets a state from its ID; return 0 if it doesn't exist.
CheckState(StateId s)106   const S *CheckState(StateId s) const {
107     if (s == cache_first_state_id_)
108       return cache_first_state_;
109     else if (s < NumStates())
110       return VectorFstBaseImpl<S>::GetState(s);
111     else
112       return 0;
113   }
114 
115   // Gets a state from its ID; add it if necessary.
ExtendState(StateId s)116   S *ExtendState(StateId s) {
117     if (s == cache_first_state_id_) {
118       return cache_first_state_;                   // Return 1st cached state
119     } else if (cache_limit_ == 0 && cache_first_state_id_ == kNoStateId) {
120       cache_first_state_id_ = s;                   // Remember 1st cached state
121       cache_first_state_ = new S;
122       return cache_first_state_;
123     } else if (cache_first_state_id_ != kNoStateId &&
124                cache_first_state_->ref_count == 0) {
125       cache_first_state_id_ = s;                   // Reuse 1st cached state
126       cache_first_state_->Reset();
127       return cache_first_state_;                   // Return 1st cached state
128     } else {
129       while (NumStates() <= s)                     // Add state to main cache
130         AddState(0);
131       if (!VectorFstBaseImpl<S>::GetState(s)) {
132         SetState(s, new S);
133         if (cache_first_state_id_ != kNoStateId) {  // Forget 1st cached state
134           while (NumStates() <= cache_first_state_id_)
135             AddState(0);
136           SetState(cache_first_state_id_, cache_first_state_);
137           if (cache_gc_) {
138             cache_states_.push_back(cache_first_state_id_);
139             cache_size_ += sizeof(S) +
140                            cache_first_state_->arcs.capacity() * sizeof(Arc);
141             cache_limit_ = kMinCacheLimit;
142           }
143           cache_first_state_id_ = kNoStateId;
144           cache_first_state_ = 0;
145         }
146         if (cache_gc_) {
147           cache_states_.push_back(s);
148           cache_size_ += sizeof(S);
149           if (cache_size_ > cache_limit_)
150             GC(s, false);
151         }
152       }
153       return VectorFstBaseImpl<S>::GetState(s);
154     }
155   }
156 
SetStart(StateId s)157   void SetStart(StateId s) {
158     VectorFstBaseImpl<S>::SetStart(s);
159     cache_start_ = true;
160     if (s >= nknown_states_)
161       nknown_states_ = s + 1;
162   }
163 
SetFinal(StateId s,Weight w)164   void SetFinal(StateId s, Weight w) {
165     S *state = ExtendState(s);
166     state->final = w;
167     state->flags |= kCacheFinal | kCacheRecent;
168   }
169 
AddArc(StateId s,const Arc & arc)170   void AddArc(StateId s, const Arc &arc) {
171     S *state = ExtendState(s);
172     state->arcs.push_back(arc);
173   }
174 
175   // Marks arcs of state s as cached.
SetArcs(StateId s)176   void SetArcs(StateId s) {
177     S *state = ExtendState(s);
178     vector<Arc> &arcs = state->arcs;
179     state->niepsilons = state->noepsilons = 0;
180     for (unsigned int a = 0; a < arcs.size(); ++a) {
181       const Arc &arc = arcs[a];
182       if (arc.nextstate >= nknown_states_)
183         nknown_states_ = arc.nextstate + 1;
184       if (arc.ilabel == 0)
185         ++state->niepsilons;
186       if (arc.olabel == 0)
187         ++state->noepsilons;
188     }
189     ExpandedState(s);
190     state->flags |= kCacheArcs | kCacheRecent;
191     if (cache_gc_ && s != cache_first_state_id_) {
192       cache_size_ += arcs.capacity() * sizeof(Arc);
193       if (cache_size_ > cache_limit_)
194         GC(s, false);
195     }
196   };
197 
ReserveArcs(StateId s,size_t n)198   void ReserveArcs(StateId s, size_t n) {
199     S *state = ExtendState(s);
200     state->arcs.reserve(n);
201   }
202 
203   // Is the start state cached?
HasStart()204   bool HasStart() const { return cache_start_; }
205   // Is the final weight of state s cached?
206 
HasFinal(StateId s)207   bool HasFinal(StateId s) const {
208     const S *state = CheckState(s);
209     if (state && state->flags & kCacheFinal) {
210       state->flags |= kCacheRecent;
211       return true;
212     } else {
213       return false;
214     }
215   }
216 
217   // Are arcs of state s cached?
HasArcs(StateId s)218   bool HasArcs(StateId s) const {
219     const S *state = CheckState(s);
220     if (state && state->flags & kCacheArcs) {
221       state->flags |= kCacheRecent;
222       return true;
223     } else {
224       return false;
225     }
226   }
227 
Final(StateId s)228   Weight Final(StateId s) const {
229     const S *state = GetState(s);
230     return state->final;
231   }
232 
NumArcs(StateId s)233   size_t NumArcs(StateId s) const {
234     const S *state = GetState(s);
235     return state->arcs.size();
236   }
237 
NumInputEpsilons(StateId s)238   size_t NumInputEpsilons(StateId s) const {
239     const S *state = GetState(s);
240     return state->niepsilons;
241   }
242 
NumOutputEpsilons(StateId s)243   size_t NumOutputEpsilons(StateId s) const {
244     const S *state = GetState(s);
245     return state->noepsilons;
246   }
247 
248   // Provides information needed for generic arc iterator.
InitArcIterator(StateId s,ArcIteratorData<Arc> * data)249   void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const {
250     const S *state = GetState(s);
251     data->base = 0;
252     data->narcs = state->arcs.size();
253     data->arcs = data->narcs > 0 ? &(state->arcs[0]) : 0;
254     data->ref_count = &(state->ref_count);
255     ++(*data->ref_count);
256   }
257 
258   // Number of known states.
NumKnownStates()259   StateId NumKnownStates() const { return nknown_states_; }
260   // Find the mininum never-expanded state Id
MinUnexpandedState()261   StateId MinUnexpandedState() const {
262     while (min_unexpanded_state_id_ < (StateId)expanded_states_.size() &&
263           expanded_states_[min_unexpanded_state_id_])
264       ++min_unexpanded_state_id_;
265     return min_unexpanded_state_id_;
266   }
267 
268   // Removes from cache_states_ and uncaches (not referenced-counted)
269   // states that have not been accessed since the last GC until
270   // cache_limit_/3 bytes are uncached.  If that fails to free enough,
271   // recurs uncaching recently visited states as well. If still
272   // unable to free enough memory, then widens cache_limit_.
GC(StateId current,bool free_recent)273   void GC(StateId current, bool free_recent) {
274     if (!cache_gc_)
275       return;
276     VLOG(2) << "CacheImpl: Enter GC: object = " << Type() << "(" << this
277             << "), free recently cached = " << free_recent
278             << ", cache size = " << cache_size_
279             << ", cache limit = " << cache_limit_ << "\n";
280     typename list<StateId>::iterator siter = cache_states_.begin();
281 
282     size_t cache_target = (2 * cache_limit_)/3 + 1;
283     while (siter != cache_states_.end()) {
284       StateId s = *siter;
285       S* state = VectorFstBaseImpl<S>::GetState(s);
286       if (cache_size_ > cache_target && state->ref_count == 0 &&
287           (free_recent || !(state->flags & kCacheRecent)) && s != current) {
288         cache_size_ -= sizeof(S) + state->arcs.capacity() * sizeof(Arc);
289         delete state;
290         SetState(s, 0);
291         cache_states_.erase(siter++);
292       } else {
293         state->flags &= ~kCacheRecent;
294         ++siter;
295       }
296     }
297     if (!free_recent && cache_size_ > cache_target) {
298       GC(current, true);
299     } else {
300       while (cache_size_ > cache_target) {
301         cache_limit_ *= 2;
302         cache_target *= 2;
303       }
304     }
305     VLOG(2) << "CacheImpl: Exit GC: object = " << Type() << "(" << this
306             << "), free recently cached = " << free_recent
307             << ", cache size = " << cache_size_
308             << ", cache limit = " << cache_limit_ << "\n";
309   }
310 
311  private:
312   static const uint32 kCacheFinal =  0x0001;  // Final weight has been cached
313   static const uint32 kCacheArcs =   0x0002;  // Arcs have been cached
314   static const uint32 kCacheRecent = 0x0004;  // Mark as visited since GC
315 
316   static const size_t kMinCacheLimit;         // Minimum (non-zero) cache limit
317 
ExpandedState(StateId s)318   void ExpandedState(StateId s) {
319     if (s < min_unexpanded_state_id_)
320       return;
321     while ((StateId)expanded_states_.size() <= s)
322       expanded_states_.push_back(false);
323     expanded_states_[s] = true;
324   }
325 
326   bool cache_start_;                         // Is the start state cached?
327   StateId nknown_states_;                    // # of known states
328   vector<bool> expanded_states_;             // states that have been expanded
329   mutable StateId min_unexpanded_state_id_;  // minimum never-expanded state Id
330   StateId cache_first_state_id_;             // First cached state id
331   S *cache_first_state_;                     // First cached state
332   list<StateId> cache_states_;               // list of currently cached states
333   bool cache_gc_;                            // enable GC
334   size_t cache_size_;                        // # of bytes cached
335   size_t cache_limit_;                       // # of bytes allowed before GC
336 
337   void InitStateIterator(StateIteratorData<Arc> *);  // disallow
338   DISALLOW_EVIL_CONSTRUCTORS(CacheBaseImpl);
339 };
340 
341 template <class S>
342 const size_t CacheBaseImpl<S>::kMinCacheLimit = 8096;
343 
344 
345 // Arcs implemented by an STL vector per state. Similar to VectorState
346 // but adds flags and ref count to keep track of what has been cached.
347 template <class A>
348 struct CacheState {
349   typedef A Arc;
350   typedef typename A::Weight Weight;
351   typedef typename A::StateId StateId;
352 
CacheStateCacheState353   CacheState() :  final(Weight::Zero()), flags(0), ref_count(0) {}
354 
ResetCacheState355   void Reset() {
356     flags = 0;
357     ref_count = 0;
358     arcs.resize(0);
359   }
360 
361   Weight final;              // Final weight
362   vector<A> arcs;            // Arcs represenation
363   size_t niepsilons;         // # of input epsilons
364   size_t noepsilons;         // # of output epsilons
365   mutable uint32 flags;
366   mutable int ref_count;
367 };
368 
369 // A CacheBaseImpl with a commonly used CacheState.
370 template <class A>
371 class CacheImpl : public CacheBaseImpl< CacheState<A> > {
372  public:
373   typedef CacheState<A> State;
374 
CacheImpl()375   CacheImpl() {}
376 
CacheImpl(const CacheOptions & opts)377   explicit CacheImpl(const CacheOptions &opts)
378       : CacheBaseImpl< CacheState<A> >(opts) {}
379 
380  private:
381   DISALLOW_EVIL_CONSTRUCTORS(CacheImpl);
382 };
383 
384 
385 // Use this to make a state iterator for a CacheBaseImpl-derived Fst.
386 // You'll need to make this class a friend of your derived Fst.
387 // Note this iterator only returns those states reachable from
388 // the initial state, so consider implementing a class-specific one.
389 template <class F>
390 class CacheStateIterator : public StateIteratorBase<typename F::Arc> {
391  public:
392   typedef typename F::Arc Arc;
393   typedef typename Arc::StateId StateId;
394 
CacheStateIterator(const F & fst)395   explicit CacheStateIterator(const F &fst) : fst_(fst), s_(0) {}
396 
Done()397   virtual bool Done() const {
398     if (s_ < fst_.impl_->NumKnownStates())
399       return false;
400     fst_.Start();  // force start state
401     if (s_ < fst_.impl_->NumKnownStates())
402       return false;
403     for (int u = fst_.impl_->MinUnexpandedState();
404          u < fst_.impl_->NumKnownStates();
405          u = fst_.impl_->MinUnexpandedState()) {
406       ArcIterator<F>(fst_, u);  // force state expansion
407       if (s_ < fst_.impl_->NumKnownStates())
408         return false;
409     }
410     return true;
411   }
412 
Value()413   virtual StateId Value() const { return s_; }
414 
Next()415   virtual void Next() { ++s_; }
416 
Reset()417   virtual void Reset() { s_ = 0; }
418 
419  private:
420   const F &fst_;
421   StateId s_;
422 };
423 
424 
425 // Use this to make an arc iterator for a CacheBaseImpl-derived Fst.
426 // You'll need to make this class a friend of your derived Fst and
427 // define types Arc and State.
428 template <class F>
429 class CacheArcIterator {
430  public:
431   typedef typename F::Arc Arc;
432   typedef typename F::State State;
433   typedef typename Arc::StateId StateId;
434 
CacheArcIterator(const F & fst,StateId s)435   CacheArcIterator(const F &fst, StateId s) : i_(0) {
436     state_ = fst.impl_->ExtendState(s);
437     ++state_->ref_count;
438   }
439 
~CacheArcIterator()440   ~CacheArcIterator() { --state_->ref_count;  }
441 
Done()442   bool Done() const { return i_ >= state_->arcs.size(); }
443 
Value()444   const Arc& Value() const { return state_->arcs[i_]; }
445 
Next()446   void Next() { ++i_; }
447 
Reset()448   void Reset() { i_ = 0; }
449 
Seek(size_t a)450   void Seek(size_t a) { i_ = a; }
451 
452  private:
453   const State *state_;
454   size_t i_;
455 
456   DISALLOW_EVIL_CONSTRUCTORS(CacheArcIterator);
457 };
458 
459 }  // namespace fst
460 
461 #endif  // FST_LIB_CACHE_H__
462