• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 //
14 // Copyright 2005-2010 Google, Inc.
15 // Author: krr@google.com (Kasturi Rangan Raghavan)
16 // \file
17 // LogWeight along with sign information that represents the value X in the
18 // linear domain as <sign(X), -ln(|X|)>
19 // The sign is a TropicalWeight:
20 //  positive, TropicalWeight.Value() > 0.0, recommended value 1.0
21 //  negative, TropicalWeight.Value() <= 0.0, recommended value -1.0
22 
23 #ifndef FST_LIB_SIGNED_LOG_WEIGHT_H_
24 #define FST_LIB_SIGNED_LOG_WEIGHT_H_
25 
26 #include <fst/float-weight.h>
27 #include <fst/pair-weight.h>
28 
29 
30 namespace fst {
31 template <class T>
32 class SignedLogWeightTpl
33     : public PairWeight<TropicalWeight, LogWeightTpl<T> > {
34  public:
35   typedef TropicalWeight X1;
36   typedef LogWeightTpl<T> X2;
37   using PairWeight<X1, X2>::Value1;
38   using PairWeight<X1, X2>::Value2;
39 
40   using PairWeight<X1, X2>::Reverse;
41   using PairWeight<X1, X2>::Quantize;
42   using PairWeight<X1, X2>::Member;
43 
44   typedef SignedLogWeightTpl<T> ReverseWeight;
45 
SignedLogWeightTpl()46   SignedLogWeightTpl() : PairWeight<X1, X2>() {}
47 
SignedLogWeightTpl(const SignedLogWeightTpl<T> & w)48   SignedLogWeightTpl(const SignedLogWeightTpl<T>& w)
49       : PairWeight<X1, X2> (w) { }
50 
SignedLogWeightTpl(const PairWeight<X1,X2> & w)51   SignedLogWeightTpl(const PairWeight<X1, X2>& w)
52       : PairWeight<X1, X2> (w) { }
53 
SignedLogWeightTpl(const X1 & x1,const X2 & x2)54   SignedLogWeightTpl(const X1& x1, const X2& x2)
55       : PairWeight<X1, X2>(x1, x2) { }
56 
Zero()57   static const SignedLogWeightTpl<T> &Zero() {
58     static const SignedLogWeightTpl<T> zero(X1(1.0), X2::Zero());
59     return zero;
60   }
61 
One()62   static const SignedLogWeightTpl<T> &One() {
63     static const SignedLogWeightTpl<T> one(X1(1.0), X2::One());
64     return one;
65   }
66 
NoWeight()67   static const SignedLogWeightTpl<T> &NoWeight() {
68     static const SignedLogWeightTpl<T> no_weight(X1(1.0), X2::NoWeight());
69     return no_weight;
70   }
71 
Type()72   static const string &Type() {
73     static const string type = "signed_log_" + X1::Type() + "_" + X2::Type();
74     return type;
75   }
76 
77   ProductWeight<X1, X2> Quantize(float delta = kDelta) const {
78     return PairWeight<X1, X2>::Quantize();
79   }
80 
Reverse()81   ReverseWeight Reverse() const {
82     return PairWeight<X1, X2>::Reverse();
83   }
84 
Member()85   bool Member() const {
86     return PairWeight<X1, X2>::Member();
87   }
88 
Properties()89   static uint64 Properties() {
90     // not idempotent nor path
91     return kLeftSemiring | kRightSemiring | kCommutative;
92   }
93 
Hash()94   size_t Hash() const {
95     size_t h1;
96     if (Value2() == X2::Zero() || Value1().Value() > 0.0)
97       h1 = TropicalWeight(1.0).Hash();
98     else
99       h1 = TropicalWeight(-1.0).Hash();
100     size_t h2 = Value2().Hash();
101     const int lshift = 5;
102     const int rshift = CHAR_BIT * sizeof(size_t) - 5;
103     return h1 << lshift ^ h1 >> rshift ^ h2;
104   }
105 };
106 
107 template <class T>
Plus(const SignedLogWeightTpl<T> & w1,const SignedLogWeightTpl<T> & w2)108 inline SignedLogWeightTpl<T> Plus(const SignedLogWeightTpl<T> &w1,
109                                   const SignedLogWeightTpl<T> &w2) {
110   if (!w1.Member() || !w2.Member())
111     return SignedLogWeightTpl<T>::NoWeight();
112   bool s1 = w1.Value1().Value() > 0.0;
113   bool s2 = w2.Value1().Value() > 0.0;
114   T f1 = w1.Value2().Value();
115   T f2 = w2.Value2().Value();
116   if (f1 == FloatLimits<T>::kPosInfinity)
117     return w2;
118   else if (f2 == FloatLimits<T>::kPosInfinity)
119     return w1;
120   else if (f1 == f2) {
121     if (s1 == s2)
122       return SignedLogWeightTpl<T>(w1.Value1(), (f2 - log(2.0F)));
123     else
124       return SignedLogWeightTpl<T>::Zero();
125   } else if (f1 > f2) {
126     if (s1 == s2) {
127       return SignedLogWeightTpl<T>(
128         w1.Value1(), (f2 - log(1.0F + exp(f2 - f1))));
129     } else {
130       return SignedLogWeightTpl<T>(
131         w2.Value1(), (f2 - log(1.0F - exp(f2 - f1))));
132     }
133   } else {
134     if (s2 == s1) {
135       return SignedLogWeightTpl<T>(
136         w2.Value1(), (f1 - log(1.0F + exp(f1 - f2))));
137     } else {
138       return SignedLogWeightTpl<T>(
139         w1.Value1(), (f1 - log(1.0F - exp(f1 - f2))));
140     }
141   }
142 }
143 
144 template <class T>
Minus(const SignedLogWeightTpl<T> & w1,const SignedLogWeightTpl<T> & w2)145 inline SignedLogWeightTpl<T> Minus(const SignedLogWeightTpl<T> &w1,
146                                    const SignedLogWeightTpl<T> &w2) {
147   SignedLogWeightTpl<T> minus_w2(-w2.Value1().Value(), w2.Value2());
148   return Plus(w1, minus_w2);
149 }
150 
151 template <class T>
Times(const SignedLogWeightTpl<T> & w1,const SignedLogWeightTpl<T> & w2)152 inline SignedLogWeightTpl<T> Times(const SignedLogWeightTpl<T> &w1,
153                                    const SignedLogWeightTpl<T> &w2) {
154   if (!w1.Member() || !w2.Member())
155     return SignedLogWeightTpl<T>::NoWeight();
156   bool s1 = w1.Value1().Value() > 0.0;
157   bool s2 = w2.Value1().Value() > 0.0;
158   T f1 = w1.Value2().Value();
159   T f2 = w2.Value2().Value();
160   if (s1 == s2)
161     return SignedLogWeightTpl<T>(TropicalWeight(1.0), (f1 + f2));
162   else
163     return SignedLogWeightTpl<T>(TropicalWeight(-1.0), (f1 + f2));
164 }
165 
166 template <class T>
167 inline SignedLogWeightTpl<T> Divide(const SignedLogWeightTpl<T> &w1,
168                                     const SignedLogWeightTpl<T> &w2,
169                                     DivideType typ = DIVIDE_ANY) {
170   if (!w1.Member() || !w2.Member())
171     return SignedLogWeightTpl<T>::NoWeight();
172   bool s1 = w1.Value1().Value() > 0.0;
173   bool s2 = w2.Value1().Value() > 0.0;
174   T f1 = w1.Value2().Value();
175   T f2 = w2.Value2().Value();
176   if (f2 == FloatLimits<T>::kPosInfinity)
177     return SignedLogWeightTpl<T>(TropicalWeight(1.0),
178       FloatLimits<T>::kNumberBad);
179   else if (f1 == FloatLimits<T>::kPosInfinity)
180     return SignedLogWeightTpl<T>(TropicalWeight(1.0),
181       FloatLimits<T>::kPosInfinity);
182   else if (s1 == s2)
183     return SignedLogWeightTpl<T>(TropicalWeight(1.0), (f1 - f2));
184   else
185     return SignedLogWeightTpl<T>(TropicalWeight(-1.0), (f1 - f2));
186 }
187 
188 template <class T>
189 inline bool ApproxEqual(const SignedLogWeightTpl<T> &w1,
190                         const SignedLogWeightTpl<T> &w2,
191                         float delta = kDelta) {
192   bool s1 = w1.Value1().Value() > 0.0;
193   bool s2 = w2.Value1().Value() > 0.0;
194   if (s1 == s2) {
195     return ApproxEqual(w1.Value2(), w2.Value2(), delta);
196   } else {
197     return w1.Value2() == LogWeightTpl<T>::Zero()
198         && w2.Value2() == LogWeightTpl<T>::Zero();
199   }
200 }
201 
202 template <class T>
203 inline bool operator==(const SignedLogWeightTpl<T> &w1,
204                        const SignedLogWeightTpl<T> &w2) {
205   bool s1 = w1.Value1().Value() > 0.0;
206   bool s2 = w2.Value1().Value() > 0.0;
207   if (s1 == s2)
208     return w1.Value2() == w2.Value2();
209   else
210     return (w1.Value2() == LogWeightTpl<T>::Zero()) &&
211            (w2.Value2() == LogWeightTpl<T>::Zero());
212 }
213 
214 
215 // Single-precision signed-log weight
216 typedef SignedLogWeightTpl<float> SignedLogWeight;
217 // Double-precision signed-log weight
218 typedef SignedLogWeightTpl<double> SignedLog64Weight;
219 
220 //
221 // WEIGHT CONVERTER SPECIALIZATIONS.
222 //
223 
224 template <class W1, class W2>
SignedLogConvertCheck(W1 w)225 bool SignedLogConvertCheck(W1 w) {
226   if (w.Value1().Value() < 0.0) {
227     FSTERROR() << "WeightConvert: can't convert weight from \""
228                << W1::Type() << "\" to \"" << W2::Type();
229     return false;
230   }
231   return true;
232 }
233 
234 // Convert to tropical
235 template <>
236 struct WeightConvert<SignedLogWeight, TropicalWeight> {
237   TropicalWeight operator()(SignedLogWeight w) const {
238     if (!SignedLogConvertCheck<SignedLogWeight, TropicalWeight>(w))
239       return TropicalWeight::NoWeight();
240     return w.Value2().Value();
241   }
242 };
243 
244 template <>
245 struct WeightConvert<SignedLog64Weight, TropicalWeight> {
246   TropicalWeight operator()(SignedLog64Weight w) const {
247     if (!SignedLogConvertCheck<SignedLog64Weight, TropicalWeight>(w))
248       return TropicalWeight::NoWeight();
249     return w.Value2().Value();
250   }
251 };
252 
253 // Convert to log
254 template <>
255 struct WeightConvert<SignedLogWeight, LogWeight> {
256   LogWeight operator()(SignedLogWeight w) const {
257     if (!SignedLogConvertCheck<SignedLogWeight, LogWeight>(w))
258       return LogWeight::NoWeight();
259     return w.Value2().Value();
260   }
261 };
262 
263 template <>
264 struct WeightConvert<SignedLog64Weight, LogWeight> {
265   LogWeight operator()(SignedLog64Weight w) const {
266     if (!SignedLogConvertCheck<SignedLog64Weight, LogWeight>(w))
267       return LogWeight::NoWeight();
268     return w.Value2().Value();
269   }
270 };
271 
272 // Convert to log64
273 template <>
274 struct WeightConvert<SignedLogWeight, Log64Weight> {
275   Log64Weight operator()(SignedLogWeight w) const {
276     if (!SignedLogConvertCheck<SignedLogWeight, Log64Weight>(w))
277       return Log64Weight::NoWeight();
278     return w.Value2().Value();
279   }
280 };
281 
282 template <>
283 struct WeightConvert<SignedLog64Weight, Log64Weight> {
284   Log64Weight operator()(SignedLog64Weight w) const {
285     if (!SignedLogConvertCheck<SignedLog64Weight, Log64Weight>(w))
286       return Log64Weight::NoWeight();
287     return w.Value2().Value();
288   }
289 };
290 
291 // Convert to signed log
292 template <>
293 struct WeightConvert<TropicalWeight, SignedLogWeight> {
294   SignedLogWeight operator()(TropicalWeight w) const {
295     TropicalWeight x1 = 1.0;
296     LogWeight x2 = w.Value();
297     return SignedLogWeight(x1, x2);
298   }
299 };
300 
301 template <>
302 struct WeightConvert<LogWeight, SignedLogWeight> {
303   SignedLogWeight operator()(LogWeight w) const {
304     TropicalWeight x1 = 1.0;
305     LogWeight x2 = w.Value();
306     return SignedLogWeight(x1, x2);
307   }
308 };
309 
310 template <>
311 struct WeightConvert<Log64Weight, SignedLogWeight> {
312   SignedLogWeight operator()(Log64Weight w) const {
313     TropicalWeight x1 = 1.0;
314     LogWeight x2 = w.Value();
315     return SignedLogWeight(x1, x2);
316   }
317 };
318 
319 template <>
320 struct WeightConvert<SignedLog64Weight, SignedLogWeight> {
321   SignedLogWeight operator()(SignedLog64Weight w) const {
322     TropicalWeight x1 = w.Value1();
323     LogWeight x2 = w.Value2().Value();
324     return SignedLogWeight(x1, x2);
325   }
326 };
327 
328 // Convert to signed log64
329 template <>
330 struct WeightConvert<TropicalWeight, SignedLog64Weight> {
331   SignedLog64Weight operator()(TropicalWeight w) const {
332     TropicalWeight x1 = 1.0;
333     Log64Weight x2 = w.Value();
334     return SignedLog64Weight(x1, x2);
335   }
336 };
337 
338 template <>
339 struct WeightConvert<LogWeight, SignedLog64Weight> {
340   SignedLog64Weight operator()(LogWeight w) const {
341     TropicalWeight x1 = 1.0;
342     Log64Weight x2 = w.Value();
343     return SignedLog64Weight(x1, x2);
344   }
345 };
346 
347 template <>
348 struct WeightConvert<Log64Weight, SignedLog64Weight> {
349   SignedLog64Weight operator()(Log64Weight w) const {
350     TropicalWeight x1 = 1.0;
351     Log64Weight x2 = w.Value();
352     return SignedLog64Weight(x1, x2);
353   }
354 };
355 
356 template <>
357 struct WeightConvert<SignedLogWeight, SignedLog64Weight> {
358   SignedLog64Weight operator()(SignedLogWeight w) const {
359     TropicalWeight x1 = w.Value1();
360     Log64Weight x2 = w.Value2().Value();
361     return SignedLog64Weight(x1, x2);
362   }
363 };
364 
365 }  // namespace fst
366 
367 #endif  // FST_LIB_SIGNED_LOG_WEIGHT_H_
368