• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // factor-weight.h
2 
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Copyright 2005-2010 Google, Inc.
16 // Author: allauzen@google.com (Cyril Allauzen)
17 //
18 // \file
19 // Classes to factor weights in an FST.
20 
21 #ifndef FST_LIB_FACTOR_WEIGHT_H__
22 #define FST_LIB_FACTOR_WEIGHT_H__
23 
24 #include <algorithm>
25 #include <unordered_map>
26 using std::tr1::unordered_map;
27 using std::tr1::unordered_multimap;
28 #include <fst/slist.h>
29 #include <string>
30 #include <utility>
31 using std::pair; using std::make_pair;
32 #include <vector>
33 using std::vector;
34 
35 #include <fst/cache.h>
36 #include <fst/test-properties.h>
37 
38 
39 namespace fst {
40 
41 const uint32 kFactorFinalWeights = 0x00000001;
42 const uint32 kFactorArcWeights   = 0x00000002;
43 
44 template <class Arc>
45 struct FactorWeightOptions : CacheOptions {
46   typedef typename Arc::Label Label;
47   float delta;
48   uint32 mode;         // factor arc weights and/or final weights
49   Label final_ilabel;  // input label of arc created when factoring final w's
50   Label final_olabel;  // output label of arc created when factoring final w's
51 
52   FactorWeightOptions(const CacheOptions &opts, float d,
53                       uint32 m = kFactorArcWeights | kFactorFinalWeights,
54                       Label il = 0, Label ol = 0)
CacheOptionsFactorWeightOptions55       : CacheOptions(opts), delta(d), mode(m), final_ilabel(il),
56         final_olabel(ol) {}
57 
58   explicit FactorWeightOptions(
59       float d, uint32 m = kFactorArcWeights | kFactorFinalWeights,
60       Label il = 0, Label ol = 0)
deltaFactorWeightOptions61       : delta(d), mode(m), final_ilabel(il), final_olabel(ol) {}
62 
63   FactorWeightOptions(uint32 m = kFactorArcWeights | kFactorFinalWeights,
64                       Label il = 0, Label ol = 0)
deltaFactorWeightOptions65       : delta(kDelta), mode(m), final_ilabel(il), final_olabel(ol) {}
66 };
67 
68 
69 // A factor iterator takes as argument a weight w and returns a
70 // sequence of pairs of weights (xi,yi) such that the sum of the
71 // products xi times yi is equal to w. If w is fully factored,
72 // the iterator should return nothing.
73 //
74 // template <class W>
75 // class FactorIterator {
76 //  public:
77 //   FactorIterator(W w);
78 //   bool Done() const;
79 //   void Next();
80 //   pair<W, W> Value() const;
81 //   void Reset();
82 // }
83 
84 
85 // Factor trivially.
86 template <class W>
87 class IdentityFactor {
88  public:
IdentityFactor(const W & w)89   IdentityFactor(const W &w) {}
Done()90   bool Done() const { return true; }
Next()91   void Next() {}
Value()92   pair<W, W> Value() const { return make_pair(W::One(), W::One()); } // unused
Reset()93   void Reset() {}
94 };
95 
96 
97 // Factor a StringWeight w as 'ab' where 'a' is a label.
98 template <typename L, StringType S = STRING_LEFT>
99 class StringFactor {
100  public:
StringFactor(const StringWeight<L,S> & w)101   StringFactor(const StringWeight<L, S> &w)
102       : weight_(w), done_(w.Size() <= 1) {}
103 
Done()104   bool Done() const { return done_; }
105 
Next()106   void Next() { done_ = true; }
107 
Value()108   pair< StringWeight<L, S>, StringWeight<L, S> > Value() const {
109     StringWeightIterator<L, S> iter(weight_);
110     StringWeight<L, S> w1(iter.Value());
111     StringWeight<L, S> w2;
112     for (iter.Next(); !iter.Done(); iter.Next())
113       w2.PushBack(iter.Value());
114     return make_pair(w1, w2);
115   }
116 
Reset()117   void Reset() { done_ = weight_.Size() <= 1; }
118 
119  private:
120   StringWeight<L, S> weight_;
121   bool done_;
122 };
123 
124 
125 // Factor a GallicWeight using StringFactor.
126 template <class L, class W, StringType S = STRING_LEFT>
127 class GallicFactor {
128  public:
GallicFactor(const GallicWeight<L,W,S> & w)129   GallicFactor(const GallicWeight<L, W, S> &w)
130       : weight_(w), done_(w.Value1().Size() <= 1) {}
131 
Done()132   bool Done() const { return done_; }
133 
Next()134   void Next() { done_ = true; }
135 
Value()136   pair< GallicWeight<L, W, S>, GallicWeight<L, W, S> > Value() const {
137     StringFactor<L, S> iter(weight_.Value1());
138     GallicWeight<L, W, S> w1(iter.Value().first, weight_.Value2());
139     GallicWeight<L, W, S> w2(iter.Value().second, W::One());
140     return make_pair(w1, w2);
141   }
142 
Reset()143   void Reset() { done_ = weight_.Value1().Size() <= 1; }
144 
145  private:
146   GallicWeight<L, W, S> weight_;
147   bool done_;
148 };
149 
150 
151 // Implementation class for FactorWeight
152 template <class A, class F>
153 class FactorWeightFstImpl
154     : public CacheImpl<A> {
155  public:
156   using FstImpl<A>::SetType;
157   using FstImpl<A>::SetProperties;
158   using FstImpl<A>::SetInputSymbols;
159   using FstImpl<A>::SetOutputSymbols;
160 
161   using CacheBaseImpl< CacheState<A> >::PushArc;
162   using CacheBaseImpl< CacheState<A> >::HasStart;
163   using CacheBaseImpl< CacheState<A> >::HasFinal;
164   using CacheBaseImpl< CacheState<A> >::HasArcs;
165   using CacheBaseImpl< CacheState<A> >::SetArcs;
166   using CacheBaseImpl< CacheState<A> >::SetFinal;
167   using CacheBaseImpl< CacheState<A> >::SetStart;
168 
169   typedef A Arc;
170   typedef typename A::Label Label;
171   typedef typename A::Weight Weight;
172   typedef typename A::StateId StateId;
173   typedef F FactorIterator;
174 
175   struct Element {
ElementElement176     Element() {}
177 
ElementElement178     Element(StateId s, Weight w) : state(s), weight(w) {}
179 
180     StateId state;     // Input state Id
181     Weight weight;     // Residual weight
182   };
183 
FactorWeightFstImpl(const Fst<A> & fst,const FactorWeightOptions<A> & opts)184   FactorWeightFstImpl(const Fst<A> &fst, const FactorWeightOptions<A> &opts)
185       : CacheImpl<A>(opts),
186         fst_(fst.Copy()),
187         delta_(opts.delta),
188         mode_(opts.mode),
189         final_ilabel_(opts.final_ilabel),
190         final_olabel_(opts.final_olabel) {
191     SetType("factor_weight");
192     uint64 props = fst.Properties(kFstProperties, false);
193     SetProperties(FactorWeightProperties(props), kCopyProperties);
194 
195     SetInputSymbols(fst.InputSymbols());
196     SetOutputSymbols(fst.OutputSymbols());
197 
198     if (mode_ == 0)
199       LOG(WARNING) << "FactorWeightFst: factor mode is set to 0: "
200                    << "factoring neither arc weights nor final weights.";
201   }
202 
FactorWeightFstImpl(const FactorWeightFstImpl<A,F> & impl)203   FactorWeightFstImpl(const FactorWeightFstImpl<A, F> &impl)
204       : CacheImpl<A>(impl),
205         fst_(impl.fst_->Copy(true)),
206         delta_(impl.delta_),
207         mode_(impl.mode_),
208         final_ilabel_(impl.final_ilabel_),
209         final_olabel_(impl.final_olabel_) {
210     SetType("factor_weight");
211     SetProperties(impl.Properties(), kCopyProperties);
212     SetInputSymbols(impl.InputSymbols());
213     SetOutputSymbols(impl.OutputSymbols());
214   }
215 
~FactorWeightFstImpl()216   ~FactorWeightFstImpl() {
217     delete fst_;
218   }
219 
Start()220   StateId Start() {
221     if (!HasStart()) {
222       StateId s = fst_->Start();
223       if (s == kNoStateId)
224         return kNoStateId;
225       StateId start = FindState(Element(fst_->Start(), Weight::One()));
226       SetStart(start);
227     }
228     return CacheImpl<A>::Start();
229   }
230 
Final(StateId s)231   Weight Final(StateId s) {
232     if (!HasFinal(s)) {
233       const Element &e = elements_[s];
234       // TODO: fix so cast is unnecessary
235       Weight w = e.state == kNoStateId
236                  ? e.weight
237                  : (Weight) Times(e.weight, fst_->Final(e.state));
238       FactorIterator f(w);
239       if (!(mode_ & kFactorFinalWeights) || f.Done())
240         SetFinal(s, w);
241       else
242         SetFinal(s, Weight::Zero());
243     }
244     return CacheImpl<A>::Final(s);
245   }
246 
NumArcs(StateId s)247   size_t NumArcs(StateId s) {
248     if (!HasArcs(s))
249       Expand(s);
250     return CacheImpl<A>::NumArcs(s);
251   }
252 
NumInputEpsilons(StateId s)253   size_t NumInputEpsilons(StateId s) {
254     if (!HasArcs(s))
255       Expand(s);
256     return CacheImpl<A>::NumInputEpsilons(s);
257   }
258 
NumOutputEpsilons(StateId s)259   size_t NumOutputEpsilons(StateId s) {
260     if (!HasArcs(s))
261       Expand(s);
262     return CacheImpl<A>::NumOutputEpsilons(s);
263   }
264 
Properties()265   uint64 Properties() const { return Properties(kFstProperties); }
266 
267   // Set error if found; return FST impl properties.
Properties(uint64 mask)268   uint64 Properties(uint64 mask) const {
269     if ((mask & kError) && fst_->Properties(kError, false))
270       SetProperties(kError, kError);
271     return FstImpl<Arc>::Properties(mask);
272   }
273 
InitArcIterator(StateId s,ArcIteratorData<A> * data)274   void InitArcIterator(StateId s, ArcIteratorData<A> *data) {
275     if (!HasArcs(s))
276       Expand(s);
277     CacheImpl<A>::InitArcIterator(s, data);
278   }
279 
280 
281   // Find state corresponding to an element. Create new state
282   // if element not found.
FindState(const Element & e)283   StateId FindState(const Element &e) {
284     if (!(mode_ & kFactorArcWeights) && e.weight == Weight::One()) {
285       while (unfactored_.size() <= e.state)
286         unfactored_.push_back(kNoStateId);
287       if (unfactored_[e.state] == kNoStateId) {
288         unfactored_[e.state] = elements_.size();
289         elements_.push_back(e);
290       }
291       return unfactored_[e.state];
292     } else {
293       typename ElementMap::iterator eit = element_map_.find(e);
294       if (eit != element_map_.end()) {
295         return (*eit).second;
296       } else {
297         StateId s = elements_.size();
298         elements_.push_back(e);
299         element_map_.insert(pair<const Element, StateId>(e, s));
300         return s;
301       }
302     }
303   }
304 
305   // Computes the outgoing transitions from a state, creating new destination
306   // states as needed.
Expand(StateId s)307   void Expand(StateId s) {
308     Element e = elements_[s];
309     if (e.state != kNoStateId) {
310       for (ArcIterator< Fst<A> > ait(*fst_, e.state);
311            !ait.Done();
312            ait.Next()) {
313         const A &arc = ait.Value();
314         Weight w = Times(e.weight, arc.weight);
315         FactorIterator fit(w);
316         if (!(mode_ & kFactorArcWeights) || fit.Done()) {
317           StateId d = FindState(Element(arc.nextstate, Weight::One()));
318           PushArc(s, Arc(arc.ilabel, arc.olabel, w, d));
319         } else {
320           for (; !fit.Done(); fit.Next()) {
321             const pair<Weight, Weight> &p = fit.Value();
322             StateId d = FindState(Element(arc.nextstate,
323                                           p.second.Quantize(delta_)));
324             PushArc(s, Arc(arc.ilabel, arc.olabel, p.first, d));
325           }
326         }
327       }
328     }
329 
330     if ((mode_ & kFactorFinalWeights) &&
331         ((e.state == kNoStateId) ||
332          (fst_->Final(e.state) != Weight::Zero()))) {
333       Weight w = e.state == kNoStateId
334                  ? e.weight
335                  : Times(e.weight, fst_->Final(e.state));
336       for (FactorIterator fit(w);
337            !fit.Done();
338            fit.Next()) {
339         const pair<Weight, Weight> &p = fit.Value();
340         StateId d = FindState(Element(kNoStateId,
341                                       p.second.Quantize(delta_)));
342         PushArc(s, Arc(final_ilabel_, final_olabel_, p.first, d));
343       }
344     }
345     SetArcs(s);
346   }
347 
348  private:
349   static const size_t kPrime = 7853;
350 
351   // Equality function for Elements, assume weights have been quantized.
352   class ElementEqual {
353    public:
operator()354     bool operator()(const Element &x, const Element &y) const {
355       return x.state == y.state && x.weight == y.weight;
356     }
357   };
358 
359   // Hash function for Elements to Fst states.
360   class ElementKey {
361    public:
operator()362     size_t operator()(const Element &x) const {
363       return static_cast<size_t>(x.state * kPrime + x.weight.Hash());
364     }
365    private:
366   };
367 
368   typedef unordered_map<Element, StateId, ElementKey, ElementEqual> ElementMap;
369 
370   const Fst<A> *fst_;
371   float delta_;
372   uint32 mode_;               // factoring arc and/or final weights
373   Label final_ilabel_;        // ilabel of arc created when factoring final w's
374   Label final_olabel_;        // olabel of arc created when factoring final w's
375   vector<Element> elements_;  // mapping Fst state to Elements
376   ElementMap element_map_;    // mapping Elements to Fst state
377   // mapping between old/new 'StateId' for states that do not need to
378   // be factored when 'mode_' is '0' or 'kFactorFinalWeights'
379   vector<StateId> unfactored_;
380 
381   void operator=(const FactorWeightFstImpl<A, F> &);  // disallow
382 };
383 
384 template <class A, class F> const size_t FactorWeightFstImpl<A, F>::kPrime;
385 
386 
387 // FactorWeightFst takes as template parameter a FactorIterator as
388 // defined above. The result of weight factoring is a transducer
389 // equivalent to the input whose path weights have been factored
390 // according to the FactorIterator. States and transitions will be
391 // added as necessary. The algorithm is a generalization to arbitrary
392 // weights of the second step of the input epsilon-normalization
393 // algorithm due to Mohri, "Generic epsilon-removal and input
394 // epsilon-normalization algorithms for weighted transducers",
395 // International Journal of Computer Science 13(1): 129-143 (2002).
396 //
397 // This class attaches interface to implementation and handles
398 // reference counting, delegating most methods to ImplToFst.
399 template <class A, class F>
400 class FactorWeightFst : public ImplToFst< FactorWeightFstImpl<A, F> > {
401  public:
402   friend class ArcIterator< FactorWeightFst<A, F> >;
403   friend class StateIterator< FactorWeightFst<A, F> >;
404 
405   typedef A Arc;
406   typedef typename A::Weight Weight;
407   typedef typename A::StateId StateId;
408   typedef CacheState<A> State;
409   typedef FactorWeightFstImpl<A, F> Impl;
410 
FactorWeightFst(const Fst<A> & fst)411   FactorWeightFst(const Fst<A> &fst)
412       : ImplToFst<Impl>(new Impl(fst, FactorWeightOptions<A>())) {}
413 
FactorWeightFst(const Fst<A> & fst,const FactorWeightOptions<A> & opts)414   FactorWeightFst(const Fst<A> &fst,  const FactorWeightOptions<A> &opts)
415       : ImplToFst<Impl>(new Impl(fst, opts)) {}
416 
417   // See Fst<>::Copy() for doc.
FactorWeightFst(const FactorWeightFst<A,F> & fst,bool copy)418   FactorWeightFst(const FactorWeightFst<A, F> &fst, bool copy)
419       : ImplToFst<Impl>(fst, copy) {}
420 
421   // Get a copy of this FactorWeightFst. See Fst<>::Copy() for further doc.
422   virtual FactorWeightFst<A, F> *Copy(bool copy = false) const {
423     return new FactorWeightFst<A, F>(*this, copy);
424   }
425 
426   virtual inline void InitStateIterator(StateIteratorData<A> *data) const;
427 
InitArcIterator(StateId s,ArcIteratorData<A> * data)428   virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
429     GetImpl()->InitArcIterator(s, data);
430   }
431 
432  private:
433   // Makes visible to friends.
GetImpl()434   Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); }
435 
436   void operator=(const FactorWeightFst<A, F> &fst);  // Disallow
437 };
438 
439 
440 // Specialization for FactorWeightFst.
441 template<class A, class F>
442 class StateIterator< FactorWeightFst<A, F> >
443     : public CacheStateIterator< FactorWeightFst<A, F> > {
444  public:
StateIterator(const FactorWeightFst<A,F> & fst)445   explicit StateIterator(const FactorWeightFst<A, F> &fst)
446       : CacheStateIterator< FactorWeightFst<A, F> >(fst, fst.GetImpl()) {}
447 };
448 
449 
450 // Specialization for FactorWeightFst.
451 template <class A, class F>
452 class ArcIterator< FactorWeightFst<A, F> >
453     : public CacheArcIterator< FactorWeightFst<A, F> > {
454  public:
455   typedef typename A::StateId StateId;
456 
ArcIterator(const FactorWeightFst<A,F> & fst,StateId s)457   ArcIterator(const FactorWeightFst<A, F> &fst, StateId s)
458       : CacheArcIterator< FactorWeightFst<A, F> >(fst.GetImpl(), s) {
459     if (!fst.GetImpl()->HasArcs(s))
460       fst.GetImpl()->Expand(s);
461   }
462 
463  private:
464   DISALLOW_COPY_AND_ASSIGN(ArcIterator);
465 };
466 
467 template <class A, class F> inline
InitStateIterator(StateIteratorData<A> * data)468 void FactorWeightFst<A, F>::InitStateIterator(StateIteratorData<A> *data) const
469 {
470   data->base = new StateIterator< FactorWeightFst<A, F> >(*this);
471 }
472 
473 
474 }  // namespace fst
475 
476 #endif // FST_LIB_FACTOR_WEIGHT_H__
477