• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15from __future__ import absolute_import
16from __future__ import division
17from __future__ import print_function
18
19import numpy as np
20
21from tensorflow.python.eager import backprop
22from tensorflow.python.framework import constant_op
23from tensorflow.python.framework import dtypes
24from tensorflow.python.framework import tensor_shape
25from tensorflow.python.framework import test_util
26from tensorflow.python.ops import array_ops
27from tensorflow.python.ops import math_ops
28from tensorflow.python.ops.distributions import multinomial
29from tensorflow.python.platform import test
30
31
32class MultinomialTest(test.TestCase):
33
34  def setUp(self):
35    self._rng = np.random.RandomState(42)
36
37  @test_util.run_v1_only("b/120545219")
38  def testSimpleShapes(self):
39    with self.cached_session():
40      p = [.1, .3, .6]
41      dist = multinomial.Multinomial(total_count=1., probs=p)
42      self.assertEqual(3, dist.event_shape_tensor().eval())
43      self.assertAllEqual([], dist.batch_shape_tensor().eval())
44      self.assertEqual(tensor_shape.TensorShape([3]), dist.event_shape)
45      self.assertEqual(tensor_shape.TensorShape([]), dist.batch_shape)
46
47  @test_util.run_v1_only("b/120545219")
48  def testComplexShapes(self):
49    with self.cached_session():
50      p = 0.5 * np.ones([3, 2, 2], dtype=np.float32)
51      n = [[3., 2], [4, 5], [6, 7]]
52      dist = multinomial.Multinomial(total_count=n, probs=p)
53      self.assertEqual(2, dist.event_shape_tensor().eval())
54      self.assertAllEqual([3, 2], dist.batch_shape_tensor().eval())
55      self.assertEqual(tensor_shape.TensorShape([2]), dist.event_shape)
56      self.assertEqual(tensor_shape.TensorShape([3, 2]), dist.batch_shape)
57
58  @test_util.run_v1_only("b/120545219")
59  def testN(self):
60    p = [[0.1, 0.2, 0.7], [0.2, 0.3, 0.5]]
61    n = [[3.], [4]]
62    with self.cached_session():
63      dist = multinomial.Multinomial(total_count=n, probs=p)
64      self.assertEqual((2, 1), dist.total_count.get_shape())
65      self.assertAllClose(n, dist.total_count.eval())
66
67  @test_util.run_v1_only("b/120545219")
68  def testP(self):
69    p = [[0.1, 0.2, 0.7]]
70    with self.cached_session():
71      dist = multinomial.Multinomial(total_count=3., probs=p)
72      self.assertEqual((1, 3), dist.probs.get_shape())
73      self.assertEqual((1, 3), dist.logits.get_shape())
74      self.assertAllClose(p, dist.probs.eval())
75
76  @test_util.run_v1_only("b/120545219")
77  def testLogits(self):
78    p = np.array([[0.1, 0.2, 0.7]], dtype=np.float32)
79    logits = np.log(p) - 50.
80    with self.cached_session():
81      multinom = multinomial.Multinomial(total_count=3., logits=logits)
82      self.assertEqual((1, 3), multinom.probs.get_shape())
83      self.assertEqual((1, 3), multinom.logits.get_shape())
84      self.assertAllClose(p, multinom.probs.eval())
85      self.assertAllClose(logits, multinom.logits.eval())
86
87  @test_util.run_v1_only("b/120545219")
88  def testPmfUnderflow(self):
89    logits = np.array([[-200, 0]], dtype=np.float32)
90    with self.cached_session():
91      dist = multinomial.Multinomial(total_count=1., logits=logits)
92      lp = dist.log_prob([1., 0.]).eval()[0]
93      self.assertAllClose(-200, lp, atol=0, rtol=1e-6)
94
95  @test_util.run_v1_only("b/120545219")
96  def testPmfandCountsAgree(self):
97    p = [[0.1, 0.2, 0.7]]
98    n = [[5.]]
99    with self.cached_session():
100      dist = multinomial.Multinomial(total_count=n, probs=p, validate_args=True)
101      dist.prob([2., 3, 0]).eval()
102      dist.prob([3., 0, 2]).eval()
103      with self.assertRaisesOpError("must be non-negative"):
104        dist.prob([-1., 4, 2]).eval()
105      with self.assertRaisesOpError("counts must sum to `self.total_count`"):
106        dist.prob([3., 3, 0]).eval()
107
108  @test_util.run_v1_only("b/120545219")
109  def testPmfNonIntegerCounts(self):
110    p = [[0.1, 0.2, 0.7]]
111    n = [[5.]]
112    with self.cached_session():
113      # No errors with integer n.
114      multinom = multinomial.Multinomial(
115          total_count=n, probs=p, validate_args=True)
116      multinom.prob([2., 1, 2]).eval()
117      multinom.prob([3., 0, 2]).eval()
118      # Counts don't sum to n.
119      with self.assertRaisesOpError("counts must sum to `self.total_count`"):
120        multinom.prob([2., 3, 2]).eval()
121      # Counts are non-integers.
122      x = array_ops.placeholder(dtypes.float32)
123      with self.assertRaisesOpError(
124          "cannot contain fractional components."):
125        multinom.prob(x).eval(feed_dict={x: [1.0, 2.5, 1.5]})
126
127      multinom = multinomial.Multinomial(
128          total_count=n, probs=p, validate_args=False)
129      multinom.prob([1., 2., 2.]).eval()
130      # Non-integer arguments work.
131      multinom.prob([1.0, 2.5, 1.5]).eval()
132
133  def testPmfBothZeroBatches(self):
134    with self.cached_session():
135      # Both zero-batches.  No broadcast
136      p = [0.5, 0.5]
137      counts = [1., 0]
138      pmf = multinomial.Multinomial(total_count=1., probs=p).prob(counts)
139      self.assertAllClose(0.5, self.evaluate(pmf))
140      self.assertEqual((), pmf.get_shape())
141
142  def testPmfBothZeroBatchesNontrivialN(self):
143    with self.cached_session():
144      # Both zero-batches.  No broadcast
145      p = [0.1, 0.9]
146      counts = [3., 2]
147      dist = multinomial.Multinomial(total_count=5., probs=p)
148      pmf = dist.prob(counts)
149      # 5 choose 3 = 5 choose 2 = 10. 10 * (.9)^2 * (.1)^3 = 81/10000.
150      self.assertAllClose(81. / 10000, self.evaluate(pmf))
151      self.assertEqual((), pmf.get_shape())
152
153  def testPmfPStretchedInBroadcastWhenSameRank(self):
154    with self.cached_session():
155      p = [[0.1, 0.9]]
156      counts = [[1., 0], [0, 1]]
157      pmf = multinomial.Multinomial(total_count=1., probs=p).prob(counts)
158      self.assertAllClose([0.1, 0.9], self.evaluate(pmf))
159      self.assertEqual((2), pmf.get_shape())
160
161  def testPmfPStretchedInBroadcastWhenLowerRank(self):
162    with self.cached_session():
163      p = [0.1, 0.9]
164      counts = [[1., 0], [0, 1]]
165      pmf = multinomial.Multinomial(total_count=1., probs=p).prob(counts)
166      self.assertAllClose([0.1, 0.9], self.evaluate(pmf))
167      self.assertEqual((2), pmf.get_shape())
168
169  @test_util.run_v1_only("b/120545219")
170  def testPmfCountsStretchedInBroadcastWhenSameRank(self):
171    with self.cached_session():
172      p = [[0.1, 0.9], [0.7, 0.3]]
173      counts = [[1., 0]]
174      pmf = multinomial.Multinomial(total_count=1., probs=p).prob(counts)
175      self.assertAllClose(pmf.eval(), [0.1, 0.7])
176      self.assertEqual((2), pmf.get_shape())
177
178  @test_util.run_v1_only("b/120545219")
179  def testPmfCountsStretchedInBroadcastWhenLowerRank(self):
180    with self.cached_session():
181      p = [[0.1, 0.9], [0.7, 0.3]]
182      counts = [1., 0]
183      pmf = multinomial.Multinomial(total_count=1., probs=p).prob(counts)
184      self.assertAllClose(pmf.eval(), [0.1, 0.7])
185      self.assertEqual(pmf.get_shape(), (2))
186
187  def testPmfShapeCountsStretchedN(self):
188    with self.cached_session():
189      # [2, 2, 2]
190      p = [[[0.1, 0.9], [0.1, 0.9]], [[0.7, 0.3], [0.7, 0.3]]]
191      # [2, 2]
192      n = [[3., 3], [3, 3]]
193      # [2]
194      counts = [2., 1]
195      pmf = multinomial.Multinomial(total_count=n, probs=p).prob(counts)
196      self.evaluate(pmf)
197      self.assertEqual(pmf.get_shape(), (2, 2))
198
199  def testPmfShapeCountsPStretchedN(self):
200    with self.cached_session():
201      p = [0.1, 0.9]
202      counts = [3., 2]
203      n = np.full([4, 3], 5., dtype=np.float32)
204      pmf = multinomial.Multinomial(total_count=n, probs=p).prob(counts)
205      self.evaluate(pmf)
206      self.assertEqual((4, 3), pmf.get_shape())
207
208  @test_util.run_v1_only("b/120545219")
209  def testMultinomialMean(self):
210    with self.cached_session():
211      n = 5.
212      p = [0.1, 0.2, 0.7]
213      dist = multinomial.Multinomial(total_count=n, probs=p)
214      expected_means = 5 * np.array(p, dtype=np.float32)
215      self.assertEqual((3,), dist.mean().get_shape())
216      self.assertAllClose(expected_means, dist.mean().eval())
217
218  @test_util.run_v1_only("b/120545219")
219  def testMultinomialCovariance(self):
220    with self.cached_session():
221      n = 5.
222      p = [0.1, 0.2, 0.7]
223      dist = multinomial.Multinomial(total_count=n, probs=p)
224      expected_covariances = [[9. / 20, -1 / 10, -7 / 20],
225                              [-1 / 10, 4 / 5, -7 / 10],
226                              [-7 / 20, -7 / 10, 21 / 20]]
227      self.assertEqual((3, 3), dist.covariance().get_shape())
228      self.assertAllClose(expected_covariances, dist.covariance().eval())
229
230  @test_util.run_v1_only("b/120545219")
231  def testMultinomialCovarianceBatch(self):
232    with self.cached_session():
233      # Shape [2]
234      n = [5.] * 2
235      # Shape [4, 1, 2]
236      p = [[[0.1, 0.9]], [[0.1, 0.9]]] * 2
237      dist = multinomial.Multinomial(total_count=n, probs=p)
238      # Shape [2, 2]
239      inner_var = [[9. / 20, -9 / 20], [-9 / 20, 9 / 20]]
240      # Shape [4, 2, 2, 2]
241      expected_covariances = [[inner_var, inner_var]] * 4
242      self.assertEqual((4, 2, 2, 2), dist.covariance().get_shape())
243      self.assertAllClose(expected_covariances, dist.covariance().eval())
244
245  def testCovarianceMultidimensional(self):
246    # Shape [3, 5, 4]
247    p = np.random.dirichlet([.25, .25, .25, .25], [3, 5]).astype(np.float32)
248    # Shape [6, 3, 3]
249    p2 = np.random.dirichlet([.3, .3, .4], [6, 3]).astype(np.float32)
250
251    ns = np.random.randint(low=1, high=11, size=[3, 5]).astype(np.float32)
252    ns2 = np.random.randint(low=1, high=11, size=[6, 1]).astype(np.float32)
253
254    with self.cached_session():
255      dist = multinomial.Multinomial(ns, p)
256      dist2 = multinomial.Multinomial(ns2, p2)
257
258      covariance = dist.covariance()
259      covariance2 = dist2.covariance()
260      self.assertEqual((3, 5, 4, 4), covariance.get_shape())
261      self.assertEqual((6, 3, 3, 3), covariance2.get_shape())
262
263  @test_util.run_v1_only("b/120545219")
264  def testCovarianceFromSampling(self):
265    # We will test mean, cov, var, stddev on a DirichletMultinomial constructed
266    # via broadcast between alpha, n.
267    theta = np.array([[1., 2, 3],
268                      [2.5, 4, 0.01]], dtype=np.float32)
269    theta /= np.sum(theta, 1)[..., array_ops.newaxis]
270    n = np.array([[10., 9.], [8., 7.], [6., 5.]], dtype=np.float32)
271    with self.cached_session() as sess:
272      # batch_shape=[3, 2], event_shape=[3]
273      dist = multinomial.Multinomial(n, theta)
274      x = dist.sample(int(1000e3), seed=1)
275      sample_mean = math_ops.reduce_mean(x, 0)
276      x_centered = x - sample_mean[array_ops.newaxis, ...]
277      sample_cov = math_ops.reduce_mean(math_ops.matmul(
278          x_centered[..., array_ops.newaxis],
279          x_centered[..., array_ops.newaxis, :]), 0)
280      sample_var = array_ops.matrix_diag_part(sample_cov)
281      sample_stddev = math_ops.sqrt(sample_var)
282      [
283          sample_mean_,
284          sample_cov_,
285          sample_var_,
286          sample_stddev_,
287          analytic_mean,
288          analytic_cov,
289          analytic_var,
290          analytic_stddev,
291      ] = sess.run([
292          sample_mean,
293          sample_cov,
294          sample_var,
295          sample_stddev,
296          dist.mean(),
297          dist.covariance(),
298          dist.variance(),
299          dist.stddev(),
300      ])
301      self.assertAllClose(sample_mean_, analytic_mean, atol=0.01, rtol=0.01)
302      self.assertAllClose(sample_cov_, analytic_cov, atol=0.01, rtol=0.01)
303      self.assertAllClose(sample_var_, analytic_var, atol=0.01, rtol=0.01)
304      self.assertAllClose(sample_stddev_, analytic_stddev, atol=0.01, rtol=0.01)
305
306  @test_util.run_v1_only("b/120545219")
307  def testSampleUnbiasedNonScalarBatch(self):
308    with self.cached_session() as sess:
309      dist = multinomial.Multinomial(
310          total_count=[7., 6., 5.],
311          logits=math_ops.log(2. * self._rng.rand(4, 3, 2).astype(np.float32)))
312      n = int(3e4)
313      x = dist.sample(n, seed=0)
314      sample_mean = math_ops.reduce_mean(x, 0)
315      # Cyclically rotate event dims left.
316      x_centered = array_ops.transpose(x - sample_mean, [1, 2, 3, 0])
317      sample_covariance = math_ops.matmul(
318          x_centered, x_centered, adjoint_b=True) / n
319      [
320          sample_mean_,
321          sample_covariance_,
322          actual_mean_,
323          actual_covariance_,
324      ] = sess.run([
325          sample_mean,
326          sample_covariance,
327          dist.mean(),
328          dist.covariance(),
329      ])
330      self.assertAllEqual([4, 3, 2], sample_mean.get_shape())
331      self.assertAllClose(actual_mean_, sample_mean_, atol=0., rtol=0.10)
332      self.assertAllEqual([4, 3, 2, 2], sample_covariance.get_shape())
333      self.assertAllClose(
334          actual_covariance_, sample_covariance_, atol=0., rtol=0.20)
335
336  @test_util.run_v1_only("b/120545219")
337  def testSampleUnbiasedScalarBatch(self):
338    with self.cached_session() as sess:
339      dist = multinomial.Multinomial(
340          total_count=5.,
341          logits=math_ops.log(2. * self._rng.rand(4).astype(np.float32)))
342      n = int(5e3)
343      x = dist.sample(n, seed=0)
344      sample_mean = math_ops.reduce_mean(x, 0)
345      x_centered = x - sample_mean  # Already transposed to [n, 2].
346      sample_covariance = math_ops.matmul(
347          x_centered, x_centered, adjoint_a=True) / n
348      [
349          sample_mean_,
350          sample_covariance_,
351          actual_mean_,
352          actual_covariance_,
353      ] = sess.run([
354          sample_mean,
355          sample_covariance,
356          dist.mean(),
357          dist.covariance(),
358      ])
359      self.assertAllEqual([4], sample_mean.get_shape())
360      self.assertAllClose(actual_mean_, sample_mean_, atol=0., rtol=0.10)
361      self.assertAllEqual([4, 4], sample_covariance.get_shape())
362      self.assertAllClose(
363          actual_covariance_, sample_covariance_, atol=0., rtol=0.20)
364
365  def testNotReparameterized(self):
366    total_count = constant_op.constant(5.0)
367    p = constant_op.constant([0.2, 0.6])
368    with backprop.GradientTape() as tape:
369      tape.watch(total_count)
370      tape.watch(p)
371      dist = multinomial.Multinomial(
372          total_count=total_count,
373          probs=p)
374      samples = dist.sample(100)
375    grad_total_count, grad_p = tape.gradient(samples, [total_count, p])
376    self.assertIsNone(grad_total_count)
377    self.assertIsNone(grad_p)
378
379
380if __name__ == "__main__":
381  test.main()
382