1 // factor-weight.h
2
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Copyright 2005-2010 Google, Inc.
16 // Author: allauzen@google.com (Cyril Allauzen)
17 //
18 // \file
19 // Classes to factor weights in an FST.
20
21 #ifndef FST_LIB_FACTOR_WEIGHT_H__
22 #define FST_LIB_FACTOR_WEIGHT_H__
23
24 #include <algorithm>
25 #include <unordered_map>
26 using std::tr1::unordered_map;
27 using std::tr1::unordered_multimap;
28 #include <fst/slist.h>
29 #include <string>
30 #include <utility>
31 using std::pair; using std::make_pair;
32 #include <vector>
33 using std::vector;
34
35 #include <fst/cache.h>
36 #include <fst/test-properties.h>
37
38
39 namespace fst {
40
41 const uint32 kFactorFinalWeights = 0x00000001;
42 const uint32 kFactorArcWeights = 0x00000002;
43
44 template <class Arc>
45 struct FactorWeightOptions : CacheOptions {
46 typedef typename Arc::Label Label;
47 float delta;
48 uint32 mode; // factor arc weights and/or final weights
49 Label final_ilabel; // input label of arc created when factoring final w's
50 Label final_olabel; // output label of arc created when factoring final w's
51
52 FactorWeightOptions(const CacheOptions &opts, float d,
53 uint32 m = kFactorArcWeights | kFactorFinalWeights,
54 Label il = 0, Label ol = 0)
CacheOptionsFactorWeightOptions55 : CacheOptions(opts), delta(d), mode(m), final_ilabel(il),
56 final_olabel(ol) {}
57
58 explicit FactorWeightOptions(
59 float d, uint32 m = kFactorArcWeights | kFactorFinalWeights,
60 Label il = 0, Label ol = 0)
deltaFactorWeightOptions61 : delta(d), mode(m), final_ilabel(il), final_olabel(ol) {}
62
63 FactorWeightOptions(uint32 m = kFactorArcWeights | kFactorFinalWeights,
64 Label il = 0, Label ol = 0)
deltaFactorWeightOptions65 : delta(kDelta), mode(m), final_ilabel(il), final_olabel(ol) {}
66 };
67
68
69 // A factor iterator takes as argument a weight w and returns a
70 // sequence of pairs of weights (xi,yi) such that the sum of the
71 // products xi times yi is equal to w. If w is fully factored,
72 // the iterator should return nothing.
73 //
74 // template <class W>
75 // class FactorIterator {
76 // public:
77 // FactorIterator(W w);
78 // bool Done() const;
79 // void Next();
80 // pair<W, W> Value() const;
81 // void Reset();
82 // }
83
84
85 // Factor trivially.
86 template <class W>
87 class IdentityFactor {
88 public:
IdentityFactor(const W & w)89 IdentityFactor(const W &w) {}
Done()90 bool Done() const { return true; }
Next()91 void Next() {}
Value()92 pair<W, W> Value() const { return make_pair(W::One(), W::One()); } // unused
Reset()93 void Reset() {}
94 };
95
96
97 // Factor a StringWeight w as 'ab' where 'a' is a label.
98 template <typename L, StringType S = STRING_LEFT>
99 class StringFactor {
100 public:
StringFactor(const StringWeight<L,S> & w)101 StringFactor(const StringWeight<L, S> &w)
102 : weight_(w), done_(w.Size() <= 1) {}
103
Done()104 bool Done() const { return done_; }
105
Next()106 void Next() { done_ = true; }
107
Value()108 pair< StringWeight<L, S>, StringWeight<L, S> > Value() const {
109 StringWeightIterator<L, S> iter(weight_);
110 StringWeight<L, S> w1(iter.Value());
111 StringWeight<L, S> w2;
112 for (iter.Next(); !iter.Done(); iter.Next())
113 w2.PushBack(iter.Value());
114 return make_pair(w1, w2);
115 }
116
Reset()117 void Reset() { done_ = weight_.Size() <= 1; }
118
119 private:
120 StringWeight<L, S> weight_;
121 bool done_;
122 };
123
124
125 // Factor a GallicWeight using StringFactor.
126 template <class L, class W, StringType S = STRING_LEFT>
127 class GallicFactor {
128 public:
GallicFactor(const GallicWeight<L,W,S> & w)129 GallicFactor(const GallicWeight<L, W, S> &w)
130 : weight_(w), done_(w.Value1().Size() <= 1) {}
131
Done()132 bool Done() const { return done_; }
133
Next()134 void Next() { done_ = true; }
135
Value()136 pair< GallicWeight<L, W, S>, GallicWeight<L, W, S> > Value() const {
137 StringFactor<L, S> iter(weight_.Value1());
138 GallicWeight<L, W, S> w1(iter.Value().first, weight_.Value2());
139 GallicWeight<L, W, S> w2(iter.Value().second, W::One());
140 return make_pair(w1, w2);
141 }
142
Reset()143 void Reset() { done_ = weight_.Value1().Size() <= 1; }
144
145 private:
146 GallicWeight<L, W, S> weight_;
147 bool done_;
148 };
149
150
151 // Implementation class for FactorWeight
152 template <class A, class F>
153 class FactorWeightFstImpl
154 : public CacheImpl<A> {
155 public:
156 using FstImpl<A>::SetType;
157 using FstImpl<A>::SetProperties;
158 using FstImpl<A>::SetInputSymbols;
159 using FstImpl<A>::SetOutputSymbols;
160
161 using CacheBaseImpl< CacheState<A> >::PushArc;
162 using CacheBaseImpl< CacheState<A> >::HasStart;
163 using CacheBaseImpl< CacheState<A> >::HasFinal;
164 using CacheBaseImpl< CacheState<A> >::HasArcs;
165 using CacheBaseImpl< CacheState<A> >::SetArcs;
166 using CacheBaseImpl< CacheState<A> >::SetFinal;
167 using CacheBaseImpl< CacheState<A> >::SetStart;
168
169 typedef A Arc;
170 typedef typename A::Label Label;
171 typedef typename A::Weight Weight;
172 typedef typename A::StateId StateId;
173 typedef F FactorIterator;
174
175 struct Element {
ElementElement176 Element() {}
177
ElementElement178 Element(StateId s, Weight w) : state(s), weight(w) {}
179
180 StateId state; // Input state Id
181 Weight weight; // Residual weight
182 };
183
FactorWeightFstImpl(const Fst<A> & fst,const FactorWeightOptions<A> & opts)184 FactorWeightFstImpl(const Fst<A> &fst, const FactorWeightOptions<A> &opts)
185 : CacheImpl<A>(opts),
186 fst_(fst.Copy()),
187 delta_(opts.delta),
188 mode_(opts.mode),
189 final_ilabel_(opts.final_ilabel),
190 final_olabel_(opts.final_olabel) {
191 SetType("factor_weight");
192 uint64 props = fst.Properties(kFstProperties, false);
193 SetProperties(FactorWeightProperties(props), kCopyProperties);
194
195 SetInputSymbols(fst.InputSymbols());
196 SetOutputSymbols(fst.OutputSymbols());
197
198 if (mode_ == 0)
199 LOG(WARNING) << "FactorWeightFst: factor mode is set to 0: "
200 << "factoring neither arc weights nor final weights.";
201 }
202
FactorWeightFstImpl(const FactorWeightFstImpl<A,F> & impl)203 FactorWeightFstImpl(const FactorWeightFstImpl<A, F> &impl)
204 : CacheImpl<A>(impl),
205 fst_(impl.fst_->Copy(true)),
206 delta_(impl.delta_),
207 mode_(impl.mode_),
208 final_ilabel_(impl.final_ilabel_),
209 final_olabel_(impl.final_olabel_) {
210 SetType("factor_weight");
211 SetProperties(impl.Properties(), kCopyProperties);
212 SetInputSymbols(impl.InputSymbols());
213 SetOutputSymbols(impl.OutputSymbols());
214 }
215
~FactorWeightFstImpl()216 ~FactorWeightFstImpl() {
217 delete fst_;
218 }
219
Start()220 StateId Start() {
221 if (!HasStart()) {
222 StateId s = fst_->Start();
223 if (s == kNoStateId)
224 return kNoStateId;
225 StateId start = FindState(Element(fst_->Start(), Weight::One()));
226 SetStart(start);
227 }
228 return CacheImpl<A>::Start();
229 }
230
Final(StateId s)231 Weight Final(StateId s) {
232 if (!HasFinal(s)) {
233 const Element &e = elements_[s];
234 // TODO: fix so cast is unnecessary
235 Weight w = e.state == kNoStateId
236 ? e.weight
237 : (Weight) Times(e.weight, fst_->Final(e.state));
238 FactorIterator f(w);
239 if (!(mode_ & kFactorFinalWeights) || f.Done())
240 SetFinal(s, w);
241 else
242 SetFinal(s, Weight::Zero());
243 }
244 return CacheImpl<A>::Final(s);
245 }
246
NumArcs(StateId s)247 size_t NumArcs(StateId s) {
248 if (!HasArcs(s))
249 Expand(s);
250 return CacheImpl<A>::NumArcs(s);
251 }
252
NumInputEpsilons(StateId s)253 size_t NumInputEpsilons(StateId s) {
254 if (!HasArcs(s))
255 Expand(s);
256 return CacheImpl<A>::NumInputEpsilons(s);
257 }
258
NumOutputEpsilons(StateId s)259 size_t NumOutputEpsilons(StateId s) {
260 if (!HasArcs(s))
261 Expand(s);
262 return CacheImpl<A>::NumOutputEpsilons(s);
263 }
264
Properties()265 uint64 Properties() const { return Properties(kFstProperties); }
266
267 // Set error if found; return FST impl properties.
Properties(uint64 mask)268 uint64 Properties(uint64 mask) const {
269 if ((mask & kError) && fst_->Properties(kError, false))
270 SetProperties(kError, kError);
271 return FstImpl<Arc>::Properties(mask);
272 }
273
InitArcIterator(StateId s,ArcIteratorData<A> * data)274 void InitArcIterator(StateId s, ArcIteratorData<A> *data) {
275 if (!HasArcs(s))
276 Expand(s);
277 CacheImpl<A>::InitArcIterator(s, data);
278 }
279
280
281 // Find state corresponding to an element. Create new state
282 // if element not found.
FindState(const Element & e)283 StateId FindState(const Element &e) {
284 if (!(mode_ & kFactorArcWeights) && e.weight == Weight::One()) {
285 while (unfactored_.size() <= e.state)
286 unfactored_.push_back(kNoStateId);
287 if (unfactored_[e.state] == kNoStateId) {
288 unfactored_[e.state] = elements_.size();
289 elements_.push_back(e);
290 }
291 return unfactored_[e.state];
292 } else {
293 typename ElementMap::iterator eit = element_map_.find(e);
294 if (eit != element_map_.end()) {
295 return (*eit).second;
296 } else {
297 StateId s = elements_.size();
298 elements_.push_back(e);
299 element_map_.insert(pair<const Element, StateId>(e, s));
300 return s;
301 }
302 }
303 }
304
305 // Computes the outgoing transitions from a state, creating new destination
306 // states as needed.
Expand(StateId s)307 void Expand(StateId s) {
308 Element e = elements_[s];
309 if (e.state != kNoStateId) {
310 for (ArcIterator< Fst<A> > ait(*fst_, e.state);
311 !ait.Done();
312 ait.Next()) {
313 const A &arc = ait.Value();
314 Weight w = Times(e.weight, arc.weight);
315 FactorIterator fit(w);
316 if (!(mode_ & kFactorArcWeights) || fit.Done()) {
317 StateId d = FindState(Element(arc.nextstate, Weight::One()));
318 PushArc(s, Arc(arc.ilabel, arc.olabel, w, d));
319 } else {
320 for (; !fit.Done(); fit.Next()) {
321 const pair<Weight, Weight> &p = fit.Value();
322 StateId d = FindState(Element(arc.nextstate,
323 p.second.Quantize(delta_)));
324 PushArc(s, Arc(arc.ilabel, arc.olabel, p.first, d));
325 }
326 }
327 }
328 }
329
330 if ((mode_ & kFactorFinalWeights) &&
331 ((e.state == kNoStateId) ||
332 (fst_->Final(e.state) != Weight::Zero()))) {
333 Weight w = e.state == kNoStateId
334 ? e.weight
335 : Times(e.weight, fst_->Final(e.state));
336 for (FactorIterator fit(w);
337 !fit.Done();
338 fit.Next()) {
339 const pair<Weight, Weight> &p = fit.Value();
340 StateId d = FindState(Element(kNoStateId,
341 p.second.Quantize(delta_)));
342 PushArc(s, Arc(final_ilabel_, final_olabel_, p.first, d));
343 }
344 }
345 SetArcs(s);
346 }
347
348 private:
349 static const size_t kPrime = 7853;
350
351 // Equality function for Elements, assume weights have been quantized.
352 class ElementEqual {
353 public:
operator()354 bool operator()(const Element &x, const Element &y) const {
355 return x.state == y.state && x.weight == y.weight;
356 }
357 };
358
359 // Hash function for Elements to Fst states.
360 class ElementKey {
361 public:
operator()362 size_t operator()(const Element &x) const {
363 return static_cast<size_t>(x.state * kPrime + x.weight.Hash());
364 }
365 private:
366 };
367
368 typedef unordered_map<Element, StateId, ElementKey, ElementEqual> ElementMap;
369
370 const Fst<A> *fst_;
371 float delta_;
372 uint32 mode_; // factoring arc and/or final weights
373 Label final_ilabel_; // ilabel of arc created when factoring final w's
374 Label final_olabel_; // olabel of arc created when factoring final w's
375 vector<Element> elements_; // mapping Fst state to Elements
376 ElementMap element_map_; // mapping Elements to Fst state
377 // mapping between old/new 'StateId' for states that do not need to
378 // be factored when 'mode_' is '0' or 'kFactorFinalWeights'
379 vector<StateId> unfactored_;
380
381 void operator=(const FactorWeightFstImpl<A, F> &); // disallow
382 };
383
384 template <class A, class F> const size_t FactorWeightFstImpl<A, F>::kPrime;
385
386
387 // FactorWeightFst takes as template parameter a FactorIterator as
388 // defined above. The result of weight factoring is a transducer
389 // equivalent to the input whose path weights have been factored
390 // according to the FactorIterator. States and transitions will be
391 // added as necessary. The algorithm is a generalization to arbitrary
392 // weights of the second step of the input epsilon-normalization
393 // algorithm due to Mohri, "Generic epsilon-removal and input
394 // epsilon-normalization algorithms for weighted transducers",
395 // International Journal of Computer Science 13(1): 129-143 (2002).
396 //
397 // This class attaches interface to implementation and handles
398 // reference counting, delegating most methods to ImplToFst.
399 template <class A, class F>
400 class FactorWeightFst : public ImplToFst< FactorWeightFstImpl<A, F> > {
401 public:
402 friend class ArcIterator< FactorWeightFst<A, F> >;
403 friend class StateIterator< FactorWeightFst<A, F> >;
404
405 typedef A Arc;
406 typedef typename A::Weight Weight;
407 typedef typename A::StateId StateId;
408 typedef CacheState<A> State;
409 typedef FactorWeightFstImpl<A, F> Impl;
410
FactorWeightFst(const Fst<A> & fst)411 FactorWeightFst(const Fst<A> &fst)
412 : ImplToFst<Impl>(new Impl(fst, FactorWeightOptions<A>())) {}
413
FactorWeightFst(const Fst<A> & fst,const FactorWeightOptions<A> & opts)414 FactorWeightFst(const Fst<A> &fst, const FactorWeightOptions<A> &opts)
415 : ImplToFst<Impl>(new Impl(fst, opts)) {}
416
417 // See Fst<>::Copy() for doc.
FactorWeightFst(const FactorWeightFst<A,F> & fst,bool copy)418 FactorWeightFst(const FactorWeightFst<A, F> &fst, bool copy)
419 : ImplToFst<Impl>(fst, copy) {}
420
421 // Get a copy of this FactorWeightFst. See Fst<>::Copy() for further doc.
422 virtual FactorWeightFst<A, F> *Copy(bool copy = false) const {
423 return new FactorWeightFst<A, F>(*this, copy);
424 }
425
426 virtual inline void InitStateIterator(StateIteratorData<A> *data) const;
427
InitArcIterator(StateId s,ArcIteratorData<A> * data)428 virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
429 GetImpl()->InitArcIterator(s, data);
430 }
431
432 private:
433 // Makes visible to friends.
GetImpl()434 Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); }
435
436 void operator=(const FactorWeightFst<A, F> &fst); // Disallow
437 };
438
439
440 // Specialization for FactorWeightFst.
441 template<class A, class F>
442 class StateIterator< FactorWeightFst<A, F> >
443 : public CacheStateIterator< FactorWeightFst<A, F> > {
444 public:
StateIterator(const FactorWeightFst<A,F> & fst)445 explicit StateIterator(const FactorWeightFst<A, F> &fst)
446 : CacheStateIterator< FactorWeightFst<A, F> >(fst, fst.GetImpl()) {}
447 };
448
449
450 // Specialization for FactorWeightFst.
451 template <class A, class F>
452 class ArcIterator< FactorWeightFst<A, F> >
453 : public CacheArcIterator< FactorWeightFst<A, F> > {
454 public:
455 typedef typename A::StateId StateId;
456
ArcIterator(const FactorWeightFst<A,F> & fst,StateId s)457 ArcIterator(const FactorWeightFst<A, F> &fst, StateId s)
458 : CacheArcIterator< FactorWeightFst<A, F> >(fst.GetImpl(), s) {
459 if (!fst.GetImpl()->HasArcs(s))
460 fst.GetImpl()->Expand(s);
461 }
462
463 private:
464 DISALLOW_COPY_AND_ASSIGN(ArcIterator);
465 };
466
467 template <class A, class F> inline
InitStateIterator(StateIteratorData<A> * data)468 void FactorWeightFst<A, F>::InitStateIterator(StateIteratorData<A> *data) const
469 {
470 data->base = new StateIterator< FactorWeightFst<A, F> >(*this);
471 }
472
473
474 } // namespace fst
475
476 #endif // FST_LIB_FACTOR_WEIGHT_H__
477