1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for distributions KL mechanism.""" 16 17from tensorflow.python.framework import test_util 18from tensorflow.python.ops import array_ops 19from tensorflow.python.ops.distributions import kullback_leibler 20from tensorflow.python.ops.distributions import normal 21from tensorflow.python.platform import test 22 23# pylint: disable=protected-access 24_DIVERGENCES = kullback_leibler._DIVERGENCES 25_registered_kl = kullback_leibler._registered_kl 26 27# pylint: enable=protected-access 28 29 30class KLTest(test.TestCase): 31 32 def testRegistration(self): 33 34 class MyDist(normal.Normal): 35 pass 36 37 # Register KL to a lambda that spits out the name parameter 38 @kullback_leibler.RegisterKL(MyDist, MyDist) 39 def _kl(a, b, name=None): # pylint: disable=unused-argument,unused-variable 40 return name 41 42 a = MyDist(loc=0.0, scale=1.0) 43 self.assertEqual("OK", kullback_leibler.kl_divergence(a, a, name="OK")) 44 45 @test_util.run_deprecated_v1 46 def testDomainErrorExceptions(self): 47 48 class MyDistException(normal.Normal): 49 pass 50 51 # Register KL to a lambda that spits out the name parameter 52 @kullback_leibler.RegisterKL(MyDistException, MyDistException) 53 # pylint: disable=unused-argument,unused-variable 54 def _kl(a, b, name=None): 55 return array_ops.identity([float("nan")]) 56 57 # pylint: disable=unused-argument,unused-variable 58 59 with self.cached_session(): 60 a = MyDistException(loc=0.0, scale=1.0, allow_nan_stats=False) 61 kl = kullback_leibler.kl_divergence(a, a, allow_nan_stats=False) 62 with self.assertRaisesOpError( 63 "KL calculation between .* and .* returned NaN values"): 64 self.evaluate(kl) 65 with self.assertRaisesOpError( 66 "KL calculation between .* and .* returned NaN values"): 67 a.kl_divergence(a).eval() 68 a = MyDistException(loc=0.0, scale=1.0, allow_nan_stats=True) 69 kl_ok = kullback_leibler.kl_divergence(a, a) 70 self.assertAllEqual([float("nan")], self.evaluate(kl_ok)) 71 self_kl_ok = a.kl_divergence(a) 72 self.assertAllEqual([float("nan")], self.evaluate(self_kl_ok)) 73 cross_ok = a.cross_entropy(a) 74 self.assertAllEqual([float("nan")], self.evaluate(cross_ok)) 75 76 def testRegistrationFailures(self): 77 78 class MyDist(normal.Normal): 79 pass 80 81 with self.assertRaisesRegex(TypeError, "must be callable"): 82 kullback_leibler.RegisterKL(MyDist, MyDist)("blah") 83 84 # First registration is OK 85 kullback_leibler.RegisterKL(MyDist, MyDist)(lambda a, b: None) 86 87 # Second registration fails 88 with self.assertRaisesRegex(ValueError, "has already been registered"): 89 kullback_leibler.RegisterKL(MyDist, MyDist)(lambda a, b: None) 90 91 def testExactRegistrationsAllMatch(self): 92 for (k, v) in _DIVERGENCES.items(): 93 self.assertEqual(v, _registered_kl(*k)) 94 95 def _testIndirectRegistration(self, fn): 96 97 class Sub1(normal.Normal): 98 99 def entropy(self): 100 return "" 101 102 class Sub2(normal.Normal): 103 104 def entropy(self): 105 return "" 106 107 class Sub11(Sub1): 108 109 def entropy(self): 110 return "" 111 112 # pylint: disable=unused-argument,unused-variable 113 @kullback_leibler.RegisterKL(Sub1, Sub1) 114 def _kl11(a, b, name=None): 115 return "sub1-1" 116 117 @kullback_leibler.RegisterKL(Sub1, Sub2) 118 def _kl12(a, b, name=None): 119 return "sub1-2" 120 121 @kullback_leibler.RegisterKL(Sub2, Sub1) 122 def _kl21(a, b, name=None): 123 return "sub2-1" 124 125 # pylint: enable=unused-argument,unused_variable 126 127 sub1 = Sub1(loc=0.0, scale=1.0) 128 sub2 = Sub2(loc=0.0, scale=1.0) 129 sub11 = Sub11(loc=0.0, scale=1.0) 130 131 self.assertEqual("sub1-1", fn(sub1, sub1)) 132 self.assertEqual("sub1-2", fn(sub1, sub2)) 133 self.assertEqual("sub2-1", fn(sub2, sub1)) 134 self.assertEqual("sub1-1", fn(sub11, sub11)) 135 self.assertEqual("sub1-1", fn(sub11, sub1)) 136 self.assertEqual("sub1-2", fn(sub11, sub2)) 137 self.assertEqual("sub1-1", fn(sub11, sub1)) 138 self.assertEqual("sub1-2", fn(sub11, sub2)) 139 self.assertEqual("sub2-1", fn(sub2, sub11)) 140 self.assertEqual("sub1-1", fn(sub1, sub11)) 141 142 def testIndirectRegistrationKLFun(self): 143 self._testIndirectRegistration(kullback_leibler.kl_divergence) 144 145 def testIndirectRegistrationKLSelf(self): 146 self._testIndirectRegistration( 147 lambda p, q: p.kl_divergence(q)) 148 149 def testIndirectRegistrationCrossEntropy(self): 150 self._testIndirectRegistration( 151 lambda p, q: p.cross_entropy(q)) 152 153 def testFunctionCrossEntropy(self): 154 self._testIndirectRegistration(kullback_leibler.cross_entropy) 155 156 157if __name__ == "__main__": 158 test.main() 159