• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // float-weight.h
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 //
16 // \file
17 // Float weight set and associated semiring operation definitions.
18 //
19 
20 #ifndef FST_LIB_FLOAT_WEIGHT_H__
21 #define FST_LIB_FLOAT_WEIGHT_H__
22 
23 #include <limits>
24 
25 #include "fst/lib/weight.h"
26 
27 namespace fst {
28 
29 static const float kPosInfinity = numeric_limits<float>::infinity();
30 static const float kNegInfinity = -kPosInfinity;
31 
32 // Single precision floating point weight base class
33 class FloatWeight {
34  public:
FloatWeight()35   FloatWeight() {}
36 
FloatWeight(float f)37   FloatWeight(float f) : value_(f) {}
38 
FloatWeight(const FloatWeight & w)39   FloatWeight(const FloatWeight &w) : value_(w.value_) {}
40 
41   FloatWeight &operator=(const FloatWeight &w) {
42     value_ = w.value_;
43     return *this;
44   }
45 
Read(istream & strm)46   istream &Read(istream &strm) {
47     return ReadType(strm, &value_);
48   }
49 
Write(ostream & strm)50   ostream &Write(ostream &strm) const {
51     return WriteType(strm, value_);
52   }
53 
Hash()54   ssize_t Hash() const {
55     union {
56       float f;
57       ssize_t s;
58     } u = { value_ };
59     return u.s;
60   }
61 
Value()62   const float &Value() const { return value_; }
63 
64  protected:
65   float value_;
66 };
67 
68 inline bool operator==(const FloatWeight &w1, const FloatWeight &w2) {
69   // Volatile qualifier thwarts over-aggressive compiler optimizations
70   // that lead to problems esp. with NaturalLess().
71   volatile float v1 = w1.Value();
72   volatile float v2 = w2.Value();
73   return v1 == v2;
74 }
75 
76 inline bool operator!=(const FloatWeight &w1, const FloatWeight &w2) {
77   return !(w1 == w2);
78 }
79 
80 inline bool ApproxEqual(const FloatWeight &w1, const FloatWeight &w2,
81                         float delta = kDelta) {
82   return w1.Value() <= w2.Value() + delta && w2.Value() <= w1.Value() + delta;
83 }
84 
85 inline ostream &operator<<(ostream &strm, const FloatWeight &w) {
86   if (w.Value() == kPosInfinity)
87     return strm << "Infinity";
88   else if (w.Value() == kNegInfinity)
89     return strm << "-Infinity";
90   else if (w.Value() != w.Value())   // Fails for NaN
91     return strm << "BadFloat";
92   else
93     return strm << w.Value();
94 }
95 
96 inline istream &operator>>(istream &strm, FloatWeight &w) {
97   string s;
98   strm >> s;
99   if (s == "Infinity") {
100     w = FloatWeight(kPosInfinity);
101   } else if (s == "-Infinity") {
102     w = FloatWeight(kNegInfinity);
103   } else {
104     char *p;
105     float f = strtod(s.c_str(), &p);
106     if (p < s.c_str() + s.size())
107       strm.clear(std::ios::badbit);
108     else
109       w = FloatWeight(f);
110   }
111   return strm;
112 }
113 
114 
115 // Tropical semiring: (min, +, inf, 0)
116 class TropicalWeight : public FloatWeight {
117  public:
118   typedef TropicalWeight ReverseWeight;
119 
TropicalWeight()120   TropicalWeight() : FloatWeight() {}
121 
TropicalWeight(float f)122   TropicalWeight(float f) : FloatWeight(f) {}
123 
TropicalWeight(const TropicalWeight & w)124   TropicalWeight(const TropicalWeight &w) : FloatWeight(w) {}
125 
Zero()126   static const TropicalWeight Zero() { return TropicalWeight(kPosInfinity); }
127 
One()128   static const TropicalWeight One() { return TropicalWeight(0.0F); }
129 
Type()130   static const string &Type() {
131     static const string type = "tropical";
132     return type;
133   }
134 
Member()135   bool Member() const {
136     // First part fails for IEEE NaN
137     return Value() == Value() && Value() != kNegInfinity;
138   }
139 
140   TropicalWeight Quantize(float delta = kDelta) const {
141     return TropicalWeight(floor(Value()/delta + 0.5F) * delta);
142   }
143 
Reverse()144   TropicalWeight Reverse() const { return *this; }
145 
Properties()146   static uint64 Properties() {
147     return kLeftSemiring | kRightSemiring | kCommutative |
148       kPath | kIdempotent;
149   }
150 };
151 
Plus(const TropicalWeight & w1,const TropicalWeight & w2)152 inline TropicalWeight Plus(const TropicalWeight &w1,
153                            const TropicalWeight &w2) {
154   return w1.Value() < w2.Value() ? w1 : w2;
155 }
156 
Times(const TropicalWeight & w1,const TropicalWeight & w2)157 inline TropicalWeight Times(const TropicalWeight &w1,
158                             const TropicalWeight &w2) {
159   float f1 = w1.Value(), f2 = w2.Value();
160   if (f1 == kPosInfinity)
161     return w1;
162   else if (f2 == kPosInfinity)
163     return w2;
164   else
165     return TropicalWeight(f1 + f2);
166 }
167 
168 inline TropicalWeight Divide(const TropicalWeight &w1,
169                              const TropicalWeight &w2,
170                              DivideType typ = DIVIDE_ANY) {
171   float f1 = w1.Value(), f2 = w2.Value();
172   if (f2 == kPosInfinity)
173     return kNegInfinity;
174   else if (f1 == kPosInfinity)
175     return kPosInfinity;
176   else
177     return TropicalWeight(f1 - f2);
178 }
179 
180 
181 // Log semiring: (log(e^-x + e^y), +, inf, 0)
182 class LogWeight : public FloatWeight {
183  public:
184   typedef LogWeight ReverseWeight;
185 
LogWeight()186   LogWeight() : FloatWeight() {}
187 
LogWeight(float f)188   LogWeight(float f) : FloatWeight(f) {}
189 
LogWeight(const LogWeight & w)190   LogWeight(const LogWeight &w) : FloatWeight(w) {}
191 
Zero()192   static const LogWeight Zero() {   return LogWeight(kPosInfinity); }
193 
One()194   static const LogWeight One() { return LogWeight(0.0F); }
195 
Type()196   static const string &Type() {
197     static const string type = "log";
198     return type;
199   }
200 
Member()201   bool Member() const {
202     // First part fails for IEEE NaN
203     return Value() == Value() && Value() != kNegInfinity;
204   }
205 
206   LogWeight Quantize(float delta = kDelta) const {
207     return LogWeight(floor(Value()/delta + 0.5F) * delta);
208   }
209 
Reverse()210   LogWeight Reverse() const { return *this; }
211 
Properties()212   static uint64 Properties() {
213     return kLeftSemiring | kRightSemiring | kCommutative;
214   }
215 };
216 
LogExp(double x)217 inline double LogExp(double x) { return log(1.0F + exp(-x)); }
218 
Plus(const LogWeight & w1,const LogWeight & w2)219 inline LogWeight Plus(const LogWeight &w1, const LogWeight &w2) {
220   float f1 = w1.Value(), f2 = w2.Value();
221   if (f1 == kPosInfinity)
222     return w2;
223   else if (f2 == kPosInfinity)
224     return w1;
225   else if (f1 > f2)
226     return LogWeight(f2 - LogExp(f1 - f2));
227   else
228     return LogWeight(f1 - LogExp(f2 - f1));
229 }
230 
Times(const LogWeight & w1,const LogWeight & w2)231 inline LogWeight Times(const LogWeight &w1, const LogWeight &w2) {
232   float f1 = w1.Value(), f2 = w2.Value();
233   if (f1 == kPosInfinity)
234     return w1;
235   else if (f2 == kPosInfinity)
236     return w2;
237   else
238     return LogWeight(f1 + f2);
239 }
240 
241 inline LogWeight Divide(const LogWeight &w1,
242                              const LogWeight &w2,
243                              DivideType typ = DIVIDE_ANY) {
244   float f1 = w1.Value(), f2 = w2.Value();
245   if (f2 == kPosInfinity)
246     return kNegInfinity;
247   else if (f1 == kPosInfinity)
248     return kPosInfinity;
249   else
250     return LogWeight(f1 - f2);
251 }
252 
253 }  // namespace fst;
254 
255 #endif  // FST_LIB_FLOAT_WEIGHT_H__
256