• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // cache.h
2 
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Copyright 2005-2010 Google, Inc.
16 // Author: riley@google.com (Michael Riley)
17 //
18 // \file
19 // An Fst implementation that caches FST elements of a delayed
20 // computation.
21 
22 #ifndef FST_LIB_CACHE_H__
23 #define FST_LIB_CACHE_H__
24 
25 #include <vector>
26 using std::vector;
27 #include <list>
28 
29 #include <fst/vector-fst.h>
30 
31 
32 DECLARE_bool(fst_default_cache_gc);
33 DECLARE_int64(fst_default_cache_gc_limit);
34 
35 namespace fst {
36 
37 struct CacheOptions {
38   bool gc;          // enable GC
39   size_t gc_limit;  // # of bytes allowed before GC
40 
CacheOptionsCacheOptions41   CacheOptions(bool g, size_t l) : gc(g), gc_limit(l) {}
CacheOptionsCacheOptions42   CacheOptions()
43       : gc(FLAGS_fst_default_cache_gc),
44         gc_limit(FLAGS_fst_default_cache_gc_limit) {}
45 };
46 
47 // A CacheStateAllocator allocates and frees CacheStates
48 // template <class S>
49 // struct CacheStateAllocator {
50 //   S *Allocate(StateId s);
51 //   void Free(S *state, StateId s);
52 // };
53 //
54 
55 // A simple allocator class, can be overridden as needed,
56 // maintains a single entry cache.
57 template <class S>
58 struct DefaultCacheStateAllocator {
59   typedef typename S::Arc::StateId StateId;
60 
DefaultCacheStateAllocatorDefaultCacheStateAllocator61   DefaultCacheStateAllocator() : mru_(NULL) { }
62 
~DefaultCacheStateAllocatorDefaultCacheStateAllocator63   ~DefaultCacheStateAllocator() {
64     delete mru_;
65   }
66 
AllocateDefaultCacheStateAllocator67   S *Allocate(StateId s) {
68     if (mru_) {
69       S *state = mru_;
70       mru_ = NULL;
71       state->Reset();
72       return state;
73     }
74     return new S();
75   }
76 
FreeDefaultCacheStateAllocator77   void Free(S *state, StateId s) {
78     if (mru_) {
79       delete mru_;
80     }
81     mru_ = state;
82   }
83 
84  private:
85   S *mru_;
86 };
87 
88 // VectorState but additionally has a flags data member (see
89 // CacheState below). This class is used to cache FST elements with
90 // the flags used to indicate what has been cached. Use HasStart()
91 // HasFinal(), and HasArcs() to determine if cached and SetStart(),
92 // SetFinal(), AddArc(), (or PushArc() and SetArcs()) to cache. Note you
93 // must set the final weight even if the state is non-final to mark it as
94 // cached. If the 'gc' option is 'false', cached items have the extent
95 // of the FST - minimizing computation. If the 'gc' option is 'true',
96 // garbage collection of states (not in use in an arc iterator) is
97 // performed, in a rough approximation of LRU order, when 'gc_limit'
98 // bytes is reached - controlling memory use. When 'gc_limit' is 0,
99 // special optimizations apply - minimizing memory use.
100 
101 template <class S, class C = DefaultCacheStateAllocator<S> >
102 class CacheBaseImpl : public VectorFstBaseImpl<S> {
103  public:
104   typedef S State;
105   typedef C Allocator;
106   typedef typename State::Arc Arc;
107   typedef typename Arc::Weight Weight;
108   typedef typename Arc::StateId StateId;
109 
110   using FstImpl<Arc>::Type;
111   using FstImpl<Arc>::Properties;
112   using FstImpl<Arc>::SetProperties;
113   using VectorFstBaseImpl<State>::NumStates;
114   using VectorFstBaseImpl<State>::AddState;
115   using VectorFstBaseImpl<State>::SetState;
116 
117   explicit CacheBaseImpl(C *allocator = 0)
cache_start_(false)118       : cache_start_(false), nknown_states_(0), min_unexpanded_state_id_(0),
119         cache_first_state_id_(kNoStateId), cache_first_state_(0),
120         cache_gc_(FLAGS_fst_default_cache_gc),  cache_size_(0),
121         cache_limit_(FLAGS_fst_default_cache_gc_limit > kMinCacheLimit ||
122                      FLAGS_fst_default_cache_gc_limit == 0 ?
123                      FLAGS_fst_default_cache_gc_limit : kMinCacheLimit) {
124           allocator_ = allocator ? allocator : new C();
125         }
126 
127   explicit CacheBaseImpl(const CacheOptions &opts, C *allocator = 0)
cache_start_(false)128       : cache_start_(false), nknown_states_(0),
129         min_unexpanded_state_id_(0), cache_first_state_id_(kNoStateId),
130         cache_first_state_(0), cache_gc_(opts.gc), cache_size_(0),
131         cache_limit_(opts.gc_limit > kMinCacheLimit || opts.gc_limit == 0 ?
132                      opts.gc_limit : kMinCacheLimit) {
133           allocator_ = allocator ? allocator : new C();
134         }
135 
136   // Preserve gc parameters, but initially cache nothing.
CacheBaseImpl(const CacheBaseImpl & impl)137   CacheBaseImpl(const CacheBaseImpl &impl)
138     : cache_start_(false), nknown_states_(0),
139       min_unexpanded_state_id_(0), cache_first_state_id_(kNoStateId),
140       cache_first_state_(0), cache_gc_(impl.cache_gc_), cache_size_(0),
141       cache_limit_(impl.cache_limit_) {
142         allocator_ = new C();
143       }
144 
~CacheBaseImpl()145   ~CacheBaseImpl() {
146     allocator_->Free(cache_first_state_, cache_first_state_id_);
147     delete allocator_;
148   }
149 
150   // Gets a state from its ID; state must exist.
GetState(StateId s)151   const S *GetState(StateId s) const {
152     if (s == cache_first_state_id_)
153       return cache_first_state_;
154     else
155       return VectorFstBaseImpl<S>::GetState(s);
156   }
157 
158   // Gets a state from its ID; state must exist.
GetState(StateId s)159   S *GetState(StateId s) {
160     if (s == cache_first_state_id_)
161       return cache_first_state_;
162     else
163       return VectorFstBaseImpl<S>::GetState(s);
164   }
165 
166   // Gets a state from its ID; return 0 if it doesn't exist.
CheckState(StateId s)167   const S *CheckState(StateId s) const {
168     if (s == cache_first_state_id_)
169       return cache_first_state_;
170     else if (s < NumStates())
171       return VectorFstBaseImpl<S>::GetState(s);
172     else
173       return 0;
174   }
175 
176   // Gets a state from its ID; add it if necessary.
ExtendState(StateId s)177   S *ExtendState(StateId s) {
178     if (s == cache_first_state_id_) {
179       return cache_first_state_;                   // Return 1st cached state
180     } else if (cache_limit_ == 0 && cache_first_state_id_ == kNoStateId) {
181       cache_first_state_id_ = s;                   // Remember 1st cached state
182       cache_first_state_ = allocator_->Allocate(s);
183       return cache_first_state_;
184     } else if (cache_first_state_id_ != kNoStateId &&
185                cache_first_state_->ref_count == 0) {
186       // With Default allocator, the Free and Allocate will reuse the same S*.
187       allocator_->Free(cache_first_state_, cache_first_state_id_);
188       cache_first_state_id_ = s;
189       cache_first_state_ = allocator_->Allocate(s);
190       return cache_first_state_;                   // Return 1st cached state
191     } else {
192       while (NumStates() <= s)                     // Add state to main cache
193         AddState(0);
194       if (!VectorFstBaseImpl<S>::GetState(s)) {
195         SetState(s, allocator_->Allocate(s));
196         if (cache_first_state_id_ != kNoStateId) {  // Forget 1st cached state
197           while (NumStates() <= cache_first_state_id_)
198             AddState(0);
199           SetState(cache_first_state_id_, cache_first_state_);
200           if (cache_gc_) {
201             cache_states_.push_back(cache_first_state_id_);
202             cache_size_ += sizeof(S) +
203                            cache_first_state_->arcs.capacity() * sizeof(Arc);
204           }
205           cache_limit_ = kMinCacheLimit;
206           cache_first_state_id_ = kNoStateId;
207           cache_first_state_ = 0;
208         }
209         if (cache_gc_) {
210           cache_states_.push_back(s);
211           cache_size_ += sizeof(S);
212           if (cache_size_ > cache_limit_)
213             GC(s, false);
214         }
215       }
216       S *state = VectorFstBaseImpl<S>::GetState(s);
217       return state;
218     }
219   }
220 
SetStart(StateId s)221   void SetStart(StateId s) {
222     VectorFstBaseImpl<S>::SetStart(s);
223     cache_start_ = true;
224     if (s >= nknown_states_)
225       nknown_states_ = s + 1;
226   }
227 
SetFinal(StateId s,Weight w)228   void SetFinal(StateId s, Weight w) {
229     S *state = ExtendState(s);
230     state->final = w;
231     state->flags |= kCacheFinal | kCacheRecent | kCacheModified;
232   }
233 
234   // AddArc adds a single arc to state s and does incremental cache
235   // book-keeping.  For efficiency, prefer PushArc and SetArcs below
236   // when possible.
AddArc(StateId s,const Arc & arc)237   void AddArc(StateId s, const Arc &arc) {
238     S *state = ExtendState(s);
239     state->arcs.push_back(arc);
240     if (arc.ilabel == 0) {
241       ++state->niepsilons;
242     }
243     if (arc.olabel == 0) {
244       ++state->noepsilons;
245     }
246     const Arc *parc = state->arcs.empty() ? 0 : &(state->arcs.back());
247     SetProperties(AddArcProperties(Properties(), s, arc, parc));
248     state->flags |= kCacheModified;
249     if (cache_gc_ && s != cache_first_state_id_) {
250       cache_size_ += sizeof(Arc);
251       if (cache_size_ > cache_limit_)
252         GC(s, false);
253     }
254   }
255 
256   // Adds a single arc to state s but delays cache book-keeping.
257   // SetArcs must be called when all PushArc calls at a state are
258   // complete.  Do not mix with calls to AddArc.
PushArc(StateId s,const Arc & arc)259   void PushArc(StateId s, const Arc &arc) {
260     S *state = ExtendState(s);
261     state->arcs.push_back(arc);
262   }
263 
264   // Marks arcs of state s as cached and does cache book-keeping after all
265   // calls to PushArc have been completed.  Do not mix with calls to AddArc.
SetArcs(StateId s)266   void SetArcs(StateId s) {
267     S *state = ExtendState(s);
268     vector<Arc> &arcs = state->arcs;
269     state->niepsilons = state->noepsilons = 0;
270     for (size_t a = 0; a < arcs.size(); ++a) {
271       const Arc &arc = arcs[a];
272       if (arc.nextstate >= nknown_states_)
273         nknown_states_ = arc.nextstate + 1;
274       if (arc.ilabel == 0)
275         ++state->niepsilons;
276       if (arc.olabel == 0)
277         ++state->noepsilons;
278     }
279     ExpandedState(s);
280     state->flags |= kCacheArcs | kCacheRecent | kCacheModified;
281     if (cache_gc_ && s != cache_first_state_id_) {
282       cache_size_ += arcs.capacity() * sizeof(Arc);
283       if (cache_size_ > cache_limit_)
284         GC(s, false);
285     }
286   };
287 
ReserveArcs(StateId s,size_t n)288   void ReserveArcs(StateId s, size_t n) {
289     S *state = ExtendState(s);
290     state->arcs.reserve(n);
291   }
292 
DeleteArcs(StateId s,size_t n)293   void DeleteArcs(StateId s, size_t n) {
294     S *state = ExtendState(s);
295     const vector<Arc> &arcs = GetState(s)->arcs;
296     for (size_t i = 0; i < n; ++i) {
297       size_t j = arcs.size() - i - 1;
298       if (arcs[j].ilabel == 0)
299         --GetState(s)->niepsilons;
300       if (arcs[j].olabel == 0)
301         --GetState(s)->noepsilons;
302     }
303     state->arcs.resize(arcs.size() - n);
304     SetProperties(DeleteArcsProperties(Properties()));
305     state->flags |= kCacheModified;
306   }
307 
DeleteArcs(StateId s)308   void DeleteArcs(StateId s) {
309     S *state = ExtendState(s);
310     state->niepsilons = 0;
311     state->noepsilons = 0;
312     state->arcs.clear();
313     SetProperties(DeleteArcsProperties(Properties()));
314     state->flags |= kCacheModified;
315   }
316 
317   // Is the start state cached?
HasStart()318   bool HasStart() const {
319     if (!cache_start_ && Properties(kError))
320       cache_start_ = true;
321     return cache_start_;
322   }
323 
324   // Is the final weight of state s cached?
HasFinal(StateId s)325   bool HasFinal(StateId s) const {
326     const S *state = CheckState(s);
327     if (state && state->flags & kCacheFinal) {
328       state->flags |= kCacheRecent;
329       return true;
330     } else {
331       return false;
332     }
333   }
334 
335   // Are arcs of state s cached?
HasArcs(StateId s)336   bool HasArcs(StateId s) const {
337     const S *state = CheckState(s);
338     if (state && state->flags & kCacheArcs) {
339       state->flags |= kCacheRecent;
340       return true;
341     } else {
342       return false;
343     }
344   }
345 
Final(StateId s)346   Weight Final(StateId s) const {
347     const S *state = GetState(s);
348     return state->final;
349   }
350 
NumArcs(StateId s)351   size_t NumArcs(StateId s) const {
352     const S *state = GetState(s);
353     return state->arcs.size();
354   }
355 
NumInputEpsilons(StateId s)356   size_t NumInputEpsilons(StateId s) const {
357     const S *state = GetState(s);
358     return state->niepsilons;
359   }
360 
NumOutputEpsilons(StateId s)361   size_t NumOutputEpsilons(StateId s) const {
362     const S *state = GetState(s);
363     return state->noepsilons;
364   }
365 
366   // Provides information needed for generic arc iterator.
InitArcIterator(StateId s,ArcIteratorData<Arc> * data)367   void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const {
368     const S *state = GetState(s);
369     data->base = 0;
370     data->narcs = state->arcs.size();
371     data->arcs = data->narcs > 0 ? &(state->arcs[0]) : 0;
372     data->ref_count = &(state->ref_count);
373     ++(*data->ref_count);
374   }
375 
376   // Number of known states.
NumKnownStates()377   StateId NumKnownStates() const { return nknown_states_; }
378 
379   // Update number of known states taking in account the existence of state s.
UpdateNumKnownStates(StateId s)380   void UpdateNumKnownStates(StateId s) {
381     if (s >= nknown_states_)
382       nknown_states_ = s + 1;
383   }
384 
385   // Find the mininum never-expanded state Id
MinUnexpandedState()386   StateId MinUnexpandedState() const {
387     while (min_unexpanded_state_id_ < expanded_states_.size() &&
388           expanded_states_[min_unexpanded_state_id_])
389       ++min_unexpanded_state_id_;
390     return min_unexpanded_state_id_;
391   }
392 
393   // Removes from cache_states_ and uncaches (not referenced-counted)
394   // states that have not been accessed since the last GC until
395   // cache_limit_/3 bytes are uncached.  If that fails to free enough,
396   // recurs uncaching recently visited states as well. If still
397   // unable to free enough memory, then widens cache_limit_.
GC(StateId current,bool free_recent)398   void GC(StateId current, bool free_recent) {
399     if (!cache_gc_)
400       return;
401     VLOG(2) << "CacheImpl: Enter GC: object = " << Type() << "(" << this
402             << "), free recently cached = " << free_recent
403             << ", cache size = " << cache_size_
404             << ", cache limit = " << cache_limit_ << "\n";
405     typename list<StateId>::iterator siter = cache_states_.begin();
406 
407     size_t cache_target = (2 * cache_limit_)/3 + 1;
408     while (siter != cache_states_.end()) {
409       StateId s = *siter;
410       S* state = VectorFstBaseImpl<S>::GetState(s);
411       if (cache_size_ > cache_target && state->ref_count == 0 &&
412           (free_recent || !(state->flags & kCacheRecent)) && s != current) {
413         cache_size_ -= sizeof(S) + state->arcs.capacity() * sizeof(Arc);
414         allocator_->Free(state, s);
415         SetState(s, 0);
416         cache_states_.erase(siter++);
417       } else {
418         state->flags &= ~kCacheRecent;
419         ++siter;
420       }
421     }
422     if (!free_recent && cache_size_ > cache_target) {
423       GC(current, true);
424     } else {
425       while (cache_size_ > cache_target) {
426         cache_limit_ *= 2;
427         cache_target *= 2;
428       }
429     }
430     VLOG(2) << "CacheImpl: Exit GC: object = " << Type() << "(" << this
431             << "), free recently cached = " << free_recent
432             << ", cache size = " << cache_size_
433             << ", cache limit = " << cache_limit_ << "\n";
434   }
435 
ExpandedState(StateId s)436   void ExpandedState(StateId s) {
437     if (s < min_unexpanded_state_id_)
438       return;
439     while (expanded_states_.size() <= s)
440       expanded_states_.push_back(false);
441     expanded_states_[s] = true;
442   }
443 
444   // Caching on/off switch, limit and size accessors.
GetCacheGc()445   bool GetCacheGc() const { return cache_gc_; }
GetCacheLimit()446   size_t GetCacheLimit() const { return cache_limit_; }
GetCacheSize()447   size_t GetCacheSize() const { return cache_size_; }
448 
449  private:
450   static const size_t kMinCacheLimit = 8096;  // Minimum (non-zero) cache limit
451   static const uint32 kCacheFinal =  0x0001;  // Final weight has been cached
452   static const uint32 kCacheArcs =   0x0002;  // Arcs have been cached
453   static const uint32 kCacheRecent = 0x0004;  // Mark as visited since GC
454 
455  public:
456   static const uint32 kCacheModified = 0x0008;  // Mark state as modified
457   static const uint32 kCacheFlags = kCacheFinal | kCacheArcs | kCacheRecent
458                                     | kCacheModified;
459 
460  protected:
461   C *allocator_;                             // used to allocate new states
462 
463  private:
464   mutable bool cache_start_;                 // Is the start state cached?
465   StateId nknown_states_;                    // # of known states
466   vector<bool> expanded_states_;             // states that have been expanded
467   mutable StateId min_unexpanded_state_id_;  // minimum never-expanded state Id
468   StateId cache_first_state_id_;             // First cached state id
469   S *cache_first_state_;                     // First cached state
470   list<StateId> cache_states_;               // list of currently cached states
471   bool cache_gc_;                            // enable GC
472   size_t cache_size_;                        // # of bytes cached
473   size_t cache_limit_;                       // # of bytes allowed before GC
474 
475   void operator=(const CacheBaseImpl<S> &impl);    // disallow
476 };
477 
478 template <class S, class C> const uint32 CacheBaseImpl<S, C>::kCacheFinal;
479 template <class S, class C> const uint32 CacheBaseImpl<S, C>::kCacheArcs;
480 template <class S, class C> const uint32 CacheBaseImpl<S, C>::kCacheRecent;
481 template <class S, class C> const uint32 CacheBaseImpl<S, C>::kCacheModified;
482 template <class S, class C> const size_t CacheBaseImpl<S, C>::kMinCacheLimit;
483 
484 // Arcs implemented by an STL vector per state. Similar to VectorState
485 // but adds flags and ref count to keep track of what has been cached.
486 template <class A>
487 struct CacheState {
488   typedef A Arc;
489   typedef typename A::Weight Weight;
490   typedef typename A::StateId StateId;
491 
CacheStateCacheState492   CacheState() :  final(Weight::Zero()), flags(0), ref_count(0) {}
493 
ResetCacheState494   void Reset() {
495     flags = 0;
496     ref_count = 0;
497     arcs.resize(0);
498   }
499 
500   Weight final;              // Final weight
501   vector<A> arcs;            // Arcs represenation
502   size_t niepsilons;         // # of input epsilons
503   size_t noepsilons;         // # of output epsilons
504   mutable uint32 flags;
505   mutable int ref_count;
506 
507  private:
508   DISALLOW_COPY_AND_ASSIGN(CacheState);
509 };
510 
511 // A CacheBaseImpl with a commonly used CacheState.
512 template <class A>
513 class CacheImpl : public CacheBaseImpl< CacheState<A> > {
514  public:
515   typedef CacheState<A> State;
516 
CacheImpl()517   CacheImpl() {}
518 
CacheImpl(const CacheOptions & opts)519   explicit CacheImpl(const CacheOptions &opts)
520       : CacheBaseImpl< CacheState<A> >(opts) {}
521 
CacheImpl(const CacheImpl<State> & impl)522   CacheImpl(const CacheImpl<State> &impl) : CacheBaseImpl<State>(impl) {}
523 
524  private:
525   void operator=(const CacheImpl<State> &impl);    // disallow
526 };
527 
528 
529 // Use this to make a state iterator for a CacheBaseImpl-derived Fst,
530 // which must have type 'State' defined.  Note this iterator only
531 // returns those states reachable from the initial state, so consider
532 // implementing a class-specific one.
533 template <class F>
534 class CacheStateIterator : public StateIteratorBase<typename F::Arc> {
535  public:
536   typedef typename F::Arc Arc;
537   typedef typename Arc::StateId StateId;
538   typedef typename F::State State;
539   typedef CacheBaseImpl<State> Impl;
540 
CacheStateIterator(const F & fst,Impl * impl)541   CacheStateIterator(const F &fst, Impl *impl)
542       : fst_(fst), impl_(impl), s_(0) {}
543 
Done()544   bool Done() const {
545     if (s_ < impl_->NumKnownStates())
546       return false;
547     fst_.Start();  // force start state
548     if (s_ < impl_->NumKnownStates())
549       return false;
550     for (StateId u = impl_->MinUnexpandedState();
551          u < impl_->NumKnownStates();
552          u = impl_->MinUnexpandedState()) {
553       // force state expansion
554       ArcIterator<F> aiter(fst_, u);
555       aiter.SetFlags(kArcValueFlags, kArcValueFlags | kArcNoCache);
556       for (; !aiter.Done(); aiter.Next())
557         impl_->UpdateNumKnownStates(aiter.Value().nextstate);
558       impl_->ExpandedState(u);
559       if (s_ < impl_->NumKnownStates())
560         return false;
561     }
562     return true;
563   }
564 
Value()565   StateId Value() const { return s_; }
566 
Next()567   void Next() { ++s_; }
568 
Reset()569   void Reset() { s_ = 0; }
570 
571  private:
572   // This allows base class virtual access to non-virtual derived-
573   // class members of the same name. It makes the derived class more
574   // efficient to use but unsafe to further derive.
Done_()575   virtual bool Done_() const { return Done(); }
Value_()576   virtual StateId Value_() const { return Value(); }
Next_()577   virtual void Next_() { Next(); }
Reset_()578   virtual void Reset_() { Reset(); }
579 
580   const F &fst_;
581   Impl *impl_;
582   StateId s_;
583 };
584 
585 
586 // Use this to make an arc iterator for a CacheBaseImpl-derived Fst,
587 // which must have types 'Arc' and 'State' defined.
588 template <class F,
589           class C = DefaultCacheStateAllocator<CacheState<typename F::Arc> > >
590 class CacheArcIterator {
591  public:
592   typedef typename F::Arc Arc;
593   typedef typename F::State State;
594   typedef typename Arc::StateId StateId;
595   typedef CacheBaseImpl<State, C> Impl;
596 
CacheArcIterator(Impl * impl,StateId s)597   CacheArcIterator(Impl *impl, StateId s) : i_(0) {
598     state_ = impl->ExtendState(s);
599     ++state_->ref_count;
600   }
601 
~CacheArcIterator()602   ~CacheArcIterator() { --state_->ref_count;  }
603 
Done()604   bool Done() const { return i_ >= state_->arcs.size(); }
605 
Value()606   const Arc& Value() const { return state_->arcs[i_]; }
607 
Next()608   void Next() { ++i_; }
609 
Position()610   size_t Position() const { return i_; }
611 
Reset()612   void Reset() { i_ = 0; }
613 
Seek(size_t a)614   void Seek(size_t a) { i_ = a; }
615 
Flags()616   uint32 Flags() const {
617     return kArcValueFlags;
618   }
619 
SetFlags(uint32 flags,uint32 mask)620   void SetFlags(uint32 flags, uint32 mask) {}
621 
622  private:
623   const State *state_;
624   size_t i_;
625 
626   DISALLOW_COPY_AND_ASSIGN(CacheArcIterator);
627 };
628 
629 // Use this to make a mutable arc iterator for a CacheBaseImpl-derived Fst,
630 // which must have types 'Arc' and 'State' defined.
631 template <class F,
632           class C = DefaultCacheStateAllocator<CacheState<typename F::Arc> > >
633 class CacheMutableArcIterator
634     : public MutableArcIteratorBase<typename F::Arc> {
635  public:
636   typedef typename F::State State;
637   typedef typename F::Arc Arc;
638   typedef typename Arc::StateId StateId;
639   typedef typename Arc::Weight Weight;
640   typedef CacheBaseImpl<State, C> Impl;
641 
642   // You will need to call MutateCheck() in the constructor.
CacheMutableArcIterator(Impl * impl,StateId s)643   CacheMutableArcIterator(Impl *impl, StateId s) : i_(0), s_(s), impl_(impl) {
644     state_ = impl_->ExtendState(s_);
645     ++state_->ref_count;
646   };
647 
~CacheMutableArcIterator()648   ~CacheMutableArcIterator() {
649     --state_->ref_count;
650   }
651 
Done()652   bool Done() const { return i_ >= state_->arcs.size(); }
653 
Value()654   const Arc& Value() const { return state_->arcs[i_]; }
655 
Next()656   void Next() { ++i_; }
657 
Position()658   size_t Position() const { return i_; }
659 
Reset()660   void Reset() { i_ = 0; }
661 
Seek(size_t a)662   void Seek(size_t a) { i_ = a; }
663 
SetValue(const Arc & arc)664   void SetValue(const Arc& arc) {
665     state_->flags |= CacheBaseImpl<State, C>::kCacheModified;
666     uint64 properties = impl_->Properties();
667     Arc& oarc = state_->arcs[i_];
668     if (oarc.ilabel != oarc.olabel)
669       properties &= ~kNotAcceptor;
670     if (oarc.ilabel == 0) {
671       --state_->niepsilons;
672       properties &= ~kIEpsilons;
673       if (oarc.olabel == 0)
674         properties &= ~kEpsilons;
675     }
676     if (oarc.olabel == 0) {
677       --state_->noepsilons;
678       properties &= ~kOEpsilons;
679     }
680     if (oarc.weight != Weight::Zero() && oarc.weight != Weight::One())
681       properties &= ~kWeighted;
682     oarc = arc;
683     if (arc.ilabel != arc.olabel) {
684       properties |= kNotAcceptor;
685       properties &= ~kAcceptor;
686     }
687     if (arc.ilabel == 0) {
688       ++state_->niepsilons;
689       properties |= kIEpsilons;
690       properties &= ~kNoIEpsilons;
691       if (arc.olabel == 0) {
692         properties |= kEpsilons;
693         properties &= ~kNoEpsilons;
694       }
695     }
696     if (arc.olabel == 0) {
697       ++state_->noepsilons;
698       properties |= kOEpsilons;
699       properties &= ~kNoOEpsilons;
700     }
701     if (arc.weight != Weight::Zero() && arc.weight != Weight::One()) {
702       properties |= kWeighted;
703       properties &= ~kUnweighted;
704     }
705     properties &= kSetArcProperties | kAcceptor | kNotAcceptor |
706         kEpsilons | kNoEpsilons | kIEpsilons | kNoIEpsilons |
707         kOEpsilons | kNoOEpsilons | kWeighted | kUnweighted;
708     impl_->SetProperties(properties);
709   }
710 
Flags()711   uint32 Flags() const {
712     return kArcValueFlags;
713   }
714 
SetFlags(uint32 f,uint32 m)715   void SetFlags(uint32 f, uint32 m) {}
716 
717  private:
Done_()718   virtual bool Done_() const { return Done(); }
Value_()719   virtual const Arc& Value_() const { return Value(); }
Next_()720   virtual void Next_() { Next(); }
Position_()721   virtual size_t Position_() const { return Position(); }
Reset_()722   virtual void Reset_() { Reset(); }
Seek_(size_t a)723   virtual void Seek_(size_t a) { Seek(a); }
SetValue_(const Arc & a)724   virtual void SetValue_(const Arc &a) { SetValue(a); }
Flags_()725   uint32 Flags_() const { return Flags(); }
SetFlags_(uint32 f,uint32 m)726   void SetFlags_(uint32 f, uint32 m) { SetFlags(f, m); }
727 
728   size_t i_;
729   StateId s_;
730   Impl *impl_;
731   State *state_;
732 
733   DISALLOW_COPY_AND_ASSIGN(CacheMutableArcIterator);
734 };
735 
736 }  // namespace fst
737 
738 #endif  // FST_LIB_CACHE_H__
739