1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15import numpy as np 16 17from tensorflow.python.eager import backprop 18from tensorflow.python.framework import constant_op 19from tensorflow.python.framework import dtypes 20from tensorflow.python.framework import tensor_shape 21from tensorflow.python.framework import test_util 22from tensorflow.python.ops import array_ops 23from tensorflow.python.ops import math_ops 24from tensorflow.python.ops.distributions import dirichlet_multinomial 25from tensorflow.python.platform import test 26 27 28ds = dirichlet_multinomial 29 30 31class DirichletMultinomialTest(test.TestCase): 32 33 def setUp(self): 34 self._rng = np.random.RandomState(42) 35 36 @test_util.run_deprecated_v1 37 def testSimpleShapes(self): 38 with self.cached_session(): 39 alpha = np.random.rand(3) 40 dist = ds.DirichletMultinomial(1., alpha) 41 self.assertEqual(3, dist.event_shape_tensor().eval()) 42 self.assertAllEqual([], dist.batch_shape_tensor()) 43 self.assertEqual(tensor_shape.TensorShape([3]), dist.event_shape) 44 self.assertEqual(tensor_shape.TensorShape([]), dist.batch_shape) 45 46 @test_util.run_deprecated_v1 47 def testComplexShapes(self): 48 with self.cached_session(): 49 alpha = np.random.rand(3, 2, 2) 50 n = [[3., 2], [4, 5], [6, 7]] 51 dist = ds.DirichletMultinomial(n, alpha) 52 self.assertEqual(2, dist.event_shape_tensor().eval()) 53 self.assertAllEqual([3, 2], dist.batch_shape_tensor()) 54 self.assertEqual(tensor_shape.TensorShape([2]), dist.event_shape) 55 self.assertEqual(tensor_shape.TensorShape([3, 2]), dist.batch_shape) 56 57 @test_util.run_deprecated_v1 58 def testNproperty(self): 59 alpha = [[1., 2, 3]] 60 n = [[5.]] 61 with self.cached_session(): 62 dist = ds.DirichletMultinomial(n, alpha) 63 self.assertEqual([1, 1], dist.total_count.get_shape()) 64 self.assertAllClose(n, dist.total_count) 65 66 @test_util.run_deprecated_v1 67 def testAlphaProperty(self): 68 alpha = [[1., 2, 3]] 69 with self.cached_session(): 70 dist = ds.DirichletMultinomial(1, alpha) 71 self.assertEqual([1, 3], dist.concentration.get_shape()) 72 self.assertAllClose(alpha, dist.concentration) 73 74 @test_util.run_deprecated_v1 75 def testPmfNandCountsAgree(self): 76 alpha = [[1., 2, 3]] 77 n = [[5.]] 78 with self.cached_session(): 79 dist = ds.DirichletMultinomial(n, alpha, validate_args=True) 80 dist.prob([2., 3, 0]).eval() 81 dist.prob([3., 0, 2]).eval() 82 with self.assertRaisesOpError("must be non-negative"): 83 dist.prob([-1., 4, 2]).eval() 84 with self.assertRaisesOpError( 85 "last-dimension must sum to `self.total_count`"): 86 dist.prob([3., 3, 0]).eval() 87 88 @test_util.run_deprecated_v1 89 def testPmfNonIntegerCounts(self): 90 alpha = [[1., 2, 3]] 91 n = [[5.]] 92 with self.cached_session(): 93 dist = ds.DirichletMultinomial(n, alpha, validate_args=True) 94 dist.prob([2., 3, 0]).eval() 95 dist.prob([3., 0, 2]).eval() 96 dist.prob([3.0, 0, 2.0]).eval() 97 # Both equality and integer checking fail. 98 placeholder = array_ops.placeholder(dtypes.float32) 99 with self.assertRaisesOpError( 100 "cannot contain fractional components"): 101 dist.prob(placeholder).eval(feed_dict={placeholder: [1.0, 2.5, 1.5]}) 102 dist = ds.DirichletMultinomial(n, alpha, validate_args=False) 103 dist.prob([1., 2., 3.]).eval() 104 # Non-integer arguments work. 105 dist.prob([1.0, 2.5, 1.5]).eval() 106 107 def testPmfBothZeroBatches(self): 108 # The probabilities of one vote falling into class k is the mean for class 109 # k. 110 with self.cached_session(): 111 # Both zero-batches. No broadcast 112 alpha = [1., 2] 113 counts = [1., 0] 114 dist = ds.DirichletMultinomial(1., alpha) 115 pmf = dist.prob(counts) 116 self.assertAllClose(1 / 3., self.evaluate(pmf)) 117 self.assertEqual((), pmf.get_shape()) 118 119 def testPmfBothZeroBatchesNontrivialN(self): 120 # The probabilities of one vote falling into class k is the mean for class 121 # k. 122 with self.cached_session(): 123 # Both zero-batches. No broadcast 124 alpha = [1., 2] 125 counts = [3., 2] 126 dist = ds.DirichletMultinomial(5., alpha) 127 pmf = dist.prob(counts) 128 self.assertAllClose(1 / 7., self.evaluate(pmf)) 129 self.assertEqual((), pmf.get_shape()) 130 131 def testPmfBothZeroBatchesMultidimensionalN(self): 132 # The probabilities of one vote falling into class k is the mean for class 133 # k. 134 with self.cached_session(): 135 alpha = [1., 2] 136 counts = [3., 2] 137 n = np.full([4, 3], 5., dtype=np.float32) 138 dist = ds.DirichletMultinomial(n, alpha) 139 pmf = dist.prob(counts) 140 self.assertAllClose([[1 / 7., 1 / 7., 1 / 7.]] * 4, self.evaluate(pmf)) 141 self.assertEqual((4, 3), pmf.get_shape()) 142 143 def testPmfAlphaStretchedInBroadcastWhenSameRank(self): 144 # The probabilities of one vote falling into class k is the mean for class 145 # k. 146 with self.cached_session(): 147 alpha = [[1., 2]] 148 counts = [[1., 0], [0., 1]] 149 dist = ds.DirichletMultinomial([1.], alpha) 150 pmf = dist.prob(counts) 151 self.assertAllClose([1 / 3., 2 / 3.], self.evaluate(pmf)) 152 self.assertAllEqual([2], pmf.get_shape()) 153 154 def testPmfAlphaStretchedInBroadcastWhenLowerRank(self): 155 # The probabilities of one vote falling into class k is the mean for class 156 # k. 157 with self.cached_session(): 158 alpha = [1., 2] 159 counts = [[1., 0], [0., 1]] 160 pmf = ds.DirichletMultinomial(1., alpha).prob(counts) 161 self.assertAllClose([1 / 3., 2 / 3.], self.evaluate(pmf)) 162 self.assertAllEqual([2], pmf.get_shape()) 163 164 def testPmfCountsStretchedInBroadcastWhenSameRank(self): 165 # The probabilities of one vote falling into class k is the mean for class 166 # k. 167 with self.cached_session(): 168 alpha = [[1., 2], [2., 3]] 169 counts = [[1., 0]] 170 pmf = ds.DirichletMultinomial([1., 1.], alpha).prob(counts) 171 self.assertAllClose([1 / 3., 2 / 5.], self.evaluate(pmf)) 172 self.assertAllEqual([2], pmf.get_shape()) 173 174 def testPmfCountsStretchedInBroadcastWhenLowerRank(self): 175 # The probabilities of one vote falling into class k is the mean for class 176 # k. 177 with self.cached_session(): 178 alpha = [[1., 2], [2., 3]] 179 counts = [1., 0] 180 pmf = ds.DirichletMultinomial(1., alpha).prob(counts) 181 self.assertAllClose([1 / 3., 2 / 5.], self.evaluate(pmf)) 182 self.assertAllEqual([2], pmf.get_shape()) 183 184 @test_util.run_deprecated_v1 185 def testPmfForOneVoteIsTheMeanWithOneRecordInput(self): 186 # The probabilities of one vote falling into class k is the mean for class 187 # k. 188 alpha = [1., 2, 3] 189 with self.cached_session(): 190 for class_num in range(3): 191 counts = np.zeros([3], dtype=np.float32) 192 counts[class_num] = 1 193 dist = ds.DirichletMultinomial(1., alpha) 194 mean = dist.mean().eval() 195 pmf = dist.prob(counts).eval() 196 197 self.assertAllClose(mean[class_num], pmf) 198 self.assertAllEqual([3], mean.shape) 199 self.assertAllEqual([], pmf.shape) 200 201 @test_util.run_deprecated_v1 202 def testMeanDoubleTwoVotes(self): 203 # The probabilities of two votes falling into class k for 204 # DirichletMultinomial(2, alpha) is twice as much as the probability of one 205 # vote falling into class k for DirichletMultinomial(1, alpha) 206 alpha = [1., 2, 3] 207 with self.cached_session(): 208 for class_num in range(3): 209 counts_one = np.zeros([3], dtype=np.float32) 210 counts_one[class_num] = 1. 211 counts_two = np.zeros([3], dtype=np.float32) 212 counts_two[class_num] = 2 213 214 dist1 = ds.DirichletMultinomial(1., alpha) 215 dist2 = ds.DirichletMultinomial(2., alpha) 216 217 mean1 = dist1.mean().eval() 218 mean2 = dist2.mean().eval() 219 220 self.assertAllClose(mean2[class_num], 2 * mean1[class_num]) 221 self.assertAllEqual([3], mean1.shape) 222 223 @test_util.run_deprecated_v1 224 def testCovarianceFromSampling(self): 225 # We will test mean, cov, var, stddev on a DirichletMultinomial constructed 226 # via broadcast between alpha, n. 227 alpha = np.array([[1., 2, 3], 228 [2.5, 4, 0.01]], dtype=np.float32) 229 # Ideally we'd be able to test broadcasting but, the multinomial sampler 230 # doesn't support different total counts. 231 n = np.float32(5) 232 with self.cached_session() as sess: 233 # batch_shape=[2], event_shape=[3] 234 dist = ds.DirichletMultinomial(n, alpha) 235 x = dist.sample(int(250e3), seed=1) 236 sample_mean = math_ops.reduce_mean(x, 0) 237 x_centered = x - sample_mean[array_ops.newaxis, ...] 238 sample_cov = math_ops.reduce_mean(math_ops.matmul( 239 x_centered[..., array_ops.newaxis], 240 x_centered[..., array_ops.newaxis, :]), 0) 241 sample_var = array_ops.matrix_diag_part(sample_cov) 242 sample_stddev = math_ops.sqrt(sample_var) 243 [ 244 sample_mean_, 245 sample_cov_, 246 sample_var_, 247 sample_stddev_, 248 analytic_mean, 249 analytic_cov, 250 analytic_var, 251 analytic_stddev, 252 ] = sess.run([ 253 sample_mean, 254 sample_cov, 255 sample_var, 256 sample_stddev, 257 dist.mean(), 258 dist.covariance(), 259 dist.variance(), 260 dist.stddev(), 261 ]) 262 self.assertAllClose(sample_mean_, analytic_mean, atol=0.04, rtol=0.) 263 self.assertAllClose(sample_cov_, analytic_cov, atol=0.05, rtol=0.) 264 self.assertAllClose(sample_var_, analytic_var, atol=0.05, rtol=0.) 265 self.assertAllClose(sample_stddev_, analytic_stddev, atol=0.02, rtol=0.) 266 267 @test_util.run_without_tensor_float_32( 268 "Tests DirichletMultinomial.covariance, which calls matmul") 269 def testCovariance(self): 270 # Shape [2] 271 alpha = [1., 2] 272 ns = [2., 3., 4., 5.] 273 alpha_0 = np.sum(alpha) 274 275 # Diagonal entries are of the form: 276 # Var(X_i) = n * alpha_i / alpha_sum * (1 - alpha_i / alpha_sum) * 277 # (alpha_sum + n) / (alpha_sum + 1) 278 variance_entry = lambda a, a_sum: a / a_sum * (1 - a / a_sum) 279 # Off diagonal entries are of the form: 280 # Cov(X_i, X_j) = -n * alpha_i * alpha_j / (alpha_sum ** 2) * 281 # (alpha_sum + n) / (alpha_sum + 1) 282 covariance_entry = lambda a, b, a_sum: -a * b / a_sum**2 283 # Shape [2, 2]. 284 shared_matrix = np.array([[ 285 variance_entry(alpha[0], alpha_0), 286 covariance_entry(alpha[0], alpha[1], alpha_0) 287 ], [ 288 covariance_entry(alpha[1], alpha[0], alpha_0), 289 variance_entry(alpha[1], alpha_0) 290 ]]) 291 292 with self.cached_session(): 293 for n in ns: 294 # n is shape [] and alpha is shape [2]. 295 dist = ds.DirichletMultinomial(n, alpha) 296 covariance = dist.covariance() 297 expected_covariance = n * (n + alpha_0) / (1 + alpha_0) * shared_matrix 298 299 self.assertEqual([2, 2], covariance.get_shape()) 300 self.assertAllClose(expected_covariance, self.evaluate(covariance)) 301 302 def testCovarianceNAlphaBroadcast(self): 303 alpha_v = [1., 2, 3] 304 alpha_0 = 6. 305 306 # Shape [4, 3] 307 alpha = np.array(4 * [alpha_v], dtype=np.float32) 308 # Shape [4, 1] 309 ns = np.array([[2.], [3.], [4.], [5.]], dtype=np.float32) 310 311 variance_entry = lambda a, a_sum: a / a_sum * (1 - a / a_sum) 312 covariance_entry = lambda a, b, a_sum: -a * b / a_sum**2 313 # Shape [4, 3, 3] 314 shared_matrix = np.array( 315 4 * [[[ 316 variance_entry(alpha_v[0], alpha_0), 317 covariance_entry(alpha_v[0], alpha_v[1], alpha_0), 318 covariance_entry(alpha_v[0], alpha_v[2], alpha_0) 319 ], [ 320 covariance_entry(alpha_v[1], alpha_v[0], alpha_0), 321 variance_entry(alpha_v[1], alpha_0), 322 covariance_entry(alpha_v[1], alpha_v[2], alpha_0) 323 ], [ 324 covariance_entry(alpha_v[2], alpha_v[0], alpha_0), 325 covariance_entry(alpha_v[2], alpha_v[1], alpha_0), 326 variance_entry(alpha_v[2], alpha_0) 327 ]]], 328 dtype=np.float32) 329 330 with self.cached_session(): 331 # ns is shape [4, 1], and alpha is shape [4, 3]. 332 dist = ds.DirichletMultinomial(ns, alpha) 333 covariance = dist.covariance() 334 expected_covariance = shared_matrix * ( 335 ns * (ns + alpha_0) / (1 + alpha_0))[..., array_ops.newaxis] 336 337 self.assertEqual([4, 3, 3], covariance.get_shape()) 338 self.assertAllClose(expected_covariance, self.evaluate(covariance)) 339 340 def testCovarianceMultidimensional(self): 341 alpha = np.random.rand(3, 5, 4).astype(np.float32) 342 alpha2 = np.random.rand(6, 3, 3).astype(np.float32) 343 344 ns = np.random.randint(low=1, high=11, size=[3, 5, 1]).astype(np.float32) 345 ns2 = np.random.randint(low=1, high=11, size=[6, 1, 1]).astype(np.float32) 346 347 with self.cached_session(): 348 dist = ds.DirichletMultinomial(ns, alpha) 349 dist2 = ds.DirichletMultinomial(ns2, alpha2) 350 351 covariance = dist.covariance() 352 covariance2 = dist2.covariance() 353 self.assertEqual([3, 5, 4, 4], covariance.get_shape()) 354 self.assertEqual([6, 3, 3, 3], covariance2.get_shape()) 355 356 def testZeroCountsResultsInPmfEqualToOne(self): 357 # There is only one way for zero items to be selected, and this happens with 358 # probability 1. 359 alpha = [5, 0.5] 360 counts = [0., 0] 361 with self.cached_session(): 362 dist = ds.DirichletMultinomial(0., alpha) 363 pmf = dist.prob(counts) 364 self.assertAllClose(1.0, self.evaluate(pmf)) 365 self.assertEqual((), pmf.get_shape()) 366 367 def testLargeTauGivesPreciseProbabilities(self): 368 # If tau is large, we are doing coin flips with probability mu. 369 mu = np.array([0.1, 0.1, 0.8], dtype=np.float32) 370 tau = np.array([100.], dtype=np.float32) 371 alpha = tau * mu 372 373 # One (three sided) coin flip. Prob[coin 3] = 0.8. 374 # Note that since it was one flip, value of tau didn't matter. 375 counts = [0., 0, 1] 376 with self.cached_session(): 377 dist = ds.DirichletMultinomial(1., alpha) 378 pmf = dist.prob(counts) 379 self.assertAllClose(0.8, self.evaluate(pmf), atol=1e-4) 380 self.assertEqual((), pmf.get_shape()) 381 382 # Two (three sided) coin flips. Prob[coin 3] = 0.8. 383 counts = [0., 0, 2] 384 with self.cached_session(): 385 dist = ds.DirichletMultinomial(2., alpha) 386 pmf = dist.prob(counts) 387 self.assertAllClose(0.8**2, self.evaluate(pmf), atol=1e-2) 388 self.assertEqual((), pmf.get_shape()) 389 390 # Three (three sided) coin flips. 391 counts = [1., 0, 2] 392 with self.cached_session(): 393 dist = ds.DirichletMultinomial(3., alpha) 394 pmf = dist.prob(counts) 395 self.assertAllClose(3 * 0.1 * 0.8 * 0.8, self.evaluate(pmf), atol=1e-2) 396 self.assertEqual((), pmf.get_shape()) 397 398 def testSmallTauPrefersCorrelatedResults(self): 399 # If tau is small, then correlation between draws is large, so draws that 400 # are both of the same class are more likely. 401 mu = np.array([0.5, 0.5], dtype=np.float32) 402 tau = np.array([0.1], dtype=np.float32) 403 alpha = tau * mu 404 405 # If there is only one draw, it is still a coin flip, even with small tau. 406 counts = [1., 0] 407 with self.cached_session(): 408 dist = ds.DirichletMultinomial(1., alpha) 409 pmf = dist.prob(counts) 410 self.assertAllClose(0.5, self.evaluate(pmf)) 411 self.assertEqual((), pmf.get_shape()) 412 413 # If there are two draws, it is much more likely that they are the same. 414 counts_same = [2., 0] 415 counts_different = [1, 1.] 416 with self.cached_session(): 417 dist = ds.DirichletMultinomial(2., alpha) 418 pmf_same = dist.prob(counts_same) 419 pmf_different = dist.prob(counts_different) 420 self.assertLess(5 * self.evaluate(pmf_different), self.evaluate(pmf_same)) 421 self.assertEqual((), pmf_same.get_shape()) 422 423 @test_util.run_deprecated_v1 424 def testNonStrictTurnsOffAllChecks(self): 425 # Make totally invalid input. 426 with self.cached_session(): 427 alpha = [[-1., 2]] # alpha should be positive. 428 counts = [[1., 0], [0., -1]] # counts should be non-negative. 429 n = [-5.3] # n should be a non negative integer equal to counts.sum. 430 dist = ds.DirichletMultinomial(n, alpha, validate_args=False) 431 dist.prob(counts).eval() # Should not raise. 432 433 @test_util.run_deprecated_v1 434 def testSampleUnbiasedNonScalarBatch(self): 435 with self.cached_session() as sess: 436 dist = ds.DirichletMultinomial( 437 total_count=5., 438 concentration=1. + 2. * self._rng.rand(4, 3, 2).astype(np.float32)) 439 n = int(3e3) 440 x = dist.sample(n, seed=0) 441 sample_mean = math_ops.reduce_mean(x, 0) 442 # Cyclically rotate event dims left. 443 x_centered = array_ops.transpose(x - sample_mean, [1, 2, 3, 0]) 444 sample_covariance = math_ops.matmul( 445 x_centered, x_centered, adjoint_b=True) / n 446 [ 447 sample_mean_, 448 sample_covariance_, 449 actual_mean_, 450 actual_covariance_, 451 ] = sess.run([ 452 sample_mean, 453 sample_covariance, 454 dist.mean(), 455 dist.covariance(), 456 ]) 457 self.assertAllEqual([4, 3, 2], sample_mean.get_shape()) 458 self.assertAllClose(actual_mean_, sample_mean_, atol=0., rtol=0.20) 459 self.assertAllEqual([4, 3, 2, 2], sample_covariance.get_shape()) 460 self.assertAllClose( 461 actual_covariance_, sample_covariance_, atol=0., rtol=0.20) 462 463 @test_util.run_deprecated_v1 464 def testSampleUnbiasedScalarBatch(self): 465 with self.cached_session() as sess: 466 dist = ds.DirichletMultinomial( 467 total_count=5., 468 concentration=1. + 2. * self._rng.rand(4).astype(np.float32)) 469 n = int(5e3) 470 x = dist.sample(n, seed=0) 471 sample_mean = math_ops.reduce_mean(x, 0) 472 x_centered = x - sample_mean # Already transposed to [n, 2]. 473 sample_covariance = math_ops.matmul( 474 x_centered, x_centered, adjoint_a=True) / n 475 [ 476 sample_mean_, 477 sample_covariance_, 478 actual_mean_, 479 actual_covariance_, 480 ] = sess.run([ 481 sample_mean, 482 sample_covariance, 483 dist.mean(), 484 dist.covariance(), 485 ]) 486 self.assertAllEqual([4], sample_mean.get_shape()) 487 self.assertAllClose(actual_mean_, sample_mean_, atol=0., rtol=0.20) 488 self.assertAllEqual([4, 4], sample_covariance.get_shape()) 489 self.assertAllClose( 490 actual_covariance_, sample_covariance_, atol=0., rtol=0.20) 491 492 def testNotReparameterized(self): 493 total_count = constant_op.constant(5.0) 494 concentration = constant_op.constant([0.1, 0.1, 0.1]) 495 with backprop.GradientTape() as tape: 496 tape.watch(total_count) 497 tape.watch(concentration) 498 dist = ds.DirichletMultinomial( 499 total_count=total_count, 500 concentration=concentration) 501 samples = dist.sample(100) 502 grad_total_count, grad_concentration = tape.gradient( 503 samples, [total_count, concentration]) 504 self.assertIsNone(grad_total_count) 505 self.assertIsNone(grad_concentration) 506 507 508if __name__ == "__main__": 509 test.main() 510