1 // cache.h 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 // 16 // \file 17 // An Fst implementation that caches FST elements of a delayed 18 // computation. 19 20 #ifndef FST_LIB_CACHE_H__ 21 #define FST_LIB_CACHE_H__ 22 23 #include <list> 24 25 #include "fst/lib/vector-fst.h" 26 27 DECLARE_bool(fst_default_cache_gc); 28 DECLARE_int64(fst_default_cache_gc_limit); 29 30 namespace fst { 31 32 struct CacheOptions { 33 bool gc; // enable GC 34 size_t gc_limit; // # of bytes allowed before GC 35 36 CacheOptionsCacheOptions37 CacheOptions(bool g, size_t l) : gc(g), gc_limit(l) {} CacheOptionsCacheOptions38 CacheOptions() 39 : gc(FLAGS_fst_default_cache_gc), 40 gc_limit(FLAGS_fst_default_cache_gc_limit) {} 41 }; 42 43 44 // This is a VectorFstBaseImpl container that holds a State similar to 45 // VectorState but additionally has a flags data member (see 46 // CacheState below). This class is used to cache FST elements with 47 // the flags used to indicate what has been cached. Use HasStart() 48 // HasFinal(), and HasArcs() to determine if cached and SetStart(), 49 // SetFinal(), AddArc(), and SetArcs() to cache. Note you must set the 50 // final weight even if the state is non-final to mark it as 51 // cached. If the 'gc' option is 'false', cached items have the extent 52 // of the FST - minimizing computation. If the 'gc' option is 'true', 53 // garbage collection of states (not in use in an arc iterator) is 54 // performed, in a rough approximation of LRU order, when 'gc_limit' 55 // bytes is reached - controlling memory use. When 'gc_limit' is 0, 56 // special optimizations apply - minimizing memory use. 57 58 template <class S> 59 class CacheBaseImpl : public VectorFstBaseImpl<S> { 60 public: 61 using FstImpl<typename S::Arc>::Type; 62 using VectorFstBaseImpl<S>::NumStates; 63 using VectorFstBaseImpl<S>::AddState; 64 65 typedef S State; 66 typedef typename S::Arc Arc; 67 typedef typename Arc::Weight Weight; 68 typedef typename Arc::StateId StateId; 69 CacheBaseImpl()70 CacheBaseImpl() 71 : cache_start_(false), nknown_states_(0), min_unexpanded_state_id_(0), 72 cache_first_state_id_(kNoStateId), cache_first_state_(0), 73 cache_gc_(FLAGS_fst_default_cache_gc), cache_size_(0), 74 cache_limit_(FLAGS_fst_default_cache_gc_limit > kMinCacheLimit || 75 FLAGS_fst_default_cache_gc_limit == 0 ? 76 FLAGS_fst_default_cache_gc_limit : kMinCacheLimit) {} 77 CacheBaseImpl(const CacheOptions & opts)78 explicit CacheBaseImpl(const CacheOptions &opts) 79 : cache_start_(false), nknown_states_(0), 80 min_unexpanded_state_id_(0), cache_first_state_id_(kNoStateId), 81 cache_first_state_(0), cache_gc_(opts.gc), cache_size_(0), 82 cache_limit_(opts.gc_limit > kMinCacheLimit || opts.gc_limit == 0 ? 83 opts.gc_limit : kMinCacheLimit) {} 84 ~CacheBaseImpl()85 ~CacheBaseImpl() { 86 delete cache_first_state_; 87 } 88 89 // Gets a state from its ID; state must exist. GetState(StateId s)90 const S *GetState(StateId s) const { 91 if (s == cache_first_state_id_) 92 return cache_first_state_; 93 else 94 return VectorFstBaseImpl<S>::GetState(s); 95 } 96 97 // Gets a state from its ID; state must exist. GetState(StateId s)98 S *GetState(StateId s) { 99 if (s == cache_first_state_id_) 100 return cache_first_state_; 101 else 102 return VectorFstBaseImpl<S>::GetState(s); 103 } 104 105 // Gets a state from its ID; return 0 if it doesn't exist. CheckState(StateId s)106 const S *CheckState(StateId s) const { 107 if (s == cache_first_state_id_) 108 return cache_first_state_; 109 else if (s < NumStates()) 110 return VectorFstBaseImpl<S>::GetState(s); 111 else 112 return 0; 113 } 114 115 // Gets a state from its ID; add it if necessary. ExtendState(StateId s)116 S *ExtendState(StateId s) { 117 if (s == cache_first_state_id_) { 118 return cache_first_state_; // Return 1st cached state 119 } else if (cache_limit_ == 0 && cache_first_state_id_ == kNoStateId) { 120 cache_first_state_id_ = s; // Remember 1st cached state 121 cache_first_state_ = new S; 122 return cache_first_state_; 123 } else if (cache_first_state_id_ != kNoStateId && 124 cache_first_state_->ref_count == 0) { 125 cache_first_state_id_ = s; // Reuse 1st cached state 126 cache_first_state_->Reset(); 127 return cache_first_state_; // Return 1st cached state 128 } else { 129 while (NumStates() <= s) // Add state to main cache 130 AddState(0); 131 if (!VectorFstBaseImpl<S>::GetState(s)) { 132 SetState(s, new S); 133 if (cache_first_state_id_ != kNoStateId) { // Forget 1st cached state 134 while (NumStates() <= cache_first_state_id_) 135 AddState(0); 136 SetState(cache_first_state_id_, cache_first_state_); 137 if (cache_gc_) { 138 cache_states_.push_back(cache_first_state_id_); 139 cache_size_ += sizeof(S) + 140 cache_first_state_->arcs.capacity() * sizeof(Arc); 141 cache_limit_ = kMinCacheLimit; 142 } 143 cache_first_state_id_ = kNoStateId; 144 cache_first_state_ = 0; 145 } 146 if (cache_gc_) { 147 cache_states_.push_back(s); 148 cache_size_ += sizeof(S); 149 if (cache_size_ > cache_limit_) 150 GC(s, false); 151 } 152 } 153 return VectorFstBaseImpl<S>::GetState(s); 154 } 155 } 156 SetStart(StateId s)157 void SetStart(StateId s) { 158 VectorFstBaseImpl<S>::SetStart(s); 159 cache_start_ = true; 160 if (s >= nknown_states_) 161 nknown_states_ = s + 1; 162 } 163 SetFinal(StateId s,Weight w)164 void SetFinal(StateId s, Weight w) { 165 S *state = ExtendState(s); 166 state->final = w; 167 state->flags |= kCacheFinal | kCacheRecent; 168 } 169 AddArc(StateId s,const Arc & arc)170 void AddArc(StateId s, const Arc &arc) { 171 S *state = ExtendState(s); 172 state->arcs.push_back(arc); 173 } 174 175 // Marks arcs of state s as cached. SetArcs(StateId s)176 void SetArcs(StateId s) { 177 S *state = ExtendState(s); 178 vector<Arc> &arcs = state->arcs; 179 state->niepsilons = state->noepsilons = 0; 180 for (unsigned int a = 0; a < arcs.size(); ++a) { 181 const Arc &arc = arcs[a]; 182 if (arc.nextstate >= nknown_states_) 183 nknown_states_ = arc.nextstate + 1; 184 if (arc.ilabel == 0) 185 ++state->niepsilons; 186 if (arc.olabel == 0) 187 ++state->noepsilons; 188 } 189 ExpandedState(s); 190 state->flags |= kCacheArcs | kCacheRecent; 191 if (cache_gc_ && s != cache_first_state_id_) { 192 cache_size_ += arcs.capacity() * sizeof(Arc); 193 if (cache_size_ > cache_limit_) 194 GC(s, false); 195 } 196 }; 197 ReserveArcs(StateId s,size_t n)198 void ReserveArcs(StateId s, size_t n) { 199 S *state = ExtendState(s); 200 state->arcs.reserve(n); 201 } 202 203 // Is the start state cached? HasStart()204 bool HasStart() const { return cache_start_; } 205 // Is the final weight of state s cached? 206 HasFinal(StateId s)207 bool HasFinal(StateId s) const { 208 const S *state = CheckState(s); 209 if (state && state->flags & kCacheFinal) { 210 state->flags |= kCacheRecent; 211 return true; 212 } else { 213 return false; 214 } 215 } 216 217 // Are arcs of state s cached? HasArcs(StateId s)218 bool HasArcs(StateId s) const { 219 const S *state = CheckState(s); 220 if (state && state->flags & kCacheArcs) { 221 state->flags |= kCacheRecent; 222 return true; 223 } else { 224 return false; 225 } 226 } 227 Final(StateId s)228 Weight Final(StateId s) const { 229 const S *state = GetState(s); 230 return state->final; 231 } 232 NumArcs(StateId s)233 size_t NumArcs(StateId s) const { 234 const S *state = GetState(s); 235 return state->arcs.size(); 236 } 237 NumInputEpsilons(StateId s)238 size_t NumInputEpsilons(StateId s) const { 239 const S *state = GetState(s); 240 return state->niepsilons; 241 } 242 NumOutputEpsilons(StateId s)243 size_t NumOutputEpsilons(StateId s) const { 244 const S *state = GetState(s); 245 return state->noepsilons; 246 } 247 248 // Provides information needed for generic arc iterator. InitArcIterator(StateId s,ArcIteratorData<Arc> * data)249 void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const { 250 const S *state = GetState(s); 251 data->base = 0; 252 data->narcs = state->arcs.size(); 253 data->arcs = data->narcs > 0 ? &(state->arcs[0]) : 0; 254 data->ref_count = &(state->ref_count); 255 ++(*data->ref_count); 256 } 257 258 // Number of known states. NumKnownStates()259 StateId NumKnownStates() const { return nknown_states_; } 260 // Find the mininum never-expanded state Id MinUnexpandedState()261 StateId MinUnexpandedState() const { 262 while (min_unexpanded_state_id_ < (StateId)expanded_states_.size() && 263 expanded_states_[min_unexpanded_state_id_]) 264 ++min_unexpanded_state_id_; 265 return min_unexpanded_state_id_; 266 } 267 268 // Removes from cache_states_ and uncaches (not referenced-counted) 269 // states that have not been accessed since the last GC until 270 // cache_limit_/3 bytes are uncached. If that fails to free enough, 271 // recurs uncaching recently visited states as well. If still 272 // unable to free enough memory, then widens cache_limit_. GC(StateId current,bool free_recent)273 void GC(StateId current, bool free_recent) { 274 if (!cache_gc_) 275 return; 276 VLOG(2) << "CacheImpl: Enter GC: object = " << Type() << "(" << this 277 << "), free recently cached = " << free_recent 278 << ", cache size = " << cache_size_ 279 << ", cache limit = " << cache_limit_ << "\n"; 280 typename list<StateId>::iterator siter = cache_states_.begin(); 281 282 size_t cache_target = (2 * cache_limit_)/3 + 1; 283 while (siter != cache_states_.end()) { 284 StateId s = *siter; 285 S* state = VectorFstBaseImpl<S>::GetState(s); 286 if (cache_size_ > cache_target && state->ref_count == 0 && 287 (free_recent || !(state->flags & kCacheRecent)) && s != current) { 288 cache_size_ -= sizeof(S) + state->arcs.capacity() * sizeof(Arc); 289 delete state; 290 SetState(s, 0); 291 cache_states_.erase(siter++); 292 } else { 293 state->flags &= ~kCacheRecent; 294 ++siter; 295 } 296 } 297 if (!free_recent && cache_size_ > cache_target) { 298 GC(current, true); 299 } else { 300 while (cache_size_ > cache_target) { 301 cache_limit_ *= 2; 302 cache_target *= 2; 303 } 304 } 305 VLOG(2) << "CacheImpl: Exit GC: object = " << Type() << "(" << this 306 << "), free recently cached = " << free_recent 307 << ", cache size = " << cache_size_ 308 << ", cache limit = " << cache_limit_ << "\n"; 309 } 310 311 private: 312 static const uint32 kCacheFinal = 0x0001; // Final weight has been cached 313 static const uint32 kCacheArcs = 0x0002; // Arcs have been cached 314 static const uint32 kCacheRecent = 0x0004; // Mark as visited since GC 315 316 static const size_t kMinCacheLimit; // Minimum (non-zero) cache limit 317 ExpandedState(StateId s)318 void ExpandedState(StateId s) { 319 if (s < min_unexpanded_state_id_) 320 return; 321 while ((StateId)expanded_states_.size() <= s) 322 expanded_states_.push_back(false); 323 expanded_states_[s] = true; 324 } 325 326 bool cache_start_; // Is the start state cached? 327 StateId nknown_states_; // # of known states 328 vector<bool> expanded_states_; // states that have been expanded 329 mutable StateId min_unexpanded_state_id_; // minimum never-expanded state Id 330 StateId cache_first_state_id_; // First cached state id 331 S *cache_first_state_; // First cached state 332 list<StateId> cache_states_; // list of currently cached states 333 bool cache_gc_; // enable GC 334 size_t cache_size_; // # of bytes cached 335 size_t cache_limit_; // # of bytes allowed before GC 336 337 void InitStateIterator(StateIteratorData<Arc> *); // disallow 338 DISALLOW_EVIL_CONSTRUCTORS(CacheBaseImpl); 339 }; 340 341 template <class S> 342 const size_t CacheBaseImpl<S>::kMinCacheLimit = 8096; 343 344 345 // Arcs implemented by an STL vector per state. Similar to VectorState 346 // but adds flags and ref count to keep track of what has been cached. 347 template <class A> 348 struct CacheState { 349 typedef A Arc; 350 typedef typename A::Weight Weight; 351 typedef typename A::StateId StateId; 352 CacheStateCacheState353 CacheState() : final(Weight::Zero()), flags(0), ref_count(0) {} 354 ResetCacheState355 void Reset() { 356 flags = 0; 357 ref_count = 0; 358 arcs.resize(0); 359 } 360 361 Weight final; // Final weight 362 vector<A> arcs; // Arcs represenation 363 size_t niepsilons; // # of input epsilons 364 size_t noepsilons; // # of output epsilons 365 mutable uint32 flags; 366 mutable int ref_count; 367 }; 368 369 // A CacheBaseImpl with a commonly used CacheState. 370 template <class A> 371 class CacheImpl : public CacheBaseImpl< CacheState<A> > { 372 public: 373 typedef CacheState<A> State; 374 CacheImpl()375 CacheImpl() {} 376 CacheImpl(const CacheOptions & opts)377 explicit CacheImpl(const CacheOptions &opts) 378 : CacheBaseImpl< CacheState<A> >(opts) {} 379 380 private: 381 DISALLOW_EVIL_CONSTRUCTORS(CacheImpl); 382 }; 383 384 385 // Use this to make a state iterator for a CacheBaseImpl-derived Fst. 386 // You'll need to make this class a friend of your derived Fst. 387 // Note this iterator only returns those states reachable from 388 // the initial state, so consider implementing a class-specific one. 389 template <class F> 390 class CacheStateIterator : public StateIteratorBase<typename F::Arc> { 391 public: 392 typedef typename F::Arc Arc; 393 typedef typename Arc::StateId StateId; 394 CacheStateIterator(const F & fst)395 explicit CacheStateIterator(const F &fst) : fst_(fst), s_(0) {} 396 Done()397 virtual bool Done() const { 398 if (s_ < fst_.impl_->NumKnownStates()) 399 return false; 400 fst_.Start(); // force start state 401 if (s_ < fst_.impl_->NumKnownStates()) 402 return false; 403 for (int u = fst_.impl_->MinUnexpandedState(); 404 u < fst_.impl_->NumKnownStates(); 405 u = fst_.impl_->MinUnexpandedState()) { 406 ArcIterator<F>(fst_, u); // force state expansion 407 if (s_ < fst_.impl_->NumKnownStates()) 408 return false; 409 } 410 return true; 411 } 412 Value()413 virtual StateId Value() const { return s_; } 414 Next()415 virtual void Next() { ++s_; } 416 Reset()417 virtual void Reset() { s_ = 0; } 418 419 private: 420 const F &fst_; 421 StateId s_; 422 }; 423 424 425 // Use this to make an arc iterator for a CacheBaseImpl-derived Fst. 426 // You'll need to make this class a friend of your derived Fst and 427 // define types Arc and State. 428 template <class F> 429 class CacheArcIterator { 430 public: 431 typedef typename F::Arc Arc; 432 typedef typename F::State State; 433 typedef typename Arc::StateId StateId; 434 CacheArcIterator(const F & fst,StateId s)435 CacheArcIterator(const F &fst, StateId s) : i_(0) { 436 state_ = fst.impl_->ExtendState(s); 437 ++state_->ref_count; 438 } 439 ~CacheArcIterator()440 ~CacheArcIterator() { --state_->ref_count; } 441 Done()442 bool Done() const { return i_ >= state_->arcs.size(); } 443 Value()444 const Arc& Value() const { return state_->arcs[i_]; } 445 Next()446 void Next() { ++i_; } 447 Reset()448 void Reset() { i_ = 0; } 449 Seek(size_t a)450 void Seek(size_t a) { i_ = a; } 451 452 private: 453 const State *state_; 454 size_t i_; 455 456 DISALLOW_EVIL_CONSTRUCTORS(CacheArcIterator); 457 }; 458 459 } // namespace fst 460 461 #endif // FST_LIB_CACHE_H__ 462