• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15import importlib
16
17import numpy as np
18
19from tensorflow.python.eager import backprop
20from tensorflow.python.framework import constant_op
21from tensorflow.python.framework import tensor_shape
22from tensorflow.python.framework import test_util
23from tensorflow.python.ops import array_ops
24from tensorflow.python.ops import math_ops
25from tensorflow.python.ops.distributions import dirichlet as dirichlet_lib
26from tensorflow.python.ops.distributions import kullback_leibler
27from tensorflow.python.platform import test
28from tensorflow.python.platform import tf_logging
29
30
31def try_import(name):  # pylint: disable=invalid-name
32  module = None
33  try:
34    module = importlib.import_module(name)
35  except ImportError as e:
36    tf_logging.warning("Could not import %s: %s" % (name, str(e)))
37  return module
38
39
40special = try_import("scipy.special")
41stats = try_import("scipy.stats")
42
43
44@test_util.run_all_in_graph_and_eager_modes
45class DirichletTest(test.TestCase):
46
47  def testSimpleShapes(self):
48    alpha = np.random.rand(3)
49    dist = dirichlet_lib.Dirichlet(alpha)
50    self.assertEqual(3, self.evaluate(dist.event_shape_tensor()))
51    self.assertAllEqual([], self.evaluate(dist.batch_shape_tensor()))
52    self.assertEqual(tensor_shape.TensorShape([3]), dist.event_shape)
53    self.assertEqual(tensor_shape.TensorShape([]), dist.batch_shape)
54
55  def testComplexShapes(self):
56    alpha = np.random.rand(3, 2, 2)
57    dist = dirichlet_lib.Dirichlet(alpha)
58    self.assertEqual(2, self.evaluate(dist.event_shape_tensor()))
59    self.assertAllEqual([3, 2], self.evaluate(dist.batch_shape_tensor()))
60    self.assertEqual(tensor_shape.TensorShape([2]), dist.event_shape)
61    self.assertEqual(tensor_shape.TensorShape([3, 2]), dist.batch_shape)
62
63  def testConcentrationProperty(self):
64    alpha = [[1., 2, 3]]
65    dist = dirichlet_lib.Dirichlet(alpha)
66    self.assertEqual([1, 3], dist.concentration.get_shape())
67    self.assertAllClose(alpha, self.evaluate(dist.concentration))
68
69  def testPdfXProper(self):
70    alpha = [[1., 2, 3]]
71    dist = dirichlet_lib.Dirichlet(alpha, validate_args=True)
72    self.evaluate(dist.prob([.1, .3, .6]))
73    self.evaluate(dist.prob([.2, .3, .5]))
74    # Either condition can trigger.
75    with self.assertRaisesOpError("samples must be positive"):
76      self.evaluate(dist.prob([-1., 1.5, 0.5]))
77    with self.assertRaisesOpError("samples must be positive"):
78      self.evaluate(dist.prob([0., .1, .9]))
79    with self.assertRaisesOpError("sample last-dimension must sum to `1`"):
80      self.evaluate(dist.prob([.1, .2, .8]))
81
82  def testLogPdfOnBoundaryIsFiniteWhenAlphaIsOne(self):
83    # Test concentration = 1. for each dimension.
84    concentration = 3 * np.ones((10, 10)).astype(np.float32)
85    concentration[range(10), range(10)] = 1.
86    x = 1 / 9. * np.ones((10, 10)).astype(np.float32)
87    x[range(10), range(10)] = 0.
88    dist = dirichlet_lib.Dirichlet(concentration)
89    log_prob = self.evaluate(dist.log_prob(x))
90    self.assertAllEqual(
91        np.ones_like(log_prob, dtype=np.bool_), np.isfinite(log_prob))
92
93    # Test when concentration[k] = 1., and x is zero at various dimensions.
94    dist = dirichlet_lib.Dirichlet(10 * [1.])
95    log_prob = self.evaluate(dist.log_prob(x))
96    self.assertAllEqual(
97        np.ones_like(log_prob, dtype=np.bool_), np.isfinite(log_prob))
98
99  def testPdfZeroBatches(self):
100    alpha = [1., 2]
101    x = [.5, .5]
102    dist = dirichlet_lib.Dirichlet(alpha)
103    pdf = dist.prob(x)
104    self.assertAllClose(1., self.evaluate(pdf))
105    self.assertEqual((), pdf.get_shape())
106
107  def testPdfZeroBatchesNontrivialX(self):
108    alpha = [1., 2]
109    x = [.3, .7]
110    dist = dirichlet_lib.Dirichlet(alpha)
111    pdf = dist.prob(x)
112    self.assertAllClose(7. / 5, self.evaluate(pdf))
113    self.assertEqual((), pdf.get_shape())
114
115  def testPdfUniformZeroBatches(self):
116    # Corresponds to a uniform distribution
117    alpha = [1., 1, 1]
118    x = [[.2, .5, .3], [.3, .4, .3]]
119    dist = dirichlet_lib.Dirichlet(alpha)
120    pdf = dist.prob(x)
121    self.assertAllClose([2., 2.], self.evaluate(pdf))
122    self.assertEqual((2), pdf.get_shape())
123
124  def testPdfAlphaStretchedInBroadcastWhenSameRank(self):
125    alpha = [[1., 2]]
126    x = [[.5, .5], [.3, .7]]
127    dist = dirichlet_lib.Dirichlet(alpha)
128    pdf = dist.prob(x)
129    self.assertAllClose([1., 7. / 5], self.evaluate(pdf))
130    self.assertEqual((2), pdf.get_shape())
131
132  def testPdfAlphaStretchedInBroadcastWhenLowerRank(self):
133    alpha = [1., 2]
134    x = [[.5, .5], [.2, .8]]
135    pdf = dirichlet_lib.Dirichlet(alpha).prob(x)
136    self.assertAllClose([1., 8. / 5], self.evaluate(pdf))
137    self.assertEqual((2), pdf.get_shape())
138
139  def testPdfXStretchedInBroadcastWhenSameRank(self):
140    alpha = [[1., 2], [2., 3]]
141    x = [[.5, .5]]
142    pdf = dirichlet_lib.Dirichlet(alpha).prob(x)
143    self.assertAllClose([1., 3. / 2], self.evaluate(pdf))
144    self.assertEqual((2), pdf.get_shape())
145
146  def testPdfXStretchedInBroadcastWhenLowerRank(self):
147    alpha = [[1., 2], [2., 3]]
148    x = [.5, .5]
149    pdf = dirichlet_lib.Dirichlet(alpha).prob(x)
150    self.assertAllClose([1., 3. / 2], self.evaluate(pdf))
151    self.assertEqual((2), pdf.get_shape())
152
153  def testMean(self):
154    alpha = [1., 2, 3]
155    dirichlet = dirichlet_lib.Dirichlet(concentration=alpha)
156    self.assertEqual(dirichlet.mean().get_shape(), [3])
157    if not stats:
158      return
159    expected_mean = stats.dirichlet.mean(alpha)
160    self.assertAllClose(self.evaluate(dirichlet.mean()), expected_mean)
161
162  def testCovarianceFromSampling(self):
163    alpha = np.array([[1., 2, 3],
164                      [2.5, 4, 0.01]], dtype=np.float32)
165    dist = dirichlet_lib.Dirichlet(alpha)  # batch_shape=[2], event_shape=[3]
166    x = dist.sample(int(250e3), seed=1)
167    sample_mean = math_ops.reduce_mean(x, 0)
168    x_centered = x - sample_mean[None, ...]
169    sample_cov = math_ops.reduce_mean(math_ops.matmul(
170        x_centered[..., None], x_centered[..., None, :]), 0)
171    sample_var = array_ops.matrix_diag_part(sample_cov)
172    sample_stddev = math_ops.sqrt(sample_var)
173
174    [
175        sample_mean_,
176        sample_cov_,
177        sample_var_,
178        sample_stddev_,
179        analytic_mean,
180        analytic_cov,
181        analytic_var,
182        analytic_stddev,
183    ] = self.evaluate([
184        sample_mean,
185        sample_cov,
186        sample_var,
187        sample_stddev,
188        dist.mean(),
189        dist.covariance(),
190        dist.variance(),
191        dist.stddev(),
192    ])
193
194    self.assertAllClose(sample_mean_, analytic_mean, atol=0.04, rtol=0.)
195    self.assertAllClose(sample_cov_, analytic_cov, atol=0.06, rtol=0.)
196    self.assertAllClose(sample_var_, analytic_var, atol=0.04, rtol=0.)
197    self.assertAllClose(sample_stddev_, analytic_stddev, atol=0.02, rtol=0.)
198
199  @test_util.run_without_tensor_float_32(
200      "Calls Dirichlet.covariance, which calls matmul")
201  def testVariance(self):
202    alpha = [1., 2, 3]
203    denominator = np.sum(alpha)**2 * (np.sum(alpha) + 1)
204    dirichlet = dirichlet_lib.Dirichlet(concentration=alpha)
205    self.assertEqual(dirichlet.covariance().get_shape(), (3, 3))
206    if not stats:
207      return
208    expected_covariance = np.diag(stats.dirichlet.var(alpha))
209    expected_covariance += [[0., -2, -3], [-2, 0, -6], [-3, -6, 0]
210                           ] / denominator
211    self.assertAllClose(
212        self.evaluate(dirichlet.covariance()), expected_covariance)
213
214  def testMode(self):
215    alpha = np.array([1.1, 2, 3])
216    expected_mode = (alpha - 1) / (np.sum(alpha) - 3)
217    dirichlet = dirichlet_lib.Dirichlet(concentration=alpha)
218    self.assertEqual(dirichlet.mode().get_shape(), [3])
219    self.assertAllClose(self.evaluate(dirichlet.mode()), expected_mode)
220
221  def testModeInvalid(self):
222    alpha = np.array([1., 2, 3])
223    dirichlet = dirichlet_lib.Dirichlet(
224        concentration=alpha, allow_nan_stats=False)
225    with self.assertRaisesOpError("Condition x < y.*"):
226      self.evaluate(dirichlet.mode())
227
228  def testModeEnableAllowNanStats(self):
229    alpha = np.array([1., 2, 3])
230    dirichlet = dirichlet_lib.Dirichlet(
231        concentration=alpha, allow_nan_stats=True)
232    expected_mode = np.zeros_like(alpha) + np.nan
233
234    self.assertEqual(dirichlet.mode().get_shape(), [3])
235    self.assertAllClose(self.evaluate(dirichlet.mode()), expected_mode)
236
237  def testEntropy(self):
238    alpha = [1., 2, 3]
239    dirichlet = dirichlet_lib.Dirichlet(concentration=alpha)
240    self.assertEqual(dirichlet.entropy().get_shape(), ())
241    if not stats:
242      return
243    expected_entropy = stats.dirichlet.entropy(alpha)
244    self.assertAllClose(self.evaluate(dirichlet.entropy()), expected_entropy)
245
246  def testSample(self):
247    alpha = [1., 2]
248    dirichlet = dirichlet_lib.Dirichlet(alpha)
249    n = constant_op.constant(100000)
250    samples = dirichlet.sample(n)
251    sample_values = self.evaluate(samples)
252    self.assertEqual(sample_values.shape, (100000, 2))
253    self.assertTrue(np.all(sample_values > 0.0))
254    if not stats:
255      return
256    self.assertLess(
257        stats.kstest(
258            # Beta is a univariate distribution.
259            sample_values[:, 0],
260            stats.beta(a=1., b=2.).cdf)[0],
261        0.01)
262
263  def testDirichletFullyReparameterized(self):
264    alpha = constant_op.constant([1.0, 2.0, 3.0])
265    with backprop.GradientTape() as tape:
266      tape.watch(alpha)
267      dirichlet = dirichlet_lib.Dirichlet(alpha)
268      samples = dirichlet.sample(100)
269    grad_alpha = tape.gradient(samples, alpha)
270    self.assertIsNotNone(grad_alpha)
271
272  def testDirichletDirichletKL(self):
273    conc1 = np.array([[1., 2., 3., 1.5, 2.5, 3.5],
274                      [1.5, 2.5, 3.5, 4.5, 5.5, 6.5]])
275    conc2 = np.array([[0.5, 1., 1.5, 2., 2.5, 3.]])
276
277    d1 = dirichlet_lib.Dirichlet(conc1)
278    d2 = dirichlet_lib.Dirichlet(conc2)
279    x = d1.sample(int(1e4), seed=0)
280    kl_sample = math_ops.reduce_mean(d1.log_prob(x) - d2.log_prob(x), 0)
281    kl_actual = kullback_leibler.kl_divergence(d1, d2)
282
283    kl_sample_val = self.evaluate(kl_sample)
284    kl_actual_val = self.evaluate(kl_actual)
285
286    self.assertEqual(conc1.shape[:-1], kl_actual.get_shape())
287
288    if not special:
289      return
290
291    kl_expected = (
292        special.gammaln(np.sum(conc1, -1))
293        - special.gammaln(np.sum(conc2, -1))
294        - np.sum(special.gammaln(conc1) - special.gammaln(conc2), -1)
295        + np.sum((conc1 - conc2) * (special.digamma(conc1) - special.digamma(
296            np.sum(conc1, -1, keepdims=True))), -1))
297
298    self.assertAllClose(kl_expected, kl_actual_val, atol=0., rtol=1e-6)
299    self.assertAllClose(kl_sample_val, kl_actual_val, atol=0., rtol=1e-1)
300
301    # Make sure KL(d1||d1) is 0
302    kl_same = self.evaluate(kullback_leibler.kl_divergence(d1, d1))
303    self.assertAllClose(kl_same, np.zeros_like(kl_expected))
304
305
306if __name__ == "__main__":
307  test.main()
308