• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Bijector unit-test utilities."""
16
17import numpy as np
18
19from tensorflow.python.framework import ops
20from tensorflow.python.ops import math_ops
21from tensorflow.python.ops.distributions import uniform as uniform_lib
22
23
24def assert_finite(array):
25  if not np.isfinite(array).all():
26    raise AssertionError("array was not all finite. %s" % array[:15])
27
28
29def assert_strictly_increasing(array):
30  np.testing.assert_array_less(0., np.diff(array))
31
32
33def assert_strictly_decreasing(array):
34  np.testing.assert_array_less(np.diff(array), 0.)
35
36
37def assert_strictly_monotonic(array):
38  if array[0] < array[-1]:
39    assert_strictly_increasing(array)
40  else:
41    assert_strictly_decreasing(array)
42
43
44def assert_scalar_congruency(bijector,
45                             lower_x,
46                             upper_x,
47                             n=int(10e3),
48                             rtol=0.01,
49                             sess=None):
50  """Assert `bijector`'s forward/inverse/inverse_log_det_jacobian are congruent.
51
52  We draw samples `X ~ U(lower_x, upper_x)`, then feed these through the
53  `bijector` in order to check that:
54
55  1. the forward is strictly monotonic.
56  2. the forward/inverse methods are inverses of each other.
57  3. the jacobian is the correct change of measure.
58
59  This can only be used for a Bijector mapping open subsets of the real line
60  to themselves.  This is due to the fact that this test compares the `prob`
61  before/after transformation with the Lebesgue measure on the line.
62
63  Args:
64    bijector:  Instance of Bijector
65    lower_x:  Python scalar.
66    upper_x:  Python scalar.  Must have `lower_x < upper_x`, and both must be in
67      the domain of the `bijector`.  The `bijector` should probably not produce
68      huge variation in values in the interval `(lower_x, upper_x)`, or else
69      the variance based check of the Jacobian will require small `rtol` or
70      huge `n`.
71    n:  Number of samples to draw for the checks.
72    rtol:  Positive number.  Used for the Jacobian check.
73    sess:  `tf.compat.v1.Session`.  Defaults to the default session.
74
75  Raises:
76    AssertionError:  If tests fail.
77  """
78  # Checks and defaults.
79  if sess is None:
80    sess = ops.get_default_session()
81
82  # Should be monotonic over this interval
83  ten_x_pts = np.linspace(lower_x, upper_x, num=10).astype(np.float32)
84  if bijector.dtype is not None:
85    ten_x_pts = ten_x_pts.astype(bijector.dtype.as_numpy_dtype)
86  forward_on_10_pts = bijector.forward(ten_x_pts)
87
88  # Set the lower/upper limits in the range of the bijector.
89  lower_y, upper_y = sess.run(
90      [bijector.forward(lower_x), bijector.forward(upper_x)])
91  if upper_y < lower_y:  # If bijector.forward is a decreasing function.
92    lower_y, upper_y = upper_y, lower_y
93
94  # Uniform samples from the domain, range.
95  uniform_x_samps = uniform_lib.Uniform(
96      low=lower_x, high=upper_x).sample(n, seed=0)
97  uniform_y_samps = uniform_lib.Uniform(
98      low=lower_y, high=upper_y).sample(n, seed=1)
99
100  # These compositions should be the identity.
101  inverse_forward_x = bijector.inverse(bijector.forward(uniform_x_samps))
102  forward_inverse_y = bijector.forward(bijector.inverse(uniform_y_samps))
103
104  # For a < b, and transformation y = y(x),
105  # (b - a) = \int_a^b dx = \int_{y(a)}^{y(b)} |dx/dy| dy
106  # "change_measure_dy_dx" below is a Monte Carlo approximation to the right
107  # hand side, which should then be close to the left, which is (b - a).
108  # We assume event_ndims=0 because we assume scalar -> scalar. The log_det
109  # methods will handle whether they expect event_ndims > 0.
110  dy_dx = math_ops.exp(bijector.inverse_log_det_jacobian(
111      uniform_y_samps, event_ndims=0))
112  # E[|dx/dy|] under Uniform[lower_y, upper_y]
113  # = \int_{y(a)}^{y(b)} |dx/dy| dP(u), where dP(u) is the uniform measure
114  expectation_of_dy_dx_under_uniform = math_ops.reduce_mean(dy_dx)
115  # dy = dP(u) * (upper_y - lower_y)
116  change_measure_dy_dx = (
117      (upper_y - lower_y) * expectation_of_dy_dx_under_uniform)
118
119  # We'll also check that dy_dx = 1 / dx_dy.
120  dx_dy = math_ops.exp(
121      bijector.forward_log_det_jacobian(
122          bijector.inverse(uniform_y_samps), event_ndims=0))
123
124  [
125      forward_on_10_pts_v,
126      dy_dx_v,
127      dx_dy_v,
128      change_measure_dy_dx_v,
129      uniform_x_samps_v,
130      uniform_y_samps_v,
131      inverse_forward_x_v,
132      forward_inverse_y_v,
133  ] = sess.run([
134      forward_on_10_pts,
135      dy_dx,
136      dx_dy,
137      change_measure_dy_dx,
138      uniform_x_samps,
139      uniform_y_samps,
140      inverse_forward_x,
141      forward_inverse_y,
142  ])
143
144  assert_strictly_monotonic(forward_on_10_pts_v)
145  # Composition of forward/inverse should be the identity.
146  np.testing.assert_allclose(
147      inverse_forward_x_v, uniform_x_samps_v, atol=1e-5, rtol=1e-3)
148  np.testing.assert_allclose(
149      forward_inverse_y_v, uniform_y_samps_v, atol=1e-5, rtol=1e-3)
150  # Change of measure should be correct.
151  np.testing.assert_allclose(
152      upper_x - lower_x, change_measure_dy_dx_v, atol=0, rtol=rtol)
153  # Inverse Jacobian should be equivalent to the reciprocal of the forward
154  # Jacobian.
155  np.testing.assert_allclose(
156      dy_dx_v, np.divide(1., dx_dy_v), atol=1e-5, rtol=1e-3)
157
158
159def assert_bijective_and_finite(
160    bijector, x, y, event_ndims, atol=0, rtol=1e-5, sess=None):
161  """Assert that forward/inverse (along with jacobians) are inverses and finite.
162
163  It is recommended to use x and y values that are very very close to the edge
164  of the Bijector's domain.
165
166  Args:
167    bijector:  A Bijector instance.
168    x:  np.array of values in the domain of bijector.forward.
169    y:  np.array of values in the domain of bijector.inverse.
170    event_ndims: Integer describing the number of event dimensions this bijector
171      operates on.
172    atol:  Absolute tolerance.
173    rtol:  Relative tolerance.
174    sess:  TensorFlow session.  Defaults to the default session.
175
176  Raises:
177    AssertionError:  If tests fail.
178  """
179  sess = sess or ops.get_default_session()
180
181  # These are the incoming points, but people often create a crazy range of
182  # values for which these end up being bad, especially in 16bit.
183  assert_finite(x)
184  assert_finite(y)
185
186  f_x = bijector.forward(x)
187  g_y = bijector.inverse(y)
188
189  [
190      x_from_x,
191      y_from_y,
192      ildj_f_x,
193      fldj_x,
194      ildj_y,
195      fldj_g_y,
196      f_x_v,
197      g_y_v,
198  ] = sess.run([
199      bijector.inverse(f_x),
200      bijector.forward(g_y),
201      bijector.inverse_log_det_jacobian(f_x, event_ndims=event_ndims),
202      bijector.forward_log_det_jacobian(x, event_ndims=event_ndims),
203      bijector.inverse_log_det_jacobian(y, event_ndims=event_ndims),
204      bijector.forward_log_det_jacobian(g_y, event_ndims=event_ndims),
205      f_x,
206      g_y,
207  ])
208
209  assert_finite(x_from_x)
210  assert_finite(y_from_y)
211  assert_finite(ildj_f_x)
212  assert_finite(fldj_x)
213  assert_finite(ildj_y)
214  assert_finite(fldj_g_y)
215  assert_finite(f_x_v)
216  assert_finite(g_y_v)
217
218  np.testing.assert_allclose(x_from_x, x, atol=atol, rtol=rtol)
219  np.testing.assert_allclose(y_from_y, y, atol=atol, rtol=rtol)
220  np.testing.assert_allclose(-ildj_f_x, fldj_x, atol=atol, rtol=rtol)
221  np.testing.assert_allclose(-ildj_y, fldj_g_y, atol=atol, rtol=rtol)
222