1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for initializers.""" 16 17import importlib 18 19import numpy as np 20 21from tensorflow.python.eager import backprop 22from tensorflow.python.framework import constant_op 23from tensorflow.python.framework import test_util 24from tensorflow.python.ops import nn_ops 25from tensorflow.python.ops.distributions import exponential as exponential_lib 26from tensorflow.python.platform import test 27from tensorflow.python.platform import tf_logging 28 29 30def try_import(name): # pylint: disable=invalid-name 31 module = None 32 try: 33 module = importlib.import_module(name) 34 except ImportError as e: 35 tf_logging.warning("Could not import %s: %s" % (name, str(e))) 36 return module 37 38 39stats = try_import("scipy.stats") 40 41 42@test_util.run_all_in_graph_and_eager_modes 43class ExponentialTest(test.TestCase): 44 45 def testExponentialLogPDF(self): 46 batch_size = 6 47 lam = constant_op.constant([2.0] * batch_size) 48 lam_v = 2.0 49 x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32) 50 exponential = exponential_lib.Exponential(rate=lam) 51 52 log_pdf = exponential.log_prob(x) 53 self.assertEqual(log_pdf.get_shape(), (6,)) 54 55 pdf = exponential.prob(x) 56 self.assertEqual(pdf.get_shape(), (6,)) 57 58 if not stats: 59 return 60 expected_log_pdf = stats.expon.logpdf(x, scale=1 / lam_v) 61 self.assertAllClose(self.evaluate(log_pdf), expected_log_pdf) 62 self.assertAllClose(self.evaluate(pdf), np.exp(expected_log_pdf)) 63 64 def testExponentialLogPDFBoundary(self): 65 # Check that Log PDF is finite at 0. 66 rate = np.array([0.1, 0.5, 1., 2., 5., 10.], dtype=np.float32) 67 exponential = exponential_lib.Exponential(rate=rate) 68 log_pdf = exponential.log_prob(0.) 69 self.assertAllClose(np.log(rate), self.evaluate(log_pdf)) 70 71 def testExponentialCDF(self): 72 batch_size = 6 73 lam = constant_op.constant([2.0] * batch_size) 74 lam_v = 2.0 75 x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32) 76 77 exponential = exponential_lib.Exponential(rate=lam) 78 79 cdf = exponential.cdf(x) 80 self.assertEqual(cdf.get_shape(), (6,)) 81 82 if not stats: 83 return 84 expected_cdf = stats.expon.cdf(x, scale=1 / lam_v) 85 self.assertAllClose(self.evaluate(cdf), expected_cdf) 86 87 def testExponentialLogSurvival(self): 88 batch_size = 7 89 lam = constant_op.constant([2.0] * batch_size) 90 lam_v = 2.0 91 x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0, 10.0], dtype=np.float32) 92 93 exponential = exponential_lib.Exponential(rate=lam) 94 95 log_survival = exponential.log_survival_function(x) 96 self.assertEqual(log_survival.get_shape(), (7,)) 97 98 if not stats: 99 return 100 expected_log_survival = stats.expon.logsf(x, scale=1 / lam_v) 101 self.assertAllClose(self.evaluate(log_survival), expected_log_survival) 102 103 def testExponentialMean(self): 104 lam_v = np.array([1.0, 4.0, 2.5]) 105 exponential = exponential_lib.Exponential(rate=lam_v) 106 self.assertEqual(exponential.mean().get_shape(), (3,)) 107 if not stats: 108 return 109 expected_mean = stats.expon.mean(scale=1 / lam_v) 110 self.assertAllClose(self.evaluate(exponential.mean()), expected_mean) 111 112 def testExponentialVariance(self): 113 lam_v = np.array([1.0, 4.0, 2.5]) 114 exponential = exponential_lib.Exponential(rate=lam_v) 115 self.assertEqual(exponential.variance().get_shape(), (3,)) 116 if not stats: 117 return 118 expected_variance = stats.expon.var(scale=1 / lam_v) 119 self.assertAllClose( 120 self.evaluate(exponential.variance()), expected_variance) 121 122 def testExponentialEntropy(self): 123 lam_v = np.array([1.0, 4.0, 2.5]) 124 exponential = exponential_lib.Exponential(rate=lam_v) 125 self.assertEqual(exponential.entropy().get_shape(), (3,)) 126 if not stats: 127 return 128 expected_entropy = stats.expon.entropy(scale=1 / lam_v) 129 self.assertAllClose(self.evaluate(exponential.entropy()), expected_entropy) 130 131 def testExponentialSample(self): 132 lam = constant_op.constant([3.0, 4.0]) 133 lam_v = [3.0, 4.0] 134 n = constant_op.constant(100000) 135 exponential = exponential_lib.Exponential(rate=lam) 136 137 samples = exponential.sample(n, seed=137) 138 sample_values = self.evaluate(samples) 139 self.assertEqual(sample_values.shape, (100000, 2)) 140 self.assertFalse(np.any(sample_values < 0.0)) 141 if not stats: 142 return 143 for i in range(2): 144 self.assertLess( 145 stats.kstest(sample_values[:, i], 146 stats.expon(scale=1.0 / lam_v[i]).cdf)[0], 0.01) 147 148 def testExponentialSampleMultiDimensional(self): 149 batch_size = 2 150 lam_v = [3.0, 22.0] 151 lam = constant_op.constant([lam_v] * batch_size) 152 153 exponential = exponential_lib.Exponential(rate=lam) 154 155 n = 100000 156 samples = exponential.sample(n, seed=138) 157 self.assertEqual(samples.get_shape(), (n, batch_size, 2)) 158 159 sample_values = self.evaluate(samples) 160 161 self.assertFalse(np.any(sample_values < 0.0)) 162 if not stats: 163 return 164 for i in range(2): 165 self.assertLess( 166 stats.kstest(sample_values[:, 0, i], 167 stats.expon(scale=1.0 / lam_v[i]).cdf)[0], 0.01) 168 self.assertLess( 169 stats.kstest(sample_values[:, 1, i], 170 stats.expon(scale=1.0 / lam_v[i]).cdf)[0], 0.01) 171 172 def testFullyReparameterized(self): 173 lam = constant_op.constant([0.1, 1.0]) 174 with backprop.GradientTape() as tape: 175 tape.watch(lam) 176 exponential = exponential_lib.Exponential(rate=lam) 177 samples = exponential.sample(100) 178 grad_lam = tape.gradient(samples, lam) 179 self.assertIsNotNone(grad_lam) 180 181 def testExponentialWithSoftplusRate(self): 182 lam = [-2.2, -3.4] 183 exponential = exponential_lib.ExponentialWithSoftplusRate(rate=lam) 184 self.assertAllClose( 185 self.evaluate(nn_ops.softplus(lam)), self.evaluate(exponential.rate)) 186 187 188if __name__ == "__main__": 189 test.main() 190