• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Identity bijector."""
16
17from tensorflow.python.framework import constant_op
18from tensorflow.python.ops.distributions import bijector
19from tensorflow.python.util import deprecation
20
21
22__all__ = [
23    "Identity",
24]
25
26
27class Identity(bijector.Bijector):
28  """Compute Y = g(X) = X.
29
30    Example Use:
31
32    ```python
33    # Create the Y=g(X)=X transform which is intended for Tensors with 1 batch
34    # ndim and 1 event ndim (i.e., vector of vectors).
35    identity = Identity()
36    x = [[1., 2],
37         [3, 4]]
38    x == identity.forward(x) == identity.inverse(x)
39    ```
40
41  """
42
43  @deprecation.deprecated(
44      "2019-01-01",
45      "The TensorFlow Distributions library has moved to "
46      "TensorFlow Probability "
47      "(https://github.com/tensorflow/probability). You "
48      "should update all references to use `tfp.distributions` "
49      "instead of `tf.distributions`.",
50      warn_once=True)
51  def __init__(self, validate_args=False, name="identity"):
52    super(Identity, self).__init__(
53        forward_min_event_ndims=0,
54        is_constant_jacobian=True,
55        validate_args=validate_args,
56        name=name)
57
58  def _forward(self, x):
59    return x
60
61  def _inverse(self, y):
62    return y
63
64  def _inverse_log_det_jacobian(self, y):
65    return constant_op.constant(0., dtype=y.dtype)
66
67  def _forward_log_det_jacobian(self, x):
68    return constant_op.constant(0., dtype=x.dtype)
69