• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16import importlib
17
18import numpy as np
19
20from tensorflow.python.eager import backprop
21from tensorflow.python.framework import constant_op
22from tensorflow.python.framework import tensor_shape
23from tensorflow.python.framework import test_util
24from tensorflow.python.ops import math_ops
25from tensorflow.python.ops import nn_ops
26from tensorflow.python.ops.distributions import gamma as gamma_lib
27from tensorflow.python.ops.distributions import kullback_leibler
28from tensorflow.python.platform import test
29from tensorflow.python.platform import tf_logging
30
31
32def try_import(name):  # pylint: disable=invalid-name
33  module = None
34  try:
35    module = importlib.import_module(name)
36  except ImportError as e:
37    tf_logging.warning("Could not import %s: %s" % (name, str(e)))
38  return module
39
40
41special = try_import("scipy.special")
42stats = try_import("scipy.stats")
43
44
45@test_util.run_all_in_graph_and_eager_modes
46class GammaTest(test.TestCase):
47
48  def testGammaShape(self):
49    alpha = constant_op.constant([3.0] * 5)
50    beta = constant_op.constant(11.0)
51    gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
52
53    self.assertEqual(self.evaluate(gamma.batch_shape_tensor()), (5,))
54    self.assertEqual(gamma.batch_shape, tensor_shape.TensorShape([5]))
55    self.assertAllEqual(self.evaluate(gamma.event_shape_tensor()), [])
56    self.assertEqual(gamma.event_shape, tensor_shape.TensorShape([]))
57
58  def testGammaLogPDF(self):
59    batch_size = 6
60    alpha = constant_op.constant([2.0] * batch_size)
61    beta = constant_op.constant([3.0] * batch_size)
62    alpha_v = 2.0
63    beta_v = 3.0
64    x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
65    gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
66    log_pdf = gamma.log_prob(x)
67    self.assertEqual(log_pdf.get_shape(), (6,))
68    pdf = gamma.prob(x)
69    self.assertEqual(pdf.get_shape(), (6,))
70    if not stats:
71      return
72    expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v)
73    self.assertAllClose(self.evaluate(log_pdf), expected_log_pdf)
74    self.assertAllClose(self.evaluate(pdf), np.exp(expected_log_pdf))
75
76  def testGammaLogPDFBoundary(self):
77    # When concentration = 1, we have an exponential distribution. Check that at
78    # 0 we have finite log prob.
79    rate = np.array([0.1, 0.5, 1., 2., 5., 10.], dtype=np.float32)
80    gamma = gamma_lib.Gamma(concentration=1., rate=rate)
81    log_pdf = gamma.log_prob(0.)
82    self.assertAllClose(np.log(rate), self.evaluate(log_pdf))
83
84  def testGammaLogPDFMultidimensional(self):
85    batch_size = 6
86    alpha = constant_op.constant([[2.0, 4.0]] * batch_size)
87    beta = constant_op.constant([[3.0, 4.0]] * batch_size)
88    alpha_v = np.array([2.0, 4.0])
89    beta_v = np.array([3.0, 4.0])
90    x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T
91    gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
92    log_pdf = gamma.log_prob(x)
93    log_pdf_values = self.evaluate(log_pdf)
94    self.assertEqual(log_pdf.get_shape(), (6, 2))
95    pdf = gamma.prob(x)
96    pdf_values = self.evaluate(pdf)
97    self.assertEqual(pdf.get_shape(), (6, 2))
98    if not stats:
99      return
100    expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v)
101    self.assertAllClose(log_pdf_values, expected_log_pdf)
102    self.assertAllClose(pdf_values, np.exp(expected_log_pdf))
103
104  def testGammaLogPDFMultidimensionalBroadcasting(self):
105    batch_size = 6
106    alpha = constant_op.constant([[2.0, 4.0]] * batch_size)
107    beta = constant_op.constant(3.0)
108    alpha_v = np.array([2.0, 4.0])
109    beta_v = 3.0
110    x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T
111    gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
112    log_pdf = gamma.log_prob(x)
113    log_pdf_values = self.evaluate(log_pdf)
114    self.assertEqual(log_pdf.get_shape(), (6, 2))
115    pdf = gamma.prob(x)
116    pdf_values = self.evaluate(pdf)
117    self.assertEqual(pdf.get_shape(), (6, 2))
118
119    if not stats:
120      return
121    expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v)
122    self.assertAllClose(log_pdf_values, expected_log_pdf)
123    self.assertAllClose(pdf_values, np.exp(expected_log_pdf))
124
125  def testGammaCDF(self):
126    batch_size = 6
127    alpha = constant_op.constant([2.0] * batch_size)
128    beta = constant_op.constant([3.0] * batch_size)
129    alpha_v = 2.0
130    beta_v = 3.0
131    x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
132
133    gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
134    cdf = gamma.cdf(x)
135    self.assertEqual(cdf.get_shape(), (6,))
136    if not stats:
137      return
138    expected_cdf = stats.gamma.cdf(x, alpha_v, scale=1 / beta_v)
139    self.assertAllClose(self.evaluate(cdf), expected_cdf)
140
141  def testGammaMean(self):
142    alpha_v = np.array([1.0, 3.0, 2.5])
143    beta_v = np.array([1.0, 4.0, 5.0])
144    gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
145    self.assertEqual(gamma.mean().get_shape(), (3,))
146    if not stats:
147      return
148    expected_means = stats.gamma.mean(alpha_v, scale=1 / beta_v)
149    self.assertAllClose(self.evaluate(gamma.mean()), expected_means)
150
151  def testGammaModeAllowNanStatsIsFalseWorksWhenAllBatchMembersAreDefined(self):
152    alpha_v = np.array([5.5, 3.0, 2.5])
153    beta_v = np.array([1.0, 4.0, 5.0])
154    gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
155    expected_modes = (alpha_v - 1) / beta_v
156    self.assertEqual(gamma.mode().get_shape(), (3,))
157    self.assertAllClose(self.evaluate(gamma.mode()), expected_modes)
158
159  def testGammaModeAllowNanStatsFalseRaisesForUndefinedBatchMembers(self):
160    # Mode will not be defined for the first entry.
161    alpha_v = np.array([0.5, 3.0, 2.5])
162    beta_v = np.array([1.0, 4.0, 5.0])
163    gamma = gamma_lib.Gamma(
164        concentration=alpha_v, rate=beta_v, allow_nan_stats=False)
165    with self.assertRaisesOpError("x < y"):
166      self.evaluate(gamma.mode())
167
168  def testGammaModeAllowNanStatsIsTrueReturnsNaNforUndefinedBatchMembers(self):
169    # Mode will not be defined for the first entry.
170    alpha_v = np.array([0.5, 3.0, 2.5])
171    beta_v = np.array([1.0, 4.0, 5.0])
172    gamma = gamma_lib.Gamma(
173        concentration=alpha_v, rate=beta_v, allow_nan_stats=True)
174    expected_modes = (alpha_v - 1) / beta_v
175    expected_modes[0] = np.nan
176    self.assertEqual(gamma.mode().get_shape(), (3,))
177    self.assertAllClose(self.evaluate(gamma.mode()), expected_modes)
178
179  def testGammaVariance(self):
180    alpha_v = np.array([1.0, 3.0, 2.5])
181    beta_v = np.array([1.0, 4.0, 5.0])
182    gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
183    self.assertEqual(gamma.variance().get_shape(), (3,))
184    if not stats:
185      return
186    expected_variances = stats.gamma.var(alpha_v, scale=1 / beta_v)
187    self.assertAllClose(self.evaluate(gamma.variance()), expected_variances)
188
189  def testGammaStd(self):
190    alpha_v = np.array([1.0, 3.0, 2.5])
191    beta_v = np.array([1.0, 4.0, 5.0])
192    gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
193    self.assertEqual(gamma.stddev().get_shape(), (3,))
194    if not stats:
195      return
196    expected_stddev = stats.gamma.std(alpha_v, scale=1. / beta_v)
197    self.assertAllClose(self.evaluate(gamma.stddev()), expected_stddev)
198
199  def testGammaEntropy(self):
200    alpha_v = np.array([1.0, 3.0, 2.5])
201    beta_v = np.array([1.0, 4.0, 5.0])
202    gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
203    self.assertEqual(gamma.entropy().get_shape(), (3,))
204    if not stats:
205      return
206    expected_entropy = stats.gamma.entropy(alpha_v, scale=1 / beta_v)
207    self.assertAllClose(self.evaluate(gamma.entropy()), expected_entropy)
208
209  def testGammaSampleSmallAlpha(self):
210    alpha_v = 0.05
211    beta_v = 1.0
212    alpha = constant_op.constant(alpha_v)
213    beta = constant_op.constant(beta_v)
214    n = 100000
215    gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
216    samples = gamma.sample(n, seed=137)
217    sample_values = self.evaluate(samples)
218    self.assertEqual(samples.get_shape(), (n,))
219    self.assertEqual(sample_values.shape, (n,))
220    self.assertTrue(self._kstest(alpha_v, beta_v, sample_values))
221    if not stats:
222      return
223    self.assertAllClose(
224        sample_values.mean(),
225        stats.gamma.mean(alpha_v, scale=1 / beta_v),
226        atol=.01)
227    self.assertAllClose(
228        sample_values.var(),
229        stats.gamma.var(alpha_v, scale=1 / beta_v),
230        atol=.15)
231
232  def testGammaSample(self):
233    alpha_v = 4.0
234    beta_v = 3.0
235    alpha = constant_op.constant(alpha_v)
236    beta = constant_op.constant(beta_v)
237    n = 100000
238    gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
239    samples = gamma.sample(n, seed=137)
240    sample_values = self.evaluate(samples)
241    self.assertEqual(samples.get_shape(), (n,))
242    self.assertEqual(sample_values.shape, (n,))
243    self.assertTrue(self._kstest(alpha_v, beta_v, sample_values))
244    if not stats:
245      return
246    self.assertAllClose(
247        sample_values.mean(),
248        stats.gamma.mean(alpha_v, scale=1 / beta_v),
249        atol=.01)
250    self.assertAllClose(
251        sample_values.var(),
252        stats.gamma.var(alpha_v, scale=1 / beta_v),
253        atol=.15)
254
255  def testGammaFullyReparameterized(self):
256    alpha = constant_op.constant(4.0)
257    beta = constant_op.constant(3.0)
258    with backprop.GradientTape() as tape:
259      tape.watch(alpha)
260      tape.watch(beta)
261      gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
262      samples = gamma.sample(100)
263    grad_alpha, grad_beta = tape.gradient(samples, [alpha, beta])
264    self.assertIsNotNone(grad_alpha)
265    self.assertIsNotNone(grad_beta)
266
267  def testGammaSampleMultiDimensional(self):
268    alpha_v = np.array([np.arange(1, 101, dtype=np.float32)])  # 1 x 100
269    beta_v = np.array([np.arange(1, 11, dtype=np.float32)]).T  # 10 x 1
270    gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
271    n = 10000
272    samples = gamma.sample(n, seed=137)
273    sample_values = self.evaluate(samples)
274    self.assertEqual(samples.get_shape(), (n, 10, 100))
275    self.assertEqual(sample_values.shape, (n, 10, 100))
276    zeros = np.zeros_like(alpha_v + beta_v)  # 10 x 100
277    alpha_bc = alpha_v + zeros
278    beta_bc = beta_v + zeros
279    if not stats:
280      return
281    self.assertAllClose(
282        sample_values.mean(axis=0),
283        stats.gamma.mean(alpha_bc, scale=1 / beta_bc),
284        atol=0.,
285        rtol=.05)
286    self.assertAllClose(
287        sample_values.var(axis=0),
288        stats.gamma.var(alpha_bc, scale=1 / beta_bc),
289        atol=10.0,
290        rtol=0.)
291    fails = 0
292    trials = 0
293    for ai, a in enumerate(np.reshape(alpha_v, [-1])):
294      for bi, b in enumerate(np.reshape(beta_v, [-1])):
295        s = sample_values[:, bi, ai]
296        trials += 1
297        fails += 0 if self._kstest(a, b, s) else 1
298    self.assertLess(fails, trials * 0.03)
299
300  def _kstest(self, alpha, beta, samples):
301    # Uses the Kolmogorov-Smirnov test for goodness of fit.
302    if not stats:
303      return True  # If we can't test, return that the test passes.
304    ks, _ = stats.kstest(samples, stats.gamma(alpha, scale=1 / beta).cdf)
305    # Return True when the test passes.
306    return ks < 0.02
307
308  def testGammaPdfOfSampleMultiDims(self):
309    gamma = gamma_lib.Gamma(concentration=[7., 11.], rate=[[5.], [6.]])
310    num = 50000
311    samples = gamma.sample(num, seed=137)
312    pdfs = gamma.prob(samples)
313    sample_vals, pdf_vals = self.evaluate([samples, pdfs])
314    self.assertEqual(samples.get_shape(), (num, 2, 2))
315    self.assertEqual(pdfs.get_shape(), (num, 2, 2))
316    self._assertIntegral(sample_vals[:, 0, 0], pdf_vals[:, 0, 0], err=0.02)
317    self._assertIntegral(sample_vals[:, 0, 1], pdf_vals[:, 0, 1], err=0.02)
318    self._assertIntegral(sample_vals[:, 1, 0], pdf_vals[:, 1, 0], err=0.02)
319    self._assertIntegral(sample_vals[:, 1, 1], pdf_vals[:, 1, 1], err=0.02)
320    if not stats:
321      return
322    self.assertAllClose(
323        stats.gamma.mean([[7., 11.], [7., 11.]],
324                         scale=1 / np.array([[5., 5.], [6., 6.]])),
325        sample_vals.mean(axis=0),
326        atol=.1)
327    self.assertAllClose(
328        stats.gamma.var([[7., 11.], [7., 11.]],
329                        scale=1 / np.array([[5., 5.], [6., 6.]])),
330        sample_vals.var(axis=0),
331        atol=.1)
332
333  def _assertIntegral(self, sample_vals, pdf_vals, err=1e-3):
334    s_p = zip(sample_vals, pdf_vals)
335    prev = (0, 0)
336    total = 0
337    for k in sorted(s_p, key=lambda x: x[0]):
338      pair_pdf = (k[1] + prev[1]) / 2
339      total += (k[0] - prev[0]) * pair_pdf
340      prev = k
341    self.assertNear(1., total, err=err)
342
343  def testGammaNonPositiveInitializationParamsRaises(self):
344    alpha_v = constant_op.constant(0.0, name="alpha")
345    beta_v = constant_op.constant(1.0, name="beta")
346    with self.assertRaisesOpError("x > 0"):
347      gamma = gamma_lib.Gamma(
348          concentration=alpha_v, rate=beta_v, validate_args=True)
349      self.evaluate(gamma.mean())
350    alpha_v = constant_op.constant(1.0, name="alpha")
351    beta_v = constant_op.constant(0.0, name="beta")
352    with self.assertRaisesOpError("x > 0"):
353      gamma = gamma_lib.Gamma(
354          concentration=alpha_v, rate=beta_v, validate_args=True)
355      self.evaluate(gamma.mean())
356
357  def testGammaWithSoftplusConcentrationRate(self):
358    alpha_v = constant_op.constant([0.0, -2.1], name="alpha")
359    beta_v = constant_op.constant([1.0, -3.6], name="beta")
360    gamma = gamma_lib.GammaWithSoftplusConcentrationRate(
361        concentration=alpha_v, rate=beta_v)
362    self.assertAllEqual(
363        self.evaluate(nn_ops.softplus(alpha_v)),
364        self.evaluate(gamma.concentration))
365    self.assertAllEqual(
366        self.evaluate(nn_ops.softplus(beta_v)), self.evaluate(gamma.rate))
367
368  def testGammaGammaKL(self):
369    alpha0 = np.array([3.])
370    beta0 = np.array([1., 2., 3., 1.5, 2.5, 3.5])
371
372    alpha1 = np.array([0.4])
373    beta1 = np.array([0.5, 1., 1.5, 2., 2.5, 3.])
374
375    # Build graph.
376    g0 = gamma_lib.Gamma(concentration=alpha0, rate=beta0)
377    g1 = gamma_lib.Gamma(concentration=alpha1, rate=beta1)
378    x = g0.sample(int(1e4), seed=0)
379    kl_sample = math_ops.reduce_mean(g0.log_prob(x) - g1.log_prob(x), 0)
380    kl_actual = kullback_leibler.kl_divergence(g0, g1)
381
382    # Execute graph.
383    [kl_sample_, kl_actual_] = self.evaluate([kl_sample, kl_actual])
384
385    self.assertEqual(beta0.shape, kl_actual.get_shape())
386
387    if not special:
388      return
389    kl_expected = ((alpha0 - alpha1) * special.digamma(alpha0)
390                   + special.gammaln(alpha1)
391                   - special.gammaln(alpha0)
392                   + alpha1 * np.log(beta0)
393                   - alpha1 * np.log(beta1)
394                   + alpha0 * (beta1 / beta0 - 1.))
395
396    self.assertAllClose(kl_expected, kl_actual_, atol=0., rtol=1e-6)
397    self.assertAllClose(kl_sample_, kl_actual_, atol=0., rtol=1e-1)
398
399
400if __name__ == "__main__":
401  test.main()
402