• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""AffineLinearOperator bijector."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.contrib.distributions.python.ops.shape import _DistributionShape
22from tensorflow.python.framework import constant_op
23from tensorflow.python.framework import dtypes
24from tensorflow.python.framework import ops
25from tensorflow.python.ops.distributions import bijector
26from tensorflow.python.ops.linalg import linear_operator
27from tensorflow.python.util import deprecation
28
29
30__all__ = [
31    "AffineLinearOperator",
32]
33
34
35class AffineLinearOperator(bijector.Bijector):
36  """Compute `Y = g(X; shift, scale) = scale @ X + shift`.
37
38  `shift` is a numeric `Tensor` and `scale` is a `LinearOperator`.
39
40  If `X` is a scalar then the forward transformation is: `scale * X + shift`
41  where `*` denotes the scalar product.
42
43  Note: we don't always simply transpose `X` (but write it this way for
44  brevity). Actually the input `X` undergoes the following transformation
45  before being premultiplied by `scale`:
46
47  1. If there are no sample dims, we call `X = tf.expand_dims(X, 0)`, i.e.,
48     `new_sample_shape = [1]`. Otherwise do nothing.
49  2. The sample shape is flattened to have one dimension, i.e.,
50     `new_sample_shape = [n]` where `n = tf.reduce_prod(old_sample_shape)`.
51  3. The sample dim is cyclically rotated left by 1, i.e.,
52     `new_shape = [B1,...,Bb, k, n]` where `n` is as above, `k` is the
53     event_shape, and `B1,...,Bb` are the batch shapes for each of `b` batch
54     dimensions.
55
56  (For more details see `shape.make_batch_of_event_sample_matrices`.)
57
58  The result of the above transformation is that `X` can be regarded as a batch
59  of matrices where each column is a draw from the distribution. After
60  premultiplying by `scale`, we take the inverse of this procedure. The input
61  `Y` also undergoes the same transformation before/after premultiplying by
62  `inv(scale)`.
63
64  Example Use:
65
66  ```python
67  linalg = tf.linalg
68
69  x = [1., 2, 3]
70
71  shift = [-1., 0., 1]
72  diag = [1., 2, 3]
73  scale = linalg.LinearOperatorDiag(diag)
74  affine = AffineLinearOperator(shift, scale)
75  # In this case, `forward` is equivalent to:
76  # y = scale @ x + shift
77  y = affine.forward(x)  # [0., 4, 10]
78
79  shift = [2., 3, 1]
80  tril = [[1., 0, 0],
81          [2, 1, 0],
82          [3, 2, 1]]
83  scale = linalg.LinearOperatorLowerTriangular(tril)
84  affine = AffineLinearOperator(shift, scale)
85  # In this case, `forward` is equivalent to:
86  # np.squeeze(np.matmul(tril, np.expand_dims(x, -1)), -1) + shift
87  y = affine.forward(x)  # [3., 7, 11]
88  ```
89
90  """
91
92  @deprecation.deprecated(
93      "2018-10-01",
94      "The TensorFlow Distributions library has moved to "
95      "TensorFlow Probability "
96      "(https://github.com/tensorflow/probability). You "
97      "should update all references to use `tfp.distributions` "
98      "instead of `tf.contrib.distributions`.",
99      warn_once=True)
100  def __init__(self,
101               shift=None,
102               scale=None,
103               validate_args=False,
104               name="affine_linear_operator"):
105    """Instantiates the `AffineLinearOperator` bijector.
106
107    Args:
108      shift: Floating-point `Tensor`.
109      scale:  Subclass of `LinearOperator`. Represents the (batch) positive
110        definite matrix `M` in `R^{k x k}`.
111      validate_args: Python `bool` indicating whether arguments should be
112        checked for correctness.
113      name: Python `str` name given to ops managed by this object.
114
115    Raises:
116      TypeError: if `scale` is not a `LinearOperator`.
117      TypeError: if `shift.dtype` does not match `scale.dtype`.
118      ValueError: if not `scale.is_non_singular`.
119    """
120    self._graph_parents = []
121    self._name = name
122    self._validate_args = validate_args
123    graph_parents = []
124    with self._name_scope("init", values=[shift]):
125      # In the absence of `loc` and `scale`, we'll assume `dtype` is `float32`.
126      dtype = dtypes.float32
127
128      if shift is not None:
129        shift = ops.convert_to_tensor(shift, name="shift")
130        graph_parents += [shift]
131        dtype = shift.dtype.base_dtype
132      self._shift = shift
133
134      if scale is not None:
135        if (shift is not None and
136            shift.dtype.base_dtype != scale.dtype.base_dtype):
137          raise TypeError(
138              "shift.dtype({}) is incompatible with scale.dtype({}).".format(
139                  shift.dtype, scale.dtype))
140        if not isinstance(scale, linear_operator.LinearOperator):
141          raise TypeError("scale is not an instance of tf.LinearOperator")
142        if validate_args and not scale.is_non_singular:
143          raise ValueError("Scale matrix must be non-singular.")
144        graph_parents += scale.graph_parents
145        if scale.tensor_rank is not None:
146          batch_ndims = scale.tensor_rank - 2
147        else:
148          batch_ndims = scale.tensor_rank_tensor() - 2
149          graph_parents += [batch_ndims]
150        if scale.dtype is not None:
151          dtype = scale.dtype.base_dtype
152      else:
153        batch_ndims = 0  # We won't need shape inference when scale is None.
154      self._scale = scale
155      self._shaper = _DistributionShape(
156          batch_ndims=batch_ndims,
157          event_ndims=1,
158          validate_args=validate_args)
159      super(AffineLinearOperator, self).__init__(
160          forward_min_event_ndims=1,
161          graph_parents=graph_parents,
162          is_constant_jacobian=True,
163          dtype=dtype,
164          validate_args=validate_args,
165          name=name)
166
167  @property
168  def shift(self):
169    """The `shift` `Tensor` in `Y = scale @ X + shift`."""
170    return self._shift
171
172  @property
173  def scale(self):
174    """The `scale` `LinearOperator` in `Y = scale @ X + shift`."""
175    return self._scale
176
177  def _forward(self, x):
178    y = x
179    if self.scale is not None:
180      y, sample_shape = self._shaper.make_batch_of_event_sample_matrices(
181          y, expand_batch_dim=False)
182      with ops.control_dependencies(self._maybe_collect_assertions() if
183                                    self.validate_args else []):
184        y = self.scale.matmul(y)
185      y = self._shaper.undo_make_batch_of_event_sample_matrices(
186          y, sample_shape, expand_batch_dim=False)
187    if self.shift is not None:
188      y += self.shift
189    return y
190
191  def _inverse(self, y):
192    x = y
193    if self.shift is not None:
194      x -= self.shift
195    if self.scale is not None:
196      x, sample_shape = self._shaper.make_batch_of_event_sample_matrices(
197          x, expand_batch_dim=False)
198      # Solve fails if the op is singular so we may safely skip this assertion.
199      x = self.scale.solve(x)
200      x = self._shaper.undo_make_batch_of_event_sample_matrices(
201          x, sample_shape, expand_batch_dim=False)
202    return x
203
204  def _forward_log_det_jacobian(self, x):
205    # is_constant_jacobian = True for this bijector, hence the
206    # `log_det_jacobian` need only be specified for a single input, as this will
207    # be tiled to match `event_ndims`.
208    if self.scale is None:
209      return constant_op.constant(0., dtype=x.dtype.base_dtype)
210
211    with ops.control_dependencies(self._maybe_collect_assertions() if
212                                  self.validate_args else []):
213      return self.scale.log_abs_determinant()
214
215  def _maybe_collect_assertions(self):
216    try:
217      return [self.scale.assert_non_singular()]
218    except NotImplementedError:
219      pass
220    return []
221