1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15from __future__ import absolute_import 16from __future__ import division 17from __future__ import print_function 18 19import numpy as np 20 21from tensorflow.python.eager import backprop 22from tensorflow.python.framework import constant_op 23from tensorflow.python.framework import dtypes 24from tensorflow.python.framework import tensor_shape 25from tensorflow.python.framework import test_util 26from tensorflow.python.ops import array_ops 27from tensorflow.python.ops import math_ops 28from tensorflow.python.ops.distributions import dirichlet_multinomial 29from tensorflow.python.platform import test 30 31 32ds = dirichlet_multinomial 33 34 35class DirichletMultinomialTest(test.TestCase): 36 37 def setUp(self): 38 self._rng = np.random.RandomState(42) 39 40 @test_util.run_deprecated_v1 41 def testSimpleShapes(self): 42 with self.cached_session(): 43 alpha = np.random.rand(3) 44 dist = ds.DirichletMultinomial(1., alpha) 45 self.assertEqual(3, dist.event_shape_tensor().eval()) 46 self.assertAllEqual([], dist.batch_shape_tensor().eval()) 47 self.assertEqual(tensor_shape.TensorShape([3]), dist.event_shape) 48 self.assertEqual(tensor_shape.TensorShape([]), dist.batch_shape) 49 50 @test_util.run_deprecated_v1 51 def testComplexShapes(self): 52 with self.cached_session(): 53 alpha = np.random.rand(3, 2, 2) 54 n = [[3., 2], [4, 5], [6, 7]] 55 dist = ds.DirichletMultinomial(n, alpha) 56 self.assertEqual(2, dist.event_shape_tensor().eval()) 57 self.assertAllEqual([3, 2], dist.batch_shape_tensor().eval()) 58 self.assertEqual(tensor_shape.TensorShape([2]), dist.event_shape) 59 self.assertEqual(tensor_shape.TensorShape([3, 2]), dist.batch_shape) 60 61 @test_util.run_deprecated_v1 62 def testNproperty(self): 63 alpha = [[1., 2, 3]] 64 n = [[5.]] 65 with self.cached_session(): 66 dist = ds.DirichletMultinomial(n, alpha) 67 self.assertEqual([1, 1], dist.total_count.get_shape()) 68 self.assertAllClose(n, dist.total_count.eval()) 69 70 @test_util.run_deprecated_v1 71 def testAlphaProperty(self): 72 alpha = [[1., 2, 3]] 73 with self.cached_session(): 74 dist = ds.DirichletMultinomial(1, alpha) 75 self.assertEqual([1, 3], dist.concentration.get_shape()) 76 self.assertAllClose(alpha, dist.concentration.eval()) 77 78 @test_util.run_deprecated_v1 79 def testPmfNandCountsAgree(self): 80 alpha = [[1., 2, 3]] 81 n = [[5.]] 82 with self.cached_session(): 83 dist = ds.DirichletMultinomial(n, alpha, validate_args=True) 84 dist.prob([2., 3, 0]).eval() 85 dist.prob([3., 0, 2]).eval() 86 with self.assertRaisesOpError("must be non-negative"): 87 dist.prob([-1., 4, 2]).eval() 88 with self.assertRaisesOpError( 89 "last-dimension must sum to `self.total_count`"): 90 dist.prob([3., 3, 0]).eval() 91 92 @test_util.run_deprecated_v1 93 def testPmfNonIntegerCounts(self): 94 alpha = [[1., 2, 3]] 95 n = [[5.]] 96 with self.cached_session(): 97 dist = ds.DirichletMultinomial(n, alpha, validate_args=True) 98 dist.prob([2., 3, 0]).eval() 99 dist.prob([3., 0, 2]).eval() 100 dist.prob([3.0, 0, 2.0]).eval() 101 # Both equality and integer checking fail. 102 placeholder = array_ops.placeholder(dtypes.float32) 103 with self.assertRaisesOpError( 104 "cannot contain fractional components"): 105 dist.prob(placeholder).eval(feed_dict={placeholder: [1.0, 2.5, 1.5]}) 106 dist = ds.DirichletMultinomial(n, alpha, validate_args=False) 107 dist.prob([1., 2., 3.]).eval() 108 # Non-integer arguments work. 109 dist.prob([1.0, 2.5, 1.5]).eval() 110 111 def testPmfBothZeroBatches(self): 112 # The probabilities of one vote falling into class k is the mean for class 113 # k. 114 with self.cached_session(): 115 # Both zero-batches. No broadcast 116 alpha = [1., 2] 117 counts = [1., 0] 118 dist = ds.DirichletMultinomial(1., alpha) 119 pmf = dist.prob(counts) 120 self.assertAllClose(1 / 3., self.evaluate(pmf)) 121 self.assertEqual((), pmf.get_shape()) 122 123 def testPmfBothZeroBatchesNontrivialN(self): 124 # The probabilities of one vote falling into class k is the mean for class 125 # k. 126 with self.cached_session(): 127 # Both zero-batches. No broadcast 128 alpha = [1., 2] 129 counts = [3., 2] 130 dist = ds.DirichletMultinomial(5., alpha) 131 pmf = dist.prob(counts) 132 self.assertAllClose(1 / 7., self.evaluate(pmf)) 133 self.assertEqual((), pmf.get_shape()) 134 135 def testPmfBothZeroBatchesMultidimensionalN(self): 136 # The probabilities of one vote falling into class k is the mean for class 137 # k. 138 with self.cached_session(): 139 alpha = [1., 2] 140 counts = [3., 2] 141 n = np.full([4, 3], 5., dtype=np.float32) 142 dist = ds.DirichletMultinomial(n, alpha) 143 pmf = dist.prob(counts) 144 self.assertAllClose([[1 / 7., 1 / 7., 1 / 7.]] * 4, self.evaluate(pmf)) 145 self.assertEqual((4, 3), pmf.get_shape()) 146 147 def testPmfAlphaStretchedInBroadcastWhenSameRank(self): 148 # The probabilities of one vote falling into class k is the mean for class 149 # k. 150 with self.cached_session(): 151 alpha = [[1., 2]] 152 counts = [[1., 0], [0., 1]] 153 dist = ds.DirichletMultinomial([1.], alpha) 154 pmf = dist.prob(counts) 155 self.assertAllClose([1 / 3., 2 / 3.], self.evaluate(pmf)) 156 self.assertAllEqual([2], pmf.get_shape()) 157 158 def testPmfAlphaStretchedInBroadcastWhenLowerRank(self): 159 # The probabilities of one vote falling into class k is the mean for class 160 # k. 161 with self.cached_session(): 162 alpha = [1., 2] 163 counts = [[1., 0], [0., 1]] 164 pmf = ds.DirichletMultinomial(1., alpha).prob(counts) 165 self.assertAllClose([1 / 3., 2 / 3.], self.evaluate(pmf)) 166 self.assertAllEqual([2], pmf.get_shape()) 167 168 def testPmfCountsStretchedInBroadcastWhenSameRank(self): 169 # The probabilities of one vote falling into class k is the mean for class 170 # k. 171 with self.cached_session(): 172 alpha = [[1., 2], [2., 3]] 173 counts = [[1., 0]] 174 pmf = ds.DirichletMultinomial([1., 1.], alpha).prob(counts) 175 self.assertAllClose([1 / 3., 2 / 5.], self.evaluate(pmf)) 176 self.assertAllEqual([2], pmf.get_shape()) 177 178 def testPmfCountsStretchedInBroadcastWhenLowerRank(self): 179 # The probabilities of one vote falling into class k is the mean for class 180 # k. 181 with self.cached_session(): 182 alpha = [[1., 2], [2., 3]] 183 counts = [1., 0] 184 pmf = ds.DirichletMultinomial(1., alpha).prob(counts) 185 self.assertAllClose([1 / 3., 2 / 5.], self.evaluate(pmf)) 186 self.assertAllEqual([2], pmf.get_shape()) 187 188 @test_util.run_deprecated_v1 189 def testPmfForOneVoteIsTheMeanWithOneRecordInput(self): 190 # The probabilities of one vote falling into class k is the mean for class 191 # k. 192 alpha = [1., 2, 3] 193 with self.cached_session(): 194 for class_num in range(3): 195 counts = np.zeros([3], dtype=np.float32) 196 counts[class_num] = 1 197 dist = ds.DirichletMultinomial(1., alpha) 198 mean = dist.mean().eval() 199 pmf = dist.prob(counts).eval() 200 201 self.assertAllClose(mean[class_num], pmf) 202 self.assertAllEqual([3], mean.shape) 203 self.assertAllEqual([], pmf.shape) 204 205 @test_util.run_deprecated_v1 206 def testMeanDoubleTwoVotes(self): 207 # The probabilities of two votes falling into class k for 208 # DirichletMultinomial(2, alpha) is twice as much as the probability of one 209 # vote falling into class k for DirichletMultinomial(1, alpha) 210 alpha = [1., 2, 3] 211 with self.cached_session(): 212 for class_num in range(3): 213 counts_one = np.zeros([3], dtype=np.float32) 214 counts_one[class_num] = 1. 215 counts_two = np.zeros([3], dtype=np.float32) 216 counts_two[class_num] = 2 217 218 dist1 = ds.DirichletMultinomial(1., alpha) 219 dist2 = ds.DirichletMultinomial(2., alpha) 220 221 mean1 = dist1.mean().eval() 222 mean2 = dist2.mean().eval() 223 224 self.assertAllClose(mean2[class_num], 2 * mean1[class_num]) 225 self.assertAllEqual([3], mean1.shape) 226 227 @test_util.run_deprecated_v1 228 def testCovarianceFromSampling(self): 229 # We will test mean, cov, var, stddev on a DirichletMultinomial constructed 230 # via broadcast between alpha, n. 231 alpha = np.array([[1., 2, 3], 232 [2.5, 4, 0.01]], dtype=np.float32) 233 # Ideally we'd be able to test broadcasting but, the multinomial sampler 234 # doesn't support different total counts. 235 n = np.float32(5) 236 with self.cached_session() as sess: 237 # batch_shape=[2], event_shape=[3] 238 dist = ds.DirichletMultinomial(n, alpha) 239 x = dist.sample(int(250e3), seed=1) 240 sample_mean = math_ops.reduce_mean(x, 0) 241 x_centered = x - sample_mean[array_ops.newaxis, ...] 242 sample_cov = math_ops.reduce_mean(math_ops.matmul( 243 x_centered[..., array_ops.newaxis], 244 x_centered[..., array_ops.newaxis, :]), 0) 245 sample_var = array_ops.matrix_diag_part(sample_cov) 246 sample_stddev = math_ops.sqrt(sample_var) 247 [ 248 sample_mean_, 249 sample_cov_, 250 sample_var_, 251 sample_stddev_, 252 analytic_mean, 253 analytic_cov, 254 analytic_var, 255 analytic_stddev, 256 ] = sess.run([ 257 sample_mean, 258 sample_cov, 259 sample_var, 260 sample_stddev, 261 dist.mean(), 262 dist.covariance(), 263 dist.variance(), 264 dist.stddev(), 265 ]) 266 self.assertAllClose(sample_mean_, analytic_mean, atol=0.04, rtol=0.) 267 self.assertAllClose(sample_cov_, analytic_cov, atol=0.05, rtol=0.) 268 self.assertAllClose(sample_var_, analytic_var, atol=0.05, rtol=0.) 269 self.assertAllClose(sample_stddev_, analytic_stddev, atol=0.02, rtol=0.) 270 271 def testCovariance(self): 272 # Shape [2] 273 alpha = [1., 2] 274 ns = [2., 3., 4., 5.] 275 alpha_0 = np.sum(alpha) 276 277 # Diagonal entries are of the form: 278 # Var(X_i) = n * alpha_i / alpha_sum * (1 - alpha_i / alpha_sum) * 279 # (alpha_sum + n) / (alpha_sum + 1) 280 variance_entry = lambda a, a_sum: a / a_sum * (1 - a / a_sum) 281 # Off diagonal entries are of the form: 282 # Cov(X_i, X_j) = -n * alpha_i * alpha_j / (alpha_sum ** 2) * 283 # (alpha_sum + n) / (alpha_sum + 1) 284 covariance_entry = lambda a, b, a_sum: -a * b / a_sum**2 285 # Shape [2, 2]. 286 shared_matrix = np.array([[ 287 variance_entry(alpha[0], alpha_0), 288 covariance_entry(alpha[0], alpha[1], alpha_0) 289 ], [ 290 covariance_entry(alpha[1], alpha[0], alpha_0), 291 variance_entry(alpha[1], alpha_0) 292 ]]) 293 294 with self.cached_session(): 295 for n in ns: 296 # n is shape [] and alpha is shape [2]. 297 dist = ds.DirichletMultinomial(n, alpha) 298 covariance = dist.covariance() 299 expected_covariance = n * (n + alpha_0) / (1 + alpha_0) * shared_matrix 300 301 self.assertEqual([2, 2], covariance.get_shape()) 302 self.assertAllClose(expected_covariance, self.evaluate(covariance)) 303 304 def testCovarianceNAlphaBroadcast(self): 305 alpha_v = [1., 2, 3] 306 alpha_0 = 6. 307 308 # Shape [4, 3] 309 alpha = np.array(4 * [alpha_v], dtype=np.float32) 310 # Shape [4, 1] 311 ns = np.array([[2.], [3.], [4.], [5.]], dtype=np.float32) 312 313 variance_entry = lambda a, a_sum: a / a_sum * (1 - a / a_sum) 314 covariance_entry = lambda a, b, a_sum: -a * b / a_sum**2 315 # Shape [4, 3, 3] 316 shared_matrix = np.array( 317 4 * [[[ 318 variance_entry(alpha_v[0], alpha_0), 319 covariance_entry(alpha_v[0], alpha_v[1], alpha_0), 320 covariance_entry(alpha_v[0], alpha_v[2], alpha_0) 321 ], [ 322 covariance_entry(alpha_v[1], alpha_v[0], alpha_0), 323 variance_entry(alpha_v[1], alpha_0), 324 covariance_entry(alpha_v[1], alpha_v[2], alpha_0) 325 ], [ 326 covariance_entry(alpha_v[2], alpha_v[0], alpha_0), 327 covariance_entry(alpha_v[2], alpha_v[1], alpha_0), 328 variance_entry(alpha_v[2], alpha_0) 329 ]]], 330 dtype=np.float32) 331 332 with self.cached_session(): 333 # ns is shape [4, 1], and alpha is shape [4, 3]. 334 dist = ds.DirichletMultinomial(ns, alpha) 335 covariance = dist.covariance() 336 expected_covariance = shared_matrix * ( 337 ns * (ns + alpha_0) / (1 + alpha_0))[..., array_ops.newaxis] 338 339 self.assertEqual([4, 3, 3], covariance.get_shape()) 340 self.assertAllClose(expected_covariance, self.evaluate(covariance)) 341 342 def testCovarianceMultidimensional(self): 343 alpha = np.random.rand(3, 5, 4).astype(np.float32) 344 alpha2 = np.random.rand(6, 3, 3).astype(np.float32) 345 346 ns = np.random.randint(low=1, high=11, size=[3, 5, 1]).astype(np.float32) 347 ns2 = np.random.randint(low=1, high=11, size=[6, 1, 1]).astype(np.float32) 348 349 with self.cached_session(): 350 dist = ds.DirichletMultinomial(ns, alpha) 351 dist2 = ds.DirichletMultinomial(ns2, alpha2) 352 353 covariance = dist.covariance() 354 covariance2 = dist2.covariance() 355 self.assertEqual([3, 5, 4, 4], covariance.get_shape()) 356 self.assertEqual([6, 3, 3, 3], covariance2.get_shape()) 357 358 def testZeroCountsResultsInPmfEqualToOne(self): 359 # There is only one way for zero items to be selected, and this happens with 360 # probability 1. 361 alpha = [5, 0.5] 362 counts = [0., 0] 363 with self.cached_session(): 364 dist = ds.DirichletMultinomial(0., alpha) 365 pmf = dist.prob(counts) 366 self.assertAllClose(1.0, self.evaluate(pmf)) 367 self.assertEqual((), pmf.get_shape()) 368 369 def testLargeTauGivesPreciseProbabilities(self): 370 # If tau is large, we are doing coin flips with probability mu. 371 mu = np.array([0.1, 0.1, 0.8], dtype=np.float32) 372 tau = np.array([100.], dtype=np.float32) 373 alpha = tau * mu 374 375 # One (three sided) coin flip. Prob[coin 3] = 0.8. 376 # Note that since it was one flip, value of tau didn't matter. 377 counts = [0., 0, 1] 378 with self.cached_session(): 379 dist = ds.DirichletMultinomial(1., alpha) 380 pmf = dist.prob(counts) 381 self.assertAllClose(0.8, self.evaluate(pmf), atol=1e-4) 382 self.assertEqual((), pmf.get_shape()) 383 384 # Two (three sided) coin flips. Prob[coin 3] = 0.8. 385 counts = [0., 0, 2] 386 with self.cached_session(): 387 dist = ds.DirichletMultinomial(2., alpha) 388 pmf = dist.prob(counts) 389 self.assertAllClose(0.8**2, self.evaluate(pmf), atol=1e-2) 390 self.assertEqual((), pmf.get_shape()) 391 392 # Three (three sided) coin flips. 393 counts = [1., 0, 2] 394 with self.cached_session(): 395 dist = ds.DirichletMultinomial(3., alpha) 396 pmf = dist.prob(counts) 397 self.assertAllClose(3 * 0.1 * 0.8 * 0.8, self.evaluate(pmf), atol=1e-2) 398 self.assertEqual((), pmf.get_shape()) 399 400 def testSmallTauPrefersCorrelatedResults(self): 401 # If tau is small, then correlation between draws is large, so draws that 402 # are both of the same class are more likely. 403 mu = np.array([0.5, 0.5], dtype=np.float32) 404 tau = np.array([0.1], dtype=np.float32) 405 alpha = tau * mu 406 407 # If there is only one draw, it is still a coin flip, even with small tau. 408 counts = [1., 0] 409 with self.cached_session(): 410 dist = ds.DirichletMultinomial(1., alpha) 411 pmf = dist.prob(counts) 412 self.assertAllClose(0.5, self.evaluate(pmf)) 413 self.assertEqual((), pmf.get_shape()) 414 415 # If there are two draws, it is much more likely that they are the same. 416 counts_same = [2., 0] 417 counts_different = [1, 1.] 418 with self.cached_session(): 419 dist = ds.DirichletMultinomial(2., alpha) 420 pmf_same = dist.prob(counts_same) 421 pmf_different = dist.prob(counts_different) 422 self.assertLess(5 * self.evaluate(pmf_different), self.evaluate(pmf_same)) 423 self.assertEqual((), pmf_same.get_shape()) 424 425 @test_util.run_deprecated_v1 426 def testNonStrictTurnsOffAllChecks(self): 427 # Make totally invalid input. 428 with self.cached_session(): 429 alpha = [[-1., 2]] # alpha should be positive. 430 counts = [[1., 0], [0., -1]] # counts should be non-negative. 431 n = [-5.3] # n should be a non negative integer equal to counts.sum. 432 dist = ds.DirichletMultinomial(n, alpha, validate_args=False) 433 dist.prob(counts).eval() # Should not raise. 434 435 @test_util.run_deprecated_v1 436 def testSampleUnbiasedNonScalarBatch(self): 437 with self.cached_session() as sess: 438 dist = ds.DirichletMultinomial( 439 total_count=5., 440 concentration=1. + 2. * self._rng.rand(4, 3, 2).astype(np.float32)) 441 n = int(3e3) 442 x = dist.sample(n, seed=0) 443 sample_mean = math_ops.reduce_mean(x, 0) 444 # Cyclically rotate event dims left. 445 x_centered = array_ops.transpose(x - sample_mean, [1, 2, 3, 0]) 446 sample_covariance = math_ops.matmul( 447 x_centered, x_centered, adjoint_b=True) / n 448 [ 449 sample_mean_, 450 sample_covariance_, 451 actual_mean_, 452 actual_covariance_, 453 ] = sess.run([ 454 sample_mean, 455 sample_covariance, 456 dist.mean(), 457 dist.covariance(), 458 ]) 459 self.assertAllEqual([4, 3, 2], sample_mean.get_shape()) 460 self.assertAllClose(actual_mean_, sample_mean_, atol=0., rtol=0.20) 461 self.assertAllEqual([4, 3, 2, 2], sample_covariance.get_shape()) 462 self.assertAllClose( 463 actual_covariance_, sample_covariance_, atol=0., rtol=0.20) 464 465 @test_util.run_deprecated_v1 466 def testSampleUnbiasedScalarBatch(self): 467 with self.cached_session() as sess: 468 dist = ds.DirichletMultinomial( 469 total_count=5., 470 concentration=1. + 2. * self._rng.rand(4).astype(np.float32)) 471 n = int(5e3) 472 x = dist.sample(n, seed=0) 473 sample_mean = math_ops.reduce_mean(x, 0) 474 x_centered = x - sample_mean # Already transposed to [n, 2]. 475 sample_covariance = math_ops.matmul( 476 x_centered, x_centered, adjoint_a=True) / n 477 [ 478 sample_mean_, 479 sample_covariance_, 480 actual_mean_, 481 actual_covariance_, 482 ] = sess.run([ 483 sample_mean, 484 sample_covariance, 485 dist.mean(), 486 dist.covariance(), 487 ]) 488 self.assertAllEqual([4], sample_mean.get_shape()) 489 self.assertAllClose(actual_mean_, sample_mean_, atol=0., rtol=0.20) 490 self.assertAllEqual([4, 4], sample_covariance.get_shape()) 491 self.assertAllClose( 492 actual_covariance_, sample_covariance_, atol=0., rtol=0.20) 493 494 def testNotReparameterized(self): 495 total_count = constant_op.constant(5.0) 496 concentration = constant_op.constant([0.1, 0.1, 0.1]) 497 with backprop.GradientTape() as tape: 498 tape.watch(total_count) 499 tape.watch(concentration) 500 dist = ds.DirichletMultinomial( 501 total_count=total_count, 502 concentration=concentration) 503 samples = dist.sample(100) 504 grad_total_count, grad_concentration = tape.gradient( 505 samples, [total_count, concentration]) 506 self.assertIsNone(grad_total_count) 507 self.assertIsNone(grad_concentration) 508 509 510if __name__ == "__main__": 511 test.main() 512