1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2016 Gael Guennebaud <gael.guennebaud@inria.fr>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10 #include "main.h"
11 #include "../Eigen/SpecialFunctions"
12
13 template<typename X, typename Y>
verify_component_wise(const X & x,const Y & y)14 void verify_component_wise(const X& x, const Y& y)
15 {
16 for(Index i=0; i<x.size(); ++i)
17 {
18 if((numext::isfinite)(y(i)))
19 VERIFY_IS_APPROX( x(i), y(i) );
20 else if((numext::isnan)(y(i)))
21 VERIFY((numext::isnan)(x(i)));
22 else
23 VERIFY_IS_EQUAL( x(i), y(i) );
24 }
25 }
26
array_special_functions()27 template<typename ArrayType> void array_special_functions()
28 {
29 using std::abs;
30 using std::sqrt;
31 typedef typename ArrayType::Scalar Scalar;
32 typedef typename NumTraits<Scalar>::Real RealScalar;
33
34 Scalar plusinf = std::numeric_limits<Scalar>::infinity();
35 Scalar nan = std::numeric_limits<Scalar>::quiet_NaN();
36
37 Index rows = internal::random<Index>(1,30);
38 Index cols = 1;
39
40 // API
41 {
42 ArrayType m1 = ArrayType::Random(rows,cols);
43 #if EIGEN_HAS_C99_MATH
44 VERIFY_IS_APPROX(m1.lgamma(), lgamma(m1));
45 VERIFY_IS_APPROX(m1.digamma(), digamma(m1));
46 VERIFY_IS_APPROX(m1.erf(), erf(m1));
47 VERIFY_IS_APPROX(m1.erfc(), erfc(m1));
48 #endif // EIGEN_HAS_C99_MATH
49 }
50
51
52 #if EIGEN_HAS_C99_MATH
53 // check special functions (comparing against numpy implementation)
54 if (!NumTraits<Scalar>::IsComplex)
55 {
56
57 {
58 ArrayType m1 = ArrayType::Random(rows,cols);
59 ArrayType m2 = ArrayType::Random(rows,cols);
60
61 // Test various propreties of igamma & igammac. These are normalized
62 // gamma integrals where
63 // igammac(a, x) = Gamma(a, x) / Gamma(a)
64 // igamma(a, x) = gamma(a, x) / Gamma(a)
65 // where Gamma and gamma are considered the standard unnormalized
66 // upper and lower incomplete gamma functions, respectively.
67 ArrayType a = m1.abs() + 2;
68 ArrayType x = m2.abs() + 2;
69 ArrayType zero = ArrayType::Zero(rows, cols);
70 ArrayType one = ArrayType::Constant(rows, cols, Scalar(1.0));
71 ArrayType a_m1 = a - one;
72 ArrayType Gamma_a_x = Eigen::igammac(a, x) * a.lgamma().exp();
73 ArrayType Gamma_a_m1_x = Eigen::igammac(a_m1, x) * a_m1.lgamma().exp();
74 ArrayType gamma_a_x = Eigen::igamma(a, x) * a.lgamma().exp();
75 ArrayType gamma_a_m1_x = Eigen::igamma(a_m1, x) * a_m1.lgamma().exp();
76
77 // Gamma(a, 0) == Gamma(a)
78 VERIFY_IS_APPROX(Eigen::igammac(a, zero), one);
79
80 // Gamma(a, x) + gamma(a, x) == Gamma(a)
81 VERIFY_IS_APPROX(Gamma_a_x + gamma_a_x, a.lgamma().exp());
82
83 // Gamma(a, x) == (a - 1) * Gamma(a-1, x) + x^(a-1) * exp(-x)
84 VERIFY_IS_APPROX(Gamma_a_x, (a - 1) * Gamma_a_m1_x + x.pow(a-1) * (-x).exp());
85
86 // gamma(a, x) == (a - 1) * gamma(a-1, x) - x^(a-1) * exp(-x)
87 VERIFY_IS_APPROX(gamma_a_x, (a - 1) * gamma_a_m1_x - x.pow(a-1) * (-x).exp());
88 }
89
90 {
91 // Check exact values of igamma and igammac against a third party calculation.
92 Scalar a_s[] = {Scalar(0), Scalar(1), Scalar(1.5), Scalar(4), Scalar(0.0001), Scalar(1000.5)};
93 Scalar x_s[] = {Scalar(0), Scalar(1), Scalar(1.5), Scalar(4), Scalar(0.0001), Scalar(1000.5)};
94
95 // location i*6+j corresponds to a_s[i], x_s[j].
96 Scalar igamma_s[][6] = {{0.0, nan, nan, nan, nan, nan},
97 {0.0, 0.6321205588285578, 0.7768698398515702,
98 0.9816843611112658, 9.999500016666262e-05, 1.0},
99 {0.0, 0.4275932955291202, 0.608374823728911,
100 0.9539882943107686, 7.522076445089201e-07, 1.0},
101 {0.0, 0.01898815687615381, 0.06564245437845008,
102 0.5665298796332909, 4.166333347221828e-18, 1.0},
103 {0.0, 0.9999780593618628, 0.9999899967080838,
104 0.9999996219837988, 0.9991370418689945, 1.0},
105 {0.0, 0.0, 0.0, 0.0, 0.0, 0.5042041932513908}};
106 Scalar igammac_s[][6] = {{nan, nan, nan, nan, nan, nan},
107 {1.0, 0.36787944117144233, 0.22313016014842982,
108 0.018315638888734182, 0.9999000049998333, 0.0},
109 {1.0, 0.5724067044708798, 0.3916251762710878,
110 0.04601170568923136, 0.9999992477923555, 0.0},
111 {1.0, 0.9810118431238462, 0.9343575456215499,
112 0.4334701203667089, 1.0, 0.0},
113 {1.0, 2.1940638138146658e-05, 1.0003291916285e-05,
114 3.7801620118431334e-07, 0.0008629581310054535,
115 0.0},
116 {1.0, 1.0, 1.0, 1.0, 1.0, 0.49579580674813944}};
117 for (int i = 0; i < 6; ++i) {
118 for (int j = 0; j < 6; ++j) {
119 if ((std::isnan)(igamma_s[i][j])) {
120 VERIFY((std::isnan)(numext::igamma(a_s[i], x_s[j])));
121 } else {
122 VERIFY_IS_APPROX(numext::igamma(a_s[i], x_s[j]), igamma_s[i][j]);
123 }
124
125 if ((std::isnan)(igammac_s[i][j])) {
126 VERIFY((std::isnan)(numext::igammac(a_s[i], x_s[j])));
127 } else {
128 VERIFY_IS_APPROX(numext::igammac(a_s[i], x_s[j]), igammac_s[i][j]);
129 }
130 }
131 }
132 }
133 }
134 #endif // EIGEN_HAS_C99_MATH
135
136 // Check the zeta function against scipy.special.zeta
137 {
138 ArrayType x(7), q(7), res(7), ref(7);
139 x << 1.5, 4, 10.5, 10000.5, 3, 1, 0.9;
140 q << 2, 1.5, 3, 1.0001, -2.5, 1.2345, 1.2345;
141 ref << 1.61237534869, 0.234848505667, 1.03086757337e-5, 0.367879440865, 0.054102025820864097, plusinf, nan;
142 CALL_SUBTEST( verify_component_wise(ref, ref); );
143 CALL_SUBTEST( res = x.zeta(q); verify_component_wise(res, ref); );
144 CALL_SUBTEST( res = zeta(x,q); verify_component_wise(res, ref); );
145 }
146
147 // digamma
148 {
149 ArrayType x(7), res(7), ref(7);
150 x << 1, 1.5, 4, -10.5, 10000.5, 0, -1;
151 ref << -0.5772156649015329, 0.03648997397857645, 1.2561176684318, 2.398239129535781, 9.210340372392849, plusinf, plusinf;
152 CALL_SUBTEST( verify_component_wise(ref, ref); );
153
154 CALL_SUBTEST( res = x.digamma(); verify_component_wise(res, ref); );
155 CALL_SUBTEST( res = digamma(x); verify_component_wise(res, ref); );
156 }
157
158
159 #if EIGEN_HAS_C99_MATH
160 {
161 ArrayType n(11), x(11), res(11), ref(11);
162 n << 1, 1, 1, 1.5, 17, 31, 28, 8, 42, 147, 170;
163 x << 2, 3, 25.5, 1.5, 4.7, 11.8, 17.7, 30.2, 15.8, 54.1, 64;
164 ref << 0.644934066848, 0.394934066848, 0.0399946696496, nan, 293.334565435, 0.445487887616, -2.47810300902e-07, -8.29668781082e-09, -0.434562276666, 0.567742190178, -0.0108615497927;
165 CALL_SUBTEST( verify_component_wise(ref, ref); );
166
167 if(sizeof(RealScalar)>=8) { // double
168 // Reason for commented line: http://eigen.tuxfamily.org/bz/show_bug.cgi?id=1232
169 // CALL_SUBTEST( res = x.polygamma(n); verify_component_wise(res, ref); );
170 CALL_SUBTEST( res = polygamma(n,x); verify_component_wise(res, ref); );
171 }
172 else {
173 // CALL_SUBTEST( res = x.polygamma(n); verify_component_wise(res.head(8), ref.head(8)); );
174 CALL_SUBTEST( res = polygamma(n,x); verify_component_wise(res.head(8), ref.head(8)); );
175 }
176 }
177 #endif
178
179 #if EIGEN_HAS_C99_MATH
180 {
181 // Inputs and ground truth generated with scipy via:
182 // a = np.logspace(-3, 3, 5) - 1e-3
183 // b = np.logspace(-3, 3, 5) - 1e-3
184 // x = np.linspace(-0.1, 1.1, 5)
185 // (full_a, full_b, full_x) = np.vectorize(lambda a, b, x: (a, b, x))(*np.ix_(a, b, x))
186 // full_a = full_a.flatten().tolist() # same for full_b, full_x
187 // v = scipy.special.betainc(full_a, full_b, full_x).flatten().tolist()
188 //
189 // Note in Eigen, we call betainc with arguments in the order (x, a, b).
190 ArrayType a(125);
191 ArrayType b(125);
192 ArrayType x(125);
193 ArrayType v(125);
194 ArrayType res(125);
195
196 a << 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
197 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
198 0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
199 0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
200 0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
201 0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
202 0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
203 0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
204 0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
205 0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
206 0.03062277660168379, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999,
207 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999,
208 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999,
209 31.62177660168379, 31.62177660168379, 31.62177660168379,
210 31.62177660168379, 31.62177660168379, 31.62177660168379,
211 31.62177660168379, 31.62177660168379, 31.62177660168379,
212 31.62177660168379, 31.62177660168379, 31.62177660168379,
213 31.62177660168379, 31.62177660168379, 31.62177660168379,
214 31.62177660168379, 31.62177660168379, 31.62177660168379,
215 31.62177660168379, 31.62177660168379, 31.62177660168379,
216 31.62177660168379, 31.62177660168379, 31.62177660168379,
217 31.62177660168379, 999.999, 999.999, 999.999, 999.999, 999.999, 999.999,
218 999.999, 999.999, 999.999, 999.999, 999.999, 999.999, 999.999, 999.999,
219 999.999, 999.999, 999.999, 999.999, 999.999, 999.999, 999.999, 999.999,
220 999.999, 999.999, 999.999;
221
222 b << 0.0, 0.0, 0.0, 0.0, 0.0, 0.03062277660168379, 0.03062277660168379,
223 0.03062277660168379, 0.03062277660168379, 0.03062277660168379, 0.999,
224 0.999, 0.999, 0.999, 0.999, 31.62177660168379, 31.62177660168379,
225 31.62177660168379, 31.62177660168379, 31.62177660168379, 999.999,
226 999.999, 999.999, 999.999, 999.999, 0.0, 0.0, 0.0, 0.0, 0.0,
227 0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
228 0.03062277660168379, 0.03062277660168379, 0.999, 0.999, 0.999, 0.999,
229 0.999, 31.62177660168379, 31.62177660168379, 31.62177660168379,
230 31.62177660168379, 31.62177660168379, 999.999, 999.999, 999.999,
231 999.999, 999.999, 0.0, 0.0, 0.0, 0.0, 0.0, 0.03062277660168379,
232 0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
233 0.03062277660168379, 0.999, 0.999, 0.999, 0.999, 0.999,
234 31.62177660168379, 31.62177660168379, 31.62177660168379,
235 31.62177660168379, 31.62177660168379, 999.999, 999.999, 999.999,
236 999.999, 999.999, 0.0, 0.0, 0.0, 0.0, 0.0, 0.03062277660168379,
237 0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
238 0.03062277660168379, 0.999, 0.999, 0.999, 0.999, 0.999,
239 31.62177660168379, 31.62177660168379, 31.62177660168379,
240 31.62177660168379, 31.62177660168379, 999.999, 999.999, 999.999,
241 999.999, 999.999, 0.0, 0.0, 0.0, 0.0, 0.0, 0.03062277660168379,
242 0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
243 0.03062277660168379, 0.999, 0.999, 0.999, 0.999, 0.999,
244 31.62177660168379, 31.62177660168379, 31.62177660168379,
245 31.62177660168379, 31.62177660168379, 999.999, 999.999, 999.999,
246 999.999, 999.999;
247
248 x << -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5,
249 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2,
250 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1,
251 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1,
252 -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8,
253 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5,
254 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2,
255 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1,
256 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5,
257 0.8, 1.1;
258
259 v << nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
260 nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
261 nan, nan, nan, 0.47972119876364683, 0.5, 0.5202788012363533, nan, nan,
262 0.9518683957740043, 0.9789663010413743, 0.9931729188073435, nan, nan,
263 0.999995949033062, 0.9999999999993698, 0.9999999999999999, nan, nan,
264 0.9999999999999999, 0.9999999999999999, 0.9999999999999999, nan, nan,
265 nan, nan, nan, nan, nan, 0.006827081192655869, 0.0210336989586256,
266 0.04813160422599567, nan, nan, 0.20014344256217678, 0.5000000000000001,
267 0.7998565574378232, nan, nan, 0.9991401428435834, 0.999999999698403,
268 0.9999999999999999, nan, nan, 0.9999999999999999, 0.9999999999999999,
269 0.9999999999999999, nan, nan, nan, nan, nan, nan, nan,
270 1.0646600232370887e-25, 6.301722877826246e-13, 4.050966937974938e-06,
271 nan, nan, 7.864342668429763e-23, 3.015969667594166e-10,
272 0.0008598571564165444, nan, nan, 6.031987710123844e-08,
273 0.5000000000000007, 0.9999999396801229, nan, nan, 0.9999999999999999,
274 0.9999999999999999, 0.9999999999999999, nan, nan, nan, nan, nan, nan,
275 nan, 0.0, 7.029920380986636e-306, 2.2450728208591345e-101, nan, nan,
276 0.0, 9.275871147869727e-302, 1.2232913026152827e-97, nan, nan, 0.0,
277 3.0891393081932924e-252, 2.9303043666183996e-60, nan, nan,
278 2.248913486879199e-196, 0.5000000000004947, 0.9999999999999999, nan;
279
280 CALL_SUBTEST(res = betainc(a, b, x);
281 verify_component_wise(res, v););
282 }
283
284 // Test various properties of betainc
285 {
286 ArrayType m1 = ArrayType::Random(32);
287 ArrayType m2 = ArrayType::Random(32);
288 ArrayType m3 = ArrayType::Random(32);
289 ArrayType one = ArrayType::Constant(32, Scalar(1.0));
290 const Scalar eps = std::numeric_limits<Scalar>::epsilon();
291 ArrayType a = (m1 * 4.0).exp();
292 ArrayType b = (m2 * 4.0).exp();
293 ArrayType x = m3.abs();
294
295 // betainc(a, 1, x) == x**a
296 CALL_SUBTEST(
297 ArrayType test = betainc(a, one, x);
298 ArrayType expected = x.pow(a);
299 verify_component_wise(test, expected););
300
301 // betainc(1, b, x) == 1 - (1 - x)**b
302 CALL_SUBTEST(
303 ArrayType test = betainc(one, b, x);
304 ArrayType expected = one - (one - x).pow(b);
305 verify_component_wise(test, expected););
306
307 // betainc(a, b, x) == 1 - betainc(b, a, 1-x)
308 CALL_SUBTEST(
309 ArrayType test = betainc(a, b, x) + betainc(b, a, one - x);
310 ArrayType expected = one;
311 verify_component_wise(test, expected););
312
313 // betainc(a+1, b, x) = betainc(a, b, x) - x**a * (1 - x)**b / (a * beta(a, b))
314 CALL_SUBTEST(
315 ArrayType num = x.pow(a) * (one - x).pow(b);
316 ArrayType denom = a * (a.lgamma() + b.lgamma() - (a + b).lgamma()).exp();
317 // Add eps to rhs and lhs so that component-wise test doesn't result in
318 // nans when both outputs are zeros.
319 ArrayType expected = betainc(a, b, x) - num / denom + eps;
320 ArrayType test = betainc(a + one, b, x) + eps;
321 if (sizeof(Scalar) >= 8) { // double
322 verify_component_wise(test, expected);
323 } else {
324 // Reason for limited test: http://eigen.tuxfamily.org/bz/show_bug.cgi?id=1232
325 verify_component_wise(test.head(8), expected.head(8));
326 });
327
328 // betainc(a, b+1, x) = betainc(a, b, x) + x**a * (1 - x)**b / (b * beta(a, b))
329 CALL_SUBTEST(
330 // Add eps to rhs and lhs so that component-wise test doesn't result in
331 // nans when both outputs are zeros.
332 ArrayType num = x.pow(a) * (one - x).pow(b);
333 ArrayType denom = b * (a.lgamma() + b.lgamma() - (a + b).lgamma()).exp();
334 ArrayType expected = betainc(a, b, x) + num / denom + eps;
335 ArrayType test = betainc(a, b + one, x) + eps;
336 verify_component_wise(test, expected););
337 }
338 #endif
339 }
340
test_special_functions()341 void test_special_functions()
342 {
343 CALL_SUBTEST_1(array_special_functions<ArrayXf>());
344 CALL_SUBTEST_2(array_special_functions<ArrayXd>());
345 }
346