• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for distributions KL mechanism."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.framework import test_util
22from tensorflow.python.ops import array_ops
23from tensorflow.python.ops.distributions import kullback_leibler
24from tensorflow.python.ops.distributions import normal
25from tensorflow.python.platform import test
26
27# pylint: disable=protected-access
28_DIVERGENCES = kullback_leibler._DIVERGENCES
29_registered_kl = kullback_leibler._registered_kl
30
31# pylint: enable=protected-access
32
33
34class KLTest(test.TestCase):
35
36  def testRegistration(self):
37
38    class MyDist(normal.Normal):
39      pass
40
41    # Register KL to a lambda that spits out the name parameter
42    @kullback_leibler.RegisterKL(MyDist, MyDist)
43    def _kl(a, b, name=None):  # pylint: disable=unused-argument,unused-variable
44      return name
45
46    a = MyDist(loc=0.0, scale=1.0)
47    self.assertEqual("OK", kullback_leibler.kl_divergence(a, a, name="OK"))
48
49  @test_util.run_deprecated_v1
50  def testDomainErrorExceptions(self):
51
52    class MyDistException(normal.Normal):
53      pass
54
55    # Register KL to a lambda that spits out the name parameter
56    @kullback_leibler.RegisterKL(MyDistException, MyDistException)
57    # pylint: disable=unused-argument,unused-variable
58    def _kl(a, b, name=None):
59      return array_ops.identity([float("nan")])
60
61    # pylint: disable=unused-argument,unused-variable
62
63    with self.cached_session():
64      a = MyDistException(loc=0.0, scale=1.0, allow_nan_stats=False)
65      kl = kullback_leibler.kl_divergence(a, a, allow_nan_stats=False)
66      with self.assertRaisesOpError(
67          "KL calculation between .* and .* returned NaN values"):
68        self.evaluate(kl)
69      with self.assertRaisesOpError(
70          "KL calculation between .* and .* returned NaN values"):
71        a.kl_divergence(a).eval()
72      a = MyDistException(loc=0.0, scale=1.0, allow_nan_stats=True)
73      kl_ok = kullback_leibler.kl_divergence(a, a)
74      self.assertAllEqual([float("nan")], self.evaluate(kl_ok))
75      self_kl_ok = a.kl_divergence(a)
76      self.assertAllEqual([float("nan")], self.evaluate(self_kl_ok))
77      cross_ok = a.cross_entropy(a)
78      self.assertAllEqual([float("nan")], self.evaluate(cross_ok))
79
80  def testRegistrationFailures(self):
81
82    class MyDist(normal.Normal):
83      pass
84
85    with self.assertRaisesRegexp(TypeError, "must be callable"):
86      kullback_leibler.RegisterKL(MyDist, MyDist)("blah")
87
88    # First registration is OK
89    kullback_leibler.RegisterKL(MyDist, MyDist)(lambda a, b: None)
90
91    # Second registration fails
92    with self.assertRaisesRegexp(ValueError, "has already been registered"):
93      kullback_leibler.RegisterKL(MyDist, MyDist)(lambda a, b: None)
94
95  def testExactRegistrationsAllMatch(self):
96    for (k, v) in _DIVERGENCES.items():
97      self.assertEqual(v, _registered_kl(*k))
98
99  def _testIndirectRegistration(self, fn):
100
101    class Sub1(normal.Normal):
102
103      def entropy(self):
104        return ""
105
106    class Sub2(normal.Normal):
107
108      def entropy(self):
109        return ""
110
111    class Sub11(Sub1):
112
113      def entropy(self):
114        return ""
115
116    # pylint: disable=unused-argument,unused-variable
117    @kullback_leibler.RegisterKL(Sub1, Sub1)
118    def _kl11(a, b, name=None):
119      return "sub1-1"
120
121    @kullback_leibler.RegisterKL(Sub1, Sub2)
122    def _kl12(a, b, name=None):
123      return "sub1-2"
124
125    @kullback_leibler.RegisterKL(Sub2, Sub1)
126    def _kl21(a, b, name=None):
127      return "sub2-1"
128
129    # pylint: enable=unused-argument,unused_variable
130
131    sub1 = Sub1(loc=0.0, scale=1.0)
132    sub2 = Sub2(loc=0.0, scale=1.0)
133    sub11 = Sub11(loc=0.0, scale=1.0)
134
135    self.assertEqual("sub1-1", fn(sub1, sub1))
136    self.assertEqual("sub1-2", fn(sub1, sub2))
137    self.assertEqual("sub2-1", fn(sub2, sub1))
138    self.assertEqual("sub1-1", fn(sub11, sub11))
139    self.assertEqual("sub1-1", fn(sub11, sub1))
140    self.assertEqual("sub1-2", fn(sub11, sub2))
141    self.assertEqual("sub1-1", fn(sub11, sub1))
142    self.assertEqual("sub1-2", fn(sub11, sub2))
143    self.assertEqual("sub2-1", fn(sub2, sub11))
144    self.assertEqual("sub1-1", fn(sub1, sub11))
145
146  def testIndirectRegistrationKLFun(self):
147    self._testIndirectRegistration(kullback_leibler.kl_divergence)
148
149  def testIndirectRegistrationKLSelf(self):
150    self._testIndirectRegistration(
151        lambda p, q: p.kl_divergence(q))
152
153  def testIndirectRegistrationCrossEntropy(self):
154    self._testIndirectRegistration(
155        lambda p, q: p.cross_entropy(q))
156
157  def testFunctionCrossEntropy(self):
158    self._testIndirectRegistration(kullback_leibler.cross_entropy)
159
160
161if __name__ == "__main__":
162  test.main()
163