• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for the Bernoulli distribution."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import importlib
22
23import numpy as np
24
25from tensorflow.python.eager import backprop
26from tensorflow.python.framework import constant_op
27from tensorflow.python.framework import dtypes
28from tensorflow.python.framework import test_util
29from tensorflow.python.ops import array_ops
30from tensorflow.python.ops.distributions import bernoulli
31from tensorflow.python.ops.distributions import kullback_leibler
32from tensorflow.python.platform import test
33from tensorflow.python.platform import tf_logging
34
35
36def try_import(name):  # pylint: disable=invalid-name
37  module = None
38  try:
39    module = importlib.import_module(name)
40  except ImportError as e:
41    tf_logging.warning("Could not import %s: %s" % (name, str(e)))
42  return module
43
44
45special = try_import("scipy.special")
46
47
48def make_bernoulli(batch_shape, dtype=dtypes.int32):
49  p = np.random.uniform(size=list(batch_shape))
50  p = constant_op.constant(p, dtype=dtypes.float32)
51  return bernoulli.Bernoulli(probs=p, dtype=dtype)
52
53
54def entropy(p):
55  q = 1. - p
56  return -q * np.log(q) - p * np.log(p)
57
58
59class BernoulliTest(test.TestCase):
60
61  @test_util.run_in_graph_and_eager_modes
62  def testP(self):
63    p = [0.2, 0.4]
64    dist = bernoulli.Bernoulli(probs=p)
65    self.assertAllClose(p, self.evaluate(dist.probs))
66
67  @test_util.run_in_graph_and_eager_modes
68  def testLogits(self):
69    logits = [-42., 42.]
70    dist = bernoulli.Bernoulli(logits=logits)
71    self.assertAllClose(logits, self.evaluate(dist.logits))
72
73    if not special:
74      return
75
76    self.assertAllClose(special.expit(logits), self.evaluate(dist.probs))
77
78    p = [0.01, 0.99, 0.42]
79    dist = bernoulli.Bernoulli(probs=p)
80    self.assertAllClose(special.logit(p), self.evaluate(dist.logits))
81
82  @test_util.run_in_graph_and_eager_modes
83  def testInvalidP(self):
84    invalid_ps = [1.01, 2.]
85    for p in invalid_ps:
86      with self.assertRaisesOpError("probs has components greater than 1"):
87        dist = bernoulli.Bernoulli(probs=p, validate_args=True)
88        self.evaluate(dist.probs)
89
90    invalid_ps = [-0.01, -3.]
91    for p in invalid_ps:
92      with self.assertRaisesOpError("Condition x >= 0"):
93        dist = bernoulli.Bernoulli(probs=p, validate_args=True)
94        self.evaluate(dist.probs)
95
96    valid_ps = [0.0, 0.5, 1.0]
97    for p in valid_ps:
98      dist = bernoulli.Bernoulli(probs=p)
99      self.assertEqual(p, self.evaluate(dist.probs))  # Should not fail
100
101  @test_util.run_in_graph_and_eager_modes
102  def testShapes(self):
103    for batch_shape in ([], [1], [2, 3, 4]):
104      dist = make_bernoulli(batch_shape)
105      self.assertAllEqual(batch_shape, dist.batch_shape.as_list())
106      self.assertAllEqual(batch_shape, self.evaluate(dist.batch_shape_tensor()))
107      self.assertAllEqual([], dist.event_shape.as_list())
108      self.assertAllEqual([], self.evaluate(dist.event_shape_tensor()))
109
110  @test_util.run_in_graph_and_eager_modes
111  def testDtype(self):
112    dist = make_bernoulli([])
113    self.assertEqual(dist.dtype, dtypes.int32)
114    self.assertEqual(dist.dtype, dist.sample(5).dtype)
115    self.assertEqual(dist.dtype, dist.mode().dtype)
116    self.assertEqual(dist.probs.dtype, dist.mean().dtype)
117    self.assertEqual(dist.probs.dtype, dist.variance().dtype)
118    self.assertEqual(dist.probs.dtype, dist.stddev().dtype)
119    self.assertEqual(dist.probs.dtype, dist.entropy().dtype)
120    self.assertEqual(dist.probs.dtype, dist.prob(0).dtype)
121    self.assertEqual(dist.probs.dtype, dist.prob(0.5).dtype)
122    self.assertEqual(dist.probs.dtype, dist.log_prob(0).dtype)
123    self.assertEqual(dist.probs.dtype, dist.log_prob(0.5).dtype)
124
125    dist64 = make_bernoulli([], dtypes.int64)
126    self.assertEqual(dist64.dtype, dtypes.int64)
127    self.assertEqual(dist64.dtype, dist64.sample(5).dtype)
128    self.assertEqual(dist64.dtype, dist64.mode().dtype)
129
130  @test_util.run_in_graph_and_eager_modes
131  def _testPmf(self, **kwargs):
132    dist = bernoulli.Bernoulli(**kwargs)
133    # pylint: disable=bad-continuation
134    xs = [
135        0,
136        [1],
137        [1, 0],
138        [[1, 0]],
139        [[1, 0], [1, 1]],
140    ]
141    expected_pmfs = [
142        [[0.8, 0.6], [0.7, 0.4]],
143        [[0.2, 0.4], [0.3, 0.6]],
144        [[0.2, 0.6], [0.3, 0.4]],
145        [[0.2, 0.6], [0.3, 0.4]],
146        [[0.2, 0.6], [0.3, 0.6]],
147    ]
148    # pylint: enable=bad-continuation
149
150    for x, expected_pmf in zip(xs, expected_pmfs):
151      self.assertAllClose(self.evaluate(dist.prob(x)), expected_pmf)
152      self.assertAllClose(self.evaluate(dist.log_prob(x)), np.log(expected_pmf))
153
154  @test_util.run_deprecated_v1
155  def testPmfCorrectBroadcastDynamicShape(self):
156    with self.cached_session():
157      p = array_ops.placeholder(dtype=dtypes.float32)
158      dist = bernoulli.Bernoulli(probs=p)
159      event1 = [1, 0, 1]
160      event2 = [[1, 0, 1]]
161      self.assertAllClose(
162          dist.prob(event1).eval({
163              p: [0.2, 0.3, 0.4]
164          }), [0.2, 0.7, 0.4])
165      self.assertAllClose(
166          dist.prob(event2).eval({
167              p: [0.2, 0.3, 0.4]
168          }), [[0.2, 0.7, 0.4]])
169
170  @test_util.run_in_graph_and_eager_modes
171  @test_util.run_deprecated_v1
172  def testPmfInvalid(self):
173    p = [0.1, 0.2, 0.7]
174    dist = bernoulli.Bernoulli(probs=p, validate_args=True)
175    with self.assertRaisesOpError("must be non-negative."):
176      self.evaluate(dist.prob([1, 1, -1]))
177    with self.assertRaisesOpError("Elements cannot exceed 1."):
178      self.evaluate(dist.prob([2, 0, 1]))
179
180  @test_util.run_in_graph_and_eager_modes
181  def testPmfWithP(self):
182    p = [[0.2, 0.4], [0.3, 0.6]]
183    self._testPmf(probs=p)
184    if not special:
185      return
186    self._testPmf(logits=special.logit(p))
187
188  @test_util.run_in_graph_and_eager_modes
189  def testPmfWithFloatArgReturnsXEntropy(self):
190    p = [[0.2], [0.4], [0.3], [0.6]]
191    samps = [0, 0.1, 0.8]
192    self.assertAllClose(
193        np.float32(samps) * np.log(np.float32(p)) +
194        (1 - np.float32(samps)) * np.log(1 - np.float32(p)),
195        self.evaluate(
196            bernoulli.Bernoulli(probs=p, validate_args=False).log_prob(samps)))
197
198  @test_util.run_deprecated_v1
199  def testBroadcasting(self):
200    with self.cached_session():
201      p = array_ops.placeholder(dtypes.float32)
202      dist = bernoulli.Bernoulli(probs=p)
203      self.assertAllClose(np.log(0.5), dist.log_prob(1).eval({p: 0.5}))
204      self.assertAllClose(
205          np.log([0.5, 0.5, 0.5]), dist.log_prob([1, 1, 1]).eval({
206              p: 0.5
207          }))
208      self.assertAllClose(
209          np.log([0.5, 0.5, 0.5]), dist.log_prob(1).eval({
210              p: [0.5, 0.5, 0.5]
211          }))
212
213  @test_util.run_deprecated_v1
214  def testPmfShapes(self):
215    with self.cached_session():
216      p = array_ops.placeholder(dtypes.float32, shape=[None, 1])
217      dist = bernoulli.Bernoulli(probs=p)
218      self.assertEqual(2, len(dist.log_prob(1).eval({p: [[0.5], [0.5]]}).shape))
219
220      dist = bernoulli.Bernoulli(probs=0.5)
221      self.assertEqual(2, len(self.evaluate(dist.log_prob([[1], [1]])).shape))
222
223      dist = bernoulli.Bernoulli(probs=0.5)
224      self.assertEqual((), dist.log_prob(1).get_shape())
225      self.assertEqual((1), dist.log_prob([1]).get_shape())
226      self.assertEqual((2, 1), dist.log_prob([[1], [1]]).get_shape())
227
228      dist = bernoulli.Bernoulli(probs=[[0.5], [0.5]])
229      self.assertEqual((2, 1), dist.log_prob(1).get_shape())
230
231  @test_util.run_in_graph_and_eager_modes
232  def testBoundaryConditions(self):
233    dist = bernoulli.Bernoulli(probs=1.0)
234    self.assertAllClose(np.nan, self.evaluate(dist.log_prob(0)))
235    self.assertAllClose([np.nan], [self.evaluate(dist.log_prob(1))])
236
237  @test_util.run_in_graph_and_eager_modes
238  def testEntropyNoBatch(self):
239    p = 0.2
240    dist = bernoulli.Bernoulli(probs=p)
241    self.assertAllClose(self.evaluate(dist.entropy()), entropy(p))
242
243  @test_util.run_in_graph_and_eager_modes
244  def testEntropyWithBatch(self):
245    p = [[0.1, 0.7], [0.2, 0.6]]
246    dist = bernoulli.Bernoulli(probs=p, validate_args=False)
247    self.assertAllClose(
248        self.evaluate(dist.entropy()),
249        [[entropy(0.1), entropy(0.7)], [entropy(0.2),
250                                        entropy(0.6)]])
251
252  @test_util.run_in_graph_and_eager_modes
253  def testSampleN(self):
254    p = [0.2, 0.6]
255    dist = bernoulli.Bernoulli(probs=p)
256    n = 100000
257    samples = dist.sample(n)
258    samples.set_shape([n, 2])
259    self.assertEqual(samples.dtype, dtypes.int32)
260    sample_values = self.evaluate(samples)
261    self.assertTrue(np.all(sample_values >= 0))
262    self.assertTrue(np.all(sample_values <= 1))
263    # Note that the standard error for the sample mean is ~ sqrt(p * (1 - p) /
264    # n). This means that the tolerance is very sensitive to the value of p
265    # as well as n.
266    self.assertAllClose(p, np.mean(sample_values, axis=0), atol=1e-2)
267    self.assertEqual(set([0, 1]), set(sample_values.flatten()))
268    # In this test we're just interested in verifying there isn't a crash
269    # owing to mismatched types. b/30940152
270    dist = bernoulli.Bernoulli(np.log([.2, .4]))
271    self.assertAllEqual((1, 2), dist.sample(1, seed=42).get_shape().as_list())
272
273  @test_util.run_in_graph_and_eager_modes
274  def testNotReparameterized(self):
275    p = constant_op.constant([0.2, 0.6])
276    with backprop.GradientTape() as tape:
277      tape.watch(p)
278      dist = bernoulli.Bernoulli(probs=p)
279      samples = dist.sample(100)
280    grad_p = tape.gradient(samples, p)
281    self.assertIsNone(grad_p)
282
283  @test_util.run_deprecated_v1
284  def testSampleActsLikeSampleN(self):
285    with self.cached_session() as sess:
286      p = [0.2, 0.6]
287      dist = bernoulli.Bernoulli(probs=p)
288      n = 1000
289      seed = 42
290      self.assertAllEqual(
291          self.evaluate(dist.sample(n, seed)),
292          self.evaluate(dist.sample(n, seed)))
293      n = array_ops.placeholder(dtypes.int32)
294      sample1, sample2 = sess.run([dist.sample(n, seed), dist.sample(n, seed)],
295                                  feed_dict={n: 1000})
296      self.assertAllEqual(sample1, sample2)
297
298  @test_util.run_in_graph_and_eager_modes
299  def testMean(self):
300    p = np.array([[0.2, 0.7], [0.5, 0.4]], dtype=np.float32)
301    dist = bernoulli.Bernoulli(probs=p)
302    self.assertAllEqual(self.evaluate(dist.mean()), p)
303
304  @test_util.run_in_graph_and_eager_modes
305  def testVarianceAndStd(self):
306    var = lambda p: p * (1. - p)
307    p = [[0.2, 0.7], [0.5, 0.4]]
308    dist = bernoulli.Bernoulli(probs=p)
309    self.assertAllClose(
310        self.evaluate(dist.variance()),
311        np.array([[var(0.2), var(0.7)], [var(0.5), var(0.4)]],
312                 dtype=np.float32))
313    self.assertAllClose(
314        self.evaluate(dist.stddev()),
315        np.array([[np.sqrt(var(0.2)), np.sqrt(var(0.7))],
316                  [np.sqrt(var(0.5)), np.sqrt(var(0.4))]],
317                 dtype=np.float32))
318
319  @test_util.run_in_graph_and_eager_modes
320  def testBernoulliBernoulliKL(self):
321    batch_size = 6
322    a_p = np.array([0.5] * batch_size, dtype=np.float32)
323    b_p = np.array([0.4] * batch_size, dtype=np.float32)
324
325    a = bernoulli.Bernoulli(probs=a_p)
326    b = bernoulli.Bernoulli(probs=b_p)
327
328    kl = kullback_leibler.kl_divergence(a, b)
329    kl_val = self.evaluate(kl)
330
331    kl_expected = (a_p * np.log(a_p / b_p) + (1. - a_p) * np.log(
332        (1. - a_p) / (1. - b_p)))
333
334    self.assertEqual(kl.get_shape(), (batch_size,))
335    self.assertAllClose(kl_val, kl_expected)
336
337
338if __name__ == "__main__":
339  test.main()
340