1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for initializers.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import importlib 22 23import numpy as np 24 25from tensorflow.python.eager import backprop 26from tensorflow.python.framework import constant_op 27from tensorflow.python.framework import test_util 28from tensorflow.python.ops import nn_ops 29from tensorflow.python.ops.distributions import exponential as exponential_lib 30from tensorflow.python.platform import test 31from tensorflow.python.platform import tf_logging 32 33 34def try_import(name): # pylint: disable=invalid-name 35 module = None 36 try: 37 module = importlib.import_module(name) 38 except ImportError as e: 39 tf_logging.warning("Could not import %s: %s" % (name, str(e))) 40 return module 41 42 43stats = try_import("scipy.stats") 44 45 46@test_util.run_all_in_graph_and_eager_modes 47class ExponentialTest(test.TestCase): 48 49 def testExponentialLogPDF(self): 50 batch_size = 6 51 lam = constant_op.constant([2.0] * batch_size) 52 lam_v = 2.0 53 x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32) 54 exponential = exponential_lib.Exponential(rate=lam) 55 56 log_pdf = exponential.log_prob(x) 57 self.assertEqual(log_pdf.get_shape(), (6,)) 58 59 pdf = exponential.prob(x) 60 self.assertEqual(pdf.get_shape(), (6,)) 61 62 if not stats: 63 return 64 expected_log_pdf = stats.expon.logpdf(x, scale=1 / lam_v) 65 self.assertAllClose(self.evaluate(log_pdf), expected_log_pdf) 66 self.assertAllClose(self.evaluate(pdf), np.exp(expected_log_pdf)) 67 68 def testExponentialLogPDFBoundary(self): 69 # Check that Log PDF is finite at 0. 70 rate = np.array([0.1, 0.5, 1., 2., 5., 10.], dtype=np.float32) 71 exponential = exponential_lib.Exponential(rate=rate) 72 log_pdf = exponential.log_prob(0.) 73 self.assertAllClose(np.log(rate), self.evaluate(log_pdf)) 74 75 def testExponentialCDF(self): 76 batch_size = 6 77 lam = constant_op.constant([2.0] * batch_size) 78 lam_v = 2.0 79 x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32) 80 81 exponential = exponential_lib.Exponential(rate=lam) 82 83 cdf = exponential.cdf(x) 84 self.assertEqual(cdf.get_shape(), (6,)) 85 86 if not stats: 87 return 88 expected_cdf = stats.expon.cdf(x, scale=1 / lam_v) 89 self.assertAllClose(self.evaluate(cdf), expected_cdf) 90 91 def testExponentialLogSurvival(self): 92 batch_size = 7 93 lam = constant_op.constant([2.0] * batch_size) 94 lam_v = 2.0 95 x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0, 10.0], dtype=np.float32) 96 97 exponential = exponential_lib.Exponential(rate=lam) 98 99 log_survival = exponential.log_survival_function(x) 100 self.assertEqual(log_survival.get_shape(), (7,)) 101 102 if not stats: 103 return 104 expected_log_survival = stats.expon.logsf(x, scale=1 / lam_v) 105 self.assertAllClose(self.evaluate(log_survival), expected_log_survival) 106 107 def testExponentialMean(self): 108 lam_v = np.array([1.0, 4.0, 2.5]) 109 exponential = exponential_lib.Exponential(rate=lam_v) 110 self.assertEqual(exponential.mean().get_shape(), (3,)) 111 if not stats: 112 return 113 expected_mean = stats.expon.mean(scale=1 / lam_v) 114 self.assertAllClose(self.evaluate(exponential.mean()), expected_mean) 115 116 def testExponentialVariance(self): 117 lam_v = np.array([1.0, 4.0, 2.5]) 118 exponential = exponential_lib.Exponential(rate=lam_v) 119 self.assertEqual(exponential.variance().get_shape(), (3,)) 120 if not stats: 121 return 122 expected_variance = stats.expon.var(scale=1 / lam_v) 123 self.assertAllClose( 124 self.evaluate(exponential.variance()), expected_variance) 125 126 def testExponentialEntropy(self): 127 lam_v = np.array([1.0, 4.0, 2.5]) 128 exponential = exponential_lib.Exponential(rate=lam_v) 129 self.assertEqual(exponential.entropy().get_shape(), (3,)) 130 if not stats: 131 return 132 expected_entropy = stats.expon.entropy(scale=1 / lam_v) 133 self.assertAllClose(self.evaluate(exponential.entropy()), expected_entropy) 134 135 def testExponentialSample(self): 136 lam = constant_op.constant([3.0, 4.0]) 137 lam_v = [3.0, 4.0] 138 n = constant_op.constant(100000) 139 exponential = exponential_lib.Exponential(rate=lam) 140 141 samples = exponential.sample(n, seed=137) 142 sample_values = self.evaluate(samples) 143 self.assertEqual(sample_values.shape, (100000, 2)) 144 self.assertFalse(np.any(sample_values < 0.0)) 145 if not stats: 146 return 147 for i in range(2): 148 self.assertLess( 149 stats.kstest(sample_values[:, i], 150 stats.expon(scale=1.0 / lam_v[i]).cdf)[0], 0.01) 151 152 def testExponentialSampleMultiDimensional(self): 153 batch_size = 2 154 lam_v = [3.0, 22.0] 155 lam = constant_op.constant([lam_v] * batch_size) 156 157 exponential = exponential_lib.Exponential(rate=lam) 158 159 n = 100000 160 samples = exponential.sample(n, seed=138) 161 self.assertEqual(samples.get_shape(), (n, batch_size, 2)) 162 163 sample_values = self.evaluate(samples) 164 165 self.assertFalse(np.any(sample_values < 0.0)) 166 if not stats: 167 return 168 for i in range(2): 169 self.assertLess( 170 stats.kstest(sample_values[:, 0, i], 171 stats.expon(scale=1.0 / lam_v[i]).cdf)[0], 0.01) 172 self.assertLess( 173 stats.kstest(sample_values[:, 1, i], 174 stats.expon(scale=1.0 / lam_v[i]).cdf)[0], 0.01) 175 176 def testFullyReparameterized(self): 177 lam = constant_op.constant([0.1, 1.0]) 178 with backprop.GradientTape() as tape: 179 tape.watch(lam) 180 exponential = exponential_lib.Exponential(rate=lam) 181 samples = exponential.sample(100) 182 grad_lam = tape.gradient(samples, lam) 183 self.assertIsNotNone(grad_lam) 184 185 def testExponentialWithSoftplusRate(self): 186 lam = [-2.2, -3.4] 187 exponential = exponential_lib.ExponentialWithSoftplusRate(rate=lam) 188 self.assertAllClose( 189 self.evaluate(nn_ops.softplus(lam)), self.evaluate(exponential.rate)) 190 191 192if __name__ == "__main__": 193 test.main() 194