• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15import numpy as np
16
17from tensorflow.python.eager import backprop
18from tensorflow.python.framework import constant_op
19from tensorflow.python.framework import dtypes
20from tensorflow.python.framework import tensor_shape
21from tensorflow.python.framework import test_util
22from tensorflow.python.ops import array_ops
23from tensorflow.python.ops import math_ops
24from tensorflow.python.ops.distributions import dirichlet_multinomial
25from tensorflow.python.platform import test
26
27
28ds = dirichlet_multinomial
29
30
31class DirichletMultinomialTest(test.TestCase):
32
33  def setUp(self):
34    self._rng = np.random.RandomState(42)
35
36  @test_util.run_deprecated_v1
37  def testSimpleShapes(self):
38    with self.cached_session():
39      alpha = np.random.rand(3)
40      dist = ds.DirichletMultinomial(1., alpha)
41      self.assertEqual(3, dist.event_shape_tensor().eval())
42      self.assertAllEqual([], dist.batch_shape_tensor())
43      self.assertEqual(tensor_shape.TensorShape([3]), dist.event_shape)
44      self.assertEqual(tensor_shape.TensorShape([]), dist.batch_shape)
45
46  @test_util.run_deprecated_v1
47  def testComplexShapes(self):
48    with self.cached_session():
49      alpha = np.random.rand(3, 2, 2)
50      n = [[3., 2], [4, 5], [6, 7]]
51      dist = ds.DirichletMultinomial(n, alpha)
52      self.assertEqual(2, dist.event_shape_tensor().eval())
53      self.assertAllEqual([3, 2], dist.batch_shape_tensor())
54      self.assertEqual(tensor_shape.TensorShape([2]), dist.event_shape)
55      self.assertEqual(tensor_shape.TensorShape([3, 2]), dist.batch_shape)
56
57  @test_util.run_deprecated_v1
58  def testNproperty(self):
59    alpha = [[1., 2, 3]]
60    n = [[5.]]
61    with self.cached_session():
62      dist = ds.DirichletMultinomial(n, alpha)
63      self.assertEqual([1, 1], dist.total_count.get_shape())
64      self.assertAllClose(n, dist.total_count)
65
66  @test_util.run_deprecated_v1
67  def testAlphaProperty(self):
68    alpha = [[1., 2, 3]]
69    with self.cached_session():
70      dist = ds.DirichletMultinomial(1, alpha)
71      self.assertEqual([1, 3], dist.concentration.get_shape())
72      self.assertAllClose(alpha, dist.concentration)
73
74  @test_util.run_deprecated_v1
75  def testPmfNandCountsAgree(self):
76    alpha = [[1., 2, 3]]
77    n = [[5.]]
78    with self.cached_session():
79      dist = ds.DirichletMultinomial(n, alpha, validate_args=True)
80      dist.prob([2., 3, 0]).eval()
81      dist.prob([3., 0, 2]).eval()
82      with self.assertRaisesOpError("must be non-negative"):
83        dist.prob([-1., 4, 2]).eval()
84      with self.assertRaisesOpError(
85          "last-dimension must sum to `self.total_count`"):
86        dist.prob([3., 3, 0]).eval()
87
88  @test_util.run_deprecated_v1
89  def testPmfNonIntegerCounts(self):
90    alpha = [[1., 2, 3]]
91    n = [[5.]]
92    with self.cached_session():
93      dist = ds.DirichletMultinomial(n, alpha, validate_args=True)
94      dist.prob([2., 3, 0]).eval()
95      dist.prob([3., 0, 2]).eval()
96      dist.prob([3.0, 0, 2.0]).eval()
97      # Both equality and integer checking fail.
98      placeholder = array_ops.placeholder(dtypes.float32)
99      with self.assertRaisesOpError(
100          "cannot contain fractional components"):
101        dist.prob(placeholder).eval(feed_dict={placeholder: [1.0, 2.5, 1.5]})
102      dist = ds.DirichletMultinomial(n, alpha, validate_args=False)
103      dist.prob([1., 2., 3.]).eval()
104      # Non-integer arguments work.
105      dist.prob([1.0, 2.5, 1.5]).eval()
106
107  def testPmfBothZeroBatches(self):
108    # The probabilities of one vote falling into class k is the mean for class
109    # k.
110    with self.cached_session():
111      # Both zero-batches.  No broadcast
112      alpha = [1., 2]
113      counts = [1., 0]
114      dist = ds.DirichletMultinomial(1., alpha)
115      pmf = dist.prob(counts)
116      self.assertAllClose(1 / 3., self.evaluate(pmf))
117      self.assertEqual((), pmf.get_shape())
118
119  def testPmfBothZeroBatchesNontrivialN(self):
120    # The probabilities of one vote falling into class k is the mean for class
121    # k.
122    with self.cached_session():
123      # Both zero-batches.  No broadcast
124      alpha = [1., 2]
125      counts = [3., 2]
126      dist = ds.DirichletMultinomial(5., alpha)
127      pmf = dist.prob(counts)
128      self.assertAllClose(1 / 7., self.evaluate(pmf))
129      self.assertEqual((), pmf.get_shape())
130
131  def testPmfBothZeroBatchesMultidimensionalN(self):
132    # The probabilities of one vote falling into class k is the mean for class
133    # k.
134    with self.cached_session():
135      alpha = [1., 2]
136      counts = [3., 2]
137      n = np.full([4, 3], 5., dtype=np.float32)
138      dist = ds.DirichletMultinomial(n, alpha)
139      pmf = dist.prob(counts)
140      self.assertAllClose([[1 / 7., 1 / 7., 1 / 7.]] * 4, self.evaluate(pmf))
141      self.assertEqual((4, 3), pmf.get_shape())
142
143  def testPmfAlphaStretchedInBroadcastWhenSameRank(self):
144    # The probabilities of one vote falling into class k is the mean for class
145    # k.
146    with self.cached_session():
147      alpha = [[1., 2]]
148      counts = [[1., 0], [0., 1]]
149      dist = ds.DirichletMultinomial([1.], alpha)
150      pmf = dist.prob(counts)
151      self.assertAllClose([1 / 3., 2 / 3.], self.evaluate(pmf))
152      self.assertAllEqual([2], pmf.get_shape())
153
154  def testPmfAlphaStretchedInBroadcastWhenLowerRank(self):
155    # The probabilities of one vote falling into class k is the mean for class
156    # k.
157    with self.cached_session():
158      alpha = [1., 2]
159      counts = [[1., 0], [0., 1]]
160      pmf = ds.DirichletMultinomial(1., alpha).prob(counts)
161      self.assertAllClose([1 / 3., 2 / 3.], self.evaluate(pmf))
162      self.assertAllEqual([2], pmf.get_shape())
163
164  def testPmfCountsStretchedInBroadcastWhenSameRank(self):
165    # The probabilities of one vote falling into class k is the mean for class
166    # k.
167    with self.cached_session():
168      alpha = [[1., 2], [2., 3]]
169      counts = [[1., 0]]
170      pmf = ds.DirichletMultinomial([1., 1.], alpha).prob(counts)
171      self.assertAllClose([1 / 3., 2 / 5.], self.evaluate(pmf))
172      self.assertAllEqual([2], pmf.get_shape())
173
174  def testPmfCountsStretchedInBroadcastWhenLowerRank(self):
175    # The probabilities of one vote falling into class k is the mean for class
176    # k.
177    with self.cached_session():
178      alpha = [[1., 2], [2., 3]]
179      counts = [1., 0]
180      pmf = ds.DirichletMultinomial(1., alpha).prob(counts)
181      self.assertAllClose([1 / 3., 2 / 5.], self.evaluate(pmf))
182      self.assertAllEqual([2], pmf.get_shape())
183
184  @test_util.run_deprecated_v1
185  def testPmfForOneVoteIsTheMeanWithOneRecordInput(self):
186    # The probabilities of one vote falling into class k is the mean for class
187    # k.
188    alpha = [1., 2, 3]
189    with self.cached_session():
190      for class_num in range(3):
191        counts = np.zeros([3], dtype=np.float32)
192        counts[class_num] = 1
193        dist = ds.DirichletMultinomial(1., alpha)
194        mean = dist.mean().eval()
195        pmf = dist.prob(counts).eval()
196
197        self.assertAllClose(mean[class_num], pmf)
198        self.assertAllEqual([3], mean.shape)
199        self.assertAllEqual([], pmf.shape)
200
201  @test_util.run_deprecated_v1
202  def testMeanDoubleTwoVotes(self):
203    # The probabilities of two votes falling into class k for
204    # DirichletMultinomial(2, alpha) is twice as much as the probability of one
205    # vote falling into class k for DirichletMultinomial(1, alpha)
206    alpha = [1., 2, 3]
207    with self.cached_session():
208      for class_num in range(3):
209        counts_one = np.zeros([3], dtype=np.float32)
210        counts_one[class_num] = 1.
211        counts_two = np.zeros([3], dtype=np.float32)
212        counts_two[class_num] = 2
213
214        dist1 = ds.DirichletMultinomial(1., alpha)
215        dist2 = ds.DirichletMultinomial(2., alpha)
216
217        mean1 = dist1.mean().eval()
218        mean2 = dist2.mean().eval()
219
220        self.assertAllClose(mean2[class_num], 2 * mean1[class_num])
221        self.assertAllEqual([3], mean1.shape)
222
223  @test_util.run_deprecated_v1
224  def testCovarianceFromSampling(self):
225    # We will test mean, cov, var, stddev on a DirichletMultinomial constructed
226    # via broadcast between alpha, n.
227    alpha = np.array([[1., 2, 3],
228                      [2.5, 4, 0.01]], dtype=np.float32)
229    # Ideally we'd be able to test broadcasting but, the multinomial sampler
230    # doesn't support different total counts.
231    n = np.float32(5)
232    with self.cached_session() as sess:
233      # batch_shape=[2], event_shape=[3]
234      dist = ds.DirichletMultinomial(n, alpha)
235      x = dist.sample(int(250e3), seed=1)
236      sample_mean = math_ops.reduce_mean(x, 0)
237      x_centered = x - sample_mean[array_ops.newaxis, ...]
238      sample_cov = math_ops.reduce_mean(math_ops.matmul(
239          x_centered[..., array_ops.newaxis],
240          x_centered[..., array_ops.newaxis, :]), 0)
241      sample_var = array_ops.matrix_diag_part(sample_cov)
242      sample_stddev = math_ops.sqrt(sample_var)
243      [
244          sample_mean_,
245          sample_cov_,
246          sample_var_,
247          sample_stddev_,
248          analytic_mean,
249          analytic_cov,
250          analytic_var,
251          analytic_stddev,
252      ] = sess.run([
253          sample_mean,
254          sample_cov,
255          sample_var,
256          sample_stddev,
257          dist.mean(),
258          dist.covariance(),
259          dist.variance(),
260          dist.stddev(),
261      ])
262      self.assertAllClose(sample_mean_, analytic_mean, atol=0.04, rtol=0.)
263      self.assertAllClose(sample_cov_, analytic_cov, atol=0.05, rtol=0.)
264      self.assertAllClose(sample_var_, analytic_var, atol=0.05, rtol=0.)
265      self.assertAllClose(sample_stddev_, analytic_stddev, atol=0.02, rtol=0.)
266
267  @test_util.run_without_tensor_float_32(
268      "Tests DirichletMultinomial.covariance, which calls matmul")
269  def testCovariance(self):
270    # Shape [2]
271    alpha = [1., 2]
272    ns = [2., 3., 4., 5.]
273    alpha_0 = np.sum(alpha)
274
275    # Diagonal entries are of the form:
276    # Var(X_i) = n * alpha_i / alpha_sum * (1 - alpha_i / alpha_sum) *
277    # (alpha_sum + n) / (alpha_sum + 1)
278    variance_entry = lambda a, a_sum: a / a_sum * (1 - a / a_sum)
279    # Off diagonal entries are of the form:
280    # Cov(X_i, X_j) = -n * alpha_i * alpha_j / (alpha_sum ** 2) *
281    # (alpha_sum + n) / (alpha_sum + 1)
282    covariance_entry = lambda a, b, a_sum: -a * b / a_sum**2
283    # Shape [2, 2].
284    shared_matrix = np.array([[
285        variance_entry(alpha[0], alpha_0),
286        covariance_entry(alpha[0], alpha[1], alpha_0)
287    ], [
288        covariance_entry(alpha[1], alpha[0], alpha_0),
289        variance_entry(alpha[1], alpha_0)
290    ]])
291
292    with self.cached_session():
293      for n in ns:
294        # n is shape [] and alpha is shape [2].
295        dist = ds.DirichletMultinomial(n, alpha)
296        covariance = dist.covariance()
297        expected_covariance = n * (n + alpha_0) / (1 + alpha_0) * shared_matrix
298
299        self.assertEqual([2, 2], covariance.get_shape())
300        self.assertAllClose(expected_covariance, self.evaluate(covariance))
301
302  def testCovarianceNAlphaBroadcast(self):
303    alpha_v = [1., 2, 3]
304    alpha_0 = 6.
305
306    # Shape [4, 3]
307    alpha = np.array(4 * [alpha_v], dtype=np.float32)
308    # Shape [4, 1]
309    ns = np.array([[2.], [3.], [4.], [5.]], dtype=np.float32)
310
311    variance_entry = lambda a, a_sum: a / a_sum * (1 - a / a_sum)
312    covariance_entry = lambda a, b, a_sum: -a * b / a_sum**2
313    # Shape [4, 3, 3]
314    shared_matrix = np.array(
315        4 * [[[
316            variance_entry(alpha_v[0], alpha_0),
317            covariance_entry(alpha_v[0], alpha_v[1], alpha_0),
318            covariance_entry(alpha_v[0], alpha_v[2], alpha_0)
319        ], [
320            covariance_entry(alpha_v[1], alpha_v[0], alpha_0),
321            variance_entry(alpha_v[1], alpha_0),
322            covariance_entry(alpha_v[1], alpha_v[2], alpha_0)
323        ], [
324            covariance_entry(alpha_v[2], alpha_v[0], alpha_0),
325            covariance_entry(alpha_v[2], alpha_v[1], alpha_0),
326            variance_entry(alpha_v[2], alpha_0)
327        ]]],
328        dtype=np.float32)
329
330    with self.cached_session():
331      # ns is shape [4, 1], and alpha is shape [4, 3].
332      dist = ds.DirichletMultinomial(ns, alpha)
333      covariance = dist.covariance()
334      expected_covariance = shared_matrix * (
335          ns * (ns + alpha_0) / (1 + alpha_0))[..., array_ops.newaxis]
336
337      self.assertEqual([4, 3, 3], covariance.get_shape())
338      self.assertAllClose(expected_covariance, self.evaluate(covariance))
339
340  def testCovarianceMultidimensional(self):
341    alpha = np.random.rand(3, 5, 4).astype(np.float32)
342    alpha2 = np.random.rand(6, 3, 3).astype(np.float32)
343
344    ns = np.random.randint(low=1, high=11, size=[3, 5, 1]).astype(np.float32)
345    ns2 = np.random.randint(low=1, high=11, size=[6, 1, 1]).astype(np.float32)
346
347    with self.cached_session():
348      dist = ds.DirichletMultinomial(ns, alpha)
349      dist2 = ds.DirichletMultinomial(ns2, alpha2)
350
351      covariance = dist.covariance()
352      covariance2 = dist2.covariance()
353      self.assertEqual([3, 5, 4, 4], covariance.get_shape())
354      self.assertEqual([6, 3, 3, 3], covariance2.get_shape())
355
356  def testZeroCountsResultsInPmfEqualToOne(self):
357    # There is only one way for zero items to be selected, and this happens with
358    # probability 1.
359    alpha = [5, 0.5]
360    counts = [0., 0]
361    with self.cached_session():
362      dist = ds.DirichletMultinomial(0., alpha)
363      pmf = dist.prob(counts)
364      self.assertAllClose(1.0, self.evaluate(pmf))
365      self.assertEqual((), pmf.get_shape())
366
367  def testLargeTauGivesPreciseProbabilities(self):
368    # If tau is large, we are doing coin flips with probability mu.
369    mu = np.array([0.1, 0.1, 0.8], dtype=np.float32)
370    tau = np.array([100.], dtype=np.float32)
371    alpha = tau * mu
372
373    # One (three sided) coin flip.  Prob[coin 3] = 0.8.
374    # Note that since it was one flip, value of tau didn't matter.
375    counts = [0., 0, 1]
376    with self.cached_session():
377      dist = ds.DirichletMultinomial(1., alpha)
378      pmf = dist.prob(counts)
379      self.assertAllClose(0.8, self.evaluate(pmf), atol=1e-4)
380      self.assertEqual((), pmf.get_shape())
381
382    # Two (three sided) coin flips.  Prob[coin 3] = 0.8.
383    counts = [0., 0, 2]
384    with self.cached_session():
385      dist = ds.DirichletMultinomial(2., alpha)
386      pmf = dist.prob(counts)
387      self.assertAllClose(0.8**2, self.evaluate(pmf), atol=1e-2)
388      self.assertEqual((), pmf.get_shape())
389
390    # Three (three sided) coin flips.
391    counts = [1., 0, 2]
392    with self.cached_session():
393      dist = ds.DirichletMultinomial(3., alpha)
394      pmf = dist.prob(counts)
395      self.assertAllClose(3 * 0.1 * 0.8 * 0.8, self.evaluate(pmf), atol=1e-2)
396      self.assertEqual((), pmf.get_shape())
397
398  def testSmallTauPrefersCorrelatedResults(self):
399    # If tau is small, then correlation between draws is large, so draws that
400    # are both of the same class are more likely.
401    mu = np.array([0.5, 0.5], dtype=np.float32)
402    tau = np.array([0.1], dtype=np.float32)
403    alpha = tau * mu
404
405    # If there is only one draw, it is still a coin flip, even with small tau.
406    counts = [1., 0]
407    with self.cached_session():
408      dist = ds.DirichletMultinomial(1., alpha)
409      pmf = dist.prob(counts)
410      self.assertAllClose(0.5, self.evaluate(pmf))
411      self.assertEqual((), pmf.get_shape())
412
413    # If there are two draws, it is much more likely that they are the same.
414    counts_same = [2., 0]
415    counts_different = [1, 1.]
416    with self.cached_session():
417      dist = ds.DirichletMultinomial(2., alpha)
418      pmf_same = dist.prob(counts_same)
419      pmf_different = dist.prob(counts_different)
420      self.assertLess(5 * self.evaluate(pmf_different), self.evaluate(pmf_same))
421      self.assertEqual((), pmf_same.get_shape())
422
423  @test_util.run_deprecated_v1
424  def testNonStrictTurnsOffAllChecks(self):
425    # Make totally invalid input.
426    with self.cached_session():
427      alpha = [[-1., 2]]  # alpha should be positive.
428      counts = [[1., 0], [0., -1]]  # counts should be non-negative.
429      n = [-5.3]  # n should be a non negative integer equal to counts.sum.
430      dist = ds.DirichletMultinomial(n, alpha, validate_args=False)
431      dist.prob(counts).eval()  # Should not raise.
432
433  @test_util.run_deprecated_v1
434  def testSampleUnbiasedNonScalarBatch(self):
435    with self.cached_session() as sess:
436      dist = ds.DirichletMultinomial(
437          total_count=5.,
438          concentration=1. + 2. * self._rng.rand(4, 3, 2).astype(np.float32))
439      n = int(3e3)
440      x = dist.sample(n, seed=0)
441      sample_mean = math_ops.reduce_mean(x, 0)
442      # Cyclically rotate event dims left.
443      x_centered = array_ops.transpose(x - sample_mean, [1, 2, 3, 0])
444      sample_covariance = math_ops.matmul(
445          x_centered, x_centered, adjoint_b=True) / n
446      [
447          sample_mean_,
448          sample_covariance_,
449          actual_mean_,
450          actual_covariance_,
451      ] = sess.run([
452          sample_mean,
453          sample_covariance,
454          dist.mean(),
455          dist.covariance(),
456      ])
457      self.assertAllEqual([4, 3, 2], sample_mean.get_shape())
458      self.assertAllClose(actual_mean_, sample_mean_, atol=0., rtol=0.20)
459      self.assertAllEqual([4, 3, 2, 2], sample_covariance.get_shape())
460      self.assertAllClose(
461          actual_covariance_, sample_covariance_, atol=0., rtol=0.20)
462
463  @test_util.run_deprecated_v1
464  def testSampleUnbiasedScalarBatch(self):
465    with self.cached_session() as sess:
466      dist = ds.DirichletMultinomial(
467          total_count=5.,
468          concentration=1. + 2. * self._rng.rand(4).astype(np.float32))
469      n = int(5e3)
470      x = dist.sample(n, seed=0)
471      sample_mean = math_ops.reduce_mean(x, 0)
472      x_centered = x - sample_mean  # Already transposed to [n, 2].
473      sample_covariance = math_ops.matmul(
474          x_centered, x_centered, adjoint_a=True) / n
475      [
476          sample_mean_,
477          sample_covariance_,
478          actual_mean_,
479          actual_covariance_,
480      ] = sess.run([
481          sample_mean,
482          sample_covariance,
483          dist.mean(),
484          dist.covariance(),
485      ])
486      self.assertAllEqual([4], sample_mean.get_shape())
487      self.assertAllClose(actual_mean_, sample_mean_, atol=0., rtol=0.20)
488      self.assertAllEqual([4, 4], sample_covariance.get_shape())
489      self.assertAllClose(
490          actual_covariance_, sample_covariance_, atol=0., rtol=0.20)
491
492  def testNotReparameterized(self):
493    total_count = constant_op.constant(5.0)
494    concentration = constant_op.constant([0.1, 0.1, 0.1])
495    with backprop.GradientTape() as tape:
496      tape.watch(total_count)
497      tape.watch(concentration)
498      dist = ds.DirichletMultinomial(
499          total_count=total_count,
500          concentration=concentration)
501      samples = dist.sample(100)
502    grad_total_count, grad_concentration = tape.gradient(
503        samples, [total_count, concentration])
504    self.assertIsNone(grad_total_count)
505    self.assertIsNone(grad_concentration)
506
507
508if __name__ == "__main__":
509  test.main()
510