1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Affine bijector.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.contrib.distributions.python.ops import distribution_util 22from tensorflow.contrib.distributions.python.ops.shape import _DistributionShape 23from tensorflow.python.framework import dtypes 24from tensorflow.python.framework import ops 25from tensorflow.python.framework import tensor_util 26from tensorflow.python.ops import array_ops 27from tensorflow.python.ops import check_ops 28from tensorflow.python.ops import control_flow_ops 29from tensorflow.python.ops import math_ops 30from tensorflow.python.ops.distributions import bijector 31from tensorflow.python.ops.linalg import linalg 32from tensorflow.python.util import deprecation 33 34 35__all__ = [ 36 "Affine", 37] 38 39 40@deprecation.deprecated( 41 "2018-10-01", 42 "The TensorFlow Distributions library has moved to " 43 "TensorFlow Probability " 44 "(https://github.com/tensorflow/probability). You " 45 "should update all references to use `tfp.distributions` " 46 "instead of `tf.contrib.distributions`.", 47 warn_once=True) 48def _as_tensor(x, name): 49 """Convenience to convert to `Tensor` or leave as `None`.""" 50 return None if x is None else ops.convert_to_tensor(x, name=name) 51 52 53class Affine(bijector.Bijector): 54 """Compute `Y = g(X; shift, scale) = scale @ X + shift`. 55 56 Here `scale = c * I + diag(D1) + tril(L) + V @ diag(D2) @ V.T`. 57 58 In TF parlance, the `scale` term is logically equivalent to: 59 60 ```python 61 scale = ( 62 scale_identity_multiplier * tf.diag(tf.ones(d)) + 63 tf.diag(scale_diag) + 64 scale_tril + 65 scale_perturb_factor @ diag(scale_perturb_diag) @ 66 tf.transpose([scale_perturb_factor]) 67 ) 68 ``` 69 70 The `scale` term is applied without necessarily materializing constituent 71 matrices, i.e., the matmul is [matrix-free]( 72 https://en.wikipedia.org/wiki/Matrix-free_methods) when possible. 73 74 #### Examples 75 76 ```python 77 # Y = X 78 b = Affine() 79 80 # Y = X + shift 81 b = Affine(shift=[1., 2, 3]) 82 83 # Y = 2 * I @ X.T + shift 84 b = Affine(shift=[1., 2, 3], 85 scale_identity_multiplier=2.) 86 87 # Y = tf.diag(d1) @ X.T + shift 88 b = Affine(shift=[1., 2, 3], 89 scale_diag=[-1., 2, 1]) # Implicitly 3x3. 90 91 # Y = (I + v * v.T) @ X.T + shift 92 b = Affine(shift=[1., 2, 3], 93 scale_perturb_factor=[[1., 0], 94 [0, 1], 95 [1, 1]]) 96 97 # Y = (diag(d1) + v * diag(d2) * v.T) @ X.T + shift 98 b = Affine(shift=[1., 2, 3], 99 scale_diag=[1., 3, 3], # Implicitly 3x3. 100 scale_perturb_diag=[2., 1], # Implicitly 2x2. 101 scale_perturb_factor=[[1., 0], 102 [0, 1], 103 [1, 1]]) 104 105 ``` 106 107 """ 108 109 @deprecation.deprecated( 110 "2018-10-01", 111 "The TensorFlow Distributions library has moved to " 112 "TensorFlow Probability " 113 "(https://github.com/tensorflow/probability). You " 114 "should update all references to use `tfp.distributions` " 115 "instead of `tf.contrib.distributions`.", 116 warn_once=True) 117 def __init__(self, 118 shift=None, 119 scale_identity_multiplier=None, 120 scale_diag=None, 121 scale_tril=None, 122 scale_perturb_factor=None, 123 scale_perturb_diag=None, 124 validate_args=False, 125 name="affine"): 126 """Instantiates the `Affine` bijector. 127 128 This `Bijector` is initialized with `shift` `Tensor` and `scale` arguments, 129 giving the forward operation: 130 131 ```none 132 Y = g(X) = scale @ X + shift 133 ``` 134 135 where the `scale` term is logically equivalent to: 136 137 ```python 138 scale = ( 139 scale_identity_multiplier * tf.diag(tf.ones(d)) + 140 tf.diag(scale_diag) + 141 scale_tril + 142 scale_perturb_factor @ diag(scale_perturb_diag) @ 143 tf.transpose([scale_perturb_factor]) 144 ) 145 ``` 146 147 If none of `scale_identity_multiplier`, `scale_diag`, or `scale_tril` are 148 specified then `scale += IdentityMatrix`. Otherwise specifying a 149 `scale` argument has the semantics of `scale += Expand(arg)`, i.e., 150 `scale_diag != None` means `scale += tf.diag(scale_diag)`. 151 152 Args: 153 shift: Floating-point `Tensor`. If this is set to `None`, no shift is 154 applied. 155 scale_identity_multiplier: floating point rank 0 `Tensor` representing a 156 scaling done to the identity matrix. 157 When `scale_identity_multiplier = scale_diag = scale_tril = None` then 158 `scale += IdentityMatrix`. Otherwise no scaled-identity-matrix is added 159 to `scale`. 160 scale_diag: Floating-point `Tensor` representing the diagonal matrix. 161 `scale_diag` has shape [N1, N2, ... k], which represents a k x k 162 diagonal matrix. 163 When `None` no diagonal term is added to `scale`. 164 scale_tril: Floating-point `Tensor` representing the diagonal matrix. 165 `scale_diag` has shape [N1, N2, ... k, k], which represents a k x k 166 lower triangular matrix. 167 When `None` no `scale_tril` term is added to `scale`. 168 The upper triangular elements above the diagonal are ignored. 169 scale_perturb_factor: Floating-point `Tensor` representing factor matrix 170 with last two dimensions of shape `(k, r)`. When `None`, no rank-r 171 update is added to `scale`. 172 scale_perturb_diag: Floating-point `Tensor` representing the diagonal 173 matrix. `scale_perturb_diag` has shape [N1, N2, ... r], which 174 represents an `r x r` diagonal matrix. When `None` low rank updates will 175 take the form `scale_perturb_factor * scale_perturb_factor.T`. 176 validate_args: Python `bool` indicating whether arguments should be 177 checked for correctness. 178 name: Python `str` name given to ops managed by this object. 179 180 Raises: 181 ValueError: if `perturb_diag` is specified but not `perturb_factor`. 182 TypeError: if `shift` has different `dtype` from `scale` arguments. 183 """ 184 self._graph_parents = [] 185 self._name = name 186 self._validate_args = validate_args 187 188 # Ambiguous definition of low rank update. 189 if scale_perturb_diag is not None and scale_perturb_factor is None: 190 raise ValueError("When scale_perturb_diag is specified, " 191 "scale_perturb_factor must be specified.") 192 193 # Special case, only handling a scaled identity matrix. We don't know its 194 # dimensions, so this is special cased. 195 # We don't check identity_multiplier, since below we set it to 1. if all 196 # other scale args are None. 197 self._is_only_identity_multiplier = (scale_tril is None and 198 scale_diag is None and 199 scale_perturb_factor is None) 200 201 with self._name_scope("init", values=[ 202 shift, scale_identity_multiplier, scale_diag, scale_tril, 203 scale_perturb_diag, scale_perturb_factor]): 204 205 # In the absence of `loc` and `scale`, we'll assume `dtype` is `float32`. 206 dtype = dtypes.float32 207 208 if shift is not None: 209 shift = ops.convert_to_tensor(shift, name="shift") 210 dtype = shift.dtype.base_dtype 211 self._shift = shift 212 213 # When no args are specified, pretend the scale matrix is the identity 214 # matrix. 215 if (self._is_only_identity_multiplier and 216 scale_identity_multiplier is None): 217 scale_identity_multiplier = ops.convert_to_tensor(1., dtype=dtype) 218 219 # self._create_scale_operator returns a LinearOperator in all cases 220 # except if self._is_only_identity_multiplier; in which case it 221 # returns a scalar Tensor. 222 scale = self._create_scale_operator( 223 identity_multiplier=scale_identity_multiplier, 224 diag=scale_diag, 225 tril=scale_tril, 226 perturb_diag=scale_perturb_diag, 227 perturb_factor=scale_perturb_factor, 228 shift=shift, 229 validate_args=validate_args) 230 231 if scale.dtype is not None: 232 dtype = scale.dtype.base_dtype 233 234 if scale is not None and not self._is_only_identity_multiplier: 235 if (shift is not None and 236 shift.dtype.base_dtype != scale.dtype.base_dtype): 237 raise TypeError( 238 "shift.dtype({}) is incompatible with scale.dtype({}).".format( 239 shift.dtype, scale.dtype)) 240 241 if scale.tensor_rank is not None: 242 batch_ndims = scale.tensor_rank - 2 243 else: 244 batch_ndims = scale.tensor_rank_tensor() - 2 245 else: 246 # We won't need shape inference when scale is None or when scale is a 247 # scalar. 248 batch_ndims = 0 249 self._scale = scale 250 self._shaper = _DistributionShape( 251 batch_ndims=batch_ndims, 252 event_ndims=1, 253 validate_args=validate_args) 254 super(Affine, self).__init__( 255 forward_min_event_ndims=1, 256 graph_parents=( 257 [self._scale] if tensor_util.is_tensor(self._scale) 258 else self._scale.graph_parents + 259 [self._shift] if self._shift is not None else []), 260 is_constant_jacobian=True, 261 dtype=dtype, 262 validate_args=validate_args, 263 name=name) 264 265 def _create_scale_operator(self, identity_multiplier, diag, tril, 266 perturb_diag, perturb_factor, shift, 267 validate_args): 268 """Construct `scale` from various components. 269 270 Args: 271 identity_multiplier: floating point rank 0 `Tensor` representing a scaling 272 done to the identity matrix. 273 diag: Floating-point `Tensor` representing the diagonal matrix. 274 `scale_diag` has shape [N1, N2, ... k], which represents a k x k 275 diagonal matrix. 276 tril: Floating-point `Tensor` representing the diagonal matrix. 277 `scale_tril` has shape [N1, N2, ... k], which represents a k x k lower 278 triangular matrix. 279 perturb_diag: Floating-point `Tensor` representing the diagonal matrix of 280 the low rank update. 281 perturb_factor: Floating-point `Tensor` representing factor matrix. 282 shift: Floating-point `Tensor` representing `shift in `scale @ X + shift`. 283 validate_args: Python `bool` indicating whether arguments should be 284 checked for correctness. 285 286 Returns: 287 scale. In the case of scaling by a constant, scale is a 288 floating point `Tensor`. Otherwise, scale is a `LinearOperator`. 289 290 Raises: 291 ValueError: if all of `tril`, `diag` and `identity_multiplier` are `None`. 292 """ 293 identity_multiplier = _as_tensor(identity_multiplier, "identity_multiplier") 294 diag = _as_tensor(diag, "diag") 295 tril = _as_tensor(tril, "tril") 296 perturb_diag = _as_tensor(perturb_diag, "perturb_diag") 297 perturb_factor = _as_tensor(perturb_factor, "perturb_factor") 298 299 # If possible, use the low rank update to infer the shape of 300 # the identity matrix, when scale represents a scaled identity matrix 301 # with a low rank update. 302 shape_hint = None 303 if perturb_factor is not None: 304 shape_hint = distribution_util.dimension_size(perturb_factor, axis=-2) 305 306 if self._is_only_identity_multiplier: 307 if validate_args: 308 return control_flow_ops.with_dependencies( 309 [check_ops.assert_none_equal( 310 identity_multiplier, 311 array_ops.zeros([], identity_multiplier.dtype), 312 ["identity_multiplier should be non-zero."])], 313 identity_multiplier) 314 return identity_multiplier 315 316 scale = distribution_util.make_tril_scale( 317 loc=shift, 318 scale_tril=tril, 319 scale_diag=diag, 320 scale_identity_multiplier=identity_multiplier, 321 validate_args=validate_args, 322 assert_positive=False, 323 shape_hint=shape_hint) 324 325 if perturb_factor is not None: 326 return linalg.LinearOperatorLowRankUpdate( 327 scale, 328 u=perturb_factor, 329 diag_update=perturb_diag, 330 is_diag_update_positive=perturb_diag is None, 331 is_non_singular=True, # Implied by is_positive_definite=True. 332 is_self_adjoint=True, 333 is_positive_definite=True, 334 is_square=True) 335 336 return scale 337 338 @property 339 def shift(self): 340 """The `shift` `Tensor` in `Y = scale @ X + shift`.""" 341 return self._shift 342 343 @property 344 def scale(self): 345 """The `scale` `LinearOperator` in `Y = scale @ X + shift`.""" 346 return self._scale 347 348 def _forward(self, x): 349 y = x 350 if self._is_only_identity_multiplier: 351 y *= self._scale 352 if self.shift is not None: 353 return y + self.shift 354 return y 355 y, sample_shape = self._shaper.make_batch_of_event_sample_matrices( 356 y, expand_batch_dim=False) 357 with ops.control_dependencies(self._maybe_check_scale() if 358 self.validate_args else []): 359 y = self.scale.matmul(y) 360 y = self._shaper.undo_make_batch_of_event_sample_matrices( 361 y, sample_shape, expand_batch_dim=False) 362 if self.shift is not None: 363 y += self.shift 364 return y 365 366 def _inverse(self, y): 367 x = y 368 if self.shift is not None: 369 x -= self.shift 370 if self._is_only_identity_multiplier: 371 return x / self._scale 372 373 x, sample_shape = self._shaper.make_batch_of_event_sample_matrices( 374 x, expand_batch_dim=False) 375 # Solve fails if the op is singular so we may safely skip this assertion. 376 x = self.scale.solve(x) 377 x = self._shaper.undo_make_batch_of_event_sample_matrices( 378 x, sample_shape, expand_batch_dim=False) 379 return x 380 381 def _forward_log_det_jacobian(self, x): 382 # is_constant_jacobian = True for this bijector, hence the 383 # `log_det_jacobian` need only be specified for a single input, as this will 384 # be tiled to match `event_ndims`. 385 if self._is_only_identity_multiplier: 386 # We don't pad in this case and instead let the fldj be applied 387 # via broadcast. 388 event_size = array_ops.shape(x)[-1] 389 event_size = math_ops.cast(event_size, dtype=self._scale.dtype) 390 return math_ops.log(math_ops.abs(self._scale)) * event_size 391 392 return self.scale.log_abs_determinant() 393 394 def _maybe_check_scale(self): 395 try: 396 return [self.scale.assert_non_singular()] 397 except NotImplementedError: 398 pass 399 return [] 400