1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15import importlib 16 17import numpy as np 18 19from tensorflow.python.eager import backprop 20from tensorflow.python.framework import constant_op 21from tensorflow.python.framework import random_seed 22from tensorflow.python.framework import tensor_shape 23from tensorflow.python.framework import test_util 24from tensorflow.python.ops import math_ops 25from tensorflow.python.ops import nn_ops 26from tensorflow.python.ops.distributions import beta as beta_lib 27from tensorflow.python.ops.distributions import kullback_leibler 28from tensorflow.python.platform import test 29from tensorflow.python.platform import tf_logging 30 31 32def try_import(name): # pylint: disable=invalid-name 33 module = None 34 try: 35 module = importlib.import_module(name) 36 except ImportError as e: 37 tf_logging.warning("Could not import %s: %s" % (name, str(e))) 38 return module 39 40 41special = try_import("scipy.special") 42stats = try_import("scipy.stats") 43 44 45@test_util.run_all_in_graph_and_eager_modes 46class BetaTest(test.TestCase): 47 48 def testSimpleShapes(self): 49 a = np.random.rand(3) 50 b = np.random.rand(3) 51 dist = beta_lib.Beta(a, b) 52 self.assertAllEqual([], self.evaluate(dist.event_shape_tensor())) 53 self.assertAllEqual([3], self.evaluate(dist.batch_shape_tensor())) 54 self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape) 55 self.assertEqual(tensor_shape.TensorShape([3]), dist.batch_shape) 56 57 def testComplexShapes(self): 58 a = np.random.rand(3, 2, 2) 59 b = np.random.rand(3, 2, 2) 60 dist = beta_lib.Beta(a, b) 61 self.assertAllEqual([], self.evaluate(dist.event_shape_tensor())) 62 self.assertAllEqual([3, 2, 2], self.evaluate(dist.batch_shape_tensor())) 63 self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape) 64 self.assertEqual(tensor_shape.TensorShape([3, 2, 2]), dist.batch_shape) 65 66 def testComplexShapesBroadcast(self): 67 a = np.random.rand(3, 2, 2) 68 b = np.random.rand(2, 2) 69 dist = beta_lib.Beta(a, b) 70 self.assertAllEqual([], self.evaluate(dist.event_shape_tensor())) 71 self.assertAllEqual([3, 2, 2], self.evaluate(dist.batch_shape_tensor())) 72 self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape) 73 self.assertEqual(tensor_shape.TensorShape([3, 2, 2]), dist.batch_shape) 74 75 def testAlphaProperty(self): 76 a = [[1., 2, 3]] 77 b = [[2., 4, 3]] 78 dist = beta_lib.Beta(a, b) 79 self.assertEqual([1, 3], dist.concentration1.get_shape()) 80 self.assertAllClose(a, self.evaluate(dist.concentration1)) 81 82 def testBetaProperty(self): 83 a = [[1., 2, 3]] 84 b = [[2., 4, 3]] 85 dist = beta_lib.Beta(a, b) 86 self.assertEqual([1, 3], dist.concentration0.get_shape()) 87 self.assertAllClose(b, self.evaluate(dist.concentration0)) 88 89 def testPdfXProper(self): 90 a = [[1., 2, 3]] 91 b = [[2., 4, 3]] 92 dist = beta_lib.Beta(a, b, validate_args=True) 93 self.evaluate(dist.prob([.1, .3, .6])) 94 self.evaluate(dist.prob([.2, .3, .5])) 95 # Either condition can trigger. 96 with self.assertRaisesOpError("sample must be positive"): 97 self.evaluate(dist.prob([-1., 0.1, 0.5])) 98 with self.assertRaisesOpError("sample must be positive"): 99 self.evaluate(dist.prob([0., 0.1, 0.5])) 100 with self.assertRaisesOpError("sample must be less than `1`"): 101 self.evaluate(dist.prob([.1, .2, 1.2])) 102 with self.assertRaisesOpError("sample must be less than `1`"): 103 self.evaluate(dist.prob([.1, .2, 1.0])) 104 105 def testPdfTwoBatches(self): 106 a = [1., 2] 107 b = [1., 2] 108 x = [.5, .5] 109 dist = beta_lib.Beta(a, b) 110 pdf = dist.prob(x) 111 self.assertAllClose([1., 3. / 2], self.evaluate(pdf), rtol=1e-5, atol=1e-5) 112 self.assertEqual((2,), pdf.get_shape()) 113 114 def testPdfTwoBatchesNontrivialX(self): 115 a = [1., 2] 116 b = [1., 2] 117 x = [.3, .7] 118 dist = beta_lib.Beta(a, b) 119 pdf = dist.prob(x) 120 self.assertAllClose([1, 63. / 50], self.evaluate(pdf), rtol=1e-5, atol=1e-5) 121 self.assertEqual((2,), pdf.get_shape()) 122 123 def testPdfUniformZeroBatch(self): 124 # This is equivalent to a uniform distribution 125 a = 1. 126 b = 1. 127 x = np.array([.1, .2, .3, .5, .8], dtype=np.float32) 128 dist = beta_lib.Beta(a, b) 129 pdf = dist.prob(x) 130 self.assertAllClose([1.] * 5, self.evaluate(pdf)) 131 self.assertEqual((5,), pdf.get_shape()) 132 133 def testPdfAlphaStretchedInBroadcastWhenSameRank(self): 134 a = [[1., 2]] 135 b = [[1., 2]] 136 x = [[.5, .5], [.3, .7]] 137 dist = beta_lib.Beta(a, b) 138 pdf = dist.prob(x) 139 self.assertAllClose([[1., 3. / 2], [1., 63. / 50]], 140 self.evaluate(pdf), 141 rtol=1e-5, 142 atol=1e-5) 143 self.assertEqual((2, 2), pdf.get_shape()) 144 145 def testPdfAlphaStretchedInBroadcastWhenLowerRank(self): 146 a = [1., 2] 147 b = [1., 2] 148 x = [[.5, .5], [.2, .8]] 149 pdf = beta_lib.Beta(a, b).prob(x) 150 self.assertAllClose([[1., 3. / 2], [1., 24. / 25]], 151 self.evaluate(pdf), 152 rtol=1e-5, 153 atol=1e-5) 154 self.assertEqual((2, 2), pdf.get_shape()) 155 156 def testPdfXStretchedInBroadcastWhenSameRank(self): 157 a = [[1., 2], [2., 3]] 158 b = [[1., 2], [2., 3]] 159 x = [[.5, .5]] 160 pdf = beta_lib.Beta(a, b).prob(x) 161 self.assertAllClose([[1., 3. / 2], [3. / 2, 15. / 8]], 162 self.evaluate(pdf), 163 rtol=1e-5, 164 atol=1e-5) 165 self.assertEqual((2, 2), pdf.get_shape()) 166 167 def testPdfXStretchedInBroadcastWhenLowerRank(self): 168 a = [[1., 2], [2., 3]] 169 b = [[1., 2], [2., 3]] 170 x = [.5, .5] 171 pdf = beta_lib.Beta(a, b).prob(x) 172 self.assertAllClose([[1., 3. / 2], [3. / 2, 15. / 8]], 173 self.evaluate(pdf), 174 rtol=1e-5, 175 atol=1e-5) 176 self.assertEqual((2, 2), pdf.get_shape()) 177 178 def testLogPdfOnBoundaryIsFiniteWhenAlphaIsOne(self): 179 b = [[0.01, 0.1, 1., 2], [5., 10., 2., 3]] 180 pdf = self.evaluate(beta_lib.Beta(1., b).prob(0.)) 181 self.assertAllEqual(np.ones_like(pdf, dtype=np.bool_), np.isfinite(pdf)) 182 183 def testBetaMean(self): 184 a = [1., 2, 3] 185 b = [2., 4, 1.2] 186 dist = beta_lib.Beta(a, b) 187 self.assertEqual(dist.mean().get_shape(), (3,)) 188 if not stats: 189 return 190 expected_mean = stats.beta.mean(a, b) 191 self.assertAllClose(expected_mean, self.evaluate(dist.mean())) 192 193 def testBetaVariance(self): 194 a = [1., 2, 3] 195 b = [2., 4, 1.2] 196 dist = beta_lib.Beta(a, b) 197 self.assertEqual(dist.variance().get_shape(), (3,)) 198 if not stats: 199 return 200 expected_variance = stats.beta.var(a, b) 201 self.assertAllClose(expected_variance, self.evaluate(dist.variance())) 202 203 def testBetaMode(self): 204 a = np.array([1.1, 2, 3]) 205 b = np.array([2., 4, 1.2]) 206 expected_mode = (a - 1) / (a + b - 2) 207 dist = beta_lib.Beta(a, b) 208 self.assertEqual(dist.mode().get_shape(), (3,)) 209 self.assertAllClose(expected_mode, self.evaluate(dist.mode())) 210 211 def testBetaModeInvalid(self): 212 a = np.array([1., 2, 3]) 213 b = np.array([2., 4, 1.2]) 214 dist = beta_lib.Beta(a, b, allow_nan_stats=False) 215 with self.assertRaisesOpError("Condition x < y.*"): 216 self.evaluate(dist.mode()) 217 218 a = np.array([2., 2, 3]) 219 b = np.array([1., 4, 1.2]) 220 dist = beta_lib.Beta(a, b, allow_nan_stats=False) 221 with self.assertRaisesOpError("Condition x < y.*"): 222 self.evaluate(dist.mode()) 223 224 def testBetaModeEnableAllowNanStats(self): 225 a = np.array([1., 2, 3]) 226 b = np.array([2., 4, 1.2]) 227 dist = beta_lib.Beta(a, b, allow_nan_stats=True) 228 229 expected_mode = (a - 1) / (a + b - 2) 230 expected_mode[0] = np.nan 231 self.assertEqual((3,), dist.mode().get_shape()) 232 self.assertAllClose(expected_mode, self.evaluate(dist.mode())) 233 234 a = np.array([2., 2, 3]) 235 b = np.array([1., 4, 1.2]) 236 dist = beta_lib.Beta(a, b, allow_nan_stats=True) 237 238 expected_mode = (a - 1) / (a + b - 2) 239 expected_mode[0] = np.nan 240 self.assertEqual((3,), dist.mode().get_shape()) 241 self.assertAllClose(expected_mode, self.evaluate(dist.mode())) 242 243 def testBetaEntropy(self): 244 a = [1., 2, 3] 245 b = [2., 4, 1.2] 246 dist = beta_lib.Beta(a, b) 247 self.assertEqual(dist.entropy().get_shape(), (3,)) 248 if not stats: 249 return 250 expected_entropy = stats.beta.entropy(a, b) 251 self.assertAllClose(expected_entropy, self.evaluate(dist.entropy())) 252 253 def testBetaSample(self): 254 a = 1. 255 b = 2. 256 beta = beta_lib.Beta(a, b) 257 n = constant_op.constant(100000) 258 samples = beta.sample(n) 259 sample_values = self.evaluate(samples) 260 self.assertEqual(sample_values.shape, (100000,)) 261 self.assertFalse(np.any(sample_values < 0.0)) 262 if not stats: 263 return 264 self.assertLess( 265 stats.kstest( 266 # Beta is a univariate distribution. 267 sample_values, 268 stats.beta(a=1., b=2.).cdf)[0], 269 0.01) 270 # The standard error of the sample mean is 1 / (sqrt(18 * n)) 271 self.assertAllClose( 272 sample_values.mean(axis=0), stats.beta.mean(a, b), atol=1e-2) 273 self.assertAllClose( 274 np.cov(sample_values, rowvar=0), stats.beta.var(a, b), atol=1e-1) 275 276 def testBetaFullyReparameterized(self): 277 a = constant_op.constant(1.0) 278 b = constant_op.constant(2.0) 279 with backprop.GradientTape() as tape: 280 tape.watch(a) 281 tape.watch(b) 282 beta = beta_lib.Beta(a, b) 283 samples = beta.sample(100) 284 grad_a, grad_b = tape.gradient(samples, [a, b]) 285 self.assertIsNotNone(grad_a) 286 self.assertIsNotNone(grad_b) 287 288 # Test that sampling with the same seed twice gives the same results. 289 def testBetaSampleMultipleTimes(self): 290 a_val = 1. 291 b_val = 2. 292 n_val = 100 293 294 random_seed.set_random_seed(654321) 295 beta1 = beta_lib.Beta( 296 concentration1=a_val, concentration0=b_val, name="beta1") 297 samples1 = self.evaluate(beta1.sample(n_val, seed=123456)) 298 299 random_seed.set_random_seed(654321) 300 beta2 = beta_lib.Beta( 301 concentration1=a_val, concentration0=b_val, name="beta2") 302 samples2 = self.evaluate(beta2.sample(n_val, seed=123456)) 303 304 self.assertAllClose(samples1, samples2) 305 306 def testBetaSampleMultidimensional(self): 307 a = np.random.rand(3, 2, 2).astype(np.float32) 308 b = np.random.rand(3, 2, 2).astype(np.float32) 309 beta = beta_lib.Beta(a, b) 310 n = constant_op.constant(100000) 311 samples = beta.sample(n) 312 sample_values = self.evaluate(samples) 313 self.assertEqual(sample_values.shape, (100000, 3, 2, 2)) 314 self.assertFalse(np.any(sample_values < 0.0)) 315 if not stats: 316 return 317 self.assertAllClose( 318 sample_values[:, 1, :].mean(axis=0), 319 stats.beta.mean(a, b)[1, :], 320 atol=1e-1) 321 322 def testBetaCdf(self): 323 shape = (30, 40, 50) 324 for dt in (np.float32, np.float64): 325 a = 10. * np.random.random(shape).astype(dt) 326 b = 10. * np.random.random(shape).astype(dt) 327 x = np.random.random(shape).astype(dt) 328 actual = self.evaluate(beta_lib.Beta(a, b).cdf(x)) 329 self.assertAllEqual(np.ones(shape, dtype=np.bool_), 0. <= x) 330 self.assertAllEqual(np.ones(shape, dtype=np.bool_), 1. >= x) 331 if not stats: 332 return 333 self.assertAllClose(stats.beta.cdf(x, a, b), actual, rtol=9e-3, atol=5e-6) 334 335 def testBetaLogCdf(self): 336 shape = (30, 40, 50) 337 for dt in (np.float32, np.float64): 338 a = 10. * np.random.random(shape).astype(dt) 339 b = 10. * np.random.random(shape).astype(dt) 340 x = np.random.random(shape).astype(dt) 341 actual = self.evaluate(math_ops.exp(beta_lib.Beta(a, b).log_cdf(x))) 342 self.assertAllEqual(np.ones(shape, dtype=np.bool_), 0. <= x) 343 self.assertAllEqual(np.ones(shape, dtype=np.bool_), 1. >= x) 344 if not stats: 345 return 346 self.assertAllClose(stats.beta.cdf(x, a, b), actual, rtol=3e-3, atol=2e-5) 347 348 def testBetaWithSoftplusConcentration(self): 349 a, b = -4.2, -9.1 350 dist = beta_lib.BetaWithSoftplusConcentration(a, b) 351 self.assertAllClose( 352 self.evaluate(nn_ops.softplus(a)), self.evaluate(dist.concentration1)) 353 self.assertAllClose( 354 self.evaluate(nn_ops.softplus(b)), self.evaluate(dist.concentration0)) 355 356 def testBetaBetaKL(self): 357 for shape in [(10,), (4, 5)]: 358 a1 = 6.0 * np.random.random(size=shape) + 1e-4 359 b1 = 6.0 * np.random.random(size=shape) + 1e-4 360 a2 = 6.0 * np.random.random(size=shape) + 1e-4 361 b2 = 6.0 * np.random.random(size=shape) + 1e-4 362 # Take inverse softplus of values to test BetaWithSoftplusConcentration 363 a1_sp = np.log(np.exp(a1) - 1.0) 364 b1_sp = np.log(np.exp(b1) - 1.0) 365 a2_sp = np.log(np.exp(a2) - 1.0) 366 b2_sp = np.log(np.exp(b2) - 1.0) 367 368 d1 = beta_lib.Beta(concentration1=a1, concentration0=b1) 369 d2 = beta_lib.Beta(concentration1=a2, concentration0=b2) 370 d1_sp = beta_lib.BetaWithSoftplusConcentration(concentration1=a1_sp, 371 concentration0=b1_sp) 372 d2_sp = beta_lib.BetaWithSoftplusConcentration(concentration1=a2_sp, 373 concentration0=b2_sp) 374 375 if not special: 376 return 377 kl_expected = (special.betaln(a2, b2) - special.betaln(a1, b1) + 378 (a1 - a2) * special.digamma(a1) + 379 (b1 - b2) * special.digamma(b1) + 380 (a2 - a1 + b2 - b1) * special.digamma(a1 + b1)) 381 382 for dist1 in [d1, d1_sp]: 383 for dist2 in [d2, d2_sp]: 384 kl = kullback_leibler.kl_divergence(dist1, dist2) 385 kl_val = self.evaluate(kl) 386 self.assertEqual(kl.get_shape(), shape) 387 self.assertAllClose(kl_val, kl_expected) 388 389 # Make sure KL(d1||d1) is 0 390 kl_same = self.evaluate(kullback_leibler.kl_divergence(d1, d1)) 391 self.assertAllClose(kl_same, np.zeros_like(kl_expected)) 392 393 394if __name__ == "__main__": 395 test.main() 396