1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16import importlib 17 18import numpy as np 19 20from tensorflow.python.eager import backprop 21from tensorflow.python.framework import constant_op 22from tensorflow.python.framework import tensor_shape 23from tensorflow.python.framework import test_util 24from tensorflow.python.ops import nn_ops 25from tensorflow.python.ops.distributions import laplace as laplace_lib 26from tensorflow.python.platform import test 27 28from tensorflow.python.platform import tf_logging 29 30 31def try_import(name): # pylint: disable=invalid-name 32 module = None 33 try: 34 module = importlib.import_module(name) 35 except ImportError as e: 36 tf_logging.warning("Could not import %s: %s" % (name, str(e))) 37 return module 38 39 40stats = try_import("scipy.stats") 41 42 43@test_util.run_all_in_graph_and_eager_modes 44class LaplaceTest(test.TestCase): 45 46 def testLaplaceShape(self): 47 loc = constant_op.constant([3.0] * 5) 48 scale = constant_op.constant(11.0) 49 laplace = laplace_lib.Laplace(loc=loc, scale=scale) 50 51 self.assertEqual(self.evaluate(laplace.batch_shape_tensor()), (5,)) 52 self.assertEqual(laplace.batch_shape, tensor_shape.TensorShape([5])) 53 self.assertAllEqual(self.evaluate(laplace.event_shape_tensor()), []) 54 self.assertEqual(laplace.event_shape, tensor_shape.TensorShape([])) 55 56 def testLaplaceLogPDF(self): 57 batch_size = 6 58 loc = constant_op.constant([2.0] * batch_size) 59 scale = constant_op.constant([3.0] * batch_size) 60 loc_v = 2.0 61 scale_v = 3.0 62 x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32) 63 laplace = laplace_lib.Laplace(loc=loc, scale=scale) 64 log_pdf = laplace.log_prob(x) 65 self.assertEqual(log_pdf.get_shape(), (6,)) 66 if not stats: 67 return 68 expected_log_pdf = stats.laplace.logpdf(x, loc_v, scale=scale_v) 69 self.assertAllClose(self.evaluate(log_pdf), expected_log_pdf) 70 71 pdf = laplace.prob(x) 72 self.assertEqual(pdf.get_shape(), (6,)) 73 self.assertAllClose(self.evaluate(pdf), np.exp(expected_log_pdf)) 74 75 def testLaplaceLogPDFMultidimensional(self): 76 batch_size = 6 77 loc = constant_op.constant([[2.0, 4.0]] * batch_size) 78 scale = constant_op.constant([[3.0, 4.0]] * batch_size) 79 loc_v = np.array([2.0, 4.0]) 80 scale_v = np.array([3.0, 4.0]) 81 x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T 82 laplace = laplace_lib.Laplace(loc=loc, scale=scale) 83 log_pdf = laplace.log_prob(x) 84 log_pdf_values = self.evaluate(log_pdf) 85 self.assertEqual(log_pdf.get_shape(), (6, 2)) 86 87 pdf = laplace.prob(x) 88 pdf_values = self.evaluate(pdf) 89 self.assertEqual(pdf.get_shape(), (6, 2)) 90 if not stats: 91 return 92 expected_log_pdf = stats.laplace.logpdf(x, loc_v, scale=scale_v) 93 self.assertAllClose(log_pdf_values, expected_log_pdf) 94 self.assertAllClose(pdf_values, np.exp(expected_log_pdf)) 95 96 def testLaplaceLogPDFMultidimensionalBroadcasting(self): 97 batch_size = 6 98 loc = constant_op.constant([[2.0, 4.0]] * batch_size) 99 scale = constant_op.constant(3.0) 100 loc_v = np.array([2.0, 4.0]) 101 scale_v = 3.0 102 x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T 103 laplace = laplace_lib.Laplace(loc=loc, scale=scale) 104 log_pdf = laplace.log_prob(x) 105 log_pdf_values = self.evaluate(log_pdf) 106 self.assertEqual(log_pdf.get_shape(), (6, 2)) 107 108 pdf = laplace.prob(x) 109 pdf_values = self.evaluate(pdf) 110 self.assertEqual(pdf.get_shape(), (6, 2)) 111 if not stats: 112 return 113 expected_log_pdf = stats.laplace.logpdf(x, loc_v, scale=scale_v) 114 self.assertAllClose(log_pdf_values, expected_log_pdf) 115 self.assertAllClose(pdf_values, np.exp(expected_log_pdf)) 116 117 def testLaplaceCDF(self): 118 batch_size = 6 119 loc = constant_op.constant([2.0] * batch_size) 120 scale = constant_op.constant([3.0] * batch_size) 121 loc_v = 2.0 122 scale_v = 3.0 123 x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32) 124 125 laplace = laplace_lib.Laplace(loc=loc, scale=scale) 126 127 cdf = laplace.cdf(x) 128 self.assertEqual(cdf.get_shape(), (6,)) 129 if not stats: 130 return 131 expected_cdf = stats.laplace.cdf(x, loc_v, scale=scale_v) 132 self.assertAllClose(self.evaluate(cdf), expected_cdf) 133 134 def testLaplaceLogCDF(self): 135 batch_size = 6 136 loc = constant_op.constant([2.0] * batch_size) 137 scale = constant_op.constant([3.0] * batch_size) 138 loc_v = 2.0 139 scale_v = 3.0 140 x = np.array([-2.5, 2.5, -4.0, 0.1, 1.0, 2.0], dtype=np.float32) 141 142 laplace = laplace_lib.Laplace(loc=loc, scale=scale) 143 144 cdf = laplace.log_cdf(x) 145 self.assertEqual(cdf.get_shape(), (6,)) 146 if not stats: 147 return 148 expected_cdf = stats.laplace.logcdf(x, loc_v, scale=scale_v) 149 self.assertAllClose(self.evaluate(cdf), expected_cdf) 150 151 def testLaplaceLogSurvivalFunction(self): 152 batch_size = 6 153 loc = constant_op.constant([2.0] * batch_size) 154 scale = constant_op.constant([3.0] * batch_size) 155 loc_v = 2.0 156 scale_v = 3.0 157 x = np.array([-2.5, 2.5, -4.0, 0.1, 1.0, 2.0], dtype=np.float32) 158 159 laplace = laplace_lib.Laplace(loc=loc, scale=scale) 160 161 sf = laplace.log_survival_function(x) 162 self.assertEqual(sf.get_shape(), (6,)) 163 if not stats: 164 return 165 expected_sf = stats.laplace.logsf(x, loc_v, scale=scale_v) 166 self.assertAllClose(self.evaluate(sf), expected_sf) 167 168 def testLaplaceMean(self): 169 loc_v = np.array([1.0, 3.0, 2.5]) 170 scale_v = np.array([1.0, 4.0, 5.0]) 171 laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v) 172 self.assertEqual(laplace.mean().get_shape(), (3,)) 173 if not stats: 174 return 175 expected_means = stats.laplace.mean(loc_v, scale=scale_v) 176 self.assertAllClose(self.evaluate(laplace.mean()), expected_means) 177 178 def testLaplaceMode(self): 179 loc_v = np.array([0.5, 3.0, 2.5]) 180 scale_v = np.array([1.0, 4.0, 5.0]) 181 laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v) 182 self.assertEqual(laplace.mode().get_shape(), (3,)) 183 self.assertAllClose(self.evaluate(laplace.mode()), loc_v) 184 185 def testLaplaceVariance(self): 186 loc_v = np.array([1.0, 3.0, 2.5]) 187 scale_v = np.array([1.0, 4.0, 5.0]) 188 laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v) 189 self.assertEqual(laplace.variance().get_shape(), (3,)) 190 if not stats: 191 return 192 expected_variances = stats.laplace.var(loc_v, scale=scale_v) 193 self.assertAllClose(self.evaluate(laplace.variance()), expected_variances) 194 195 def testLaplaceStd(self): 196 loc_v = np.array([1.0, 3.0, 2.5]) 197 scale_v = np.array([1.0, 4.0, 5.0]) 198 laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v) 199 self.assertEqual(laplace.stddev().get_shape(), (3,)) 200 if not stats: 201 return 202 expected_stddev = stats.laplace.std(loc_v, scale=scale_v) 203 self.assertAllClose(self.evaluate(laplace.stddev()), expected_stddev) 204 205 def testLaplaceEntropy(self): 206 loc_v = np.array([1.0, 3.0, 2.5]) 207 scale_v = np.array([1.0, 4.0, 5.0]) 208 laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v) 209 self.assertEqual(laplace.entropy().get_shape(), (3,)) 210 if not stats: 211 return 212 expected_entropy = stats.laplace.entropy(loc_v, scale=scale_v) 213 self.assertAllClose(self.evaluate(laplace.entropy()), expected_entropy) 214 215 def testLaplaceSample(self): 216 loc_v = 4.0 217 scale_v = 3.0 218 loc = constant_op.constant(loc_v) 219 scale = constant_op.constant(scale_v) 220 n = 100000 221 laplace = laplace_lib.Laplace(loc=loc, scale=scale) 222 samples = laplace.sample(n, seed=137) 223 sample_values = self.evaluate(samples) 224 self.assertEqual(samples.get_shape(), (n,)) 225 self.assertEqual(sample_values.shape, (n,)) 226 if not stats: 227 return 228 self.assertAllClose( 229 sample_values.mean(), 230 stats.laplace.mean(loc_v, scale=scale_v), 231 rtol=0.05, 232 atol=0.) 233 self.assertAllClose( 234 sample_values.var(), 235 stats.laplace.var(loc_v, scale=scale_v), 236 rtol=0.05, 237 atol=0.) 238 self.assertTrue(self._kstest(loc_v, scale_v, sample_values)) 239 240 def testLaplaceFullyReparameterized(self): 241 loc = constant_op.constant(4.0) 242 scale = constant_op.constant(3.0) 243 with backprop.GradientTape() as tape: 244 tape.watch(loc) 245 tape.watch(scale) 246 laplace = laplace_lib.Laplace(loc=loc, scale=scale) 247 samples = laplace.sample(100) 248 grad_loc, grad_scale = tape.gradient(samples, [loc, scale]) 249 self.assertIsNotNone(grad_loc) 250 self.assertIsNotNone(grad_scale) 251 252 def testLaplaceSampleMultiDimensional(self): 253 loc_v = np.array([np.arange(1, 101, dtype=np.float32)]) # 1 x 100 254 scale_v = np.array([np.arange(1, 11, dtype=np.float32)]).T # 10 x 1 255 laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v) 256 n = 10000 257 samples = laplace.sample(n, seed=137) 258 sample_values = self.evaluate(samples) 259 self.assertEqual(samples.get_shape(), (n, 10, 100)) 260 self.assertEqual(sample_values.shape, (n, 10, 100)) 261 zeros = np.zeros_like(loc_v + scale_v) # 10 x 100 262 loc_bc = loc_v + zeros 263 scale_bc = scale_v + zeros 264 if not stats: 265 return 266 self.assertAllClose( 267 sample_values.mean(axis=0), 268 stats.laplace.mean(loc_bc, scale=scale_bc), 269 rtol=0.35, 270 atol=0.) 271 self.assertAllClose( 272 sample_values.var(axis=0), 273 stats.laplace.var(loc_bc, scale=scale_bc), 274 rtol=0.105, 275 atol=0.0) 276 fails = 0 277 trials = 0 278 for ai, a in enumerate(np.reshape(loc_v, [-1])): 279 for bi, b in enumerate(np.reshape(scale_v, [-1])): 280 s = sample_values[:, bi, ai] 281 trials += 1 282 fails += 0 if self._kstest(a, b, s) else 1 283 self.assertLess(fails, trials * 0.03) 284 285 def _kstest(self, loc, scale, samples): 286 # Uses the Kolmogorov-Smirnov test for goodness of fit. 287 if not stats: 288 return True # If scipy isn't available, return "True" for passing 289 ks, _ = stats.kstest(samples, stats.laplace(loc, scale=scale).cdf) 290 # Return True when the test passes. 291 return ks < 0.02 292 293 def testLaplacePdfOfSampleMultiDims(self): 294 laplace = laplace_lib.Laplace(loc=[7., 11.], scale=[[5.], [6.]]) 295 num = 50000 296 samples = laplace.sample(num, seed=137) 297 pdfs = laplace.prob(samples) 298 sample_vals, pdf_vals = self.evaluate([samples, pdfs]) 299 self.assertEqual(samples.get_shape(), (num, 2, 2)) 300 self.assertEqual(pdfs.get_shape(), (num, 2, 2)) 301 self._assertIntegral(sample_vals[:, 0, 0], pdf_vals[:, 0, 0], err=0.02) 302 self._assertIntegral(sample_vals[:, 0, 1], pdf_vals[:, 0, 1], err=0.02) 303 self._assertIntegral(sample_vals[:, 1, 0], pdf_vals[:, 1, 0], err=0.02) 304 self._assertIntegral(sample_vals[:, 1, 1], pdf_vals[:, 1, 1], err=0.02) 305 if not stats: 306 return 307 self.assertAllClose( 308 stats.laplace.mean( 309 [[7., 11.], [7., 11.]], scale=np.array([[5., 5.], [6., 6.]])), 310 sample_vals.mean(axis=0), 311 rtol=0.05, 312 atol=0.) 313 self.assertAllClose( 314 stats.laplace.var([[7., 11.], [7., 11.]], 315 scale=np.array([[5., 5.], [6., 6.]])), 316 sample_vals.var(axis=0), 317 rtol=0.05, 318 atol=0.) 319 320 def _assertIntegral(self, sample_vals, pdf_vals, err=1e-3): 321 s_p = zip(sample_vals, pdf_vals) 322 prev = (0, 0) 323 total = 0 324 for k in sorted(s_p, key=lambda x: x[0]): 325 pair_pdf = (k[1] + prev[1]) / 2 326 total += (k[0] - prev[0]) * pair_pdf 327 prev = k 328 self.assertNear(1., total, err=err) 329 330 def testLaplaceNonPositiveInitializationParamsRaises(self): 331 loc_v = constant_op.constant(0.0, name="loc") 332 scale_v = constant_op.constant(-1.0, name="scale") 333 with self.assertRaisesOpError("Condition x > 0 did not hold element-wise"): 334 laplace = laplace_lib.Laplace( 335 loc=loc_v, scale=scale_v, validate_args=True) 336 self.evaluate(laplace.mean()) 337 loc_v = constant_op.constant(1.0, name="loc") 338 scale_v = constant_op.constant(0.0, name="scale") 339 with self.assertRaisesOpError("Condition x > 0 did not hold element-wise"): 340 laplace = laplace_lib.Laplace( 341 loc=loc_v, scale=scale_v, validate_args=True) 342 self.evaluate(laplace.mean()) 343 344 def testLaplaceWithSoftplusScale(self): 345 loc_v = constant_op.constant([0.0, 1.0], name="loc") 346 scale_v = constant_op.constant([-1.0, 2.0], name="scale") 347 laplace = laplace_lib.LaplaceWithSoftplusScale(loc=loc_v, scale=scale_v) 348 self.assertAllClose( 349 self.evaluate(nn_ops.softplus(scale_v)), self.evaluate(laplace.scale)) 350 self.assertAllClose(self.evaluate(loc_v), self.evaluate(laplace.loc)) 351 352 353if __name__ == "__main__": 354 test.main() 355