• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for initializers."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import importlib
22
23import numpy as np
24
25from tensorflow.python.eager import backprop
26from tensorflow.python.framework import constant_op
27from tensorflow.python.framework import test_util
28from tensorflow.python.ops import nn_ops
29from tensorflow.python.ops.distributions import exponential as exponential_lib
30from tensorflow.python.platform import test
31from tensorflow.python.platform import tf_logging
32
33
34def try_import(name):  # pylint: disable=invalid-name
35  module = None
36  try:
37    module = importlib.import_module(name)
38  except ImportError as e:
39    tf_logging.warning("Could not import %s: %s" % (name, str(e)))
40  return module
41
42
43stats = try_import("scipy.stats")
44
45
46@test_util.run_all_in_graph_and_eager_modes
47class ExponentialTest(test.TestCase):
48
49  def testExponentialLogPDF(self):
50    batch_size = 6
51    lam = constant_op.constant([2.0] * batch_size)
52    lam_v = 2.0
53    x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
54    exponential = exponential_lib.Exponential(rate=lam)
55
56    log_pdf = exponential.log_prob(x)
57    self.assertEqual(log_pdf.get_shape(), (6,))
58
59    pdf = exponential.prob(x)
60    self.assertEqual(pdf.get_shape(), (6,))
61
62    if not stats:
63      return
64    expected_log_pdf = stats.expon.logpdf(x, scale=1 / lam_v)
65    self.assertAllClose(self.evaluate(log_pdf), expected_log_pdf)
66    self.assertAllClose(self.evaluate(pdf), np.exp(expected_log_pdf))
67
68  def testExponentialLogPDFBoundary(self):
69    # Check that Log PDF is finite at 0.
70    rate = np.array([0.1, 0.5, 1., 2., 5., 10.], dtype=np.float32)
71    exponential = exponential_lib.Exponential(rate=rate)
72    log_pdf = exponential.log_prob(0.)
73    self.assertAllClose(np.log(rate), self.evaluate(log_pdf))
74
75  def testExponentialCDF(self):
76    batch_size = 6
77    lam = constant_op.constant([2.0] * batch_size)
78    lam_v = 2.0
79    x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
80
81    exponential = exponential_lib.Exponential(rate=lam)
82
83    cdf = exponential.cdf(x)
84    self.assertEqual(cdf.get_shape(), (6,))
85
86    if not stats:
87      return
88    expected_cdf = stats.expon.cdf(x, scale=1 / lam_v)
89    self.assertAllClose(self.evaluate(cdf), expected_cdf)
90
91  def testExponentialLogSurvival(self):
92    batch_size = 7
93    lam = constant_op.constant([2.0] * batch_size)
94    lam_v = 2.0
95    x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0, 10.0], dtype=np.float32)
96
97    exponential = exponential_lib.Exponential(rate=lam)
98
99    log_survival = exponential.log_survival_function(x)
100    self.assertEqual(log_survival.get_shape(), (7,))
101
102    if not stats:
103      return
104    expected_log_survival = stats.expon.logsf(x, scale=1 / lam_v)
105    self.assertAllClose(self.evaluate(log_survival), expected_log_survival)
106
107  def testExponentialMean(self):
108    lam_v = np.array([1.0, 4.0, 2.5])
109    exponential = exponential_lib.Exponential(rate=lam_v)
110    self.assertEqual(exponential.mean().get_shape(), (3,))
111    if not stats:
112      return
113    expected_mean = stats.expon.mean(scale=1 / lam_v)
114    self.assertAllClose(self.evaluate(exponential.mean()), expected_mean)
115
116  def testExponentialVariance(self):
117    lam_v = np.array([1.0, 4.0, 2.5])
118    exponential = exponential_lib.Exponential(rate=lam_v)
119    self.assertEqual(exponential.variance().get_shape(), (3,))
120    if not stats:
121      return
122    expected_variance = stats.expon.var(scale=1 / lam_v)
123    self.assertAllClose(
124        self.evaluate(exponential.variance()), expected_variance)
125
126  def testExponentialEntropy(self):
127    lam_v = np.array([1.0, 4.0, 2.5])
128    exponential = exponential_lib.Exponential(rate=lam_v)
129    self.assertEqual(exponential.entropy().get_shape(), (3,))
130    if not stats:
131      return
132    expected_entropy = stats.expon.entropy(scale=1 / lam_v)
133    self.assertAllClose(self.evaluate(exponential.entropy()), expected_entropy)
134
135  def testExponentialSample(self):
136    lam = constant_op.constant([3.0, 4.0])
137    lam_v = [3.0, 4.0]
138    n = constant_op.constant(100000)
139    exponential = exponential_lib.Exponential(rate=lam)
140
141    samples = exponential.sample(n, seed=137)
142    sample_values = self.evaluate(samples)
143    self.assertEqual(sample_values.shape, (100000, 2))
144    self.assertFalse(np.any(sample_values < 0.0))
145    if not stats:
146      return
147    for i in range(2):
148      self.assertLess(
149          stats.kstest(sample_values[:, i],
150                       stats.expon(scale=1.0 / lam_v[i]).cdf)[0], 0.01)
151
152  def testExponentialSampleMultiDimensional(self):
153    batch_size = 2
154    lam_v = [3.0, 22.0]
155    lam = constant_op.constant([lam_v] * batch_size)
156
157    exponential = exponential_lib.Exponential(rate=lam)
158
159    n = 100000
160    samples = exponential.sample(n, seed=138)
161    self.assertEqual(samples.get_shape(), (n, batch_size, 2))
162
163    sample_values = self.evaluate(samples)
164
165    self.assertFalse(np.any(sample_values < 0.0))
166    if not stats:
167      return
168    for i in range(2):
169      self.assertLess(
170          stats.kstest(sample_values[:, 0, i],
171                       stats.expon(scale=1.0 / lam_v[i]).cdf)[0], 0.01)
172      self.assertLess(
173          stats.kstest(sample_values[:, 1, i],
174                       stats.expon(scale=1.0 / lam_v[i]).cdf)[0], 0.01)
175
176  def testFullyReparameterized(self):
177    lam = constant_op.constant([0.1, 1.0])
178    with backprop.GradientTape() as tape:
179      tape.watch(lam)
180      exponential = exponential_lib.Exponential(rate=lam)
181      samples = exponential.sample(100)
182    grad_lam = tape.gradient(samples, lam)
183    self.assertIsNotNone(grad_lam)
184
185  def testExponentialWithSoftplusRate(self):
186    lam = [-2.2, -3.4]
187    exponential = exponential_lib.ExponentialWithSoftplusRate(rate=lam)
188    self.assertAllClose(
189        self.evaluate(nn_ops.softplus(lam)), self.evaluate(exponential.rate))
190
191
192if __name__ == "__main__":
193  test.main()
194