1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for Bijector.""" 16 17import abc 18 19import numpy as np 20 21from tensorflow.python.framework import constant_op 22from tensorflow.python.framework import test_util 23from tensorflow.python.ops import array_ops 24from tensorflow.python.ops import math_ops 25from tensorflow.python.ops.distributions import bijector 26from tensorflow.python.platform import test 27 28 29@test_util.run_all_in_graph_and_eager_modes 30class BaseBijectorTest(test.TestCase): 31 """Tests properties of the Bijector base-class.""" 32 33 def testIsAbstract(self): 34 # In Python 3.9, "abstract methods" become "abstract method" 35 with self.assertRaisesRegex(TypeError, 36 ("Can't instantiate abstract class Bijector " 37 "with abstract methods? __init__")): 38 bijector.Bijector() # pylint: disable=abstract-class-instantiated 39 40 def testDefaults(self): 41 class _BareBonesBijector(bijector.Bijector): 42 """Minimal specification of a `Bijector`.""" 43 44 def __init__(self): 45 super().__init__(forward_min_event_ndims=0) 46 47 bij = _BareBonesBijector() 48 self.assertEqual([], bij.graph_parents) 49 self.assertEqual(False, bij.is_constant_jacobian) 50 self.assertEqual(False, bij.validate_args) 51 self.assertEqual(None, bij.dtype) 52 self.assertEqual("bare_bones_bijector", bij.name) 53 54 for shape in [[], [1, 2], [1, 2, 3]]: 55 forward_event_shape_ = self.evaluate( 56 bij.inverse_event_shape_tensor(shape)) 57 inverse_event_shape_ = self.evaluate( 58 bij.forward_event_shape_tensor(shape)) 59 self.assertAllEqual(shape, forward_event_shape_) 60 self.assertAllEqual(shape, bij.forward_event_shape(shape)) 61 self.assertAllEqual(shape, inverse_event_shape_) 62 self.assertAllEqual(shape, bij.inverse_event_shape(shape)) 63 64 with self.assertRaisesRegex(NotImplementedError, "inverse not implemented"): 65 bij.inverse(0) 66 67 with self.assertRaisesRegex(NotImplementedError, "forward not implemented"): 68 bij.forward(0) 69 70 with self.assertRaisesRegex(NotImplementedError, 71 "inverse_log_det_jacobian not implemented"): 72 bij.inverse_log_det_jacobian(0, event_ndims=0) 73 74 with self.assertRaisesRegex(NotImplementedError, 75 "forward_log_det_jacobian not implemented"): 76 bij.forward_log_det_jacobian(0, event_ndims=0) 77 78 79class IntentionallyMissingError(Exception): 80 pass 81 82 83class BrokenBijector(bijector.Bijector): 84 """Forward and inverse are not inverses of each other.""" 85 86 def __init__( 87 self, forward_missing=False, inverse_missing=False, validate_args=False): 88 super().__init__( 89 validate_args=validate_args, forward_min_event_ndims=0, name="broken") 90 self._forward_missing = forward_missing 91 self._inverse_missing = inverse_missing 92 93 def _forward(self, x): 94 if self._forward_missing: 95 raise IntentionallyMissingError 96 return 2 * x 97 98 def _inverse(self, y): 99 if self._inverse_missing: 100 raise IntentionallyMissingError 101 return y / 2. 102 103 def _inverse_log_det_jacobian(self, y): # pylint:disable=unused-argument 104 if self._inverse_missing: 105 raise IntentionallyMissingError 106 return -math_ops.log(2.) 107 108 def _forward_log_det_jacobian(self, x): # pylint:disable=unused-argument 109 if self._forward_missing: 110 raise IntentionallyMissingError 111 return math_ops.log(2.) 112 113 114class BijectorTestEventNdims(test.TestCase): 115 116 def testBijectorNonIntegerEventNdims(self): 117 bij = BrokenBijector() 118 with self.assertRaisesRegex(ValueError, "Expected integer"): 119 bij.forward_log_det_jacobian(1., event_ndims=1.5) 120 with self.assertRaisesRegex(ValueError, "Expected integer"): 121 bij.inverse_log_det_jacobian(1., event_ndims=1.5) 122 123 def testBijectorArrayEventNdims(self): 124 bij = BrokenBijector() 125 with self.assertRaisesRegex(ValueError, "Expected scalar"): 126 bij.forward_log_det_jacobian(1., event_ndims=(1, 2)) 127 with self.assertRaisesRegex(ValueError, "Expected scalar"): 128 bij.inverse_log_det_jacobian(1., event_ndims=(1, 2)) 129 130 @test_util.run_deprecated_v1 131 def testBijectorDynamicEventNdims(self): 132 bij = BrokenBijector(validate_args=True) 133 event_ndims = array_ops.placeholder(dtype=np.int32, shape=None) 134 with self.cached_session(): 135 with self.assertRaisesOpError("Expected scalar"): 136 bij.forward_log_det_jacobian(1., event_ndims=event_ndims).eval({ 137 event_ndims: (1, 2)}) 138 with self.assertRaisesOpError("Expected scalar"): 139 bij.inverse_log_det_jacobian(1., event_ndims=event_ndims).eval({ 140 event_ndims: (1, 2)}) 141 142 143class BijectorCachingTestBase(metaclass=abc.ABCMeta): 144 145 @abc.abstractproperty 146 def broken_bijector_cls(self): 147 # return a BrokenBijector type Bijector, since this will test the caching. 148 raise IntentionallyMissingError("Not implemented") 149 150 def testCachingOfForwardResults(self): 151 broken_bijector = self.broken_bijector_cls(inverse_missing=True) 152 x = constant_op.constant(1.1) 153 154 # Call forward and forward_log_det_jacobian one-by-one (not together). 155 y = broken_bijector.forward(x) 156 _ = broken_bijector.forward_log_det_jacobian(x, event_ndims=0) 157 158 # Now, everything should be cached if the argument is y. 159 broken_bijector.inverse(y) 160 broken_bijector.inverse_log_det_jacobian(y, event_ndims=0) 161 162 # Different event_ndims should not be cached. 163 with self.assertRaises(IntentionallyMissingError): 164 broken_bijector.inverse_log_det_jacobian(y, event_ndims=1) 165 166 def testCachingOfInverseResults(self): 167 broken_bijector = self.broken_bijector_cls(forward_missing=True) 168 y = constant_op.constant(1.1) 169 170 # Call inverse and inverse_log_det_jacobian one-by-one (not together). 171 x = broken_bijector.inverse(y) 172 _ = broken_bijector.inverse_log_det_jacobian(y, event_ndims=0) 173 174 # Now, everything should be cached if the argument is x. 175 broken_bijector.forward(x) 176 broken_bijector.forward_log_det_jacobian(x, event_ndims=0) 177 178 # Different event_ndims should not be cached. 179 with self.assertRaises(IntentionallyMissingError): 180 broken_bijector.forward_log_det_jacobian(x, event_ndims=1) 181 182 183class BijectorCachingTest(BijectorCachingTestBase, test.TestCase): 184 """Test caching with BrokenBijector.""" 185 186 @property 187 def broken_bijector_cls(self): 188 return BrokenBijector 189 190 191class ExpOnlyJacobian(bijector.Bijector): 192 """Only used for jacobian calculations.""" 193 194 def __init__(self, forward_min_event_ndims=0): 195 super().__init__( 196 validate_args=False, 197 is_constant_jacobian=False, 198 forward_min_event_ndims=forward_min_event_ndims, 199 name="exp") 200 201 def _inverse_log_det_jacobian(self, y): 202 return -math_ops.log(y) 203 204 def _forward_log_det_jacobian(self, x): 205 return math_ops.log(x) 206 207 208class ConstantJacobian(bijector.Bijector): 209 """Only used for jacobian calculations.""" 210 211 def __init__(self, forward_min_event_ndims=0): 212 super().__init__( 213 validate_args=False, 214 is_constant_jacobian=True, 215 forward_min_event_ndims=forward_min_event_ndims, 216 name="c") 217 218 def _inverse_log_det_jacobian(self, y): 219 return constant_op.constant(2., y.dtype) 220 221 def _forward_log_det_jacobian(self, x): 222 return constant_op.constant(-2., x.dtype) 223 224 225class BijectorReduceEventDimsTest(test.TestCase): 226 """Test caching with BrokenBijector.""" 227 228 def testReduceEventNdimsForward(self): 229 x = [[[1., 2.], [3., 4.]]] 230 bij = ExpOnlyJacobian() 231 self.assertAllClose( 232 np.log(x), 233 self.evaluate(bij.forward_log_det_jacobian(x, event_ndims=0))) 234 self.assertAllClose( 235 np.sum(np.log(x), axis=-1), 236 self.evaluate(bij.forward_log_det_jacobian(x, event_ndims=1))) 237 self.assertAllClose( 238 np.sum(np.log(x), axis=(-1, -2)), 239 self.evaluate(bij.forward_log_det_jacobian(x, event_ndims=2))) 240 241 def testReduceEventNdimsForwardRaiseError(self): 242 x = [[[1., 2.], [3., 4.]]] 243 bij = ExpOnlyJacobian(forward_min_event_ndims=1) 244 with self.assertRaisesRegex(ValueError, "must be larger than"): 245 bij.forward_log_det_jacobian(x, event_ndims=0) 246 247 def testReduceEventNdimsInverse(self): 248 x = [[[1., 2.], [3., 4.]]] 249 bij = ExpOnlyJacobian() 250 self.assertAllClose( 251 -np.log(x), 252 self.evaluate(bij.inverse_log_det_jacobian(x, event_ndims=0))) 253 self.assertAllClose( 254 np.sum(-np.log(x), axis=-1), 255 self.evaluate(bij.inverse_log_det_jacobian(x, event_ndims=1))) 256 self.assertAllClose( 257 np.sum(-np.log(x), axis=(-1, -2)), 258 self.evaluate(bij.inverse_log_det_jacobian(x, event_ndims=2))) 259 260 def testReduceEventNdimsInverseRaiseError(self): 261 x = [[[1., 2.], [3., 4.]]] 262 bij = ExpOnlyJacobian(forward_min_event_ndims=1) 263 with self.assertRaisesRegex(ValueError, "must be larger than"): 264 bij.inverse_log_det_jacobian(x, event_ndims=0) 265 266 def testReduceEventNdimsForwardConstJacobian(self): 267 x = [[[1., 2.], [3., 4.]]] 268 bij = ConstantJacobian() 269 self.assertAllClose( 270 -2., 271 self.evaluate(bij.forward_log_det_jacobian(x, event_ndims=0))) 272 self.assertAllClose( 273 -4., 274 self.evaluate(bij.forward_log_det_jacobian(x, event_ndims=1))) 275 self.assertAllClose( 276 -8., 277 self.evaluate(bij.forward_log_det_jacobian(x, event_ndims=2))) 278 279 def testReduceEventNdimsInverseConstJacobian(self): 280 x = [[[1., 2.], [3., 4.]]] 281 bij = ConstantJacobian() 282 self.assertAllClose( 283 2., 284 self.evaluate(bij.inverse_log_det_jacobian(x, event_ndims=0))) 285 self.assertAllClose( 286 4., 287 self.evaluate(bij.inverse_log_det_jacobian(x, event_ndims=1))) 288 self.assertAllClose( 289 8., 290 self.evaluate(bij.inverse_log_det_jacobian(x, event_ndims=2))) 291 292 @test_util.run_deprecated_v1 293 def testHandlesNonStaticEventNdims(self): 294 x_ = [[[1., 2.], [3., 4.]]] 295 x = array_ops.placeholder_with_default(x_, shape=None) 296 event_ndims = array_ops.placeholder(dtype=np.int32, shape=[]) 297 bij = ExpOnlyJacobian(forward_min_event_ndims=1) 298 bij.inverse_log_det_jacobian(x, event_ndims=event_ndims) 299 with self.cached_session() as sess: 300 ildj = sess.run(bij.inverse_log_det_jacobian(x, event_ndims=event_ndims), 301 feed_dict={event_ndims: 1}) 302 self.assertAllClose(-np.log(x_), ildj) 303 304 305if __name__ == "__main__": 306 test.main() 307