• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // factor-weight.h
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Author: allauzen@cs.nyu.edu (Cyril Allauzen)
16 //
17 // \file
18 // Classes to factor weights in an FST.
19 
20 #ifndef FST_LIB_FACTOR_WEIGHT_H__
21 #define FST_LIB_FACTOR_WEIGHT_H__
22 
23 #include <algorithm>
24 
25 #include <unordered_map>
26 #include <forward_list>
27 
28 #include "fst/lib/cache.h"
29 #include "fst/lib/test-properties.h"
30 
31 namespace fst {
32 
33 struct FactorWeightOptions : CacheOptions {
34   float delta;
35   bool final_only;  // only factor final weights when true
36 
FactorWeightOptionsFactorWeightOptions37   FactorWeightOptions(const CacheOptions &opts, float d, bool of)
38       : CacheOptions(opts), delta(d), final_only(of) {}
39 
40   explicit FactorWeightOptions(float d, bool of = false)
deltaFactorWeightOptions41       : delta(d), final_only(of) {}
42 
43   FactorWeightOptions(bool of = false)
deltaFactorWeightOptions44       : delta(kDelta), final_only(of) {}
45 };
46 
47 
48 // A factor iterator takes as argument a weight w and returns a
49 // sequence of pairs of weights (xi,yi) such that the sum of the
50 // products xi times yi is equal to w. If w is fully factored,
51 // the iterator should return nothing.
52 //
53 // template <class W>
54 // class FactorIterator {
55 //  public:
56 //   FactorIterator(W w);
57 //   bool Done() const;
58 //   void Next();
59 //   pair<W, W> Value() const;
60 //   void Reset();
61 // }
62 
63 
64 // Factor trivially.
65 template <class W>
66 class IdentityFactor {
67  public:
IdentityFactor(const W & w)68   IdentityFactor(const W &w) {}
Done()69   bool Done() const { return true; }
Next()70   void Next() {}
Value()71   pair<W, W> Value() const { return make_pair(W::One(), W::One()); } // unused
Reset()72   void Reset() {}
73 };
74 
75 
76 // Factor a StringWeight w as 'ab' where 'a' is a label.
77 template <typename L, StringType S = STRING_LEFT>
78 class StringFactor {
79  public:
StringFactor(const StringWeight<L,S> & w)80   StringFactor(const StringWeight<L, S> &w)
81       : weight_(w), done_(w.Size() <= 1) {}
82 
Done()83   bool Done() const { return done_; }
84 
Next()85   void Next() { done_ = true; }
86 
Value()87   pair< StringWeight<L, S>, StringWeight<L, S> > Value() const {
88     StringWeightIterator<L, S> iter(weight_);
89     StringWeight<L, S> w1(iter.Value());
90     StringWeight<L, S> w2;
91     for (iter.Next(); !iter.Done(); iter.Next())
92       w2.PushBack(iter.Value());
93     return make_pair(w1, w2);
94   }
95 
Reset()96   void Reset() { done_ = weight_.Size() <= 1; }
97 
98  private:
99   StringWeight<L, S> weight_;
100   bool done_;
101 };
102 
103 
104 // Factor a GallicWeight using StringFactor.
105 template <class L, class W, StringType S = STRING_LEFT>
106 class GallicFactor {
107  public:
GallicFactor(const GallicWeight<L,W,S> & w)108   GallicFactor(const GallicWeight<L, W, S> &w)
109       : weight_(w), done_(w.Value1().Size() <= 1) {}
110 
Done()111   bool Done() const { return done_; }
112 
Next()113   void Next() { done_ = true; }
114 
Value()115   pair< GallicWeight<L, W, S>, GallicWeight<L, W, S> > Value() const {
116     StringFactor<L, S> iter(weight_.Value1());
117     GallicWeight<L, W, S> w1(iter.Value().first, weight_.Value2());
118     GallicWeight<L, W, S> w2(iter.Value().second, W::One());
119     return make_pair(w1, w2);
120   }
121 
Reset()122   void Reset() { done_ = weight_.Value1().Size() <= 1; }
123 
124  private:
125   GallicWeight<L, W, S> weight_;
126   bool done_;
127 };
128 
129 
130 // Implementation class for FactorWeight
131 template <class A, class F>
132 class FactorWeightFstImpl
133     : public CacheImpl<A> {
134  public:
135   using FstImpl<A>::SetType;
136   using FstImpl<A>::SetProperties;
137   using FstImpl<A>::Properties;
138   using FstImpl<A>::SetInputSymbols;
139   using FstImpl<A>::SetOutputSymbols;
140 
141   using CacheBaseImpl< CacheState<A> >::HasStart;
142   using CacheBaseImpl< CacheState<A> >::HasFinal;
143   using CacheBaseImpl< CacheState<A> >::HasArcs;
144 
145   typedef A Arc;
146   typedef typename A::Label Label;
147   typedef typename A::Weight Weight;
148   typedef typename A::StateId StateId;
149   typedef F FactorIterator;
150 
151   struct Element {
ElementElement152     Element() {}
153 
ElementElement154     Element(StateId s, Weight w) : state(s), weight(w) {}
155 
156     StateId state;     // Input state Id
157     Weight weight;     // Residual weight
158   };
159 
FactorWeightFstImpl(const Fst<A> & fst,const FactorWeightOptions & opts)160   FactorWeightFstImpl(const Fst<A> &fst, const FactorWeightOptions &opts)
161       : CacheImpl<A>(opts), fst_(fst.Copy()), delta_(opts.delta),
162         final_only_(opts.final_only) {
163     SetType("factor-weight");
164     uint64 props = fst.Properties(kFstProperties, false);
165     SetProperties(FactorWeightProperties(props), kCopyProperties);
166 
167     SetInputSymbols(fst.InputSymbols());
168     SetOutputSymbols(fst.OutputSymbols());
169   }
170 
~FactorWeightFstImpl()171   ~FactorWeightFstImpl() {
172     delete fst_;
173   }
174 
Start()175   StateId Start() {
176     if (!HasStart()) {
177       StateId s = fst_->Start();
178       if (s == kNoStateId)
179         return kNoStateId;
180       StateId start = FindState(Element(fst_->Start(), Weight::One()));
181       this->SetStart(start);
182     }
183     return CacheImpl<A>::Start();
184   }
185 
Final(StateId s)186   Weight Final(StateId s) {
187     if (!HasFinal(s)) {
188       const Element &e = elements_[s];
189       // TODO: fix so cast is unnecessary
190       Weight w = e.state == kNoStateId
191                  ? e.weight
192                  : (Weight) Times(e.weight, fst_->Final(e.state));
193       FactorIterator f(w);
194       if (w != Weight::Zero() && f.Done())
195         this->SetFinal(s, w);
196       else
197         this->SetFinal(s, Weight::Zero());
198     }
199     return CacheImpl<A>::Final(s);
200   }
201 
NumArcs(StateId s)202   size_t NumArcs(StateId s) {
203     if (!HasArcs(s))
204       Expand(s);
205     return CacheImpl<A>::NumArcs(s);
206   }
207 
NumInputEpsilons(StateId s)208   size_t NumInputEpsilons(StateId s) {
209     if (!HasArcs(s))
210       Expand(s);
211     return CacheImpl<A>::NumInputEpsilons(s);
212   }
213 
NumOutputEpsilons(StateId s)214   size_t NumOutputEpsilons(StateId s) {
215     if (!HasArcs(s))
216       Expand(s);
217     return CacheImpl<A>::NumOutputEpsilons(s);
218   }
219 
InitArcIterator(StateId s,ArcIteratorData<A> * data)220   void InitArcIterator(StateId s, ArcIteratorData<A> *data) {
221     if (!HasArcs(s))
222       Expand(s);
223     CacheImpl<A>::InitArcIterator(s, data);
224   }
225 
226 
227   // Find state corresponding to an element. Create new state
228   // if element not found.
FindState(const Element & e)229   StateId FindState(const Element &e) {
230     if (final_only_ && e.weight == Weight::One()) {
231       while (unfactored_.size() <= (unsigned int)e.state)
232         unfactored_.push_back(kNoStateId);
233       if (unfactored_[e.state] == kNoStateId) {
234         unfactored_[e.state] = elements_.size();
235         elements_.push_back(e);
236       }
237       return unfactored_[e.state];
238     } else {
239       typename ElementMap::iterator eit = element_map_.find(e);
240       if (eit != element_map_.end()) {
241         return (*eit).second;
242       } else {
243         StateId s = elements_.size();
244         elements_.push_back(e);
245         element_map_.insert(pair<const Element, StateId>(e, s));
246         return s;
247       }
248     }
249   }
250 
251   // Computes the outgoing transitions from a state, creating new destination
252   // states as needed.
Expand(StateId s)253   void Expand(StateId s) {
254     Element e = elements_[s];
255     if (e.state != kNoStateId) {
256       for (ArcIterator< Fst<A> > ait(*fst_, e.state);
257            !ait.Done();
258            ait.Next()) {
259         const A &arc = ait.Value();
260         Weight w = Times(e.weight, arc.weight);
261         FactorIterator fit(w);
262         if (final_only_ || fit.Done()) {
263           StateId d = FindState(Element(arc.nextstate, Weight::One()));
264           this->AddArc(s, Arc(arc.ilabel, arc.olabel, w, d));
265         } else {
266           for (; !fit.Done(); fit.Next()) {
267             const pair<Weight, Weight> &p = fit.Value();
268             StateId d = FindState(Element(arc.nextstate,
269                                           p.second.Quantize(delta_)));
270             this->AddArc(s, Arc(arc.ilabel, arc.olabel, p.first, d));
271           }
272         }
273       }
274     }
275     if ((e.state == kNoStateId) ||
276         (fst_->Final(e.state) != Weight::Zero())) {
277       Weight w = e.state == kNoStateId
278                  ? e.weight
279                  : Times(e.weight, fst_->Final(e.state));
280       for (FactorIterator fit(w);
281            !fit.Done();
282            fit.Next()) {
283         const pair<Weight, Weight> &p = fit.Value();
284         StateId d = FindState(Element(kNoStateId,
285                                       p.second.Quantize(delta_)));
286         this->AddArc(s, Arc(0, 0, p.first, d));
287       }
288     }
289     this->SetArcs(s);
290   }
291 
292  private:
293   // Equality function for Elements, assume weights have been quantized.
294   class ElementEqual {
295    public:
operator()296     bool operator()(const Element &x, const Element &y) const {
297       return x.state == y.state && x.weight == y.weight;
298     }
299   };
300 
301   // Hash function for Elements to Fst states.
302   class ElementKey {
303    public:
operator()304     size_t operator()(const Element &x) const {
305       return static_cast<size_t>(x.state * kPrime + x.weight.Hash());
306     }
307    private:
308     static const int kPrime = 7853;
309   };
310 
311   typedef std::unordered_map<Element, StateId, ElementKey, ElementEqual> ElementMap;
312 
313   const Fst<A> *fst_;
314   float delta_;
315   bool final_only_;
316   vector<Element> elements_;  // mapping Fst state to Elements
317   ElementMap element_map_;    // mapping Elements to Fst state
318   // mapping between old/new 'StateId' for states that do not need to
319   // be factored when 'final_only_' is true
320   vector<StateId> unfactored_;
321 
322   DISALLOW_EVIL_CONSTRUCTORS(FactorWeightFstImpl);
323 };
324 
325 
326 // FactorWeightFst takes as template parameter a FactorIterator as
327 // defined above. The result of weight factoring is a transducer
328 // equivalent to the input whose path weights have been factored
329 // according to the FactorIterator. States and transitions will be
330 // added as necessary. The algorithm is a generalization to arbitrary
331 // weights of the second step of the input epsilon-normalization
332 // algorithm due to Mohri, "Generic epsilon-removal and input
333 // epsilon-normalization algorithms for weighted transducers",
334 // International Journal of Computer Science 13(1): 129-143 (2002).
335 template <class A, class F>
336 class FactorWeightFst : public Fst<A> {
337  public:
338   friend class ArcIterator< FactorWeightFst<A, F> >;
339   friend class CacheStateIterator< FactorWeightFst<A, F> >;
340   friend class CacheArcIterator< FactorWeightFst<A, F> >;
341 
342   typedef A Arc;
343   typedef typename A::Weight Weight;
344   typedef typename A::StateId StateId;
345   typedef CacheState<A> State;
346 
FactorWeightFst(const Fst<A> & fst)347   FactorWeightFst(const Fst<A> &fst)
348       : impl_(new FactorWeightFstImpl<A, F>(fst, FactorWeightOptions())) {}
349 
FactorWeightFst(const Fst<A> & fst,const FactorWeightOptions & opts)350   FactorWeightFst(const Fst<A> &fst,  const FactorWeightOptions &opts)
351       : impl_(new FactorWeightFstImpl<A, F>(fst, opts)) {}
FactorWeightFst(const FactorWeightFst<A,F> & fst)352   FactorWeightFst(const FactorWeightFst<A, F> &fst) : Fst<A>(fst), impl_(fst.impl_) {
353     impl_->IncrRefCount();
354   }
355 
~FactorWeightFst()356   virtual ~FactorWeightFst() { if (!impl_->DecrRefCount()) delete impl_;  }
357 
Start()358   virtual StateId Start() const { return impl_->Start(); }
359 
Final(StateId s)360   virtual Weight Final(StateId s) const { return impl_->Final(s); }
361 
NumArcs(StateId s)362   virtual size_t NumArcs(StateId s) const { return impl_->NumArcs(s); }
363 
NumInputEpsilons(StateId s)364   virtual size_t NumInputEpsilons(StateId s) const {
365     return impl_->NumInputEpsilons(s);
366   }
367 
NumOutputEpsilons(StateId s)368   virtual size_t NumOutputEpsilons(StateId s) const {
369     return impl_->NumOutputEpsilons(s);
370   }
371 
Properties(uint64 mask,bool test)372   virtual uint64 Properties(uint64 mask, bool test) const {
373     if (test) {
374       uint64 known, test = TestProperties(*this, mask, &known);
375       impl_->SetProperties(test, known);
376       return test & mask;
377     } else {
378       return impl_->Properties(mask);
379     }
380   }
381 
Type()382   virtual const string& Type() const { return impl_->Type(); }
383 
Copy()384   virtual FactorWeightFst<A, F> *Copy() const {
385     return new FactorWeightFst<A, F>(*this);
386   }
387 
InputSymbols()388   virtual const SymbolTable* InputSymbols() const {
389     return impl_->InputSymbols();
390   }
391 
OutputSymbols()392   virtual const SymbolTable* OutputSymbols() const {
393     return impl_->OutputSymbols();
394   }
395 
396   virtual inline void InitStateIterator(StateIteratorData<A> *data) const;
397 
InitArcIterator(StateId s,ArcIteratorData<A> * data)398   virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
399     impl_->InitArcIterator(s, data);
400   }
401 
402  private:
Impl()403   FactorWeightFstImpl<A, F> *Impl() { return impl_; }
404 
405   FactorWeightFstImpl<A, F> *impl_;
406 
407   void operator=(const FactorWeightFst<A, F> &fst);  // Disallow
408 };
409 
410 
411 // Specialization for FactorWeightFst.
412 template<class A, class F>
413 class StateIterator< FactorWeightFst<A, F> >
414     : public CacheStateIterator< FactorWeightFst<A, F> > {
415  public:
StateIterator(const FactorWeightFst<A,F> & fst)416   explicit StateIterator(const FactorWeightFst<A, F> &fst)
417       : CacheStateIterator< FactorWeightFst<A, F> >(fst) {}
418 };
419 
420 
421 // Specialization for FactorWeightFst.
422 template <class A, class F>
423 class ArcIterator< FactorWeightFst<A, F> >
424     : public CacheArcIterator< FactorWeightFst<A, F> > {
425  public:
426   typedef typename A::StateId StateId;
427 
ArcIterator(const FactorWeightFst<A,F> & fst,StateId s)428   ArcIterator(const FactorWeightFst<A, F> &fst, StateId s)
429       : CacheArcIterator< FactorWeightFst<A, F> >(fst, s) {
430     if (!fst.impl_->HasArcs(s))
431       fst.impl_->Expand(s);
432   }
433 
434  private:
435   DISALLOW_EVIL_CONSTRUCTORS(ArcIterator);
436 };
437 
438 template <class A, class F> inline
InitStateIterator(StateIteratorData<A> * data)439 void FactorWeightFst<A, F>::InitStateIterator(StateIteratorData<A> *data) const
440 {
441   data->base = new StateIterator< FactorWeightFst<A, F> >(*this);
442 }
443 
444 
445 }  // namespace fst
446 
447 #endif // FST_LIB_FACTOR_WEIGHT_H__
448