1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15from __future__ import absolute_import 16from __future__ import division 17from __future__ import print_function 18 19import numpy as np 20 21from tensorflow.python.eager import backprop 22from tensorflow.python.framework import constant_op 23from tensorflow.python.framework import dtypes 24from tensorflow.python.framework import tensor_shape 25from tensorflow.python.framework import test_util 26from tensorflow.python.ops import array_ops 27from tensorflow.python.ops import math_ops 28from tensorflow.python.ops.distributions import multinomial 29from tensorflow.python.platform import test 30 31 32class MultinomialTest(test.TestCase): 33 34 def setUp(self): 35 self._rng = np.random.RandomState(42) 36 37 @test_util.run_v1_only("b/120545219") 38 def testSimpleShapes(self): 39 with self.cached_session(): 40 p = [.1, .3, .6] 41 dist = multinomial.Multinomial(total_count=1., probs=p) 42 self.assertEqual(3, dist.event_shape_tensor().eval()) 43 self.assertAllEqual([], dist.batch_shape_tensor().eval()) 44 self.assertEqual(tensor_shape.TensorShape([3]), dist.event_shape) 45 self.assertEqual(tensor_shape.TensorShape([]), dist.batch_shape) 46 47 @test_util.run_v1_only("b/120545219") 48 def testComplexShapes(self): 49 with self.cached_session(): 50 p = 0.5 * np.ones([3, 2, 2], dtype=np.float32) 51 n = [[3., 2], [4, 5], [6, 7]] 52 dist = multinomial.Multinomial(total_count=n, probs=p) 53 self.assertEqual(2, dist.event_shape_tensor().eval()) 54 self.assertAllEqual([3, 2], dist.batch_shape_tensor().eval()) 55 self.assertEqual(tensor_shape.TensorShape([2]), dist.event_shape) 56 self.assertEqual(tensor_shape.TensorShape([3, 2]), dist.batch_shape) 57 58 @test_util.run_v1_only("b/120545219") 59 def testN(self): 60 p = [[0.1, 0.2, 0.7], [0.2, 0.3, 0.5]] 61 n = [[3.], [4]] 62 with self.cached_session(): 63 dist = multinomial.Multinomial(total_count=n, probs=p) 64 self.assertEqual((2, 1), dist.total_count.get_shape()) 65 self.assertAllClose(n, dist.total_count.eval()) 66 67 @test_util.run_v1_only("b/120545219") 68 def testP(self): 69 p = [[0.1, 0.2, 0.7]] 70 with self.cached_session(): 71 dist = multinomial.Multinomial(total_count=3., probs=p) 72 self.assertEqual((1, 3), dist.probs.get_shape()) 73 self.assertEqual((1, 3), dist.logits.get_shape()) 74 self.assertAllClose(p, dist.probs.eval()) 75 76 @test_util.run_v1_only("b/120545219") 77 def testLogits(self): 78 p = np.array([[0.1, 0.2, 0.7]], dtype=np.float32) 79 logits = np.log(p) - 50. 80 with self.cached_session(): 81 multinom = multinomial.Multinomial(total_count=3., logits=logits) 82 self.assertEqual((1, 3), multinom.probs.get_shape()) 83 self.assertEqual((1, 3), multinom.logits.get_shape()) 84 self.assertAllClose(p, multinom.probs.eval()) 85 self.assertAllClose(logits, multinom.logits.eval()) 86 87 @test_util.run_v1_only("b/120545219") 88 def testPmfUnderflow(self): 89 logits = np.array([[-200, 0]], dtype=np.float32) 90 with self.cached_session(): 91 dist = multinomial.Multinomial(total_count=1., logits=logits) 92 lp = dist.log_prob([1., 0.]).eval()[0] 93 self.assertAllClose(-200, lp, atol=0, rtol=1e-6) 94 95 @test_util.run_v1_only("b/120545219") 96 def testPmfandCountsAgree(self): 97 p = [[0.1, 0.2, 0.7]] 98 n = [[5.]] 99 with self.cached_session(): 100 dist = multinomial.Multinomial(total_count=n, probs=p, validate_args=True) 101 dist.prob([2., 3, 0]).eval() 102 dist.prob([3., 0, 2]).eval() 103 with self.assertRaisesOpError("must be non-negative"): 104 dist.prob([-1., 4, 2]).eval() 105 with self.assertRaisesOpError("counts must sum to `self.total_count`"): 106 dist.prob([3., 3, 0]).eval() 107 108 @test_util.run_v1_only("b/120545219") 109 def testPmfNonIntegerCounts(self): 110 p = [[0.1, 0.2, 0.7]] 111 n = [[5.]] 112 with self.cached_session(): 113 # No errors with integer n. 114 multinom = multinomial.Multinomial( 115 total_count=n, probs=p, validate_args=True) 116 multinom.prob([2., 1, 2]).eval() 117 multinom.prob([3., 0, 2]).eval() 118 # Counts don't sum to n. 119 with self.assertRaisesOpError("counts must sum to `self.total_count`"): 120 multinom.prob([2., 3, 2]).eval() 121 # Counts are non-integers. 122 x = array_ops.placeholder(dtypes.float32) 123 with self.assertRaisesOpError( 124 "cannot contain fractional components."): 125 multinom.prob(x).eval(feed_dict={x: [1.0, 2.5, 1.5]}) 126 127 multinom = multinomial.Multinomial( 128 total_count=n, probs=p, validate_args=False) 129 multinom.prob([1., 2., 2.]).eval() 130 # Non-integer arguments work. 131 multinom.prob([1.0, 2.5, 1.5]).eval() 132 133 def testPmfBothZeroBatches(self): 134 with self.cached_session(): 135 # Both zero-batches. No broadcast 136 p = [0.5, 0.5] 137 counts = [1., 0] 138 pmf = multinomial.Multinomial(total_count=1., probs=p).prob(counts) 139 self.assertAllClose(0.5, self.evaluate(pmf)) 140 self.assertEqual((), pmf.get_shape()) 141 142 def testPmfBothZeroBatchesNontrivialN(self): 143 with self.cached_session(): 144 # Both zero-batches. No broadcast 145 p = [0.1, 0.9] 146 counts = [3., 2] 147 dist = multinomial.Multinomial(total_count=5., probs=p) 148 pmf = dist.prob(counts) 149 # 5 choose 3 = 5 choose 2 = 10. 10 * (.9)^2 * (.1)^3 = 81/10000. 150 self.assertAllClose(81. / 10000, self.evaluate(pmf)) 151 self.assertEqual((), pmf.get_shape()) 152 153 def testPmfPStretchedInBroadcastWhenSameRank(self): 154 with self.cached_session(): 155 p = [[0.1, 0.9]] 156 counts = [[1., 0], [0, 1]] 157 pmf = multinomial.Multinomial(total_count=1., probs=p).prob(counts) 158 self.assertAllClose([0.1, 0.9], self.evaluate(pmf)) 159 self.assertEqual((2), pmf.get_shape()) 160 161 def testPmfPStretchedInBroadcastWhenLowerRank(self): 162 with self.cached_session(): 163 p = [0.1, 0.9] 164 counts = [[1., 0], [0, 1]] 165 pmf = multinomial.Multinomial(total_count=1., probs=p).prob(counts) 166 self.assertAllClose([0.1, 0.9], self.evaluate(pmf)) 167 self.assertEqual((2), pmf.get_shape()) 168 169 @test_util.run_v1_only("b/120545219") 170 def testPmfCountsStretchedInBroadcastWhenSameRank(self): 171 with self.cached_session(): 172 p = [[0.1, 0.9], [0.7, 0.3]] 173 counts = [[1., 0]] 174 pmf = multinomial.Multinomial(total_count=1., probs=p).prob(counts) 175 self.assertAllClose(pmf.eval(), [0.1, 0.7]) 176 self.assertEqual((2), pmf.get_shape()) 177 178 @test_util.run_v1_only("b/120545219") 179 def testPmfCountsStretchedInBroadcastWhenLowerRank(self): 180 with self.cached_session(): 181 p = [[0.1, 0.9], [0.7, 0.3]] 182 counts = [1., 0] 183 pmf = multinomial.Multinomial(total_count=1., probs=p).prob(counts) 184 self.assertAllClose(pmf.eval(), [0.1, 0.7]) 185 self.assertEqual(pmf.get_shape(), (2)) 186 187 def testPmfShapeCountsStretchedN(self): 188 with self.cached_session(): 189 # [2, 2, 2] 190 p = [[[0.1, 0.9], [0.1, 0.9]], [[0.7, 0.3], [0.7, 0.3]]] 191 # [2, 2] 192 n = [[3., 3], [3, 3]] 193 # [2] 194 counts = [2., 1] 195 pmf = multinomial.Multinomial(total_count=n, probs=p).prob(counts) 196 self.evaluate(pmf) 197 self.assertEqual(pmf.get_shape(), (2, 2)) 198 199 def testPmfShapeCountsPStretchedN(self): 200 with self.cached_session(): 201 p = [0.1, 0.9] 202 counts = [3., 2] 203 n = np.full([4, 3], 5., dtype=np.float32) 204 pmf = multinomial.Multinomial(total_count=n, probs=p).prob(counts) 205 self.evaluate(pmf) 206 self.assertEqual((4, 3), pmf.get_shape()) 207 208 @test_util.run_v1_only("b/120545219") 209 def testMultinomialMean(self): 210 with self.cached_session(): 211 n = 5. 212 p = [0.1, 0.2, 0.7] 213 dist = multinomial.Multinomial(total_count=n, probs=p) 214 expected_means = 5 * np.array(p, dtype=np.float32) 215 self.assertEqual((3,), dist.mean().get_shape()) 216 self.assertAllClose(expected_means, dist.mean().eval()) 217 218 @test_util.run_v1_only("b/120545219") 219 def testMultinomialCovariance(self): 220 with self.cached_session(): 221 n = 5. 222 p = [0.1, 0.2, 0.7] 223 dist = multinomial.Multinomial(total_count=n, probs=p) 224 expected_covariances = [[9. / 20, -1 / 10, -7 / 20], 225 [-1 / 10, 4 / 5, -7 / 10], 226 [-7 / 20, -7 / 10, 21 / 20]] 227 self.assertEqual((3, 3), dist.covariance().get_shape()) 228 self.assertAllClose(expected_covariances, dist.covariance().eval()) 229 230 @test_util.run_v1_only("b/120545219") 231 def testMultinomialCovarianceBatch(self): 232 with self.cached_session(): 233 # Shape [2] 234 n = [5.] * 2 235 # Shape [4, 1, 2] 236 p = [[[0.1, 0.9]], [[0.1, 0.9]]] * 2 237 dist = multinomial.Multinomial(total_count=n, probs=p) 238 # Shape [2, 2] 239 inner_var = [[9. / 20, -9 / 20], [-9 / 20, 9 / 20]] 240 # Shape [4, 2, 2, 2] 241 expected_covariances = [[inner_var, inner_var]] * 4 242 self.assertEqual((4, 2, 2, 2), dist.covariance().get_shape()) 243 self.assertAllClose(expected_covariances, dist.covariance().eval()) 244 245 def testCovarianceMultidimensional(self): 246 # Shape [3, 5, 4] 247 p = np.random.dirichlet([.25, .25, .25, .25], [3, 5]).astype(np.float32) 248 # Shape [6, 3, 3] 249 p2 = np.random.dirichlet([.3, .3, .4], [6, 3]).astype(np.float32) 250 251 ns = np.random.randint(low=1, high=11, size=[3, 5]).astype(np.float32) 252 ns2 = np.random.randint(low=1, high=11, size=[6, 1]).astype(np.float32) 253 254 with self.cached_session(): 255 dist = multinomial.Multinomial(ns, p) 256 dist2 = multinomial.Multinomial(ns2, p2) 257 258 covariance = dist.covariance() 259 covariance2 = dist2.covariance() 260 self.assertEqual((3, 5, 4, 4), covariance.get_shape()) 261 self.assertEqual((6, 3, 3, 3), covariance2.get_shape()) 262 263 @test_util.run_v1_only("b/120545219") 264 def testCovarianceFromSampling(self): 265 # We will test mean, cov, var, stddev on a DirichletMultinomial constructed 266 # via broadcast between alpha, n. 267 theta = np.array([[1., 2, 3], 268 [2.5, 4, 0.01]], dtype=np.float32) 269 theta /= np.sum(theta, 1)[..., array_ops.newaxis] 270 n = np.array([[10., 9.], [8., 7.], [6., 5.]], dtype=np.float32) 271 with self.cached_session() as sess: 272 # batch_shape=[3, 2], event_shape=[3] 273 dist = multinomial.Multinomial(n, theta) 274 x = dist.sample(int(1000e3), seed=1) 275 sample_mean = math_ops.reduce_mean(x, 0) 276 x_centered = x - sample_mean[array_ops.newaxis, ...] 277 sample_cov = math_ops.reduce_mean(math_ops.matmul( 278 x_centered[..., array_ops.newaxis], 279 x_centered[..., array_ops.newaxis, :]), 0) 280 sample_var = array_ops.matrix_diag_part(sample_cov) 281 sample_stddev = math_ops.sqrt(sample_var) 282 [ 283 sample_mean_, 284 sample_cov_, 285 sample_var_, 286 sample_stddev_, 287 analytic_mean, 288 analytic_cov, 289 analytic_var, 290 analytic_stddev, 291 ] = sess.run([ 292 sample_mean, 293 sample_cov, 294 sample_var, 295 sample_stddev, 296 dist.mean(), 297 dist.covariance(), 298 dist.variance(), 299 dist.stddev(), 300 ]) 301 self.assertAllClose(sample_mean_, analytic_mean, atol=0.01, rtol=0.01) 302 self.assertAllClose(sample_cov_, analytic_cov, atol=0.01, rtol=0.01) 303 self.assertAllClose(sample_var_, analytic_var, atol=0.01, rtol=0.01) 304 self.assertAllClose(sample_stddev_, analytic_stddev, atol=0.01, rtol=0.01) 305 306 @test_util.run_v1_only("b/120545219") 307 def testSampleUnbiasedNonScalarBatch(self): 308 with self.cached_session() as sess: 309 dist = multinomial.Multinomial( 310 total_count=[7., 6., 5.], 311 logits=math_ops.log(2. * self._rng.rand(4, 3, 2).astype(np.float32))) 312 n = int(3e4) 313 x = dist.sample(n, seed=0) 314 sample_mean = math_ops.reduce_mean(x, 0) 315 # Cyclically rotate event dims left. 316 x_centered = array_ops.transpose(x - sample_mean, [1, 2, 3, 0]) 317 sample_covariance = math_ops.matmul( 318 x_centered, x_centered, adjoint_b=True) / n 319 [ 320 sample_mean_, 321 sample_covariance_, 322 actual_mean_, 323 actual_covariance_, 324 ] = sess.run([ 325 sample_mean, 326 sample_covariance, 327 dist.mean(), 328 dist.covariance(), 329 ]) 330 self.assertAllEqual([4, 3, 2], sample_mean.get_shape()) 331 self.assertAllClose(actual_mean_, sample_mean_, atol=0., rtol=0.10) 332 self.assertAllEqual([4, 3, 2, 2], sample_covariance.get_shape()) 333 self.assertAllClose( 334 actual_covariance_, sample_covariance_, atol=0., rtol=0.20) 335 336 @test_util.run_v1_only("b/120545219") 337 def testSampleUnbiasedScalarBatch(self): 338 with self.cached_session() as sess: 339 dist = multinomial.Multinomial( 340 total_count=5., 341 logits=math_ops.log(2. * self._rng.rand(4).astype(np.float32))) 342 n = int(5e3) 343 x = dist.sample(n, seed=0) 344 sample_mean = math_ops.reduce_mean(x, 0) 345 x_centered = x - sample_mean # Already transposed to [n, 2]. 346 sample_covariance = math_ops.matmul( 347 x_centered, x_centered, adjoint_a=True) / n 348 [ 349 sample_mean_, 350 sample_covariance_, 351 actual_mean_, 352 actual_covariance_, 353 ] = sess.run([ 354 sample_mean, 355 sample_covariance, 356 dist.mean(), 357 dist.covariance(), 358 ]) 359 self.assertAllEqual([4], sample_mean.get_shape()) 360 self.assertAllClose(actual_mean_, sample_mean_, atol=0., rtol=0.10) 361 self.assertAllEqual([4, 4], sample_covariance.get_shape()) 362 self.assertAllClose( 363 actual_covariance_, sample_covariance_, atol=0., rtol=0.20) 364 365 def testNotReparameterized(self): 366 total_count = constant_op.constant(5.0) 367 p = constant_op.constant([0.2, 0.6]) 368 with backprop.GradientTape() as tape: 369 tape.watch(total_count) 370 tape.watch(p) 371 dist = multinomial.Multinomial( 372 total_count=total_count, 373 probs=p) 374 samples = dist.sample(100) 375 grad_total_count, grad_p = tape.gradient(samples, [total_count, p]) 376 self.assertIsNone(grad_total_count) 377 self.assertIsNone(grad_p) 378 379 380if __name__ == "__main__": 381 test.main() 382