• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for initializers."""
16
17import importlib
18
19import numpy as np
20
21from tensorflow.python.eager import backprop
22from tensorflow.python.framework import constant_op
23from tensorflow.python.framework import test_util
24from tensorflow.python.ops import nn_ops
25from tensorflow.python.ops.distributions import exponential as exponential_lib
26from tensorflow.python.platform import test
27from tensorflow.python.platform import tf_logging
28
29
30def try_import(name):  # pylint: disable=invalid-name
31  module = None
32  try:
33    module = importlib.import_module(name)
34  except ImportError as e:
35    tf_logging.warning("Could not import %s: %s" % (name, str(e)))
36  return module
37
38
39stats = try_import("scipy.stats")
40
41
42@test_util.run_all_in_graph_and_eager_modes
43class ExponentialTest(test.TestCase):
44
45  def testExponentialLogPDF(self):
46    batch_size = 6
47    lam = constant_op.constant([2.0] * batch_size)
48    lam_v = 2.0
49    x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
50    exponential = exponential_lib.Exponential(rate=lam)
51
52    log_pdf = exponential.log_prob(x)
53    self.assertEqual(log_pdf.get_shape(), (6,))
54
55    pdf = exponential.prob(x)
56    self.assertEqual(pdf.get_shape(), (6,))
57
58    if not stats:
59      return
60    expected_log_pdf = stats.expon.logpdf(x, scale=1 / lam_v)
61    self.assertAllClose(self.evaluate(log_pdf), expected_log_pdf)
62    self.assertAllClose(self.evaluate(pdf), np.exp(expected_log_pdf))
63
64  def testExponentialLogPDFBoundary(self):
65    # Check that Log PDF is finite at 0.
66    rate = np.array([0.1, 0.5, 1., 2., 5., 10.], dtype=np.float32)
67    exponential = exponential_lib.Exponential(rate=rate)
68    log_pdf = exponential.log_prob(0.)
69    self.assertAllClose(np.log(rate), self.evaluate(log_pdf))
70
71  def testExponentialCDF(self):
72    batch_size = 6
73    lam = constant_op.constant([2.0] * batch_size)
74    lam_v = 2.0
75    x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
76
77    exponential = exponential_lib.Exponential(rate=lam)
78
79    cdf = exponential.cdf(x)
80    self.assertEqual(cdf.get_shape(), (6,))
81
82    if not stats:
83      return
84    expected_cdf = stats.expon.cdf(x, scale=1 / lam_v)
85    self.assertAllClose(self.evaluate(cdf), expected_cdf)
86
87  def testExponentialLogSurvival(self):
88    batch_size = 7
89    lam = constant_op.constant([2.0] * batch_size)
90    lam_v = 2.0
91    x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0, 10.0], dtype=np.float32)
92
93    exponential = exponential_lib.Exponential(rate=lam)
94
95    log_survival = exponential.log_survival_function(x)
96    self.assertEqual(log_survival.get_shape(), (7,))
97
98    if not stats:
99      return
100    expected_log_survival = stats.expon.logsf(x, scale=1 / lam_v)
101    self.assertAllClose(self.evaluate(log_survival), expected_log_survival)
102
103  def testExponentialMean(self):
104    lam_v = np.array([1.0, 4.0, 2.5])
105    exponential = exponential_lib.Exponential(rate=lam_v)
106    self.assertEqual(exponential.mean().get_shape(), (3,))
107    if not stats:
108      return
109    expected_mean = stats.expon.mean(scale=1 / lam_v)
110    self.assertAllClose(self.evaluate(exponential.mean()), expected_mean)
111
112  def testExponentialVariance(self):
113    lam_v = np.array([1.0, 4.0, 2.5])
114    exponential = exponential_lib.Exponential(rate=lam_v)
115    self.assertEqual(exponential.variance().get_shape(), (3,))
116    if not stats:
117      return
118    expected_variance = stats.expon.var(scale=1 / lam_v)
119    self.assertAllClose(
120        self.evaluate(exponential.variance()), expected_variance)
121
122  def testExponentialEntropy(self):
123    lam_v = np.array([1.0, 4.0, 2.5])
124    exponential = exponential_lib.Exponential(rate=lam_v)
125    self.assertEqual(exponential.entropy().get_shape(), (3,))
126    if not stats:
127      return
128    expected_entropy = stats.expon.entropy(scale=1 / lam_v)
129    self.assertAllClose(self.evaluate(exponential.entropy()), expected_entropy)
130
131  def testExponentialSample(self):
132    lam = constant_op.constant([3.0, 4.0])
133    lam_v = [3.0, 4.0]
134    n = constant_op.constant(100000)
135    exponential = exponential_lib.Exponential(rate=lam)
136
137    samples = exponential.sample(n, seed=137)
138    sample_values = self.evaluate(samples)
139    self.assertEqual(sample_values.shape, (100000, 2))
140    self.assertFalse(np.any(sample_values < 0.0))
141    if not stats:
142      return
143    for i in range(2):
144      self.assertLess(
145          stats.kstest(sample_values[:, i],
146                       stats.expon(scale=1.0 / lam_v[i]).cdf)[0], 0.01)
147
148  def testExponentialSampleMultiDimensional(self):
149    batch_size = 2
150    lam_v = [3.0, 22.0]
151    lam = constant_op.constant([lam_v] * batch_size)
152
153    exponential = exponential_lib.Exponential(rate=lam)
154
155    n = 100000
156    samples = exponential.sample(n, seed=138)
157    self.assertEqual(samples.get_shape(), (n, batch_size, 2))
158
159    sample_values = self.evaluate(samples)
160
161    self.assertFalse(np.any(sample_values < 0.0))
162    if not stats:
163      return
164    for i in range(2):
165      self.assertLess(
166          stats.kstest(sample_values[:, 0, i],
167                       stats.expon(scale=1.0 / lam_v[i]).cdf)[0], 0.01)
168      self.assertLess(
169          stats.kstest(sample_values[:, 1, i],
170                       stats.expon(scale=1.0 / lam_v[i]).cdf)[0], 0.01)
171
172  def testFullyReparameterized(self):
173    lam = constant_op.constant([0.1, 1.0])
174    with backprop.GradientTape() as tape:
175      tape.watch(lam)
176      exponential = exponential_lib.Exponential(rate=lam)
177      samples = exponential.sample(100)
178    grad_lam = tape.gradient(samples, lam)
179    self.assertIsNotNone(grad_lam)
180
181  def testExponentialWithSoftplusRate(self):
182    lam = [-2.2, -3.4]
183    exponential = exponential_lib.ExponentialWithSoftplusRate(rate=lam)
184    self.assertAllClose(
185        self.evaluate(nn_ops.softplus(lam)), self.evaluate(exponential.rate))
186
187
188if __name__ == "__main__":
189  test.main()
190