• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15from __future__ import absolute_import
16from __future__ import division
17from __future__ import print_function
18
19import importlib
20
21import numpy as np
22
23from tensorflow.python.eager import backprop
24from tensorflow.python.framework import constant_op
25from tensorflow.python.framework import tensor_shape
26from tensorflow.python.framework import test_util
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import math_ops
29from tensorflow.python.ops.distributions import dirichlet as dirichlet_lib
30from tensorflow.python.ops.distributions import kullback_leibler
31from tensorflow.python.platform import test
32from tensorflow.python.platform import tf_logging
33
34
35def try_import(name):  # pylint: disable=invalid-name
36  module = None
37  try:
38    module = importlib.import_module(name)
39  except ImportError as e:
40    tf_logging.warning("Could not import %s: %s" % (name, str(e)))
41  return module
42
43
44special = try_import("scipy.special")
45stats = try_import("scipy.stats")
46
47
48@test_util.run_all_in_graph_and_eager_modes
49class DirichletTest(test.TestCase):
50
51  def testSimpleShapes(self):
52    alpha = np.random.rand(3)
53    dist = dirichlet_lib.Dirichlet(alpha)
54    self.assertEqual(3, self.evaluate(dist.event_shape_tensor()))
55    self.assertAllEqual([], self.evaluate(dist.batch_shape_tensor()))
56    self.assertEqual(tensor_shape.TensorShape([3]), dist.event_shape)
57    self.assertEqual(tensor_shape.TensorShape([]), dist.batch_shape)
58
59  def testComplexShapes(self):
60    alpha = np.random.rand(3, 2, 2)
61    dist = dirichlet_lib.Dirichlet(alpha)
62    self.assertEqual(2, self.evaluate(dist.event_shape_tensor()))
63    self.assertAllEqual([3, 2], self.evaluate(dist.batch_shape_tensor()))
64    self.assertEqual(tensor_shape.TensorShape([2]), dist.event_shape)
65    self.assertEqual(tensor_shape.TensorShape([3, 2]), dist.batch_shape)
66
67  def testConcentrationProperty(self):
68    alpha = [[1., 2, 3]]
69    dist = dirichlet_lib.Dirichlet(alpha)
70    self.assertEqual([1, 3], dist.concentration.get_shape())
71    self.assertAllClose(alpha, self.evaluate(dist.concentration))
72
73  def testPdfXProper(self):
74    alpha = [[1., 2, 3]]
75    dist = dirichlet_lib.Dirichlet(alpha, validate_args=True)
76    self.evaluate(dist.prob([.1, .3, .6]))
77    self.evaluate(dist.prob([.2, .3, .5]))
78    # Either condition can trigger.
79    with self.assertRaisesOpError("samples must be positive"):
80      self.evaluate(dist.prob([-1., 1.5, 0.5]))
81    with self.assertRaisesOpError("samples must be positive"):
82      self.evaluate(dist.prob([0., .1, .9]))
83    with self.assertRaisesOpError("sample last-dimension must sum to `1`"):
84      self.evaluate(dist.prob([.1, .2, .8]))
85
86  def testLogPdfOnBoundaryIsFiniteWhenAlphaIsOne(self):
87    # Test concentration = 1. for each dimension.
88    concentration = 3 * np.ones((10, 10)).astype(np.float32)
89    concentration[range(10), range(10)] = 1.
90    x = 1 / 9. * np.ones((10, 10)).astype(np.float32)
91    x[range(10), range(10)] = 0.
92    dist = dirichlet_lib.Dirichlet(concentration)
93    log_prob = self.evaluate(dist.log_prob(x))
94    self.assertAllEqual(
95        np.ones_like(log_prob, dtype=np.bool), np.isfinite(log_prob))
96
97    # Test when concentration[k] = 1., and x is zero at various dimensions.
98    dist = dirichlet_lib.Dirichlet(10 * [1.])
99    log_prob = self.evaluate(dist.log_prob(x))
100    self.assertAllEqual(
101        np.ones_like(log_prob, dtype=np.bool), np.isfinite(log_prob))
102
103  def testPdfZeroBatches(self):
104    alpha = [1., 2]
105    x = [.5, .5]
106    dist = dirichlet_lib.Dirichlet(alpha)
107    pdf = dist.prob(x)
108    self.assertAllClose(1., self.evaluate(pdf))
109    self.assertEqual((), pdf.get_shape())
110
111  def testPdfZeroBatchesNontrivialX(self):
112    alpha = [1., 2]
113    x = [.3, .7]
114    dist = dirichlet_lib.Dirichlet(alpha)
115    pdf = dist.prob(x)
116    self.assertAllClose(7. / 5, self.evaluate(pdf))
117    self.assertEqual((), pdf.get_shape())
118
119  def testPdfUniformZeroBatches(self):
120    # Corresponds to a uniform distribution
121    alpha = [1., 1, 1]
122    x = [[.2, .5, .3], [.3, .4, .3]]
123    dist = dirichlet_lib.Dirichlet(alpha)
124    pdf = dist.prob(x)
125    self.assertAllClose([2., 2.], self.evaluate(pdf))
126    self.assertEqual((2), pdf.get_shape())
127
128  def testPdfAlphaStretchedInBroadcastWhenSameRank(self):
129    alpha = [[1., 2]]
130    x = [[.5, .5], [.3, .7]]
131    dist = dirichlet_lib.Dirichlet(alpha)
132    pdf = dist.prob(x)
133    self.assertAllClose([1., 7. / 5], self.evaluate(pdf))
134    self.assertEqual((2), pdf.get_shape())
135
136  def testPdfAlphaStretchedInBroadcastWhenLowerRank(self):
137    alpha = [1., 2]
138    x = [[.5, .5], [.2, .8]]
139    pdf = dirichlet_lib.Dirichlet(alpha).prob(x)
140    self.assertAllClose([1., 8. / 5], self.evaluate(pdf))
141    self.assertEqual((2), pdf.get_shape())
142
143  def testPdfXStretchedInBroadcastWhenSameRank(self):
144    alpha = [[1., 2], [2., 3]]
145    x = [[.5, .5]]
146    pdf = dirichlet_lib.Dirichlet(alpha).prob(x)
147    self.assertAllClose([1., 3. / 2], self.evaluate(pdf))
148    self.assertEqual((2), pdf.get_shape())
149
150  def testPdfXStretchedInBroadcastWhenLowerRank(self):
151    alpha = [[1., 2], [2., 3]]
152    x = [.5, .5]
153    pdf = dirichlet_lib.Dirichlet(alpha).prob(x)
154    self.assertAllClose([1., 3. / 2], self.evaluate(pdf))
155    self.assertEqual((2), pdf.get_shape())
156
157  def testMean(self):
158    alpha = [1., 2, 3]
159    dirichlet = dirichlet_lib.Dirichlet(concentration=alpha)
160    self.assertEqual(dirichlet.mean().get_shape(), [3])
161    if not stats:
162      return
163    expected_mean = stats.dirichlet.mean(alpha)
164    self.assertAllClose(self.evaluate(dirichlet.mean()), expected_mean)
165
166  def testCovarianceFromSampling(self):
167    alpha = np.array([[1., 2, 3],
168                      [2.5, 4, 0.01]], dtype=np.float32)
169    dist = dirichlet_lib.Dirichlet(alpha)  # batch_shape=[2], event_shape=[3]
170    x = dist.sample(int(250e3), seed=1)
171    sample_mean = math_ops.reduce_mean(x, 0)
172    x_centered = x - sample_mean[None, ...]
173    sample_cov = math_ops.reduce_mean(math_ops.matmul(
174        x_centered[..., None], x_centered[..., None, :]), 0)
175    sample_var = array_ops.matrix_diag_part(sample_cov)
176    sample_stddev = math_ops.sqrt(sample_var)
177
178    [
179        sample_mean_,
180        sample_cov_,
181        sample_var_,
182        sample_stddev_,
183        analytic_mean,
184        analytic_cov,
185        analytic_var,
186        analytic_stddev,
187    ] = self.evaluate([
188        sample_mean,
189        sample_cov,
190        sample_var,
191        sample_stddev,
192        dist.mean(),
193        dist.covariance(),
194        dist.variance(),
195        dist.stddev(),
196    ])
197
198    self.assertAllClose(sample_mean_, analytic_mean, atol=0.04, rtol=0.)
199    self.assertAllClose(sample_cov_, analytic_cov, atol=0.06, rtol=0.)
200    self.assertAllClose(sample_var_, analytic_var, atol=0.03, rtol=0.)
201    self.assertAllClose(sample_stddev_, analytic_stddev, atol=0.02, rtol=0.)
202
203  def testVariance(self):
204    alpha = [1., 2, 3]
205    denominator = np.sum(alpha)**2 * (np.sum(alpha) + 1)
206    dirichlet = dirichlet_lib.Dirichlet(concentration=alpha)
207    self.assertEqual(dirichlet.covariance().get_shape(), (3, 3))
208    if not stats:
209      return
210    expected_covariance = np.diag(stats.dirichlet.var(alpha))
211    expected_covariance += [[0., -2, -3], [-2, 0, -6], [-3, -6, 0]
212                           ] / denominator
213    self.assertAllClose(
214        self.evaluate(dirichlet.covariance()), expected_covariance)
215
216  def testMode(self):
217    alpha = np.array([1.1, 2, 3])
218    expected_mode = (alpha - 1) / (np.sum(alpha) - 3)
219    dirichlet = dirichlet_lib.Dirichlet(concentration=alpha)
220    self.assertEqual(dirichlet.mode().get_shape(), [3])
221    self.assertAllClose(self.evaluate(dirichlet.mode()), expected_mode)
222
223  def testModeInvalid(self):
224    alpha = np.array([1., 2, 3])
225    dirichlet = dirichlet_lib.Dirichlet(
226        concentration=alpha, allow_nan_stats=False)
227    with self.assertRaisesOpError("Condition x < y.*"):
228      self.evaluate(dirichlet.mode())
229
230  def testModeEnableAllowNanStats(self):
231    alpha = np.array([1., 2, 3])
232    dirichlet = dirichlet_lib.Dirichlet(
233        concentration=alpha, allow_nan_stats=True)
234    expected_mode = np.zeros_like(alpha) + np.nan
235
236    self.assertEqual(dirichlet.mode().get_shape(), [3])
237    self.assertAllClose(self.evaluate(dirichlet.mode()), expected_mode)
238
239  def testEntropy(self):
240    alpha = [1., 2, 3]
241    dirichlet = dirichlet_lib.Dirichlet(concentration=alpha)
242    self.assertEqual(dirichlet.entropy().get_shape(), ())
243    if not stats:
244      return
245    expected_entropy = stats.dirichlet.entropy(alpha)
246    self.assertAllClose(self.evaluate(dirichlet.entropy()), expected_entropy)
247
248  def testSample(self):
249    alpha = [1., 2]
250    dirichlet = dirichlet_lib.Dirichlet(alpha)
251    n = constant_op.constant(100000)
252    samples = dirichlet.sample(n)
253    sample_values = self.evaluate(samples)
254    self.assertEqual(sample_values.shape, (100000, 2))
255    self.assertTrue(np.all(sample_values > 0.0))
256    if not stats:
257      return
258    self.assertLess(
259        stats.kstest(
260            # Beta is a univariate distribution.
261            sample_values[:, 0],
262            stats.beta(a=1., b=2.).cdf)[0],
263        0.01)
264
265  def testDirichletFullyReparameterized(self):
266    alpha = constant_op.constant([1.0, 2.0, 3.0])
267    with backprop.GradientTape() as tape:
268      tape.watch(alpha)
269      dirichlet = dirichlet_lib.Dirichlet(alpha)
270      samples = dirichlet.sample(100)
271    grad_alpha = tape.gradient(samples, alpha)
272    self.assertIsNotNone(grad_alpha)
273
274  def testDirichletDirichletKL(self):
275    conc1 = np.array([[1., 2., 3., 1.5, 2.5, 3.5],
276                      [1.5, 2.5, 3.5, 4.5, 5.5, 6.5]])
277    conc2 = np.array([[0.5, 1., 1.5, 2., 2.5, 3.]])
278
279    d1 = dirichlet_lib.Dirichlet(conc1)
280    d2 = dirichlet_lib.Dirichlet(conc2)
281    x = d1.sample(int(1e4), seed=0)
282    kl_sample = math_ops.reduce_mean(d1.log_prob(x) - d2.log_prob(x), 0)
283    kl_actual = kullback_leibler.kl_divergence(d1, d2)
284
285    kl_sample_val = self.evaluate(kl_sample)
286    kl_actual_val = self.evaluate(kl_actual)
287
288    self.assertEqual(conc1.shape[:-1], kl_actual.get_shape())
289
290    if not special:
291      return
292
293    kl_expected = (
294        special.gammaln(np.sum(conc1, -1))
295        - special.gammaln(np.sum(conc2, -1))
296        - np.sum(special.gammaln(conc1) - special.gammaln(conc2), -1)
297        + np.sum((conc1 - conc2) * (special.digamma(conc1) - special.digamma(
298            np.sum(conc1, -1, keepdims=True))), -1))
299
300    self.assertAllClose(kl_expected, kl_actual_val, atol=0., rtol=1e-6)
301    self.assertAllClose(kl_sample_val, kl_actual_val, atol=0., rtol=1e-1)
302
303    # Make sure KL(d1||d1) is 0
304    kl_same = self.evaluate(kullback_leibler.kl_divergence(d1, d1))
305    self.assertAllClose(kl_same, np.zeros_like(kl_expected))
306
307
308if __name__ == "__main__":
309  test.main()
310