• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // factor-weight.h
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Author: allauzen@cs.nyu.edu (Cyril Allauzen)
16 //
17 // \file
18 // Classes to factor weights in an FST.
19 
20 #ifndef FST_LIB_FACTOR_WEIGHT_H__
21 #define FST_LIB_FACTOR_WEIGHT_H__
22 
23 #include <algorithm>
24 
25 #include <ext/hash_map>
26 using __gnu_cxx::hash_map;
27 #include <ext/slist>
28 using __gnu_cxx::slist;
29 
30 #include "fst/lib/cache.h"
31 #include "fst/lib/test-properties.h"
32 
33 namespace fst {
34 
35 struct FactorWeightOptions : CacheOptions {
36   float delta;
37   bool final_only;  // only factor final weights when true
38 
FactorWeightOptionsFactorWeightOptions39   FactorWeightOptions(const CacheOptions &opts, float d, bool of)
40       : CacheOptions(opts), delta(d), final_only(of) {}
41 
42   explicit FactorWeightOptions(float d, bool of = false)
deltaFactorWeightOptions43       : delta(d), final_only(of) {}
44 
45   FactorWeightOptions(bool of = false)
deltaFactorWeightOptions46       : delta(kDelta), final_only(of) {}
47 };
48 
49 
50 // A factor iterator takes as argument a weight w and returns a
51 // sequence of pairs of weights (xi,yi) such that the sum of the
52 // products xi times yi is equal to w. If w is fully factored,
53 // the iterator should return nothing.
54 //
55 // template <class W>
56 // class FactorIterator {
57 //  public:
58 //   FactorIterator(W w);
59 //   bool Done() const;
60 //   void Next();
61 //   pair<W, W> Value() const;
62 //   void Reset();
63 // }
64 
65 
66 // Factor trivially.
67 template <class W>
68 class IdentityFactor {
69  public:
IdentityFactor(const W & w)70   IdentityFactor(const W &w) {}
Done()71   bool Done() const { return true; }
Next()72   void Next() {}
Value()73   pair<W, W> Value() const { return make_pair(W::One(), W::One()); } // unused
Reset()74   void Reset() {}
75 };
76 
77 
78 // Factor a StringWeight w as 'ab' where 'a' is a label.
79 template <typename L, StringType S = STRING_LEFT>
80 class StringFactor {
81  public:
StringFactor(const StringWeight<L,S> & w)82   StringFactor(const StringWeight<L, S> &w)
83       : weight_(w), done_(w.Size() <= 1) {}
84 
Done()85   bool Done() const { return done_; }
86 
Next()87   void Next() { done_ = true; }
88 
Value()89   pair< StringWeight<L, S>, StringWeight<L, S> > Value() const {
90     StringWeightIterator<L, S> iter(weight_);
91     StringWeight<L, S> w1(iter.Value());
92     StringWeight<L, S> w2;
93     for (iter.Next(); !iter.Done(); iter.Next())
94       w2.PushBack(iter.Value());
95     return make_pair(w1, w2);
96   }
97 
Reset()98   void Reset() { done_ = weight_.Size() <= 1; }
99 
100  private:
101   StringWeight<L, S> weight_;
102   bool done_;
103 };
104 
105 
106 // Factor a GallicWeight using StringFactor.
107 template <class L, class W, StringType S = STRING_LEFT>
108 class GallicFactor {
109  public:
GallicFactor(const GallicWeight<L,W,S> & w)110   GallicFactor(const GallicWeight<L, W, S> &w)
111       : weight_(w), done_(w.Value1().Size() <= 1) {}
112 
Done()113   bool Done() const { return done_; }
114 
Next()115   void Next() { done_ = true; }
116 
Value()117   pair< GallicWeight<L, W, S>, GallicWeight<L, W, S> > Value() const {
118     StringFactor<L, S> iter(weight_.Value1());
119     GallicWeight<L, W, S> w1(iter.Value().first, weight_.Value2());
120     GallicWeight<L, W, S> w2(iter.Value().second, W::One());
121     return make_pair(w1, w2);
122   }
123 
Reset()124   void Reset() { done_ = weight_.Value1().Size() <= 1; }
125 
126  private:
127   GallicWeight<L, W, S> weight_;
128   bool done_;
129 };
130 
131 
132 // Implementation class for FactorWeight
133 template <class A, class F>
134 class FactorWeightFstImpl
135     : public CacheImpl<A> {
136  public:
137   using FstImpl<A>::SetType;
138   using FstImpl<A>::SetProperties;
139   using FstImpl<A>::Properties;
140   using FstImpl<A>::SetInputSymbols;
141   using FstImpl<A>::SetOutputSymbols;
142 
143   using CacheBaseImpl< CacheState<A> >::HasStart;
144   using CacheBaseImpl< CacheState<A> >::HasFinal;
145   using CacheBaseImpl< CacheState<A> >::HasArcs;
146 
147   typedef A Arc;
148   typedef typename A::Label Label;
149   typedef typename A::Weight Weight;
150   typedef typename A::StateId StateId;
151   typedef F FactorIterator;
152 
153   struct Element {
ElementElement154     Element() {}
155 
ElementElement156     Element(StateId s, Weight w) : state(s), weight(w) {}
157 
158     StateId state;     // Input state Id
159     Weight weight;     // Residual weight
160   };
161 
FactorWeightFstImpl(const Fst<A> & fst,const FactorWeightOptions & opts)162   FactorWeightFstImpl(const Fst<A> &fst, const FactorWeightOptions &opts)
163       : CacheImpl<A>(opts), fst_(fst.Copy()), delta_(opts.delta),
164         final_only_(opts.final_only) {
165     SetType("factor-weight");
166     uint64 props = fst.Properties(kFstProperties, false);
167     SetProperties(FactorWeightProperties(props), kCopyProperties);
168 
169     SetInputSymbols(fst.InputSymbols());
170     SetOutputSymbols(fst.OutputSymbols());
171   }
172 
~FactorWeightFstImpl()173   ~FactorWeightFstImpl() {
174     delete fst_;
175   }
176 
Start()177   StateId Start() {
178     if (!HasStart()) {
179       StateId s = fst_->Start();
180       if (s == kNoStateId)
181         return kNoStateId;
182       StateId start = FindState(Element(fst_->Start(), Weight::One()));
183       SetStart(start);
184     }
185     return CacheImpl<A>::Start();
186   }
187 
Final(StateId s)188   Weight Final(StateId s) {
189     if (!HasFinal(s)) {
190       const Element &e = elements_[s];
191       // TODO: fix so cast is unnecessary
192       Weight w = e.state == kNoStateId
193                  ? e.weight
194                  : (Weight) Times(e.weight, fst_->Final(e.state));
195       FactorIterator f(w);
196       if (w != Weight::Zero() && f.Done())
197         SetFinal(s, w);
198       else
199         SetFinal(s, Weight::Zero());
200     }
201     return CacheImpl<A>::Final(s);
202   }
203 
NumArcs(StateId s)204   size_t NumArcs(StateId s) {
205     if (!HasArcs(s))
206       Expand(s);
207     return CacheImpl<A>::NumArcs(s);
208   }
209 
NumInputEpsilons(StateId s)210   size_t NumInputEpsilons(StateId s) {
211     if (!HasArcs(s))
212       Expand(s);
213     return CacheImpl<A>::NumInputEpsilons(s);
214   }
215 
NumOutputEpsilons(StateId s)216   size_t NumOutputEpsilons(StateId s) {
217     if (!HasArcs(s))
218       Expand(s);
219     return CacheImpl<A>::NumOutputEpsilons(s);
220   }
221 
InitArcIterator(StateId s,ArcIteratorData<A> * data)222   void InitArcIterator(StateId s, ArcIteratorData<A> *data) {
223     if (!HasArcs(s))
224       Expand(s);
225     CacheImpl<A>::InitArcIterator(s, data);
226   }
227 
228 
229   // Find state corresponding to an element. Create new state
230   // if element not found.
FindState(const Element & e)231   StateId FindState(const Element &e) {
232     if (final_only_ && e.weight == Weight::One()) {
233       while (unfactored_.size() <= (unsigned int)e.state)
234         unfactored_.push_back(kNoStateId);
235       if (unfactored_[e.state] == kNoStateId) {
236         unfactored_[e.state] = elements_.size();
237         elements_.push_back(e);
238       }
239       return unfactored_[e.state];
240     } else {
241       typename ElementMap::iterator eit = element_map_.find(e);
242       if (eit != element_map_.end()) {
243         return (*eit).second;
244       } else {
245         StateId s = elements_.size();
246         elements_.push_back(e);
247         element_map_.insert(pair<const Element, StateId>(e, s));
248         return s;
249       }
250     }
251   }
252 
253   // Computes the outgoing transitions from a state, creating new destination
254   // states as needed.
Expand(StateId s)255   void Expand(StateId s) {
256     Element e = elements_[s];
257     if (e.state != kNoStateId) {
258       for (ArcIterator< Fst<A> > ait(*fst_, e.state);
259            !ait.Done();
260            ait.Next()) {
261         const A &arc = ait.Value();
262         Weight w = Times(e.weight, arc.weight);
263         FactorIterator fit(w);
264         if (final_only_ || fit.Done()) {
265           StateId d = FindState(Element(arc.nextstate, Weight::One()));
266           AddArc(s, Arc(arc.ilabel, arc.olabel, w, d));
267         } else {
268           for (; !fit.Done(); fit.Next()) {
269             const pair<Weight, Weight> &p = fit.Value();
270             StateId d = FindState(Element(arc.nextstate,
271                                           p.second.Quantize(delta_)));
272             AddArc(s, Arc(arc.ilabel, arc.olabel, p.first, d));
273           }
274         }
275       }
276     }
277     if ((e.state == kNoStateId) ||
278         (fst_->Final(e.state) != Weight::Zero())) {
279       Weight w = e.state == kNoStateId
280                  ? e.weight
281                  : Times(e.weight, fst_->Final(e.state));
282       for (FactorIterator fit(w);
283            !fit.Done();
284            fit.Next()) {
285         const pair<Weight, Weight> &p = fit.Value();
286         StateId d = FindState(Element(kNoStateId,
287                                       p.second.Quantize(delta_)));
288         AddArc(s, Arc(0, 0, p.first, d));
289       }
290     }
291     SetArcs(s);
292   }
293 
294  private:
295   // Equality function for Elements, assume weights have been quantized.
296   class ElementEqual {
297    public:
operator()298     bool operator()(const Element &x, const Element &y) const {
299       return x.state == y.state && x.weight == y.weight;
300     }
301   };
302 
303   // Hash function for Elements to Fst states.
304   class ElementKey {
305    public:
operator()306     size_t operator()(const Element &x) const {
307       return static_cast<size_t>(x.state * kPrime + x.weight.Hash());
308     }
309    private:
310     static const int kPrime = 7853;
311   };
312 
313   typedef hash_map<Element, StateId, ElementKey, ElementEqual> ElementMap;
314 
315   const Fst<A> *fst_;
316   float delta_;
317   bool final_only_;
318   vector<Element> elements_;  // mapping Fst state to Elements
319   ElementMap element_map_;    // mapping Elements to Fst state
320   // mapping between old/new 'StateId' for states that do not need to
321   // be factored when 'final_only_' is true
322   vector<StateId> unfactored_;
323 
324   DISALLOW_EVIL_CONSTRUCTORS(FactorWeightFstImpl);
325 };
326 
327 
328 // FactorWeightFst takes as template parameter a FactorIterator as
329 // defined above. The result of weight factoring is a transducer
330 // equivalent to the input whose path weights have been factored
331 // according to the FactorIterator. States and transitions will be
332 // added as necessary. The algorithm is a generalization to arbitrary
333 // weights of the second step of the input epsilon-normalization
334 // algorithm due to Mohri, "Generic epsilon-removal and input
335 // epsilon-normalization algorithms for weighted transducers",
336 // International Journal of Computer Science 13(1): 129-143 (2002).
337 template <class A, class F>
338 class FactorWeightFst : public Fst<A> {
339  public:
340   friend class ArcIterator< FactorWeightFst<A, F> >;
341   friend class CacheStateIterator< FactorWeightFst<A, F> >;
342   friend class CacheArcIterator< FactorWeightFst<A, F> >;
343 
344   typedef A Arc;
345   typedef typename A::Weight Weight;
346   typedef typename A::StateId StateId;
347   typedef CacheState<A> State;
348 
FactorWeightFst(const Fst<A> & fst)349   FactorWeightFst(const Fst<A> &fst)
350       : impl_(new FactorWeightFstImpl<A, F>(fst, FactorWeightOptions())) {}
351 
FactorWeightFst(const Fst<A> & fst,const FactorWeightOptions & opts)352   FactorWeightFst(const Fst<A> &fst,  const FactorWeightOptions &opts)
353       : impl_(new FactorWeightFstImpl<A, F>(fst, opts)) {}
FactorWeightFst(const FactorWeightFst<A,F> & fst)354   FactorWeightFst(const FactorWeightFst<A, F> &fst) : Fst<A>(fst), impl_(fst.impl_) {
355     impl_->IncrRefCount();
356   }
357 
~FactorWeightFst()358   virtual ~FactorWeightFst() { if (!impl_->DecrRefCount()) delete impl_;  }
359 
Start()360   virtual StateId Start() const { return impl_->Start(); }
361 
Final(StateId s)362   virtual Weight Final(StateId s) const { return impl_->Final(s); }
363 
NumArcs(StateId s)364   virtual size_t NumArcs(StateId s) const { return impl_->NumArcs(s); }
365 
NumInputEpsilons(StateId s)366   virtual size_t NumInputEpsilons(StateId s) const {
367     return impl_->NumInputEpsilons(s);
368   }
369 
NumOutputEpsilons(StateId s)370   virtual size_t NumOutputEpsilons(StateId s) const {
371     return impl_->NumOutputEpsilons(s);
372   }
373 
Properties(uint64 mask,bool test)374   virtual uint64 Properties(uint64 mask, bool test) const {
375     if (test) {
376       uint64 known, test = TestProperties(*this, mask, &known);
377       impl_->SetProperties(test, known);
378       return test & mask;
379     } else {
380       return impl_->Properties(mask);
381     }
382   }
383 
Type()384   virtual const string& Type() const { return impl_->Type(); }
385 
Copy()386   virtual FactorWeightFst<A, F> *Copy() const {
387     return new FactorWeightFst<A, F>(*this);
388   }
389 
InputSymbols()390   virtual const SymbolTable* InputSymbols() const {
391     return impl_->InputSymbols();
392   }
393 
OutputSymbols()394   virtual const SymbolTable* OutputSymbols() const {
395     return impl_->OutputSymbols();
396   }
397 
398   virtual inline void InitStateIterator(StateIteratorData<A> *data) const;
399 
InitArcIterator(StateId s,ArcIteratorData<A> * data)400   virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
401     impl_->InitArcIterator(s, data);
402   }
403 
404  private:
Impl()405   FactorWeightFstImpl<A, F> *Impl() { return impl_; }
406 
407   FactorWeightFstImpl<A, F> *impl_;
408 
409   void operator=(const FactorWeightFst<A, F> &fst);  // Disallow
410 };
411 
412 
413 // Specialization for FactorWeightFst.
414 template<class A, class F>
415 class StateIterator< FactorWeightFst<A, F> >
416     : public CacheStateIterator< FactorWeightFst<A, F> > {
417  public:
StateIterator(const FactorWeightFst<A,F> & fst)418   explicit StateIterator(const FactorWeightFst<A, F> &fst)
419       : CacheStateIterator< FactorWeightFst<A, F> >(fst) {}
420 };
421 
422 
423 // Specialization for FactorWeightFst.
424 template <class A, class F>
425 class ArcIterator< FactorWeightFst<A, F> >
426     : public CacheArcIterator< FactorWeightFst<A, F> > {
427  public:
428   typedef typename A::StateId StateId;
429 
ArcIterator(const FactorWeightFst<A,F> & fst,StateId s)430   ArcIterator(const FactorWeightFst<A, F> &fst, StateId s)
431       : CacheArcIterator< FactorWeightFst<A, F> >(fst, s) {
432     if (!fst.impl_->HasArcs(s))
433       fst.impl_->Expand(s);
434   }
435 
436  private:
437   DISALLOW_EVIL_CONSTRUCTORS(ArcIterator);
438 };
439 
440 template <class A, class F> inline
InitStateIterator(StateIteratorData<A> * data)441 void FactorWeightFst<A, F>::InitStateIterator(StateIteratorData<A> *data) const
442 {
443   data->base = new StateIterator< FactorWeightFst<A, F> >(*this);
444 }
445 
446 
447 }  // namespace fst
448 
449 #endif // FST_LIB_FACTOR_WEIGHT_H__
450