• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15from __future__ import absolute_import
16from __future__ import division
17from __future__ import print_function
18
19import numpy as np
20
21from tensorflow.python.eager import backprop
22from tensorflow.python.framework import constant_op
23from tensorflow.python.framework import dtypes
24from tensorflow.python.framework import tensor_shape
25from tensorflow.python.framework import test_util
26from tensorflow.python.ops import array_ops
27from tensorflow.python.ops import math_ops
28from tensorflow.python.ops.distributions import dirichlet_multinomial
29from tensorflow.python.platform import test
30
31
32ds = dirichlet_multinomial
33
34
35class DirichletMultinomialTest(test.TestCase):
36
37  def setUp(self):
38    self._rng = np.random.RandomState(42)
39
40  @test_util.run_deprecated_v1
41  def testSimpleShapes(self):
42    with self.cached_session():
43      alpha = np.random.rand(3)
44      dist = ds.DirichletMultinomial(1., alpha)
45      self.assertEqual(3, dist.event_shape_tensor().eval())
46      self.assertAllEqual([], dist.batch_shape_tensor().eval())
47      self.assertEqual(tensor_shape.TensorShape([3]), dist.event_shape)
48      self.assertEqual(tensor_shape.TensorShape([]), dist.batch_shape)
49
50  @test_util.run_deprecated_v1
51  def testComplexShapes(self):
52    with self.cached_session():
53      alpha = np.random.rand(3, 2, 2)
54      n = [[3., 2], [4, 5], [6, 7]]
55      dist = ds.DirichletMultinomial(n, alpha)
56      self.assertEqual(2, dist.event_shape_tensor().eval())
57      self.assertAllEqual([3, 2], dist.batch_shape_tensor().eval())
58      self.assertEqual(tensor_shape.TensorShape([2]), dist.event_shape)
59      self.assertEqual(tensor_shape.TensorShape([3, 2]), dist.batch_shape)
60
61  @test_util.run_deprecated_v1
62  def testNproperty(self):
63    alpha = [[1., 2, 3]]
64    n = [[5.]]
65    with self.cached_session():
66      dist = ds.DirichletMultinomial(n, alpha)
67      self.assertEqual([1, 1], dist.total_count.get_shape())
68      self.assertAllClose(n, dist.total_count.eval())
69
70  @test_util.run_deprecated_v1
71  def testAlphaProperty(self):
72    alpha = [[1., 2, 3]]
73    with self.cached_session():
74      dist = ds.DirichletMultinomial(1, alpha)
75      self.assertEqual([1, 3], dist.concentration.get_shape())
76      self.assertAllClose(alpha, dist.concentration.eval())
77
78  @test_util.run_deprecated_v1
79  def testPmfNandCountsAgree(self):
80    alpha = [[1., 2, 3]]
81    n = [[5.]]
82    with self.cached_session():
83      dist = ds.DirichletMultinomial(n, alpha, validate_args=True)
84      dist.prob([2., 3, 0]).eval()
85      dist.prob([3., 0, 2]).eval()
86      with self.assertRaisesOpError("must be non-negative"):
87        dist.prob([-1., 4, 2]).eval()
88      with self.assertRaisesOpError(
89          "last-dimension must sum to `self.total_count`"):
90        dist.prob([3., 3, 0]).eval()
91
92  @test_util.run_deprecated_v1
93  def testPmfNonIntegerCounts(self):
94    alpha = [[1., 2, 3]]
95    n = [[5.]]
96    with self.cached_session():
97      dist = ds.DirichletMultinomial(n, alpha, validate_args=True)
98      dist.prob([2., 3, 0]).eval()
99      dist.prob([3., 0, 2]).eval()
100      dist.prob([3.0, 0, 2.0]).eval()
101      # Both equality and integer checking fail.
102      placeholder = array_ops.placeholder(dtypes.float32)
103      with self.assertRaisesOpError(
104          "cannot contain fractional components"):
105        dist.prob(placeholder).eval(feed_dict={placeholder: [1.0, 2.5, 1.5]})
106      dist = ds.DirichletMultinomial(n, alpha, validate_args=False)
107      dist.prob([1., 2., 3.]).eval()
108      # Non-integer arguments work.
109      dist.prob([1.0, 2.5, 1.5]).eval()
110
111  def testPmfBothZeroBatches(self):
112    # The probabilities of one vote falling into class k is the mean for class
113    # k.
114    with self.cached_session():
115      # Both zero-batches.  No broadcast
116      alpha = [1., 2]
117      counts = [1., 0]
118      dist = ds.DirichletMultinomial(1., alpha)
119      pmf = dist.prob(counts)
120      self.assertAllClose(1 / 3., self.evaluate(pmf))
121      self.assertEqual((), pmf.get_shape())
122
123  def testPmfBothZeroBatchesNontrivialN(self):
124    # The probabilities of one vote falling into class k is the mean for class
125    # k.
126    with self.cached_session():
127      # Both zero-batches.  No broadcast
128      alpha = [1., 2]
129      counts = [3., 2]
130      dist = ds.DirichletMultinomial(5., alpha)
131      pmf = dist.prob(counts)
132      self.assertAllClose(1 / 7., self.evaluate(pmf))
133      self.assertEqual((), pmf.get_shape())
134
135  def testPmfBothZeroBatchesMultidimensionalN(self):
136    # The probabilities of one vote falling into class k is the mean for class
137    # k.
138    with self.cached_session():
139      alpha = [1., 2]
140      counts = [3., 2]
141      n = np.full([4, 3], 5., dtype=np.float32)
142      dist = ds.DirichletMultinomial(n, alpha)
143      pmf = dist.prob(counts)
144      self.assertAllClose([[1 / 7., 1 / 7., 1 / 7.]] * 4, self.evaluate(pmf))
145      self.assertEqual((4, 3), pmf.get_shape())
146
147  def testPmfAlphaStretchedInBroadcastWhenSameRank(self):
148    # The probabilities of one vote falling into class k is the mean for class
149    # k.
150    with self.cached_session():
151      alpha = [[1., 2]]
152      counts = [[1., 0], [0., 1]]
153      dist = ds.DirichletMultinomial([1.], alpha)
154      pmf = dist.prob(counts)
155      self.assertAllClose([1 / 3., 2 / 3.], self.evaluate(pmf))
156      self.assertAllEqual([2], pmf.get_shape())
157
158  def testPmfAlphaStretchedInBroadcastWhenLowerRank(self):
159    # The probabilities of one vote falling into class k is the mean for class
160    # k.
161    with self.cached_session():
162      alpha = [1., 2]
163      counts = [[1., 0], [0., 1]]
164      pmf = ds.DirichletMultinomial(1., alpha).prob(counts)
165      self.assertAllClose([1 / 3., 2 / 3.], self.evaluate(pmf))
166      self.assertAllEqual([2], pmf.get_shape())
167
168  def testPmfCountsStretchedInBroadcastWhenSameRank(self):
169    # The probabilities of one vote falling into class k is the mean for class
170    # k.
171    with self.cached_session():
172      alpha = [[1., 2], [2., 3]]
173      counts = [[1., 0]]
174      pmf = ds.DirichletMultinomial([1., 1.], alpha).prob(counts)
175      self.assertAllClose([1 / 3., 2 / 5.], self.evaluate(pmf))
176      self.assertAllEqual([2], pmf.get_shape())
177
178  def testPmfCountsStretchedInBroadcastWhenLowerRank(self):
179    # The probabilities of one vote falling into class k is the mean for class
180    # k.
181    with self.cached_session():
182      alpha = [[1., 2], [2., 3]]
183      counts = [1., 0]
184      pmf = ds.DirichletMultinomial(1., alpha).prob(counts)
185      self.assertAllClose([1 / 3., 2 / 5.], self.evaluate(pmf))
186      self.assertAllEqual([2], pmf.get_shape())
187
188  @test_util.run_deprecated_v1
189  def testPmfForOneVoteIsTheMeanWithOneRecordInput(self):
190    # The probabilities of one vote falling into class k is the mean for class
191    # k.
192    alpha = [1., 2, 3]
193    with self.cached_session():
194      for class_num in range(3):
195        counts = np.zeros([3], dtype=np.float32)
196        counts[class_num] = 1
197        dist = ds.DirichletMultinomial(1., alpha)
198        mean = dist.mean().eval()
199        pmf = dist.prob(counts).eval()
200
201        self.assertAllClose(mean[class_num], pmf)
202        self.assertAllEqual([3], mean.shape)
203        self.assertAllEqual([], pmf.shape)
204
205  @test_util.run_deprecated_v1
206  def testMeanDoubleTwoVotes(self):
207    # The probabilities of two votes falling into class k for
208    # DirichletMultinomial(2, alpha) is twice as much as the probability of one
209    # vote falling into class k for DirichletMultinomial(1, alpha)
210    alpha = [1., 2, 3]
211    with self.cached_session():
212      for class_num in range(3):
213        counts_one = np.zeros([3], dtype=np.float32)
214        counts_one[class_num] = 1.
215        counts_two = np.zeros([3], dtype=np.float32)
216        counts_two[class_num] = 2
217
218        dist1 = ds.DirichletMultinomial(1., alpha)
219        dist2 = ds.DirichletMultinomial(2., alpha)
220
221        mean1 = dist1.mean().eval()
222        mean2 = dist2.mean().eval()
223
224        self.assertAllClose(mean2[class_num], 2 * mean1[class_num])
225        self.assertAllEqual([3], mean1.shape)
226
227  @test_util.run_deprecated_v1
228  def testCovarianceFromSampling(self):
229    # We will test mean, cov, var, stddev on a DirichletMultinomial constructed
230    # via broadcast between alpha, n.
231    alpha = np.array([[1., 2, 3],
232                      [2.5, 4, 0.01]], dtype=np.float32)
233    # Ideally we'd be able to test broadcasting but, the multinomial sampler
234    # doesn't support different total counts.
235    n = np.float32(5)
236    with self.cached_session() as sess:
237      # batch_shape=[2], event_shape=[3]
238      dist = ds.DirichletMultinomial(n, alpha)
239      x = dist.sample(int(250e3), seed=1)
240      sample_mean = math_ops.reduce_mean(x, 0)
241      x_centered = x - sample_mean[array_ops.newaxis, ...]
242      sample_cov = math_ops.reduce_mean(math_ops.matmul(
243          x_centered[..., array_ops.newaxis],
244          x_centered[..., array_ops.newaxis, :]), 0)
245      sample_var = array_ops.matrix_diag_part(sample_cov)
246      sample_stddev = math_ops.sqrt(sample_var)
247      [
248          sample_mean_,
249          sample_cov_,
250          sample_var_,
251          sample_stddev_,
252          analytic_mean,
253          analytic_cov,
254          analytic_var,
255          analytic_stddev,
256      ] = sess.run([
257          sample_mean,
258          sample_cov,
259          sample_var,
260          sample_stddev,
261          dist.mean(),
262          dist.covariance(),
263          dist.variance(),
264          dist.stddev(),
265      ])
266      self.assertAllClose(sample_mean_, analytic_mean, atol=0.04, rtol=0.)
267      self.assertAllClose(sample_cov_, analytic_cov, atol=0.05, rtol=0.)
268      self.assertAllClose(sample_var_, analytic_var, atol=0.05, rtol=0.)
269      self.assertAllClose(sample_stddev_, analytic_stddev, atol=0.02, rtol=0.)
270
271  def testCovariance(self):
272    # Shape [2]
273    alpha = [1., 2]
274    ns = [2., 3., 4., 5.]
275    alpha_0 = np.sum(alpha)
276
277    # Diagonal entries are of the form:
278    # Var(X_i) = n * alpha_i / alpha_sum * (1 - alpha_i / alpha_sum) *
279    # (alpha_sum + n) / (alpha_sum + 1)
280    variance_entry = lambda a, a_sum: a / a_sum * (1 - a / a_sum)
281    # Off diagonal entries are of the form:
282    # Cov(X_i, X_j) = -n * alpha_i * alpha_j / (alpha_sum ** 2) *
283    # (alpha_sum + n) / (alpha_sum + 1)
284    covariance_entry = lambda a, b, a_sum: -a * b / a_sum**2
285    # Shape [2, 2].
286    shared_matrix = np.array([[
287        variance_entry(alpha[0], alpha_0),
288        covariance_entry(alpha[0], alpha[1], alpha_0)
289    ], [
290        covariance_entry(alpha[1], alpha[0], alpha_0),
291        variance_entry(alpha[1], alpha_0)
292    ]])
293
294    with self.cached_session():
295      for n in ns:
296        # n is shape [] and alpha is shape [2].
297        dist = ds.DirichletMultinomial(n, alpha)
298        covariance = dist.covariance()
299        expected_covariance = n * (n + alpha_0) / (1 + alpha_0) * shared_matrix
300
301        self.assertEqual([2, 2], covariance.get_shape())
302        self.assertAllClose(expected_covariance, self.evaluate(covariance))
303
304  def testCovarianceNAlphaBroadcast(self):
305    alpha_v = [1., 2, 3]
306    alpha_0 = 6.
307
308    # Shape [4, 3]
309    alpha = np.array(4 * [alpha_v], dtype=np.float32)
310    # Shape [4, 1]
311    ns = np.array([[2.], [3.], [4.], [5.]], dtype=np.float32)
312
313    variance_entry = lambda a, a_sum: a / a_sum * (1 - a / a_sum)
314    covariance_entry = lambda a, b, a_sum: -a * b / a_sum**2
315    # Shape [4, 3, 3]
316    shared_matrix = np.array(
317        4 * [[[
318            variance_entry(alpha_v[0], alpha_0),
319            covariance_entry(alpha_v[0], alpha_v[1], alpha_0),
320            covariance_entry(alpha_v[0], alpha_v[2], alpha_0)
321        ], [
322            covariance_entry(alpha_v[1], alpha_v[0], alpha_0),
323            variance_entry(alpha_v[1], alpha_0),
324            covariance_entry(alpha_v[1], alpha_v[2], alpha_0)
325        ], [
326            covariance_entry(alpha_v[2], alpha_v[0], alpha_0),
327            covariance_entry(alpha_v[2], alpha_v[1], alpha_0),
328            variance_entry(alpha_v[2], alpha_0)
329        ]]],
330        dtype=np.float32)
331
332    with self.cached_session():
333      # ns is shape [4, 1], and alpha is shape [4, 3].
334      dist = ds.DirichletMultinomial(ns, alpha)
335      covariance = dist.covariance()
336      expected_covariance = shared_matrix * (
337          ns * (ns + alpha_0) / (1 + alpha_0))[..., array_ops.newaxis]
338
339      self.assertEqual([4, 3, 3], covariance.get_shape())
340      self.assertAllClose(expected_covariance, self.evaluate(covariance))
341
342  def testCovarianceMultidimensional(self):
343    alpha = np.random.rand(3, 5, 4).astype(np.float32)
344    alpha2 = np.random.rand(6, 3, 3).astype(np.float32)
345
346    ns = np.random.randint(low=1, high=11, size=[3, 5, 1]).astype(np.float32)
347    ns2 = np.random.randint(low=1, high=11, size=[6, 1, 1]).astype(np.float32)
348
349    with self.cached_session():
350      dist = ds.DirichletMultinomial(ns, alpha)
351      dist2 = ds.DirichletMultinomial(ns2, alpha2)
352
353      covariance = dist.covariance()
354      covariance2 = dist2.covariance()
355      self.assertEqual([3, 5, 4, 4], covariance.get_shape())
356      self.assertEqual([6, 3, 3, 3], covariance2.get_shape())
357
358  def testZeroCountsResultsInPmfEqualToOne(self):
359    # There is only one way for zero items to be selected, and this happens with
360    # probability 1.
361    alpha = [5, 0.5]
362    counts = [0., 0]
363    with self.cached_session():
364      dist = ds.DirichletMultinomial(0., alpha)
365      pmf = dist.prob(counts)
366      self.assertAllClose(1.0, self.evaluate(pmf))
367      self.assertEqual((), pmf.get_shape())
368
369  def testLargeTauGivesPreciseProbabilities(self):
370    # If tau is large, we are doing coin flips with probability mu.
371    mu = np.array([0.1, 0.1, 0.8], dtype=np.float32)
372    tau = np.array([100.], dtype=np.float32)
373    alpha = tau * mu
374
375    # One (three sided) coin flip.  Prob[coin 3] = 0.8.
376    # Note that since it was one flip, value of tau didn't matter.
377    counts = [0., 0, 1]
378    with self.cached_session():
379      dist = ds.DirichletMultinomial(1., alpha)
380      pmf = dist.prob(counts)
381      self.assertAllClose(0.8, self.evaluate(pmf), atol=1e-4)
382      self.assertEqual((), pmf.get_shape())
383
384    # Two (three sided) coin flips.  Prob[coin 3] = 0.8.
385    counts = [0., 0, 2]
386    with self.cached_session():
387      dist = ds.DirichletMultinomial(2., alpha)
388      pmf = dist.prob(counts)
389      self.assertAllClose(0.8**2, self.evaluate(pmf), atol=1e-2)
390      self.assertEqual((), pmf.get_shape())
391
392    # Three (three sided) coin flips.
393    counts = [1., 0, 2]
394    with self.cached_session():
395      dist = ds.DirichletMultinomial(3., alpha)
396      pmf = dist.prob(counts)
397      self.assertAllClose(3 * 0.1 * 0.8 * 0.8, self.evaluate(pmf), atol=1e-2)
398      self.assertEqual((), pmf.get_shape())
399
400  def testSmallTauPrefersCorrelatedResults(self):
401    # If tau is small, then correlation between draws is large, so draws that
402    # are both of the same class are more likely.
403    mu = np.array([0.5, 0.5], dtype=np.float32)
404    tau = np.array([0.1], dtype=np.float32)
405    alpha = tau * mu
406
407    # If there is only one draw, it is still a coin flip, even with small tau.
408    counts = [1., 0]
409    with self.cached_session():
410      dist = ds.DirichletMultinomial(1., alpha)
411      pmf = dist.prob(counts)
412      self.assertAllClose(0.5, self.evaluate(pmf))
413      self.assertEqual((), pmf.get_shape())
414
415    # If there are two draws, it is much more likely that they are the same.
416    counts_same = [2., 0]
417    counts_different = [1, 1.]
418    with self.cached_session():
419      dist = ds.DirichletMultinomial(2., alpha)
420      pmf_same = dist.prob(counts_same)
421      pmf_different = dist.prob(counts_different)
422      self.assertLess(5 * self.evaluate(pmf_different), self.evaluate(pmf_same))
423      self.assertEqual((), pmf_same.get_shape())
424
425  @test_util.run_deprecated_v1
426  def testNonStrictTurnsOffAllChecks(self):
427    # Make totally invalid input.
428    with self.cached_session():
429      alpha = [[-1., 2]]  # alpha should be positive.
430      counts = [[1., 0], [0., -1]]  # counts should be non-negative.
431      n = [-5.3]  # n should be a non negative integer equal to counts.sum.
432      dist = ds.DirichletMultinomial(n, alpha, validate_args=False)
433      dist.prob(counts).eval()  # Should not raise.
434
435  @test_util.run_deprecated_v1
436  def testSampleUnbiasedNonScalarBatch(self):
437    with self.cached_session() as sess:
438      dist = ds.DirichletMultinomial(
439          total_count=5.,
440          concentration=1. + 2. * self._rng.rand(4, 3, 2).astype(np.float32))
441      n = int(3e3)
442      x = dist.sample(n, seed=0)
443      sample_mean = math_ops.reduce_mean(x, 0)
444      # Cyclically rotate event dims left.
445      x_centered = array_ops.transpose(x - sample_mean, [1, 2, 3, 0])
446      sample_covariance = math_ops.matmul(
447          x_centered, x_centered, adjoint_b=True) / n
448      [
449          sample_mean_,
450          sample_covariance_,
451          actual_mean_,
452          actual_covariance_,
453      ] = sess.run([
454          sample_mean,
455          sample_covariance,
456          dist.mean(),
457          dist.covariance(),
458      ])
459      self.assertAllEqual([4, 3, 2], sample_mean.get_shape())
460      self.assertAllClose(actual_mean_, sample_mean_, atol=0., rtol=0.20)
461      self.assertAllEqual([4, 3, 2, 2], sample_covariance.get_shape())
462      self.assertAllClose(
463          actual_covariance_, sample_covariance_, atol=0., rtol=0.20)
464
465  @test_util.run_deprecated_v1
466  def testSampleUnbiasedScalarBatch(self):
467    with self.cached_session() as sess:
468      dist = ds.DirichletMultinomial(
469          total_count=5.,
470          concentration=1. + 2. * self._rng.rand(4).astype(np.float32))
471      n = int(5e3)
472      x = dist.sample(n, seed=0)
473      sample_mean = math_ops.reduce_mean(x, 0)
474      x_centered = x - sample_mean  # Already transposed to [n, 2].
475      sample_covariance = math_ops.matmul(
476          x_centered, x_centered, adjoint_a=True) / n
477      [
478          sample_mean_,
479          sample_covariance_,
480          actual_mean_,
481          actual_covariance_,
482      ] = sess.run([
483          sample_mean,
484          sample_covariance,
485          dist.mean(),
486          dist.covariance(),
487      ])
488      self.assertAllEqual([4], sample_mean.get_shape())
489      self.assertAllClose(actual_mean_, sample_mean_, atol=0., rtol=0.20)
490      self.assertAllEqual([4, 4], sample_covariance.get_shape())
491      self.assertAllClose(
492          actual_covariance_, sample_covariance_, atol=0., rtol=0.20)
493
494  def testNotReparameterized(self):
495    total_count = constant_op.constant(5.0)
496    concentration = constant_op.constant([0.1, 0.1, 0.1])
497    with backprop.GradientTape() as tape:
498      tape.watch(total_count)
499      tape.watch(concentration)
500      dist = ds.DirichletMultinomial(
501          total_count=total_count,
502          concentration=concentration)
503      samples = dist.sample(100)
504    grad_total_count, grad_concentration = tape.gradient(
505        samples, [total_count, concentration])
506    self.assertIsNone(grad_total_count)
507    self.assertIsNone(grad_concentration)
508
509
510if __name__ == "__main__":
511  test.main()
512