• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for initializers."""
16
17import importlib
18import math
19
20import numpy as np
21
22from tensorflow.python.eager import backprop
23from tensorflow.python.framework import constant_op
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import tensor_shape
27from tensorflow.python.framework import test_util
28from tensorflow.python.ops import array_ops
29from tensorflow.python.ops import gradients_impl
30from tensorflow.python.ops import nn_ops
31from tensorflow.python.ops import variables
32from tensorflow.python.ops.distributions import kullback_leibler
33from tensorflow.python.ops.distributions import normal as normal_lib
34from tensorflow.python.platform import test
35from tensorflow.python.platform import tf_logging
36
37
38def try_import(name):  # pylint: disable=invalid-name
39  module = None
40  try:
41    module = importlib.import_module(name)
42  except ImportError as e:
43    tf_logging.warning("Could not import %s: %s" % (name, str(e)))
44  return module
45
46stats = try_import("scipy.stats")
47
48
49class NormalTest(test.TestCase):
50
51  def setUp(self):
52    self._rng = np.random.RandomState(123)
53
54  def assertAllFinite(self, tensor):
55    is_finite = np.isfinite(self.evaluate(tensor))
56    all_true = np.ones_like(is_finite, dtype=np.bool_)
57    self.assertAllEqual(all_true, is_finite)
58
59  def _testParamShapes(self, sample_shape, expected):
60    param_shapes = normal_lib.Normal.param_shapes(sample_shape)
61    mu_shape, sigma_shape = param_shapes["loc"], param_shapes["scale"]
62    self.assertAllEqual(expected, self.evaluate(mu_shape))
63    self.assertAllEqual(expected, self.evaluate(sigma_shape))
64    mu = array_ops.zeros(mu_shape)
65    sigma = array_ops.ones(sigma_shape)
66    self.assertAllEqual(
67        expected,
68        self.evaluate(array_ops.shape(normal_lib.Normal(mu, sigma).sample())))
69
70  def _testParamStaticShapes(self, sample_shape, expected):
71    param_shapes = normal_lib.Normal.param_static_shapes(sample_shape)
72    mu_shape, sigma_shape = param_shapes["loc"], param_shapes["scale"]
73    self.assertEqual(expected, mu_shape)
74    self.assertEqual(expected, sigma_shape)
75
76  @test_util.run_in_graph_and_eager_modes
77  def testSampleLikeArgsGetDistDType(self):
78    dist = normal_lib.Normal(0., 1.)
79    self.assertEqual(dtypes.float32, dist.dtype)
80    for method in ("log_prob", "prob", "log_cdf", "cdf",
81                   "log_survival_function", "survival_function", "quantile"):
82      self.assertEqual(dtypes.float32, getattr(dist, method)(1).dtype)
83
84  @test_util.run_in_graph_and_eager_modes
85  def testParamShapes(self):
86    sample_shape = [10, 3, 4]
87    self._testParamShapes(sample_shape, sample_shape)
88    self._testParamShapes(constant_op.constant(sample_shape), sample_shape)
89
90  @test_util.run_in_graph_and_eager_modes
91  def testParamStaticShapes(self):
92    sample_shape = [10, 3, 4]
93    self._testParamStaticShapes(sample_shape, sample_shape)
94    self._testParamStaticShapes(
95        tensor_shape.TensorShape(sample_shape), sample_shape)
96
97  @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
98  def testNormalWithSoftplusScale(self):
99    mu = array_ops.zeros((10, 3))
100    rho = array_ops.ones((10, 3)) * -2.
101    normal = normal_lib.NormalWithSoftplusScale(loc=mu, scale=rho)
102    self.assertAllEqual(self.evaluate(mu), self.evaluate(normal.loc))
103    self.assertAllEqual(
104        self.evaluate(nn_ops.softplus(rho)), self.evaluate(normal.scale))
105
106  @test_util.run_in_graph_and_eager_modes
107  def testNormalLogPDF(self):
108    batch_size = 6
109    mu = constant_op.constant([3.0] * batch_size)
110    sigma = constant_op.constant([math.sqrt(10.0)] * batch_size)
111    x = np.array([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0], dtype=np.float32)
112    normal = normal_lib.Normal(loc=mu, scale=sigma)
113
114    log_pdf = normal.log_prob(x)
115    self.assertAllEqual(
116        self.evaluate(normal.batch_shape_tensor()), log_pdf.get_shape())
117    self.assertAllEqual(
118        self.evaluate(normal.batch_shape_tensor()),
119        self.evaluate(log_pdf).shape)
120    self.assertAllEqual(normal.batch_shape, log_pdf.get_shape())
121    self.assertAllEqual(normal.batch_shape, self.evaluate(log_pdf).shape)
122
123    pdf = normal.prob(x)
124    self.assertAllEqual(
125        self.evaluate(normal.batch_shape_tensor()), pdf.get_shape())
126    self.assertAllEqual(
127        self.evaluate(normal.batch_shape_tensor()),
128        self.evaluate(pdf).shape)
129    self.assertAllEqual(normal.batch_shape, pdf.get_shape())
130    self.assertAllEqual(normal.batch_shape, self.evaluate(pdf).shape)
131
132    if not stats:
133      return
134    expected_log_pdf = stats.norm(self.evaluate(mu),
135                                  self.evaluate(sigma)).logpdf(x)
136    self.assertAllClose(expected_log_pdf, self.evaluate(log_pdf))
137    self.assertAllClose(np.exp(expected_log_pdf), self.evaluate(pdf))
138
139  @test_util.run_in_graph_and_eager_modes
140  def testNormalLogPDFMultidimensional(self):
141    batch_size = 6
142    mu = constant_op.constant([[3.0, -3.0]] * batch_size)
143    sigma = constant_op.constant(
144        [[math.sqrt(10.0), math.sqrt(15.0)]] * batch_size)
145    x = np.array([[-2.5, 2.5, 4.0, 0.0, -1.0, 2.0]], dtype=np.float32).T
146    normal = normal_lib.Normal(loc=mu, scale=sigma)
147
148    log_pdf = normal.log_prob(x)
149    log_pdf_values = self.evaluate(log_pdf)
150    self.assertEqual(log_pdf.get_shape(), (6, 2))
151    self.assertAllEqual(
152        self.evaluate(normal.batch_shape_tensor()), log_pdf.get_shape())
153    self.assertAllEqual(
154        self.evaluate(normal.batch_shape_tensor()),
155        self.evaluate(log_pdf).shape)
156    self.assertAllEqual(normal.batch_shape, log_pdf.get_shape())
157    self.assertAllEqual(normal.batch_shape, self.evaluate(log_pdf).shape)
158
159    pdf = normal.prob(x)
160    pdf_values = self.evaluate(pdf)
161    self.assertEqual(pdf.get_shape(), (6, 2))
162    self.assertAllEqual(
163        self.evaluate(normal.batch_shape_tensor()), pdf.get_shape())
164    self.assertAllEqual(
165        self.evaluate(normal.batch_shape_tensor()), pdf_values.shape)
166    self.assertAllEqual(normal.batch_shape, pdf.get_shape())
167    self.assertAllEqual(normal.batch_shape, pdf_values.shape)
168
169    if not stats:
170      return
171    expected_log_pdf = stats.norm(self.evaluate(mu),
172                                  self.evaluate(sigma)).logpdf(x)
173    self.assertAllClose(expected_log_pdf, log_pdf_values)
174    self.assertAllClose(np.exp(expected_log_pdf), pdf_values)
175
176  @test_util.run_in_graph_and_eager_modes
177  def testNormalCDF(self):
178    batch_size = 50
179    mu = self._rng.randn(batch_size)
180    sigma = self._rng.rand(batch_size) + 1.0
181    x = np.linspace(-8.0, 8.0, batch_size).astype(np.float64)
182
183    normal = normal_lib.Normal(loc=mu, scale=sigma)
184    cdf = normal.cdf(x)
185    self.assertAllEqual(
186        self.evaluate(normal.batch_shape_tensor()), cdf.get_shape())
187    self.assertAllEqual(
188        self.evaluate(normal.batch_shape_tensor()),
189        self.evaluate(cdf).shape)
190    self.assertAllEqual(normal.batch_shape, cdf.get_shape())
191    self.assertAllEqual(normal.batch_shape, self.evaluate(cdf).shape)
192    if not stats:
193      return
194    expected_cdf = stats.norm(mu, sigma).cdf(x)
195    self.assertAllClose(expected_cdf, self.evaluate(cdf), atol=0)
196
197  @test_util.run_in_graph_and_eager_modes
198  def testNormalSurvivalFunction(self):
199    batch_size = 50
200    mu = self._rng.randn(batch_size)
201    sigma = self._rng.rand(batch_size) + 1.0
202    x = np.linspace(-8.0, 8.0, batch_size).astype(np.float64)
203
204    normal = normal_lib.Normal(loc=mu, scale=sigma)
205
206    sf = normal.survival_function(x)
207    self.assertAllEqual(
208        self.evaluate(normal.batch_shape_tensor()), sf.get_shape())
209    self.assertAllEqual(
210        self.evaluate(normal.batch_shape_tensor()),
211        self.evaluate(sf).shape)
212    self.assertAllEqual(normal.batch_shape, sf.get_shape())
213    self.assertAllEqual(normal.batch_shape, self.evaluate(sf).shape)
214    if not stats:
215      return
216    expected_sf = stats.norm(mu, sigma).sf(x)
217    self.assertAllClose(expected_sf, self.evaluate(sf), atol=0)
218
219  @test_util.run_in_graph_and_eager_modes
220  def testNormalLogCDF(self):
221    batch_size = 50
222    mu = self._rng.randn(batch_size)
223    sigma = self._rng.rand(batch_size) + 1.0
224    x = np.linspace(-100.0, 10.0, batch_size).astype(np.float64)
225
226    normal = normal_lib.Normal(loc=mu, scale=sigma)
227
228    cdf = normal.log_cdf(x)
229    self.assertAllEqual(
230        self.evaluate(normal.batch_shape_tensor()), cdf.get_shape())
231    self.assertAllEqual(
232        self.evaluate(normal.batch_shape_tensor()),
233        self.evaluate(cdf).shape)
234    self.assertAllEqual(normal.batch_shape, cdf.get_shape())
235    self.assertAllEqual(normal.batch_shape, self.evaluate(cdf).shape)
236
237    if not stats:
238      return
239    expected_cdf = stats.norm(mu, sigma).logcdf(x)
240    self.assertAllClose(expected_cdf, self.evaluate(cdf), atol=0, rtol=1e-3)
241
242  def testFiniteGradientAtDifficultPoints(self):
243    for dtype in [np.float32, np.float64]:
244      g = ops.Graph()
245      with g.as_default():
246        mu = variables.Variable(dtype(0.0))
247        sigma = variables.Variable(dtype(1.0))
248        dist = normal_lib.Normal(loc=mu, scale=sigma)
249        x = np.array([-100., -20., -5., 0., 5., 20., 100.]).astype(dtype)
250        for func in [
251            dist.cdf, dist.log_cdf, dist.survival_function,
252            dist.log_survival_function, dist.log_prob, dist.prob
253        ]:
254          value = func(x)
255          grads = gradients_impl.gradients(value, [mu, sigma])
256          with self.session(graph=g):
257            self.evaluate(variables.global_variables_initializer())
258            self.assertAllFinite(value)
259            self.assertAllFinite(grads[0])
260            self.assertAllFinite(grads[1])
261
262  @test_util.run_in_graph_and_eager_modes
263  def testNormalLogSurvivalFunction(self):
264    batch_size = 50
265    mu = self._rng.randn(batch_size)
266    sigma = self._rng.rand(batch_size) + 1.0
267    x = np.linspace(-10.0, 100.0, batch_size).astype(np.float64)
268
269    normal = normal_lib.Normal(loc=mu, scale=sigma)
270
271    sf = normal.log_survival_function(x)
272    self.assertAllEqual(
273        self.evaluate(normal.batch_shape_tensor()), sf.get_shape())
274    self.assertAllEqual(
275        self.evaluate(normal.batch_shape_tensor()),
276        self.evaluate(sf).shape)
277    self.assertAllEqual(normal.batch_shape, sf.get_shape())
278    self.assertAllEqual(normal.batch_shape, self.evaluate(sf).shape)
279
280    if not stats:
281      return
282    expected_sf = stats.norm(mu, sigma).logsf(x)
283    self.assertAllClose(expected_sf, self.evaluate(sf), atol=0, rtol=1e-5)
284
285  @test_util.run_in_graph_and_eager_modes
286  def testNormalEntropyWithScalarInputs(self):
287    # Scipy.stats.norm cannot deal with the shapes in the other test.
288    mu_v = 2.34
289    sigma_v = 4.56
290    normal = normal_lib.Normal(loc=mu_v, scale=sigma_v)
291
292    entropy = normal.entropy()
293    self.assertAllEqual(
294        self.evaluate(normal.batch_shape_tensor()), entropy.get_shape())
295    self.assertAllEqual(
296        self.evaluate(normal.batch_shape_tensor()),
297        self.evaluate(entropy).shape)
298    self.assertAllEqual(normal.batch_shape, entropy.get_shape())
299    self.assertAllEqual(normal.batch_shape, self.evaluate(entropy).shape)
300    # scipy.stats.norm cannot deal with these shapes.
301    if not stats:
302      return
303    expected_entropy = stats.norm(mu_v, sigma_v).entropy()
304    self.assertAllClose(expected_entropy, self.evaluate(entropy))
305
306  @test_util.run_in_graph_and_eager_modes
307  def testNormalEntropy(self):
308    mu_v = np.array([1.0, 1.0, 1.0])
309    sigma_v = np.array([[1.0, 2.0, 3.0]]).T
310    normal = normal_lib.Normal(loc=mu_v, scale=sigma_v)
311
312    # scipy.stats.norm cannot deal with these shapes.
313    sigma_broadcast = mu_v * sigma_v
314    expected_entropy = 0.5 * np.log(2 * np.pi * np.exp(1) * sigma_broadcast**2)
315    entropy = normal.entropy()
316    np.testing.assert_allclose(expected_entropy, self.evaluate(entropy))
317    self.assertAllEqual(
318        self.evaluate(normal.batch_shape_tensor()), entropy.get_shape())
319    self.assertAllEqual(
320        self.evaluate(normal.batch_shape_tensor()),
321        self.evaluate(entropy).shape)
322    self.assertAllEqual(normal.batch_shape, entropy.get_shape())
323    self.assertAllEqual(normal.batch_shape, self.evaluate(entropy).shape)
324
325  @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
326  def testNormalMeanAndMode(self):
327    # Mu will be broadcast to [7, 7, 7].
328    mu = [7.]
329    sigma = [11., 12., 13.]
330
331    normal = normal_lib.Normal(loc=mu, scale=sigma)
332
333    self.assertAllEqual((3,), normal.mean().get_shape())
334    self.assertAllEqual([7., 7, 7], self.evaluate(normal.mean()))
335
336    self.assertAllEqual((3,), normal.mode().get_shape())
337    self.assertAllEqual([7., 7, 7], self.evaluate(normal.mode()))
338
339  @test_util.run_in_graph_and_eager_modes
340  def testNormalQuantile(self):
341    batch_size = 52
342    mu = self._rng.randn(batch_size)
343    sigma = self._rng.rand(batch_size) + 1.0
344    p = np.linspace(0., 1.0, batch_size - 2).astype(np.float64)
345    # Quantile performs piecewise rational approximation so adding some
346    # special input values to make sure we hit all the pieces.
347    p = np.hstack((p, np.exp(-33), 1. - np.exp(-33)))
348
349    normal = normal_lib.Normal(loc=mu, scale=sigma)
350    x = normal.quantile(p)
351
352    self.assertAllEqual(
353        self.evaluate(normal.batch_shape_tensor()), x.get_shape())
354    self.assertAllEqual(
355        self.evaluate(normal.batch_shape_tensor()),
356        self.evaluate(x).shape)
357    self.assertAllEqual(normal.batch_shape, x.get_shape())
358    self.assertAllEqual(normal.batch_shape, self.evaluate(x).shape)
359
360    if not stats:
361      return
362    expected_x = stats.norm(mu, sigma).ppf(p)
363    self.assertAllClose(expected_x, self.evaluate(x), atol=0.)
364
365  def _baseQuantileFiniteGradientAtDifficultPoints(self, dtype):
366    g = ops.Graph()
367    with g.as_default():
368      mu = variables.Variable(dtype(0.0))
369      sigma = variables.Variable(dtype(1.0))
370      dist = normal_lib.Normal(loc=mu, scale=sigma)
371      p = variables.Variable(
372          np.array([0.,
373                    np.exp(-32.), np.exp(-2.),
374                    1. - np.exp(-2.), 1. - np.exp(-32.),
375                    1.]).astype(dtype))
376
377      value = dist.quantile(p)
378      grads = gradients_impl.gradients(value, [mu, p])
379      with self.cached_session(graph=g):
380        self.evaluate(variables.global_variables_initializer())
381        self.assertAllFinite(grads[0])
382        self.assertAllFinite(grads[1])
383
384  def testQuantileFiniteGradientAtDifficultPointsFloat32(self):
385    self._baseQuantileFiniteGradientAtDifficultPoints(np.float32)
386
387  def testQuantileFiniteGradientAtDifficultPointsFloat64(self):
388    self._baseQuantileFiniteGradientAtDifficultPoints(np.float64)
389
390  @test_util.run_in_graph_and_eager_modes
391  def testNormalVariance(self):
392    # sigma will be broadcast to [7, 7, 7]
393    mu = [1., 2., 3.]
394    sigma = [7.]
395
396    normal = normal_lib.Normal(loc=mu, scale=sigma)
397
398    self.assertAllEqual((3,), normal.variance().get_shape())
399    self.assertAllEqual([49., 49, 49], self.evaluate(normal.variance()))
400
401  @test_util.run_in_graph_and_eager_modes
402  def testNormalStandardDeviation(self):
403    # sigma will be broadcast to [7, 7, 7]
404    mu = [1., 2., 3.]
405    sigma = [7.]
406
407    normal = normal_lib.Normal(loc=mu, scale=sigma)
408
409    self.assertAllEqual((3,), normal.stddev().get_shape())
410    self.assertAllEqual([7., 7, 7], self.evaluate(normal.stddev()))
411
412  @test_util.run_in_graph_and_eager_modes
413  def testNormalSample(self):
414    mu = constant_op.constant(3.0)
415    sigma = constant_op.constant(math.sqrt(3.0))
416    mu_v = 3.0
417    sigma_v = np.sqrt(3.0)
418    n = constant_op.constant(100000)
419    normal = normal_lib.Normal(loc=mu, scale=sigma)
420    samples = normal.sample(n)
421    sample_values = self.evaluate(samples)
422    # Note that the standard error for the sample mean is ~ sigma / sqrt(n).
423    # The sample variance similarly is dependent on sigma and n.
424    # Thus, the tolerances below are very sensitive to number of samples
425    # as well as the variances chosen.
426    self.assertEqual(sample_values.shape, (100000,))
427    self.assertAllClose(sample_values.mean(), mu_v, atol=1e-1)
428    self.assertAllClose(sample_values.std(), sigma_v, atol=1e-1)
429
430    expected_samples_shape = tensor_shape.TensorShape(
431        [self.evaluate(n)]).concatenate(
432            tensor_shape.TensorShape(
433                self.evaluate(normal.batch_shape_tensor())))
434
435    self.assertAllEqual(expected_samples_shape, samples.get_shape())
436    self.assertAllEqual(expected_samples_shape, sample_values.shape)
437
438    expected_samples_shape = (
439        tensor_shape.TensorShape([self.evaluate(n)]).concatenate(
440            normal.batch_shape))
441
442    self.assertAllEqual(expected_samples_shape, samples.get_shape())
443    self.assertAllEqual(expected_samples_shape, sample_values.shape)
444
445  def testNormalFullyReparameterized(self):
446    mu = constant_op.constant(4.0)
447    sigma = constant_op.constant(3.0)
448    with backprop.GradientTape() as tape:
449      tape.watch(mu)
450      tape.watch(sigma)
451      normal = normal_lib.Normal(loc=mu, scale=sigma)
452      samples = normal.sample(100)
453    grad_mu, grad_sigma = tape.gradient(samples, [mu, sigma])
454    self.assertIsNotNone(grad_mu)
455    self.assertIsNotNone(grad_sigma)
456
457  @test_util.run_in_graph_and_eager_modes
458  def testNormalSampleMultiDimensional(self):
459    batch_size = 2
460    mu = constant_op.constant([[3.0, -3.0]] * batch_size)
461    sigma = constant_op.constant(
462        [[math.sqrt(2.0), math.sqrt(3.0)]] * batch_size)
463    mu_v = [3.0, -3.0]
464    sigma_v = [np.sqrt(2.0), np.sqrt(3.0)]
465    n = constant_op.constant(100000)
466    normal = normal_lib.Normal(loc=mu, scale=sigma)
467    samples = normal.sample(n)
468    sample_values = self.evaluate(samples)
469    # Note that the standard error for the sample mean is ~ sigma / sqrt(n).
470    # The sample variance similarly is dependent on sigma and n.
471    # Thus, the tolerances below are very sensitive to number of samples
472    # as well as the variances chosen.
473    self.assertEqual(samples.get_shape(), (100000, batch_size, 2))
474    self.assertAllClose(sample_values[:, 0, 0].mean(), mu_v[0], atol=1e-1)
475    self.assertAllClose(sample_values[:, 0, 0].std(), sigma_v[0], atol=1e-1)
476    self.assertAllClose(sample_values[:, 0, 1].mean(), mu_v[1], atol=1e-1)
477    self.assertAllClose(sample_values[:, 0, 1].std(), sigma_v[1], atol=1e-1)
478
479    expected_samples_shape = tensor_shape.TensorShape(
480        [self.evaluate(n)]).concatenate(
481            tensor_shape.TensorShape(
482                self.evaluate(normal.batch_shape_tensor())))
483    self.assertAllEqual(expected_samples_shape, samples.get_shape())
484    self.assertAllEqual(expected_samples_shape, sample_values.shape)
485
486    expected_samples_shape = (
487        tensor_shape.TensorShape([self.evaluate(n)]).concatenate(
488            normal.batch_shape))
489    self.assertAllEqual(expected_samples_shape, samples.get_shape())
490    self.assertAllEqual(expected_samples_shape, sample_values.shape)
491
492  @test_util.run_in_graph_and_eager_modes
493  def testNegativeSigmaFails(self):
494    with self.assertRaisesOpError("Condition x > 0 did not hold"):
495      normal = normal_lib.Normal(
496          loc=[1.], scale=[-5.], validate_args=True, name="G")
497      self.evaluate(normal.mean())
498
499  @test_util.run_in_graph_and_eager_modes
500  def testNormalShape(self):
501    mu = constant_op.constant([-3.0] * 5)
502    sigma = constant_op.constant(11.0)
503    normal = normal_lib.Normal(loc=mu, scale=sigma)
504
505    self.assertEqual(self.evaluate(normal.batch_shape_tensor()), [5])
506    self.assertEqual(normal.batch_shape, tensor_shape.TensorShape([5]))
507    self.assertAllEqual(self.evaluate(normal.event_shape_tensor()), [])
508    self.assertEqual(normal.event_shape, tensor_shape.TensorShape([]))
509
510  @test_util.run_deprecated_v1
511  def testNormalShapeWithPlaceholders(self):
512    mu = array_ops.placeholder(dtype=dtypes.float32)
513    sigma = array_ops.placeholder(dtype=dtypes.float32)
514    normal = normal_lib.Normal(loc=mu, scale=sigma)
515
516    with self.cached_session() as sess:
517      # get_batch_shape should return an "<unknown>" tensor.
518      self.assertEqual(normal.batch_shape, tensor_shape.TensorShape(None))
519      self.assertEqual(normal.event_shape, ())
520      self.assertAllEqual(self.evaluate(normal.event_shape_tensor()), [])
521      self.assertAllEqual(
522          sess.run(normal.batch_shape_tensor(),
523                   feed_dict={mu: 5.0,
524                              sigma: [1.0, 2.0]}), [2])
525
526  @test_util.run_in_graph_and_eager_modes
527  def testNormalNormalKL(self):
528    batch_size = 6
529    mu_a = np.array([3.0] * batch_size)
530    sigma_a = np.array([1.0, 2.0, 3.0, 1.5, 2.5, 3.5])
531    mu_b = np.array([-3.0] * batch_size)
532    sigma_b = np.array([0.5, 1.0, 1.5, 2.0, 2.5, 3.0])
533
534    n_a = normal_lib.Normal(loc=mu_a, scale=sigma_a)
535    n_b = normal_lib.Normal(loc=mu_b, scale=sigma_b)
536
537    kl = kullback_leibler.kl_divergence(n_a, n_b)
538    kl_val = self.evaluate(kl)
539
540    kl_expected = ((mu_a - mu_b)**2 / (2 * sigma_b**2) + 0.5 * (
541        (sigma_a**2 / sigma_b**2) - 1 - 2 * np.log(sigma_a / sigma_b)))
542
543    self.assertEqual(kl.get_shape(), (batch_size,))
544    self.assertAllClose(kl_val, kl_expected)
545
546
547if __name__ == "__main__":
548  test.main()
549