1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for the Bernoulli distribution.""" 16 17import importlib 18 19import numpy as np 20 21from tensorflow.python.eager import backprop 22from tensorflow.python.framework import constant_op 23from tensorflow.python.framework import dtypes 24from tensorflow.python.framework import test_util 25from tensorflow.python.ops import array_ops 26from tensorflow.python.ops.distributions import bernoulli 27from tensorflow.python.ops.distributions import kullback_leibler 28from tensorflow.python.platform import test 29from tensorflow.python.platform import tf_logging 30 31 32def try_import(name): # pylint: disable=invalid-name 33 module = None 34 try: 35 module = importlib.import_module(name) 36 except ImportError as e: 37 tf_logging.warning("Could not import %s: %s" % (name, str(e))) 38 return module 39 40 41special = try_import("scipy.special") 42 43 44def make_bernoulli(batch_shape, dtype=dtypes.int32): 45 p = np.random.uniform(size=list(batch_shape)) 46 p = constant_op.constant(p, dtype=dtypes.float32) 47 return bernoulli.Bernoulli(probs=p, dtype=dtype) 48 49 50def entropy(p): 51 q = 1. - p 52 return -q * np.log(q) - p * np.log(p) 53 54 55class BernoulliTest(test.TestCase): 56 57 @test_util.run_in_graph_and_eager_modes 58 def testP(self): 59 p = [0.2, 0.4] 60 dist = bernoulli.Bernoulli(probs=p) 61 self.assertAllClose(p, self.evaluate(dist.probs)) 62 63 @test_util.run_in_graph_and_eager_modes 64 def testLogits(self): 65 logits = [-42., 42.] 66 dist = bernoulli.Bernoulli(logits=logits) 67 self.assertAllClose(logits, self.evaluate(dist.logits)) 68 69 if not special: 70 return 71 72 self.assertAllClose(special.expit(logits), self.evaluate(dist.probs)) 73 74 p = [0.01, 0.99, 0.42] 75 dist = bernoulli.Bernoulli(probs=p) 76 self.assertAllClose(special.logit(p), self.evaluate(dist.logits)) 77 78 @test_util.run_in_graph_and_eager_modes 79 def testInvalidP(self): 80 invalid_ps = [1.01, 2.] 81 for p in invalid_ps: 82 with self.assertRaisesOpError("probs has components greater than 1"): 83 dist = bernoulli.Bernoulli(probs=p, validate_args=True) 84 self.evaluate(dist.probs) 85 86 invalid_ps = [-0.01, -3.] 87 for p in invalid_ps: 88 with self.assertRaisesOpError("Condition x >= 0"): 89 dist = bernoulli.Bernoulli(probs=p, validate_args=True) 90 self.evaluate(dist.probs) 91 92 valid_ps = [0.0, 0.5, 1.0] 93 for p in valid_ps: 94 dist = bernoulli.Bernoulli(probs=p) 95 self.assertEqual(p, self.evaluate(dist.probs)) # Should not fail 96 97 @test_util.run_in_graph_and_eager_modes 98 def testShapes(self): 99 for batch_shape in ([], [1], [2, 3, 4]): 100 dist = make_bernoulli(batch_shape) 101 self.assertAllEqual(batch_shape, dist.batch_shape.as_list()) 102 self.assertAllEqual(batch_shape, self.evaluate(dist.batch_shape_tensor())) 103 self.assertAllEqual([], dist.event_shape.as_list()) 104 self.assertAllEqual([], self.evaluate(dist.event_shape_tensor())) 105 106 @test_util.run_in_graph_and_eager_modes 107 def testDtype(self): 108 dist = make_bernoulli([]) 109 self.assertEqual(dist.dtype, dtypes.int32) 110 self.assertEqual(dist.dtype, dist.sample(5).dtype) 111 self.assertEqual(dist.dtype, dist.mode().dtype) 112 self.assertEqual(dist.probs.dtype, dist.mean().dtype) 113 self.assertEqual(dist.probs.dtype, dist.variance().dtype) 114 self.assertEqual(dist.probs.dtype, dist.stddev().dtype) 115 self.assertEqual(dist.probs.dtype, dist.entropy().dtype) 116 self.assertEqual(dist.probs.dtype, dist.prob(0).dtype) 117 self.assertEqual(dist.probs.dtype, dist.prob(0.5).dtype) 118 self.assertEqual(dist.probs.dtype, dist.log_prob(0).dtype) 119 self.assertEqual(dist.probs.dtype, dist.log_prob(0.5).dtype) 120 121 dist64 = make_bernoulli([], dtypes.int64) 122 self.assertEqual(dist64.dtype, dtypes.int64) 123 self.assertEqual(dist64.dtype, dist64.sample(5).dtype) 124 self.assertEqual(dist64.dtype, dist64.mode().dtype) 125 126 @test_util.run_in_graph_and_eager_modes 127 def _testPmf(self, **kwargs): 128 dist = bernoulli.Bernoulli(**kwargs) 129 # pylint: disable=bad-continuation 130 xs = [ 131 0, 132 [1], 133 [1, 0], 134 [[1, 0]], 135 [[1, 0], [1, 1]], 136 ] 137 expected_pmfs = [ 138 [[0.8, 0.6], [0.7, 0.4]], 139 [[0.2, 0.4], [0.3, 0.6]], 140 [[0.2, 0.6], [0.3, 0.4]], 141 [[0.2, 0.6], [0.3, 0.4]], 142 [[0.2, 0.6], [0.3, 0.6]], 143 ] 144 # pylint: enable=bad-continuation 145 146 for x, expected_pmf in zip(xs, expected_pmfs): 147 self.assertAllClose(self.evaluate(dist.prob(x)), expected_pmf) 148 self.assertAllClose(self.evaluate(dist.log_prob(x)), np.log(expected_pmf)) 149 150 @test_util.run_deprecated_v1 151 def testPmfCorrectBroadcastDynamicShape(self): 152 with self.cached_session(): 153 p = array_ops.placeholder(dtype=dtypes.float32) 154 dist = bernoulli.Bernoulli(probs=p) 155 event1 = [1, 0, 1] 156 event2 = [[1, 0, 1]] 157 self.assertAllClose( 158 dist.prob(event1).eval({ 159 p: [0.2, 0.3, 0.4] 160 }), [0.2, 0.7, 0.4]) 161 self.assertAllClose( 162 dist.prob(event2).eval({ 163 p: [0.2, 0.3, 0.4] 164 }), [[0.2, 0.7, 0.4]]) 165 166 @test_util.run_in_graph_and_eager_modes 167 @test_util.run_deprecated_v1 168 def testPmfInvalid(self): 169 p = [0.1, 0.2, 0.7] 170 dist = bernoulli.Bernoulli(probs=p, validate_args=True) 171 with self.assertRaisesOpError("must be non-negative."): 172 self.evaluate(dist.prob([1, 1, -1])) 173 with self.assertRaisesOpError("Elements cannot exceed 1."): 174 self.evaluate(dist.prob([2, 0, 1])) 175 176 @test_util.run_in_graph_and_eager_modes 177 def testPmfWithP(self): 178 p = [[0.2, 0.4], [0.3, 0.6]] 179 self._testPmf(probs=p) 180 if not special: 181 return 182 self._testPmf(logits=special.logit(p)) 183 184 @test_util.run_in_graph_and_eager_modes 185 def testPmfWithFloatArgReturnsXEntropy(self): 186 p = [[0.2], [0.4], [0.3], [0.6]] 187 samps = [0, 0.1, 0.8] 188 self.assertAllClose( 189 np.float32(samps) * np.log(np.float32(p)) + 190 (1 - np.float32(samps)) * np.log(1 - np.float32(p)), 191 self.evaluate( 192 bernoulli.Bernoulli(probs=p, validate_args=False).log_prob(samps))) 193 194 @test_util.run_deprecated_v1 195 def testBroadcasting(self): 196 with self.cached_session(): 197 p = array_ops.placeholder(dtypes.float32) 198 dist = bernoulli.Bernoulli(probs=p) 199 self.assertAllClose(np.log(0.5), dist.log_prob(1).eval({p: 0.5})) 200 self.assertAllClose( 201 np.log([0.5, 0.5, 0.5]), dist.log_prob([1, 1, 1]).eval({ 202 p: 0.5 203 })) 204 self.assertAllClose( 205 np.log([0.5, 0.5, 0.5]), dist.log_prob(1).eval({ 206 p: [0.5, 0.5, 0.5] 207 })) 208 209 @test_util.run_deprecated_v1 210 def testPmfShapes(self): 211 with self.cached_session(): 212 p = array_ops.placeholder(dtypes.float32, shape=[None, 1]) 213 dist = bernoulli.Bernoulli(probs=p) 214 self.assertEqual(2, len(dist.log_prob(1).eval({p: [[0.5], [0.5]]}).shape)) 215 216 dist = bernoulli.Bernoulli(probs=0.5) 217 self.assertEqual(2, len(self.evaluate(dist.log_prob([[1], [1]])).shape)) 218 219 dist = bernoulli.Bernoulli(probs=0.5) 220 self.assertEqual((), dist.log_prob(1).get_shape()) 221 self.assertEqual((1), dist.log_prob([1]).get_shape()) 222 self.assertEqual((2, 1), dist.log_prob([[1], [1]]).get_shape()) 223 224 dist = bernoulli.Bernoulli(probs=[[0.5], [0.5]]) 225 self.assertEqual((2, 1), dist.log_prob(1).get_shape()) 226 227 @test_util.run_in_graph_and_eager_modes 228 def testBoundaryConditions(self): 229 dist = bernoulli.Bernoulli(probs=1.0) 230 self.assertAllClose(np.nan, self.evaluate(dist.log_prob(0))) 231 self.assertAllClose([np.nan], [self.evaluate(dist.log_prob(1))]) 232 233 @test_util.run_in_graph_and_eager_modes 234 def testEntropyNoBatch(self): 235 p = 0.2 236 dist = bernoulli.Bernoulli(probs=p) 237 self.assertAllClose(self.evaluate(dist.entropy()), entropy(p)) 238 239 @test_util.run_in_graph_and_eager_modes 240 def testEntropyWithBatch(self): 241 p = [[0.1, 0.7], [0.2, 0.6]] 242 dist = bernoulli.Bernoulli(probs=p, validate_args=False) 243 self.assertAllClose( 244 self.evaluate(dist.entropy()), 245 [[entropy(0.1), entropy(0.7)], [entropy(0.2), 246 entropy(0.6)]]) 247 248 @test_util.run_in_graph_and_eager_modes 249 def testSampleN(self): 250 p = [0.2, 0.6] 251 dist = bernoulli.Bernoulli(probs=p) 252 n = 100000 253 samples = dist.sample(n) 254 samples.set_shape([n, 2]) 255 self.assertEqual(samples.dtype, dtypes.int32) 256 sample_values = self.evaluate(samples) 257 self.assertTrue(np.all(sample_values >= 0)) 258 self.assertTrue(np.all(sample_values <= 1)) 259 # Note that the standard error for the sample mean is ~ sqrt(p * (1 - p) / 260 # n). This means that the tolerance is very sensitive to the value of p 261 # as well as n. 262 self.assertAllClose(p, np.mean(sample_values, axis=0), atol=1e-2) 263 self.assertEqual(set([0, 1]), set(sample_values.flatten())) 264 # In this test we're just interested in verifying there isn't a crash 265 # owing to mismatched types. b/30940152 266 dist = bernoulli.Bernoulli(np.log([.2, .4])) 267 self.assertAllEqual((1, 2), dist.sample(1, seed=42).get_shape().as_list()) 268 269 @test_util.run_in_graph_and_eager_modes 270 def testNotReparameterized(self): 271 p = constant_op.constant([0.2, 0.6]) 272 with backprop.GradientTape() as tape: 273 tape.watch(p) 274 dist = bernoulli.Bernoulli(probs=p) 275 samples = dist.sample(100) 276 grad_p = tape.gradient(samples, p) 277 self.assertIsNone(grad_p) 278 279 @test_util.run_deprecated_v1 280 def testSampleActsLikeSampleN(self): 281 with self.cached_session() as sess: 282 p = [0.2, 0.6] 283 dist = bernoulli.Bernoulli(probs=p) 284 n = 1000 285 seed = 42 286 self.assertAllEqual( 287 self.evaluate(dist.sample(n, seed)), 288 self.evaluate(dist.sample(n, seed))) 289 n = array_ops.placeholder(dtypes.int32) 290 sample1, sample2 = sess.run([dist.sample(n, seed), dist.sample(n, seed)], 291 feed_dict={n: 1000}) 292 self.assertAllEqual(sample1, sample2) 293 294 @test_util.run_in_graph_and_eager_modes 295 def testMean(self): 296 p = np.array([[0.2, 0.7], [0.5, 0.4]], dtype=np.float32) 297 dist = bernoulli.Bernoulli(probs=p) 298 self.assertAllEqual(self.evaluate(dist.mean()), p) 299 300 @test_util.run_in_graph_and_eager_modes 301 def testVarianceAndStd(self): 302 var = lambda p: p * (1. - p) 303 p = [[0.2, 0.7], [0.5, 0.4]] 304 dist = bernoulli.Bernoulli(probs=p) 305 self.assertAllClose( 306 self.evaluate(dist.variance()), 307 np.array([[var(0.2), var(0.7)], [var(0.5), var(0.4)]], 308 dtype=np.float32)) 309 self.assertAllClose( 310 self.evaluate(dist.stddev()), 311 np.array([[np.sqrt(var(0.2)), np.sqrt(var(0.7))], 312 [np.sqrt(var(0.5)), np.sqrt(var(0.4))]], 313 dtype=np.float32)) 314 315 @test_util.run_in_graph_and_eager_modes 316 def testBernoulliBernoulliKL(self): 317 batch_size = 6 318 a_p = np.array([0.5] * batch_size, dtype=np.float32) 319 b_p = np.array([0.4] * batch_size, dtype=np.float32) 320 321 a = bernoulli.Bernoulli(probs=a_p) 322 b = bernoulli.Bernoulli(probs=b_p) 323 324 kl = kullback_leibler.kl_divergence(a, b) 325 kl_val = self.evaluate(kl) 326 327 kl_expected = (a_p * np.log(a_p / b_p) + (1. - a_p) * np.log( 328 (1. - a_p) / (1. - b_p))) 329 330 self.assertEqual(kl.get_shape(), (batch_size,)) 331 self.assertAllClose(kl_val, kl_expected) 332 333 334if __name__ == "__main__": 335 test.main() 336