• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15import importlib
16
17import numpy as np
18
19from tensorflow.python.eager import backprop
20from tensorflow.python.framework import constant_op
21from tensorflow.python.framework import random_seed
22from tensorflow.python.framework import tensor_shape
23from tensorflow.python.framework import test_util
24from tensorflow.python.ops import math_ops
25from tensorflow.python.ops import nn_ops
26from tensorflow.python.ops.distributions import beta as beta_lib
27from tensorflow.python.ops.distributions import kullback_leibler
28from tensorflow.python.platform import test
29from tensorflow.python.platform import tf_logging
30
31
32def try_import(name):  # pylint: disable=invalid-name
33  module = None
34  try:
35    module = importlib.import_module(name)
36  except ImportError as e:
37    tf_logging.warning("Could not import %s: %s" % (name, str(e)))
38  return module
39
40
41special = try_import("scipy.special")
42stats = try_import("scipy.stats")
43
44
45@test_util.run_all_in_graph_and_eager_modes
46class BetaTest(test.TestCase):
47
48  def testSimpleShapes(self):
49    a = np.random.rand(3)
50    b = np.random.rand(3)
51    dist = beta_lib.Beta(a, b)
52    self.assertAllEqual([], self.evaluate(dist.event_shape_tensor()))
53    self.assertAllEqual([3], self.evaluate(dist.batch_shape_tensor()))
54    self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape)
55    self.assertEqual(tensor_shape.TensorShape([3]), dist.batch_shape)
56
57  def testComplexShapes(self):
58    a = np.random.rand(3, 2, 2)
59    b = np.random.rand(3, 2, 2)
60    dist = beta_lib.Beta(a, b)
61    self.assertAllEqual([], self.evaluate(dist.event_shape_tensor()))
62    self.assertAllEqual([3, 2, 2], self.evaluate(dist.batch_shape_tensor()))
63    self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape)
64    self.assertEqual(tensor_shape.TensorShape([3, 2, 2]), dist.batch_shape)
65
66  def testComplexShapesBroadcast(self):
67    a = np.random.rand(3, 2, 2)
68    b = np.random.rand(2, 2)
69    dist = beta_lib.Beta(a, b)
70    self.assertAllEqual([], self.evaluate(dist.event_shape_tensor()))
71    self.assertAllEqual([3, 2, 2], self.evaluate(dist.batch_shape_tensor()))
72    self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape)
73    self.assertEqual(tensor_shape.TensorShape([3, 2, 2]), dist.batch_shape)
74
75  def testAlphaProperty(self):
76    a = [[1., 2, 3]]
77    b = [[2., 4, 3]]
78    dist = beta_lib.Beta(a, b)
79    self.assertEqual([1, 3], dist.concentration1.get_shape())
80    self.assertAllClose(a, self.evaluate(dist.concentration1))
81
82  def testBetaProperty(self):
83    a = [[1., 2, 3]]
84    b = [[2., 4, 3]]
85    dist = beta_lib.Beta(a, b)
86    self.assertEqual([1, 3], dist.concentration0.get_shape())
87    self.assertAllClose(b, self.evaluate(dist.concentration0))
88
89  def testPdfXProper(self):
90    a = [[1., 2, 3]]
91    b = [[2., 4, 3]]
92    dist = beta_lib.Beta(a, b, validate_args=True)
93    self.evaluate(dist.prob([.1, .3, .6]))
94    self.evaluate(dist.prob([.2, .3, .5]))
95    # Either condition can trigger.
96    with self.assertRaisesOpError("sample must be positive"):
97      self.evaluate(dist.prob([-1., 0.1, 0.5]))
98    with self.assertRaisesOpError("sample must be positive"):
99      self.evaluate(dist.prob([0., 0.1, 0.5]))
100    with self.assertRaisesOpError("sample must be less than `1`"):
101      self.evaluate(dist.prob([.1, .2, 1.2]))
102    with self.assertRaisesOpError("sample must be less than `1`"):
103      self.evaluate(dist.prob([.1, .2, 1.0]))
104
105  def testPdfTwoBatches(self):
106    a = [1., 2]
107    b = [1., 2]
108    x = [.5, .5]
109    dist = beta_lib.Beta(a, b)
110    pdf = dist.prob(x)
111    self.assertAllClose([1., 3. / 2], self.evaluate(pdf), rtol=1e-5, atol=1e-5)
112    self.assertEqual((2,), pdf.get_shape())
113
114  def testPdfTwoBatchesNontrivialX(self):
115    a = [1., 2]
116    b = [1., 2]
117    x = [.3, .7]
118    dist = beta_lib.Beta(a, b)
119    pdf = dist.prob(x)
120    self.assertAllClose([1, 63. / 50], self.evaluate(pdf), rtol=1e-5, atol=1e-5)
121    self.assertEqual((2,), pdf.get_shape())
122
123  def testPdfUniformZeroBatch(self):
124    # This is equivalent to a uniform distribution
125    a = 1.
126    b = 1.
127    x = np.array([.1, .2, .3, .5, .8], dtype=np.float32)
128    dist = beta_lib.Beta(a, b)
129    pdf = dist.prob(x)
130    self.assertAllClose([1.] * 5, self.evaluate(pdf))
131    self.assertEqual((5,), pdf.get_shape())
132
133  def testPdfAlphaStretchedInBroadcastWhenSameRank(self):
134    a = [[1., 2]]
135    b = [[1., 2]]
136    x = [[.5, .5], [.3, .7]]
137    dist = beta_lib.Beta(a, b)
138    pdf = dist.prob(x)
139    self.assertAllClose([[1., 3. / 2], [1., 63. / 50]],
140                        self.evaluate(pdf),
141                        rtol=1e-5,
142                        atol=1e-5)
143    self.assertEqual((2, 2), pdf.get_shape())
144
145  def testPdfAlphaStretchedInBroadcastWhenLowerRank(self):
146    a = [1., 2]
147    b = [1., 2]
148    x = [[.5, .5], [.2, .8]]
149    pdf = beta_lib.Beta(a, b).prob(x)
150    self.assertAllClose([[1., 3. / 2], [1., 24. / 25]],
151                        self.evaluate(pdf),
152                        rtol=1e-5,
153                        atol=1e-5)
154    self.assertEqual((2, 2), pdf.get_shape())
155
156  def testPdfXStretchedInBroadcastWhenSameRank(self):
157    a = [[1., 2], [2., 3]]
158    b = [[1., 2], [2., 3]]
159    x = [[.5, .5]]
160    pdf = beta_lib.Beta(a, b).prob(x)
161    self.assertAllClose([[1., 3. / 2], [3. / 2, 15. / 8]],
162                        self.evaluate(pdf),
163                        rtol=1e-5,
164                        atol=1e-5)
165    self.assertEqual((2, 2), pdf.get_shape())
166
167  def testPdfXStretchedInBroadcastWhenLowerRank(self):
168    a = [[1., 2], [2., 3]]
169    b = [[1., 2], [2., 3]]
170    x = [.5, .5]
171    pdf = beta_lib.Beta(a, b).prob(x)
172    self.assertAllClose([[1., 3. / 2], [3. / 2, 15. / 8]],
173                        self.evaluate(pdf),
174                        rtol=1e-5,
175                        atol=1e-5)
176    self.assertEqual((2, 2), pdf.get_shape())
177
178  def testLogPdfOnBoundaryIsFiniteWhenAlphaIsOne(self):
179    b = [[0.01, 0.1, 1., 2], [5., 10., 2., 3]]
180    pdf = self.evaluate(beta_lib.Beta(1., b).prob(0.))
181    self.assertAllEqual(np.ones_like(pdf, dtype=np.bool_), np.isfinite(pdf))
182
183  def testBetaMean(self):
184    a = [1., 2, 3]
185    b = [2., 4, 1.2]
186    dist = beta_lib.Beta(a, b)
187    self.assertEqual(dist.mean().get_shape(), (3,))
188    if not stats:
189      return
190    expected_mean = stats.beta.mean(a, b)
191    self.assertAllClose(expected_mean, self.evaluate(dist.mean()))
192
193  def testBetaVariance(self):
194    a = [1., 2, 3]
195    b = [2., 4, 1.2]
196    dist = beta_lib.Beta(a, b)
197    self.assertEqual(dist.variance().get_shape(), (3,))
198    if not stats:
199      return
200    expected_variance = stats.beta.var(a, b)
201    self.assertAllClose(expected_variance, self.evaluate(dist.variance()))
202
203  def testBetaMode(self):
204    a = np.array([1.1, 2, 3])
205    b = np.array([2., 4, 1.2])
206    expected_mode = (a - 1) / (a + b - 2)
207    dist = beta_lib.Beta(a, b)
208    self.assertEqual(dist.mode().get_shape(), (3,))
209    self.assertAllClose(expected_mode, self.evaluate(dist.mode()))
210
211  def testBetaModeInvalid(self):
212    a = np.array([1., 2, 3])
213    b = np.array([2., 4, 1.2])
214    dist = beta_lib.Beta(a, b, allow_nan_stats=False)
215    with self.assertRaisesOpError("Condition x < y.*"):
216      self.evaluate(dist.mode())
217
218    a = np.array([2., 2, 3])
219    b = np.array([1., 4, 1.2])
220    dist = beta_lib.Beta(a, b, allow_nan_stats=False)
221    with self.assertRaisesOpError("Condition x < y.*"):
222      self.evaluate(dist.mode())
223
224  def testBetaModeEnableAllowNanStats(self):
225    a = np.array([1., 2, 3])
226    b = np.array([2., 4, 1.2])
227    dist = beta_lib.Beta(a, b, allow_nan_stats=True)
228
229    expected_mode = (a - 1) / (a + b - 2)
230    expected_mode[0] = np.nan
231    self.assertEqual((3,), dist.mode().get_shape())
232    self.assertAllClose(expected_mode, self.evaluate(dist.mode()))
233
234    a = np.array([2., 2, 3])
235    b = np.array([1., 4, 1.2])
236    dist = beta_lib.Beta(a, b, allow_nan_stats=True)
237
238    expected_mode = (a - 1) / (a + b - 2)
239    expected_mode[0] = np.nan
240    self.assertEqual((3,), dist.mode().get_shape())
241    self.assertAllClose(expected_mode, self.evaluate(dist.mode()))
242
243  def testBetaEntropy(self):
244    a = [1., 2, 3]
245    b = [2., 4, 1.2]
246    dist = beta_lib.Beta(a, b)
247    self.assertEqual(dist.entropy().get_shape(), (3,))
248    if not stats:
249      return
250    expected_entropy = stats.beta.entropy(a, b)
251    self.assertAllClose(expected_entropy, self.evaluate(dist.entropy()))
252
253  def testBetaSample(self):
254    a = 1.
255    b = 2.
256    beta = beta_lib.Beta(a, b)
257    n = constant_op.constant(100000)
258    samples = beta.sample(n)
259    sample_values = self.evaluate(samples)
260    self.assertEqual(sample_values.shape, (100000,))
261    self.assertFalse(np.any(sample_values < 0.0))
262    if not stats:
263      return
264    self.assertLess(
265        stats.kstest(
266            # Beta is a univariate distribution.
267            sample_values,
268            stats.beta(a=1., b=2.).cdf)[0],
269        0.01)
270    # The standard error of the sample mean is 1 / (sqrt(18 * n))
271    self.assertAllClose(
272        sample_values.mean(axis=0), stats.beta.mean(a, b), atol=1e-2)
273    self.assertAllClose(
274        np.cov(sample_values, rowvar=0), stats.beta.var(a, b), atol=1e-1)
275
276  def testBetaFullyReparameterized(self):
277    a = constant_op.constant(1.0)
278    b = constant_op.constant(2.0)
279    with backprop.GradientTape() as tape:
280      tape.watch(a)
281      tape.watch(b)
282      beta = beta_lib.Beta(a, b)
283      samples = beta.sample(100)
284    grad_a, grad_b = tape.gradient(samples, [a, b])
285    self.assertIsNotNone(grad_a)
286    self.assertIsNotNone(grad_b)
287
288  # Test that sampling with the same seed twice gives the same results.
289  def testBetaSampleMultipleTimes(self):
290    a_val = 1.
291    b_val = 2.
292    n_val = 100
293
294    random_seed.set_random_seed(654321)
295    beta1 = beta_lib.Beta(
296        concentration1=a_val, concentration0=b_val, name="beta1")
297    samples1 = self.evaluate(beta1.sample(n_val, seed=123456))
298
299    random_seed.set_random_seed(654321)
300    beta2 = beta_lib.Beta(
301        concentration1=a_val, concentration0=b_val, name="beta2")
302    samples2 = self.evaluate(beta2.sample(n_val, seed=123456))
303
304    self.assertAllClose(samples1, samples2)
305
306  def testBetaSampleMultidimensional(self):
307    a = np.random.rand(3, 2, 2).astype(np.float32)
308    b = np.random.rand(3, 2, 2).astype(np.float32)
309    beta = beta_lib.Beta(a, b)
310    n = constant_op.constant(100000)
311    samples = beta.sample(n)
312    sample_values = self.evaluate(samples)
313    self.assertEqual(sample_values.shape, (100000, 3, 2, 2))
314    self.assertFalse(np.any(sample_values < 0.0))
315    if not stats:
316      return
317    self.assertAllClose(
318        sample_values[:, 1, :].mean(axis=0),
319        stats.beta.mean(a, b)[1, :],
320        atol=1e-1)
321
322  def testBetaCdf(self):
323    shape = (30, 40, 50)
324    for dt in (np.float32, np.float64):
325      a = 10. * np.random.random(shape).astype(dt)
326      b = 10. * np.random.random(shape).astype(dt)
327      x = np.random.random(shape).astype(dt)
328      actual = self.evaluate(beta_lib.Beta(a, b).cdf(x))
329      self.assertAllEqual(np.ones(shape, dtype=np.bool_), 0. <= x)
330      self.assertAllEqual(np.ones(shape, dtype=np.bool_), 1. >= x)
331      if not stats:
332        return
333      self.assertAllClose(stats.beta.cdf(x, a, b), actual, rtol=9e-3, atol=5e-6)
334
335  def testBetaLogCdf(self):
336    shape = (30, 40, 50)
337    for dt in (np.float32, np.float64):
338      a = 10. * np.random.random(shape).astype(dt)
339      b = 10. * np.random.random(shape).astype(dt)
340      x = np.random.random(shape).astype(dt)
341      actual = self.evaluate(math_ops.exp(beta_lib.Beta(a, b).log_cdf(x)))
342      self.assertAllEqual(np.ones(shape, dtype=np.bool_), 0. <= x)
343      self.assertAllEqual(np.ones(shape, dtype=np.bool_), 1. >= x)
344      if not stats:
345        return
346      self.assertAllClose(stats.beta.cdf(x, a, b), actual, rtol=3e-3, atol=2e-5)
347
348  def testBetaWithSoftplusConcentration(self):
349    a, b = -4.2, -9.1
350    dist = beta_lib.BetaWithSoftplusConcentration(a, b)
351    self.assertAllClose(
352        self.evaluate(nn_ops.softplus(a)), self.evaluate(dist.concentration1))
353    self.assertAllClose(
354        self.evaluate(nn_ops.softplus(b)), self.evaluate(dist.concentration0))
355
356  def testBetaBetaKL(self):
357    for shape in [(10,), (4, 5)]:
358      a1 = 6.0 * np.random.random(size=shape) + 1e-4
359      b1 = 6.0 * np.random.random(size=shape) + 1e-4
360      a2 = 6.0 * np.random.random(size=shape) + 1e-4
361      b2 = 6.0 * np.random.random(size=shape) + 1e-4
362      # Take inverse softplus of values to test BetaWithSoftplusConcentration
363      a1_sp = np.log(np.exp(a1) - 1.0)
364      b1_sp = np.log(np.exp(b1) - 1.0)
365      a2_sp = np.log(np.exp(a2) - 1.0)
366      b2_sp = np.log(np.exp(b2) - 1.0)
367
368      d1 = beta_lib.Beta(concentration1=a1, concentration0=b1)
369      d2 = beta_lib.Beta(concentration1=a2, concentration0=b2)
370      d1_sp = beta_lib.BetaWithSoftplusConcentration(concentration1=a1_sp,
371                                                     concentration0=b1_sp)
372      d2_sp = beta_lib.BetaWithSoftplusConcentration(concentration1=a2_sp,
373                                                     concentration0=b2_sp)
374
375      if not special:
376        return
377      kl_expected = (special.betaln(a2, b2) - special.betaln(a1, b1) +
378                     (a1 - a2) * special.digamma(a1) +
379                     (b1 - b2) * special.digamma(b1) +
380                     (a2 - a1 + b2 - b1) * special.digamma(a1 + b1))
381
382      for dist1 in [d1, d1_sp]:
383        for dist2 in [d2, d2_sp]:
384          kl = kullback_leibler.kl_divergence(dist1, dist2)
385          kl_val = self.evaluate(kl)
386          self.assertEqual(kl.get_shape(), shape)
387          self.assertAllClose(kl_val, kl_expected)
388
389      # Make sure KL(d1||d1) is 0
390      kl_same = self.evaluate(kullback_leibler.kl_divergence(d1, d1))
391      self.assertAllClose(kl_same, np.zeros_like(kl_expected))
392
393
394if __name__ == "__main__":
395  test.main()
396