1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15import numpy as np 16 17from tensorflow.python.eager import backprop 18from tensorflow.python.framework import constant_op 19from tensorflow.python.framework import dtypes 20from tensorflow.python.framework import tensor_shape 21from tensorflow.python.framework import test_util 22from tensorflow.python.ops import array_ops 23from tensorflow.python.ops import math_ops 24from tensorflow.python.ops.distributions import multinomial 25from tensorflow.python.platform import test 26 27 28class MultinomialTest(test.TestCase): 29 30 def setUp(self): 31 self._rng = np.random.RandomState(42) 32 33 @test_util.run_v1_only("b/120545219") 34 def testSimpleShapes(self): 35 with self.cached_session(): 36 p = [.1, .3, .6] 37 dist = multinomial.Multinomial(total_count=1., probs=p) 38 self.assertEqual(3, dist.event_shape_tensor().eval()) 39 self.assertAllEqual([], dist.batch_shape_tensor()) 40 self.assertEqual(tensor_shape.TensorShape([3]), dist.event_shape) 41 self.assertEqual(tensor_shape.TensorShape([]), dist.batch_shape) 42 43 @test_util.run_v1_only("b/120545219") 44 def testComplexShapes(self): 45 with self.cached_session(): 46 p = 0.5 * np.ones([3, 2, 2], dtype=np.float32) 47 n = [[3., 2], [4, 5], [6, 7]] 48 dist = multinomial.Multinomial(total_count=n, probs=p) 49 self.assertEqual(2, dist.event_shape_tensor().eval()) 50 self.assertAllEqual([3, 2], dist.batch_shape_tensor()) 51 self.assertEqual(tensor_shape.TensorShape([2]), dist.event_shape) 52 self.assertEqual(tensor_shape.TensorShape([3, 2]), dist.batch_shape) 53 54 @test_util.run_v1_only("b/120545219") 55 def testN(self): 56 p = [[0.1, 0.2, 0.7], [0.2, 0.3, 0.5]] 57 n = [[3.], [4]] 58 with self.cached_session(): 59 dist = multinomial.Multinomial(total_count=n, probs=p) 60 self.assertEqual((2, 1), dist.total_count.get_shape()) 61 self.assertAllClose(n, dist.total_count) 62 63 @test_util.run_v1_only("b/120545219") 64 def testP(self): 65 p = [[0.1, 0.2, 0.7]] 66 with self.cached_session(): 67 dist = multinomial.Multinomial(total_count=3., probs=p) 68 self.assertEqual((1, 3), dist.probs.get_shape()) 69 self.assertEqual((1, 3), dist.logits.get_shape()) 70 self.assertAllClose(p, dist.probs) 71 72 @test_util.run_v1_only("b/120545219") 73 def testLogits(self): 74 p = np.array([[0.1, 0.2, 0.7]], dtype=np.float32) 75 logits = np.log(p) - 50. 76 with self.cached_session(): 77 multinom = multinomial.Multinomial(total_count=3., logits=logits) 78 self.assertEqual((1, 3), multinom.probs.get_shape()) 79 self.assertEqual((1, 3), multinom.logits.get_shape()) 80 self.assertAllClose(p, multinom.probs) 81 self.assertAllClose(logits, multinom.logits) 82 83 @test_util.run_v1_only("b/120545219") 84 def testPmfUnderflow(self): 85 logits = np.array([[-200, 0]], dtype=np.float32) 86 with self.cached_session(): 87 dist = multinomial.Multinomial(total_count=1., logits=logits) 88 lp = dist.log_prob([1., 0.]).eval()[0] 89 self.assertAllClose(-200, lp, atol=0, rtol=1e-6) 90 91 @test_util.run_v1_only("b/120545219") 92 def testPmfandCountsAgree(self): 93 p = [[0.1, 0.2, 0.7]] 94 n = [[5.]] 95 with self.cached_session(): 96 dist = multinomial.Multinomial(total_count=n, probs=p, validate_args=True) 97 dist.prob([2., 3, 0]).eval() 98 dist.prob([3., 0, 2]).eval() 99 with self.assertRaisesOpError("must be non-negative"): 100 dist.prob([-1., 4, 2]).eval() 101 with self.assertRaisesOpError("counts must sum to `self.total_count`"): 102 dist.prob([3., 3, 0]).eval() 103 104 @test_util.run_v1_only("b/120545219") 105 def testPmfNonIntegerCounts(self): 106 p = [[0.1, 0.2, 0.7]] 107 n = [[5.]] 108 with self.cached_session(): 109 # No errors with integer n. 110 multinom = multinomial.Multinomial( 111 total_count=n, probs=p, validate_args=True) 112 multinom.prob([2., 1, 2]).eval() 113 multinom.prob([3., 0, 2]).eval() 114 # Counts don't sum to n. 115 with self.assertRaisesOpError("counts must sum to `self.total_count`"): 116 multinom.prob([2., 3, 2]).eval() 117 # Counts are non-integers. 118 x = array_ops.placeholder(dtypes.float32) 119 with self.assertRaisesOpError( 120 "cannot contain fractional components."): 121 multinom.prob(x).eval(feed_dict={x: [1.0, 2.5, 1.5]}) 122 123 multinom = multinomial.Multinomial( 124 total_count=n, probs=p, validate_args=False) 125 multinom.prob([1., 2., 2.]).eval() 126 # Non-integer arguments work. 127 multinom.prob([1.0, 2.5, 1.5]).eval() 128 129 def testPmfBothZeroBatches(self): 130 with self.cached_session(): 131 # Both zero-batches. No broadcast 132 p = [0.5, 0.5] 133 counts = [1., 0] 134 pmf = multinomial.Multinomial(total_count=1., probs=p).prob(counts) 135 self.assertAllClose(0.5, self.evaluate(pmf)) 136 self.assertEqual((), pmf.get_shape()) 137 138 def testPmfBothZeroBatchesNontrivialN(self): 139 with self.cached_session(): 140 # Both zero-batches. No broadcast 141 p = [0.1, 0.9] 142 counts = [3., 2] 143 dist = multinomial.Multinomial(total_count=5., probs=p) 144 pmf = dist.prob(counts) 145 # 5 choose 3 = 5 choose 2 = 10. 10 * (.9)^2 * (.1)^3 = 81/10000. 146 self.assertAllClose(81. / 10000, self.evaluate(pmf)) 147 self.assertEqual((), pmf.get_shape()) 148 149 def testPmfPStretchedInBroadcastWhenSameRank(self): 150 with self.cached_session(): 151 p = [[0.1, 0.9]] 152 counts = [[1., 0], [0, 1]] 153 pmf = multinomial.Multinomial(total_count=1., probs=p).prob(counts) 154 self.assertAllClose([0.1, 0.9], self.evaluate(pmf)) 155 self.assertEqual((2), pmf.get_shape()) 156 157 def testPmfPStretchedInBroadcastWhenLowerRank(self): 158 with self.cached_session(): 159 p = [0.1, 0.9] 160 counts = [[1., 0], [0, 1]] 161 pmf = multinomial.Multinomial(total_count=1., probs=p).prob(counts) 162 self.assertAllClose([0.1, 0.9], self.evaluate(pmf)) 163 self.assertEqual((2), pmf.get_shape()) 164 165 @test_util.run_v1_only("b/120545219") 166 def testPmfCountsStretchedInBroadcastWhenSameRank(self): 167 with self.cached_session(): 168 p = [[0.1, 0.9], [0.7, 0.3]] 169 counts = [[1., 0]] 170 pmf = multinomial.Multinomial(total_count=1., probs=p).prob(counts) 171 self.assertAllClose(pmf, [0.1, 0.7]) 172 self.assertEqual((2), pmf.get_shape()) 173 174 @test_util.run_v1_only("b/120545219") 175 def testPmfCountsStretchedInBroadcastWhenLowerRank(self): 176 with self.cached_session(): 177 p = [[0.1, 0.9], [0.7, 0.3]] 178 counts = [1., 0] 179 pmf = multinomial.Multinomial(total_count=1., probs=p).prob(counts) 180 self.assertAllClose(pmf, [0.1, 0.7]) 181 self.assertEqual(pmf.get_shape(), (2)) 182 183 def testPmfShapeCountsStretchedN(self): 184 with self.cached_session(): 185 # [2, 2, 2] 186 p = [[[0.1, 0.9], [0.1, 0.9]], [[0.7, 0.3], [0.7, 0.3]]] 187 # [2, 2] 188 n = [[3., 3], [3, 3]] 189 # [2] 190 counts = [2., 1] 191 pmf = multinomial.Multinomial(total_count=n, probs=p).prob(counts) 192 self.evaluate(pmf) 193 self.assertEqual(pmf.get_shape(), (2, 2)) 194 195 def testPmfShapeCountsPStretchedN(self): 196 with self.cached_session(): 197 p = [0.1, 0.9] 198 counts = [3., 2] 199 n = np.full([4, 3], 5., dtype=np.float32) 200 pmf = multinomial.Multinomial(total_count=n, probs=p).prob(counts) 201 self.evaluate(pmf) 202 self.assertEqual((4, 3), pmf.get_shape()) 203 204 @test_util.run_v1_only("b/120545219") 205 def testMultinomialMean(self): 206 with self.cached_session(): 207 n = 5. 208 p = [0.1, 0.2, 0.7] 209 dist = multinomial.Multinomial(total_count=n, probs=p) 210 expected_means = 5 * np.array(p, dtype=np.float32) 211 self.assertEqual((3,), dist.mean().get_shape()) 212 self.assertAllClose(expected_means, dist.mean()) 213 214 @test_util.run_v1_only("b/120545219") 215 def testMultinomialCovariance(self): 216 with self.cached_session(): 217 n = 5. 218 p = [0.1, 0.2, 0.7] 219 dist = multinomial.Multinomial(total_count=n, probs=p) 220 expected_covariances = [[9. / 20, -1 / 10, -7 / 20], 221 [-1 / 10, 4 / 5, -7 / 10], 222 [-7 / 20, -7 / 10, 21 / 20]] 223 self.assertEqual((3, 3), dist.covariance().get_shape()) 224 self.assertAllClose(expected_covariances, dist.covariance()) 225 226 @test_util.run_v1_only("b/120545219") 227 def testMultinomialCovarianceBatch(self): 228 with self.cached_session(): 229 # Shape [2] 230 n = [5.] * 2 231 # Shape [4, 1, 2] 232 p = [[[0.1, 0.9]], [[0.1, 0.9]]] * 2 233 dist = multinomial.Multinomial(total_count=n, probs=p) 234 # Shape [2, 2] 235 inner_var = [[9. / 20, -9 / 20], [-9 / 20, 9 / 20]] 236 # Shape [4, 2, 2, 2] 237 expected_covariances = [[inner_var, inner_var]] * 4 238 self.assertEqual((4, 2, 2, 2), dist.covariance().get_shape()) 239 self.assertAllClose(expected_covariances, dist.covariance()) 240 241 def testCovarianceMultidimensional(self): 242 # Shape [3, 5, 4] 243 p = np.random.dirichlet([.25, .25, .25, .25], [3, 5]).astype(np.float32) 244 # Shape [6, 3, 3] 245 p2 = np.random.dirichlet([.3, .3, .4], [6, 3]).astype(np.float32) 246 247 ns = np.random.randint(low=1, high=11, size=[3, 5]).astype(np.float32) 248 ns2 = np.random.randint(low=1, high=11, size=[6, 1]).astype(np.float32) 249 250 with self.cached_session(): 251 dist = multinomial.Multinomial(ns, p) 252 dist2 = multinomial.Multinomial(ns2, p2) 253 254 covariance = dist.covariance() 255 covariance2 = dist2.covariance() 256 self.assertEqual((3, 5, 4, 4), covariance.get_shape()) 257 self.assertEqual((6, 3, 3, 3), covariance2.get_shape()) 258 259 @test_util.run_v1_only("b/120545219") 260 def testCovarianceFromSampling(self): 261 # We will test mean, cov, var, stddev on a DirichletMultinomial constructed 262 # via broadcast between alpha, n. 263 theta = np.array([[1., 2, 3], 264 [2.5, 4, 0.01]], dtype=np.float32) 265 theta /= np.sum(theta, 1)[..., array_ops.newaxis] 266 n = np.array([[10., 9.], [8., 7.], [6., 5.]], dtype=np.float32) 267 with self.cached_session() as sess: 268 # batch_shape=[3, 2], event_shape=[3] 269 dist = multinomial.Multinomial(n, theta) 270 x = dist.sample(int(1000e3), seed=1) 271 sample_mean = math_ops.reduce_mean(x, 0) 272 x_centered = x - sample_mean[array_ops.newaxis, ...] 273 sample_cov = math_ops.reduce_mean(math_ops.matmul( 274 x_centered[..., array_ops.newaxis], 275 x_centered[..., array_ops.newaxis, :]), 0) 276 sample_var = array_ops.matrix_diag_part(sample_cov) 277 sample_stddev = math_ops.sqrt(sample_var) 278 [ 279 sample_mean_, 280 sample_cov_, 281 sample_var_, 282 sample_stddev_, 283 analytic_mean, 284 analytic_cov, 285 analytic_var, 286 analytic_stddev, 287 ] = sess.run([ 288 sample_mean, 289 sample_cov, 290 sample_var, 291 sample_stddev, 292 dist.mean(), 293 dist.covariance(), 294 dist.variance(), 295 dist.stddev(), 296 ]) 297 self.assertAllClose(sample_mean_, analytic_mean, atol=0.01, rtol=0.01) 298 self.assertAllClose(sample_cov_, analytic_cov, atol=0.01, rtol=0.01) 299 self.assertAllClose(sample_var_, analytic_var, atol=0.01, rtol=0.01) 300 self.assertAllClose(sample_stddev_, analytic_stddev, atol=0.01, rtol=0.01) 301 302 @test_util.run_v1_only("b/120545219") 303 def testSampleUnbiasedNonScalarBatch(self): 304 with self.cached_session() as sess: 305 dist = multinomial.Multinomial( 306 total_count=[7., 6., 5.], 307 logits=math_ops.log(2. * self._rng.rand(4, 3, 2).astype(np.float32))) 308 n = int(3e4) 309 x = dist.sample(n, seed=0) 310 sample_mean = math_ops.reduce_mean(x, 0) 311 # Cyclically rotate event dims left. 312 x_centered = array_ops.transpose(x - sample_mean, [1, 2, 3, 0]) 313 sample_covariance = math_ops.matmul( 314 x_centered, x_centered, adjoint_b=True) / n 315 [ 316 sample_mean_, 317 sample_covariance_, 318 actual_mean_, 319 actual_covariance_, 320 ] = sess.run([ 321 sample_mean, 322 sample_covariance, 323 dist.mean(), 324 dist.covariance(), 325 ]) 326 self.assertAllEqual([4, 3, 2], sample_mean.get_shape()) 327 self.assertAllClose(actual_mean_, sample_mean_, atol=0., rtol=0.10) 328 self.assertAllEqual([4, 3, 2, 2], sample_covariance.get_shape()) 329 self.assertAllClose( 330 actual_covariance_, sample_covariance_, atol=0., rtol=0.20) 331 332 @test_util.run_v1_only("b/120545219") 333 def testSampleUnbiasedScalarBatch(self): 334 with self.cached_session() as sess: 335 dist = multinomial.Multinomial( 336 total_count=5., 337 logits=math_ops.log(2. * self._rng.rand(4).astype(np.float32))) 338 n = int(5e3) 339 x = dist.sample(n, seed=0) 340 sample_mean = math_ops.reduce_mean(x, 0) 341 x_centered = x - sample_mean # Already transposed to [n, 2]. 342 sample_covariance = math_ops.matmul( 343 x_centered, x_centered, adjoint_a=True) / n 344 [ 345 sample_mean_, 346 sample_covariance_, 347 actual_mean_, 348 actual_covariance_, 349 ] = sess.run([ 350 sample_mean, 351 sample_covariance, 352 dist.mean(), 353 dist.covariance(), 354 ]) 355 self.assertAllEqual([4], sample_mean.get_shape()) 356 self.assertAllClose(actual_mean_, sample_mean_, atol=0., rtol=0.10) 357 self.assertAllEqual([4, 4], sample_covariance.get_shape()) 358 self.assertAllClose( 359 actual_covariance_, sample_covariance_, atol=0., rtol=0.20) 360 361 def testNotReparameterized(self): 362 total_count = constant_op.constant(5.0) 363 p = constant_op.constant([0.2, 0.6]) 364 with backprop.GradientTape() as tape: 365 tape.watch(total_count) 366 tape.watch(p) 367 dist = multinomial.Multinomial( 368 total_count=total_count, 369 probs=p) 370 samples = dist.sample(100) 371 grad_total_count, grad_p = tape.gradient(samples, [total_count, p]) 372 self.assertIsNone(grad_total_count) 373 self.assertIsNone(grad_p) 374 375 376if __name__ == "__main__": 377 test.main() 378