1 // factor-weight.h
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Author: allauzen@cs.nyu.edu (Cyril Allauzen)
16 //
17 // \file
18 // Classes to factor weights in an FST.
19
20 #ifndef FST_LIB_FACTOR_WEIGHT_H__
21 #define FST_LIB_FACTOR_WEIGHT_H__
22
23 #include <algorithm>
24
25 #include <ext/hash_map>
26 using __gnu_cxx::hash_map;
27 #include <ext/slist>
28 using __gnu_cxx::slist;
29
30 #include "fst/lib/cache.h"
31 #include "fst/lib/test-properties.h"
32
33 namespace fst {
34
35 struct FactorWeightOptions : CacheOptions {
36 float delta;
37 bool final_only; // only factor final weights when true
38
FactorWeightOptionsFactorWeightOptions39 FactorWeightOptions(const CacheOptions &opts, float d, bool of)
40 : CacheOptions(opts), delta(d), final_only(of) {}
41
42 explicit FactorWeightOptions(float d, bool of = false)
deltaFactorWeightOptions43 : delta(d), final_only(of) {}
44
45 FactorWeightOptions(bool of = false)
deltaFactorWeightOptions46 : delta(kDelta), final_only(of) {}
47 };
48
49
50 // A factor iterator takes as argument a weight w and returns a
51 // sequence of pairs of weights (xi,yi) such that the sum of the
52 // products xi times yi is equal to w. If w is fully factored,
53 // the iterator should return nothing.
54 //
55 // template <class W>
56 // class FactorIterator {
57 // public:
58 // FactorIterator(W w);
59 // bool Done() const;
60 // void Next();
61 // pair<W, W> Value() const;
62 // void Reset();
63 // }
64
65
66 // Factor trivially.
67 template <class W>
68 class IdentityFactor {
69 public:
IdentityFactor(const W & w)70 IdentityFactor(const W &w) {}
Done()71 bool Done() const { return true; }
Next()72 void Next() {}
Value()73 pair<W, W> Value() const { return make_pair(W::One(), W::One()); } // unused
Reset()74 void Reset() {}
75 };
76
77
78 // Factor a StringWeight w as 'ab' where 'a' is a label.
79 template <typename L, StringType S = STRING_LEFT>
80 class StringFactor {
81 public:
StringFactor(const StringWeight<L,S> & w)82 StringFactor(const StringWeight<L, S> &w)
83 : weight_(w), done_(w.Size() <= 1) {}
84
Done()85 bool Done() const { return done_; }
86
Next()87 void Next() { done_ = true; }
88
Value()89 pair< StringWeight<L, S>, StringWeight<L, S> > Value() const {
90 StringWeightIterator<L, S> iter(weight_);
91 StringWeight<L, S> w1(iter.Value());
92 StringWeight<L, S> w2;
93 for (iter.Next(); !iter.Done(); iter.Next())
94 w2.PushBack(iter.Value());
95 return make_pair(w1, w2);
96 }
97
Reset()98 void Reset() { done_ = weight_.Size() <= 1; }
99
100 private:
101 StringWeight<L, S> weight_;
102 bool done_;
103 };
104
105
106 // Factor a GallicWeight using StringFactor.
107 template <class L, class W, StringType S = STRING_LEFT>
108 class GallicFactor {
109 public:
GallicFactor(const GallicWeight<L,W,S> & w)110 GallicFactor(const GallicWeight<L, W, S> &w)
111 : weight_(w), done_(w.Value1().Size() <= 1) {}
112
Done()113 bool Done() const { return done_; }
114
Next()115 void Next() { done_ = true; }
116
Value()117 pair< GallicWeight<L, W, S>, GallicWeight<L, W, S> > Value() const {
118 StringFactor<L, S> iter(weight_.Value1());
119 GallicWeight<L, W, S> w1(iter.Value().first, weight_.Value2());
120 GallicWeight<L, W, S> w2(iter.Value().second, W::One());
121 return make_pair(w1, w2);
122 }
123
Reset()124 void Reset() { done_ = weight_.Value1().Size() <= 1; }
125
126 private:
127 GallicWeight<L, W, S> weight_;
128 bool done_;
129 };
130
131
132 // Implementation class for FactorWeight
133 template <class A, class F>
134 class FactorWeightFstImpl
135 : public CacheImpl<A> {
136 public:
137 using FstImpl<A>::SetType;
138 using FstImpl<A>::SetProperties;
139 using FstImpl<A>::Properties;
140 using FstImpl<A>::SetInputSymbols;
141 using FstImpl<A>::SetOutputSymbols;
142
143 using CacheBaseImpl< CacheState<A> >::HasStart;
144 using CacheBaseImpl< CacheState<A> >::HasFinal;
145 using CacheBaseImpl< CacheState<A> >::HasArcs;
146
147 typedef A Arc;
148 typedef typename A::Label Label;
149 typedef typename A::Weight Weight;
150 typedef typename A::StateId StateId;
151 typedef F FactorIterator;
152
153 struct Element {
ElementElement154 Element() {}
155
ElementElement156 Element(StateId s, Weight w) : state(s), weight(w) {}
157
158 StateId state; // Input state Id
159 Weight weight; // Residual weight
160 };
161
FactorWeightFstImpl(const Fst<A> & fst,const FactorWeightOptions & opts)162 FactorWeightFstImpl(const Fst<A> &fst, const FactorWeightOptions &opts)
163 : CacheImpl<A>(opts), fst_(fst.Copy()), delta_(opts.delta),
164 final_only_(opts.final_only) {
165 SetType("factor-weight");
166 uint64 props = fst.Properties(kFstProperties, false);
167 SetProperties(FactorWeightProperties(props), kCopyProperties);
168
169 SetInputSymbols(fst.InputSymbols());
170 SetOutputSymbols(fst.OutputSymbols());
171 }
172
~FactorWeightFstImpl()173 ~FactorWeightFstImpl() {
174 delete fst_;
175 }
176
Start()177 StateId Start() {
178 if (!HasStart()) {
179 StateId s = fst_->Start();
180 if (s == kNoStateId)
181 return kNoStateId;
182 StateId start = FindState(Element(fst_->Start(), Weight::One()));
183 SetStart(start);
184 }
185 return CacheImpl<A>::Start();
186 }
187
Final(StateId s)188 Weight Final(StateId s) {
189 if (!HasFinal(s)) {
190 const Element &e = elements_[s];
191 // TODO: fix so cast is unnecessary
192 Weight w = e.state == kNoStateId
193 ? e.weight
194 : (Weight) Times(e.weight, fst_->Final(e.state));
195 FactorIterator f(w);
196 if (w != Weight::Zero() && f.Done())
197 SetFinal(s, w);
198 else
199 SetFinal(s, Weight::Zero());
200 }
201 return CacheImpl<A>::Final(s);
202 }
203
NumArcs(StateId s)204 size_t NumArcs(StateId s) {
205 if (!HasArcs(s))
206 Expand(s);
207 return CacheImpl<A>::NumArcs(s);
208 }
209
NumInputEpsilons(StateId s)210 size_t NumInputEpsilons(StateId s) {
211 if (!HasArcs(s))
212 Expand(s);
213 return CacheImpl<A>::NumInputEpsilons(s);
214 }
215
NumOutputEpsilons(StateId s)216 size_t NumOutputEpsilons(StateId s) {
217 if (!HasArcs(s))
218 Expand(s);
219 return CacheImpl<A>::NumOutputEpsilons(s);
220 }
221
InitArcIterator(StateId s,ArcIteratorData<A> * data)222 void InitArcIterator(StateId s, ArcIteratorData<A> *data) {
223 if (!HasArcs(s))
224 Expand(s);
225 CacheImpl<A>::InitArcIterator(s, data);
226 }
227
228
229 // Find state corresponding to an element. Create new state
230 // if element not found.
FindState(const Element & e)231 StateId FindState(const Element &e) {
232 if (final_only_ && e.weight == Weight::One()) {
233 while (unfactored_.size() <= (unsigned int)e.state)
234 unfactored_.push_back(kNoStateId);
235 if (unfactored_[e.state] == kNoStateId) {
236 unfactored_[e.state] = elements_.size();
237 elements_.push_back(e);
238 }
239 return unfactored_[e.state];
240 } else {
241 typename ElementMap::iterator eit = element_map_.find(e);
242 if (eit != element_map_.end()) {
243 return (*eit).second;
244 } else {
245 StateId s = elements_.size();
246 elements_.push_back(e);
247 element_map_.insert(pair<const Element, StateId>(e, s));
248 return s;
249 }
250 }
251 }
252
253 // Computes the outgoing transitions from a state, creating new destination
254 // states as needed.
Expand(StateId s)255 void Expand(StateId s) {
256 Element e = elements_[s];
257 if (e.state != kNoStateId) {
258 for (ArcIterator< Fst<A> > ait(*fst_, e.state);
259 !ait.Done();
260 ait.Next()) {
261 const A &arc = ait.Value();
262 Weight w = Times(e.weight, arc.weight);
263 FactorIterator fit(w);
264 if (final_only_ || fit.Done()) {
265 StateId d = FindState(Element(arc.nextstate, Weight::One()));
266 AddArc(s, Arc(arc.ilabel, arc.olabel, w, d));
267 } else {
268 for (; !fit.Done(); fit.Next()) {
269 const pair<Weight, Weight> &p = fit.Value();
270 StateId d = FindState(Element(arc.nextstate,
271 p.second.Quantize(delta_)));
272 AddArc(s, Arc(arc.ilabel, arc.olabel, p.first, d));
273 }
274 }
275 }
276 }
277 if ((e.state == kNoStateId) ||
278 (fst_->Final(e.state) != Weight::Zero())) {
279 Weight w = e.state == kNoStateId
280 ? e.weight
281 : Times(e.weight, fst_->Final(e.state));
282 for (FactorIterator fit(w);
283 !fit.Done();
284 fit.Next()) {
285 const pair<Weight, Weight> &p = fit.Value();
286 StateId d = FindState(Element(kNoStateId,
287 p.second.Quantize(delta_)));
288 AddArc(s, Arc(0, 0, p.first, d));
289 }
290 }
291 SetArcs(s);
292 }
293
294 private:
295 // Equality function for Elements, assume weights have been quantized.
296 class ElementEqual {
297 public:
operator()298 bool operator()(const Element &x, const Element &y) const {
299 return x.state == y.state && x.weight == y.weight;
300 }
301 };
302
303 // Hash function for Elements to Fst states.
304 class ElementKey {
305 public:
operator()306 size_t operator()(const Element &x) const {
307 return static_cast<size_t>(x.state * kPrime + x.weight.Hash());
308 }
309 private:
310 static const int kPrime = 7853;
311 };
312
313 typedef hash_map<Element, StateId, ElementKey, ElementEqual> ElementMap;
314
315 const Fst<A> *fst_;
316 float delta_;
317 bool final_only_;
318 vector<Element> elements_; // mapping Fst state to Elements
319 ElementMap element_map_; // mapping Elements to Fst state
320 // mapping between old/new 'StateId' for states that do not need to
321 // be factored when 'final_only_' is true
322 vector<StateId> unfactored_;
323
324 DISALLOW_EVIL_CONSTRUCTORS(FactorWeightFstImpl);
325 };
326
327
328 // FactorWeightFst takes as template parameter a FactorIterator as
329 // defined above. The result of weight factoring is a transducer
330 // equivalent to the input whose path weights have been factored
331 // according to the FactorIterator. States and transitions will be
332 // added as necessary. The algorithm is a generalization to arbitrary
333 // weights of the second step of the input epsilon-normalization
334 // algorithm due to Mohri, "Generic epsilon-removal and input
335 // epsilon-normalization algorithms for weighted transducers",
336 // International Journal of Computer Science 13(1): 129-143 (2002).
337 template <class A, class F>
338 class FactorWeightFst : public Fst<A> {
339 public:
340 friend class ArcIterator< FactorWeightFst<A, F> >;
341 friend class CacheStateIterator< FactorWeightFst<A, F> >;
342 friend class CacheArcIterator< FactorWeightFst<A, F> >;
343
344 typedef A Arc;
345 typedef typename A::Weight Weight;
346 typedef typename A::StateId StateId;
347 typedef CacheState<A> State;
348
FactorWeightFst(const Fst<A> & fst)349 FactorWeightFst(const Fst<A> &fst)
350 : impl_(new FactorWeightFstImpl<A, F>(fst, FactorWeightOptions())) {}
351
FactorWeightFst(const Fst<A> & fst,const FactorWeightOptions & opts)352 FactorWeightFst(const Fst<A> &fst, const FactorWeightOptions &opts)
353 : impl_(new FactorWeightFstImpl<A, F>(fst, opts)) {}
FactorWeightFst(const FactorWeightFst<A,F> & fst)354 FactorWeightFst(const FactorWeightFst<A, F> &fst) : Fst<A>(fst), impl_(fst.impl_) {
355 impl_->IncrRefCount();
356 }
357
~FactorWeightFst()358 virtual ~FactorWeightFst() { if (!impl_->DecrRefCount()) delete impl_; }
359
Start()360 virtual StateId Start() const { return impl_->Start(); }
361
Final(StateId s)362 virtual Weight Final(StateId s) const { return impl_->Final(s); }
363
NumArcs(StateId s)364 virtual size_t NumArcs(StateId s) const { return impl_->NumArcs(s); }
365
NumInputEpsilons(StateId s)366 virtual size_t NumInputEpsilons(StateId s) const {
367 return impl_->NumInputEpsilons(s);
368 }
369
NumOutputEpsilons(StateId s)370 virtual size_t NumOutputEpsilons(StateId s) const {
371 return impl_->NumOutputEpsilons(s);
372 }
373
Properties(uint64 mask,bool test)374 virtual uint64 Properties(uint64 mask, bool test) const {
375 if (test) {
376 uint64 known, test = TestProperties(*this, mask, &known);
377 impl_->SetProperties(test, known);
378 return test & mask;
379 } else {
380 return impl_->Properties(mask);
381 }
382 }
383
Type()384 virtual const string& Type() const { return impl_->Type(); }
385
Copy()386 virtual FactorWeightFst<A, F> *Copy() const {
387 return new FactorWeightFst<A, F>(*this);
388 }
389
InputSymbols()390 virtual const SymbolTable* InputSymbols() const {
391 return impl_->InputSymbols();
392 }
393
OutputSymbols()394 virtual const SymbolTable* OutputSymbols() const {
395 return impl_->OutputSymbols();
396 }
397
398 virtual inline void InitStateIterator(StateIteratorData<A> *data) const;
399
InitArcIterator(StateId s,ArcIteratorData<A> * data)400 virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
401 impl_->InitArcIterator(s, data);
402 }
403
404 private:
Impl()405 FactorWeightFstImpl<A, F> *Impl() { return impl_; }
406
407 FactorWeightFstImpl<A, F> *impl_;
408
409 void operator=(const FactorWeightFst<A, F> &fst); // Disallow
410 };
411
412
413 // Specialization for FactorWeightFst.
414 template<class A, class F>
415 class StateIterator< FactorWeightFst<A, F> >
416 : public CacheStateIterator< FactorWeightFst<A, F> > {
417 public:
StateIterator(const FactorWeightFst<A,F> & fst)418 explicit StateIterator(const FactorWeightFst<A, F> &fst)
419 : CacheStateIterator< FactorWeightFst<A, F> >(fst) {}
420 };
421
422
423 // Specialization for FactorWeightFst.
424 template <class A, class F>
425 class ArcIterator< FactorWeightFst<A, F> >
426 : public CacheArcIterator< FactorWeightFst<A, F> > {
427 public:
428 typedef typename A::StateId StateId;
429
ArcIterator(const FactorWeightFst<A,F> & fst,StateId s)430 ArcIterator(const FactorWeightFst<A, F> &fst, StateId s)
431 : CacheArcIterator< FactorWeightFst<A, F> >(fst, s) {
432 if (!fst.impl_->HasArcs(s))
433 fst.impl_->Expand(s);
434 }
435
436 private:
437 DISALLOW_EVIL_CONSTRUCTORS(ArcIterator);
438 };
439
440 template <class A, class F> inline
InitStateIterator(StateIteratorData<A> * data)441 void FactorWeightFst<A, F>::InitStateIterator(StateIteratorData<A> *data) const
442 {
443 data->base = new StateIterator< FactorWeightFst<A, F> >(*this);
444 }
445
446
447 } // namespace fst
448
449 #endif // FST_LIB_FACTOR_WEIGHT_H__
450