1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16import importlib 17 18import numpy as np 19 20from tensorflow.python.eager import backprop 21from tensorflow.python.framework import constant_op 22from tensorflow.python.framework import tensor_shape 23from tensorflow.python.framework import test_util 24from tensorflow.python.ops import math_ops 25from tensorflow.python.ops import nn_ops 26from tensorflow.python.ops.distributions import gamma as gamma_lib 27from tensorflow.python.ops.distributions import kullback_leibler 28from tensorflow.python.platform import test 29from tensorflow.python.platform import tf_logging 30 31 32def try_import(name): # pylint: disable=invalid-name 33 module = None 34 try: 35 module = importlib.import_module(name) 36 except ImportError as e: 37 tf_logging.warning("Could not import %s: %s" % (name, str(e))) 38 return module 39 40 41special = try_import("scipy.special") 42stats = try_import("scipy.stats") 43 44 45@test_util.run_all_in_graph_and_eager_modes 46class GammaTest(test.TestCase): 47 48 def testGammaShape(self): 49 alpha = constant_op.constant([3.0] * 5) 50 beta = constant_op.constant(11.0) 51 gamma = gamma_lib.Gamma(concentration=alpha, rate=beta) 52 53 self.assertEqual(self.evaluate(gamma.batch_shape_tensor()), (5,)) 54 self.assertEqual(gamma.batch_shape, tensor_shape.TensorShape([5])) 55 self.assertAllEqual(self.evaluate(gamma.event_shape_tensor()), []) 56 self.assertEqual(gamma.event_shape, tensor_shape.TensorShape([])) 57 58 def testGammaLogPDF(self): 59 batch_size = 6 60 alpha = constant_op.constant([2.0] * batch_size) 61 beta = constant_op.constant([3.0] * batch_size) 62 alpha_v = 2.0 63 beta_v = 3.0 64 x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32) 65 gamma = gamma_lib.Gamma(concentration=alpha, rate=beta) 66 log_pdf = gamma.log_prob(x) 67 self.assertEqual(log_pdf.get_shape(), (6,)) 68 pdf = gamma.prob(x) 69 self.assertEqual(pdf.get_shape(), (6,)) 70 if not stats: 71 return 72 expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v) 73 self.assertAllClose(self.evaluate(log_pdf), expected_log_pdf) 74 self.assertAllClose(self.evaluate(pdf), np.exp(expected_log_pdf)) 75 76 def testGammaLogPDFBoundary(self): 77 # When concentration = 1, we have an exponential distribution. Check that at 78 # 0 we have finite log prob. 79 rate = np.array([0.1, 0.5, 1., 2., 5., 10.], dtype=np.float32) 80 gamma = gamma_lib.Gamma(concentration=1., rate=rate) 81 log_pdf = gamma.log_prob(0.) 82 self.assertAllClose(np.log(rate), self.evaluate(log_pdf)) 83 84 def testGammaLogPDFMultidimensional(self): 85 batch_size = 6 86 alpha = constant_op.constant([[2.0, 4.0]] * batch_size) 87 beta = constant_op.constant([[3.0, 4.0]] * batch_size) 88 alpha_v = np.array([2.0, 4.0]) 89 beta_v = np.array([3.0, 4.0]) 90 x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T 91 gamma = gamma_lib.Gamma(concentration=alpha, rate=beta) 92 log_pdf = gamma.log_prob(x) 93 log_pdf_values = self.evaluate(log_pdf) 94 self.assertEqual(log_pdf.get_shape(), (6, 2)) 95 pdf = gamma.prob(x) 96 pdf_values = self.evaluate(pdf) 97 self.assertEqual(pdf.get_shape(), (6, 2)) 98 if not stats: 99 return 100 expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v) 101 self.assertAllClose(log_pdf_values, expected_log_pdf) 102 self.assertAllClose(pdf_values, np.exp(expected_log_pdf)) 103 104 def testGammaLogPDFMultidimensionalBroadcasting(self): 105 batch_size = 6 106 alpha = constant_op.constant([[2.0, 4.0]] * batch_size) 107 beta = constant_op.constant(3.0) 108 alpha_v = np.array([2.0, 4.0]) 109 beta_v = 3.0 110 x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T 111 gamma = gamma_lib.Gamma(concentration=alpha, rate=beta) 112 log_pdf = gamma.log_prob(x) 113 log_pdf_values = self.evaluate(log_pdf) 114 self.assertEqual(log_pdf.get_shape(), (6, 2)) 115 pdf = gamma.prob(x) 116 pdf_values = self.evaluate(pdf) 117 self.assertEqual(pdf.get_shape(), (6, 2)) 118 119 if not stats: 120 return 121 expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v) 122 self.assertAllClose(log_pdf_values, expected_log_pdf) 123 self.assertAllClose(pdf_values, np.exp(expected_log_pdf)) 124 125 def testGammaCDF(self): 126 batch_size = 6 127 alpha = constant_op.constant([2.0] * batch_size) 128 beta = constant_op.constant([3.0] * batch_size) 129 alpha_v = 2.0 130 beta_v = 3.0 131 x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32) 132 133 gamma = gamma_lib.Gamma(concentration=alpha, rate=beta) 134 cdf = gamma.cdf(x) 135 self.assertEqual(cdf.get_shape(), (6,)) 136 if not stats: 137 return 138 expected_cdf = stats.gamma.cdf(x, alpha_v, scale=1 / beta_v) 139 self.assertAllClose(self.evaluate(cdf), expected_cdf) 140 141 def testGammaMean(self): 142 alpha_v = np.array([1.0, 3.0, 2.5]) 143 beta_v = np.array([1.0, 4.0, 5.0]) 144 gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v) 145 self.assertEqual(gamma.mean().get_shape(), (3,)) 146 if not stats: 147 return 148 expected_means = stats.gamma.mean(alpha_v, scale=1 / beta_v) 149 self.assertAllClose(self.evaluate(gamma.mean()), expected_means) 150 151 def testGammaModeAllowNanStatsIsFalseWorksWhenAllBatchMembersAreDefined(self): 152 alpha_v = np.array([5.5, 3.0, 2.5]) 153 beta_v = np.array([1.0, 4.0, 5.0]) 154 gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v) 155 expected_modes = (alpha_v - 1) / beta_v 156 self.assertEqual(gamma.mode().get_shape(), (3,)) 157 self.assertAllClose(self.evaluate(gamma.mode()), expected_modes) 158 159 def testGammaModeAllowNanStatsFalseRaisesForUndefinedBatchMembers(self): 160 # Mode will not be defined for the first entry. 161 alpha_v = np.array([0.5, 3.0, 2.5]) 162 beta_v = np.array([1.0, 4.0, 5.0]) 163 gamma = gamma_lib.Gamma( 164 concentration=alpha_v, rate=beta_v, allow_nan_stats=False) 165 with self.assertRaisesOpError("x < y"): 166 self.evaluate(gamma.mode()) 167 168 def testGammaModeAllowNanStatsIsTrueReturnsNaNforUndefinedBatchMembers(self): 169 # Mode will not be defined for the first entry. 170 alpha_v = np.array([0.5, 3.0, 2.5]) 171 beta_v = np.array([1.0, 4.0, 5.0]) 172 gamma = gamma_lib.Gamma( 173 concentration=alpha_v, rate=beta_v, allow_nan_stats=True) 174 expected_modes = (alpha_v - 1) / beta_v 175 expected_modes[0] = np.nan 176 self.assertEqual(gamma.mode().get_shape(), (3,)) 177 self.assertAllClose(self.evaluate(gamma.mode()), expected_modes) 178 179 def testGammaVariance(self): 180 alpha_v = np.array([1.0, 3.0, 2.5]) 181 beta_v = np.array([1.0, 4.0, 5.0]) 182 gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v) 183 self.assertEqual(gamma.variance().get_shape(), (3,)) 184 if not stats: 185 return 186 expected_variances = stats.gamma.var(alpha_v, scale=1 / beta_v) 187 self.assertAllClose(self.evaluate(gamma.variance()), expected_variances) 188 189 def testGammaStd(self): 190 alpha_v = np.array([1.0, 3.0, 2.5]) 191 beta_v = np.array([1.0, 4.0, 5.0]) 192 gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v) 193 self.assertEqual(gamma.stddev().get_shape(), (3,)) 194 if not stats: 195 return 196 expected_stddev = stats.gamma.std(alpha_v, scale=1. / beta_v) 197 self.assertAllClose(self.evaluate(gamma.stddev()), expected_stddev) 198 199 def testGammaEntropy(self): 200 alpha_v = np.array([1.0, 3.0, 2.5]) 201 beta_v = np.array([1.0, 4.0, 5.0]) 202 gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v) 203 self.assertEqual(gamma.entropy().get_shape(), (3,)) 204 if not stats: 205 return 206 expected_entropy = stats.gamma.entropy(alpha_v, scale=1 / beta_v) 207 self.assertAllClose(self.evaluate(gamma.entropy()), expected_entropy) 208 209 def testGammaSampleSmallAlpha(self): 210 alpha_v = 0.05 211 beta_v = 1.0 212 alpha = constant_op.constant(alpha_v) 213 beta = constant_op.constant(beta_v) 214 n = 100000 215 gamma = gamma_lib.Gamma(concentration=alpha, rate=beta) 216 samples = gamma.sample(n, seed=137) 217 sample_values = self.evaluate(samples) 218 self.assertEqual(samples.get_shape(), (n,)) 219 self.assertEqual(sample_values.shape, (n,)) 220 self.assertTrue(self._kstest(alpha_v, beta_v, sample_values)) 221 if not stats: 222 return 223 self.assertAllClose( 224 sample_values.mean(), 225 stats.gamma.mean(alpha_v, scale=1 / beta_v), 226 atol=.01) 227 self.assertAllClose( 228 sample_values.var(), 229 stats.gamma.var(alpha_v, scale=1 / beta_v), 230 atol=.15) 231 232 def testGammaSample(self): 233 alpha_v = 4.0 234 beta_v = 3.0 235 alpha = constant_op.constant(alpha_v) 236 beta = constant_op.constant(beta_v) 237 n = 100000 238 gamma = gamma_lib.Gamma(concentration=alpha, rate=beta) 239 samples = gamma.sample(n, seed=137) 240 sample_values = self.evaluate(samples) 241 self.assertEqual(samples.get_shape(), (n,)) 242 self.assertEqual(sample_values.shape, (n,)) 243 self.assertTrue(self._kstest(alpha_v, beta_v, sample_values)) 244 if not stats: 245 return 246 self.assertAllClose( 247 sample_values.mean(), 248 stats.gamma.mean(alpha_v, scale=1 / beta_v), 249 atol=.01) 250 self.assertAllClose( 251 sample_values.var(), 252 stats.gamma.var(alpha_v, scale=1 / beta_v), 253 atol=.15) 254 255 def testGammaFullyReparameterized(self): 256 alpha = constant_op.constant(4.0) 257 beta = constant_op.constant(3.0) 258 with backprop.GradientTape() as tape: 259 tape.watch(alpha) 260 tape.watch(beta) 261 gamma = gamma_lib.Gamma(concentration=alpha, rate=beta) 262 samples = gamma.sample(100) 263 grad_alpha, grad_beta = tape.gradient(samples, [alpha, beta]) 264 self.assertIsNotNone(grad_alpha) 265 self.assertIsNotNone(grad_beta) 266 267 def testGammaSampleMultiDimensional(self): 268 alpha_v = np.array([np.arange(1, 101, dtype=np.float32)]) # 1 x 100 269 beta_v = np.array([np.arange(1, 11, dtype=np.float32)]).T # 10 x 1 270 gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v) 271 n = 10000 272 samples = gamma.sample(n, seed=137) 273 sample_values = self.evaluate(samples) 274 self.assertEqual(samples.get_shape(), (n, 10, 100)) 275 self.assertEqual(sample_values.shape, (n, 10, 100)) 276 zeros = np.zeros_like(alpha_v + beta_v) # 10 x 100 277 alpha_bc = alpha_v + zeros 278 beta_bc = beta_v + zeros 279 if not stats: 280 return 281 self.assertAllClose( 282 sample_values.mean(axis=0), 283 stats.gamma.mean(alpha_bc, scale=1 / beta_bc), 284 atol=0., 285 rtol=.05) 286 self.assertAllClose( 287 sample_values.var(axis=0), 288 stats.gamma.var(alpha_bc, scale=1 / beta_bc), 289 atol=10.0, 290 rtol=0.) 291 fails = 0 292 trials = 0 293 for ai, a in enumerate(np.reshape(alpha_v, [-1])): 294 for bi, b in enumerate(np.reshape(beta_v, [-1])): 295 s = sample_values[:, bi, ai] 296 trials += 1 297 fails += 0 if self._kstest(a, b, s) else 1 298 self.assertLess(fails, trials * 0.03) 299 300 def _kstest(self, alpha, beta, samples): 301 # Uses the Kolmogorov-Smirnov test for goodness of fit. 302 if not stats: 303 return True # If we can't test, return that the test passes. 304 ks, _ = stats.kstest(samples, stats.gamma(alpha, scale=1 / beta).cdf) 305 # Return True when the test passes. 306 return ks < 0.02 307 308 def testGammaPdfOfSampleMultiDims(self): 309 gamma = gamma_lib.Gamma(concentration=[7., 11.], rate=[[5.], [6.]]) 310 num = 50000 311 samples = gamma.sample(num, seed=137) 312 pdfs = gamma.prob(samples) 313 sample_vals, pdf_vals = self.evaluate([samples, pdfs]) 314 self.assertEqual(samples.get_shape(), (num, 2, 2)) 315 self.assertEqual(pdfs.get_shape(), (num, 2, 2)) 316 self._assertIntegral(sample_vals[:, 0, 0], pdf_vals[:, 0, 0], err=0.02) 317 self._assertIntegral(sample_vals[:, 0, 1], pdf_vals[:, 0, 1], err=0.02) 318 self._assertIntegral(sample_vals[:, 1, 0], pdf_vals[:, 1, 0], err=0.02) 319 self._assertIntegral(sample_vals[:, 1, 1], pdf_vals[:, 1, 1], err=0.02) 320 if not stats: 321 return 322 self.assertAllClose( 323 stats.gamma.mean([[7., 11.], [7., 11.]], 324 scale=1 / np.array([[5., 5.], [6., 6.]])), 325 sample_vals.mean(axis=0), 326 atol=.1) 327 self.assertAllClose( 328 stats.gamma.var([[7., 11.], [7., 11.]], 329 scale=1 / np.array([[5., 5.], [6., 6.]])), 330 sample_vals.var(axis=0), 331 atol=.1) 332 333 def _assertIntegral(self, sample_vals, pdf_vals, err=1e-3): 334 s_p = zip(sample_vals, pdf_vals) 335 prev = (0, 0) 336 total = 0 337 for k in sorted(s_p, key=lambda x: x[0]): 338 pair_pdf = (k[1] + prev[1]) / 2 339 total += (k[0] - prev[0]) * pair_pdf 340 prev = k 341 self.assertNear(1., total, err=err) 342 343 def testGammaNonPositiveInitializationParamsRaises(self): 344 alpha_v = constant_op.constant(0.0, name="alpha") 345 beta_v = constant_op.constant(1.0, name="beta") 346 with self.assertRaisesOpError("x > 0"): 347 gamma = gamma_lib.Gamma( 348 concentration=alpha_v, rate=beta_v, validate_args=True) 349 self.evaluate(gamma.mean()) 350 alpha_v = constant_op.constant(1.0, name="alpha") 351 beta_v = constant_op.constant(0.0, name="beta") 352 with self.assertRaisesOpError("x > 0"): 353 gamma = gamma_lib.Gamma( 354 concentration=alpha_v, rate=beta_v, validate_args=True) 355 self.evaluate(gamma.mean()) 356 357 def testGammaWithSoftplusConcentrationRate(self): 358 alpha_v = constant_op.constant([0.0, -2.1], name="alpha") 359 beta_v = constant_op.constant([1.0, -3.6], name="beta") 360 gamma = gamma_lib.GammaWithSoftplusConcentrationRate( 361 concentration=alpha_v, rate=beta_v) 362 self.assertAllEqual( 363 self.evaluate(nn_ops.softplus(alpha_v)), 364 self.evaluate(gamma.concentration)) 365 self.assertAllEqual( 366 self.evaluate(nn_ops.softplus(beta_v)), self.evaluate(gamma.rate)) 367 368 def testGammaGammaKL(self): 369 alpha0 = np.array([3.]) 370 beta0 = np.array([1., 2., 3., 1.5, 2.5, 3.5]) 371 372 alpha1 = np.array([0.4]) 373 beta1 = np.array([0.5, 1., 1.5, 2., 2.5, 3.]) 374 375 # Build graph. 376 g0 = gamma_lib.Gamma(concentration=alpha0, rate=beta0) 377 g1 = gamma_lib.Gamma(concentration=alpha1, rate=beta1) 378 x = g0.sample(int(1e4), seed=0) 379 kl_sample = math_ops.reduce_mean(g0.log_prob(x) - g1.log_prob(x), 0) 380 kl_actual = kullback_leibler.kl_divergence(g0, g1) 381 382 # Execute graph. 383 [kl_sample_, kl_actual_] = self.evaluate([kl_sample, kl_actual]) 384 385 self.assertEqual(beta0.shape, kl_actual.get_shape()) 386 387 if not special: 388 return 389 kl_expected = ((alpha0 - alpha1) * special.digamma(alpha0) 390 + special.gammaln(alpha1) 391 - special.gammaln(alpha0) 392 + alpha1 * np.log(beta0) 393 - alpha1 * np.log(beta1) 394 + alpha0 * (beta1 / beta0 - 1.)) 395 396 self.assertAllClose(kl_expected, kl_actual_, atol=0., rtol=1e-6) 397 self.assertAllClose(kl_sample_, kl_actual_, atol=0., rtol=1e-1) 398 399 400if __name__ == "__main__": 401 test.main() 402