• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15from __future__ import absolute_import
16from __future__ import division
17from __future__ import print_function
18
19import importlib
20
21import numpy as np
22
23from tensorflow.python.eager import backprop
24from tensorflow.python.framework import constant_op
25from tensorflow.python.framework import random_seed
26from tensorflow.python.framework import tensor_shape
27from tensorflow.python.framework import test_util
28from tensorflow.python.ops import math_ops
29from tensorflow.python.ops import nn_ops
30from tensorflow.python.ops.distributions import beta as beta_lib
31from tensorflow.python.ops.distributions import kullback_leibler
32from tensorflow.python.platform import test
33from tensorflow.python.platform import tf_logging
34
35
36def try_import(name):  # pylint: disable=invalid-name
37  module = None
38  try:
39    module = importlib.import_module(name)
40  except ImportError as e:
41    tf_logging.warning("Could not import %s: %s" % (name, str(e)))
42  return module
43
44
45special = try_import("scipy.special")
46stats = try_import("scipy.stats")
47
48
49@test_util.run_all_in_graph_and_eager_modes
50class BetaTest(test.TestCase):
51
52  def testSimpleShapes(self):
53    a = np.random.rand(3)
54    b = np.random.rand(3)
55    dist = beta_lib.Beta(a, b)
56    self.assertAllEqual([], self.evaluate(dist.event_shape_tensor()))
57    self.assertAllEqual([3], self.evaluate(dist.batch_shape_tensor()))
58    self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape)
59    self.assertEqual(tensor_shape.TensorShape([3]), dist.batch_shape)
60
61  def testComplexShapes(self):
62    a = np.random.rand(3, 2, 2)
63    b = np.random.rand(3, 2, 2)
64    dist = beta_lib.Beta(a, b)
65    self.assertAllEqual([], self.evaluate(dist.event_shape_tensor()))
66    self.assertAllEqual([3, 2, 2], self.evaluate(dist.batch_shape_tensor()))
67    self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape)
68    self.assertEqual(tensor_shape.TensorShape([3, 2, 2]), dist.batch_shape)
69
70  def testComplexShapesBroadcast(self):
71    a = np.random.rand(3, 2, 2)
72    b = np.random.rand(2, 2)
73    dist = beta_lib.Beta(a, b)
74    self.assertAllEqual([], self.evaluate(dist.event_shape_tensor()))
75    self.assertAllEqual([3, 2, 2], self.evaluate(dist.batch_shape_tensor()))
76    self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape)
77    self.assertEqual(tensor_shape.TensorShape([3, 2, 2]), dist.batch_shape)
78
79  def testAlphaProperty(self):
80    a = [[1., 2, 3]]
81    b = [[2., 4, 3]]
82    dist = beta_lib.Beta(a, b)
83    self.assertEqual([1, 3], dist.concentration1.get_shape())
84    self.assertAllClose(a, self.evaluate(dist.concentration1))
85
86  def testBetaProperty(self):
87    a = [[1., 2, 3]]
88    b = [[2., 4, 3]]
89    dist = beta_lib.Beta(a, b)
90    self.assertEqual([1, 3], dist.concentration0.get_shape())
91    self.assertAllClose(b, self.evaluate(dist.concentration0))
92
93  def testPdfXProper(self):
94    a = [[1., 2, 3]]
95    b = [[2., 4, 3]]
96    dist = beta_lib.Beta(a, b, validate_args=True)
97    self.evaluate(dist.prob([.1, .3, .6]))
98    self.evaluate(dist.prob([.2, .3, .5]))
99    # Either condition can trigger.
100    with self.assertRaisesOpError("sample must be positive"):
101      self.evaluate(dist.prob([-1., 0.1, 0.5]))
102    with self.assertRaisesOpError("sample must be positive"):
103      self.evaluate(dist.prob([0., 0.1, 0.5]))
104    with self.assertRaisesOpError("sample must be less than `1`"):
105      self.evaluate(dist.prob([.1, .2, 1.2]))
106    with self.assertRaisesOpError("sample must be less than `1`"):
107      self.evaluate(dist.prob([.1, .2, 1.0]))
108
109  def testPdfTwoBatches(self):
110    a = [1., 2]
111    b = [1., 2]
112    x = [.5, .5]
113    dist = beta_lib.Beta(a, b)
114    pdf = dist.prob(x)
115    self.assertAllClose([1., 3. / 2], self.evaluate(pdf))
116    self.assertEqual((2,), pdf.get_shape())
117
118  def testPdfTwoBatchesNontrivialX(self):
119    a = [1., 2]
120    b = [1., 2]
121    x = [.3, .7]
122    dist = beta_lib.Beta(a, b)
123    pdf = dist.prob(x)
124    self.assertAllClose([1, 63. / 50], self.evaluate(pdf))
125    self.assertEqual((2,), pdf.get_shape())
126
127  def testPdfUniformZeroBatch(self):
128    # This is equivalent to a uniform distribution
129    a = 1.
130    b = 1.
131    x = np.array([.1, .2, .3, .5, .8], dtype=np.float32)
132    dist = beta_lib.Beta(a, b)
133    pdf = dist.prob(x)
134    self.assertAllClose([1.] * 5, self.evaluate(pdf))
135    self.assertEqual((5,), pdf.get_shape())
136
137  def testPdfAlphaStretchedInBroadcastWhenSameRank(self):
138    a = [[1., 2]]
139    b = [[1., 2]]
140    x = [[.5, .5], [.3, .7]]
141    dist = beta_lib.Beta(a, b)
142    pdf = dist.prob(x)
143    self.assertAllClose([[1., 3. / 2], [1., 63. / 50]], self.evaluate(pdf))
144    self.assertEqual((2, 2), pdf.get_shape())
145
146  def testPdfAlphaStretchedInBroadcastWhenLowerRank(self):
147    a = [1., 2]
148    b = [1., 2]
149    x = [[.5, .5], [.2, .8]]
150    pdf = beta_lib.Beta(a, b).prob(x)
151    self.assertAllClose([[1., 3. / 2], [1., 24. / 25]], self.evaluate(pdf))
152    self.assertEqual((2, 2), pdf.get_shape())
153
154  def testPdfXStretchedInBroadcastWhenSameRank(self):
155    a = [[1., 2], [2., 3]]
156    b = [[1., 2], [2., 3]]
157    x = [[.5, .5]]
158    pdf = beta_lib.Beta(a, b).prob(x)
159    self.assertAllClose([[1., 3. / 2], [3. / 2, 15. / 8]], self.evaluate(pdf))
160    self.assertEqual((2, 2), pdf.get_shape())
161
162  def testPdfXStretchedInBroadcastWhenLowerRank(self):
163    a = [[1., 2], [2., 3]]
164    b = [[1., 2], [2., 3]]
165    x = [.5, .5]
166    pdf = beta_lib.Beta(a, b).prob(x)
167    self.assertAllClose([[1., 3. / 2], [3. / 2, 15. / 8]], self.evaluate(pdf))
168    self.assertEqual((2, 2), pdf.get_shape())
169
170  def testLogPdfOnBoundaryIsFiniteWhenAlphaIsOne(self):
171    b = [[0.01, 0.1, 1., 2], [5., 10., 2., 3]]
172    pdf = self.evaluate(beta_lib.Beta(1., b).prob(0.))
173    self.assertAllEqual(np.ones_like(pdf, dtype=np.bool), np.isfinite(pdf))
174
175  def testBetaMean(self):
176    a = [1., 2, 3]
177    b = [2., 4, 1.2]
178    dist = beta_lib.Beta(a, b)
179    self.assertEqual(dist.mean().get_shape(), (3,))
180    if not stats:
181      return
182    expected_mean = stats.beta.mean(a, b)
183    self.assertAllClose(expected_mean, self.evaluate(dist.mean()))
184
185  def testBetaVariance(self):
186    a = [1., 2, 3]
187    b = [2., 4, 1.2]
188    dist = beta_lib.Beta(a, b)
189    self.assertEqual(dist.variance().get_shape(), (3,))
190    if not stats:
191      return
192    expected_variance = stats.beta.var(a, b)
193    self.assertAllClose(expected_variance, self.evaluate(dist.variance()))
194
195  def testBetaMode(self):
196    a = np.array([1.1, 2, 3])
197    b = np.array([2., 4, 1.2])
198    expected_mode = (a - 1) / (a + b - 2)
199    dist = beta_lib.Beta(a, b)
200    self.assertEqual(dist.mode().get_shape(), (3,))
201    self.assertAllClose(expected_mode, self.evaluate(dist.mode()))
202
203  def testBetaModeInvalid(self):
204    a = np.array([1., 2, 3])
205    b = np.array([2., 4, 1.2])
206    dist = beta_lib.Beta(a, b, allow_nan_stats=False)
207    with self.assertRaisesOpError("Condition x < y.*"):
208      self.evaluate(dist.mode())
209
210    a = np.array([2., 2, 3])
211    b = np.array([1., 4, 1.2])
212    dist = beta_lib.Beta(a, b, allow_nan_stats=False)
213    with self.assertRaisesOpError("Condition x < y.*"):
214      self.evaluate(dist.mode())
215
216  def testBetaModeEnableAllowNanStats(self):
217    a = np.array([1., 2, 3])
218    b = np.array([2., 4, 1.2])
219    dist = beta_lib.Beta(a, b, allow_nan_stats=True)
220
221    expected_mode = (a - 1) / (a + b - 2)
222    expected_mode[0] = np.nan
223    self.assertEqual((3,), dist.mode().get_shape())
224    self.assertAllClose(expected_mode, self.evaluate(dist.mode()))
225
226    a = np.array([2., 2, 3])
227    b = np.array([1., 4, 1.2])
228    dist = beta_lib.Beta(a, b, allow_nan_stats=True)
229
230    expected_mode = (a - 1) / (a + b - 2)
231    expected_mode[0] = np.nan
232    self.assertEqual((3,), dist.mode().get_shape())
233    self.assertAllClose(expected_mode, self.evaluate(dist.mode()))
234
235  def testBetaEntropy(self):
236    a = [1., 2, 3]
237    b = [2., 4, 1.2]
238    dist = beta_lib.Beta(a, b)
239    self.assertEqual(dist.entropy().get_shape(), (3,))
240    if not stats:
241      return
242    expected_entropy = stats.beta.entropy(a, b)
243    self.assertAllClose(expected_entropy, self.evaluate(dist.entropy()))
244
245  def testBetaSample(self):
246    a = 1.
247    b = 2.
248    beta = beta_lib.Beta(a, b)
249    n = constant_op.constant(100000)
250    samples = beta.sample(n)
251    sample_values = self.evaluate(samples)
252    self.assertEqual(sample_values.shape, (100000,))
253    self.assertFalse(np.any(sample_values < 0.0))
254    if not stats:
255      return
256    self.assertLess(
257        stats.kstest(
258            # Beta is a univariate distribution.
259            sample_values,
260            stats.beta(a=1., b=2.).cdf)[0],
261        0.01)
262    # The standard error of the sample mean is 1 / (sqrt(18 * n))
263    self.assertAllClose(
264        sample_values.mean(axis=0), stats.beta.mean(a, b), atol=1e-2)
265    self.assertAllClose(
266        np.cov(sample_values, rowvar=0), stats.beta.var(a, b), atol=1e-1)
267
268  def testBetaFullyReparameterized(self):
269    a = constant_op.constant(1.0)
270    b = constant_op.constant(2.0)
271    with backprop.GradientTape() as tape:
272      tape.watch(a)
273      tape.watch(b)
274      beta = beta_lib.Beta(a, b)
275      samples = beta.sample(100)
276    grad_a, grad_b = tape.gradient(samples, [a, b])
277    self.assertIsNotNone(grad_a)
278    self.assertIsNotNone(grad_b)
279
280  # Test that sampling with the same seed twice gives the same results.
281  def testBetaSampleMultipleTimes(self):
282    a_val = 1.
283    b_val = 2.
284    n_val = 100
285
286    random_seed.set_random_seed(654321)
287    beta1 = beta_lib.Beta(
288        concentration1=a_val, concentration0=b_val, name="beta1")
289    samples1 = self.evaluate(beta1.sample(n_val, seed=123456))
290
291    random_seed.set_random_seed(654321)
292    beta2 = beta_lib.Beta(
293        concentration1=a_val, concentration0=b_val, name="beta2")
294    samples2 = self.evaluate(beta2.sample(n_val, seed=123456))
295
296    self.assertAllClose(samples1, samples2)
297
298  def testBetaSampleMultidimensional(self):
299    a = np.random.rand(3, 2, 2).astype(np.float32)
300    b = np.random.rand(3, 2, 2).astype(np.float32)
301    beta = beta_lib.Beta(a, b)
302    n = constant_op.constant(100000)
303    samples = beta.sample(n)
304    sample_values = self.evaluate(samples)
305    self.assertEqual(sample_values.shape, (100000, 3, 2, 2))
306    self.assertFalse(np.any(sample_values < 0.0))
307    if not stats:
308      return
309    self.assertAllClose(
310        sample_values[:, 1, :].mean(axis=0),
311        stats.beta.mean(a, b)[1, :],
312        atol=1e-1)
313
314  def testBetaCdf(self):
315    shape = (30, 40, 50)
316    for dt in (np.float32, np.float64):
317      a = 10. * np.random.random(shape).astype(dt)
318      b = 10. * np.random.random(shape).astype(dt)
319      x = np.random.random(shape).astype(dt)
320      actual = self.evaluate(beta_lib.Beta(a, b).cdf(x))
321      self.assertAllEqual(np.ones(shape, dtype=np.bool), 0. <= x)
322      self.assertAllEqual(np.ones(shape, dtype=np.bool), 1. >= x)
323      if not stats:
324        return
325      self.assertAllClose(stats.beta.cdf(x, a, b), actual, rtol=9e-3, atol=5e-6)
326
327  def testBetaLogCdf(self):
328    shape = (30, 40, 50)
329    for dt in (np.float32, np.float64):
330      a = 10. * np.random.random(shape).astype(dt)
331      b = 10. * np.random.random(shape).astype(dt)
332      x = np.random.random(shape).astype(dt)
333      actual = self.evaluate(math_ops.exp(beta_lib.Beta(a, b).log_cdf(x)))
334      self.assertAllEqual(np.ones(shape, dtype=np.bool), 0. <= x)
335      self.assertAllEqual(np.ones(shape, dtype=np.bool), 1. >= x)
336      if not stats:
337        return
338      self.assertAllClose(stats.beta.cdf(x, a, b), actual, rtol=3e-3, atol=2e-5)
339
340  def testBetaWithSoftplusConcentration(self):
341    a, b = -4.2, -9.1
342    dist = beta_lib.BetaWithSoftplusConcentration(a, b)
343    self.assertAllClose(
344        self.evaluate(nn_ops.softplus(a)), self.evaluate(dist.concentration1))
345    self.assertAllClose(
346        self.evaluate(nn_ops.softplus(b)), self.evaluate(dist.concentration0))
347
348  def testBetaBetaKL(self):
349    for shape in [(10,), (4, 5)]:
350      a1 = 6.0 * np.random.random(size=shape) + 1e-4
351      b1 = 6.0 * np.random.random(size=shape) + 1e-4
352      a2 = 6.0 * np.random.random(size=shape) + 1e-4
353      b2 = 6.0 * np.random.random(size=shape) + 1e-4
354      # Take inverse softplus of values to test BetaWithSoftplusConcentration
355      a1_sp = np.log(np.exp(a1) - 1.0)
356      b1_sp = np.log(np.exp(b1) - 1.0)
357      a2_sp = np.log(np.exp(a2) - 1.0)
358      b2_sp = np.log(np.exp(b2) - 1.0)
359
360      d1 = beta_lib.Beta(concentration1=a1, concentration0=b1)
361      d2 = beta_lib.Beta(concentration1=a2, concentration0=b2)
362      d1_sp = beta_lib.BetaWithSoftplusConcentration(concentration1=a1_sp,
363                                                     concentration0=b1_sp)
364      d2_sp = beta_lib.BetaWithSoftplusConcentration(concentration1=a2_sp,
365                                                     concentration0=b2_sp)
366
367      if not special:
368        return
369      kl_expected = (special.betaln(a2, b2) - special.betaln(a1, b1) +
370                     (a1 - a2) * special.digamma(a1) +
371                     (b1 - b2) * special.digamma(b1) +
372                     (a2 - a1 + b2 - b1) * special.digamma(a1 + b1))
373
374      for dist1 in [d1, d1_sp]:
375        for dist2 in [d2, d2_sp]:
376          kl = kullback_leibler.kl_divergence(dist1, dist2)
377          kl_val = self.evaluate(kl)
378          self.assertEqual(kl.get_shape(), shape)
379          self.assertAllClose(kl_val, kl_expected)
380
381      # Make sure KL(d1||d1) is 0
382      kl_same = self.evaluate(kullback_leibler.kl_divergence(d1, d1))
383      self.assertAllClose(kl_same, np.zeros_like(kl_expected))
384
385
386if __name__ == "__main__":
387  test.main()
388