1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for the Bernoulli distribution.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import importlib 22 23import numpy as np 24 25from tensorflow.python.eager import backprop 26from tensorflow.python.framework import constant_op 27from tensorflow.python.framework import dtypes 28from tensorflow.python.framework import test_util 29from tensorflow.python.ops import array_ops 30from tensorflow.python.ops.distributions import bernoulli 31from tensorflow.python.ops.distributions import kullback_leibler 32from tensorflow.python.platform import test 33from tensorflow.python.platform import tf_logging 34 35 36def try_import(name): # pylint: disable=invalid-name 37 module = None 38 try: 39 module = importlib.import_module(name) 40 except ImportError as e: 41 tf_logging.warning("Could not import %s: %s" % (name, str(e))) 42 return module 43 44 45special = try_import("scipy.special") 46 47 48def make_bernoulli(batch_shape, dtype=dtypes.int32): 49 p = np.random.uniform(size=list(batch_shape)) 50 p = constant_op.constant(p, dtype=dtypes.float32) 51 return bernoulli.Bernoulli(probs=p, dtype=dtype) 52 53 54def entropy(p): 55 q = 1. - p 56 return -q * np.log(q) - p * np.log(p) 57 58 59class BernoulliTest(test.TestCase): 60 61 @test_util.run_in_graph_and_eager_modes 62 def testP(self): 63 p = [0.2, 0.4] 64 dist = bernoulli.Bernoulli(probs=p) 65 self.assertAllClose(p, self.evaluate(dist.probs)) 66 67 @test_util.run_in_graph_and_eager_modes 68 def testLogits(self): 69 logits = [-42., 42.] 70 dist = bernoulli.Bernoulli(logits=logits) 71 self.assertAllClose(logits, self.evaluate(dist.logits)) 72 73 if not special: 74 return 75 76 self.assertAllClose(special.expit(logits), self.evaluate(dist.probs)) 77 78 p = [0.01, 0.99, 0.42] 79 dist = bernoulli.Bernoulli(probs=p) 80 self.assertAllClose(special.logit(p), self.evaluate(dist.logits)) 81 82 @test_util.run_in_graph_and_eager_modes 83 def testInvalidP(self): 84 invalid_ps = [1.01, 2.] 85 for p in invalid_ps: 86 with self.assertRaisesOpError("probs has components greater than 1"): 87 dist = bernoulli.Bernoulli(probs=p, validate_args=True) 88 self.evaluate(dist.probs) 89 90 invalid_ps = [-0.01, -3.] 91 for p in invalid_ps: 92 with self.assertRaisesOpError("Condition x >= 0"): 93 dist = bernoulli.Bernoulli(probs=p, validate_args=True) 94 self.evaluate(dist.probs) 95 96 valid_ps = [0.0, 0.5, 1.0] 97 for p in valid_ps: 98 dist = bernoulli.Bernoulli(probs=p) 99 self.assertEqual(p, self.evaluate(dist.probs)) # Should not fail 100 101 @test_util.run_in_graph_and_eager_modes 102 def testShapes(self): 103 for batch_shape in ([], [1], [2, 3, 4]): 104 dist = make_bernoulli(batch_shape) 105 self.assertAllEqual(batch_shape, dist.batch_shape.as_list()) 106 self.assertAllEqual(batch_shape, self.evaluate(dist.batch_shape_tensor())) 107 self.assertAllEqual([], dist.event_shape.as_list()) 108 self.assertAllEqual([], self.evaluate(dist.event_shape_tensor())) 109 110 @test_util.run_in_graph_and_eager_modes 111 def testDtype(self): 112 dist = make_bernoulli([]) 113 self.assertEqual(dist.dtype, dtypes.int32) 114 self.assertEqual(dist.dtype, dist.sample(5).dtype) 115 self.assertEqual(dist.dtype, dist.mode().dtype) 116 self.assertEqual(dist.probs.dtype, dist.mean().dtype) 117 self.assertEqual(dist.probs.dtype, dist.variance().dtype) 118 self.assertEqual(dist.probs.dtype, dist.stddev().dtype) 119 self.assertEqual(dist.probs.dtype, dist.entropy().dtype) 120 self.assertEqual(dist.probs.dtype, dist.prob(0).dtype) 121 self.assertEqual(dist.probs.dtype, dist.prob(0.5).dtype) 122 self.assertEqual(dist.probs.dtype, dist.log_prob(0).dtype) 123 self.assertEqual(dist.probs.dtype, dist.log_prob(0.5).dtype) 124 125 dist64 = make_bernoulli([], dtypes.int64) 126 self.assertEqual(dist64.dtype, dtypes.int64) 127 self.assertEqual(dist64.dtype, dist64.sample(5).dtype) 128 self.assertEqual(dist64.dtype, dist64.mode().dtype) 129 130 @test_util.run_in_graph_and_eager_modes 131 def _testPmf(self, **kwargs): 132 dist = bernoulli.Bernoulli(**kwargs) 133 # pylint: disable=bad-continuation 134 xs = [ 135 0, 136 [1], 137 [1, 0], 138 [[1, 0]], 139 [[1, 0], [1, 1]], 140 ] 141 expected_pmfs = [ 142 [[0.8, 0.6], [0.7, 0.4]], 143 [[0.2, 0.4], [0.3, 0.6]], 144 [[0.2, 0.6], [0.3, 0.4]], 145 [[0.2, 0.6], [0.3, 0.4]], 146 [[0.2, 0.6], [0.3, 0.6]], 147 ] 148 # pylint: enable=bad-continuation 149 150 for x, expected_pmf in zip(xs, expected_pmfs): 151 self.assertAllClose(self.evaluate(dist.prob(x)), expected_pmf) 152 self.assertAllClose(self.evaluate(dist.log_prob(x)), np.log(expected_pmf)) 153 154 @test_util.run_deprecated_v1 155 def testPmfCorrectBroadcastDynamicShape(self): 156 with self.cached_session(): 157 p = array_ops.placeholder(dtype=dtypes.float32) 158 dist = bernoulli.Bernoulli(probs=p) 159 event1 = [1, 0, 1] 160 event2 = [[1, 0, 1]] 161 self.assertAllClose( 162 dist.prob(event1).eval({ 163 p: [0.2, 0.3, 0.4] 164 }), [0.2, 0.7, 0.4]) 165 self.assertAllClose( 166 dist.prob(event2).eval({ 167 p: [0.2, 0.3, 0.4] 168 }), [[0.2, 0.7, 0.4]]) 169 170 @test_util.run_in_graph_and_eager_modes 171 @test_util.run_deprecated_v1 172 def testPmfInvalid(self): 173 p = [0.1, 0.2, 0.7] 174 dist = bernoulli.Bernoulli(probs=p, validate_args=True) 175 with self.assertRaisesOpError("must be non-negative."): 176 self.evaluate(dist.prob([1, 1, -1])) 177 with self.assertRaisesOpError("Elements cannot exceed 1."): 178 self.evaluate(dist.prob([2, 0, 1])) 179 180 @test_util.run_in_graph_and_eager_modes 181 def testPmfWithP(self): 182 p = [[0.2, 0.4], [0.3, 0.6]] 183 self._testPmf(probs=p) 184 if not special: 185 return 186 self._testPmf(logits=special.logit(p)) 187 188 @test_util.run_in_graph_and_eager_modes 189 def testPmfWithFloatArgReturnsXEntropy(self): 190 p = [[0.2], [0.4], [0.3], [0.6]] 191 samps = [0, 0.1, 0.8] 192 self.assertAllClose( 193 np.float32(samps) * np.log(np.float32(p)) + 194 (1 - np.float32(samps)) * np.log(1 - np.float32(p)), 195 self.evaluate( 196 bernoulli.Bernoulli(probs=p, validate_args=False).log_prob(samps))) 197 198 @test_util.run_deprecated_v1 199 def testBroadcasting(self): 200 with self.cached_session(): 201 p = array_ops.placeholder(dtypes.float32) 202 dist = bernoulli.Bernoulli(probs=p) 203 self.assertAllClose(np.log(0.5), dist.log_prob(1).eval({p: 0.5})) 204 self.assertAllClose( 205 np.log([0.5, 0.5, 0.5]), dist.log_prob([1, 1, 1]).eval({ 206 p: 0.5 207 })) 208 self.assertAllClose( 209 np.log([0.5, 0.5, 0.5]), dist.log_prob(1).eval({ 210 p: [0.5, 0.5, 0.5] 211 })) 212 213 @test_util.run_deprecated_v1 214 def testPmfShapes(self): 215 with self.cached_session(): 216 p = array_ops.placeholder(dtypes.float32, shape=[None, 1]) 217 dist = bernoulli.Bernoulli(probs=p) 218 self.assertEqual(2, len(dist.log_prob(1).eval({p: [[0.5], [0.5]]}).shape)) 219 220 dist = bernoulli.Bernoulli(probs=0.5) 221 self.assertEqual(2, len(self.evaluate(dist.log_prob([[1], [1]])).shape)) 222 223 dist = bernoulli.Bernoulli(probs=0.5) 224 self.assertEqual((), dist.log_prob(1).get_shape()) 225 self.assertEqual((1), dist.log_prob([1]).get_shape()) 226 self.assertEqual((2, 1), dist.log_prob([[1], [1]]).get_shape()) 227 228 dist = bernoulli.Bernoulli(probs=[[0.5], [0.5]]) 229 self.assertEqual((2, 1), dist.log_prob(1).get_shape()) 230 231 @test_util.run_in_graph_and_eager_modes 232 def testBoundaryConditions(self): 233 dist = bernoulli.Bernoulli(probs=1.0) 234 self.assertAllClose(np.nan, self.evaluate(dist.log_prob(0))) 235 self.assertAllClose([np.nan], [self.evaluate(dist.log_prob(1))]) 236 237 @test_util.run_in_graph_and_eager_modes 238 def testEntropyNoBatch(self): 239 p = 0.2 240 dist = bernoulli.Bernoulli(probs=p) 241 self.assertAllClose(self.evaluate(dist.entropy()), entropy(p)) 242 243 @test_util.run_in_graph_and_eager_modes 244 def testEntropyWithBatch(self): 245 p = [[0.1, 0.7], [0.2, 0.6]] 246 dist = bernoulli.Bernoulli(probs=p, validate_args=False) 247 self.assertAllClose( 248 self.evaluate(dist.entropy()), 249 [[entropy(0.1), entropy(0.7)], [entropy(0.2), 250 entropy(0.6)]]) 251 252 @test_util.run_in_graph_and_eager_modes 253 def testSampleN(self): 254 p = [0.2, 0.6] 255 dist = bernoulli.Bernoulli(probs=p) 256 n = 100000 257 samples = dist.sample(n) 258 samples.set_shape([n, 2]) 259 self.assertEqual(samples.dtype, dtypes.int32) 260 sample_values = self.evaluate(samples) 261 self.assertTrue(np.all(sample_values >= 0)) 262 self.assertTrue(np.all(sample_values <= 1)) 263 # Note that the standard error for the sample mean is ~ sqrt(p * (1 - p) / 264 # n). This means that the tolerance is very sensitive to the value of p 265 # as well as n. 266 self.assertAllClose(p, np.mean(sample_values, axis=0), atol=1e-2) 267 self.assertEqual(set([0, 1]), set(sample_values.flatten())) 268 # In this test we're just interested in verifying there isn't a crash 269 # owing to mismatched types. b/30940152 270 dist = bernoulli.Bernoulli(np.log([.2, .4])) 271 self.assertAllEqual((1, 2), dist.sample(1, seed=42).get_shape().as_list()) 272 273 @test_util.run_in_graph_and_eager_modes 274 def testNotReparameterized(self): 275 p = constant_op.constant([0.2, 0.6]) 276 with backprop.GradientTape() as tape: 277 tape.watch(p) 278 dist = bernoulli.Bernoulli(probs=p) 279 samples = dist.sample(100) 280 grad_p = tape.gradient(samples, p) 281 self.assertIsNone(grad_p) 282 283 @test_util.run_deprecated_v1 284 def testSampleActsLikeSampleN(self): 285 with self.cached_session() as sess: 286 p = [0.2, 0.6] 287 dist = bernoulli.Bernoulli(probs=p) 288 n = 1000 289 seed = 42 290 self.assertAllEqual( 291 self.evaluate(dist.sample(n, seed)), 292 self.evaluate(dist.sample(n, seed))) 293 n = array_ops.placeholder(dtypes.int32) 294 sample1, sample2 = sess.run([dist.sample(n, seed), dist.sample(n, seed)], 295 feed_dict={n: 1000}) 296 self.assertAllEqual(sample1, sample2) 297 298 @test_util.run_in_graph_and_eager_modes 299 def testMean(self): 300 p = np.array([[0.2, 0.7], [0.5, 0.4]], dtype=np.float32) 301 dist = bernoulli.Bernoulli(probs=p) 302 self.assertAllEqual(self.evaluate(dist.mean()), p) 303 304 @test_util.run_in_graph_and_eager_modes 305 def testVarianceAndStd(self): 306 var = lambda p: p * (1. - p) 307 p = [[0.2, 0.7], [0.5, 0.4]] 308 dist = bernoulli.Bernoulli(probs=p) 309 self.assertAllClose( 310 self.evaluate(dist.variance()), 311 np.array([[var(0.2), var(0.7)], [var(0.5), var(0.4)]], 312 dtype=np.float32)) 313 self.assertAllClose( 314 self.evaluate(dist.stddev()), 315 np.array([[np.sqrt(var(0.2)), np.sqrt(var(0.7))], 316 [np.sqrt(var(0.5)), np.sqrt(var(0.4))]], 317 dtype=np.float32)) 318 319 @test_util.run_in_graph_and_eager_modes 320 def testBernoulliBernoulliKL(self): 321 batch_size = 6 322 a_p = np.array([0.5] * batch_size, dtype=np.float32) 323 b_p = np.array([0.4] * batch_size, dtype=np.float32) 324 325 a = bernoulli.Bernoulli(probs=a_p) 326 b = bernoulli.Bernoulli(probs=b_p) 327 328 kl = kullback_leibler.kl_divergence(a, b) 329 kl_val = self.evaluate(kl) 330 331 kl_expected = (a_p * np.log(a_p / b_p) + (1. - a_p) * np.log( 332 (1. - a_p) / (1. - b_p))) 333 334 self.assertEqual(kl.get_shape(), (batch_size,)) 335 self.assertAllClose(kl_val, kl_expected) 336 337 338if __name__ == "__main__": 339 test.main() 340