1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""AffineLinearOperator bijector.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.contrib.distributions.python.ops.shape import _DistributionShape 22from tensorflow.python.framework import constant_op 23from tensorflow.python.framework import dtypes 24from tensorflow.python.framework import ops 25from tensorflow.python.ops.distributions import bijector 26from tensorflow.python.ops.linalg import linear_operator 27from tensorflow.python.util import deprecation 28 29 30__all__ = [ 31 "AffineLinearOperator", 32] 33 34 35class AffineLinearOperator(bijector.Bijector): 36 """Compute `Y = g(X; shift, scale) = scale @ X + shift`. 37 38 `shift` is a numeric `Tensor` and `scale` is a `LinearOperator`. 39 40 If `X` is a scalar then the forward transformation is: `scale * X + shift` 41 where `*` denotes the scalar product. 42 43 Note: we don't always simply transpose `X` (but write it this way for 44 brevity). Actually the input `X` undergoes the following transformation 45 before being premultiplied by `scale`: 46 47 1. If there are no sample dims, we call `X = tf.expand_dims(X, 0)`, i.e., 48 `new_sample_shape = [1]`. Otherwise do nothing. 49 2. The sample shape is flattened to have one dimension, i.e., 50 `new_sample_shape = [n]` where `n = tf.reduce_prod(old_sample_shape)`. 51 3. The sample dim is cyclically rotated left by 1, i.e., 52 `new_shape = [B1,...,Bb, k, n]` where `n` is as above, `k` is the 53 event_shape, and `B1,...,Bb` are the batch shapes for each of `b` batch 54 dimensions. 55 56 (For more details see `shape.make_batch_of_event_sample_matrices`.) 57 58 The result of the above transformation is that `X` can be regarded as a batch 59 of matrices where each column is a draw from the distribution. After 60 premultiplying by `scale`, we take the inverse of this procedure. The input 61 `Y` also undergoes the same transformation before/after premultiplying by 62 `inv(scale)`. 63 64 Example Use: 65 66 ```python 67 linalg = tf.linalg 68 69 x = [1., 2, 3] 70 71 shift = [-1., 0., 1] 72 diag = [1., 2, 3] 73 scale = linalg.LinearOperatorDiag(diag) 74 affine = AffineLinearOperator(shift, scale) 75 # In this case, `forward` is equivalent to: 76 # y = scale @ x + shift 77 y = affine.forward(x) # [0., 4, 10] 78 79 shift = [2., 3, 1] 80 tril = [[1., 0, 0], 81 [2, 1, 0], 82 [3, 2, 1]] 83 scale = linalg.LinearOperatorLowerTriangular(tril) 84 affine = AffineLinearOperator(shift, scale) 85 # In this case, `forward` is equivalent to: 86 # np.squeeze(np.matmul(tril, np.expand_dims(x, -1)), -1) + shift 87 y = affine.forward(x) # [3., 7, 11] 88 ``` 89 90 """ 91 92 @deprecation.deprecated( 93 "2018-10-01", 94 "The TensorFlow Distributions library has moved to " 95 "TensorFlow Probability " 96 "(https://github.com/tensorflow/probability). You " 97 "should update all references to use `tfp.distributions` " 98 "instead of `tf.contrib.distributions`.", 99 warn_once=True) 100 def __init__(self, 101 shift=None, 102 scale=None, 103 validate_args=False, 104 name="affine_linear_operator"): 105 """Instantiates the `AffineLinearOperator` bijector. 106 107 Args: 108 shift: Floating-point `Tensor`. 109 scale: Subclass of `LinearOperator`. Represents the (batch) positive 110 definite matrix `M` in `R^{k x k}`. 111 validate_args: Python `bool` indicating whether arguments should be 112 checked for correctness. 113 name: Python `str` name given to ops managed by this object. 114 115 Raises: 116 TypeError: if `scale` is not a `LinearOperator`. 117 TypeError: if `shift.dtype` does not match `scale.dtype`. 118 ValueError: if not `scale.is_non_singular`. 119 """ 120 self._graph_parents = [] 121 self._name = name 122 self._validate_args = validate_args 123 graph_parents = [] 124 with self._name_scope("init", values=[shift]): 125 # In the absence of `loc` and `scale`, we'll assume `dtype` is `float32`. 126 dtype = dtypes.float32 127 128 if shift is not None: 129 shift = ops.convert_to_tensor(shift, name="shift") 130 graph_parents += [shift] 131 dtype = shift.dtype.base_dtype 132 self._shift = shift 133 134 if scale is not None: 135 if (shift is not None and 136 shift.dtype.base_dtype != scale.dtype.base_dtype): 137 raise TypeError( 138 "shift.dtype({}) is incompatible with scale.dtype({}).".format( 139 shift.dtype, scale.dtype)) 140 if not isinstance(scale, linear_operator.LinearOperator): 141 raise TypeError("scale is not an instance of tf.LinearOperator") 142 if validate_args and not scale.is_non_singular: 143 raise ValueError("Scale matrix must be non-singular.") 144 graph_parents += scale.graph_parents 145 if scale.tensor_rank is not None: 146 batch_ndims = scale.tensor_rank - 2 147 else: 148 batch_ndims = scale.tensor_rank_tensor() - 2 149 graph_parents += [batch_ndims] 150 if scale.dtype is not None: 151 dtype = scale.dtype.base_dtype 152 else: 153 batch_ndims = 0 # We won't need shape inference when scale is None. 154 self._scale = scale 155 self._shaper = _DistributionShape( 156 batch_ndims=batch_ndims, 157 event_ndims=1, 158 validate_args=validate_args) 159 super(AffineLinearOperator, self).__init__( 160 forward_min_event_ndims=1, 161 graph_parents=graph_parents, 162 is_constant_jacobian=True, 163 dtype=dtype, 164 validate_args=validate_args, 165 name=name) 166 167 @property 168 def shift(self): 169 """The `shift` `Tensor` in `Y = scale @ X + shift`.""" 170 return self._shift 171 172 @property 173 def scale(self): 174 """The `scale` `LinearOperator` in `Y = scale @ X + shift`.""" 175 return self._scale 176 177 def _forward(self, x): 178 y = x 179 if self.scale is not None: 180 y, sample_shape = self._shaper.make_batch_of_event_sample_matrices( 181 y, expand_batch_dim=False) 182 with ops.control_dependencies(self._maybe_collect_assertions() if 183 self.validate_args else []): 184 y = self.scale.matmul(y) 185 y = self._shaper.undo_make_batch_of_event_sample_matrices( 186 y, sample_shape, expand_batch_dim=False) 187 if self.shift is not None: 188 y += self.shift 189 return y 190 191 def _inverse(self, y): 192 x = y 193 if self.shift is not None: 194 x -= self.shift 195 if self.scale is not None: 196 x, sample_shape = self._shaper.make_batch_of_event_sample_matrices( 197 x, expand_batch_dim=False) 198 # Solve fails if the op is singular so we may safely skip this assertion. 199 x = self.scale.solve(x) 200 x = self._shaper.undo_make_batch_of_event_sample_matrices( 201 x, sample_shape, expand_batch_dim=False) 202 return x 203 204 def _forward_log_det_jacobian(self, x): 205 # is_constant_jacobian = True for this bijector, hence the 206 # `log_det_jacobian` need only be specified for a single input, as this will 207 # be tiled to match `event_ndims`. 208 if self.scale is None: 209 return constant_op.constant(0., dtype=x.dtype.base_dtype) 210 211 with ops.control_dependencies(self._maybe_collect_assertions() if 212 self.validate_args else []): 213 return self.scale.log_abs_determinant() 214 215 def _maybe_collect_assertions(self): 216 try: 217 return [self.scale.assert_non_singular()] 218 except NotImplementedError: 219 pass 220 return [] 221