1 // float-weight.h
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 //
16 // \file
17 // Float weight set and associated semiring operation definitions.
18 //
19
20 #ifndef FST_LIB_FLOAT_WEIGHT_H__
21 #define FST_LIB_FLOAT_WEIGHT_H__
22
23 #include <limits>
24
25 #include "fst/lib/weight.h"
26
27 namespace fst {
28
29 static const float kPosInfinity = numeric_limits<float>::infinity();
30 static const float kNegInfinity = -kPosInfinity;
31
32 // Single precision floating point weight base class
33 class FloatWeight {
34 public:
FloatWeight()35 FloatWeight() {}
36
FloatWeight(float f)37 FloatWeight(float f) : value_(f) {}
38
FloatWeight(const FloatWeight & w)39 FloatWeight(const FloatWeight &w) : value_(w.value_) {}
40
41 FloatWeight &operator=(const FloatWeight &w) {
42 value_ = w.value_;
43 return *this;
44 }
45
Read(istream & strm)46 istream &Read(istream &strm) {
47 return ReadType(strm, &value_);
48 }
49
Write(ostream & strm)50 ostream &Write(ostream &strm) const {
51 return WriteType(strm, value_);
52 }
53
Hash()54 ssize_t Hash() const {
55 union {
56 float f;
57 ssize_t s;
58 } u = { value_ };
59 return u.s;
60 }
61
Value()62 const float &Value() const { return value_; }
63
64 protected:
65 float value_;
66 };
67
68 inline bool operator==(const FloatWeight &w1, const FloatWeight &w2) {
69 // Volatile qualifier thwarts over-aggressive compiler optimizations
70 // that lead to problems esp. with NaturalLess().
71 volatile float v1 = w1.Value();
72 volatile float v2 = w2.Value();
73 return v1 == v2;
74 }
75
76 inline bool operator!=(const FloatWeight &w1, const FloatWeight &w2) {
77 return !(w1 == w2);
78 }
79
80 inline bool ApproxEqual(const FloatWeight &w1, const FloatWeight &w2,
81 float delta = kDelta) {
82 return w1.Value() <= w2.Value() + delta && w2.Value() <= w1.Value() + delta;
83 }
84
85 inline ostream &operator<<(ostream &strm, const FloatWeight &w) {
86 if (w.Value() == kPosInfinity)
87 return strm << "Infinity";
88 else if (w.Value() == kNegInfinity)
89 return strm << "-Infinity";
90 else if (w.Value() != w.Value()) // Fails for NaN
91 return strm << "BadFloat";
92 else
93 return strm << w.Value();
94 }
95
96 inline istream &operator>>(istream &strm, FloatWeight &w) {
97 string s;
98 strm >> s;
99 if (s == "Infinity") {
100 w = FloatWeight(kPosInfinity);
101 } else if (s == "-Infinity") {
102 w = FloatWeight(kNegInfinity);
103 } else {
104 char *p;
105 float f = strtod(s.c_str(), &p);
106 if (p < s.c_str() + s.size())
107 strm.clear(std::ios::badbit);
108 else
109 w = FloatWeight(f);
110 }
111 return strm;
112 }
113
114
115 // Tropical semiring: (min, +, inf, 0)
116 class TropicalWeight : public FloatWeight {
117 public:
118 typedef TropicalWeight ReverseWeight;
119
TropicalWeight()120 TropicalWeight() : FloatWeight() {}
121
TropicalWeight(float f)122 TropicalWeight(float f) : FloatWeight(f) {}
123
TropicalWeight(const TropicalWeight & w)124 TropicalWeight(const TropicalWeight &w) : FloatWeight(w) {}
125
Zero()126 static const TropicalWeight Zero() { return TropicalWeight(kPosInfinity); }
127
One()128 static const TropicalWeight One() { return TropicalWeight(0.0F); }
129
Type()130 static const string &Type() {
131 static const string type = "tropical";
132 return type;
133 }
134
Member()135 bool Member() const {
136 // First part fails for IEEE NaN
137 return Value() == Value() && Value() != kNegInfinity;
138 }
139
140 TropicalWeight Quantize(float delta = kDelta) const {
141 return TropicalWeight(floor(Value()/delta + 0.5F) * delta);
142 }
143
Reverse()144 TropicalWeight Reverse() const { return *this; }
145
Properties()146 static uint64 Properties() {
147 return kLeftSemiring | kRightSemiring | kCommutative |
148 kPath | kIdempotent;
149 }
150 };
151
Plus(const TropicalWeight & w1,const TropicalWeight & w2)152 inline TropicalWeight Plus(const TropicalWeight &w1,
153 const TropicalWeight &w2) {
154 return w1.Value() < w2.Value() ? w1 : w2;
155 }
156
Times(const TropicalWeight & w1,const TropicalWeight & w2)157 inline TropicalWeight Times(const TropicalWeight &w1,
158 const TropicalWeight &w2) {
159 float f1 = w1.Value(), f2 = w2.Value();
160 if (f1 == kPosInfinity)
161 return w1;
162 else if (f2 == kPosInfinity)
163 return w2;
164 else
165 return TropicalWeight(f1 + f2);
166 }
167
168 inline TropicalWeight Divide(const TropicalWeight &w1,
169 const TropicalWeight &w2,
170 DivideType typ = DIVIDE_ANY) {
171 float f1 = w1.Value(), f2 = w2.Value();
172 if (f2 == kPosInfinity)
173 return kNegInfinity;
174 else if (f1 == kPosInfinity)
175 return kPosInfinity;
176 else
177 return TropicalWeight(f1 - f2);
178 }
179
180
181 // Log semiring: (log(e^-x + e^y), +, inf, 0)
182 class LogWeight : public FloatWeight {
183 public:
184 typedef LogWeight ReverseWeight;
185
LogWeight()186 LogWeight() : FloatWeight() {}
187
LogWeight(float f)188 LogWeight(float f) : FloatWeight(f) {}
189
LogWeight(const LogWeight & w)190 LogWeight(const LogWeight &w) : FloatWeight(w) {}
191
Zero()192 static const LogWeight Zero() { return LogWeight(kPosInfinity); }
193
One()194 static const LogWeight One() { return LogWeight(0.0F); }
195
Type()196 static const string &Type() {
197 static const string type = "log";
198 return type;
199 }
200
Member()201 bool Member() const {
202 // First part fails for IEEE NaN
203 return Value() == Value() && Value() != kNegInfinity;
204 }
205
206 LogWeight Quantize(float delta = kDelta) const {
207 return LogWeight(floor(Value()/delta + 0.5F) * delta);
208 }
209
Reverse()210 LogWeight Reverse() const { return *this; }
211
Properties()212 static uint64 Properties() {
213 return kLeftSemiring | kRightSemiring | kCommutative;
214 }
215 };
216
LogExp(double x)217 inline double LogExp(double x) { return log(1.0F + exp(-x)); }
218
Plus(const LogWeight & w1,const LogWeight & w2)219 inline LogWeight Plus(const LogWeight &w1, const LogWeight &w2) {
220 float f1 = w1.Value(), f2 = w2.Value();
221 if (f1 == kPosInfinity)
222 return w2;
223 else if (f2 == kPosInfinity)
224 return w1;
225 else if (f1 > f2)
226 return LogWeight(f2 - LogExp(f1 - f2));
227 else
228 return LogWeight(f1 - LogExp(f2 - f1));
229 }
230
Times(const LogWeight & w1,const LogWeight & w2)231 inline LogWeight Times(const LogWeight &w1, const LogWeight &w2) {
232 float f1 = w1.Value(), f2 = w2.Value();
233 if (f1 == kPosInfinity)
234 return w1;
235 else if (f2 == kPosInfinity)
236 return w2;
237 else
238 return LogWeight(f1 + f2);
239 }
240
241 inline LogWeight Divide(const LogWeight &w1,
242 const LogWeight &w2,
243 DivideType typ = DIVIDE_ANY) {
244 float f1 = w1.Value(), f2 = w2.Value();
245 if (f2 == kPosInfinity)
246 return kNegInfinity;
247 else if (f1 == kPosInfinity)
248 return kPosInfinity;
249 else
250 return LogWeight(f1 - f2);
251 }
252
253 } // namespace fst;
254
255 #endif // FST_LIB_FLOAT_WEIGHT_H__
256