1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15from __future__ import absolute_import 16from __future__ import division 17from __future__ import print_function 18 19import importlib 20 21import numpy as np 22 23from tensorflow.python.eager import backprop 24from tensorflow.python.framework import constant_op 25from tensorflow.python.framework import random_seed 26from tensorflow.python.framework import tensor_shape 27from tensorflow.python.framework import test_util 28from tensorflow.python.ops import math_ops 29from tensorflow.python.ops import nn_ops 30from tensorflow.python.ops.distributions import beta as beta_lib 31from tensorflow.python.ops.distributions import kullback_leibler 32from tensorflow.python.platform import test 33from tensorflow.python.platform import tf_logging 34 35 36def try_import(name): # pylint: disable=invalid-name 37 module = None 38 try: 39 module = importlib.import_module(name) 40 except ImportError as e: 41 tf_logging.warning("Could not import %s: %s" % (name, str(e))) 42 return module 43 44 45special = try_import("scipy.special") 46stats = try_import("scipy.stats") 47 48 49@test_util.run_all_in_graph_and_eager_modes 50class BetaTest(test.TestCase): 51 52 def testSimpleShapes(self): 53 a = np.random.rand(3) 54 b = np.random.rand(3) 55 dist = beta_lib.Beta(a, b) 56 self.assertAllEqual([], self.evaluate(dist.event_shape_tensor())) 57 self.assertAllEqual([3], self.evaluate(dist.batch_shape_tensor())) 58 self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape) 59 self.assertEqual(tensor_shape.TensorShape([3]), dist.batch_shape) 60 61 def testComplexShapes(self): 62 a = np.random.rand(3, 2, 2) 63 b = np.random.rand(3, 2, 2) 64 dist = beta_lib.Beta(a, b) 65 self.assertAllEqual([], self.evaluate(dist.event_shape_tensor())) 66 self.assertAllEqual([3, 2, 2], self.evaluate(dist.batch_shape_tensor())) 67 self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape) 68 self.assertEqual(tensor_shape.TensorShape([3, 2, 2]), dist.batch_shape) 69 70 def testComplexShapesBroadcast(self): 71 a = np.random.rand(3, 2, 2) 72 b = np.random.rand(2, 2) 73 dist = beta_lib.Beta(a, b) 74 self.assertAllEqual([], self.evaluate(dist.event_shape_tensor())) 75 self.assertAllEqual([3, 2, 2], self.evaluate(dist.batch_shape_tensor())) 76 self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape) 77 self.assertEqual(tensor_shape.TensorShape([3, 2, 2]), dist.batch_shape) 78 79 def testAlphaProperty(self): 80 a = [[1., 2, 3]] 81 b = [[2., 4, 3]] 82 dist = beta_lib.Beta(a, b) 83 self.assertEqual([1, 3], dist.concentration1.get_shape()) 84 self.assertAllClose(a, self.evaluate(dist.concentration1)) 85 86 def testBetaProperty(self): 87 a = [[1., 2, 3]] 88 b = [[2., 4, 3]] 89 dist = beta_lib.Beta(a, b) 90 self.assertEqual([1, 3], dist.concentration0.get_shape()) 91 self.assertAllClose(b, self.evaluate(dist.concentration0)) 92 93 def testPdfXProper(self): 94 a = [[1., 2, 3]] 95 b = [[2., 4, 3]] 96 dist = beta_lib.Beta(a, b, validate_args=True) 97 self.evaluate(dist.prob([.1, .3, .6])) 98 self.evaluate(dist.prob([.2, .3, .5])) 99 # Either condition can trigger. 100 with self.assertRaisesOpError("sample must be positive"): 101 self.evaluate(dist.prob([-1., 0.1, 0.5])) 102 with self.assertRaisesOpError("sample must be positive"): 103 self.evaluate(dist.prob([0., 0.1, 0.5])) 104 with self.assertRaisesOpError("sample must be less than `1`"): 105 self.evaluate(dist.prob([.1, .2, 1.2])) 106 with self.assertRaisesOpError("sample must be less than `1`"): 107 self.evaluate(dist.prob([.1, .2, 1.0])) 108 109 def testPdfTwoBatches(self): 110 a = [1., 2] 111 b = [1., 2] 112 x = [.5, .5] 113 dist = beta_lib.Beta(a, b) 114 pdf = dist.prob(x) 115 self.assertAllClose([1., 3. / 2], self.evaluate(pdf)) 116 self.assertEqual((2,), pdf.get_shape()) 117 118 def testPdfTwoBatchesNontrivialX(self): 119 a = [1., 2] 120 b = [1., 2] 121 x = [.3, .7] 122 dist = beta_lib.Beta(a, b) 123 pdf = dist.prob(x) 124 self.assertAllClose([1, 63. / 50], self.evaluate(pdf)) 125 self.assertEqual((2,), pdf.get_shape()) 126 127 def testPdfUniformZeroBatch(self): 128 # This is equivalent to a uniform distribution 129 a = 1. 130 b = 1. 131 x = np.array([.1, .2, .3, .5, .8], dtype=np.float32) 132 dist = beta_lib.Beta(a, b) 133 pdf = dist.prob(x) 134 self.assertAllClose([1.] * 5, self.evaluate(pdf)) 135 self.assertEqual((5,), pdf.get_shape()) 136 137 def testPdfAlphaStretchedInBroadcastWhenSameRank(self): 138 a = [[1., 2]] 139 b = [[1., 2]] 140 x = [[.5, .5], [.3, .7]] 141 dist = beta_lib.Beta(a, b) 142 pdf = dist.prob(x) 143 self.assertAllClose([[1., 3. / 2], [1., 63. / 50]], self.evaluate(pdf)) 144 self.assertEqual((2, 2), pdf.get_shape()) 145 146 def testPdfAlphaStretchedInBroadcastWhenLowerRank(self): 147 a = [1., 2] 148 b = [1., 2] 149 x = [[.5, .5], [.2, .8]] 150 pdf = beta_lib.Beta(a, b).prob(x) 151 self.assertAllClose([[1., 3. / 2], [1., 24. / 25]], self.evaluate(pdf)) 152 self.assertEqual((2, 2), pdf.get_shape()) 153 154 def testPdfXStretchedInBroadcastWhenSameRank(self): 155 a = [[1., 2], [2., 3]] 156 b = [[1., 2], [2., 3]] 157 x = [[.5, .5]] 158 pdf = beta_lib.Beta(a, b).prob(x) 159 self.assertAllClose([[1., 3. / 2], [3. / 2, 15. / 8]], self.evaluate(pdf)) 160 self.assertEqual((2, 2), pdf.get_shape()) 161 162 def testPdfXStretchedInBroadcastWhenLowerRank(self): 163 a = [[1., 2], [2., 3]] 164 b = [[1., 2], [2., 3]] 165 x = [.5, .5] 166 pdf = beta_lib.Beta(a, b).prob(x) 167 self.assertAllClose([[1., 3. / 2], [3. / 2, 15. / 8]], self.evaluate(pdf)) 168 self.assertEqual((2, 2), pdf.get_shape()) 169 170 def testLogPdfOnBoundaryIsFiniteWhenAlphaIsOne(self): 171 b = [[0.01, 0.1, 1., 2], [5., 10., 2., 3]] 172 pdf = self.evaluate(beta_lib.Beta(1., b).prob(0.)) 173 self.assertAllEqual(np.ones_like(pdf, dtype=np.bool), np.isfinite(pdf)) 174 175 def testBetaMean(self): 176 a = [1., 2, 3] 177 b = [2., 4, 1.2] 178 dist = beta_lib.Beta(a, b) 179 self.assertEqual(dist.mean().get_shape(), (3,)) 180 if not stats: 181 return 182 expected_mean = stats.beta.mean(a, b) 183 self.assertAllClose(expected_mean, self.evaluate(dist.mean())) 184 185 def testBetaVariance(self): 186 a = [1., 2, 3] 187 b = [2., 4, 1.2] 188 dist = beta_lib.Beta(a, b) 189 self.assertEqual(dist.variance().get_shape(), (3,)) 190 if not stats: 191 return 192 expected_variance = stats.beta.var(a, b) 193 self.assertAllClose(expected_variance, self.evaluate(dist.variance())) 194 195 def testBetaMode(self): 196 a = np.array([1.1, 2, 3]) 197 b = np.array([2., 4, 1.2]) 198 expected_mode = (a - 1) / (a + b - 2) 199 dist = beta_lib.Beta(a, b) 200 self.assertEqual(dist.mode().get_shape(), (3,)) 201 self.assertAllClose(expected_mode, self.evaluate(dist.mode())) 202 203 def testBetaModeInvalid(self): 204 a = np.array([1., 2, 3]) 205 b = np.array([2., 4, 1.2]) 206 dist = beta_lib.Beta(a, b, allow_nan_stats=False) 207 with self.assertRaisesOpError("Condition x < y.*"): 208 self.evaluate(dist.mode()) 209 210 a = np.array([2., 2, 3]) 211 b = np.array([1., 4, 1.2]) 212 dist = beta_lib.Beta(a, b, allow_nan_stats=False) 213 with self.assertRaisesOpError("Condition x < y.*"): 214 self.evaluate(dist.mode()) 215 216 def testBetaModeEnableAllowNanStats(self): 217 a = np.array([1., 2, 3]) 218 b = np.array([2., 4, 1.2]) 219 dist = beta_lib.Beta(a, b, allow_nan_stats=True) 220 221 expected_mode = (a - 1) / (a + b - 2) 222 expected_mode[0] = np.nan 223 self.assertEqual((3,), dist.mode().get_shape()) 224 self.assertAllClose(expected_mode, self.evaluate(dist.mode())) 225 226 a = np.array([2., 2, 3]) 227 b = np.array([1., 4, 1.2]) 228 dist = beta_lib.Beta(a, b, allow_nan_stats=True) 229 230 expected_mode = (a - 1) / (a + b - 2) 231 expected_mode[0] = np.nan 232 self.assertEqual((3,), dist.mode().get_shape()) 233 self.assertAllClose(expected_mode, self.evaluate(dist.mode())) 234 235 def testBetaEntropy(self): 236 a = [1., 2, 3] 237 b = [2., 4, 1.2] 238 dist = beta_lib.Beta(a, b) 239 self.assertEqual(dist.entropy().get_shape(), (3,)) 240 if not stats: 241 return 242 expected_entropy = stats.beta.entropy(a, b) 243 self.assertAllClose(expected_entropy, self.evaluate(dist.entropy())) 244 245 def testBetaSample(self): 246 a = 1. 247 b = 2. 248 beta = beta_lib.Beta(a, b) 249 n = constant_op.constant(100000) 250 samples = beta.sample(n) 251 sample_values = self.evaluate(samples) 252 self.assertEqual(sample_values.shape, (100000,)) 253 self.assertFalse(np.any(sample_values < 0.0)) 254 if not stats: 255 return 256 self.assertLess( 257 stats.kstest( 258 # Beta is a univariate distribution. 259 sample_values, 260 stats.beta(a=1., b=2.).cdf)[0], 261 0.01) 262 # The standard error of the sample mean is 1 / (sqrt(18 * n)) 263 self.assertAllClose( 264 sample_values.mean(axis=0), stats.beta.mean(a, b), atol=1e-2) 265 self.assertAllClose( 266 np.cov(sample_values, rowvar=0), stats.beta.var(a, b), atol=1e-1) 267 268 def testBetaFullyReparameterized(self): 269 a = constant_op.constant(1.0) 270 b = constant_op.constant(2.0) 271 with backprop.GradientTape() as tape: 272 tape.watch(a) 273 tape.watch(b) 274 beta = beta_lib.Beta(a, b) 275 samples = beta.sample(100) 276 grad_a, grad_b = tape.gradient(samples, [a, b]) 277 self.assertIsNotNone(grad_a) 278 self.assertIsNotNone(grad_b) 279 280 # Test that sampling with the same seed twice gives the same results. 281 def testBetaSampleMultipleTimes(self): 282 a_val = 1. 283 b_val = 2. 284 n_val = 100 285 286 random_seed.set_random_seed(654321) 287 beta1 = beta_lib.Beta( 288 concentration1=a_val, concentration0=b_val, name="beta1") 289 samples1 = self.evaluate(beta1.sample(n_val, seed=123456)) 290 291 random_seed.set_random_seed(654321) 292 beta2 = beta_lib.Beta( 293 concentration1=a_val, concentration0=b_val, name="beta2") 294 samples2 = self.evaluate(beta2.sample(n_val, seed=123456)) 295 296 self.assertAllClose(samples1, samples2) 297 298 def testBetaSampleMultidimensional(self): 299 a = np.random.rand(3, 2, 2).astype(np.float32) 300 b = np.random.rand(3, 2, 2).astype(np.float32) 301 beta = beta_lib.Beta(a, b) 302 n = constant_op.constant(100000) 303 samples = beta.sample(n) 304 sample_values = self.evaluate(samples) 305 self.assertEqual(sample_values.shape, (100000, 3, 2, 2)) 306 self.assertFalse(np.any(sample_values < 0.0)) 307 if not stats: 308 return 309 self.assertAllClose( 310 sample_values[:, 1, :].mean(axis=0), 311 stats.beta.mean(a, b)[1, :], 312 atol=1e-1) 313 314 def testBetaCdf(self): 315 shape = (30, 40, 50) 316 for dt in (np.float32, np.float64): 317 a = 10. * np.random.random(shape).astype(dt) 318 b = 10. * np.random.random(shape).astype(dt) 319 x = np.random.random(shape).astype(dt) 320 actual = self.evaluate(beta_lib.Beta(a, b).cdf(x)) 321 self.assertAllEqual(np.ones(shape, dtype=np.bool), 0. <= x) 322 self.assertAllEqual(np.ones(shape, dtype=np.bool), 1. >= x) 323 if not stats: 324 return 325 self.assertAllClose(stats.beta.cdf(x, a, b), actual, rtol=9e-3, atol=5e-6) 326 327 def testBetaLogCdf(self): 328 shape = (30, 40, 50) 329 for dt in (np.float32, np.float64): 330 a = 10. * np.random.random(shape).astype(dt) 331 b = 10. * np.random.random(shape).astype(dt) 332 x = np.random.random(shape).astype(dt) 333 actual = self.evaluate(math_ops.exp(beta_lib.Beta(a, b).log_cdf(x))) 334 self.assertAllEqual(np.ones(shape, dtype=np.bool), 0. <= x) 335 self.assertAllEqual(np.ones(shape, dtype=np.bool), 1. >= x) 336 if not stats: 337 return 338 self.assertAllClose(stats.beta.cdf(x, a, b), actual, rtol=3e-3, atol=2e-5) 339 340 def testBetaWithSoftplusConcentration(self): 341 a, b = -4.2, -9.1 342 dist = beta_lib.BetaWithSoftplusConcentration(a, b) 343 self.assertAllClose( 344 self.evaluate(nn_ops.softplus(a)), self.evaluate(dist.concentration1)) 345 self.assertAllClose( 346 self.evaluate(nn_ops.softplus(b)), self.evaluate(dist.concentration0)) 347 348 def testBetaBetaKL(self): 349 for shape in [(10,), (4, 5)]: 350 a1 = 6.0 * np.random.random(size=shape) + 1e-4 351 b1 = 6.0 * np.random.random(size=shape) + 1e-4 352 a2 = 6.0 * np.random.random(size=shape) + 1e-4 353 b2 = 6.0 * np.random.random(size=shape) + 1e-4 354 # Take inverse softplus of values to test BetaWithSoftplusConcentration 355 a1_sp = np.log(np.exp(a1) - 1.0) 356 b1_sp = np.log(np.exp(b1) - 1.0) 357 a2_sp = np.log(np.exp(a2) - 1.0) 358 b2_sp = np.log(np.exp(b2) - 1.0) 359 360 d1 = beta_lib.Beta(concentration1=a1, concentration0=b1) 361 d2 = beta_lib.Beta(concentration1=a2, concentration0=b2) 362 d1_sp = beta_lib.BetaWithSoftplusConcentration(concentration1=a1_sp, 363 concentration0=b1_sp) 364 d2_sp = beta_lib.BetaWithSoftplusConcentration(concentration1=a2_sp, 365 concentration0=b2_sp) 366 367 if not special: 368 return 369 kl_expected = (special.betaln(a2, b2) - special.betaln(a1, b1) + 370 (a1 - a2) * special.digamma(a1) + 371 (b1 - b2) * special.digamma(b1) + 372 (a2 - a1 + b2 - b1) * special.digamma(a1 + b1)) 373 374 for dist1 in [d1, d1_sp]: 375 for dist2 in [d2, d2_sp]: 376 kl = kullback_leibler.kl_divergence(dist1, dist2) 377 kl_val = self.evaluate(kl) 378 self.assertEqual(kl.get_shape(), shape) 379 self.assertAllClose(kl_val, kl_expected) 380 381 # Make sure KL(d1||d1) is 0 382 kl_same = self.evaluate(kullback_leibler.kl_divergence(d1, d1)) 383 self.assertAllClose(kl_same, np.zeros_like(kl_expected)) 384 385 386if __name__ == "__main__": 387 test.main() 388