1 // factor-weight.h
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Author: allauzen@cs.nyu.edu (Cyril Allauzen)
16 //
17 // \file
18 // Classes to factor weights in an FST.
19
20 #ifndef FST_LIB_FACTOR_WEIGHT_H__
21 #define FST_LIB_FACTOR_WEIGHT_H__
22
23 #include <algorithm>
24
25 #include <unordered_map>
26 #include <forward_list>
27
28 #include "fst/lib/cache.h"
29 #include "fst/lib/test-properties.h"
30
31 namespace fst {
32
33 struct FactorWeightOptions : CacheOptions {
34 float delta;
35 bool final_only; // only factor final weights when true
36
FactorWeightOptionsFactorWeightOptions37 FactorWeightOptions(const CacheOptions &opts, float d, bool of)
38 : CacheOptions(opts), delta(d), final_only(of) {}
39
40 explicit FactorWeightOptions(float d, bool of = false)
deltaFactorWeightOptions41 : delta(d), final_only(of) {}
42
43 FactorWeightOptions(bool of = false)
deltaFactorWeightOptions44 : delta(kDelta), final_only(of) {}
45 };
46
47
48 // A factor iterator takes as argument a weight w and returns a
49 // sequence of pairs of weights (xi,yi) such that the sum of the
50 // products xi times yi is equal to w. If w is fully factored,
51 // the iterator should return nothing.
52 //
53 // template <class W>
54 // class FactorIterator {
55 // public:
56 // FactorIterator(W w);
57 // bool Done() const;
58 // void Next();
59 // pair<W, W> Value() const;
60 // void Reset();
61 // }
62
63
64 // Factor trivially.
65 template <class W>
66 class IdentityFactor {
67 public:
IdentityFactor(const W & w)68 IdentityFactor(const W &w) {}
Done()69 bool Done() const { return true; }
Next()70 void Next() {}
Value()71 pair<W, W> Value() const { return make_pair(W::One(), W::One()); } // unused
Reset()72 void Reset() {}
73 };
74
75
76 // Factor a StringWeight w as 'ab' where 'a' is a label.
77 template <typename L, StringType S = STRING_LEFT>
78 class StringFactor {
79 public:
StringFactor(const StringWeight<L,S> & w)80 StringFactor(const StringWeight<L, S> &w)
81 : weight_(w), done_(w.Size() <= 1) {}
82
Done()83 bool Done() const { return done_; }
84
Next()85 void Next() { done_ = true; }
86
Value()87 pair< StringWeight<L, S>, StringWeight<L, S> > Value() const {
88 StringWeightIterator<L, S> iter(weight_);
89 StringWeight<L, S> w1(iter.Value());
90 StringWeight<L, S> w2;
91 for (iter.Next(); !iter.Done(); iter.Next())
92 w2.PushBack(iter.Value());
93 return make_pair(w1, w2);
94 }
95
Reset()96 void Reset() { done_ = weight_.Size() <= 1; }
97
98 private:
99 StringWeight<L, S> weight_;
100 bool done_;
101 };
102
103
104 // Factor a GallicWeight using StringFactor.
105 template <class L, class W, StringType S = STRING_LEFT>
106 class GallicFactor {
107 public:
GallicFactor(const GallicWeight<L,W,S> & w)108 GallicFactor(const GallicWeight<L, W, S> &w)
109 : weight_(w), done_(w.Value1().Size() <= 1) {}
110
Done()111 bool Done() const { return done_; }
112
Next()113 void Next() { done_ = true; }
114
Value()115 pair< GallicWeight<L, W, S>, GallicWeight<L, W, S> > Value() const {
116 StringFactor<L, S> iter(weight_.Value1());
117 GallicWeight<L, W, S> w1(iter.Value().first, weight_.Value2());
118 GallicWeight<L, W, S> w2(iter.Value().second, W::One());
119 return make_pair(w1, w2);
120 }
121
Reset()122 void Reset() { done_ = weight_.Value1().Size() <= 1; }
123
124 private:
125 GallicWeight<L, W, S> weight_;
126 bool done_;
127 };
128
129
130 // Implementation class for FactorWeight
131 template <class A, class F>
132 class FactorWeightFstImpl
133 : public CacheImpl<A> {
134 public:
135 using FstImpl<A>::SetType;
136 using FstImpl<A>::SetProperties;
137 using FstImpl<A>::Properties;
138 using FstImpl<A>::SetInputSymbols;
139 using FstImpl<A>::SetOutputSymbols;
140
141 using CacheBaseImpl< CacheState<A> >::HasStart;
142 using CacheBaseImpl< CacheState<A> >::HasFinal;
143 using CacheBaseImpl< CacheState<A> >::HasArcs;
144
145 typedef A Arc;
146 typedef typename A::Label Label;
147 typedef typename A::Weight Weight;
148 typedef typename A::StateId StateId;
149 typedef F FactorIterator;
150
151 struct Element {
ElementElement152 Element() {}
153
ElementElement154 Element(StateId s, Weight w) : state(s), weight(w) {}
155
156 StateId state; // Input state Id
157 Weight weight; // Residual weight
158 };
159
FactorWeightFstImpl(const Fst<A> & fst,const FactorWeightOptions & opts)160 FactorWeightFstImpl(const Fst<A> &fst, const FactorWeightOptions &opts)
161 : CacheImpl<A>(opts), fst_(fst.Copy()), delta_(opts.delta),
162 final_only_(opts.final_only) {
163 SetType("factor-weight");
164 uint64 props = fst.Properties(kFstProperties, false);
165 SetProperties(FactorWeightProperties(props), kCopyProperties);
166
167 SetInputSymbols(fst.InputSymbols());
168 SetOutputSymbols(fst.OutputSymbols());
169 }
170
~FactorWeightFstImpl()171 ~FactorWeightFstImpl() {
172 delete fst_;
173 }
174
Start()175 StateId Start() {
176 if (!HasStart()) {
177 StateId s = fst_->Start();
178 if (s == kNoStateId)
179 return kNoStateId;
180 StateId start = FindState(Element(fst_->Start(), Weight::One()));
181 this->SetStart(start);
182 }
183 return CacheImpl<A>::Start();
184 }
185
Final(StateId s)186 Weight Final(StateId s) {
187 if (!HasFinal(s)) {
188 const Element &e = elements_[s];
189 // TODO: fix so cast is unnecessary
190 Weight w = e.state == kNoStateId
191 ? e.weight
192 : (Weight) Times(e.weight, fst_->Final(e.state));
193 FactorIterator f(w);
194 if (w != Weight::Zero() && f.Done())
195 this->SetFinal(s, w);
196 else
197 this->SetFinal(s, Weight::Zero());
198 }
199 return CacheImpl<A>::Final(s);
200 }
201
NumArcs(StateId s)202 size_t NumArcs(StateId s) {
203 if (!HasArcs(s))
204 Expand(s);
205 return CacheImpl<A>::NumArcs(s);
206 }
207
NumInputEpsilons(StateId s)208 size_t NumInputEpsilons(StateId s) {
209 if (!HasArcs(s))
210 Expand(s);
211 return CacheImpl<A>::NumInputEpsilons(s);
212 }
213
NumOutputEpsilons(StateId s)214 size_t NumOutputEpsilons(StateId s) {
215 if (!HasArcs(s))
216 Expand(s);
217 return CacheImpl<A>::NumOutputEpsilons(s);
218 }
219
InitArcIterator(StateId s,ArcIteratorData<A> * data)220 void InitArcIterator(StateId s, ArcIteratorData<A> *data) {
221 if (!HasArcs(s))
222 Expand(s);
223 CacheImpl<A>::InitArcIterator(s, data);
224 }
225
226
227 // Find state corresponding to an element. Create new state
228 // if element not found.
FindState(const Element & e)229 StateId FindState(const Element &e) {
230 if (final_only_ && e.weight == Weight::One()) {
231 while (unfactored_.size() <= (unsigned int)e.state)
232 unfactored_.push_back(kNoStateId);
233 if (unfactored_[e.state] == kNoStateId) {
234 unfactored_[e.state] = elements_.size();
235 elements_.push_back(e);
236 }
237 return unfactored_[e.state];
238 } else {
239 typename ElementMap::iterator eit = element_map_.find(e);
240 if (eit != element_map_.end()) {
241 return (*eit).second;
242 } else {
243 StateId s = elements_.size();
244 elements_.push_back(e);
245 element_map_.insert(pair<const Element, StateId>(e, s));
246 return s;
247 }
248 }
249 }
250
251 // Computes the outgoing transitions from a state, creating new destination
252 // states as needed.
Expand(StateId s)253 void Expand(StateId s) {
254 Element e = elements_[s];
255 if (e.state != kNoStateId) {
256 for (ArcIterator< Fst<A> > ait(*fst_, e.state);
257 !ait.Done();
258 ait.Next()) {
259 const A &arc = ait.Value();
260 Weight w = Times(e.weight, arc.weight);
261 FactorIterator fit(w);
262 if (final_only_ || fit.Done()) {
263 StateId d = FindState(Element(arc.nextstate, Weight::One()));
264 this->AddArc(s, Arc(arc.ilabel, arc.olabel, w, d));
265 } else {
266 for (; !fit.Done(); fit.Next()) {
267 const pair<Weight, Weight> &p = fit.Value();
268 StateId d = FindState(Element(arc.nextstate,
269 p.second.Quantize(delta_)));
270 this->AddArc(s, Arc(arc.ilabel, arc.olabel, p.first, d));
271 }
272 }
273 }
274 }
275 if ((e.state == kNoStateId) ||
276 (fst_->Final(e.state) != Weight::Zero())) {
277 Weight w = e.state == kNoStateId
278 ? e.weight
279 : Times(e.weight, fst_->Final(e.state));
280 for (FactorIterator fit(w);
281 !fit.Done();
282 fit.Next()) {
283 const pair<Weight, Weight> &p = fit.Value();
284 StateId d = FindState(Element(kNoStateId,
285 p.second.Quantize(delta_)));
286 this->AddArc(s, Arc(0, 0, p.first, d));
287 }
288 }
289 this->SetArcs(s);
290 }
291
292 private:
293 // Equality function for Elements, assume weights have been quantized.
294 class ElementEqual {
295 public:
operator()296 bool operator()(const Element &x, const Element &y) const {
297 return x.state == y.state && x.weight == y.weight;
298 }
299 };
300
301 // Hash function for Elements to Fst states.
302 class ElementKey {
303 public:
operator()304 size_t operator()(const Element &x) const {
305 return static_cast<size_t>(x.state * kPrime + x.weight.Hash());
306 }
307 private:
308 static const int kPrime = 7853;
309 };
310
311 typedef std::unordered_map<Element, StateId, ElementKey, ElementEqual> ElementMap;
312
313 const Fst<A> *fst_;
314 float delta_;
315 bool final_only_;
316 vector<Element> elements_; // mapping Fst state to Elements
317 ElementMap element_map_; // mapping Elements to Fst state
318 // mapping between old/new 'StateId' for states that do not need to
319 // be factored when 'final_only_' is true
320 vector<StateId> unfactored_;
321
322 DISALLOW_EVIL_CONSTRUCTORS(FactorWeightFstImpl);
323 };
324
325
326 // FactorWeightFst takes as template parameter a FactorIterator as
327 // defined above. The result of weight factoring is a transducer
328 // equivalent to the input whose path weights have been factored
329 // according to the FactorIterator. States and transitions will be
330 // added as necessary. The algorithm is a generalization to arbitrary
331 // weights of the second step of the input epsilon-normalization
332 // algorithm due to Mohri, "Generic epsilon-removal and input
333 // epsilon-normalization algorithms for weighted transducers",
334 // International Journal of Computer Science 13(1): 129-143 (2002).
335 template <class A, class F>
336 class FactorWeightFst : public Fst<A> {
337 public:
338 friend class ArcIterator< FactorWeightFst<A, F> >;
339 friend class CacheStateIterator< FactorWeightFst<A, F> >;
340 friend class CacheArcIterator< FactorWeightFst<A, F> >;
341
342 typedef A Arc;
343 typedef typename A::Weight Weight;
344 typedef typename A::StateId StateId;
345 typedef CacheState<A> State;
346
FactorWeightFst(const Fst<A> & fst)347 FactorWeightFst(const Fst<A> &fst)
348 : impl_(new FactorWeightFstImpl<A, F>(fst, FactorWeightOptions())) {}
349
FactorWeightFst(const Fst<A> & fst,const FactorWeightOptions & opts)350 FactorWeightFst(const Fst<A> &fst, const FactorWeightOptions &opts)
351 : impl_(new FactorWeightFstImpl<A, F>(fst, opts)) {}
FactorWeightFst(const FactorWeightFst<A,F> & fst)352 FactorWeightFst(const FactorWeightFst<A, F> &fst) : Fst<A>(fst), impl_(fst.impl_) {
353 impl_->IncrRefCount();
354 }
355
~FactorWeightFst()356 virtual ~FactorWeightFst() { if (!impl_->DecrRefCount()) delete impl_; }
357
Start()358 virtual StateId Start() const { return impl_->Start(); }
359
Final(StateId s)360 virtual Weight Final(StateId s) const { return impl_->Final(s); }
361
NumArcs(StateId s)362 virtual size_t NumArcs(StateId s) const { return impl_->NumArcs(s); }
363
NumInputEpsilons(StateId s)364 virtual size_t NumInputEpsilons(StateId s) const {
365 return impl_->NumInputEpsilons(s);
366 }
367
NumOutputEpsilons(StateId s)368 virtual size_t NumOutputEpsilons(StateId s) const {
369 return impl_->NumOutputEpsilons(s);
370 }
371
Properties(uint64 mask,bool test)372 virtual uint64 Properties(uint64 mask, bool test) const {
373 if (test) {
374 uint64 known, test = TestProperties(*this, mask, &known);
375 impl_->SetProperties(test, known);
376 return test & mask;
377 } else {
378 return impl_->Properties(mask);
379 }
380 }
381
Type()382 virtual const string& Type() const { return impl_->Type(); }
383
Copy()384 virtual FactorWeightFst<A, F> *Copy() const {
385 return new FactorWeightFst<A, F>(*this);
386 }
387
InputSymbols()388 virtual const SymbolTable* InputSymbols() const {
389 return impl_->InputSymbols();
390 }
391
OutputSymbols()392 virtual const SymbolTable* OutputSymbols() const {
393 return impl_->OutputSymbols();
394 }
395
396 virtual inline void InitStateIterator(StateIteratorData<A> *data) const;
397
InitArcIterator(StateId s,ArcIteratorData<A> * data)398 virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
399 impl_->InitArcIterator(s, data);
400 }
401
402 private:
Impl()403 FactorWeightFstImpl<A, F> *Impl() { return impl_; }
404
405 FactorWeightFstImpl<A, F> *impl_;
406
407 void operator=(const FactorWeightFst<A, F> &fst); // Disallow
408 };
409
410
411 // Specialization for FactorWeightFst.
412 template<class A, class F>
413 class StateIterator< FactorWeightFst<A, F> >
414 : public CacheStateIterator< FactorWeightFst<A, F> > {
415 public:
StateIterator(const FactorWeightFst<A,F> & fst)416 explicit StateIterator(const FactorWeightFst<A, F> &fst)
417 : CacheStateIterator< FactorWeightFst<A, F> >(fst) {}
418 };
419
420
421 // Specialization for FactorWeightFst.
422 template <class A, class F>
423 class ArcIterator< FactorWeightFst<A, F> >
424 : public CacheArcIterator< FactorWeightFst<A, F> > {
425 public:
426 typedef typename A::StateId StateId;
427
ArcIterator(const FactorWeightFst<A,F> & fst,StateId s)428 ArcIterator(const FactorWeightFst<A, F> &fst, StateId s)
429 : CacheArcIterator< FactorWeightFst<A, F> >(fst, s) {
430 if (!fst.impl_->HasArcs(s))
431 fst.impl_->Expand(s);
432 }
433
434 private:
435 DISALLOW_EVIL_CONSTRUCTORS(ArcIterator);
436 };
437
438 template <class A, class F> inline
InitStateIterator(StateIteratorData<A> * data)439 void FactorWeightFst<A, F>::InitStateIterator(StateIteratorData<A> *data) const
440 {
441 data->base = new StateIterator< FactorWeightFst<A, F> >(*this);
442 }
443
444
445 } // namespace fst
446
447 #endif // FST_LIB_FACTOR_WEIGHT_H__
448