• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2016 Gael Guennebaud <gael.guennebaud@inria.fr>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #include "main.h"
11 #include "../Eigen/SpecialFunctions"
12 
13 template<typename X, typename Y>
verify_component_wise(const X & x,const Y & y)14 void verify_component_wise(const X& x, const Y& y)
15 {
16   for(Index i=0; i<x.size(); ++i)
17   {
18     if((numext::isfinite)(y(i)))
19       VERIFY_IS_APPROX( x(i), y(i) );
20     else if((numext::isnan)(y(i)))
21       VERIFY((numext::isnan)(x(i)));
22     else
23       VERIFY_IS_EQUAL( x(i), y(i) );
24   }
25 }
26 
array_special_functions()27 template<typename ArrayType> void array_special_functions()
28 {
29   using std::abs;
30   using std::sqrt;
31   typedef typename ArrayType::Scalar Scalar;
32   typedef typename NumTraits<Scalar>::Real RealScalar;
33 
34   Scalar plusinf = std::numeric_limits<Scalar>::infinity();
35   Scalar nan = std::numeric_limits<Scalar>::quiet_NaN();
36 
37   Index rows = internal::random<Index>(1,30);
38   Index cols = 1;
39 
40   // API
41   {
42     ArrayType m1 = ArrayType::Random(rows,cols);
43 #if EIGEN_HAS_C99_MATH
44     VERIFY_IS_APPROX(m1.lgamma(), lgamma(m1));
45     VERIFY_IS_APPROX(m1.digamma(), digamma(m1));
46     VERIFY_IS_APPROX(m1.erf(), erf(m1));
47     VERIFY_IS_APPROX(m1.erfc(), erfc(m1));
48 #endif  // EIGEN_HAS_C99_MATH
49   }
50 
51 
52 #if EIGEN_HAS_C99_MATH
53   // check special functions (comparing against numpy implementation)
54   if (!NumTraits<Scalar>::IsComplex)
55   {
56 
57     {
58       ArrayType m1 = ArrayType::Random(rows,cols);
59       ArrayType m2 = ArrayType::Random(rows,cols);
60 
61       // Test various propreties of igamma & igammac.  These are normalized
62       // gamma integrals where
63       //   igammac(a, x) = Gamma(a, x) / Gamma(a)
64       //   igamma(a, x) = gamma(a, x) / Gamma(a)
65       // where Gamma and gamma are considered the standard unnormalized
66       // upper and lower incomplete gamma functions, respectively.
67       ArrayType a = m1.abs() + 2;
68       ArrayType x = m2.abs() + 2;
69       ArrayType zero = ArrayType::Zero(rows, cols);
70       ArrayType one = ArrayType::Constant(rows, cols, Scalar(1.0));
71       ArrayType a_m1 = a - one;
72       ArrayType Gamma_a_x = Eigen::igammac(a, x) * a.lgamma().exp();
73       ArrayType Gamma_a_m1_x = Eigen::igammac(a_m1, x) * a_m1.lgamma().exp();
74       ArrayType gamma_a_x = Eigen::igamma(a, x) * a.lgamma().exp();
75       ArrayType gamma_a_m1_x = Eigen::igamma(a_m1, x) * a_m1.lgamma().exp();
76 
77       // Gamma(a, 0) == Gamma(a)
78       VERIFY_IS_APPROX(Eigen::igammac(a, zero), one);
79 
80       // Gamma(a, x) + gamma(a, x) == Gamma(a)
81       VERIFY_IS_APPROX(Gamma_a_x + gamma_a_x, a.lgamma().exp());
82 
83       // Gamma(a, x) == (a - 1) * Gamma(a-1, x) + x^(a-1) * exp(-x)
84       VERIFY_IS_APPROX(Gamma_a_x, (a - 1) * Gamma_a_m1_x + x.pow(a-1) * (-x).exp());
85 
86       // gamma(a, x) == (a - 1) * gamma(a-1, x) - x^(a-1) * exp(-x)
87       VERIFY_IS_APPROX(gamma_a_x, (a - 1) * gamma_a_m1_x - x.pow(a-1) * (-x).exp());
88     }
89 
90     {
91       // Check exact values of igamma and igammac against a third party calculation.
92       Scalar a_s[] = {Scalar(0), Scalar(1), Scalar(1.5), Scalar(4), Scalar(0.0001), Scalar(1000.5)};
93       Scalar x_s[] = {Scalar(0), Scalar(1), Scalar(1.5), Scalar(4), Scalar(0.0001), Scalar(1000.5)};
94 
95       // location i*6+j corresponds to a_s[i], x_s[j].
96       Scalar igamma_s[][6] = {{0.0, nan, nan, nan, nan, nan},
97                               {0.0, 0.6321205588285578, 0.7768698398515702,
98                               0.9816843611112658, 9.999500016666262e-05, 1.0},
99                               {0.0, 0.4275932955291202, 0.608374823728911,
100                               0.9539882943107686, 7.522076445089201e-07, 1.0},
101                               {0.0, 0.01898815687615381, 0.06564245437845008,
102                               0.5665298796332909, 4.166333347221828e-18, 1.0},
103                               {0.0, 0.9999780593618628, 0.9999899967080838,
104                               0.9999996219837988, 0.9991370418689945, 1.0},
105                               {0.0, 0.0, 0.0, 0.0, 0.0, 0.5042041932513908}};
106       Scalar igammac_s[][6] = {{nan, nan, nan, nan, nan, nan},
107                               {1.0, 0.36787944117144233, 0.22313016014842982,
108                                 0.018315638888734182, 0.9999000049998333, 0.0},
109                               {1.0, 0.5724067044708798, 0.3916251762710878,
110                                 0.04601170568923136, 0.9999992477923555, 0.0},
111                               {1.0, 0.9810118431238462, 0.9343575456215499,
112                                 0.4334701203667089, 1.0, 0.0},
113                               {1.0, 2.1940638138146658e-05, 1.0003291916285e-05,
114                                 3.7801620118431334e-07, 0.0008629581310054535,
115                                 0.0},
116                               {1.0, 1.0, 1.0, 1.0, 1.0, 0.49579580674813944}};
117       for (int i = 0; i < 6; ++i) {
118         for (int j = 0; j < 6; ++j) {
119           if ((std::isnan)(igamma_s[i][j])) {
120             VERIFY((std::isnan)(numext::igamma(a_s[i], x_s[j])));
121           } else {
122             VERIFY_IS_APPROX(numext::igamma(a_s[i], x_s[j]), igamma_s[i][j]);
123           }
124 
125           if ((std::isnan)(igammac_s[i][j])) {
126             VERIFY((std::isnan)(numext::igammac(a_s[i], x_s[j])));
127           } else {
128             VERIFY_IS_APPROX(numext::igammac(a_s[i], x_s[j]), igammac_s[i][j]);
129           }
130         }
131       }
132     }
133   }
134 #endif  // EIGEN_HAS_C99_MATH
135 
136   // Check the zeta function against scipy.special.zeta
137   {
138     ArrayType x(7), q(7), res(7), ref(7);
139     x << 1.5,   4, 10.5, 10000.5,    3, 1,        0.9;
140     q << 2,   1.5,    3,  1.0001, -2.5, 1.2345, 1.2345;
141     ref << 1.61237534869, 0.234848505667, 1.03086757337e-5, 0.367879440865, 0.054102025820864097, plusinf, nan;
142     CALL_SUBTEST( verify_component_wise(ref, ref); );
143     CALL_SUBTEST( res = x.zeta(q); verify_component_wise(res, ref); );
144     CALL_SUBTEST( res = zeta(x,q); verify_component_wise(res, ref); );
145   }
146 
147   // digamma
148   {
149     ArrayType x(7), res(7), ref(7);
150     x << 1, 1.5, 4, -10.5, 10000.5, 0, -1;
151     ref << -0.5772156649015329, 0.03648997397857645, 1.2561176684318, 2.398239129535781, 9.210340372392849, plusinf, plusinf;
152     CALL_SUBTEST( verify_component_wise(ref, ref); );
153 
154     CALL_SUBTEST( res = x.digamma(); verify_component_wise(res, ref); );
155     CALL_SUBTEST( res = digamma(x);  verify_component_wise(res, ref); );
156   }
157 
158 
159 #if EIGEN_HAS_C99_MATH
160   {
161     ArrayType n(11), x(11), res(11), ref(11);
162     n << 1, 1,    1, 1.5,   17,   31,   28,    8, 42, 147, 170;
163     x << 2, 3, 25.5, 1.5,  4.7, 11.8, 17.7, 30.2, 15.8, 54.1, 64;
164     ref << 0.644934066848, 0.394934066848, 0.0399946696496, nan, 293.334565435, 0.445487887616, -2.47810300902e-07, -8.29668781082e-09, -0.434562276666, 0.567742190178, -0.0108615497927;
165     CALL_SUBTEST( verify_component_wise(ref, ref); );
166 
167     if(sizeof(RealScalar)>=8) {  // double
168       // Reason for commented line: http://eigen.tuxfamily.org/bz/show_bug.cgi?id=1232
169       //       CALL_SUBTEST( res = x.polygamma(n); verify_component_wise(res, ref); );
170       CALL_SUBTEST( res = polygamma(n,x);  verify_component_wise(res, ref); );
171     }
172     else {
173       //       CALL_SUBTEST( res = x.polygamma(n); verify_component_wise(res.head(8), ref.head(8)); );
174       CALL_SUBTEST( res = polygamma(n,x); verify_component_wise(res.head(8), ref.head(8)); );
175     }
176   }
177 #endif
178 
179 #if EIGEN_HAS_C99_MATH
180   {
181     // Inputs and ground truth generated with scipy via:
182     //   a = np.logspace(-3, 3, 5) - 1e-3
183     //   b = np.logspace(-3, 3, 5) - 1e-3
184     //   x = np.linspace(-0.1, 1.1, 5)
185     //   (full_a, full_b, full_x) = np.vectorize(lambda a, b, x: (a, b, x))(*np.ix_(a, b, x))
186     //   full_a = full_a.flatten().tolist()  # same for full_b, full_x
187     //   v = scipy.special.betainc(full_a, full_b, full_x).flatten().tolist()
188     //
189     // Note in Eigen, we call betainc with arguments in the order (x, a, b).
190     ArrayType a(125);
191     ArrayType b(125);
192     ArrayType x(125);
193     ArrayType v(125);
194     ArrayType res(125);
195 
196     a << 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
197         0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
198         0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
199         0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
200         0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
201         0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
202         0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
203         0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
204         0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
205         0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
206         0.03062277660168379, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999,
207         0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999,
208         0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999,
209         31.62177660168379, 31.62177660168379, 31.62177660168379,
210         31.62177660168379, 31.62177660168379, 31.62177660168379,
211         31.62177660168379, 31.62177660168379, 31.62177660168379,
212         31.62177660168379, 31.62177660168379, 31.62177660168379,
213         31.62177660168379, 31.62177660168379, 31.62177660168379,
214         31.62177660168379, 31.62177660168379, 31.62177660168379,
215         31.62177660168379, 31.62177660168379, 31.62177660168379,
216         31.62177660168379, 31.62177660168379, 31.62177660168379,
217         31.62177660168379, 999.999, 999.999, 999.999, 999.999, 999.999, 999.999,
218         999.999, 999.999, 999.999, 999.999, 999.999, 999.999, 999.999, 999.999,
219         999.999, 999.999, 999.999, 999.999, 999.999, 999.999, 999.999, 999.999,
220         999.999, 999.999, 999.999;
221 
222     b << 0.0, 0.0, 0.0, 0.0, 0.0, 0.03062277660168379, 0.03062277660168379,
223         0.03062277660168379, 0.03062277660168379, 0.03062277660168379, 0.999,
224         0.999, 0.999, 0.999, 0.999, 31.62177660168379, 31.62177660168379,
225         31.62177660168379, 31.62177660168379, 31.62177660168379, 999.999,
226         999.999, 999.999, 999.999, 999.999, 0.0, 0.0, 0.0, 0.0, 0.0,
227         0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
228         0.03062277660168379, 0.03062277660168379, 0.999, 0.999, 0.999, 0.999,
229         0.999, 31.62177660168379, 31.62177660168379, 31.62177660168379,
230         31.62177660168379, 31.62177660168379, 999.999, 999.999, 999.999,
231         999.999, 999.999, 0.0, 0.0, 0.0, 0.0, 0.0, 0.03062277660168379,
232         0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
233         0.03062277660168379, 0.999, 0.999, 0.999, 0.999, 0.999,
234         31.62177660168379, 31.62177660168379, 31.62177660168379,
235         31.62177660168379, 31.62177660168379, 999.999, 999.999, 999.999,
236         999.999, 999.999, 0.0, 0.0, 0.0, 0.0, 0.0, 0.03062277660168379,
237         0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
238         0.03062277660168379, 0.999, 0.999, 0.999, 0.999, 0.999,
239         31.62177660168379, 31.62177660168379, 31.62177660168379,
240         31.62177660168379, 31.62177660168379, 999.999, 999.999, 999.999,
241         999.999, 999.999, 0.0, 0.0, 0.0, 0.0, 0.0, 0.03062277660168379,
242         0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
243         0.03062277660168379, 0.999, 0.999, 0.999, 0.999, 0.999,
244         31.62177660168379, 31.62177660168379, 31.62177660168379,
245         31.62177660168379, 31.62177660168379, 999.999, 999.999, 999.999,
246         999.999, 999.999;
247 
248     x << -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5,
249         0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2,
250         0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1,
251         0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1,
252         -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8,
253         1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5,
254         0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2,
255         0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1,
256         0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5,
257         0.8, 1.1;
258 
259     v << nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
260         nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
261         nan, nan, nan, 0.47972119876364683, 0.5, 0.5202788012363533, nan, nan,
262         0.9518683957740043, 0.9789663010413743, 0.9931729188073435, nan, nan,
263         0.999995949033062, 0.9999999999993698, 0.9999999999999999, nan, nan,
264         0.9999999999999999, 0.9999999999999999, 0.9999999999999999, nan, nan,
265         nan, nan, nan, nan, nan, 0.006827081192655869, 0.0210336989586256,
266         0.04813160422599567, nan, nan, 0.20014344256217678, 0.5000000000000001,
267         0.7998565574378232, nan, nan, 0.9991401428435834, 0.999999999698403,
268         0.9999999999999999, nan, nan, 0.9999999999999999, 0.9999999999999999,
269         0.9999999999999999, nan, nan, nan, nan, nan, nan, nan,
270         1.0646600232370887e-25, 6.301722877826246e-13, 4.050966937974938e-06,
271         nan, nan, 7.864342668429763e-23, 3.015969667594166e-10,
272         0.0008598571564165444, nan, nan, 6.031987710123844e-08,
273         0.5000000000000007, 0.9999999396801229, nan, nan, 0.9999999999999999,
274         0.9999999999999999, 0.9999999999999999, nan, nan, nan, nan, nan, nan,
275         nan, 0.0, 7.029920380986636e-306, 2.2450728208591345e-101, nan, nan,
276         0.0, 9.275871147869727e-302, 1.2232913026152827e-97, nan, nan, 0.0,
277         3.0891393081932924e-252, 2.9303043666183996e-60, nan, nan,
278         2.248913486879199e-196, 0.5000000000004947, 0.9999999999999999, nan;
279 
280     CALL_SUBTEST(res = betainc(a, b, x);
281                  verify_component_wise(res, v););
282   }
283 
284   // Test various properties of betainc
285   {
286     ArrayType m1 = ArrayType::Random(32);
287     ArrayType m2 = ArrayType::Random(32);
288     ArrayType m3 = ArrayType::Random(32);
289     ArrayType one = ArrayType::Constant(32, Scalar(1.0));
290     const Scalar eps = std::numeric_limits<Scalar>::epsilon();
291     ArrayType a = (m1 * 4.0).exp();
292     ArrayType b = (m2 * 4.0).exp();
293     ArrayType x = m3.abs();
294 
295     // betainc(a, 1, x) == x**a
296     CALL_SUBTEST(
297         ArrayType test = betainc(a, one, x);
298         ArrayType expected = x.pow(a);
299         verify_component_wise(test, expected););
300 
301     // betainc(1, b, x) == 1 - (1 - x)**b
302     CALL_SUBTEST(
303         ArrayType test = betainc(one, b, x);
304         ArrayType expected = one - (one - x).pow(b);
305         verify_component_wise(test, expected););
306 
307     // betainc(a, b, x) == 1 - betainc(b, a, 1-x)
308     CALL_SUBTEST(
309         ArrayType test = betainc(a, b, x) + betainc(b, a, one - x);
310         ArrayType expected = one;
311         verify_component_wise(test, expected););
312 
313     // betainc(a+1, b, x) = betainc(a, b, x) - x**a * (1 - x)**b / (a * beta(a, b))
314     CALL_SUBTEST(
315         ArrayType num = x.pow(a) * (one - x).pow(b);
316         ArrayType denom = a * (a.lgamma() + b.lgamma() - (a + b).lgamma()).exp();
317         // Add eps to rhs and lhs so that component-wise test doesn't result in
318         // nans when both outputs are zeros.
319         ArrayType expected = betainc(a, b, x) - num / denom + eps;
320         ArrayType test = betainc(a + one, b, x) + eps;
321         if (sizeof(Scalar) >= 8) { // double
322           verify_component_wise(test, expected);
323         } else {
324           // Reason for limited test: http://eigen.tuxfamily.org/bz/show_bug.cgi?id=1232
325           verify_component_wise(test.head(8), expected.head(8));
326         });
327 
328     // betainc(a, b+1, x) = betainc(a, b, x) + x**a * (1 - x)**b / (b * beta(a, b))
329     CALL_SUBTEST(
330         // Add eps to rhs and lhs so that component-wise test doesn't result in
331         // nans when both outputs are zeros.
332         ArrayType num = x.pow(a) * (one - x).pow(b);
333         ArrayType denom = b * (a.lgamma() + b.lgamma() - (a + b).lgamma()).exp();
334         ArrayType expected = betainc(a, b, x) + num / denom + eps;
335         ArrayType test = betainc(a, b + one, x) + eps;
336         verify_component_wise(test, expected););
337   }
338 #endif
339 }
340 
test_special_functions()341 void test_special_functions()
342 {
343   CALL_SUBTEST_1(array_special_functions<ArrayXf>());
344   CALL_SUBTEST_2(array_special_functions<ArrayXd>());
345 }
346