1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for Student t distribution.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import importlib 22import math 23 24import numpy as np 25 26from tensorflow.python.eager import backprop 27from tensorflow.python.framework import constant_op 28from tensorflow.python.framework import random_seed 29from tensorflow.python.framework import test_util 30from tensorflow.python.ops import math_ops 31from tensorflow.python.ops import nn_ops 32from tensorflow.python.ops.distributions import student_t 33from tensorflow.python.platform import test 34from tensorflow.python.platform import tf_logging 35 36 37def try_import(name): # pylint: disable=invalid-name 38 module = None 39 try: 40 module = importlib.import_module(name) 41 except ImportError as e: 42 tf_logging.warning("Could not import %s: %s" % (name, str(e))) 43 return module 44 45 46stats = try_import("scipy.stats") 47 48 49@test_util.run_all_in_graph_and_eager_modes 50class StudentTTest(test.TestCase): 51 52 def testStudentPDFAndLogPDF(self): 53 batch_size = 6 54 df = constant_op.constant([3.] * batch_size) 55 mu = constant_op.constant([7.] * batch_size) 56 sigma = constant_op.constant([8.] * batch_size) 57 df_v = 3. 58 mu_v = 7. 59 sigma_v = 8. 60 t = np.array([-2.5, 2.5, 8., 0., -1., 2.], dtype=np.float32) 61 student = student_t.StudentT(df, loc=mu, scale=-sigma) 62 63 log_pdf = student.log_prob(t) 64 self.assertEquals(log_pdf.get_shape(), (6,)) 65 log_pdf_values = self.evaluate(log_pdf) 66 pdf = student.prob(t) 67 self.assertEquals(pdf.get_shape(), (6,)) 68 pdf_values = self.evaluate(pdf) 69 70 if not stats: 71 return 72 73 expected_log_pdf = stats.t.logpdf(t, df_v, loc=mu_v, scale=sigma_v) 74 expected_pdf = stats.t.pdf(t, df_v, loc=mu_v, scale=sigma_v) 75 self.assertAllClose(expected_log_pdf, log_pdf_values) 76 self.assertAllClose(np.log(expected_pdf), log_pdf_values) 77 self.assertAllClose(expected_pdf, pdf_values) 78 self.assertAllClose(np.exp(expected_log_pdf), pdf_values) 79 80 def testStudentLogPDFMultidimensional(self): 81 batch_size = 6 82 df = constant_op.constant([[1.5, 7.2]] * batch_size) 83 mu = constant_op.constant([[3., -3.]] * batch_size) 84 sigma = constant_op.constant( 85 [[-math.sqrt(10.), math.sqrt(15.)]] * batch_size) 86 df_v = np.array([1.5, 7.2]) 87 mu_v = np.array([3., -3.]) 88 sigma_v = np.array([np.sqrt(10.), np.sqrt(15.)]) 89 t = np.array([[-2.5, 2.5, 4., 0., -1., 2.]], dtype=np.float32).T 90 student = student_t.StudentT(df, loc=mu, scale=sigma) 91 log_pdf = student.log_prob(t) 92 log_pdf_values = self.evaluate(log_pdf) 93 self.assertEqual(log_pdf.get_shape(), (6, 2)) 94 pdf = student.prob(t) 95 pdf_values = self.evaluate(pdf) 96 self.assertEqual(pdf.get_shape(), (6, 2)) 97 98 if not stats: 99 return 100 expected_log_pdf = stats.t.logpdf(t, df_v, loc=mu_v, scale=sigma_v) 101 expected_pdf = stats.t.pdf(t, df_v, loc=mu_v, scale=sigma_v) 102 self.assertAllClose(expected_log_pdf, log_pdf_values) 103 self.assertAllClose(np.log(expected_pdf), log_pdf_values) 104 self.assertAllClose(expected_pdf, pdf_values) 105 self.assertAllClose(np.exp(expected_log_pdf), pdf_values) 106 107 def testStudentCDFAndLogCDF(self): 108 batch_size = 6 109 df = constant_op.constant([3.] * batch_size) 110 mu = constant_op.constant([7.] * batch_size) 111 sigma = constant_op.constant([-8.] * batch_size) 112 df_v = 3. 113 mu_v = 7. 114 sigma_v = 8. 115 t = np.array([-2.5, 2.5, 8., 0., -1., 2.], dtype=np.float32) 116 student = student_t.StudentT(df, loc=mu, scale=sigma) 117 118 log_cdf = student.log_cdf(t) 119 self.assertEquals(log_cdf.get_shape(), (6,)) 120 log_cdf_values = self.evaluate(log_cdf) 121 cdf = student.cdf(t) 122 self.assertEquals(cdf.get_shape(), (6,)) 123 cdf_values = self.evaluate(cdf) 124 125 if not stats: 126 return 127 expected_log_cdf = stats.t.logcdf(t, df_v, loc=mu_v, scale=sigma_v) 128 expected_cdf = stats.t.cdf(t, df_v, loc=mu_v, scale=sigma_v) 129 self.assertAllClose(expected_log_cdf, log_cdf_values, atol=0., rtol=1e-5) 130 self.assertAllClose( 131 np.log(expected_cdf), log_cdf_values, atol=0., rtol=1e-5) 132 self.assertAllClose(expected_cdf, cdf_values, atol=0., rtol=1e-5) 133 self.assertAllClose( 134 np.exp(expected_log_cdf), cdf_values, atol=0., rtol=1e-5) 135 136 def testStudentEntropy(self): 137 df_v = np.array([[2., 3., 7.]]) # 1x3 138 mu_v = np.array([[1., -1, 0]]) # 1x3 139 sigma_v = np.array([[1., -2., 3.]]).T # transposed => 3x1 140 student = student_t.StudentT(df=df_v, loc=mu_v, scale=sigma_v) 141 ent = student.entropy() 142 ent_values = self.evaluate(ent) 143 144 # Help scipy broadcast to 3x3 145 ones = np.array([[1, 1, 1]]) 146 sigma_bc = np.abs(sigma_v) * ones 147 mu_bc = ones.T * mu_v 148 df_bc = ones.T * df_v 149 if not stats: 150 return 151 expected_entropy = stats.t.entropy( 152 np.reshape(df_bc, [-1]), 153 loc=np.reshape(mu_bc, [-1]), 154 scale=np.reshape(sigma_bc, [-1])) 155 expected_entropy = np.reshape(expected_entropy, df_bc.shape) 156 self.assertAllClose(expected_entropy, ent_values) 157 158 def testStudentSample(self): 159 df = constant_op.constant(4.) 160 mu = constant_op.constant(3.) 161 sigma = constant_op.constant(-math.sqrt(10.)) 162 df_v = 4. 163 mu_v = 3. 164 sigma_v = np.sqrt(10.) 165 n = constant_op.constant(200000) 166 student = student_t.StudentT(df=df, loc=mu, scale=sigma) 167 samples = student.sample(n, seed=123456) 168 sample_values = self.evaluate(samples) 169 n_val = 200000 170 self.assertEqual(sample_values.shape, (n_val,)) 171 self.assertAllClose(sample_values.mean(), mu_v, rtol=0.1, atol=0) 172 self.assertAllClose( 173 sample_values.var(), sigma_v**2 * df_v / (df_v - 2), rtol=0.1, atol=0) 174 self._checkKLApprox(df_v, mu_v, sigma_v, sample_values) 175 176 # Test that sampling with the same seed twice gives the same results. 177 def testStudentSampleMultipleTimes(self): 178 df = constant_op.constant(4.) 179 mu = constant_op.constant(3.) 180 sigma = constant_op.constant(math.sqrt(10.)) 181 n = constant_op.constant(100) 182 183 random_seed.set_random_seed(654321) 184 student = student_t.StudentT(df=df, loc=mu, scale=sigma, name="student_t1") 185 samples1 = self.evaluate(student.sample(n, seed=123456)) 186 187 random_seed.set_random_seed(654321) 188 student2 = student_t.StudentT(df=df, loc=mu, scale=sigma, name="student_t2") 189 samples2 = self.evaluate(student2.sample(n, seed=123456)) 190 191 self.assertAllClose(samples1, samples2) 192 193 def testStudentSampleSmallDfNoNan(self): 194 df_v = [1e-1, 1e-5, 1e-10, 1e-20] 195 df = constant_op.constant(df_v) 196 n = constant_op.constant(200000) 197 student = student_t.StudentT(df=df, loc=1., scale=1.) 198 samples = student.sample(n, seed=123456) 199 sample_values = self.evaluate(samples) 200 n_val = 200000 201 self.assertEqual(sample_values.shape, (n_val, 4)) 202 self.assertTrue(np.all(np.logical_not(np.isnan(sample_values)))) 203 204 def testStudentSampleMultiDimensional(self): 205 batch_size = 7 206 df = constant_op.constant([[5., 7.]] * batch_size) 207 mu = constant_op.constant([[3., -3.]] * batch_size) 208 sigma = constant_op.constant( 209 [[math.sqrt(10.), math.sqrt(15.)]] * batch_size) 210 df_v = [5., 7.] 211 mu_v = [3., -3.] 212 sigma_v = [np.sqrt(10.), np.sqrt(15.)] 213 n = constant_op.constant(200000) 214 student = student_t.StudentT(df=df, loc=mu, scale=sigma) 215 samples = student.sample(n, seed=123456) 216 sample_values = self.evaluate(samples) 217 self.assertEqual(samples.get_shape(), (200000, batch_size, 2)) 218 self.assertAllClose( 219 sample_values[:, 0, 0].mean(), mu_v[0], rtol=0.1, atol=0) 220 self.assertAllClose( 221 sample_values[:, 0, 0].var(), 222 sigma_v[0]**2 * df_v[0] / (df_v[0] - 2), 223 rtol=0.2, 224 atol=0) 225 self._checkKLApprox(df_v[0], mu_v[0], sigma_v[0], sample_values[:, 0, 0]) 226 self.assertAllClose( 227 sample_values[:, 0, 1].mean(), mu_v[1], rtol=0.1, atol=0) 228 self.assertAllClose( 229 sample_values[:, 0, 1].var(), 230 sigma_v[1]**2 * df_v[1] / (df_v[1] - 2), 231 rtol=0.2, 232 atol=0) 233 self._checkKLApprox(df_v[1], mu_v[1], sigma_v[1], sample_values[:, 0, 1]) 234 235 def _checkKLApprox(self, df, mu, sigma, samples): 236 n = samples.size 237 np.random.seed(137) 238 if not stats: 239 return 240 sample_scipy = stats.t.rvs(df, loc=mu, scale=sigma, size=n) 241 covg = 0.99 242 r = stats.t.interval(covg, df, loc=mu, scale=sigma) 243 bins = 100 244 hist, _ = np.histogram(samples, bins=bins, range=r) 245 hist_scipy, _ = np.histogram(sample_scipy, bins=bins, range=r) 246 self.assertGreater(hist.sum(), n * (covg - .01)) 247 self.assertGreater(hist_scipy.sum(), n * (covg - .01)) 248 hist_min1 = hist + 1. # put at least one item in each bucket 249 hist_norm = hist_min1 / hist_min1.sum() 250 hist_scipy_min1 = hist_scipy + 1. # put at least one item in each bucket 251 hist_scipy_norm = hist_scipy_min1 / hist_scipy_min1.sum() 252 kl_appx = np.sum(np.log(hist_scipy_norm / hist_norm) * hist_scipy_norm) 253 self.assertLess(kl_appx, 1) 254 255 def testBroadcastingParams(self): 256 257 def _check(student): 258 self.assertEqual(student.mean().get_shape(), (3,)) 259 self.assertEqual(student.variance().get_shape(), (3,)) 260 self.assertEqual(student.entropy().get_shape(), (3,)) 261 self.assertEqual(student.log_prob(2.).get_shape(), (3,)) 262 self.assertEqual(student.prob(2.).get_shape(), (3,)) 263 self.assertEqual(student.sample(37).get_shape(), (37, 3,)) 264 265 _check(student_t.StudentT(df=[2., 3., 4.,], loc=2., scale=1.)) 266 _check(student_t.StudentT(df=7., loc=[2., 3., 4.,], scale=1.)) 267 _check(student_t.StudentT(df=7., loc=3., scale=[2., 3., 4.,])) 268 269 def testBroadcastingPdfArgs(self): 270 271 def _assert_shape(student, arg, shape): 272 self.assertEqual(student.log_prob(arg).get_shape(), shape) 273 self.assertEqual(student.prob(arg).get_shape(), shape) 274 275 def _check(student): 276 _assert_shape(student, 2., (3,)) 277 xs = np.array([2., 3., 4.], dtype=np.float32) 278 _assert_shape(student, xs, (3,)) 279 xs = np.array([xs]) 280 _assert_shape(student, xs, (1, 3)) 281 xs = xs.T 282 _assert_shape(student, xs, (3, 3)) 283 284 _check(student_t.StudentT(df=[2., 3., 4.,], loc=2., scale=1.)) 285 _check(student_t.StudentT(df=7., loc=[2., 3., 4.,], scale=1.)) 286 _check(student_t.StudentT(df=7., loc=3., scale=[2., 3., 4.,])) 287 288 def _check2d(student): 289 _assert_shape(student, 2., (1, 3)) 290 xs = np.array([2., 3., 4.], dtype=np.float32) 291 _assert_shape(student, xs, (1, 3)) 292 xs = np.array([xs]) 293 _assert_shape(student, xs, (1, 3)) 294 xs = xs.T 295 _assert_shape(student, xs, (3, 3)) 296 297 _check2d(student_t.StudentT(df=[[2., 3., 4.,]], loc=2., scale=1.)) 298 _check2d(student_t.StudentT(df=7., loc=[[2., 3., 4.,]], scale=1.)) 299 _check2d(student_t.StudentT(df=7., loc=3., scale=[[2., 3., 4.,]])) 300 301 def _check2d_rows(student): 302 _assert_shape(student, 2., (3, 1)) 303 xs = np.array([2., 3., 4.], dtype=np.float32) # (3,) 304 _assert_shape(student, xs, (3, 3)) 305 xs = np.array([xs]) # (1,3) 306 _assert_shape(student, xs, (3, 3)) 307 xs = xs.T # (3,1) 308 _assert_shape(student, xs, (3, 1)) 309 310 _check2d_rows(student_t.StudentT(df=[[2.], [3.], [4.]], loc=2., scale=1.)) 311 _check2d_rows(student_t.StudentT(df=7., loc=[[2.], [3.], [4.]], scale=1.)) 312 _check2d_rows(student_t.StudentT(df=7., loc=3., scale=[[2.], [3.], [4.]])) 313 314 def testMeanAllowNanStatsIsFalseWorksWhenAllBatchMembersAreDefined(self): 315 mu = [1., 3.3, 4.4] 316 student = student_t.StudentT(df=[3., 5., 7.], loc=mu, scale=[3., 2., 1.]) 317 mean = self.evaluate(student.mean()) 318 self.assertAllClose([1., 3.3, 4.4], mean) 319 320 def testMeanAllowNanStatsIsFalseRaisesWhenBatchMemberIsUndefined(self): 321 mu = [1., 3.3, 4.4] 322 student = student_t.StudentT( 323 df=[0.5, 5., 7.], loc=mu, scale=[3., 2., 1.], allow_nan_stats=False) 324 with self.assertRaisesOpError("x < y"): 325 self.evaluate(student.mean()) 326 327 def testMeanAllowNanStatsIsTrueReturnsNaNForUndefinedBatchMembers(self): 328 mu = [-2, 0., 1., 3.3, 4.4] 329 sigma = [5., 4., 3., 2., 1.] 330 student = student_t.StudentT( 331 df=[0.5, 1., 3., 5., 7.], loc=mu, scale=sigma, allow_nan_stats=True) 332 mean = self.evaluate(student.mean()) 333 self.assertAllClose([np.nan, np.nan, 1., 3.3, 4.4], mean) 334 335 def testVarianceAllowNanStatsTrueReturnsNaNforUndefinedBatchMembers(self): 336 # df = 0.5 ==> undefined mean ==> undefined variance. 337 # df = 1.5 ==> infinite variance. 338 df = [0.5, 1.5, 3., 5., 7.] 339 mu = [-2, 0., 1., 3.3, 4.4] 340 sigma = [5., 4., 3., 2., 1.] 341 student = student_t.StudentT( 342 df=df, loc=mu, scale=sigma, allow_nan_stats=True) 343 var = self.evaluate(student.variance()) 344 345 if not stats: 346 return 347 expected_var = [ 348 stats.t.var(d, loc=m, scale=s) for (d, m, s) in zip(df, mu, sigma) 349 ] 350 # Slicing off first element due to nan/inf mismatch in different SciPy 351 # versions. 352 self.assertAllClose(expected_var[1:], var[1:]) 353 354 def testVarianceAllowNanStatsFalseGivesCorrectValueForDefinedBatchMembers( 355 self): 356 # df = 1.5 ==> infinite variance. 357 df = [1.5, 3., 5., 7.] 358 mu = [0., 1., 3.3, 4.4] 359 sigma = [4., 3., 2., 1.] 360 student = student_t.StudentT(df=df, loc=mu, scale=sigma) 361 var = self.evaluate(student.variance()) 362 363 if not stats: 364 return 365 expected_var = [ 366 stats.t.var(d, loc=m, scale=s) for (d, m, s) in zip(df, mu, sigma) 367 ] 368 self.assertAllClose(expected_var, var) 369 370 def testVarianceAllowNanStatsFalseRaisesForUndefinedBatchMembers(self): 371 # df <= 1 ==> variance not defined 372 student = student_t.StudentT(df=1., loc=0., scale=1., allow_nan_stats=False) 373 with self.assertRaisesOpError("x < y"): 374 self.evaluate(student.variance()) 375 376 # df <= 1 ==> variance not defined 377 student = student_t.StudentT( 378 df=0.5, loc=0., scale=1., allow_nan_stats=False) 379 with self.assertRaisesOpError("x < y"): 380 self.evaluate(student.variance()) 381 382 def testStd(self): 383 # Defined for all batch members. 384 df = [3.5, 5., 3., 5., 7.] 385 mu = [-2.2] 386 sigma = [5., 4., 3., 2., 1.] 387 student = student_t.StudentT(df=df, loc=mu, scale=sigma) 388 # Test broadcast of mu across shape of df/sigma 389 stddev = self.evaluate(student.stddev()) 390 mu *= len(df) 391 392 if not stats: 393 return 394 expected_stddev = [ 395 stats.t.std(d, loc=m, scale=s) for (d, m, s) in zip(df, mu, sigma) 396 ] 397 self.assertAllClose(expected_stddev, stddev) 398 399 def testMode(self): 400 df = [0.5, 1., 3] 401 mu = [-1, 0., 1] 402 sigma = [5., 4., 3.] 403 student = student_t.StudentT(df=df, loc=mu, scale=sigma) 404 # Test broadcast of mu across shape of df/sigma 405 mode = self.evaluate(student.mode()) 406 self.assertAllClose([-1., 0, 1], mode) 407 408 def testPdfOfSample(self): 409 student = student_t.StudentT(df=3., loc=np.pi, scale=1.) 410 num = 20000 411 samples = student.sample(num, seed=123456) 412 pdfs = student.prob(samples) 413 mean = student.mean() 414 mean_pdf = student.prob(student.mean()) 415 sample_vals, pdf_vals, mean_val, mean_pdf_val = self.evaluate( 416 [samples, pdfs, student.mean(), mean_pdf]) 417 self.assertEqual(samples.get_shape(), (num,)) 418 self.assertEqual(pdfs.get_shape(), (num,)) 419 self.assertEqual(mean.get_shape(), ()) 420 self.assertNear(np.pi, np.mean(sample_vals), err=0.1) 421 self.assertNear(np.pi, mean_val, err=1e-6) 422 # Verify integral over sample*pdf ~= 1. 423 # Tolerance increased since eager was getting a value of 1.002041. 424 self._assertIntegral(sample_vals, pdf_vals, err=5e-2) 425 if not stats: 426 return 427 self.assertNear(stats.t.pdf(np.pi, 3., loc=np.pi), mean_pdf_val, err=1e-6) 428 429 def testFullyReparameterized(self): 430 df = constant_op.constant(2.0) 431 mu = constant_op.constant(1.0) 432 sigma = constant_op.constant(3.0) 433 with backprop.GradientTape() as tape: 434 tape.watch(df) 435 tape.watch(mu) 436 tape.watch(sigma) 437 student = student_t.StudentT(df=df, loc=mu, scale=sigma) 438 samples = student.sample(100) 439 grad_df, grad_mu, grad_sigma = tape.gradient(samples, [df, mu, sigma]) 440 self.assertIsNotNone(grad_df) 441 self.assertIsNotNone(grad_mu) 442 self.assertIsNotNone(grad_sigma) 443 444 def testPdfOfSampleMultiDims(self): 445 student = student_t.StudentT(df=[7., 11.], loc=[[5.], [6.]], scale=3.) 446 self.assertAllEqual([], student.event_shape) 447 self.assertAllEqual([], self.evaluate(student.event_shape_tensor())) 448 self.assertAllEqual([2, 2], student.batch_shape) 449 self.assertAllEqual([2, 2], self.evaluate(student.batch_shape_tensor())) 450 num = 50000 451 samples = student.sample(num, seed=123456) 452 pdfs = student.prob(samples) 453 sample_vals, pdf_vals = self.evaluate([samples, pdfs]) 454 self.assertEqual(samples.get_shape(), (num, 2, 2)) 455 self.assertEqual(pdfs.get_shape(), (num, 2, 2)) 456 self.assertNear(5., np.mean(sample_vals[:, 0, :]), err=0.1) 457 self.assertNear(6., np.mean(sample_vals[:, 1, :]), err=0.1) 458 self._assertIntegral(sample_vals[:, 0, 0], pdf_vals[:, 0, 0], err=0.05) 459 self._assertIntegral(sample_vals[:, 0, 1], pdf_vals[:, 0, 1], err=0.05) 460 self._assertIntegral(sample_vals[:, 1, 0], pdf_vals[:, 1, 0], err=0.05) 461 self._assertIntegral(sample_vals[:, 1, 1], pdf_vals[:, 1, 1], err=0.05) 462 if not stats: 463 return 464 self.assertNear( 465 stats.t.var(7., loc=0., scale=3.), # loc d.n. effect var 466 np.var(sample_vals[:, :, 0]), 467 err=1.0) 468 self.assertNear( 469 stats.t.var(11., loc=0., scale=3.), # loc d.n. effect var 470 np.var(sample_vals[:, :, 1]), 471 err=1.0) 472 473 def _assertIntegral(self, sample_vals, pdf_vals, err=1.5e-3): 474 s_p = zip(sample_vals, pdf_vals) 475 prev = (sample_vals.min() - 1000, 0) 476 total = 0 477 for k in sorted(s_p, key=lambda x: x[0]): 478 pair_pdf = (k[1] + prev[1]) / 2 479 total += (k[0] - prev[0]) * pair_pdf 480 prev = k 481 self.assertNear(1., total, err=err) 482 483 def testNegativeDofFails(self): 484 with self.assertRaisesOpError(r"Condition x > 0 did not hold"): 485 student = student_t.StudentT( 486 df=[2, -5.], loc=0., scale=1., validate_args=True, name="S") 487 self.evaluate(student.mean()) 488 489 def testStudentTWithAbsDfSoftplusScale(self): 490 df = constant_op.constant([-3.2, -4.6]) 491 mu = constant_op.constant([-4.2, 3.4]) 492 sigma = constant_op.constant([-6.4, -8.8]) 493 student = student_t.StudentTWithAbsDfSoftplusScale( 494 df=df, loc=mu, scale=sigma) 495 self.assertAllClose( 496 math_ops.floor(self.evaluate(math_ops.abs(df))), 497 self.evaluate(student.df)) 498 self.assertAllClose(self.evaluate(mu), self.evaluate(student.loc)) 499 self.assertAllClose( 500 self.evaluate(nn_ops.softplus(sigma)), self.evaluate(student.scale)) 501 502 503if __name__ == "__main__": 504 test.main() 505