1 // float-weight.h
2
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Copyright 2005-2010 Google, Inc.
16 // Author: riley@google.com (Michael Riley)
17 //
18 // \file
19 // Float weight set and associated semiring operation definitions.
20 //
21
22 #ifndef FST_LIB_FLOAT_WEIGHT_H__
23 #define FST_LIB_FLOAT_WEIGHT_H__
24
25 #include <limits>
26 #include <climits>
27 #include <sstream>
28 #include <string>
29
30 #include <fst/util.h>
31 #include <fst/weight.h>
32
33
34 namespace fst {
35
36 // numeric limits class
37 template <class T>
38 class FloatLimits {
39 public:
40 static const T kPosInfinity;
41 static const T kNegInfinity;
42 static const T kNumberBad;
43 };
44
45 template <class T>
46 const T FloatLimits<T>::kPosInfinity = numeric_limits<T>::infinity();
47
48 template <class T>
49 const T FloatLimits<T>::kNegInfinity = -FloatLimits<T>::kPosInfinity;
50
51 template <class T>
52 const T FloatLimits<T>::kNumberBad = numeric_limits<T>::quiet_NaN();
53
54 // weight class to be templated on floating-points types
55 template <class T = float>
56 class FloatWeightTpl {
57 public:
FloatWeightTpl()58 FloatWeightTpl() {}
59
FloatWeightTpl(T f)60 FloatWeightTpl(T f) : value_(f) {}
61
FloatWeightTpl(const FloatWeightTpl<T> & w)62 FloatWeightTpl(const FloatWeightTpl<T> &w) : value_(w.value_) {}
63
64 FloatWeightTpl<T> &operator=(const FloatWeightTpl<T> &w) {
65 value_ = w.value_;
66 return *this;
67 }
68
Read(istream & strm)69 istream &Read(istream &strm) {
70 return ReadType(strm, &value_);
71 }
72
Write(ostream & strm)73 ostream &Write(ostream &strm) const {
74 return WriteType(strm, value_);
75 }
76
Hash()77 size_t Hash() const {
78 union {
79 T f;
80 size_t s;
81 } u;
82 u.s = 0;
83 u.f = value_;
84 return u.s;
85 }
86
Value()87 const T &Value() const { return value_; }
88
89 protected:
SetValue(const T & f)90 void SetValue(const T &f) { value_ = f; }
91
GetPrecisionString()92 inline static string GetPrecisionString() {
93 int64 size = sizeof(T);
94 if (size == sizeof(float)) return "";
95 size *= CHAR_BIT;
96
97 string result;
98 Int64ToStr(size, &result);
99 return result;
100 }
101
102 private:
103 T value_;
104 };
105
106 // Single-precision float weight
107 typedef FloatWeightTpl<float> FloatWeight;
108
109 template <class T>
110 inline bool operator==(const FloatWeightTpl<T> &w1,
111 const FloatWeightTpl<T> &w2) {
112 // Volatile qualifier thwarts over-aggressive compiler optimizations
113 // that lead to problems esp. with NaturalLess().
114 volatile T v1 = w1.Value();
115 volatile T v2 = w2.Value();
116 return v1 == v2;
117 }
118
119 inline bool operator==(const FloatWeightTpl<double> &w1,
120 const FloatWeightTpl<double> &w2) {
121 return operator==<double>(w1, w2);
122 }
123
124 inline bool operator==(const FloatWeightTpl<float> &w1,
125 const FloatWeightTpl<float> &w2) {
126 return operator==<float>(w1, w2);
127 }
128
129 template <class T>
130 inline bool operator!=(const FloatWeightTpl<T> &w1,
131 const FloatWeightTpl<T> &w2) {
132 return !(w1 == w2);
133 }
134
135 inline bool operator!=(const FloatWeightTpl<double> &w1,
136 const FloatWeightTpl<double> &w2) {
137 return operator!=<double>(w1, w2);
138 }
139
140 inline bool operator!=(const FloatWeightTpl<float> &w1,
141 const FloatWeightTpl<float> &w2) {
142 return operator!=<float>(w1, w2);
143 }
144
145 template <class T>
146 inline bool ApproxEqual(const FloatWeightTpl<T> &w1,
147 const FloatWeightTpl<T> &w2,
148 float delta = kDelta) {
149 return w1.Value() <= w2.Value() + delta && w2.Value() <= w1.Value() + delta;
150 }
151
152 template <class T>
153 inline ostream &operator<<(ostream &strm, const FloatWeightTpl<T> &w) {
154 if (w.Value() == FloatLimits<T>::kPosInfinity)
155 return strm << "Infinity";
156 else if (w.Value() == FloatLimits<T>::kNegInfinity)
157 return strm << "-Infinity";
158 else if (w.Value() != w.Value()) // Fails for NaN
159 return strm << "BadNumber";
160 else
161 return strm << w.Value();
162 }
163
164 template <class T>
165 inline istream &operator>>(istream &strm, FloatWeightTpl<T> &w) {
166 string s;
167 strm >> s;
168 if (s == "Infinity") {
169 w = FloatWeightTpl<T>(FloatLimits<T>::kPosInfinity);
170 } else if (s == "-Infinity") {
171 w = FloatWeightTpl<T>(FloatLimits<T>::kNegInfinity);
172 } else {
173 char *p;
174 T f = strtod(s.c_str(), &p);
175 if (p < s.c_str() + s.size())
176 strm.clear(std::ios::badbit);
177 else
178 w = FloatWeightTpl<T>(f);
179 }
180 return strm;
181 }
182
183
184 // Tropical semiring: (min, +, inf, 0)
185 template <class T>
186 class TropicalWeightTpl : public FloatWeightTpl<T> {
187 public:
188 using FloatWeightTpl<T>::Value;
189
190 typedef TropicalWeightTpl<T> ReverseWeight;
191
TropicalWeightTpl()192 TropicalWeightTpl() : FloatWeightTpl<T>() {}
193
TropicalWeightTpl(T f)194 TropicalWeightTpl(T f) : FloatWeightTpl<T>(f) {}
195
TropicalWeightTpl(const TropicalWeightTpl<T> & w)196 TropicalWeightTpl(const TropicalWeightTpl<T> &w) : FloatWeightTpl<T>(w) {}
197
Zero()198 static const TropicalWeightTpl<T> Zero() {
199 return TropicalWeightTpl<T>(FloatLimits<T>::kPosInfinity); }
200
One()201 static const TropicalWeightTpl<T> One() {
202 return TropicalWeightTpl<T>(0.0F); }
203
NoWeight()204 static const TropicalWeightTpl<T> NoWeight() {
205 return TropicalWeightTpl<T>(FloatLimits<T>::kNumberBad); }
206
Type()207 static const string &Type() {
208 static const string type = "tropical" +
209 FloatWeightTpl<T>::GetPrecisionString();
210 return type;
211 }
212
Member()213 bool Member() const {
214 // First part fails for IEEE NaN
215 return Value() == Value() && Value() != FloatLimits<T>::kNegInfinity;
216 }
217
218 TropicalWeightTpl<T> Quantize(float delta = kDelta) const {
219 if (Value() == FloatLimits<T>::kNegInfinity ||
220 Value() == FloatLimits<T>::kPosInfinity ||
221 Value() != Value())
222 return *this;
223 else
224 return TropicalWeightTpl<T>(floor(Value()/delta + 0.5F) * delta);
225 }
226
Reverse()227 TropicalWeightTpl<T> Reverse() const { return *this; }
228
Properties()229 static uint64 Properties() {
230 return kLeftSemiring | kRightSemiring | kCommutative |
231 kPath | kIdempotent;
232 }
233 };
234
235 // Single precision tropical weight
236 typedef TropicalWeightTpl<float> TropicalWeight;
237
238 template <class T>
Plus(const TropicalWeightTpl<T> & w1,const TropicalWeightTpl<T> & w2)239 inline TropicalWeightTpl<T> Plus(const TropicalWeightTpl<T> &w1,
240 const TropicalWeightTpl<T> &w2) {
241 if (!w1.Member() || !w2.Member())
242 return TropicalWeightTpl<T>::NoWeight();
243 return w1.Value() < w2.Value() ? w1 : w2;
244 }
245
Plus(const TropicalWeightTpl<float> & w1,const TropicalWeightTpl<float> & w2)246 inline TropicalWeightTpl<float> Plus(const TropicalWeightTpl<float> &w1,
247 const TropicalWeightTpl<float> &w2) {
248 return Plus<float>(w1, w2);
249 }
250
Plus(const TropicalWeightTpl<double> & w1,const TropicalWeightTpl<double> & w2)251 inline TropicalWeightTpl<double> Plus(const TropicalWeightTpl<double> &w1,
252 const TropicalWeightTpl<double> &w2) {
253 return Plus<double>(w1, w2);
254 }
255
256 template <class T>
Times(const TropicalWeightTpl<T> & w1,const TropicalWeightTpl<T> & w2)257 inline TropicalWeightTpl<T> Times(const TropicalWeightTpl<T> &w1,
258 const TropicalWeightTpl<T> &w2) {
259 if (!w1.Member() || !w2.Member())
260 return TropicalWeightTpl<T>::NoWeight();
261 T f1 = w1.Value(), f2 = w2.Value();
262 if (f1 == FloatLimits<T>::kPosInfinity)
263 return w1;
264 else if (f2 == FloatLimits<T>::kPosInfinity)
265 return w2;
266 else
267 return TropicalWeightTpl<T>(f1 + f2);
268 }
269
Times(const TropicalWeightTpl<float> & w1,const TropicalWeightTpl<float> & w2)270 inline TropicalWeightTpl<float> Times(const TropicalWeightTpl<float> &w1,
271 const TropicalWeightTpl<float> &w2) {
272 return Times<float>(w1, w2);
273 }
274
Times(const TropicalWeightTpl<double> & w1,const TropicalWeightTpl<double> & w2)275 inline TropicalWeightTpl<double> Times(const TropicalWeightTpl<double> &w1,
276 const TropicalWeightTpl<double> &w2) {
277 return Times<double>(w1, w2);
278 }
279
280 template <class T>
281 inline TropicalWeightTpl<T> Divide(const TropicalWeightTpl<T> &w1,
282 const TropicalWeightTpl<T> &w2,
283 DivideType typ = DIVIDE_ANY) {
284 if (!w1.Member() || !w2.Member())
285 return TropicalWeightTpl<T>::NoWeight();
286 T f1 = w1.Value(), f2 = w2.Value();
287 if (f2 == FloatLimits<T>::kPosInfinity)
288 return FloatLimits<T>::kNumberBad;
289 else if (f1 == FloatLimits<T>::kPosInfinity)
290 return FloatLimits<T>::kPosInfinity;
291 else
292 return TropicalWeightTpl<T>(f1 - f2);
293 }
294
295 inline TropicalWeightTpl<float> Divide(const TropicalWeightTpl<float> &w1,
296 const TropicalWeightTpl<float> &w2,
297 DivideType typ = DIVIDE_ANY) {
298 return Divide<float>(w1, w2, typ);
299 }
300
301 inline TropicalWeightTpl<double> Divide(const TropicalWeightTpl<double> &w1,
302 const TropicalWeightTpl<double> &w2,
303 DivideType typ = DIVIDE_ANY) {
304 return Divide<double>(w1, w2, typ);
305 }
306
307
308 // Log semiring: (log(e^-x + e^y), +, inf, 0)
309 template <class T>
310 class LogWeightTpl : public FloatWeightTpl<T> {
311 public:
312 using FloatWeightTpl<T>::Value;
313
314 typedef LogWeightTpl ReverseWeight;
315
LogWeightTpl()316 LogWeightTpl() : FloatWeightTpl<T>() {}
317
LogWeightTpl(T f)318 LogWeightTpl(T f) : FloatWeightTpl<T>(f) {}
319
LogWeightTpl(const LogWeightTpl<T> & w)320 LogWeightTpl(const LogWeightTpl<T> &w) : FloatWeightTpl<T>(w) {}
321
Zero()322 static const LogWeightTpl<T> Zero() {
323 return LogWeightTpl<T>(FloatLimits<T>::kPosInfinity);
324 }
325
One()326 static const LogWeightTpl<T> One() {
327 return LogWeightTpl<T>(0.0F);
328 }
329
NoWeight()330 static const LogWeightTpl<T> NoWeight() {
331 return LogWeightTpl<T>(FloatLimits<T>::kNumberBad); }
332
Type()333 static const string &Type() {
334 static const string type = "log" + FloatWeightTpl<T>::GetPrecisionString();
335 return type;
336 }
337
Member()338 bool Member() const {
339 // First part fails for IEEE NaN
340 return Value() == Value() && Value() != FloatLimits<T>::kNegInfinity;
341 }
342
343 LogWeightTpl<T> Quantize(float delta = kDelta) const {
344 if (Value() == FloatLimits<T>::kNegInfinity ||
345 Value() == FloatLimits<T>::kPosInfinity ||
346 Value() != Value())
347 return *this;
348 else
349 return LogWeightTpl<T>(floor(Value()/delta + 0.5F) * delta);
350 }
351
Reverse()352 LogWeightTpl<T> Reverse() const { return *this; }
353
Properties()354 static uint64 Properties() {
355 return kLeftSemiring | kRightSemiring | kCommutative;
356 }
357 };
358
359 // Single-precision log weight
360 typedef LogWeightTpl<float> LogWeight;
361 // Double-precision log weight
362 typedef LogWeightTpl<double> Log64Weight;
363
364 template <class T>
LogExp(T x)365 inline T LogExp(T x) { return log(1.0F + exp(-x)); }
366
367 template <class T>
Plus(const LogWeightTpl<T> & w1,const LogWeightTpl<T> & w2)368 inline LogWeightTpl<T> Plus(const LogWeightTpl<T> &w1,
369 const LogWeightTpl<T> &w2) {
370 T f1 = w1.Value(), f2 = w2.Value();
371 if (f1 == FloatLimits<T>::kPosInfinity)
372 return w2;
373 else if (f2 == FloatLimits<T>::kPosInfinity)
374 return w1;
375 else if (f1 > f2)
376 return LogWeightTpl<T>(f2 - LogExp(f1 - f2));
377 else
378 return LogWeightTpl<T>(f1 - LogExp(f2 - f1));
379 }
380
Plus(const LogWeightTpl<float> & w1,const LogWeightTpl<float> & w2)381 inline LogWeightTpl<float> Plus(const LogWeightTpl<float> &w1,
382 const LogWeightTpl<float> &w2) {
383 return Plus<float>(w1, w2);
384 }
385
Plus(const LogWeightTpl<double> & w1,const LogWeightTpl<double> & w2)386 inline LogWeightTpl<double> Plus(const LogWeightTpl<double> &w1,
387 const LogWeightTpl<double> &w2) {
388 return Plus<double>(w1, w2);
389 }
390
391 template <class T>
Times(const LogWeightTpl<T> & w1,const LogWeightTpl<T> & w2)392 inline LogWeightTpl<T> Times(const LogWeightTpl<T> &w1,
393 const LogWeightTpl<T> &w2) {
394 if (!w1.Member() || !w2.Member())
395 return LogWeightTpl<T>::NoWeight();
396 T f1 = w1.Value(), f2 = w2.Value();
397 if (f1 == FloatLimits<T>::kPosInfinity)
398 return w1;
399 else if (f2 == FloatLimits<T>::kPosInfinity)
400 return w2;
401 else
402 return LogWeightTpl<T>(f1 + f2);
403 }
404
Times(const LogWeightTpl<float> & w1,const LogWeightTpl<float> & w2)405 inline LogWeightTpl<float> Times(const LogWeightTpl<float> &w1,
406 const LogWeightTpl<float> &w2) {
407 return Times<float>(w1, w2);
408 }
409
Times(const LogWeightTpl<double> & w1,const LogWeightTpl<double> & w2)410 inline LogWeightTpl<double> Times(const LogWeightTpl<double> &w1,
411 const LogWeightTpl<double> &w2) {
412 return Times<double>(w1, w2);
413 }
414
415 template <class T>
416 inline LogWeightTpl<T> Divide(const LogWeightTpl<T> &w1,
417 const LogWeightTpl<T> &w2,
418 DivideType typ = DIVIDE_ANY) {
419 if (!w1.Member() || !w2.Member())
420 return LogWeightTpl<T>::NoWeight();
421 T f1 = w1.Value(), f2 = w2.Value();
422 if (f2 == FloatLimits<T>::kPosInfinity)
423 return FloatLimits<T>::kNumberBad;
424 else if (f1 == FloatLimits<T>::kPosInfinity)
425 return FloatLimits<T>::kPosInfinity;
426 else
427 return LogWeightTpl<T>(f1 - f2);
428 }
429
430 inline LogWeightTpl<float> Divide(const LogWeightTpl<float> &w1,
431 const LogWeightTpl<float> &w2,
432 DivideType typ = DIVIDE_ANY) {
433 return Divide<float>(w1, w2, typ);
434 }
435
436 inline LogWeightTpl<double> Divide(const LogWeightTpl<double> &w1,
437 const LogWeightTpl<double> &w2,
438 DivideType typ = DIVIDE_ANY) {
439 return Divide<double>(w1, w2, typ);
440 }
441
442 // MinMax semiring: (min, max, inf, -inf)
443 template <class T>
444 class MinMaxWeightTpl : public FloatWeightTpl<T> {
445 public:
446 using FloatWeightTpl<T>::Value;
447
448 typedef MinMaxWeightTpl<T> ReverseWeight;
449
MinMaxWeightTpl()450 MinMaxWeightTpl() : FloatWeightTpl<T>() {}
451
MinMaxWeightTpl(T f)452 MinMaxWeightTpl(T f) : FloatWeightTpl<T>(f) {}
453
MinMaxWeightTpl(const MinMaxWeightTpl<T> & w)454 MinMaxWeightTpl(const MinMaxWeightTpl<T> &w) : FloatWeightTpl<T>(w) {}
455
Zero()456 static const MinMaxWeightTpl<T> Zero() {
457 return MinMaxWeightTpl<T>(FloatLimits<T>::kPosInfinity);
458 }
459
One()460 static const MinMaxWeightTpl<T> One() {
461 return MinMaxWeightTpl<T>(FloatLimits<T>::kNegInfinity);
462 }
463
NoWeight()464 static const MinMaxWeightTpl<T> NoWeight() {
465 return MinMaxWeightTpl<T>(FloatLimits<T>::kNumberBad); }
466
Type()467 static const string &Type() {
468 static const string type = "minmax" +
469 FloatWeightTpl<T>::GetPrecisionString();
470 return type;
471 }
472
Member()473 bool Member() const {
474 // Fails for IEEE NaN
475 return Value() == Value();
476 }
477
478 MinMaxWeightTpl<T> Quantize(float delta = kDelta) const {
479 // If one of infinities, or a NaN
480 if (Value() == FloatLimits<T>::kNegInfinity ||
481 Value() == FloatLimits<T>::kPosInfinity ||
482 Value() != Value())
483 return *this;
484 else
485 return MinMaxWeightTpl<T>(floor(Value()/delta + 0.5F) * delta);
486 }
487
Reverse()488 MinMaxWeightTpl<T> Reverse() const { return *this; }
489
Properties()490 static uint64 Properties() {
491 return kLeftSemiring | kRightSemiring | kCommutative | kIdempotent | kPath;
492 }
493 };
494
495 // Single-precision min-max weight
496 typedef MinMaxWeightTpl<float> MinMaxWeight;
497
498 // Min
499 template <class T>
Plus(const MinMaxWeightTpl<T> & w1,const MinMaxWeightTpl<T> & w2)500 inline MinMaxWeightTpl<T> Plus(
501 const MinMaxWeightTpl<T> &w1, const MinMaxWeightTpl<T> &w2) {
502 if (!w1.Member() || !w2.Member())
503 return MinMaxWeightTpl<T>::NoWeight();
504 return w1.Value() < w2.Value() ? w1 : w2;
505 }
506
Plus(const MinMaxWeightTpl<float> & w1,const MinMaxWeightTpl<float> & w2)507 inline MinMaxWeightTpl<float> Plus(
508 const MinMaxWeightTpl<float> &w1, const MinMaxWeightTpl<float> &w2) {
509 return Plus<float>(w1, w2);
510 }
511
Plus(const MinMaxWeightTpl<double> & w1,const MinMaxWeightTpl<double> & w2)512 inline MinMaxWeightTpl<double> Plus(
513 const MinMaxWeightTpl<double> &w1, const MinMaxWeightTpl<double> &w2) {
514 return Plus<double>(w1, w2);
515 }
516
517 // Max
518 template <class T>
Times(const MinMaxWeightTpl<T> & w1,const MinMaxWeightTpl<T> & w2)519 inline MinMaxWeightTpl<T> Times(
520 const MinMaxWeightTpl<T> &w1, const MinMaxWeightTpl<T> &w2) {
521 if (!w1.Member() || !w2.Member())
522 return MinMaxWeightTpl<T>::NoWeight();
523 return w1.Value() >= w2.Value() ? w1 : w2;
524 }
525
Times(const MinMaxWeightTpl<float> & w1,const MinMaxWeightTpl<float> & w2)526 inline MinMaxWeightTpl<float> Times(
527 const MinMaxWeightTpl<float> &w1, const MinMaxWeightTpl<float> &w2) {
528 return Times<float>(w1, w2);
529 }
530
Times(const MinMaxWeightTpl<double> & w1,const MinMaxWeightTpl<double> & w2)531 inline MinMaxWeightTpl<double> Times(
532 const MinMaxWeightTpl<double> &w1, const MinMaxWeightTpl<double> &w2) {
533 return Times<double>(w1, w2);
534 }
535
536 // Defined only for special cases
537 template <class T>
538 inline MinMaxWeightTpl<T> Divide(const MinMaxWeightTpl<T> &w1,
539 const MinMaxWeightTpl<T> &w2,
540 DivideType typ = DIVIDE_ANY) {
541 if (!w1.Member() || !w2.Member())
542 return MinMaxWeightTpl<T>::NoWeight();
543 // min(w1, x) = w2, w1 >= w2 => min(w1, x) = w2, x = w2
544 return w1.Value() >= w2.Value() ? w1 : FloatLimits<T>::kNumberBad;
545 }
546
547 inline MinMaxWeightTpl<float> Divide(const MinMaxWeightTpl<float> &w1,
548 const MinMaxWeightTpl<float> &w2,
549 DivideType typ = DIVIDE_ANY) {
550 return Divide<float>(w1, w2, typ);
551 }
552
553 inline MinMaxWeightTpl<double> Divide(const MinMaxWeightTpl<double> &w1,
554 const MinMaxWeightTpl<double> &w2,
555 DivideType typ = DIVIDE_ANY) {
556 return Divide<double>(w1, w2, typ);
557 }
558
559 //
560 // WEIGHT CONVERTER SPECIALIZATIONS.
561 //
562
563 // Convert to tropical
564 template <>
565 struct WeightConvert<LogWeight, TropicalWeight> {
566 TropicalWeight operator()(LogWeight w) const { return w.Value(); }
567 };
568
569 template <>
570 struct WeightConvert<Log64Weight, TropicalWeight> {
571 TropicalWeight operator()(Log64Weight w) const { return w.Value(); }
572 };
573
574 // Convert to log
575 template <>
576 struct WeightConvert<TropicalWeight, LogWeight> {
577 LogWeight operator()(TropicalWeight w) const { return w.Value(); }
578 };
579
580 template <>
581 struct WeightConvert<Log64Weight, LogWeight> {
582 LogWeight operator()(Log64Weight w) const { return w.Value(); }
583 };
584
585 // Convert to log64
586 template <>
587 struct WeightConvert<TropicalWeight, Log64Weight> {
588 Log64Weight operator()(TropicalWeight w) const { return w.Value(); }
589 };
590
591 template <>
592 struct WeightConvert<LogWeight, Log64Weight> {
593 Log64Weight operator()(LogWeight w) const { return w.Value(); }
594 };
595
596 } // namespace fst
597
598 #endif // FST_LIB_FLOAT_WEIGHT_H__
599