• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Identity Tests."""
16
17from tensorflow.python.framework import test_util
18from tensorflow.python.ops.distributions import bijector_test_util
19from tensorflow.python.ops.distributions import identity_bijector
20from tensorflow.python.platform import test
21
22
23class IdentityBijectorTest(test.TestCase):
24  """Tests correctness of the Y = g(X) = X transformation."""
25
26  def testBijector(self):
27    bijector = identity_bijector.Identity(validate_args=True)
28    self.assertEqual("identity", bijector.name)
29    x = [[[0.], [1.]]]
30    self.assertAllEqual(x, self.evaluate(bijector.forward(x)))
31    self.assertAllEqual(x, self.evaluate(bijector.inverse(x)))
32    self.assertAllEqual(
33        0.,
34        self.evaluate(
35            bijector.inverse_log_det_jacobian(x, event_ndims=3)))
36    self.assertAllEqual(
37        0.,
38        self.evaluate(
39            bijector.forward_log_det_jacobian(x, event_ndims=3)))
40
41  @test_util.run_deprecated_v1
42  def testScalarCongruency(self):
43    with self.cached_session():
44      bijector = identity_bijector.Identity()
45      bijector_test_util.assert_scalar_congruency(
46          bijector, lower_x=-2., upper_x=2.)
47
48
49if __name__ == "__main__":
50  test.main()
51