• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for the Bernoulli distribution."""
16
17import importlib
18
19import numpy as np
20
21from tensorflow.python.eager import backprop
22from tensorflow.python.framework import constant_op
23from tensorflow.python.framework import dtypes
24from tensorflow.python.framework import test_util
25from tensorflow.python.ops import array_ops
26from tensorflow.python.ops.distributions import bernoulli
27from tensorflow.python.ops.distributions import kullback_leibler
28from tensorflow.python.platform import test
29from tensorflow.python.platform import tf_logging
30
31
32def try_import(name):  # pylint: disable=invalid-name
33  module = None
34  try:
35    module = importlib.import_module(name)
36  except ImportError as e:
37    tf_logging.warning("Could not import %s: %s" % (name, str(e)))
38  return module
39
40
41special = try_import("scipy.special")
42
43
44def make_bernoulli(batch_shape, dtype=dtypes.int32):
45  p = np.random.uniform(size=list(batch_shape))
46  p = constant_op.constant(p, dtype=dtypes.float32)
47  return bernoulli.Bernoulli(probs=p, dtype=dtype)
48
49
50def entropy(p):
51  q = 1. - p
52  return -q * np.log(q) - p * np.log(p)
53
54
55class BernoulliTest(test.TestCase):
56
57  @test_util.run_in_graph_and_eager_modes
58  def testP(self):
59    p = [0.2, 0.4]
60    dist = bernoulli.Bernoulli(probs=p)
61    self.assertAllClose(p, self.evaluate(dist.probs))
62
63  @test_util.run_in_graph_and_eager_modes
64  def testLogits(self):
65    logits = [-42., 42.]
66    dist = bernoulli.Bernoulli(logits=logits)
67    self.assertAllClose(logits, self.evaluate(dist.logits))
68
69    if not special:
70      return
71
72    self.assertAllClose(special.expit(logits), self.evaluate(dist.probs))
73
74    p = [0.01, 0.99, 0.42]
75    dist = bernoulli.Bernoulli(probs=p)
76    self.assertAllClose(special.logit(p), self.evaluate(dist.logits))
77
78  @test_util.run_in_graph_and_eager_modes
79  def testInvalidP(self):
80    invalid_ps = [1.01, 2.]
81    for p in invalid_ps:
82      with self.assertRaisesOpError("probs has components greater than 1"):
83        dist = bernoulli.Bernoulli(probs=p, validate_args=True)
84        self.evaluate(dist.probs)
85
86    invalid_ps = [-0.01, -3.]
87    for p in invalid_ps:
88      with self.assertRaisesOpError("Condition x >= 0"):
89        dist = bernoulli.Bernoulli(probs=p, validate_args=True)
90        self.evaluate(dist.probs)
91
92    valid_ps = [0.0, 0.5, 1.0]
93    for p in valid_ps:
94      dist = bernoulli.Bernoulli(probs=p)
95      self.assertEqual(p, self.evaluate(dist.probs))  # Should not fail
96
97  @test_util.run_in_graph_and_eager_modes
98  def testShapes(self):
99    for batch_shape in ([], [1], [2, 3, 4]):
100      dist = make_bernoulli(batch_shape)
101      self.assertAllEqual(batch_shape, dist.batch_shape.as_list())
102      self.assertAllEqual(batch_shape, self.evaluate(dist.batch_shape_tensor()))
103      self.assertAllEqual([], dist.event_shape.as_list())
104      self.assertAllEqual([], self.evaluate(dist.event_shape_tensor()))
105
106  @test_util.run_in_graph_and_eager_modes
107  def testDtype(self):
108    dist = make_bernoulli([])
109    self.assertEqual(dist.dtype, dtypes.int32)
110    self.assertEqual(dist.dtype, dist.sample(5).dtype)
111    self.assertEqual(dist.dtype, dist.mode().dtype)
112    self.assertEqual(dist.probs.dtype, dist.mean().dtype)
113    self.assertEqual(dist.probs.dtype, dist.variance().dtype)
114    self.assertEqual(dist.probs.dtype, dist.stddev().dtype)
115    self.assertEqual(dist.probs.dtype, dist.entropy().dtype)
116    self.assertEqual(dist.probs.dtype, dist.prob(0).dtype)
117    self.assertEqual(dist.probs.dtype, dist.prob(0.5).dtype)
118    self.assertEqual(dist.probs.dtype, dist.log_prob(0).dtype)
119    self.assertEqual(dist.probs.dtype, dist.log_prob(0.5).dtype)
120
121    dist64 = make_bernoulli([], dtypes.int64)
122    self.assertEqual(dist64.dtype, dtypes.int64)
123    self.assertEqual(dist64.dtype, dist64.sample(5).dtype)
124    self.assertEqual(dist64.dtype, dist64.mode().dtype)
125
126  @test_util.run_in_graph_and_eager_modes
127  def _testPmf(self, **kwargs):
128    dist = bernoulli.Bernoulli(**kwargs)
129    # pylint: disable=bad-continuation
130    xs = [
131        0,
132        [1],
133        [1, 0],
134        [[1, 0]],
135        [[1, 0], [1, 1]],
136    ]
137    expected_pmfs = [
138        [[0.8, 0.6], [0.7, 0.4]],
139        [[0.2, 0.4], [0.3, 0.6]],
140        [[0.2, 0.6], [0.3, 0.4]],
141        [[0.2, 0.6], [0.3, 0.4]],
142        [[0.2, 0.6], [0.3, 0.6]],
143    ]
144    # pylint: enable=bad-continuation
145
146    for x, expected_pmf in zip(xs, expected_pmfs):
147      self.assertAllClose(self.evaluate(dist.prob(x)), expected_pmf)
148      self.assertAllClose(self.evaluate(dist.log_prob(x)), np.log(expected_pmf))
149
150  @test_util.run_deprecated_v1
151  def testPmfCorrectBroadcastDynamicShape(self):
152    with self.cached_session():
153      p = array_ops.placeholder(dtype=dtypes.float32)
154      dist = bernoulli.Bernoulli(probs=p)
155      event1 = [1, 0, 1]
156      event2 = [[1, 0, 1]]
157      self.assertAllClose(
158          dist.prob(event1).eval({
159              p: [0.2, 0.3, 0.4]
160          }), [0.2, 0.7, 0.4])
161      self.assertAllClose(
162          dist.prob(event2).eval({
163              p: [0.2, 0.3, 0.4]
164          }), [[0.2, 0.7, 0.4]])
165
166  @test_util.run_in_graph_and_eager_modes
167  @test_util.run_deprecated_v1
168  def testPmfInvalid(self):
169    p = [0.1, 0.2, 0.7]
170    dist = bernoulli.Bernoulli(probs=p, validate_args=True)
171    with self.assertRaisesOpError("must be non-negative."):
172      self.evaluate(dist.prob([1, 1, -1]))
173    with self.assertRaisesOpError("Elements cannot exceed 1."):
174      self.evaluate(dist.prob([2, 0, 1]))
175
176  @test_util.run_in_graph_and_eager_modes
177  def testPmfWithP(self):
178    p = [[0.2, 0.4], [0.3, 0.6]]
179    self._testPmf(probs=p)
180    if not special:
181      return
182    self._testPmf(logits=special.logit(p))
183
184  @test_util.run_in_graph_and_eager_modes
185  def testPmfWithFloatArgReturnsXEntropy(self):
186    p = [[0.2], [0.4], [0.3], [0.6]]
187    samps = [0, 0.1, 0.8]
188    self.assertAllClose(
189        np.float32(samps) * np.log(np.float32(p)) +
190        (1 - np.float32(samps)) * np.log(1 - np.float32(p)),
191        self.evaluate(
192            bernoulli.Bernoulli(probs=p, validate_args=False).log_prob(samps)))
193
194  @test_util.run_deprecated_v1
195  def testBroadcasting(self):
196    with self.cached_session():
197      p = array_ops.placeholder(dtypes.float32)
198      dist = bernoulli.Bernoulli(probs=p)
199      self.assertAllClose(np.log(0.5), dist.log_prob(1).eval({p: 0.5}))
200      self.assertAllClose(
201          np.log([0.5, 0.5, 0.5]), dist.log_prob([1, 1, 1]).eval({
202              p: 0.5
203          }))
204      self.assertAllClose(
205          np.log([0.5, 0.5, 0.5]), dist.log_prob(1).eval({
206              p: [0.5, 0.5, 0.5]
207          }))
208
209  @test_util.run_deprecated_v1
210  def testPmfShapes(self):
211    with self.cached_session():
212      p = array_ops.placeholder(dtypes.float32, shape=[None, 1])
213      dist = bernoulli.Bernoulli(probs=p)
214      self.assertEqual(2, len(dist.log_prob(1).eval({p: [[0.5], [0.5]]}).shape))
215
216      dist = bernoulli.Bernoulli(probs=0.5)
217      self.assertEqual(2, len(self.evaluate(dist.log_prob([[1], [1]])).shape))
218
219      dist = bernoulli.Bernoulli(probs=0.5)
220      self.assertEqual((), dist.log_prob(1).get_shape())
221      self.assertEqual((1), dist.log_prob([1]).get_shape())
222      self.assertEqual((2, 1), dist.log_prob([[1], [1]]).get_shape())
223
224      dist = bernoulli.Bernoulli(probs=[[0.5], [0.5]])
225      self.assertEqual((2, 1), dist.log_prob(1).get_shape())
226
227  @test_util.run_in_graph_and_eager_modes
228  def testBoundaryConditions(self):
229    dist = bernoulli.Bernoulli(probs=1.0)
230    self.assertAllClose(np.nan, self.evaluate(dist.log_prob(0)))
231    self.assertAllClose([np.nan], [self.evaluate(dist.log_prob(1))])
232
233  @test_util.run_in_graph_and_eager_modes
234  def testEntropyNoBatch(self):
235    p = 0.2
236    dist = bernoulli.Bernoulli(probs=p)
237    self.assertAllClose(self.evaluate(dist.entropy()), entropy(p))
238
239  @test_util.run_in_graph_and_eager_modes
240  def testEntropyWithBatch(self):
241    p = [[0.1, 0.7], [0.2, 0.6]]
242    dist = bernoulli.Bernoulli(probs=p, validate_args=False)
243    self.assertAllClose(
244        self.evaluate(dist.entropy()),
245        [[entropy(0.1), entropy(0.7)], [entropy(0.2),
246                                        entropy(0.6)]])
247
248  @test_util.run_in_graph_and_eager_modes
249  def testSampleN(self):
250    p = [0.2, 0.6]
251    dist = bernoulli.Bernoulli(probs=p)
252    n = 100000
253    samples = dist.sample(n)
254    samples.set_shape([n, 2])
255    self.assertEqual(samples.dtype, dtypes.int32)
256    sample_values = self.evaluate(samples)
257    self.assertTrue(np.all(sample_values >= 0))
258    self.assertTrue(np.all(sample_values <= 1))
259    # Note that the standard error for the sample mean is ~ sqrt(p * (1 - p) /
260    # n). This means that the tolerance is very sensitive to the value of p
261    # as well as n.
262    self.assertAllClose(p, np.mean(sample_values, axis=0), atol=1e-2)
263    self.assertEqual(set([0, 1]), set(sample_values.flatten()))
264    # In this test we're just interested in verifying there isn't a crash
265    # owing to mismatched types. b/30940152
266    dist = bernoulli.Bernoulli(np.log([.2, .4]))
267    self.assertAllEqual((1, 2), dist.sample(1, seed=42).get_shape().as_list())
268
269  @test_util.run_in_graph_and_eager_modes
270  def testNotReparameterized(self):
271    p = constant_op.constant([0.2, 0.6])
272    with backprop.GradientTape() as tape:
273      tape.watch(p)
274      dist = bernoulli.Bernoulli(probs=p)
275      samples = dist.sample(100)
276    grad_p = tape.gradient(samples, p)
277    self.assertIsNone(grad_p)
278
279  @test_util.run_deprecated_v1
280  def testSampleActsLikeSampleN(self):
281    with self.cached_session() as sess:
282      p = [0.2, 0.6]
283      dist = bernoulli.Bernoulli(probs=p)
284      n = 1000
285      seed = 42
286      self.assertAllEqual(
287          self.evaluate(dist.sample(n, seed)),
288          self.evaluate(dist.sample(n, seed)))
289      n = array_ops.placeholder(dtypes.int32)
290      sample1, sample2 = sess.run([dist.sample(n, seed), dist.sample(n, seed)],
291                                  feed_dict={n: 1000})
292      self.assertAllEqual(sample1, sample2)
293
294  @test_util.run_in_graph_and_eager_modes
295  def testMean(self):
296    p = np.array([[0.2, 0.7], [0.5, 0.4]], dtype=np.float32)
297    dist = bernoulli.Bernoulli(probs=p)
298    self.assertAllEqual(self.evaluate(dist.mean()), p)
299
300  @test_util.run_in_graph_and_eager_modes
301  def testVarianceAndStd(self):
302    var = lambda p: p * (1. - p)
303    p = [[0.2, 0.7], [0.5, 0.4]]
304    dist = bernoulli.Bernoulli(probs=p)
305    self.assertAllClose(
306        self.evaluate(dist.variance()),
307        np.array([[var(0.2), var(0.7)], [var(0.5), var(0.4)]],
308                 dtype=np.float32))
309    self.assertAllClose(
310        self.evaluate(dist.stddev()),
311        np.array([[np.sqrt(var(0.2)), np.sqrt(var(0.7))],
312                  [np.sqrt(var(0.5)), np.sqrt(var(0.4))]],
313                 dtype=np.float32))
314
315  @test_util.run_in_graph_and_eager_modes
316  def testBernoulliBernoulliKL(self):
317    batch_size = 6
318    a_p = np.array([0.5] * batch_size, dtype=np.float32)
319    b_p = np.array([0.4] * batch_size, dtype=np.float32)
320
321    a = bernoulli.Bernoulli(probs=a_p)
322    b = bernoulli.Bernoulli(probs=b_p)
323
324    kl = kullback_leibler.kl_divergence(a, b)
325    kl_val = self.evaluate(kl)
326
327    kl_expected = (a_p * np.log(a_p / b_p) + (1. - a_p) * np.log(
328        (1. - a_p) / (1. - b_p)))
329
330    self.assertEqual(kl.get_shape(), (batch_size,))
331    self.assertAllClose(kl_val, kl_expected)
332
333
334if __name__ == "__main__":
335  test.main()
336