1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for initializers.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import importlib 22import math 23 24import numpy as np 25 26from tensorflow.python.eager import backprop 27from tensorflow.python.framework import constant_op 28from tensorflow.python.framework import dtypes 29from tensorflow.python.framework import ops 30from tensorflow.python.framework import tensor_shape 31from tensorflow.python.framework import test_util 32from tensorflow.python.ops import array_ops 33from tensorflow.python.ops import gradients_impl 34from tensorflow.python.ops import nn_ops 35from tensorflow.python.ops import variables 36from tensorflow.python.ops.distributions import kullback_leibler 37from tensorflow.python.ops.distributions import normal as normal_lib 38from tensorflow.python.platform import test 39from tensorflow.python.platform import tf_logging 40 41 42def try_import(name): # pylint: disable=invalid-name 43 module = None 44 try: 45 module = importlib.import_module(name) 46 except ImportError as e: 47 tf_logging.warning("Could not import %s: %s" % (name, str(e))) 48 return module 49 50stats = try_import("scipy.stats") 51 52 53class NormalTest(test.TestCase): 54 55 def setUp(self): 56 self._rng = np.random.RandomState(123) 57 58 def assertAllFinite(self, tensor): 59 is_finite = np.isfinite(self.evaluate(tensor)) 60 all_true = np.ones_like(is_finite, dtype=np.bool) 61 self.assertAllEqual(all_true, is_finite) 62 63 def _testParamShapes(self, sample_shape, expected): 64 param_shapes = normal_lib.Normal.param_shapes(sample_shape) 65 mu_shape, sigma_shape = param_shapes["loc"], param_shapes["scale"] 66 self.assertAllEqual(expected, self.evaluate(mu_shape)) 67 self.assertAllEqual(expected, self.evaluate(sigma_shape)) 68 mu = array_ops.zeros(mu_shape) 69 sigma = array_ops.ones(sigma_shape) 70 self.assertAllEqual( 71 expected, 72 self.evaluate(array_ops.shape(normal_lib.Normal(mu, sigma).sample()))) 73 74 def _testParamStaticShapes(self, sample_shape, expected): 75 param_shapes = normal_lib.Normal.param_static_shapes(sample_shape) 76 mu_shape, sigma_shape = param_shapes["loc"], param_shapes["scale"] 77 self.assertEqual(expected, mu_shape) 78 self.assertEqual(expected, sigma_shape) 79 80 @test_util.run_in_graph_and_eager_modes 81 def testSampleLikeArgsGetDistDType(self): 82 dist = normal_lib.Normal(0., 1.) 83 self.assertEqual(dtypes.float32, dist.dtype) 84 for method in ("log_prob", "prob", "log_cdf", "cdf", 85 "log_survival_function", "survival_function", "quantile"): 86 self.assertEqual(dtypes.float32, getattr(dist, method)(1).dtype) 87 88 @test_util.run_in_graph_and_eager_modes 89 def testParamShapes(self): 90 sample_shape = [10, 3, 4] 91 self._testParamShapes(sample_shape, sample_shape) 92 self._testParamShapes(constant_op.constant(sample_shape), sample_shape) 93 94 @test_util.run_in_graph_and_eager_modes 95 def testParamStaticShapes(self): 96 sample_shape = [10, 3, 4] 97 self._testParamStaticShapes(sample_shape, sample_shape) 98 self._testParamStaticShapes( 99 tensor_shape.TensorShape(sample_shape), sample_shape) 100 101 @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) 102 def testNormalWithSoftplusScale(self): 103 mu = array_ops.zeros((10, 3)) 104 rho = array_ops.ones((10, 3)) * -2. 105 normal = normal_lib.NormalWithSoftplusScale(loc=mu, scale=rho) 106 self.assertAllEqual(self.evaluate(mu), self.evaluate(normal.loc)) 107 self.assertAllEqual( 108 self.evaluate(nn_ops.softplus(rho)), self.evaluate(normal.scale)) 109 110 @test_util.run_in_graph_and_eager_modes 111 def testNormalLogPDF(self): 112 batch_size = 6 113 mu = constant_op.constant([3.0] * batch_size) 114 sigma = constant_op.constant([math.sqrt(10.0)] * batch_size) 115 x = np.array([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0], dtype=np.float32) 116 normal = normal_lib.Normal(loc=mu, scale=sigma) 117 118 log_pdf = normal.log_prob(x) 119 self.assertAllEqual( 120 self.evaluate(normal.batch_shape_tensor()), log_pdf.get_shape()) 121 self.assertAllEqual( 122 self.evaluate(normal.batch_shape_tensor()), 123 self.evaluate(log_pdf).shape) 124 self.assertAllEqual(normal.batch_shape, log_pdf.get_shape()) 125 self.assertAllEqual(normal.batch_shape, self.evaluate(log_pdf).shape) 126 127 pdf = normal.prob(x) 128 self.assertAllEqual( 129 self.evaluate(normal.batch_shape_tensor()), pdf.get_shape()) 130 self.assertAllEqual( 131 self.evaluate(normal.batch_shape_tensor()), 132 self.evaluate(pdf).shape) 133 self.assertAllEqual(normal.batch_shape, pdf.get_shape()) 134 self.assertAllEqual(normal.batch_shape, self.evaluate(pdf).shape) 135 136 if not stats: 137 return 138 expected_log_pdf = stats.norm(self.evaluate(mu), 139 self.evaluate(sigma)).logpdf(x) 140 self.assertAllClose(expected_log_pdf, self.evaluate(log_pdf)) 141 self.assertAllClose(np.exp(expected_log_pdf), self.evaluate(pdf)) 142 143 @test_util.run_in_graph_and_eager_modes 144 def testNormalLogPDFMultidimensional(self): 145 batch_size = 6 146 mu = constant_op.constant([[3.0, -3.0]] * batch_size) 147 sigma = constant_op.constant( 148 [[math.sqrt(10.0), math.sqrt(15.0)]] * batch_size) 149 x = np.array([[-2.5, 2.5, 4.0, 0.0, -1.0, 2.0]], dtype=np.float32).T 150 normal = normal_lib.Normal(loc=mu, scale=sigma) 151 152 log_pdf = normal.log_prob(x) 153 log_pdf_values = self.evaluate(log_pdf) 154 self.assertEqual(log_pdf.get_shape(), (6, 2)) 155 self.assertAllEqual( 156 self.evaluate(normal.batch_shape_tensor()), log_pdf.get_shape()) 157 self.assertAllEqual( 158 self.evaluate(normal.batch_shape_tensor()), 159 self.evaluate(log_pdf).shape) 160 self.assertAllEqual(normal.batch_shape, log_pdf.get_shape()) 161 self.assertAllEqual(normal.batch_shape, self.evaluate(log_pdf).shape) 162 163 pdf = normal.prob(x) 164 pdf_values = self.evaluate(pdf) 165 self.assertEqual(pdf.get_shape(), (6, 2)) 166 self.assertAllEqual( 167 self.evaluate(normal.batch_shape_tensor()), pdf.get_shape()) 168 self.assertAllEqual( 169 self.evaluate(normal.batch_shape_tensor()), pdf_values.shape) 170 self.assertAllEqual(normal.batch_shape, pdf.get_shape()) 171 self.assertAllEqual(normal.batch_shape, pdf_values.shape) 172 173 if not stats: 174 return 175 expected_log_pdf = stats.norm(self.evaluate(mu), 176 self.evaluate(sigma)).logpdf(x) 177 self.assertAllClose(expected_log_pdf, log_pdf_values) 178 self.assertAllClose(np.exp(expected_log_pdf), pdf_values) 179 180 @test_util.run_in_graph_and_eager_modes 181 def testNormalCDF(self): 182 batch_size = 50 183 mu = self._rng.randn(batch_size) 184 sigma = self._rng.rand(batch_size) + 1.0 185 x = np.linspace(-8.0, 8.0, batch_size).astype(np.float64) 186 187 normal = normal_lib.Normal(loc=mu, scale=sigma) 188 cdf = normal.cdf(x) 189 self.assertAllEqual( 190 self.evaluate(normal.batch_shape_tensor()), cdf.get_shape()) 191 self.assertAllEqual( 192 self.evaluate(normal.batch_shape_tensor()), 193 self.evaluate(cdf).shape) 194 self.assertAllEqual(normal.batch_shape, cdf.get_shape()) 195 self.assertAllEqual(normal.batch_shape, self.evaluate(cdf).shape) 196 if not stats: 197 return 198 expected_cdf = stats.norm(mu, sigma).cdf(x) 199 self.assertAllClose(expected_cdf, self.evaluate(cdf), atol=0) 200 201 @test_util.run_in_graph_and_eager_modes 202 def testNormalSurvivalFunction(self): 203 batch_size = 50 204 mu = self._rng.randn(batch_size) 205 sigma = self._rng.rand(batch_size) + 1.0 206 x = np.linspace(-8.0, 8.0, batch_size).astype(np.float64) 207 208 normal = normal_lib.Normal(loc=mu, scale=sigma) 209 210 sf = normal.survival_function(x) 211 self.assertAllEqual( 212 self.evaluate(normal.batch_shape_tensor()), sf.get_shape()) 213 self.assertAllEqual( 214 self.evaluate(normal.batch_shape_tensor()), 215 self.evaluate(sf).shape) 216 self.assertAllEqual(normal.batch_shape, sf.get_shape()) 217 self.assertAllEqual(normal.batch_shape, self.evaluate(sf).shape) 218 if not stats: 219 return 220 expected_sf = stats.norm(mu, sigma).sf(x) 221 self.assertAllClose(expected_sf, self.evaluate(sf), atol=0) 222 223 @test_util.run_in_graph_and_eager_modes 224 def testNormalLogCDF(self): 225 batch_size = 50 226 mu = self._rng.randn(batch_size) 227 sigma = self._rng.rand(batch_size) + 1.0 228 x = np.linspace(-100.0, 10.0, batch_size).astype(np.float64) 229 230 normal = normal_lib.Normal(loc=mu, scale=sigma) 231 232 cdf = normal.log_cdf(x) 233 self.assertAllEqual( 234 self.evaluate(normal.batch_shape_tensor()), cdf.get_shape()) 235 self.assertAllEqual( 236 self.evaluate(normal.batch_shape_tensor()), 237 self.evaluate(cdf).shape) 238 self.assertAllEqual(normal.batch_shape, cdf.get_shape()) 239 self.assertAllEqual(normal.batch_shape, self.evaluate(cdf).shape) 240 241 if not stats: 242 return 243 expected_cdf = stats.norm(mu, sigma).logcdf(x) 244 self.assertAllClose(expected_cdf, self.evaluate(cdf), atol=0, rtol=1e-3) 245 246 def testFiniteGradientAtDifficultPoints(self): 247 for dtype in [np.float32, np.float64]: 248 g = ops.Graph() 249 with g.as_default(): 250 mu = variables.Variable(dtype(0.0)) 251 sigma = variables.Variable(dtype(1.0)) 252 dist = normal_lib.Normal(loc=mu, scale=sigma) 253 x = np.array([-100., -20., -5., 0., 5., 20., 100.]).astype(dtype) 254 for func in [ 255 dist.cdf, dist.log_cdf, dist.survival_function, 256 dist.log_survival_function, dist.log_prob, dist.prob 257 ]: 258 value = func(x) 259 grads = gradients_impl.gradients(value, [mu, sigma]) 260 with self.session(graph=g): 261 variables.global_variables_initializer().run() 262 self.assertAllFinite(value) 263 self.assertAllFinite(grads[0]) 264 self.assertAllFinite(grads[1]) 265 266 @test_util.run_in_graph_and_eager_modes 267 def testNormalLogSurvivalFunction(self): 268 batch_size = 50 269 mu = self._rng.randn(batch_size) 270 sigma = self._rng.rand(batch_size) + 1.0 271 x = np.linspace(-10.0, 100.0, batch_size).astype(np.float64) 272 273 normal = normal_lib.Normal(loc=mu, scale=sigma) 274 275 sf = normal.log_survival_function(x) 276 self.assertAllEqual( 277 self.evaluate(normal.batch_shape_tensor()), sf.get_shape()) 278 self.assertAllEqual( 279 self.evaluate(normal.batch_shape_tensor()), 280 self.evaluate(sf).shape) 281 self.assertAllEqual(normal.batch_shape, sf.get_shape()) 282 self.assertAllEqual(normal.batch_shape, self.evaluate(sf).shape) 283 284 if not stats: 285 return 286 expected_sf = stats.norm(mu, sigma).logsf(x) 287 self.assertAllClose(expected_sf, self.evaluate(sf), atol=0, rtol=1e-5) 288 289 @test_util.run_in_graph_and_eager_modes 290 def testNormalEntropyWithScalarInputs(self): 291 # Scipy.stats.norm cannot deal with the shapes in the other test. 292 mu_v = 2.34 293 sigma_v = 4.56 294 normal = normal_lib.Normal(loc=mu_v, scale=sigma_v) 295 296 entropy = normal.entropy() 297 self.assertAllEqual( 298 self.evaluate(normal.batch_shape_tensor()), entropy.get_shape()) 299 self.assertAllEqual( 300 self.evaluate(normal.batch_shape_tensor()), 301 self.evaluate(entropy).shape) 302 self.assertAllEqual(normal.batch_shape, entropy.get_shape()) 303 self.assertAllEqual(normal.batch_shape, self.evaluate(entropy).shape) 304 # scipy.stats.norm cannot deal with these shapes. 305 if not stats: 306 return 307 expected_entropy = stats.norm(mu_v, sigma_v).entropy() 308 self.assertAllClose(expected_entropy, self.evaluate(entropy)) 309 310 @test_util.run_in_graph_and_eager_modes 311 def testNormalEntropy(self): 312 mu_v = np.array([1.0, 1.0, 1.0]) 313 sigma_v = np.array([[1.0, 2.0, 3.0]]).T 314 normal = normal_lib.Normal(loc=mu_v, scale=sigma_v) 315 316 # scipy.stats.norm cannot deal with these shapes. 317 sigma_broadcast = mu_v * sigma_v 318 expected_entropy = 0.5 * np.log(2 * np.pi * np.exp(1) * sigma_broadcast**2) 319 entropy = normal.entropy() 320 np.testing.assert_allclose(expected_entropy, self.evaluate(entropy)) 321 self.assertAllEqual( 322 self.evaluate(normal.batch_shape_tensor()), entropy.get_shape()) 323 self.assertAllEqual( 324 self.evaluate(normal.batch_shape_tensor()), 325 self.evaluate(entropy).shape) 326 self.assertAllEqual(normal.batch_shape, entropy.get_shape()) 327 self.assertAllEqual(normal.batch_shape, self.evaluate(entropy).shape) 328 329 @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) 330 def testNormalMeanAndMode(self): 331 # Mu will be broadcast to [7, 7, 7]. 332 mu = [7.] 333 sigma = [11., 12., 13.] 334 335 normal = normal_lib.Normal(loc=mu, scale=sigma) 336 337 self.assertAllEqual((3,), normal.mean().get_shape()) 338 self.assertAllEqual([7., 7, 7], self.evaluate(normal.mean())) 339 340 self.assertAllEqual((3,), normal.mode().get_shape()) 341 self.assertAllEqual([7., 7, 7], self.evaluate(normal.mode())) 342 343 @test_util.run_in_graph_and_eager_modes 344 def testNormalQuantile(self): 345 batch_size = 52 346 mu = self._rng.randn(batch_size) 347 sigma = self._rng.rand(batch_size) + 1.0 348 p = np.linspace(0., 1.0, batch_size - 2).astype(np.float64) 349 # Quantile performs piecewise rational approximation so adding some 350 # special input values to make sure we hit all the pieces. 351 p = np.hstack((p, np.exp(-33), 1. - np.exp(-33))) 352 353 normal = normal_lib.Normal(loc=mu, scale=sigma) 354 x = normal.quantile(p) 355 356 self.assertAllEqual( 357 self.evaluate(normal.batch_shape_tensor()), x.get_shape()) 358 self.assertAllEqual( 359 self.evaluate(normal.batch_shape_tensor()), 360 self.evaluate(x).shape) 361 self.assertAllEqual(normal.batch_shape, x.get_shape()) 362 self.assertAllEqual(normal.batch_shape, self.evaluate(x).shape) 363 364 if not stats: 365 return 366 expected_x = stats.norm(mu, sigma).ppf(p) 367 self.assertAllClose(expected_x, self.evaluate(x), atol=0.) 368 369 def _baseQuantileFiniteGradientAtDifficultPoints(self, dtype): 370 g = ops.Graph() 371 with g.as_default(): 372 mu = variables.Variable(dtype(0.0)) 373 sigma = variables.Variable(dtype(1.0)) 374 dist = normal_lib.Normal(loc=mu, scale=sigma) 375 p = variables.Variable( 376 np.array([0., 377 np.exp(-32.), np.exp(-2.), 378 1. - np.exp(-2.), 1. - np.exp(-32.), 379 1.]).astype(dtype)) 380 381 value = dist.quantile(p) 382 grads = gradients_impl.gradients(value, [mu, p]) 383 with self.cached_session(graph=g): 384 variables.global_variables_initializer().run() 385 self.assertAllFinite(grads[0]) 386 self.assertAllFinite(grads[1]) 387 388 def testQuantileFiniteGradientAtDifficultPointsFloat32(self): 389 self._baseQuantileFiniteGradientAtDifficultPoints(np.float32) 390 391 def testQuantileFiniteGradientAtDifficultPointsFloat64(self): 392 self._baseQuantileFiniteGradientAtDifficultPoints(np.float64) 393 394 @test_util.run_in_graph_and_eager_modes 395 def testNormalVariance(self): 396 # sigma will be broadcast to [7, 7, 7] 397 mu = [1., 2., 3.] 398 sigma = [7.] 399 400 normal = normal_lib.Normal(loc=mu, scale=sigma) 401 402 self.assertAllEqual((3,), normal.variance().get_shape()) 403 self.assertAllEqual([49., 49, 49], self.evaluate(normal.variance())) 404 405 @test_util.run_in_graph_and_eager_modes 406 def testNormalStandardDeviation(self): 407 # sigma will be broadcast to [7, 7, 7] 408 mu = [1., 2., 3.] 409 sigma = [7.] 410 411 normal = normal_lib.Normal(loc=mu, scale=sigma) 412 413 self.assertAllEqual((3,), normal.stddev().get_shape()) 414 self.assertAllEqual([7., 7, 7], self.evaluate(normal.stddev())) 415 416 @test_util.run_in_graph_and_eager_modes 417 def testNormalSample(self): 418 mu = constant_op.constant(3.0) 419 sigma = constant_op.constant(math.sqrt(3.0)) 420 mu_v = 3.0 421 sigma_v = np.sqrt(3.0) 422 n = constant_op.constant(100000) 423 normal = normal_lib.Normal(loc=mu, scale=sigma) 424 samples = normal.sample(n) 425 sample_values = self.evaluate(samples) 426 # Note that the standard error for the sample mean is ~ sigma / sqrt(n). 427 # The sample variance similarly is dependent on sigma and n. 428 # Thus, the tolerances below are very sensitive to number of samples 429 # as well as the variances chosen. 430 self.assertEqual(sample_values.shape, (100000,)) 431 self.assertAllClose(sample_values.mean(), mu_v, atol=1e-1) 432 self.assertAllClose(sample_values.std(), sigma_v, atol=1e-1) 433 434 expected_samples_shape = tensor_shape.TensorShape( 435 [self.evaluate(n)]).concatenate( 436 tensor_shape.TensorShape( 437 self.evaluate(normal.batch_shape_tensor()))) 438 439 self.assertAllEqual(expected_samples_shape, samples.get_shape()) 440 self.assertAllEqual(expected_samples_shape, sample_values.shape) 441 442 expected_samples_shape = ( 443 tensor_shape.TensorShape([self.evaluate(n)]).concatenate( 444 normal.batch_shape)) 445 446 self.assertAllEqual(expected_samples_shape, samples.get_shape()) 447 self.assertAllEqual(expected_samples_shape, sample_values.shape) 448 449 def testNormalFullyReparameterized(self): 450 mu = constant_op.constant(4.0) 451 sigma = constant_op.constant(3.0) 452 with backprop.GradientTape() as tape: 453 tape.watch(mu) 454 tape.watch(sigma) 455 normal = normal_lib.Normal(loc=mu, scale=sigma) 456 samples = normal.sample(100) 457 grad_mu, grad_sigma = tape.gradient(samples, [mu, sigma]) 458 self.assertIsNotNone(grad_mu) 459 self.assertIsNotNone(grad_sigma) 460 461 @test_util.run_in_graph_and_eager_modes 462 def testNormalSampleMultiDimensional(self): 463 batch_size = 2 464 mu = constant_op.constant([[3.0, -3.0]] * batch_size) 465 sigma = constant_op.constant( 466 [[math.sqrt(2.0), math.sqrt(3.0)]] * batch_size) 467 mu_v = [3.0, -3.0] 468 sigma_v = [np.sqrt(2.0), np.sqrt(3.0)] 469 n = constant_op.constant(100000) 470 normal = normal_lib.Normal(loc=mu, scale=sigma) 471 samples = normal.sample(n) 472 sample_values = self.evaluate(samples) 473 # Note that the standard error for the sample mean is ~ sigma / sqrt(n). 474 # The sample variance similarly is dependent on sigma and n. 475 # Thus, the tolerances below are very sensitive to number of samples 476 # as well as the variances chosen. 477 self.assertEqual(samples.get_shape(), (100000, batch_size, 2)) 478 self.assertAllClose(sample_values[:, 0, 0].mean(), mu_v[0], atol=1e-1) 479 self.assertAllClose(sample_values[:, 0, 0].std(), sigma_v[0], atol=1e-1) 480 self.assertAllClose(sample_values[:, 0, 1].mean(), mu_v[1], atol=1e-1) 481 self.assertAllClose(sample_values[:, 0, 1].std(), sigma_v[1], atol=1e-1) 482 483 expected_samples_shape = tensor_shape.TensorShape( 484 [self.evaluate(n)]).concatenate( 485 tensor_shape.TensorShape( 486 self.evaluate(normal.batch_shape_tensor()))) 487 self.assertAllEqual(expected_samples_shape, samples.get_shape()) 488 self.assertAllEqual(expected_samples_shape, sample_values.shape) 489 490 expected_samples_shape = ( 491 tensor_shape.TensorShape([self.evaluate(n)]).concatenate( 492 normal.batch_shape)) 493 self.assertAllEqual(expected_samples_shape, samples.get_shape()) 494 self.assertAllEqual(expected_samples_shape, sample_values.shape) 495 496 @test_util.run_in_graph_and_eager_modes 497 def testNegativeSigmaFails(self): 498 with self.assertRaisesOpError("Condition x > 0 did not hold"): 499 normal = normal_lib.Normal( 500 loc=[1.], scale=[-5.], validate_args=True, name="G") 501 self.evaluate(normal.mean()) 502 503 @test_util.run_in_graph_and_eager_modes 504 def testNormalShape(self): 505 mu = constant_op.constant([-3.0] * 5) 506 sigma = constant_op.constant(11.0) 507 normal = normal_lib.Normal(loc=mu, scale=sigma) 508 509 self.assertEqual(self.evaluate(normal.batch_shape_tensor()), [5]) 510 self.assertEqual(normal.batch_shape, tensor_shape.TensorShape([5])) 511 self.assertAllEqual(self.evaluate(normal.event_shape_tensor()), []) 512 self.assertEqual(normal.event_shape, tensor_shape.TensorShape([])) 513 514 @test_util.run_deprecated_v1 515 def testNormalShapeWithPlaceholders(self): 516 mu = array_ops.placeholder(dtype=dtypes.float32) 517 sigma = array_ops.placeholder(dtype=dtypes.float32) 518 normal = normal_lib.Normal(loc=mu, scale=sigma) 519 520 with self.cached_session() as sess: 521 # get_batch_shape should return an "<unknown>" tensor. 522 self.assertEqual(normal.batch_shape, tensor_shape.TensorShape(None)) 523 self.assertEqual(normal.event_shape, ()) 524 self.assertAllEqual(self.evaluate(normal.event_shape_tensor()), []) 525 self.assertAllEqual( 526 sess.run(normal.batch_shape_tensor(), 527 feed_dict={mu: 5.0, 528 sigma: [1.0, 2.0]}), [2]) 529 530 @test_util.run_in_graph_and_eager_modes 531 def testNormalNormalKL(self): 532 batch_size = 6 533 mu_a = np.array([3.0] * batch_size) 534 sigma_a = np.array([1.0, 2.0, 3.0, 1.5, 2.5, 3.5]) 535 mu_b = np.array([-3.0] * batch_size) 536 sigma_b = np.array([0.5, 1.0, 1.5, 2.0, 2.5, 3.0]) 537 538 n_a = normal_lib.Normal(loc=mu_a, scale=sigma_a) 539 n_b = normal_lib.Normal(loc=mu_b, scale=sigma_b) 540 541 kl = kullback_leibler.kl_divergence(n_a, n_b) 542 kl_val = self.evaluate(kl) 543 544 kl_expected = ((mu_a - mu_b)**2 / (2 * sigma_b**2) + 0.5 * ( 545 (sigma_a**2 / sigma_b**2) - 1 - 2 * np.log(sigma_a / sigma_b))) 546 547 self.assertEqual(kl.get_shape(), (batch_size,)) 548 self.assertAllClose(kl_val, kl_expected) 549 550 551if __name__ == "__main__": 552 test.main() 553