• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for initializers."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import importlib
22import math
23
24import numpy as np
25
26from tensorflow.python.eager import backprop
27from tensorflow.python.framework import constant_op
28from tensorflow.python.framework import dtypes
29from tensorflow.python.framework import ops
30from tensorflow.python.framework import tensor_shape
31from tensorflow.python.framework import test_util
32from tensorflow.python.ops import array_ops
33from tensorflow.python.ops import gradients_impl
34from tensorflow.python.ops import nn_ops
35from tensorflow.python.ops import variables
36from tensorflow.python.ops.distributions import kullback_leibler
37from tensorflow.python.ops.distributions import normal as normal_lib
38from tensorflow.python.platform import test
39from tensorflow.python.platform import tf_logging
40
41
42def try_import(name):  # pylint: disable=invalid-name
43  module = None
44  try:
45    module = importlib.import_module(name)
46  except ImportError as e:
47    tf_logging.warning("Could not import %s: %s" % (name, str(e)))
48  return module
49
50stats = try_import("scipy.stats")
51
52
53class NormalTest(test.TestCase):
54
55  def setUp(self):
56    self._rng = np.random.RandomState(123)
57
58  def assertAllFinite(self, tensor):
59    is_finite = np.isfinite(self.evaluate(tensor))
60    all_true = np.ones_like(is_finite, dtype=np.bool)
61    self.assertAllEqual(all_true, is_finite)
62
63  def _testParamShapes(self, sample_shape, expected):
64    param_shapes = normal_lib.Normal.param_shapes(sample_shape)
65    mu_shape, sigma_shape = param_shapes["loc"], param_shapes["scale"]
66    self.assertAllEqual(expected, self.evaluate(mu_shape))
67    self.assertAllEqual(expected, self.evaluate(sigma_shape))
68    mu = array_ops.zeros(mu_shape)
69    sigma = array_ops.ones(sigma_shape)
70    self.assertAllEqual(
71        expected,
72        self.evaluate(array_ops.shape(normal_lib.Normal(mu, sigma).sample())))
73
74  def _testParamStaticShapes(self, sample_shape, expected):
75    param_shapes = normal_lib.Normal.param_static_shapes(sample_shape)
76    mu_shape, sigma_shape = param_shapes["loc"], param_shapes["scale"]
77    self.assertEqual(expected, mu_shape)
78    self.assertEqual(expected, sigma_shape)
79
80  @test_util.run_in_graph_and_eager_modes
81  def testSampleLikeArgsGetDistDType(self):
82    dist = normal_lib.Normal(0., 1.)
83    self.assertEqual(dtypes.float32, dist.dtype)
84    for method in ("log_prob", "prob", "log_cdf", "cdf",
85                   "log_survival_function", "survival_function", "quantile"):
86      self.assertEqual(dtypes.float32, getattr(dist, method)(1).dtype)
87
88  @test_util.run_in_graph_and_eager_modes
89  def testParamShapes(self):
90    sample_shape = [10, 3, 4]
91    self._testParamShapes(sample_shape, sample_shape)
92    self._testParamShapes(constant_op.constant(sample_shape), sample_shape)
93
94  @test_util.run_in_graph_and_eager_modes
95  def testParamStaticShapes(self):
96    sample_shape = [10, 3, 4]
97    self._testParamStaticShapes(sample_shape, sample_shape)
98    self._testParamStaticShapes(
99        tensor_shape.TensorShape(sample_shape), sample_shape)
100
101  @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
102  def testNormalWithSoftplusScale(self):
103    mu = array_ops.zeros((10, 3))
104    rho = array_ops.ones((10, 3)) * -2.
105    normal = normal_lib.NormalWithSoftplusScale(loc=mu, scale=rho)
106    self.assertAllEqual(self.evaluate(mu), self.evaluate(normal.loc))
107    self.assertAllEqual(
108        self.evaluate(nn_ops.softplus(rho)), self.evaluate(normal.scale))
109
110  @test_util.run_in_graph_and_eager_modes
111  def testNormalLogPDF(self):
112    batch_size = 6
113    mu = constant_op.constant([3.0] * batch_size)
114    sigma = constant_op.constant([math.sqrt(10.0)] * batch_size)
115    x = np.array([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0], dtype=np.float32)
116    normal = normal_lib.Normal(loc=mu, scale=sigma)
117
118    log_pdf = normal.log_prob(x)
119    self.assertAllEqual(
120        self.evaluate(normal.batch_shape_tensor()), log_pdf.get_shape())
121    self.assertAllEqual(
122        self.evaluate(normal.batch_shape_tensor()),
123        self.evaluate(log_pdf).shape)
124    self.assertAllEqual(normal.batch_shape, log_pdf.get_shape())
125    self.assertAllEqual(normal.batch_shape, self.evaluate(log_pdf).shape)
126
127    pdf = normal.prob(x)
128    self.assertAllEqual(
129        self.evaluate(normal.batch_shape_tensor()), pdf.get_shape())
130    self.assertAllEqual(
131        self.evaluate(normal.batch_shape_tensor()),
132        self.evaluate(pdf).shape)
133    self.assertAllEqual(normal.batch_shape, pdf.get_shape())
134    self.assertAllEqual(normal.batch_shape, self.evaluate(pdf).shape)
135
136    if not stats:
137      return
138    expected_log_pdf = stats.norm(self.evaluate(mu),
139                                  self.evaluate(sigma)).logpdf(x)
140    self.assertAllClose(expected_log_pdf, self.evaluate(log_pdf))
141    self.assertAllClose(np.exp(expected_log_pdf), self.evaluate(pdf))
142
143  @test_util.run_in_graph_and_eager_modes
144  def testNormalLogPDFMultidimensional(self):
145    batch_size = 6
146    mu = constant_op.constant([[3.0, -3.0]] * batch_size)
147    sigma = constant_op.constant(
148        [[math.sqrt(10.0), math.sqrt(15.0)]] * batch_size)
149    x = np.array([[-2.5, 2.5, 4.0, 0.0, -1.0, 2.0]], dtype=np.float32).T
150    normal = normal_lib.Normal(loc=mu, scale=sigma)
151
152    log_pdf = normal.log_prob(x)
153    log_pdf_values = self.evaluate(log_pdf)
154    self.assertEqual(log_pdf.get_shape(), (6, 2))
155    self.assertAllEqual(
156        self.evaluate(normal.batch_shape_tensor()), log_pdf.get_shape())
157    self.assertAllEqual(
158        self.evaluate(normal.batch_shape_tensor()),
159        self.evaluate(log_pdf).shape)
160    self.assertAllEqual(normal.batch_shape, log_pdf.get_shape())
161    self.assertAllEqual(normal.batch_shape, self.evaluate(log_pdf).shape)
162
163    pdf = normal.prob(x)
164    pdf_values = self.evaluate(pdf)
165    self.assertEqual(pdf.get_shape(), (6, 2))
166    self.assertAllEqual(
167        self.evaluate(normal.batch_shape_tensor()), pdf.get_shape())
168    self.assertAllEqual(
169        self.evaluate(normal.batch_shape_tensor()), pdf_values.shape)
170    self.assertAllEqual(normal.batch_shape, pdf.get_shape())
171    self.assertAllEqual(normal.batch_shape, pdf_values.shape)
172
173    if not stats:
174      return
175    expected_log_pdf = stats.norm(self.evaluate(mu),
176                                  self.evaluate(sigma)).logpdf(x)
177    self.assertAllClose(expected_log_pdf, log_pdf_values)
178    self.assertAllClose(np.exp(expected_log_pdf), pdf_values)
179
180  @test_util.run_in_graph_and_eager_modes
181  def testNormalCDF(self):
182    batch_size = 50
183    mu = self._rng.randn(batch_size)
184    sigma = self._rng.rand(batch_size) + 1.0
185    x = np.linspace(-8.0, 8.0, batch_size).astype(np.float64)
186
187    normal = normal_lib.Normal(loc=mu, scale=sigma)
188    cdf = normal.cdf(x)
189    self.assertAllEqual(
190        self.evaluate(normal.batch_shape_tensor()), cdf.get_shape())
191    self.assertAllEqual(
192        self.evaluate(normal.batch_shape_tensor()),
193        self.evaluate(cdf).shape)
194    self.assertAllEqual(normal.batch_shape, cdf.get_shape())
195    self.assertAllEqual(normal.batch_shape, self.evaluate(cdf).shape)
196    if not stats:
197      return
198    expected_cdf = stats.norm(mu, sigma).cdf(x)
199    self.assertAllClose(expected_cdf, self.evaluate(cdf), atol=0)
200
201  @test_util.run_in_graph_and_eager_modes
202  def testNormalSurvivalFunction(self):
203    batch_size = 50
204    mu = self._rng.randn(batch_size)
205    sigma = self._rng.rand(batch_size) + 1.0
206    x = np.linspace(-8.0, 8.0, batch_size).astype(np.float64)
207
208    normal = normal_lib.Normal(loc=mu, scale=sigma)
209
210    sf = normal.survival_function(x)
211    self.assertAllEqual(
212        self.evaluate(normal.batch_shape_tensor()), sf.get_shape())
213    self.assertAllEqual(
214        self.evaluate(normal.batch_shape_tensor()),
215        self.evaluate(sf).shape)
216    self.assertAllEqual(normal.batch_shape, sf.get_shape())
217    self.assertAllEqual(normal.batch_shape, self.evaluate(sf).shape)
218    if not stats:
219      return
220    expected_sf = stats.norm(mu, sigma).sf(x)
221    self.assertAllClose(expected_sf, self.evaluate(sf), atol=0)
222
223  @test_util.run_in_graph_and_eager_modes
224  def testNormalLogCDF(self):
225    batch_size = 50
226    mu = self._rng.randn(batch_size)
227    sigma = self._rng.rand(batch_size) + 1.0
228    x = np.linspace(-100.0, 10.0, batch_size).astype(np.float64)
229
230    normal = normal_lib.Normal(loc=mu, scale=sigma)
231
232    cdf = normal.log_cdf(x)
233    self.assertAllEqual(
234        self.evaluate(normal.batch_shape_tensor()), cdf.get_shape())
235    self.assertAllEqual(
236        self.evaluate(normal.batch_shape_tensor()),
237        self.evaluate(cdf).shape)
238    self.assertAllEqual(normal.batch_shape, cdf.get_shape())
239    self.assertAllEqual(normal.batch_shape, self.evaluate(cdf).shape)
240
241    if not stats:
242      return
243    expected_cdf = stats.norm(mu, sigma).logcdf(x)
244    self.assertAllClose(expected_cdf, self.evaluate(cdf), atol=0, rtol=1e-3)
245
246  def testFiniteGradientAtDifficultPoints(self):
247    for dtype in [np.float32, np.float64]:
248      g = ops.Graph()
249      with g.as_default():
250        mu = variables.Variable(dtype(0.0))
251        sigma = variables.Variable(dtype(1.0))
252        dist = normal_lib.Normal(loc=mu, scale=sigma)
253        x = np.array([-100., -20., -5., 0., 5., 20., 100.]).astype(dtype)
254        for func in [
255            dist.cdf, dist.log_cdf, dist.survival_function,
256            dist.log_survival_function, dist.log_prob, dist.prob
257        ]:
258          value = func(x)
259          grads = gradients_impl.gradients(value, [mu, sigma])
260          with self.session(graph=g):
261            variables.global_variables_initializer().run()
262            self.assertAllFinite(value)
263            self.assertAllFinite(grads[0])
264            self.assertAllFinite(grads[1])
265
266  @test_util.run_in_graph_and_eager_modes
267  def testNormalLogSurvivalFunction(self):
268    batch_size = 50
269    mu = self._rng.randn(batch_size)
270    sigma = self._rng.rand(batch_size) + 1.0
271    x = np.linspace(-10.0, 100.0, batch_size).astype(np.float64)
272
273    normal = normal_lib.Normal(loc=mu, scale=sigma)
274
275    sf = normal.log_survival_function(x)
276    self.assertAllEqual(
277        self.evaluate(normal.batch_shape_tensor()), sf.get_shape())
278    self.assertAllEqual(
279        self.evaluate(normal.batch_shape_tensor()),
280        self.evaluate(sf).shape)
281    self.assertAllEqual(normal.batch_shape, sf.get_shape())
282    self.assertAllEqual(normal.batch_shape, self.evaluate(sf).shape)
283
284    if not stats:
285      return
286    expected_sf = stats.norm(mu, sigma).logsf(x)
287    self.assertAllClose(expected_sf, self.evaluate(sf), atol=0, rtol=1e-5)
288
289  @test_util.run_in_graph_and_eager_modes
290  def testNormalEntropyWithScalarInputs(self):
291    # Scipy.stats.norm cannot deal with the shapes in the other test.
292    mu_v = 2.34
293    sigma_v = 4.56
294    normal = normal_lib.Normal(loc=mu_v, scale=sigma_v)
295
296    entropy = normal.entropy()
297    self.assertAllEqual(
298        self.evaluate(normal.batch_shape_tensor()), entropy.get_shape())
299    self.assertAllEqual(
300        self.evaluate(normal.batch_shape_tensor()),
301        self.evaluate(entropy).shape)
302    self.assertAllEqual(normal.batch_shape, entropy.get_shape())
303    self.assertAllEqual(normal.batch_shape, self.evaluate(entropy).shape)
304    # scipy.stats.norm cannot deal with these shapes.
305    if not stats:
306      return
307    expected_entropy = stats.norm(mu_v, sigma_v).entropy()
308    self.assertAllClose(expected_entropy, self.evaluate(entropy))
309
310  @test_util.run_in_graph_and_eager_modes
311  def testNormalEntropy(self):
312    mu_v = np.array([1.0, 1.0, 1.0])
313    sigma_v = np.array([[1.0, 2.0, 3.0]]).T
314    normal = normal_lib.Normal(loc=mu_v, scale=sigma_v)
315
316    # scipy.stats.norm cannot deal with these shapes.
317    sigma_broadcast = mu_v * sigma_v
318    expected_entropy = 0.5 * np.log(2 * np.pi * np.exp(1) * sigma_broadcast**2)
319    entropy = normal.entropy()
320    np.testing.assert_allclose(expected_entropy, self.evaluate(entropy))
321    self.assertAllEqual(
322        self.evaluate(normal.batch_shape_tensor()), entropy.get_shape())
323    self.assertAllEqual(
324        self.evaluate(normal.batch_shape_tensor()),
325        self.evaluate(entropy).shape)
326    self.assertAllEqual(normal.batch_shape, entropy.get_shape())
327    self.assertAllEqual(normal.batch_shape, self.evaluate(entropy).shape)
328
329  @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
330  def testNormalMeanAndMode(self):
331    # Mu will be broadcast to [7, 7, 7].
332    mu = [7.]
333    sigma = [11., 12., 13.]
334
335    normal = normal_lib.Normal(loc=mu, scale=sigma)
336
337    self.assertAllEqual((3,), normal.mean().get_shape())
338    self.assertAllEqual([7., 7, 7], self.evaluate(normal.mean()))
339
340    self.assertAllEqual((3,), normal.mode().get_shape())
341    self.assertAllEqual([7., 7, 7], self.evaluate(normal.mode()))
342
343  @test_util.run_in_graph_and_eager_modes
344  def testNormalQuantile(self):
345    batch_size = 52
346    mu = self._rng.randn(batch_size)
347    sigma = self._rng.rand(batch_size) + 1.0
348    p = np.linspace(0., 1.0, batch_size - 2).astype(np.float64)
349    # Quantile performs piecewise rational approximation so adding some
350    # special input values to make sure we hit all the pieces.
351    p = np.hstack((p, np.exp(-33), 1. - np.exp(-33)))
352
353    normal = normal_lib.Normal(loc=mu, scale=sigma)
354    x = normal.quantile(p)
355
356    self.assertAllEqual(
357        self.evaluate(normal.batch_shape_tensor()), x.get_shape())
358    self.assertAllEqual(
359        self.evaluate(normal.batch_shape_tensor()),
360        self.evaluate(x).shape)
361    self.assertAllEqual(normal.batch_shape, x.get_shape())
362    self.assertAllEqual(normal.batch_shape, self.evaluate(x).shape)
363
364    if not stats:
365      return
366    expected_x = stats.norm(mu, sigma).ppf(p)
367    self.assertAllClose(expected_x, self.evaluate(x), atol=0.)
368
369  def _baseQuantileFiniteGradientAtDifficultPoints(self, dtype):
370    g = ops.Graph()
371    with g.as_default():
372      mu = variables.Variable(dtype(0.0))
373      sigma = variables.Variable(dtype(1.0))
374      dist = normal_lib.Normal(loc=mu, scale=sigma)
375      p = variables.Variable(
376          np.array([0.,
377                    np.exp(-32.), np.exp(-2.),
378                    1. - np.exp(-2.), 1. - np.exp(-32.),
379                    1.]).astype(dtype))
380
381      value = dist.quantile(p)
382      grads = gradients_impl.gradients(value, [mu, p])
383      with self.cached_session(graph=g):
384        variables.global_variables_initializer().run()
385        self.assertAllFinite(grads[0])
386        self.assertAllFinite(grads[1])
387
388  def testQuantileFiniteGradientAtDifficultPointsFloat32(self):
389    self._baseQuantileFiniteGradientAtDifficultPoints(np.float32)
390
391  def testQuantileFiniteGradientAtDifficultPointsFloat64(self):
392    self._baseQuantileFiniteGradientAtDifficultPoints(np.float64)
393
394  @test_util.run_in_graph_and_eager_modes
395  def testNormalVariance(self):
396    # sigma will be broadcast to [7, 7, 7]
397    mu = [1., 2., 3.]
398    sigma = [7.]
399
400    normal = normal_lib.Normal(loc=mu, scale=sigma)
401
402    self.assertAllEqual((3,), normal.variance().get_shape())
403    self.assertAllEqual([49., 49, 49], self.evaluate(normal.variance()))
404
405  @test_util.run_in_graph_and_eager_modes
406  def testNormalStandardDeviation(self):
407    # sigma will be broadcast to [7, 7, 7]
408    mu = [1., 2., 3.]
409    sigma = [7.]
410
411    normal = normal_lib.Normal(loc=mu, scale=sigma)
412
413    self.assertAllEqual((3,), normal.stddev().get_shape())
414    self.assertAllEqual([7., 7, 7], self.evaluate(normal.stddev()))
415
416  @test_util.run_in_graph_and_eager_modes
417  def testNormalSample(self):
418    mu = constant_op.constant(3.0)
419    sigma = constant_op.constant(math.sqrt(3.0))
420    mu_v = 3.0
421    sigma_v = np.sqrt(3.0)
422    n = constant_op.constant(100000)
423    normal = normal_lib.Normal(loc=mu, scale=sigma)
424    samples = normal.sample(n)
425    sample_values = self.evaluate(samples)
426    # Note that the standard error for the sample mean is ~ sigma / sqrt(n).
427    # The sample variance similarly is dependent on sigma and n.
428    # Thus, the tolerances below are very sensitive to number of samples
429    # as well as the variances chosen.
430    self.assertEqual(sample_values.shape, (100000,))
431    self.assertAllClose(sample_values.mean(), mu_v, atol=1e-1)
432    self.assertAllClose(sample_values.std(), sigma_v, atol=1e-1)
433
434    expected_samples_shape = tensor_shape.TensorShape(
435        [self.evaluate(n)]).concatenate(
436            tensor_shape.TensorShape(
437                self.evaluate(normal.batch_shape_tensor())))
438
439    self.assertAllEqual(expected_samples_shape, samples.get_shape())
440    self.assertAllEqual(expected_samples_shape, sample_values.shape)
441
442    expected_samples_shape = (
443        tensor_shape.TensorShape([self.evaluate(n)]).concatenate(
444            normal.batch_shape))
445
446    self.assertAllEqual(expected_samples_shape, samples.get_shape())
447    self.assertAllEqual(expected_samples_shape, sample_values.shape)
448
449  def testNormalFullyReparameterized(self):
450    mu = constant_op.constant(4.0)
451    sigma = constant_op.constant(3.0)
452    with backprop.GradientTape() as tape:
453      tape.watch(mu)
454      tape.watch(sigma)
455      normal = normal_lib.Normal(loc=mu, scale=sigma)
456      samples = normal.sample(100)
457    grad_mu, grad_sigma = tape.gradient(samples, [mu, sigma])
458    self.assertIsNotNone(grad_mu)
459    self.assertIsNotNone(grad_sigma)
460
461  @test_util.run_in_graph_and_eager_modes
462  def testNormalSampleMultiDimensional(self):
463    batch_size = 2
464    mu = constant_op.constant([[3.0, -3.0]] * batch_size)
465    sigma = constant_op.constant(
466        [[math.sqrt(2.0), math.sqrt(3.0)]] * batch_size)
467    mu_v = [3.0, -3.0]
468    sigma_v = [np.sqrt(2.0), np.sqrt(3.0)]
469    n = constant_op.constant(100000)
470    normal = normal_lib.Normal(loc=mu, scale=sigma)
471    samples = normal.sample(n)
472    sample_values = self.evaluate(samples)
473    # Note that the standard error for the sample mean is ~ sigma / sqrt(n).
474    # The sample variance similarly is dependent on sigma and n.
475    # Thus, the tolerances below are very sensitive to number of samples
476    # as well as the variances chosen.
477    self.assertEqual(samples.get_shape(), (100000, batch_size, 2))
478    self.assertAllClose(sample_values[:, 0, 0].mean(), mu_v[0], atol=1e-1)
479    self.assertAllClose(sample_values[:, 0, 0].std(), sigma_v[0], atol=1e-1)
480    self.assertAllClose(sample_values[:, 0, 1].mean(), mu_v[1], atol=1e-1)
481    self.assertAllClose(sample_values[:, 0, 1].std(), sigma_v[1], atol=1e-1)
482
483    expected_samples_shape = tensor_shape.TensorShape(
484        [self.evaluate(n)]).concatenate(
485            tensor_shape.TensorShape(
486                self.evaluate(normal.batch_shape_tensor())))
487    self.assertAllEqual(expected_samples_shape, samples.get_shape())
488    self.assertAllEqual(expected_samples_shape, sample_values.shape)
489
490    expected_samples_shape = (
491        tensor_shape.TensorShape([self.evaluate(n)]).concatenate(
492            normal.batch_shape))
493    self.assertAllEqual(expected_samples_shape, samples.get_shape())
494    self.assertAllEqual(expected_samples_shape, sample_values.shape)
495
496  @test_util.run_in_graph_and_eager_modes
497  def testNegativeSigmaFails(self):
498    with self.assertRaisesOpError("Condition x > 0 did not hold"):
499      normal = normal_lib.Normal(
500          loc=[1.], scale=[-5.], validate_args=True, name="G")
501      self.evaluate(normal.mean())
502
503  @test_util.run_in_graph_and_eager_modes
504  def testNormalShape(self):
505    mu = constant_op.constant([-3.0] * 5)
506    sigma = constant_op.constant(11.0)
507    normal = normal_lib.Normal(loc=mu, scale=sigma)
508
509    self.assertEqual(self.evaluate(normal.batch_shape_tensor()), [5])
510    self.assertEqual(normal.batch_shape, tensor_shape.TensorShape([5]))
511    self.assertAllEqual(self.evaluate(normal.event_shape_tensor()), [])
512    self.assertEqual(normal.event_shape, tensor_shape.TensorShape([]))
513
514  @test_util.run_deprecated_v1
515  def testNormalShapeWithPlaceholders(self):
516    mu = array_ops.placeholder(dtype=dtypes.float32)
517    sigma = array_ops.placeholder(dtype=dtypes.float32)
518    normal = normal_lib.Normal(loc=mu, scale=sigma)
519
520    with self.cached_session() as sess:
521      # get_batch_shape should return an "<unknown>" tensor.
522      self.assertEqual(normal.batch_shape, tensor_shape.TensorShape(None))
523      self.assertEqual(normal.event_shape, ())
524      self.assertAllEqual(self.evaluate(normal.event_shape_tensor()), [])
525      self.assertAllEqual(
526          sess.run(normal.batch_shape_tensor(),
527                   feed_dict={mu: 5.0,
528                              sigma: [1.0, 2.0]}), [2])
529
530  @test_util.run_in_graph_and_eager_modes
531  def testNormalNormalKL(self):
532    batch_size = 6
533    mu_a = np.array([3.0] * batch_size)
534    sigma_a = np.array([1.0, 2.0, 3.0, 1.5, 2.5, 3.5])
535    mu_b = np.array([-3.0] * batch_size)
536    sigma_b = np.array([0.5, 1.0, 1.5, 2.0, 2.5, 3.0])
537
538    n_a = normal_lib.Normal(loc=mu_a, scale=sigma_a)
539    n_b = normal_lib.Normal(loc=mu_b, scale=sigma_b)
540
541    kl = kullback_leibler.kl_divergence(n_a, n_b)
542    kl_val = self.evaluate(kl)
543
544    kl_expected = ((mu_a - mu_b)**2 / (2 * sigma_b**2) + 0.5 * (
545        (sigma_a**2 / sigma_b**2) - 1 - 2 * np.log(sigma_a / sigma_b)))
546
547    self.assertEqual(kl.get_shape(), (batch_size,))
548    self.assertAllClose(kl_val, kl_expected)
549
550
551if __name__ == "__main__":
552  test.main()
553