1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for Student t distribution.""" 16 17import importlib 18import math 19 20import numpy as np 21 22from tensorflow.python.eager import backprop 23from tensorflow.python.framework import constant_op 24from tensorflow.python.framework import random_seed 25from tensorflow.python.framework import test_util 26from tensorflow.python.ops import math_ops 27from tensorflow.python.ops import nn_ops 28from tensorflow.python.ops.distributions import student_t 29from tensorflow.python.platform import test 30from tensorflow.python.platform import tf_logging 31 32 33def try_import(name): # pylint: disable=invalid-name 34 module = None 35 try: 36 module = importlib.import_module(name) 37 except ImportError as e: 38 tf_logging.warning("Could not import %s: %s" % (name, str(e))) 39 return module 40 41 42stats = try_import("scipy.stats") 43 44 45@test_util.run_all_in_graph_and_eager_modes 46class StudentTTest(test.TestCase): 47 48 def testStudentPDFAndLogPDF(self): 49 batch_size = 6 50 df = constant_op.constant([3.] * batch_size) 51 mu = constant_op.constant([7.] * batch_size) 52 sigma = constant_op.constant([8.] * batch_size) 53 df_v = 3. 54 mu_v = 7. 55 sigma_v = 8. 56 t = np.array([-2.5, 2.5, 8., 0., -1., 2.], dtype=np.float32) 57 student = student_t.StudentT(df, loc=mu, scale=-sigma) # pylint: disable=invalid-unary-operand-type 58 59 log_pdf = student.log_prob(t) 60 self.assertEqual(log_pdf.get_shape(), (6,)) 61 log_pdf_values = self.evaluate(log_pdf) 62 pdf = student.prob(t) 63 self.assertEqual(pdf.get_shape(), (6,)) 64 pdf_values = self.evaluate(pdf) 65 66 if not stats: 67 return 68 69 expected_log_pdf = stats.t.logpdf(t, df_v, loc=mu_v, scale=sigma_v) 70 expected_pdf = stats.t.pdf(t, df_v, loc=mu_v, scale=sigma_v) 71 self.assertAllClose(expected_log_pdf, log_pdf_values) 72 self.assertAllClose(np.log(expected_pdf), log_pdf_values) 73 self.assertAllClose(expected_pdf, pdf_values) 74 self.assertAllClose(np.exp(expected_log_pdf), pdf_values) 75 76 def testStudentLogPDFMultidimensional(self): 77 batch_size = 6 78 df = constant_op.constant([[1.5, 7.2]] * batch_size) 79 mu = constant_op.constant([[3., -3.]] * batch_size) 80 sigma = constant_op.constant( 81 [[-math.sqrt(10.), math.sqrt(15.)]] * batch_size) 82 df_v = np.array([1.5, 7.2]) 83 mu_v = np.array([3., -3.]) 84 sigma_v = np.array([np.sqrt(10.), np.sqrt(15.)]) 85 t = np.array([[-2.5, 2.5, 4., 0., -1., 2.]], dtype=np.float32).T 86 student = student_t.StudentT(df, loc=mu, scale=sigma) 87 log_pdf = student.log_prob(t) 88 log_pdf_values = self.evaluate(log_pdf) 89 self.assertEqual(log_pdf.get_shape(), (6, 2)) 90 pdf = student.prob(t) 91 pdf_values = self.evaluate(pdf) 92 self.assertEqual(pdf.get_shape(), (6, 2)) 93 94 if not stats: 95 return 96 expected_log_pdf = stats.t.logpdf(t, df_v, loc=mu_v, scale=sigma_v) 97 expected_pdf = stats.t.pdf(t, df_v, loc=mu_v, scale=sigma_v) 98 self.assertAllClose(expected_log_pdf, log_pdf_values) 99 self.assertAllClose(np.log(expected_pdf), log_pdf_values) 100 self.assertAllClose(expected_pdf, pdf_values) 101 self.assertAllClose(np.exp(expected_log_pdf), pdf_values) 102 103 def testStudentCDFAndLogCDF(self): 104 batch_size = 6 105 df = constant_op.constant([3.] * batch_size) 106 mu = constant_op.constant([7.] * batch_size) 107 sigma = constant_op.constant([-8.] * batch_size) 108 df_v = 3. 109 mu_v = 7. 110 sigma_v = 8. 111 t = np.array([-2.5, 2.5, 8., 0., -1., 2.], dtype=np.float32) 112 student = student_t.StudentT(df, loc=mu, scale=sigma) 113 114 log_cdf = student.log_cdf(t) 115 self.assertEqual(log_cdf.get_shape(), (6,)) 116 log_cdf_values = self.evaluate(log_cdf) 117 cdf = student.cdf(t) 118 self.assertEqual(cdf.get_shape(), (6,)) 119 cdf_values = self.evaluate(cdf) 120 121 if not stats: 122 return 123 expected_log_cdf = stats.t.logcdf(t, df_v, loc=mu_v, scale=sigma_v) 124 expected_cdf = stats.t.cdf(t, df_v, loc=mu_v, scale=sigma_v) 125 self.assertAllClose(expected_log_cdf, log_cdf_values, atol=0., rtol=1e-5) 126 self.assertAllClose( 127 np.log(expected_cdf), log_cdf_values, atol=0., rtol=1e-5) 128 self.assertAllClose(expected_cdf, cdf_values, atol=0., rtol=1e-5) 129 self.assertAllClose( 130 np.exp(expected_log_cdf), cdf_values, atol=0., rtol=1e-5) 131 132 def testStudentEntropy(self): 133 df_v = np.array([[2., 3., 7.]]) # 1x3 134 mu_v = np.array([[1., -1, 0]]) # 1x3 135 sigma_v = np.array([[1., -2., 3.]]).T # transposed => 3x1 136 student = student_t.StudentT(df=df_v, loc=mu_v, scale=sigma_v) 137 ent = student.entropy() 138 ent_values = self.evaluate(ent) 139 140 # Help scipy broadcast to 3x3 141 ones = np.array([[1, 1, 1]]) 142 sigma_bc = np.abs(sigma_v) * ones 143 mu_bc = ones.T * mu_v 144 df_bc = ones.T * df_v 145 if not stats: 146 return 147 expected_entropy = stats.t.entropy( 148 np.reshape(df_bc, [-1]), 149 loc=np.reshape(mu_bc, [-1]), 150 scale=np.reshape(sigma_bc, [-1])) 151 expected_entropy = np.reshape(expected_entropy, df_bc.shape) 152 self.assertAllClose(expected_entropy, ent_values) 153 154 def testStudentSample(self): 155 df = constant_op.constant(4.) 156 mu = constant_op.constant(3.) 157 sigma = constant_op.constant(-math.sqrt(10.)) 158 df_v = 4. 159 mu_v = 3. 160 sigma_v = np.sqrt(10.) 161 n = constant_op.constant(200000) 162 student = student_t.StudentT(df=df, loc=mu, scale=sigma) 163 samples = student.sample(n, seed=123456) 164 sample_values = self.evaluate(samples) 165 n_val = 200000 166 self.assertEqual(sample_values.shape, (n_val,)) 167 self.assertAllClose(sample_values.mean(), mu_v, rtol=0.1, atol=0) 168 self.assertAllClose( 169 sample_values.var(), sigma_v**2 * df_v / (df_v - 2), rtol=0.1, atol=0) 170 self._checkKLApprox(df_v, mu_v, sigma_v, sample_values) 171 172 # Test that sampling with the same seed twice gives the same results. 173 def testStudentSampleMultipleTimes(self): 174 df = constant_op.constant(4.) 175 mu = constant_op.constant(3.) 176 sigma = constant_op.constant(math.sqrt(10.)) 177 n = constant_op.constant(100) 178 179 random_seed.set_random_seed(654321) 180 student = student_t.StudentT(df=df, loc=mu, scale=sigma, name="student_t1") 181 samples1 = self.evaluate(student.sample(n, seed=123456)) 182 183 random_seed.set_random_seed(654321) 184 student2 = student_t.StudentT(df=df, loc=mu, scale=sigma, name="student_t2") 185 samples2 = self.evaluate(student2.sample(n, seed=123456)) 186 187 self.assertAllClose(samples1, samples2) 188 189 def testStudentSampleSmallDfNoNan(self): 190 df_v = [1e-1, 1e-5, 1e-10, 1e-20] 191 df = constant_op.constant(df_v) 192 n = constant_op.constant(200000) 193 student = student_t.StudentT(df=df, loc=1., scale=1.) 194 samples = student.sample(n, seed=123456) 195 sample_values = self.evaluate(samples) 196 n_val = 200000 197 self.assertEqual(sample_values.shape, (n_val, 4)) 198 self.assertTrue(np.all(np.logical_not(np.isnan(sample_values)))) 199 200 def testStudentSampleMultiDimensional(self): 201 batch_size = 7 202 df = constant_op.constant([[5., 7.]] * batch_size) 203 mu = constant_op.constant([[3., -3.]] * batch_size) 204 sigma = constant_op.constant( 205 [[math.sqrt(10.), math.sqrt(15.)]] * batch_size) 206 df_v = [5., 7.] 207 mu_v = [3., -3.] 208 sigma_v = [np.sqrt(10.), np.sqrt(15.)] 209 n = constant_op.constant(200000) 210 student = student_t.StudentT(df=df, loc=mu, scale=sigma) 211 samples = student.sample(n, seed=123456) 212 sample_values = self.evaluate(samples) 213 self.assertEqual(samples.get_shape(), (200000, batch_size, 2)) 214 self.assertAllClose( 215 sample_values[:, 0, 0].mean(), mu_v[0], rtol=0.1, atol=0) 216 self.assertAllClose( 217 sample_values[:, 0, 0].var(), 218 sigma_v[0]**2 * df_v[0] / (df_v[0] - 2), 219 rtol=0.2, 220 atol=0) 221 self._checkKLApprox(df_v[0], mu_v[0], sigma_v[0], sample_values[:, 0, 0]) 222 self.assertAllClose( 223 sample_values[:, 0, 1].mean(), mu_v[1], rtol=0.1, atol=0) 224 self.assertAllClose( 225 sample_values[:, 0, 1].var(), 226 sigma_v[1]**2 * df_v[1] / (df_v[1] - 2), 227 rtol=0.2, 228 atol=0) 229 self._checkKLApprox(df_v[1], mu_v[1], sigma_v[1], sample_values[:, 0, 1]) 230 231 def _checkKLApprox(self, df, mu, sigma, samples): 232 n = samples.size 233 np.random.seed(137) 234 if not stats: 235 return 236 sample_scipy = stats.t.rvs(df, loc=mu, scale=sigma, size=n) 237 covg = 0.99 238 r = stats.t.interval(covg, df, loc=mu, scale=sigma) 239 bins = 100 240 hist, _ = np.histogram(samples, bins=bins, range=r) 241 hist_scipy, _ = np.histogram(sample_scipy, bins=bins, range=r) 242 self.assertGreater(hist.sum(), n * (covg - .01)) 243 self.assertGreater(hist_scipy.sum(), n * (covg - .01)) 244 hist_min1 = hist + 1. # put at least one item in each bucket 245 hist_norm = hist_min1 / hist_min1.sum() 246 hist_scipy_min1 = hist_scipy + 1. # put at least one item in each bucket 247 hist_scipy_norm = hist_scipy_min1 / hist_scipy_min1.sum() 248 kl_appx = np.sum(np.log(hist_scipy_norm / hist_norm) * hist_scipy_norm) 249 self.assertLess(kl_appx, 1) 250 251 def testBroadcastingParams(self): 252 253 def _check(student): 254 self.assertEqual(student.mean().get_shape(), (3,)) 255 self.assertEqual(student.variance().get_shape(), (3,)) 256 self.assertEqual(student.entropy().get_shape(), (3,)) 257 self.assertEqual(student.log_prob(2.).get_shape(), (3,)) 258 self.assertEqual(student.prob(2.).get_shape(), (3,)) 259 self.assertEqual(student.sample(37).get_shape(), (37, 3,)) 260 261 _check(student_t.StudentT(df=[2., 3., 4.,], loc=2., scale=1.)) 262 _check(student_t.StudentT(df=7., loc=[2., 3., 4.,], scale=1.)) 263 _check(student_t.StudentT(df=7., loc=3., scale=[2., 3., 4.,])) 264 265 def testBroadcastingPdfArgs(self): 266 267 def _assert_shape(student, arg, shape): 268 self.assertEqual(student.log_prob(arg).get_shape(), shape) 269 self.assertEqual(student.prob(arg).get_shape(), shape) 270 271 def _check(student): 272 _assert_shape(student, 2., (3,)) 273 xs = np.array([2., 3., 4.], dtype=np.float32) 274 _assert_shape(student, xs, (3,)) 275 xs = np.array([xs]) 276 _assert_shape(student, xs, (1, 3)) 277 xs = xs.T 278 _assert_shape(student, xs, (3, 3)) 279 280 _check(student_t.StudentT(df=[2., 3., 4.,], loc=2., scale=1.)) 281 _check(student_t.StudentT(df=7., loc=[2., 3., 4.,], scale=1.)) 282 _check(student_t.StudentT(df=7., loc=3., scale=[2., 3., 4.,])) 283 284 def _check2d(student): 285 _assert_shape(student, 2., (1, 3)) 286 xs = np.array([2., 3., 4.], dtype=np.float32) 287 _assert_shape(student, xs, (1, 3)) 288 xs = np.array([xs]) 289 _assert_shape(student, xs, (1, 3)) 290 xs = xs.T 291 _assert_shape(student, xs, (3, 3)) 292 293 _check2d(student_t.StudentT(df=[[2., 3., 4.,]], loc=2., scale=1.)) 294 _check2d(student_t.StudentT(df=7., loc=[[2., 3., 4.,]], scale=1.)) 295 _check2d(student_t.StudentT(df=7., loc=3., scale=[[2., 3., 4.,]])) 296 297 def _check2d_rows(student): 298 _assert_shape(student, 2., (3, 1)) 299 xs = np.array([2., 3., 4.], dtype=np.float32) # (3,) 300 _assert_shape(student, xs, (3, 3)) 301 xs = np.array([xs]) # (1,3) 302 _assert_shape(student, xs, (3, 3)) 303 xs = xs.T # (3,1) 304 _assert_shape(student, xs, (3, 1)) 305 306 _check2d_rows(student_t.StudentT(df=[[2.], [3.], [4.]], loc=2., scale=1.)) 307 _check2d_rows(student_t.StudentT(df=7., loc=[[2.], [3.], [4.]], scale=1.)) 308 _check2d_rows(student_t.StudentT(df=7., loc=3., scale=[[2.], [3.], [4.]])) 309 310 def testMeanAllowNanStatsIsFalseWorksWhenAllBatchMembersAreDefined(self): 311 mu = [1., 3.3, 4.4] 312 student = student_t.StudentT(df=[3., 5., 7.], loc=mu, scale=[3., 2., 1.]) 313 mean = self.evaluate(student.mean()) 314 self.assertAllClose([1., 3.3, 4.4], mean) 315 316 def testMeanAllowNanStatsIsFalseRaisesWhenBatchMemberIsUndefined(self): 317 mu = [1., 3.3, 4.4] 318 student = student_t.StudentT( 319 df=[0.5, 5., 7.], loc=mu, scale=[3., 2., 1.], allow_nan_stats=False) 320 with self.assertRaisesOpError("x < y"): 321 self.evaluate(student.mean()) 322 323 def testMeanAllowNanStatsIsTrueReturnsNaNForUndefinedBatchMembers(self): 324 mu = [-2, 0., 1., 3.3, 4.4] 325 sigma = [5., 4., 3., 2., 1.] 326 student = student_t.StudentT( 327 df=[0.5, 1., 3., 5., 7.], loc=mu, scale=sigma, allow_nan_stats=True) 328 mean = self.evaluate(student.mean()) 329 self.assertAllClose([np.nan, np.nan, 1., 3.3, 4.4], mean) 330 331 def testVarianceAllowNanStatsTrueReturnsNaNforUndefinedBatchMembers(self): 332 # df = 0.5 ==> undefined mean ==> undefined variance. 333 # df = 1.5 ==> infinite variance. 334 df = [0.5, 1.5, 3., 5., 7.] 335 mu = [-2, 0., 1., 3.3, 4.4] 336 sigma = [5., 4., 3., 2., 1.] 337 student = student_t.StudentT( 338 df=df, loc=mu, scale=sigma, allow_nan_stats=True) 339 var = self.evaluate(student.variance()) 340 341 if not stats: 342 return 343 expected_var = [ 344 stats.t.var(d, loc=m, scale=s) for (d, m, s) in zip(df, mu, sigma) 345 ] 346 # Slicing off first element due to nan/inf mismatch in different SciPy 347 # versions. 348 self.assertAllClose(expected_var[1:], var[1:]) 349 350 def testVarianceAllowNanStatsFalseGivesCorrectValueForDefinedBatchMembers( 351 self): 352 # df = 1.5 ==> infinite variance. 353 df = [1.5, 3., 5., 7.] 354 mu = [0., 1., 3.3, 4.4] 355 sigma = [4., 3., 2., 1.] 356 student = student_t.StudentT(df=df, loc=mu, scale=sigma) 357 var = self.evaluate(student.variance()) 358 359 if not stats: 360 return 361 expected_var = [ 362 stats.t.var(d, loc=m, scale=s) for (d, m, s) in zip(df, mu, sigma) 363 ] 364 self.assertAllClose(expected_var, var) 365 366 def testVarianceAllowNanStatsFalseRaisesForUndefinedBatchMembers(self): 367 # df <= 1 ==> variance not defined 368 student = student_t.StudentT(df=1., loc=0., scale=1., allow_nan_stats=False) 369 with self.assertRaisesOpError("x < y"): 370 self.evaluate(student.variance()) 371 372 # df <= 1 ==> variance not defined 373 student = student_t.StudentT( 374 df=0.5, loc=0., scale=1., allow_nan_stats=False) 375 with self.assertRaisesOpError("x < y"): 376 self.evaluate(student.variance()) 377 378 def testStd(self): 379 # Defined for all batch members. 380 df = [3.5, 5., 3., 5., 7.] 381 mu = [-2.2] 382 sigma = [5., 4., 3., 2., 1.] 383 student = student_t.StudentT(df=df, loc=mu, scale=sigma) 384 # Test broadcast of mu across shape of df/sigma 385 stddev = self.evaluate(student.stddev()) 386 mu *= len(df) 387 388 if not stats: 389 return 390 expected_stddev = [ 391 stats.t.std(d, loc=m, scale=s) for (d, m, s) in zip(df, mu, sigma) 392 ] 393 self.assertAllClose(expected_stddev, stddev) 394 395 def testMode(self): 396 df = [0.5, 1., 3] 397 mu = [-1, 0., 1] 398 sigma = [5., 4., 3.] 399 student = student_t.StudentT(df=df, loc=mu, scale=sigma) 400 # Test broadcast of mu across shape of df/sigma 401 mode = self.evaluate(student.mode()) 402 self.assertAllClose([-1., 0, 1], mode) 403 404 def testPdfOfSample(self): 405 student = student_t.StudentT(df=3., loc=np.pi, scale=1.) 406 num = 20000 407 samples = student.sample(num, seed=123456) 408 pdfs = student.prob(samples) 409 mean = student.mean() 410 mean_pdf = student.prob(student.mean()) 411 sample_vals, pdf_vals, mean_val, mean_pdf_val = self.evaluate( 412 [samples, pdfs, student.mean(), mean_pdf]) 413 self.assertEqual(samples.get_shape(), (num,)) 414 self.assertEqual(pdfs.get_shape(), (num,)) 415 self.assertEqual(mean.get_shape(), ()) 416 self.assertNear(np.pi, np.mean(sample_vals), err=0.1) 417 self.assertNear(np.pi, mean_val, err=1e-6) 418 # Verify integral over sample*pdf ~= 1. 419 # Tolerance increased since eager was getting a value of 1.002041. 420 self._assertIntegral(sample_vals, pdf_vals, err=5e-2) 421 if not stats: 422 return 423 self.assertNear(stats.t.pdf(np.pi, 3., loc=np.pi), mean_pdf_val, err=1e-6) 424 425 def testFullyReparameterized(self): 426 df = constant_op.constant(2.0) 427 mu = constant_op.constant(1.0) 428 sigma = constant_op.constant(3.0) 429 with backprop.GradientTape() as tape: 430 tape.watch(df) 431 tape.watch(mu) 432 tape.watch(sigma) 433 student = student_t.StudentT(df=df, loc=mu, scale=sigma) 434 samples = student.sample(100) 435 grad_df, grad_mu, grad_sigma = tape.gradient(samples, [df, mu, sigma]) 436 self.assertIsNotNone(grad_df) 437 self.assertIsNotNone(grad_mu) 438 self.assertIsNotNone(grad_sigma) 439 440 def testPdfOfSampleMultiDims(self): 441 student = student_t.StudentT(df=[7., 11.], loc=[[5.], [6.]], scale=3.) 442 self.assertAllEqual([], student.event_shape) 443 self.assertAllEqual([], self.evaluate(student.event_shape_tensor())) 444 self.assertAllEqual([2, 2], student.batch_shape) 445 self.assertAllEqual([2, 2], self.evaluate(student.batch_shape_tensor())) 446 num = 50000 447 samples = student.sample(num, seed=123456) 448 pdfs = student.prob(samples) 449 sample_vals, pdf_vals = self.evaluate([samples, pdfs]) 450 self.assertEqual(samples.get_shape(), (num, 2, 2)) 451 self.assertEqual(pdfs.get_shape(), (num, 2, 2)) 452 self.assertNear(5., np.mean(sample_vals[:, 0, :]), err=0.1) 453 self.assertNear(6., np.mean(sample_vals[:, 1, :]), err=0.1) 454 self._assertIntegral(sample_vals[:, 0, 0], pdf_vals[:, 0, 0], err=0.05) 455 self._assertIntegral(sample_vals[:, 0, 1], pdf_vals[:, 0, 1], err=0.05) 456 self._assertIntegral(sample_vals[:, 1, 0], pdf_vals[:, 1, 0], err=0.05) 457 self._assertIntegral(sample_vals[:, 1, 1], pdf_vals[:, 1, 1], err=0.05) 458 if not stats: 459 return 460 self.assertNear( 461 stats.t.var(7., loc=0., scale=3.), # loc d.n. effect var 462 np.var(sample_vals[:, :, 0]), 463 err=1.0) 464 self.assertNear( 465 stats.t.var(11., loc=0., scale=3.), # loc d.n. effect var 466 np.var(sample_vals[:, :, 1]), 467 err=1.0) 468 469 def _assertIntegral(self, sample_vals, pdf_vals, err=1.5e-3): 470 s_p = zip(sample_vals, pdf_vals) 471 prev = (sample_vals.min() - 1000, 0) 472 total = 0 473 for k in sorted(s_p, key=lambda x: x[0]): 474 pair_pdf = (k[1] + prev[1]) / 2 475 total += (k[0] - prev[0]) * pair_pdf 476 prev = k 477 self.assertNear(1., total, err=err) 478 479 def testNegativeDofFails(self): 480 with self.assertRaisesOpError(r"Condition x > 0 did not hold"): 481 student = student_t.StudentT( 482 df=[2, -5.], loc=0., scale=1., validate_args=True, name="S") 483 self.evaluate(student.mean()) 484 485 def testStudentTWithAbsDfSoftplusScale(self): 486 df = constant_op.constant([-3.2, -4.6]) 487 mu = constant_op.constant([-4.2, 3.4]) 488 sigma = constant_op.constant([-6.4, -8.8]) 489 student = student_t.StudentTWithAbsDfSoftplusScale( 490 df=df, loc=mu, scale=sigma) 491 self.assertAllClose( 492 math_ops.floor(self.evaluate(math_ops.abs(df))), 493 self.evaluate(student.df)) 494 self.assertAllClose(self.evaluate(mu), self.evaluate(student.loc)) 495 self.assertAllClose( 496 self.evaluate(nn_ops.softplus(sigma)), self.evaluate(student.scale)) 497 498 499if __name__ == "__main__": 500 test.main() 501