• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15import numpy as np
16
17from tensorflow.python.eager import backprop
18from tensorflow.python.framework import constant_op
19from tensorflow.python.framework import dtypes
20from tensorflow.python.framework import tensor_shape
21from tensorflow.python.framework import test_util
22from tensorflow.python.ops import array_ops
23from tensorflow.python.ops import math_ops
24from tensorflow.python.ops.distributions import multinomial
25from tensorflow.python.platform import test
26
27
28class MultinomialTest(test.TestCase):
29
30  def setUp(self):
31    self._rng = np.random.RandomState(42)
32
33  @test_util.run_v1_only("b/120545219")
34  def testSimpleShapes(self):
35    with self.cached_session():
36      p = [.1, .3, .6]
37      dist = multinomial.Multinomial(total_count=1., probs=p)
38      self.assertEqual(3, dist.event_shape_tensor().eval())
39      self.assertAllEqual([], dist.batch_shape_tensor())
40      self.assertEqual(tensor_shape.TensorShape([3]), dist.event_shape)
41      self.assertEqual(tensor_shape.TensorShape([]), dist.batch_shape)
42
43  @test_util.run_v1_only("b/120545219")
44  def testComplexShapes(self):
45    with self.cached_session():
46      p = 0.5 * np.ones([3, 2, 2], dtype=np.float32)
47      n = [[3., 2], [4, 5], [6, 7]]
48      dist = multinomial.Multinomial(total_count=n, probs=p)
49      self.assertEqual(2, dist.event_shape_tensor().eval())
50      self.assertAllEqual([3, 2], dist.batch_shape_tensor())
51      self.assertEqual(tensor_shape.TensorShape([2]), dist.event_shape)
52      self.assertEqual(tensor_shape.TensorShape([3, 2]), dist.batch_shape)
53
54  @test_util.run_v1_only("b/120545219")
55  def testN(self):
56    p = [[0.1, 0.2, 0.7], [0.2, 0.3, 0.5]]
57    n = [[3.], [4]]
58    with self.cached_session():
59      dist = multinomial.Multinomial(total_count=n, probs=p)
60      self.assertEqual((2, 1), dist.total_count.get_shape())
61      self.assertAllClose(n, dist.total_count)
62
63  @test_util.run_v1_only("b/120545219")
64  def testP(self):
65    p = [[0.1, 0.2, 0.7]]
66    with self.cached_session():
67      dist = multinomial.Multinomial(total_count=3., probs=p)
68      self.assertEqual((1, 3), dist.probs.get_shape())
69      self.assertEqual((1, 3), dist.logits.get_shape())
70      self.assertAllClose(p, dist.probs)
71
72  @test_util.run_v1_only("b/120545219")
73  def testLogits(self):
74    p = np.array([[0.1, 0.2, 0.7]], dtype=np.float32)
75    logits = np.log(p) - 50.
76    with self.cached_session():
77      multinom = multinomial.Multinomial(total_count=3., logits=logits)
78      self.assertEqual((1, 3), multinom.probs.get_shape())
79      self.assertEqual((1, 3), multinom.logits.get_shape())
80      self.assertAllClose(p, multinom.probs)
81      self.assertAllClose(logits, multinom.logits)
82
83  @test_util.run_v1_only("b/120545219")
84  def testPmfUnderflow(self):
85    logits = np.array([[-200, 0]], dtype=np.float32)
86    with self.cached_session():
87      dist = multinomial.Multinomial(total_count=1., logits=logits)
88      lp = dist.log_prob([1., 0.]).eval()[0]
89      self.assertAllClose(-200, lp, atol=0, rtol=1e-6)
90
91  @test_util.run_v1_only("b/120545219")
92  def testPmfandCountsAgree(self):
93    p = [[0.1, 0.2, 0.7]]
94    n = [[5.]]
95    with self.cached_session():
96      dist = multinomial.Multinomial(total_count=n, probs=p, validate_args=True)
97      dist.prob([2., 3, 0]).eval()
98      dist.prob([3., 0, 2]).eval()
99      with self.assertRaisesOpError("must be non-negative"):
100        dist.prob([-1., 4, 2]).eval()
101      with self.assertRaisesOpError("counts must sum to `self.total_count`"):
102        dist.prob([3., 3, 0]).eval()
103
104  @test_util.run_v1_only("b/120545219")
105  def testPmfNonIntegerCounts(self):
106    p = [[0.1, 0.2, 0.7]]
107    n = [[5.]]
108    with self.cached_session():
109      # No errors with integer n.
110      multinom = multinomial.Multinomial(
111          total_count=n, probs=p, validate_args=True)
112      multinom.prob([2., 1, 2]).eval()
113      multinom.prob([3., 0, 2]).eval()
114      # Counts don't sum to n.
115      with self.assertRaisesOpError("counts must sum to `self.total_count`"):
116        multinom.prob([2., 3, 2]).eval()
117      # Counts are non-integers.
118      x = array_ops.placeholder(dtypes.float32)
119      with self.assertRaisesOpError(
120          "cannot contain fractional components."):
121        multinom.prob(x).eval(feed_dict={x: [1.0, 2.5, 1.5]})
122
123      multinom = multinomial.Multinomial(
124          total_count=n, probs=p, validate_args=False)
125      multinom.prob([1., 2., 2.]).eval()
126      # Non-integer arguments work.
127      multinom.prob([1.0, 2.5, 1.5]).eval()
128
129  def testPmfBothZeroBatches(self):
130    with self.cached_session():
131      # Both zero-batches.  No broadcast
132      p = [0.5, 0.5]
133      counts = [1., 0]
134      pmf = multinomial.Multinomial(total_count=1., probs=p).prob(counts)
135      self.assertAllClose(0.5, self.evaluate(pmf))
136      self.assertEqual((), pmf.get_shape())
137
138  def testPmfBothZeroBatchesNontrivialN(self):
139    with self.cached_session():
140      # Both zero-batches.  No broadcast
141      p = [0.1, 0.9]
142      counts = [3., 2]
143      dist = multinomial.Multinomial(total_count=5., probs=p)
144      pmf = dist.prob(counts)
145      # 5 choose 3 = 5 choose 2 = 10. 10 * (.9)^2 * (.1)^3 = 81/10000.
146      self.assertAllClose(81. / 10000, self.evaluate(pmf))
147      self.assertEqual((), pmf.get_shape())
148
149  def testPmfPStretchedInBroadcastWhenSameRank(self):
150    with self.cached_session():
151      p = [[0.1, 0.9]]
152      counts = [[1., 0], [0, 1]]
153      pmf = multinomial.Multinomial(total_count=1., probs=p).prob(counts)
154      self.assertAllClose([0.1, 0.9], self.evaluate(pmf))
155      self.assertEqual((2), pmf.get_shape())
156
157  def testPmfPStretchedInBroadcastWhenLowerRank(self):
158    with self.cached_session():
159      p = [0.1, 0.9]
160      counts = [[1., 0], [0, 1]]
161      pmf = multinomial.Multinomial(total_count=1., probs=p).prob(counts)
162      self.assertAllClose([0.1, 0.9], self.evaluate(pmf))
163      self.assertEqual((2), pmf.get_shape())
164
165  @test_util.run_v1_only("b/120545219")
166  def testPmfCountsStretchedInBroadcastWhenSameRank(self):
167    with self.cached_session():
168      p = [[0.1, 0.9], [0.7, 0.3]]
169      counts = [[1., 0]]
170      pmf = multinomial.Multinomial(total_count=1., probs=p).prob(counts)
171      self.assertAllClose(pmf, [0.1, 0.7])
172      self.assertEqual((2), pmf.get_shape())
173
174  @test_util.run_v1_only("b/120545219")
175  def testPmfCountsStretchedInBroadcastWhenLowerRank(self):
176    with self.cached_session():
177      p = [[0.1, 0.9], [0.7, 0.3]]
178      counts = [1., 0]
179      pmf = multinomial.Multinomial(total_count=1., probs=p).prob(counts)
180      self.assertAllClose(pmf, [0.1, 0.7])
181      self.assertEqual(pmf.get_shape(), (2))
182
183  def testPmfShapeCountsStretchedN(self):
184    with self.cached_session():
185      # [2, 2, 2]
186      p = [[[0.1, 0.9], [0.1, 0.9]], [[0.7, 0.3], [0.7, 0.3]]]
187      # [2, 2]
188      n = [[3., 3], [3, 3]]
189      # [2]
190      counts = [2., 1]
191      pmf = multinomial.Multinomial(total_count=n, probs=p).prob(counts)
192      self.evaluate(pmf)
193      self.assertEqual(pmf.get_shape(), (2, 2))
194
195  def testPmfShapeCountsPStretchedN(self):
196    with self.cached_session():
197      p = [0.1, 0.9]
198      counts = [3., 2]
199      n = np.full([4, 3], 5., dtype=np.float32)
200      pmf = multinomial.Multinomial(total_count=n, probs=p).prob(counts)
201      self.evaluate(pmf)
202      self.assertEqual((4, 3), pmf.get_shape())
203
204  @test_util.run_v1_only("b/120545219")
205  def testMultinomialMean(self):
206    with self.cached_session():
207      n = 5.
208      p = [0.1, 0.2, 0.7]
209      dist = multinomial.Multinomial(total_count=n, probs=p)
210      expected_means = 5 * np.array(p, dtype=np.float32)
211      self.assertEqual((3,), dist.mean().get_shape())
212      self.assertAllClose(expected_means, dist.mean())
213
214  @test_util.run_v1_only("b/120545219")
215  def testMultinomialCovariance(self):
216    with self.cached_session():
217      n = 5.
218      p = [0.1, 0.2, 0.7]
219      dist = multinomial.Multinomial(total_count=n, probs=p)
220      expected_covariances = [[9. / 20, -1 / 10, -7 / 20],
221                              [-1 / 10, 4 / 5, -7 / 10],
222                              [-7 / 20, -7 / 10, 21 / 20]]
223      self.assertEqual((3, 3), dist.covariance().get_shape())
224      self.assertAllClose(expected_covariances, dist.covariance())
225
226  @test_util.run_v1_only("b/120545219")
227  def testMultinomialCovarianceBatch(self):
228    with self.cached_session():
229      # Shape [2]
230      n = [5.] * 2
231      # Shape [4, 1, 2]
232      p = [[[0.1, 0.9]], [[0.1, 0.9]]] * 2
233      dist = multinomial.Multinomial(total_count=n, probs=p)
234      # Shape [2, 2]
235      inner_var = [[9. / 20, -9 / 20], [-9 / 20, 9 / 20]]
236      # Shape [4, 2, 2, 2]
237      expected_covariances = [[inner_var, inner_var]] * 4
238      self.assertEqual((4, 2, 2, 2), dist.covariance().get_shape())
239      self.assertAllClose(expected_covariances, dist.covariance())
240
241  def testCovarianceMultidimensional(self):
242    # Shape [3, 5, 4]
243    p = np.random.dirichlet([.25, .25, .25, .25], [3, 5]).astype(np.float32)
244    # Shape [6, 3, 3]
245    p2 = np.random.dirichlet([.3, .3, .4], [6, 3]).astype(np.float32)
246
247    ns = np.random.randint(low=1, high=11, size=[3, 5]).astype(np.float32)
248    ns2 = np.random.randint(low=1, high=11, size=[6, 1]).astype(np.float32)
249
250    with self.cached_session():
251      dist = multinomial.Multinomial(ns, p)
252      dist2 = multinomial.Multinomial(ns2, p2)
253
254      covariance = dist.covariance()
255      covariance2 = dist2.covariance()
256      self.assertEqual((3, 5, 4, 4), covariance.get_shape())
257      self.assertEqual((6, 3, 3, 3), covariance2.get_shape())
258
259  @test_util.run_v1_only("b/120545219")
260  def testCovarianceFromSampling(self):
261    # We will test mean, cov, var, stddev on a DirichletMultinomial constructed
262    # via broadcast between alpha, n.
263    theta = np.array([[1., 2, 3],
264                      [2.5, 4, 0.01]], dtype=np.float32)
265    theta /= np.sum(theta, 1)[..., array_ops.newaxis]
266    n = np.array([[10., 9.], [8., 7.], [6., 5.]], dtype=np.float32)
267    with self.cached_session() as sess:
268      # batch_shape=[3, 2], event_shape=[3]
269      dist = multinomial.Multinomial(n, theta)
270      x = dist.sample(int(1000e3), seed=1)
271      sample_mean = math_ops.reduce_mean(x, 0)
272      x_centered = x - sample_mean[array_ops.newaxis, ...]
273      sample_cov = math_ops.reduce_mean(math_ops.matmul(
274          x_centered[..., array_ops.newaxis],
275          x_centered[..., array_ops.newaxis, :]), 0)
276      sample_var = array_ops.matrix_diag_part(sample_cov)
277      sample_stddev = math_ops.sqrt(sample_var)
278      [
279          sample_mean_,
280          sample_cov_,
281          sample_var_,
282          sample_stddev_,
283          analytic_mean,
284          analytic_cov,
285          analytic_var,
286          analytic_stddev,
287      ] = sess.run([
288          sample_mean,
289          sample_cov,
290          sample_var,
291          sample_stddev,
292          dist.mean(),
293          dist.covariance(),
294          dist.variance(),
295          dist.stddev(),
296      ])
297      self.assertAllClose(sample_mean_, analytic_mean, atol=0.01, rtol=0.01)
298      self.assertAllClose(sample_cov_, analytic_cov, atol=0.01, rtol=0.01)
299      self.assertAllClose(sample_var_, analytic_var, atol=0.01, rtol=0.01)
300      self.assertAllClose(sample_stddev_, analytic_stddev, atol=0.01, rtol=0.01)
301
302  @test_util.run_v1_only("b/120545219")
303  def testSampleUnbiasedNonScalarBatch(self):
304    with self.cached_session() as sess:
305      dist = multinomial.Multinomial(
306          total_count=[7., 6., 5.],
307          logits=math_ops.log(2. * self._rng.rand(4, 3, 2).astype(np.float32)))
308      n = int(3e4)
309      x = dist.sample(n, seed=0)
310      sample_mean = math_ops.reduce_mean(x, 0)
311      # Cyclically rotate event dims left.
312      x_centered = array_ops.transpose(x - sample_mean, [1, 2, 3, 0])
313      sample_covariance = math_ops.matmul(
314          x_centered, x_centered, adjoint_b=True) / n
315      [
316          sample_mean_,
317          sample_covariance_,
318          actual_mean_,
319          actual_covariance_,
320      ] = sess.run([
321          sample_mean,
322          sample_covariance,
323          dist.mean(),
324          dist.covariance(),
325      ])
326      self.assertAllEqual([4, 3, 2], sample_mean.get_shape())
327      self.assertAllClose(actual_mean_, sample_mean_, atol=0., rtol=0.10)
328      self.assertAllEqual([4, 3, 2, 2], sample_covariance.get_shape())
329      self.assertAllClose(
330          actual_covariance_, sample_covariance_, atol=0., rtol=0.20)
331
332  @test_util.run_v1_only("b/120545219")
333  def testSampleUnbiasedScalarBatch(self):
334    with self.cached_session() as sess:
335      dist = multinomial.Multinomial(
336          total_count=5.,
337          logits=math_ops.log(2. * self._rng.rand(4).astype(np.float32)))
338      n = int(5e3)
339      x = dist.sample(n, seed=0)
340      sample_mean = math_ops.reduce_mean(x, 0)
341      x_centered = x - sample_mean  # Already transposed to [n, 2].
342      sample_covariance = math_ops.matmul(
343          x_centered, x_centered, adjoint_a=True) / n
344      [
345          sample_mean_,
346          sample_covariance_,
347          actual_mean_,
348          actual_covariance_,
349      ] = sess.run([
350          sample_mean,
351          sample_covariance,
352          dist.mean(),
353          dist.covariance(),
354      ])
355      self.assertAllEqual([4], sample_mean.get_shape())
356      self.assertAllClose(actual_mean_, sample_mean_, atol=0., rtol=0.10)
357      self.assertAllEqual([4, 4], sample_covariance.get_shape())
358      self.assertAllClose(
359          actual_covariance_, sample_covariance_, atol=0., rtol=0.20)
360
361  def testNotReparameterized(self):
362    total_count = constant_op.constant(5.0)
363    p = constant_op.constant([0.2, 0.6])
364    with backprop.GradientTape() as tape:
365      tape.watch(total_count)
366      tape.watch(p)
367      dist = multinomial.Multinomial(
368          total_count=total_count,
369          probs=p)
370      samples = dist.sample(100)
371    grad_total_count, grad_p = tape.gradient(samples, [total_count, p])
372    self.assertIsNone(grad_total_count)
373    self.assertIsNone(grad_p)
374
375
376if __name__ == "__main__":
377  test.main()
378