• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for Bijector."""
16
17import abc
18
19import numpy as np
20
21from tensorflow.python.framework import constant_op
22from tensorflow.python.framework import test_util
23from tensorflow.python.ops import array_ops
24from tensorflow.python.ops import math_ops
25from tensorflow.python.ops.distributions import bijector
26from tensorflow.python.platform import test
27
28
29@test_util.run_all_in_graph_and_eager_modes
30class BaseBijectorTest(test.TestCase):
31  """Tests properties of the Bijector base-class."""
32
33  def testIsAbstract(self):
34    # In Python 3.9, "abstract methods" become "abstract method"
35    with self.assertRaisesRegex(TypeError,
36                                ("Can't instantiate abstract class Bijector "
37                                 "with abstract methods? __init__")):
38      bijector.Bijector()  # pylint: disable=abstract-class-instantiated
39
40  def testDefaults(self):
41    class _BareBonesBijector(bijector.Bijector):
42      """Minimal specification of a `Bijector`."""
43
44      def __init__(self):
45        super().__init__(forward_min_event_ndims=0)
46
47    bij = _BareBonesBijector()
48    self.assertEqual([], bij.graph_parents)
49    self.assertEqual(False, bij.is_constant_jacobian)
50    self.assertEqual(False, bij.validate_args)
51    self.assertEqual(None, bij.dtype)
52    self.assertEqual("bare_bones_bijector", bij.name)
53
54    for shape in [[], [1, 2], [1, 2, 3]]:
55      forward_event_shape_ = self.evaluate(
56          bij.inverse_event_shape_tensor(shape))
57      inverse_event_shape_ = self.evaluate(
58          bij.forward_event_shape_tensor(shape))
59      self.assertAllEqual(shape, forward_event_shape_)
60      self.assertAllEqual(shape, bij.forward_event_shape(shape))
61      self.assertAllEqual(shape, inverse_event_shape_)
62      self.assertAllEqual(shape, bij.inverse_event_shape(shape))
63
64    with self.assertRaisesRegex(NotImplementedError, "inverse not implemented"):
65      bij.inverse(0)
66
67    with self.assertRaisesRegex(NotImplementedError, "forward not implemented"):
68      bij.forward(0)
69
70    with self.assertRaisesRegex(NotImplementedError,
71                                "inverse_log_det_jacobian not implemented"):
72      bij.inverse_log_det_jacobian(0, event_ndims=0)
73
74    with self.assertRaisesRegex(NotImplementedError,
75                                "forward_log_det_jacobian not implemented"):
76      bij.forward_log_det_jacobian(0, event_ndims=0)
77
78
79class IntentionallyMissingError(Exception):
80  pass
81
82
83class BrokenBijector(bijector.Bijector):
84  """Forward and inverse are not inverses of each other."""
85
86  def __init__(
87      self, forward_missing=False, inverse_missing=False, validate_args=False):
88    super().__init__(
89        validate_args=validate_args, forward_min_event_ndims=0, name="broken")
90    self._forward_missing = forward_missing
91    self._inverse_missing = inverse_missing
92
93  def _forward(self, x):
94    if self._forward_missing:
95      raise IntentionallyMissingError
96    return 2 * x
97
98  def _inverse(self, y):
99    if self._inverse_missing:
100      raise IntentionallyMissingError
101    return y / 2.
102
103  def _inverse_log_det_jacobian(self, y):  # pylint:disable=unused-argument
104    if self._inverse_missing:
105      raise IntentionallyMissingError
106    return -math_ops.log(2.)
107
108  def _forward_log_det_jacobian(self, x):  # pylint:disable=unused-argument
109    if self._forward_missing:
110      raise IntentionallyMissingError
111    return math_ops.log(2.)
112
113
114class BijectorTestEventNdims(test.TestCase):
115
116  def testBijectorNonIntegerEventNdims(self):
117    bij = BrokenBijector()
118    with self.assertRaisesRegex(ValueError, "Expected integer"):
119      bij.forward_log_det_jacobian(1., event_ndims=1.5)
120    with self.assertRaisesRegex(ValueError, "Expected integer"):
121      bij.inverse_log_det_jacobian(1., event_ndims=1.5)
122
123  def testBijectorArrayEventNdims(self):
124    bij = BrokenBijector()
125    with self.assertRaisesRegex(ValueError, "Expected scalar"):
126      bij.forward_log_det_jacobian(1., event_ndims=(1, 2))
127    with self.assertRaisesRegex(ValueError, "Expected scalar"):
128      bij.inverse_log_det_jacobian(1., event_ndims=(1, 2))
129
130  @test_util.run_deprecated_v1
131  def testBijectorDynamicEventNdims(self):
132    bij = BrokenBijector(validate_args=True)
133    event_ndims = array_ops.placeholder(dtype=np.int32, shape=None)
134    with self.cached_session():
135      with self.assertRaisesOpError("Expected scalar"):
136        bij.forward_log_det_jacobian(1., event_ndims=event_ndims).eval({
137            event_ndims: (1, 2)})
138      with self.assertRaisesOpError("Expected scalar"):
139        bij.inverse_log_det_jacobian(1., event_ndims=event_ndims).eval({
140            event_ndims: (1, 2)})
141
142
143class BijectorCachingTestBase(metaclass=abc.ABCMeta):
144
145  @abc.abstractproperty
146  def broken_bijector_cls(self):
147    # return a BrokenBijector type Bijector, since this will test the caching.
148    raise IntentionallyMissingError("Not implemented")
149
150  def testCachingOfForwardResults(self):
151    broken_bijector = self.broken_bijector_cls(inverse_missing=True)
152    x = constant_op.constant(1.1)
153
154    # Call forward and forward_log_det_jacobian one-by-one (not together).
155    y = broken_bijector.forward(x)
156    _ = broken_bijector.forward_log_det_jacobian(x, event_ndims=0)
157
158    # Now, everything should be cached if the argument is y.
159    broken_bijector.inverse(y)
160    broken_bijector.inverse_log_det_jacobian(y, event_ndims=0)
161
162    # Different event_ndims should not be cached.
163    with self.assertRaises(IntentionallyMissingError):
164      broken_bijector.inverse_log_det_jacobian(y, event_ndims=1)
165
166  def testCachingOfInverseResults(self):
167    broken_bijector = self.broken_bijector_cls(forward_missing=True)
168    y = constant_op.constant(1.1)
169
170    # Call inverse and inverse_log_det_jacobian one-by-one (not together).
171    x = broken_bijector.inverse(y)
172    _ = broken_bijector.inverse_log_det_jacobian(y, event_ndims=0)
173
174    # Now, everything should be cached if the argument is x.
175    broken_bijector.forward(x)
176    broken_bijector.forward_log_det_jacobian(x, event_ndims=0)
177
178    # Different event_ndims should not be cached.
179    with self.assertRaises(IntentionallyMissingError):
180      broken_bijector.forward_log_det_jacobian(x, event_ndims=1)
181
182
183class BijectorCachingTest(BijectorCachingTestBase, test.TestCase):
184  """Test caching with BrokenBijector."""
185
186  @property
187  def broken_bijector_cls(self):
188    return BrokenBijector
189
190
191class ExpOnlyJacobian(bijector.Bijector):
192  """Only used for jacobian calculations."""
193
194  def __init__(self, forward_min_event_ndims=0):
195    super().__init__(
196        validate_args=False,
197        is_constant_jacobian=False,
198        forward_min_event_ndims=forward_min_event_ndims,
199        name="exp")
200
201  def _inverse_log_det_jacobian(self, y):
202    return -math_ops.log(y)
203
204  def _forward_log_det_jacobian(self, x):
205    return math_ops.log(x)
206
207
208class ConstantJacobian(bijector.Bijector):
209  """Only used for jacobian calculations."""
210
211  def __init__(self, forward_min_event_ndims=0):
212    super().__init__(
213        validate_args=False,
214        is_constant_jacobian=True,
215        forward_min_event_ndims=forward_min_event_ndims,
216        name="c")
217
218  def _inverse_log_det_jacobian(self, y):
219    return constant_op.constant(2., y.dtype)
220
221  def _forward_log_det_jacobian(self, x):
222    return constant_op.constant(-2., x.dtype)
223
224
225class BijectorReduceEventDimsTest(test.TestCase):
226  """Test caching with BrokenBijector."""
227
228  def testReduceEventNdimsForward(self):
229    x = [[[1., 2.], [3., 4.]]]
230    bij = ExpOnlyJacobian()
231    self.assertAllClose(
232        np.log(x),
233        self.evaluate(bij.forward_log_det_jacobian(x, event_ndims=0)))
234    self.assertAllClose(
235        np.sum(np.log(x), axis=-1),
236        self.evaluate(bij.forward_log_det_jacobian(x, event_ndims=1)))
237    self.assertAllClose(
238        np.sum(np.log(x), axis=(-1, -2)),
239        self.evaluate(bij.forward_log_det_jacobian(x, event_ndims=2)))
240
241  def testReduceEventNdimsForwardRaiseError(self):
242    x = [[[1., 2.], [3., 4.]]]
243    bij = ExpOnlyJacobian(forward_min_event_ndims=1)
244    with self.assertRaisesRegex(ValueError, "must be larger than"):
245      bij.forward_log_det_jacobian(x, event_ndims=0)
246
247  def testReduceEventNdimsInverse(self):
248    x = [[[1., 2.], [3., 4.]]]
249    bij = ExpOnlyJacobian()
250    self.assertAllClose(
251        -np.log(x),
252        self.evaluate(bij.inverse_log_det_jacobian(x, event_ndims=0)))
253    self.assertAllClose(
254        np.sum(-np.log(x), axis=-1),
255        self.evaluate(bij.inverse_log_det_jacobian(x, event_ndims=1)))
256    self.assertAllClose(
257        np.sum(-np.log(x), axis=(-1, -2)),
258        self.evaluate(bij.inverse_log_det_jacobian(x, event_ndims=2)))
259
260  def testReduceEventNdimsInverseRaiseError(self):
261    x = [[[1., 2.], [3., 4.]]]
262    bij = ExpOnlyJacobian(forward_min_event_ndims=1)
263    with self.assertRaisesRegex(ValueError, "must be larger than"):
264      bij.inverse_log_det_jacobian(x, event_ndims=0)
265
266  def testReduceEventNdimsForwardConstJacobian(self):
267    x = [[[1., 2.], [3., 4.]]]
268    bij = ConstantJacobian()
269    self.assertAllClose(
270        -2.,
271        self.evaluate(bij.forward_log_det_jacobian(x, event_ndims=0)))
272    self.assertAllClose(
273        -4.,
274        self.evaluate(bij.forward_log_det_jacobian(x, event_ndims=1)))
275    self.assertAllClose(
276        -8.,
277        self.evaluate(bij.forward_log_det_jacobian(x, event_ndims=2)))
278
279  def testReduceEventNdimsInverseConstJacobian(self):
280    x = [[[1., 2.], [3., 4.]]]
281    bij = ConstantJacobian()
282    self.assertAllClose(
283        2.,
284        self.evaluate(bij.inverse_log_det_jacobian(x, event_ndims=0)))
285    self.assertAllClose(
286        4.,
287        self.evaluate(bij.inverse_log_det_jacobian(x, event_ndims=1)))
288    self.assertAllClose(
289        8.,
290        self.evaluate(bij.inverse_log_det_jacobian(x, event_ndims=2)))
291
292  @test_util.run_deprecated_v1
293  def testHandlesNonStaticEventNdims(self):
294    x_ = [[[1., 2.], [3., 4.]]]
295    x = array_ops.placeholder_with_default(x_, shape=None)
296    event_ndims = array_ops.placeholder(dtype=np.int32, shape=[])
297    bij = ExpOnlyJacobian(forward_min_event_ndims=1)
298    bij.inverse_log_det_jacobian(x, event_ndims=event_ndims)
299    with self.cached_session() as sess:
300      ildj = sess.run(bij.inverse_log_det_jacobian(x, event_ndims=event_ndims),
301                      feed_dict={event_ndims: 1})
302    self.assertAllClose(-np.log(x_), ildj)
303
304
305if __name__ == "__main__":
306  test.main()
307