• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Affine bijector."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.contrib.distributions.python.ops import distribution_util
22from tensorflow.contrib.distributions.python.ops.shape import _DistributionShape
23from tensorflow.python.framework import dtypes
24from tensorflow.python.framework import ops
25from tensorflow.python.framework import tensor_util
26from tensorflow.python.ops import array_ops
27from tensorflow.python.ops import check_ops
28from tensorflow.python.ops import control_flow_ops
29from tensorflow.python.ops import math_ops
30from tensorflow.python.ops.distributions import bijector
31from tensorflow.python.ops.linalg import linalg
32from tensorflow.python.util import deprecation
33
34
35__all__ = [
36    "Affine",
37]
38
39
40@deprecation.deprecated(
41    "2018-10-01",
42    "The TensorFlow Distributions library has moved to "
43    "TensorFlow Probability "
44    "(https://github.com/tensorflow/probability). You "
45    "should update all references to use `tfp.distributions` "
46    "instead of `tf.contrib.distributions`.",
47    warn_once=True)
48def _as_tensor(x, name):
49  """Convenience to convert to `Tensor` or leave as `None`."""
50  return None if x is None else ops.convert_to_tensor(x, name=name)
51
52
53class Affine(bijector.Bijector):
54  """Compute `Y = g(X; shift, scale) = scale @ X + shift`.
55
56  Here `scale = c * I + diag(D1) + tril(L) + V @ diag(D2) @ V.T`.
57
58  In TF parlance, the `scale` term is logically equivalent to:
59
60  ```python
61  scale = (
62    scale_identity_multiplier * tf.diag(tf.ones(d)) +
63    tf.diag(scale_diag) +
64    scale_tril +
65    scale_perturb_factor @ diag(scale_perturb_diag) @
66      tf.transpose([scale_perturb_factor])
67  )
68  ```
69
70  The `scale` term is applied without necessarily materializing constituent
71  matrices, i.e., the matmul is [matrix-free](
72  https://en.wikipedia.org/wiki/Matrix-free_methods) when possible.
73
74  #### Examples
75
76  ```python
77  # Y = X
78  b = Affine()
79
80  # Y = X + shift
81  b = Affine(shift=[1., 2, 3])
82
83  # Y = 2 * I @ X.T + shift
84  b = Affine(shift=[1., 2, 3],
85             scale_identity_multiplier=2.)
86
87  # Y = tf.diag(d1) @ X.T + shift
88  b = Affine(shift=[1., 2, 3],
89             scale_diag=[-1., 2, 1])         # Implicitly 3x3.
90
91  # Y = (I + v * v.T) @ X.T + shift
92  b = Affine(shift=[1., 2, 3],
93             scale_perturb_factor=[[1., 0],
94                                   [0, 1],
95                                   [1, 1]])
96
97  # Y = (diag(d1) + v * diag(d2) * v.T) @ X.T + shift
98  b = Affine(shift=[1., 2, 3],
99             scale_diag=[1., 3, 3],          # Implicitly 3x3.
100             scale_perturb_diag=[2., 1],     # Implicitly 2x2.
101             scale_perturb_factor=[[1., 0],
102                                   [0, 1],
103                                   [1, 1]])
104
105  ```
106
107  """
108
109  @deprecation.deprecated(
110      "2018-10-01",
111      "The TensorFlow Distributions library has moved to "
112      "TensorFlow Probability "
113      "(https://github.com/tensorflow/probability). You "
114      "should update all references to use `tfp.distributions` "
115      "instead of `tf.contrib.distributions`.",
116      warn_once=True)
117  def __init__(self,
118               shift=None,
119               scale_identity_multiplier=None,
120               scale_diag=None,
121               scale_tril=None,
122               scale_perturb_factor=None,
123               scale_perturb_diag=None,
124               validate_args=False,
125               name="affine"):
126    """Instantiates the `Affine` bijector.
127
128    This `Bijector` is initialized with `shift` `Tensor` and `scale` arguments,
129    giving the forward operation:
130
131    ```none
132    Y = g(X) = scale @ X + shift
133    ```
134
135    where the `scale` term is logically equivalent to:
136
137    ```python
138    scale = (
139      scale_identity_multiplier * tf.diag(tf.ones(d)) +
140      tf.diag(scale_diag) +
141      scale_tril +
142      scale_perturb_factor @ diag(scale_perturb_diag) @
143        tf.transpose([scale_perturb_factor])
144    )
145    ```
146
147    If none of `scale_identity_multiplier`, `scale_diag`, or `scale_tril` are
148    specified then `scale += IdentityMatrix`. Otherwise specifying a
149    `scale` argument has the semantics of `scale += Expand(arg)`, i.e.,
150    `scale_diag != None` means `scale += tf.diag(scale_diag)`.
151
152    Args:
153      shift: Floating-point `Tensor`. If this is set to `None`, no shift is
154        applied.
155      scale_identity_multiplier: floating point rank 0 `Tensor` representing a
156        scaling done to the identity matrix.
157        When `scale_identity_multiplier = scale_diag = scale_tril = None` then
158        `scale += IdentityMatrix`. Otherwise no scaled-identity-matrix is added
159        to `scale`.
160      scale_diag: Floating-point `Tensor` representing the diagonal matrix.
161        `scale_diag` has shape [N1, N2, ...  k], which represents a k x k
162        diagonal matrix.
163        When `None` no diagonal term is added to `scale`.
164      scale_tril: Floating-point `Tensor` representing the diagonal matrix.
165        `scale_diag` has shape [N1, N2, ...  k, k], which represents a k x k
166        lower triangular matrix.
167        When `None` no `scale_tril` term is added to `scale`.
168        The upper triangular elements above the diagonal are ignored.
169      scale_perturb_factor: Floating-point `Tensor` representing factor matrix
170        with last two dimensions of shape `(k, r)`. When `None`, no rank-r
171        update is added to `scale`.
172      scale_perturb_diag: Floating-point `Tensor` representing the diagonal
173        matrix. `scale_perturb_diag` has shape [N1, N2, ...  r], which
174        represents an `r x r` diagonal matrix. When `None` low rank updates will
175        take the form `scale_perturb_factor * scale_perturb_factor.T`.
176      validate_args: Python `bool` indicating whether arguments should be
177        checked for correctness.
178      name: Python `str` name given to ops managed by this object.
179
180    Raises:
181      ValueError: if `perturb_diag` is specified but not `perturb_factor`.
182      TypeError: if `shift` has different `dtype` from `scale` arguments.
183    """
184    self._graph_parents = []
185    self._name = name
186    self._validate_args = validate_args
187
188    # Ambiguous definition of low rank update.
189    if scale_perturb_diag is not None and scale_perturb_factor is None:
190      raise ValueError("When scale_perturb_diag is specified, "
191                       "scale_perturb_factor must be specified.")
192
193    # Special case, only handling a scaled identity matrix. We don't know its
194    # dimensions, so this is special cased.
195    # We don't check identity_multiplier, since below we set it to 1. if all
196    # other scale args are None.
197    self._is_only_identity_multiplier = (scale_tril is None and
198                                         scale_diag is None and
199                                         scale_perturb_factor is None)
200
201    with self._name_scope("init", values=[
202        shift, scale_identity_multiplier, scale_diag, scale_tril,
203        scale_perturb_diag, scale_perturb_factor]):
204
205      # In the absence of `loc` and `scale`, we'll assume `dtype` is `float32`.
206      dtype = dtypes.float32
207
208      if shift is not None:
209        shift = ops.convert_to_tensor(shift, name="shift")
210        dtype = shift.dtype.base_dtype
211      self._shift = shift
212
213      # When no args are specified, pretend the scale matrix is the identity
214      # matrix.
215      if (self._is_only_identity_multiplier and
216          scale_identity_multiplier is None):
217        scale_identity_multiplier = ops.convert_to_tensor(1., dtype=dtype)
218
219      # self._create_scale_operator returns a LinearOperator in all cases
220      # except if self._is_only_identity_multiplier; in which case it
221      # returns a scalar Tensor.
222      scale = self._create_scale_operator(
223          identity_multiplier=scale_identity_multiplier,
224          diag=scale_diag,
225          tril=scale_tril,
226          perturb_diag=scale_perturb_diag,
227          perturb_factor=scale_perturb_factor,
228          shift=shift,
229          validate_args=validate_args)
230
231      if scale.dtype is not None:
232        dtype = scale.dtype.base_dtype
233
234      if scale is not None and not self._is_only_identity_multiplier:
235        if (shift is not None and
236            shift.dtype.base_dtype != scale.dtype.base_dtype):
237          raise TypeError(
238              "shift.dtype({}) is incompatible with scale.dtype({}).".format(
239                  shift.dtype, scale.dtype))
240
241        if scale.tensor_rank is not None:
242          batch_ndims = scale.tensor_rank - 2
243        else:
244          batch_ndims = scale.tensor_rank_tensor() - 2
245      else:
246        # We won't need shape inference when scale is None or when scale is a
247        # scalar.
248        batch_ndims = 0
249      self._scale = scale
250      self._shaper = _DistributionShape(
251          batch_ndims=batch_ndims,
252          event_ndims=1,
253          validate_args=validate_args)
254      super(Affine, self).__init__(
255          forward_min_event_ndims=1,
256          graph_parents=(
257              [self._scale] if tensor_util.is_tensor(self._scale)
258              else self._scale.graph_parents +
259              [self._shift] if self._shift is not None else []),
260          is_constant_jacobian=True,
261          dtype=dtype,
262          validate_args=validate_args,
263          name=name)
264
265  def _create_scale_operator(self, identity_multiplier, diag, tril,
266                             perturb_diag, perturb_factor, shift,
267                             validate_args):
268    """Construct `scale` from various components.
269
270    Args:
271      identity_multiplier: floating point rank 0 `Tensor` representing a scaling
272        done to the identity matrix.
273      diag: Floating-point `Tensor` representing the diagonal matrix.
274        `scale_diag` has shape [N1, N2, ...  k], which represents a k x k
275        diagonal matrix.
276      tril: Floating-point `Tensor` representing the diagonal matrix.
277        `scale_tril` has shape [N1, N2, ...  k], which represents a k x k lower
278        triangular matrix.
279      perturb_diag: Floating-point `Tensor` representing the diagonal matrix of
280        the low rank update.
281      perturb_factor: Floating-point `Tensor` representing factor matrix.
282      shift: Floating-point `Tensor` representing `shift in `scale @ X + shift`.
283      validate_args: Python `bool` indicating whether arguments should be
284        checked for correctness.
285
286    Returns:
287      scale. In the case of scaling by a constant, scale is a
288      floating point `Tensor`. Otherwise, scale is a `LinearOperator`.
289
290    Raises:
291      ValueError: if all of `tril`, `diag` and `identity_multiplier` are `None`.
292    """
293    identity_multiplier = _as_tensor(identity_multiplier, "identity_multiplier")
294    diag = _as_tensor(diag, "diag")
295    tril = _as_tensor(tril, "tril")
296    perturb_diag = _as_tensor(perturb_diag, "perturb_diag")
297    perturb_factor = _as_tensor(perturb_factor, "perturb_factor")
298
299    # If possible, use the low rank update to infer the shape of
300    # the identity matrix, when scale represents a scaled identity matrix
301    # with a low rank update.
302    shape_hint = None
303    if perturb_factor is not None:
304      shape_hint = distribution_util.dimension_size(perturb_factor, axis=-2)
305
306    if self._is_only_identity_multiplier:
307      if validate_args:
308        return control_flow_ops.with_dependencies(
309            [check_ops.assert_none_equal(
310                identity_multiplier,
311                array_ops.zeros([], identity_multiplier.dtype),
312                ["identity_multiplier should be non-zero."])],
313            identity_multiplier)
314      return identity_multiplier
315
316    scale = distribution_util.make_tril_scale(
317        loc=shift,
318        scale_tril=tril,
319        scale_diag=diag,
320        scale_identity_multiplier=identity_multiplier,
321        validate_args=validate_args,
322        assert_positive=False,
323        shape_hint=shape_hint)
324
325    if perturb_factor is not None:
326      return linalg.LinearOperatorLowRankUpdate(
327          scale,
328          u=perturb_factor,
329          diag_update=perturb_diag,
330          is_diag_update_positive=perturb_diag is None,
331          is_non_singular=True,  # Implied by is_positive_definite=True.
332          is_self_adjoint=True,
333          is_positive_definite=True,
334          is_square=True)
335
336    return scale
337
338  @property
339  def shift(self):
340    """The `shift` `Tensor` in `Y = scale @ X + shift`."""
341    return self._shift
342
343  @property
344  def scale(self):
345    """The `scale` `LinearOperator` in `Y = scale @ X + shift`."""
346    return self._scale
347
348  def _forward(self, x):
349    y = x
350    if self._is_only_identity_multiplier:
351      y *= self._scale
352      if self.shift is not None:
353        return y + self.shift
354      return y
355    y, sample_shape = self._shaper.make_batch_of_event_sample_matrices(
356        y, expand_batch_dim=False)
357    with ops.control_dependencies(self._maybe_check_scale() if
358                                  self.validate_args else []):
359      y = self.scale.matmul(y)
360    y = self._shaper.undo_make_batch_of_event_sample_matrices(
361        y, sample_shape, expand_batch_dim=False)
362    if self.shift is not None:
363      y += self.shift
364    return y
365
366  def _inverse(self, y):
367    x = y
368    if self.shift is not None:
369      x -= self.shift
370    if self._is_only_identity_multiplier:
371      return x / self._scale
372
373    x, sample_shape = self._shaper.make_batch_of_event_sample_matrices(
374        x, expand_batch_dim=False)
375    # Solve fails if the op is singular so we may safely skip this assertion.
376    x = self.scale.solve(x)
377    x = self._shaper.undo_make_batch_of_event_sample_matrices(
378        x, sample_shape, expand_batch_dim=False)
379    return x
380
381  def _forward_log_det_jacobian(self, x):
382    # is_constant_jacobian = True for this bijector, hence the
383    # `log_det_jacobian` need only be specified for a single input, as this will
384    # be tiled to match `event_ndims`.
385    if self._is_only_identity_multiplier:
386      # We don't pad in this case and instead let the fldj be applied
387      # via broadcast.
388      event_size = array_ops.shape(x)[-1]
389      event_size = math_ops.cast(event_size, dtype=self._scale.dtype)
390      return math_ops.log(math_ops.abs(self._scale)) * event_size
391
392    return self.scale.log_abs_determinant()
393
394  def _maybe_check_scale(self):
395    try:
396      return [self.scale.assert_non_singular()]
397    except NotImplementedError:
398      pass
399    return []
400