1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for initializers.""" 16 17import importlib 18import math 19 20import numpy as np 21 22from tensorflow.python.eager import backprop 23from tensorflow.python.framework import constant_op 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import ops 26from tensorflow.python.framework import tensor_shape 27from tensorflow.python.framework import test_util 28from tensorflow.python.ops import array_ops 29from tensorflow.python.ops import gradients_impl 30from tensorflow.python.ops import nn_ops 31from tensorflow.python.ops import variables 32from tensorflow.python.ops.distributions import kullback_leibler 33from tensorflow.python.ops.distributions import normal as normal_lib 34from tensorflow.python.platform import test 35from tensorflow.python.platform import tf_logging 36 37 38def try_import(name): # pylint: disable=invalid-name 39 module = None 40 try: 41 module = importlib.import_module(name) 42 except ImportError as e: 43 tf_logging.warning("Could not import %s: %s" % (name, str(e))) 44 return module 45 46stats = try_import("scipy.stats") 47 48 49class NormalTest(test.TestCase): 50 51 def setUp(self): 52 self._rng = np.random.RandomState(123) 53 54 def assertAllFinite(self, tensor): 55 is_finite = np.isfinite(self.evaluate(tensor)) 56 all_true = np.ones_like(is_finite, dtype=np.bool_) 57 self.assertAllEqual(all_true, is_finite) 58 59 def _testParamShapes(self, sample_shape, expected): 60 param_shapes = normal_lib.Normal.param_shapes(sample_shape) 61 mu_shape, sigma_shape = param_shapes["loc"], param_shapes["scale"] 62 self.assertAllEqual(expected, self.evaluate(mu_shape)) 63 self.assertAllEqual(expected, self.evaluate(sigma_shape)) 64 mu = array_ops.zeros(mu_shape) 65 sigma = array_ops.ones(sigma_shape) 66 self.assertAllEqual( 67 expected, 68 self.evaluate(array_ops.shape(normal_lib.Normal(mu, sigma).sample()))) 69 70 def _testParamStaticShapes(self, sample_shape, expected): 71 param_shapes = normal_lib.Normal.param_static_shapes(sample_shape) 72 mu_shape, sigma_shape = param_shapes["loc"], param_shapes["scale"] 73 self.assertEqual(expected, mu_shape) 74 self.assertEqual(expected, sigma_shape) 75 76 @test_util.run_in_graph_and_eager_modes 77 def testSampleLikeArgsGetDistDType(self): 78 dist = normal_lib.Normal(0., 1.) 79 self.assertEqual(dtypes.float32, dist.dtype) 80 for method in ("log_prob", "prob", "log_cdf", "cdf", 81 "log_survival_function", "survival_function", "quantile"): 82 self.assertEqual(dtypes.float32, getattr(dist, method)(1).dtype) 83 84 @test_util.run_in_graph_and_eager_modes 85 def testParamShapes(self): 86 sample_shape = [10, 3, 4] 87 self._testParamShapes(sample_shape, sample_shape) 88 self._testParamShapes(constant_op.constant(sample_shape), sample_shape) 89 90 @test_util.run_in_graph_and_eager_modes 91 def testParamStaticShapes(self): 92 sample_shape = [10, 3, 4] 93 self._testParamStaticShapes(sample_shape, sample_shape) 94 self._testParamStaticShapes( 95 tensor_shape.TensorShape(sample_shape), sample_shape) 96 97 @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) 98 def testNormalWithSoftplusScale(self): 99 mu = array_ops.zeros((10, 3)) 100 rho = array_ops.ones((10, 3)) * -2. 101 normal = normal_lib.NormalWithSoftplusScale(loc=mu, scale=rho) 102 self.assertAllEqual(self.evaluate(mu), self.evaluate(normal.loc)) 103 self.assertAllEqual( 104 self.evaluate(nn_ops.softplus(rho)), self.evaluate(normal.scale)) 105 106 @test_util.run_in_graph_and_eager_modes 107 def testNormalLogPDF(self): 108 batch_size = 6 109 mu = constant_op.constant([3.0] * batch_size) 110 sigma = constant_op.constant([math.sqrt(10.0)] * batch_size) 111 x = np.array([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0], dtype=np.float32) 112 normal = normal_lib.Normal(loc=mu, scale=sigma) 113 114 log_pdf = normal.log_prob(x) 115 self.assertAllEqual( 116 self.evaluate(normal.batch_shape_tensor()), log_pdf.get_shape()) 117 self.assertAllEqual( 118 self.evaluate(normal.batch_shape_tensor()), 119 self.evaluate(log_pdf).shape) 120 self.assertAllEqual(normal.batch_shape, log_pdf.get_shape()) 121 self.assertAllEqual(normal.batch_shape, self.evaluate(log_pdf).shape) 122 123 pdf = normal.prob(x) 124 self.assertAllEqual( 125 self.evaluate(normal.batch_shape_tensor()), pdf.get_shape()) 126 self.assertAllEqual( 127 self.evaluate(normal.batch_shape_tensor()), 128 self.evaluate(pdf).shape) 129 self.assertAllEqual(normal.batch_shape, pdf.get_shape()) 130 self.assertAllEqual(normal.batch_shape, self.evaluate(pdf).shape) 131 132 if not stats: 133 return 134 expected_log_pdf = stats.norm(self.evaluate(mu), 135 self.evaluate(sigma)).logpdf(x) 136 self.assertAllClose(expected_log_pdf, self.evaluate(log_pdf)) 137 self.assertAllClose(np.exp(expected_log_pdf), self.evaluate(pdf)) 138 139 @test_util.run_in_graph_and_eager_modes 140 def testNormalLogPDFMultidimensional(self): 141 batch_size = 6 142 mu = constant_op.constant([[3.0, -3.0]] * batch_size) 143 sigma = constant_op.constant( 144 [[math.sqrt(10.0), math.sqrt(15.0)]] * batch_size) 145 x = np.array([[-2.5, 2.5, 4.0, 0.0, -1.0, 2.0]], dtype=np.float32).T 146 normal = normal_lib.Normal(loc=mu, scale=sigma) 147 148 log_pdf = normal.log_prob(x) 149 log_pdf_values = self.evaluate(log_pdf) 150 self.assertEqual(log_pdf.get_shape(), (6, 2)) 151 self.assertAllEqual( 152 self.evaluate(normal.batch_shape_tensor()), log_pdf.get_shape()) 153 self.assertAllEqual( 154 self.evaluate(normal.batch_shape_tensor()), 155 self.evaluate(log_pdf).shape) 156 self.assertAllEqual(normal.batch_shape, log_pdf.get_shape()) 157 self.assertAllEqual(normal.batch_shape, self.evaluate(log_pdf).shape) 158 159 pdf = normal.prob(x) 160 pdf_values = self.evaluate(pdf) 161 self.assertEqual(pdf.get_shape(), (6, 2)) 162 self.assertAllEqual( 163 self.evaluate(normal.batch_shape_tensor()), pdf.get_shape()) 164 self.assertAllEqual( 165 self.evaluate(normal.batch_shape_tensor()), pdf_values.shape) 166 self.assertAllEqual(normal.batch_shape, pdf.get_shape()) 167 self.assertAllEqual(normal.batch_shape, pdf_values.shape) 168 169 if not stats: 170 return 171 expected_log_pdf = stats.norm(self.evaluate(mu), 172 self.evaluate(sigma)).logpdf(x) 173 self.assertAllClose(expected_log_pdf, log_pdf_values) 174 self.assertAllClose(np.exp(expected_log_pdf), pdf_values) 175 176 @test_util.run_in_graph_and_eager_modes 177 def testNormalCDF(self): 178 batch_size = 50 179 mu = self._rng.randn(batch_size) 180 sigma = self._rng.rand(batch_size) + 1.0 181 x = np.linspace(-8.0, 8.0, batch_size).astype(np.float64) 182 183 normal = normal_lib.Normal(loc=mu, scale=sigma) 184 cdf = normal.cdf(x) 185 self.assertAllEqual( 186 self.evaluate(normal.batch_shape_tensor()), cdf.get_shape()) 187 self.assertAllEqual( 188 self.evaluate(normal.batch_shape_tensor()), 189 self.evaluate(cdf).shape) 190 self.assertAllEqual(normal.batch_shape, cdf.get_shape()) 191 self.assertAllEqual(normal.batch_shape, self.evaluate(cdf).shape) 192 if not stats: 193 return 194 expected_cdf = stats.norm(mu, sigma).cdf(x) 195 self.assertAllClose(expected_cdf, self.evaluate(cdf), atol=0) 196 197 @test_util.run_in_graph_and_eager_modes 198 def testNormalSurvivalFunction(self): 199 batch_size = 50 200 mu = self._rng.randn(batch_size) 201 sigma = self._rng.rand(batch_size) + 1.0 202 x = np.linspace(-8.0, 8.0, batch_size).astype(np.float64) 203 204 normal = normal_lib.Normal(loc=mu, scale=sigma) 205 206 sf = normal.survival_function(x) 207 self.assertAllEqual( 208 self.evaluate(normal.batch_shape_tensor()), sf.get_shape()) 209 self.assertAllEqual( 210 self.evaluate(normal.batch_shape_tensor()), 211 self.evaluate(sf).shape) 212 self.assertAllEqual(normal.batch_shape, sf.get_shape()) 213 self.assertAllEqual(normal.batch_shape, self.evaluate(sf).shape) 214 if not stats: 215 return 216 expected_sf = stats.norm(mu, sigma).sf(x) 217 self.assertAllClose(expected_sf, self.evaluate(sf), atol=0) 218 219 @test_util.run_in_graph_and_eager_modes 220 def testNormalLogCDF(self): 221 batch_size = 50 222 mu = self._rng.randn(batch_size) 223 sigma = self._rng.rand(batch_size) + 1.0 224 x = np.linspace(-100.0, 10.0, batch_size).astype(np.float64) 225 226 normal = normal_lib.Normal(loc=mu, scale=sigma) 227 228 cdf = normal.log_cdf(x) 229 self.assertAllEqual( 230 self.evaluate(normal.batch_shape_tensor()), cdf.get_shape()) 231 self.assertAllEqual( 232 self.evaluate(normal.batch_shape_tensor()), 233 self.evaluate(cdf).shape) 234 self.assertAllEqual(normal.batch_shape, cdf.get_shape()) 235 self.assertAllEqual(normal.batch_shape, self.evaluate(cdf).shape) 236 237 if not stats: 238 return 239 expected_cdf = stats.norm(mu, sigma).logcdf(x) 240 self.assertAllClose(expected_cdf, self.evaluate(cdf), atol=0, rtol=1e-3) 241 242 def testFiniteGradientAtDifficultPoints(self): 243 for dtype in [np.float32, np.float64]: 244 g = ops.Graph() 245 with g.as_default(): 246 mu = variables.Variable(dtype(0.0)) 247 sigma = variables.Variable(dtype(1.0)) 248 dist = normal_lib.Normal(loc=mu, scale=sigma) 249 x = np.array([-100., -20., -5., 0., 5., 20., 100.]).astype(dtype) 250 for func in [ 251 dist.cdf, dist.log_cdf, dist.survival_function, 252 dist.log_survival_function, dist.log_prob, dist.prob 253 ]: 254 value = func(x) 255 grads = gradients_impl.gradients(value, [mu, sigma]) 256 with self.session(graph=g): 257 self.evaluate(variables.global_variables_initializer()) 258 self.assertAllFinite(value) 259 self.assertAllFinite(grads[0]) 260 self.assertAllFinite(grads[1]) 261 262 @test_util.run_in_graph_and_eager_modes 263 def testNormalLogSurvivalFunction(self): 264 batch_size = 50 265 mu = self._rng.randn(batch_size) 266 sigma = self._rng.rand(batch_size) + 1.0 267 x = np.linspace(-10.0, 100.0, batch_size).astype(np.float64) 268 269 normal = normal_lib.Normal(loc=mu, scale=sigma) 270 271 sf = normal.log_survival_function(x) 272 self.assertAllEqual( 273 self.evaluate(normal.batch_shape_tensor()), sf.get_shape()) 274 self.assertAllEqual( 275 self.evaluate(normal.batch_shape_tensor()), 276 self.evaluate(sf).shape) 277 self.assertAllEqual(normal.batch_shape, sf.get_shape()) 278 self.assertAllEqual(normal.batch_shape, self.evaluate(sf).shape) 279 280 if not stats: 281 return 282 expected_sf = stats.norm(mu, sigma).logsf(x) 283 self.assertAllClose(expected_sf, self.evaluate(sf), atol=0, rtol=1e-5) 284 285 @test_util.run_in_graph_and_eager_modes 286 def testNormalEntropyWithScalarInputs(self): 287 # Scipy.stats.norm cannot deal with the shapes in the other test. 288 mu_v = 2.34 289 sigma_v = 4.56 290 normal = normal_lib.Normal(loc=mu_v, scale=sigma_v) 291 292 entropy = normal.entropy() 293 self.assertAllEqual( 294 self.evaluate(normal.batch_shape_tensor()), entropy.get_shape()) 295 self.assertAllEqual( 296 self.evaluate(normal.batch_shape_tensor()), 297 self.evaluate(entropy).shape) 298 self.assertAllEqual(normal.batch_shape, entropy.get_shape()) 299 self.assertAllEqual(normal.batch_shape, self.evaluate(entropy).shape) 300 # scipy.stats.norm cannot deal with these shapes. 301 if not stats: 302 return 303 expected_entropy = stats.norm(mu_v, sigma_v).entropy() 304 self.assertAllClose(expected_entropy, self.evaluate(entropy)) 305 306 @test_util.run_in_graph_and_eager_modes 307 def testNormalEntropy(self): 308 mu_v = np.array([1.0, 1.0, 1.0]) 309 sigma_v = np.array([[1.0, 2.0, 3.0]]).T 310 normal = normal_lib.Normal(loc=mu_v, scale=sigma_v) 311 312 # scipy.stats.norm cannot deal with these shapes. 313 sigma_broadcast = mu_v * sigma_v 314 expected_entropy = 0.5 * np.log(2 * np.pi * np.exp(1) * sigma_broadcast**2) 315 entropy = normal.entropy() 316 np.testing.assert_allclose(expected_entropy, self.evaluate(entropy)) 317 self.assertAllEqual( 318 self.evaluate(normal.batch_shape_tensor()), entropy.get_shape()) 319 self.assertAllEqual( 320 self.evaluate(normal.batch_shape_tensor()), 321 self.evaluate(entropy).shape) 322 self.assertAllEqual(normal.batch_shape, entropy.get_shape()) 323 self.assertAllEqual(normal.batch_shape, self.evaluate(entropy).shape) 324 325 @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) 326 def testNormalMeanAndMode(self): 327 # Mu will be broadcast to [7, 7, 7]. 328 mu = [7.] 329 sigma = [11., 12., 13.] 330 331 normal = normal_lib.Normal(loc=mu, scale=sigma) 332 333 self.assertAllEqual((3,), normal.mean().get_shape()) 334 self.assertAllEqual([7., 7, 7], self.evaluate(normal.mean())) 335 336 self.assertAllEqual((3,), normal.mode().get_shape()) 337 self.assertAllEqual([7., 7, 7], self.evaluate(normal.mode())) 338 339 @test_util.run_in_graph_and_eager_modes 340 def testNormalQuantile(self): 341 batch_size = 52 342 mu = self._rng.randn(batch_size) 343 sigma = self._rng.rand(batch_size) + 1.0 344 p = np.linspace(0., 1.0, batch_size - 2).astype(np.float64) 345 # Quantile performs piecewise rational approximation so adding some 346 # special input values to make sure we hit all the pieces. 347 p = np.hstack((p, np.exp(-33), 1. - np.exp(-33))) 348 349 normal = normal_lib.Normal(loc=mu, scale=sigma) 350 x = normal.quantile(p) 351 352 self.assertAllEqual( 353 self.evaluate(normal.batch_shape_tensor()), x.get_shape()) 354 self.assertAllEqual( 355 self.evaluate(normal.batch_shape_tensor()), 356 self.evaluate(x).shape) 357 self.assertAllEqual(normal.batch_shape, x.get_shape()) 358 self.assertAllEqual(normal.batch_shape, self.evaluate(x).shape) 359 360 if not stats: 361 return 362 expected_x = stats.norm(mu, sigma).ppf(p) 363 self.assertAllClose(expected_x, self.evaluate(x), atol=0.) 364 365 def _baseQuantileFiniteGradientAtDifficultPoints(self, dtype): 366 g = ops.Graph() 367 with g.as_default(): 368 mu = variables.Variable(dtype(0.0)) 369 sigma = variables.Variable(dtype(1.0)) 370 dist = normal_lib.Normal(loc=mu, scale=sigma) 371 p = variables.Variable( 372 np.array([0., 373 np.exp(-32.), np.exp(-2.), 374 1. - np.exp(-2.), 1. - np.exp(-32.), 375 1.]).astype(dtype)) 376 377 value = dist.quantile(p) 378 grads = gradients_impl.gradients(value, [mu, p]) 379 with self.cached_session(graph=g): 380 self.evaluate(variables.global_variables_initializer()) 381 self.assertAllFinite(grads[0]) 382 self.assertAllFinite(grads[1]) 383 384 def testQuantileFiniteGradientAtDifficultPointsFloat32(self): 385 self._baseQuantileFiniteGradientAtDifficultPoints(np.float32) 386 387 def testQuantileFiniteGradientAtDifficultPointsFloat64(self): 388 self._baseQuantileFiniteGradientAtDifficultPoints(np.float64) 389 390 @test_util.run_in_graph_and_eager_modes 391 def testNormalVariance(self): 392 # sigma will be broadcast to [7, 7, 7] 393 mu = [1., 2., 3.] 394 sigma = [7.] 395 396 normal = normal_lib.Normal(loc=mu, scale=sigma) 397 398 self.assertAllEqual((3,), normal.variance().get_shape()) 399 self.assertAllEqual([49., 49, 49], self.evaluate(normal.variance())) 400 401 @test_util.run_in_graph_and_eager_modes 402 def testNormalStandardDeviation(self): 403 # sigma will be broadcast to [7, 7, 7] 404 mu = [1., 2., 3.] 405 sigma = [7.] 406 407 normal = normal_lib.Normal(loc=mu, scale=sigma) 408 409 self.assertAllEqual((3,), normal.stddev().get_shape()) 410 self.assertAllEqual([7., 7, 7], self.evaluate(normal.stddev())) 411 412 @test_util.run_in_graph_and_eager_modes 413 def testNormalSample(self): 414 mu = constant_op.constant(3.0) 415 sigma = constant_op.constant(math.sqrt(3.0)) 416 mu_v = 3.0 417 sigma_v = np.sqrt(3.0) 418 n = constant_op.constant(100000) 419 normal = normal_lib.Normal(loc=mu, scale=sigma) 420 samples = normal.sample(n) 421 sample_values = self.evaluate(samples) 422 # Note that the standard error for the sample mean is ~ sigma / sqrt(n). 423 # The sample variance similarly is dependent on sigma and n. 424 # Thus, the tolerances below are very sensitive to number of samples 425 # as well as the variances chosen. 426 self.assertEqual(sample_values.shape, (100000,)) 427 self.assertAllClose(sample_values.mean(), mu_v, atol=1e-1) 428 self.assertAllClose(sample_values.std(), sigma_v, atol=1e-1) 429 430 expected_samples_shape = tensor_shape.TensorShape( 431 [self.evaluate(n)]).concatenate( 432 tensor_shape.TensorShape( 433 self.evaluate(normal.batch_shape_tensor()))) 434 435 self.assertAllEqual(expected_samples_shape, samples.get_shape()) 436 self.assertAllEqual(expected_samples_shape, sample_values.shape) 437 438 expected_samples_shape = ( 439 tensor_shape.TensorShape([self.evaluate(n)]).concatenate( 440 normal.batch_shape)) 441 442 self.assertAllEqual(expected_samples_shape, samples.get_shape()) 443 self.assertAllEqual(expected_samples_shape, sample_values.shape) 444 445 def testNormalFullyReparameterized(self): 446 mu = constant_op.constant(4.0) 447 sigma = constant_op.constant(3.0) 448 with backprop.GradientTape() as tape: 449 tape.watch(mu) 450 tape.watch(sigma) 451 normal = normal_lib.Normal(loc=mu, scale=sigma) 452 samples = normal.sample(100) 453 grad_mu, grad_sigma = tape.gradient(samples, [mu, sigma]) 454 self.assertIsNotNone(grad_mu) 455 self.assertIsNotNone(grad_sigma) 456 457 @test_util.run_in_graph_and_eager_modes 458 def testNormalSampleMultiDimensional(self): 459 batch_size = 2 460 mu = constant_op.constant([[3.0, -3.0]] * batch_size) 461 sigma = constant_op.constant( 462 [[math.sqrt(2.0), math.sqrt(3.0)]] * batch_size) 463 mu_v = [3.0, -3.0] 464 sigma_v = [np.sqrt(2.0), np.sqrt(3.0)] 465 n = constant_op.constant(100000) 466 normal = normal_lib.Normal(loc=mu, scale=sigma) 467 samples = normal.sample(n) 468 sample_values = self.evaluate(samples) 469 # Note that the standard error for the sample mean is ~ sigma / sqrt(n). 470 # The sample variance similarly is dependent on sigma and n. 471 # Thus, the tolerances below are very sensitive to number of samples 472 # as well as the variances chosen. 473 self.assertEqual(samples.get_shape(), (100000, batch_size, 2)) 474 self.assertAllClose(sample_values[:, 0, 0].mean(), mu_v[0], atol=1e-1) 475 self.assertAllClose(sample_values[:, 0, 0].std(), sigma_v[0], atol=1e-1) 476 self.assertAllClose(sample_values[:, 0, 1].mean(), mu_v[1], atol=1e-1) 477 self.assertAllClose(sample_values[:, 0, 1].std(), sigma_v[1], atol=1e-1) 478 479 expected_samples_shape = tensor_shape.TensorShape( 480 [self.evaluate(n)]).concatenate( 481 tensor_shape.TensorShape( 482 self.evaluate(normal.batch_shape_tensor()))) 483 self.assertAllEqual(expected_samples_shape, samples.get_shape()) 484 self.assertAllEqual(expected_samples_shape, sample_values.shape) 485 486 expected_samples_shape = ( 487 tensor_shape.TensorShape([self.evaluate(n)]).concatenate( 488 normal.batch_shape)) 489 self.assertAllEqual(expected_samples_shape, samples.get_shape()) 490 self.assertAllEqual(expected_samples_shape, sample_values.shape) 491 492 @test_util.run_in_graph_and_eager_modes 493 def testNegativeSigmaFails(self): 494 with self.assertRaisesOpError("Condition x > 0 did not hold"): 495 normal = normal_lib.Normal( 496 loc=[1.], scale=[-5.], validate_args=True, name="G") 497 self.evaluate(normal.mean()) 498 499 @test_util.run_in_graph_and_eager_modes 500 def testNormalShape(self): 501 mu = constant_op.constant([-3.0] * 5) 502 sigma = constant_op.constant(11.0) 503 normal = normal_lib.Normal(loc=mu, scale=sigma) 504 505 self.assertEqual(self.evaluate(normal.batch_shape_tensor()), [5]) 506 self.assertEqual(normal.batch_shape, tensor_shape.TensorShape([5])) 507 self.assertAllEqual(self.evaluate(normal.event_shape_tensor()), []) 508 self.assertEqual(normal.event_shape, tensor_shape.TensorShape([])) 509 510 @test_util.run_deprecated_v1 511 def testNormalShapeWithPlaceholders(self): 512 mu = array_ops.placeholder(dtype=dtypes.float32) 513 sigma = array_ops.placeholder(dtype=dtypes.float32) 514 normal = normal_lib.Normal(loc=mu, scale=sigma) 515 516 with self.cached_session() as sess: 517 # get_batch_shape should return an "<unknown>" tensor. 518 self.assertEqual(normal.batch_shape, tensor_shape.TensorShape(None)) 519 self.assertEqual(normal.event_shape, ()) 520 self.assertAllEqual(self.evaluate(normal.event_shape_tensor()), []) 521 self.assertAllEqual( 522 sess.run(normal.batch_shape_tensor(), 523 feed_dict={mu: 5.0, 524 sigma: [1.0, 2.0]}), [2]) 525 526 @test_util.run_in_graph_and_eager_modes 527 def testNormalNormalKL(self): 528 batch_size = 6 529 mu_a = np.array([3.0] * batch_size) 530 sigma_a = np.array([1.0, 2.0, 3.0, 1.5, 2.5, 3.5]) 531 mu_b = np.array([-3.0] * batch_size) 532 sigma_b = np.array([0.5, 1.0, 1.5, 2.0, 2.5, 3.0]) 533 534 n_a = normal_lib.Normal(loc=mu_a, scale=sigma_a) 535 n_b = normal_lib.Normal(loc=mu_b, scale=sigma_b) 536 537 kl = kullback_leibler.kl_divergence(n_a, n_b) 538 kl_val = self.evaluate(kl) 539 540 kl_expected = ((mu_a - mu_b)**2 / (2 * sigma_b**2) + 0.5 * ( 541 (sigma_a**2 / sigma_b**2) - 1 - 2 * np.log(sigma_a / sigma_b))) 542 543 self.assertEqual(kl.get_shape(), (batch_size,)) 544 self.assertAllClose(kl_val, kl_expected) 545 546 547if __name__ == "__main__": 548 test.main() 549