• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for distributions KL mechanism."""
16
17from tensorflow.python.framework import test_util
18from tensorflow.python.ops import array_ops
19from tensorflow.python.ops.distributions import kullback_leibler
20from tensorflow.python.ops.distributions import normal
21from tensorflow.python.platform import test
22
23# pylint: disable=protected-access
24_DIVERGENCES = kullback_leibler._DIVERGENCES
25_registered_kl = kullback_leibler._registered_kl
26
27# pylint: enable=protected-access
28
29
30class KLTest(test.TestCase):
31
32  def testRegistration(self):
33
34    class MyDist(normal.Normal):
35      pass
36
37    # Register KL to a lambda that spits out the name parameter
38    @kullback_leibler.RegisterKL(MyDist, MyDist)
39    def _kl(a, b, name=None):  # pylint: disable=unused-argument,unused-variable
40      return name
41
42    a = MyDist(loc=0.0, scale=1.0)
43    self.assertEqual("OK", kullback_leibler.kl_divergence(a, a, name="OK"))
44
45  @test_util.run_deprecated_v1
46  def testDomainErrorExceptions(self):
47
48    class MyDistException(normal.Normal):
49      pass
50
51    # Register KL to a lambda that spits out the name parameter
52    @kullback_leibler.RegisterKL(MyDistException, MyDistException)
53    # pylint: disable=unused-argument,unused-variable
54    def _kl(a, b, name=None):
55      return array_ops.identity([float("nan")])
56
57    # pylint: disable=unused-argument,unused-variable
58
59    with self.cached_session():
60      a = MyDistException(loc=0.0, scale=1.0, allow_nan_stats=False)
61      kl = kullback_leibler.kl_divergence(a, a, allow_nan_stats=False)
62      with self.assertRaisesOpError(
63          "KL calculation between .* and .* returned NaN values"):
64        self.evaluate(kl)
65      with self.assertRaisesOpError(
66          "KL calculation between .* and .* returned NaN values"):
67        a.kl_divergence(a).eval()
68      a = MyDistException(loc=0.0, scale=1.0, allow_nan_stats=True)
69      kl_ok = kullback_leibler.kl_divergence(a, a)
70      self.assertAllEqual([float("nan")], self.evaluate(kl_ok))
71      self_kl_ok = a.kl_divergence(a)
72      self.assertAllEqual([float("nan")], self.evaluate(self_kl_ok))
73      cross_ok = a.cross_entropy(a)
74      self.assertAllEqual([float("nan")], self.evaluate(cross_ok))
75
76  def testRegistrationFailures(self):
77
78    class MyDist(normal.Normal):
79      pass
80
81    with self.assertRaisesRegex(TypeError, "must be callable"):
82      kullback_leibler.RegisterKL(MyDist, MyDist)("blah")
83
84    # First registration is OK
85    kullback_leibler.RegisterKL(MyDist, MyDist)(lambda a, b: None)
86
87    # Second registration fails
88    with self.assertRaisesRegex(ValueError, "has already been registered"):
89      kullback_leibler.RegisterKL(MyDist, MyDist)(lambda a, b: None)
90
91  def testExactRegistrationsAllMatch(self):
92    for (k, v) in _DIVERGENCES.items():
93      self.assertEqual(v, _registered_kl(*k))
94
95  def _testIndirectRegistration(self, fn):
96
97    class Sub1(normal.Normal):
98
99      def entropy(self):
100        return ""
101
102    class Sub2(normal.Normal):
103
104      def entropy(self):
105        return ""
106
107    class Sub11(Sub1):
108
109      def entropy(self):
110        return ""
111
112    # pylint: disable=unused-argument,unused-variable
113    @kullback_leibler.RegisterKL(Sub1, Sub1)
114    def _kl11(a, b, name=None):
115      return "sub1-1"
116
117    @kullback_leibler.RegisterKL(Sub1, Sub2)
118    def _kl12(a, b, name=None):
119      return "sub1-2"
120
121    @kullback_leibler.RegisterKL(Sub2, Sub1)
122    def _kl21(a, b, name=None):
123      return "sub2-1"
124
125    # pylint: enable=unused-argument,unused_variable
126
127    sub1 = Sub1(loc=0.0, scale=1.0)
128    sub2 = Sub2(loc=0.0, scale=1.0)
129    sub11 = Sub11(loc=0.0, scale=1.0)
130
131    self.assertEqual("sub1-1", fn(sub1, sub1))
132    self.assertEqual("sub1-2", fn(sub1, sub2))
133    self.assertEqual("sub2-1", fn(sub2, sub1))
134    self.assertEqual("sub1-1", fn(sub11, sub11))
135    self.assertEqual("sub1-1", fn(sub11, sub1))
136    self.assertEqual("sub1-2", fn(sub11, sub2))
137    self.assertEqual("sub1-1", fn(sub11, sub1))
138    self.assertEqual("sub1-2", fn(sub11, sub2))
139    self.assertEqual("sub2-1", fn(sub2, sub11))
140    self.assertEqual("sub1-1", fn(sub1, sub11))
141
142  def testIndirectRegistrationKLFun(self):
143    self._testIndirectRegistration(kullback_leibler.kl_divergence)
144
145  def testIndirectRegistrationKLSelf(self):
146    self._testIndirectRegistration(
147        lambda p, q: p.kl_divergence(q))
148
149  def testIndirectRegistrationCrossEntropy(self):
150    self._testIndirectRegistration(
151        lambda p, q: p.cross_entropy(q))
152
153  def testFunctionCrossEntropy(self):
154    self._testIndirectRegistration(kullback_leibler.cross_entropy)
155
156
157if __name__ == "__main__":
158  test.main()
159