1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Bijector base.""" 16 17import abc 18import collections 19import contextlib 20import re 21 22import numpy as np 23 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import ops 26from tensorflow.python.framework import tensor_shape 27from tensorflow.python.framework import tensor_util 28from tensorflow.python.ops import array_ops 29from tensorflow.python.ops import check_ops 30from tensorflow.python.ops import math_ops 31from tensorflow.python.ops.distributions import util as distribution_util 32from tensorflow.python.util import object_identity 33 34 35__all__ = [ 36 "Bijector", 37] 38 39 40class _Mapping(collections.namedtuple( 41 "_Mapping", ["x", "y", "ildj_map", "kwargs"])): 42 """Helper class to make it easier to manage caching in `Bijector`.""" 43 44 def __new__(cls, x=None, y=None, ildj_map=None, kwargs=None): 45 """Custom __new__ so namedtuple items have defaults. 46 47 Args: 48 x: `Tensor`. Forward. 49 y: `Tensor`. Inverse. 50 ildj_map: `Dictionary`. This is a mapping from event_ndims to a `Tensor` 51 representing the inverse log det jacobian. 52 kwargs: Python dictionary. Extra args supplied to 53 forward/inverse/etc functions. 54 55 Returns: 56 mapping: New instance of _Mapping. 57 """ 58 return super(_Mapping, cls).__new__(cls, x, y, ildj_map, kwargs) 59 60 @property 61 def x_key(self): 62 """Returns key used for caching Y=g(X).""" 63 return ((object_identity.Reference(self.x),) + 64 self._deep_tuple(tuple(sorted(self.kwargs.items())))) 65 66 @property 67 def y_key(self): 68 """Returns key used for caching X=g^{-1}(Y).""" 69 return ((object_identity.Reference(self.y),) + 70 self._deep_tuple(tuple(sorted(self.kwargs.items())))) 71 72 def merge(self, x=None, y=None, ildj_map=None, kwargs=None, mapping=None): 73 """Returns new _Mapping with args merged with self. 74 75 Args: 76 x: `Tensor`. Forward. 77 y: `Tensor`. Inverse. 78 ildj_map: `Dictionary`. This is a mapping from event_ndims to a `Tensor` 79 representing the inverse log det jacobian. 80 kwargs: Python dictionary. Extra args supplied to 81 forward/inverse/etc functions. 82 mapping: Instance of _Mapping to merge. Can only be specified if no other 83 arg is specified. 84 85 Returns: 86 mapping: New instance of `_Mapping` which has inputs merged with self. 87 88 Raises: 89 ValueError: if mapping and any other arg is not `None`. 90 """ 91 if mapping is None: 92 mapping = _Mapping(x=x, y=y, ildj_map=ildj_map, kwargs=kwargs) 93 elif any(arg is not None for arg in [x, y, ildj_map, kwargs]): 94 raise ValueError("Cannot simultaneously specify mapping and individual " 95 "arguments.") 96 97 return _Mapping( 98 x=self._merge(self.x, mapping.x), 99 y=self._merge(self.y, mapping.y), 100 ildj_map=self._merge_dicts(self.ildj_map, mapping.ildj_map), 101 kwargs=self._merge(self.kwargs, mapping.kwargs)) 102 103 def _merge_dicts(self, old=None, new=None): 104 """Helper to merge two dictionaries.""" 105 old = {} if old is None else old 106 new = {} if new is None else new 107 for k, v in new.items(): 108 val = old.get(k, None) 109 if val is not None and val is not v: 110 raise ValueError("Found different value for existing key " 111 "(key:{} old_value:{} new_value:{}".format( 112 k, old[k], v)) 113 old[k] = v 114 return old 115 116 def _merge(self, old, new): 117 """Helper to merge which handles merging one value.""" 118 if old is None: 119 return new 120 elif new is not None and old is not new: 121 raise ValueError("Incompatible values: %s != %s" % (old, new)) 122 return old 123 124 def _deep_tuple(self, x): 125 """Converts lists of lists to tuples of tuples.""" 126 return (tuple(map(self._deep_tuple, x)) 127 if isinstance(x, (list, tuple)) else x) 128 129 130class Bijector(metaclass=abc.ABCMeta): 131 r"""Interface for transformations of a `Distribution` sample. 132 133 Bijectors can be used to represent any differentiable and injective 134 (one to one) function defined on an open subset of `R^n`. Some non-injective 135 transformations are also supported (see "Non Injective Transforms" below). 136 137 #### Mathematical Details 138 139 A `Bijector` implements a [smooth covering map]( 140 https://en.wikipedia.org/wiki/Local_diffeomorphism), i.e., a local 141 diffeomorphism such that every point in the target has a neighborhood evenly 142 covered by a map ([see also]( 143 https://en.wikipedia.org/wiki/Covering_space#Covering_of_a_manifold)). 144 A `Bijector` is used by `TransformedDistribution` but can be generally used 145 for transforming a `Distribution` generated `Tensor`. A `Bijector` is 146 characterized by three operations: 147 148 1. Forward 149 150 Useful for turning one random outcome into another random outcome from a 151 different distribution. 152 153 2. Inverse 154 155 Useful for "reversing" a transformation to compute one probability in 156 terms of another. 157 158 3. `log_det_jacobian(x)` 159 160 "The log of the absolute value of the determinant of the matrix of all 161 first-order partial derivatives of the inverse function." 162 163 Useful for inverting a transformation to compute one probability in terms 164 of another. Geometrically, the Jacobian determinant is the volume of the 165 transformation and is used to scale the probability. 166 167 We take the absolute value of the determinant before log to avoid NaN 168 values. Geometrically, a negative determinant corresponds to an 169 orientation-reversing transformation. It is ok for us to discard the sign 170 of the determinant because we only integrate everywhere-nonnegative 171 functions (probability densities) and the correct orientation is always the 172 one that produces a nonnegative integrand. 173 174 By convention, transformations of random variables are named in terms of the 175 forward transformation. The forward transformation creates samples, the 176 inverse is useful for computing probabilities. 177 178 #### Example Uses 179 180 - Basic properties: 181 182 ```python 183 x = ... # A tensor. 184 # Evaluate forward transformation. 185 fwd_x = my_bijector.forward(x) 186 x == my_bijector.inverse(fwd_x) 187 x != my_bijector.forward(fwd_x) # Not equal because x != g(g(x)). 188 ``` 189 190 - Computing a log-likelihood: 191 192 ```python 193 def transformed_log_prob(bijector, log_prob, x): 194 return (bijector.inverse_log_det_jacobian(x, event_ndims=0) + 195 log_prob(bijector.inverse(x))) 196 ``` 197 198 - Transforming a random outcome: 199 200 ```python 201 def transformed_sample(bijector, x): 202 return bijector.forward(x) 203 ``` 204 205 #### Example Bijectors 206 207 - "Exponential" 208 209 ```none 210 Y = g(X) = exp(X) 211 X ~ Normal(0, 1) # Univariate. 212 ``` 213 214 Implies: 215 216 ```none 217 g^{-1}(Y) = log(Y) 218 |Jacobian(g^{-1})(y)| = 1 / y 219 Y ~ LogNormal(0, 1), i.e., 220 prob(Y=y) = |Jacobian(g^{-1})(y)| * prob(X=g^{-1}(y)) 221 = (1 / y) Normal(log(y); 0, 1) 222 ``` 223 224 Here is an example of how one might implement the `Exp` bijector: 225 226 ```python 227 class Exp(Bijector): 228 229 def __init__(self, validate_args=False, name="exp"): 230 super(Exp, self).__init__( 231 validate_args=validate_args, 232 forward_min_event_ndims=0, 233 name=name) 234 235 def _forward(self, x): 236 return math_ops.exp(x) 237 238 def _inverse(self, y): 239 return math_ops.log(y) 240 241 def _inverse_log_det_jacobian(self, y): 242 return -self._forward_log_det_jacobian(self._inverse(y)) 243 244 def _forward_log_det_jacobian(self, x): 245 # Notice that we needn't do any reducing, even when`event_ndims > 0`. 246 # The base Bijector class will handle reducing for us; it knows how 247 # to do so because we called `super` `__init__` with 248 # `forward_min_event_ndims = 0`. 249 return x 250 ``` 251 252 - "Affine" 253 254 ```none 255 Y = g(X) = sqrtSigma * X + mu 256 X ~ MultivariateNormal(0, I_d) 257 ``` 258 259 Implies: 260 261 ```none 262 g^{-1}(Y) = inv(sqrtSigma) * (Y - mu) 263 |Jacobian(g^{-1})(y)| = det(inv(sqrtSigma)) 264 Y ~ MultivariateNormal(mu, sqrtSigma) , i.e., 265 prob(Y=y) = |Jacobian(g^{-1})(y)| * prob(X=g^{-1}(y)) 266 = det(sqrtSigma)^(-d) * 267 MultivariateNormal(inv(sqrtSigma) * (y - mu); 0, I_d) 268 ``` 269 270 #### Min_event_ndims and Naming 271 272 Bijectors are named for the dimensionality of data they act on (i.e. without 273 broadcasting). We can think of bijectors having an intrinsic `min_event_ndims` 274 , which is the minimum number of dimensions for the bijector act on. For 275 instance, a Cholesky decomposition requires a matrix, and hence 276 `min_event_ndims=2`. 277 278 Some examples: 279 280 `AffineScalar: min_event_ndims=0` 281 `Affine: min_event_ndims=1` 282 `Cholesky: min_event_ndims=2` 283 `Exp: min_event_ndims=0` 284 `Sigmoid: min_event_ndims=0` 285 `SoftmaxCentered: min_event_ndims=1` 286 287 Note the difference between `Affine` and `AffineScalar`. `AffineScalar` 288 operates on scalar events, whereas `Affine` operates on vector-valued events. 289 290 More generally, there is a `forward_min_event_ndims` and an 291 `inverse_min_event_ndims`. In most cases, these will be the same. 292 However, for some shape changing bijectors, these will be different 293 (e.g. a bijector which pads an extra dimension at the end, might have 294 `forward_min_event_ndims=0` and `inverse_min_event_ndims=1`. 295 296 297 #### Jacobian Determinant 298 299 The Jacobian determinant is a reduction over `event_ndims - min_event_ndims` 300 (`forward_min_event_ndims` for `forward_log_det_jacobian` and 301 `inverse_min_event_ndims` for `inverse_log_det_jacobian`). 302 To see this, consider the `Exp` `Bijector` applied to a `Tensor` which has 303 sample, batch, and event (S, B, E) shape semantics. Suppose the `Tensor`'s 304 partitioned-shape is `(S=[4], B=[2], E=[3, 3])`. The shape of the `Tensor` 305 returned by `forward` and `inverse` is unchanged, i.e., `[4, 2, 3, 3]`. 306 However the shape returned by `inverse_log_det_jacobian` is `[4, 2]` because 307 the Jacobian determinant is a reduction over the event dimensions. 308 309 Another example is the `Affine` `Bijector`. Because `min_event_ndims = 1`, the 310 Jacobian determinant reduction is over `event_ndims - 1`. 311 312 It is sometimes useful to implement the inverse Jacobian determinant as the 313 negative forward Jacobian determinant. For example, 314 315 ```python 316 def _inverse_log_det_jacobian(self, y): 317 return -self._forward_log_det_jac(self._inverse(y)) # Note negation. 318 ``` 319 320 The correctness of this approach can be seen from the following claim. 321 322 - Claim: 323 324 Assume `Y = g(X)` is a bijection whose derivative exists and is nonzero 325 for its domain, i.e., `dY/dX = d/dX g(X) != 0`. Then: 326 327 ```none 328 (log o det o jacobian o g^{-1})(Y) = -(log o det o jacobian o g)(X) 329 ``` 330 331 - Proof: 332 333 From the bijective, nonzero differentiability of `g`, the 334 [inverse function theorem]( 335 https://en.wikipedia.org/wiki/Inverse_function_theorem) 336 implies `g^{-1}` is differentiable in the image of `g`. 337 Applying the chain rule to `y = g(x) = g(g^{-1}(y))` yields 338 `I = g'(g^{-1}(y))*g^{-1}'(y)`. 339 The same theorem also implies `g^{-1}'` is non-singular therefore: 340 `inv[ g'(g^{-1}(y)) ] = g^{-1}'(y)`. 341 The claim follows from [properties of determinant]( 342 https://en.wikipedia.org/wiki/Determinant#Multiplicativity_and_matrix_groups). 343 344 Generally its preferable to directly implement the inverse Jacobian 345 determinant. This should have superior numerical stability and will often 346 share subgraphs with the `_inverse` implementation. 347 348 #### Is_constant_jacobian 349 350 Certain bijectors will have constant jacobian matrices. For instance, the 351 `Affine` bijector encodes multiplication by a matrix plus a shift, with 352 jacobian matrix, the same aforementioned matrix. 353 354 `is_constant_jacobian` encodes the fact that the jacobian matrix is constant. 355 The semantics of this argument are the following: 356 357 * Repeated calls to "log_det_jacobian" functions with the same 358 `event_ndims` (but not necessarily same input), will return the first 359 computed jacobian (because the matrix is constant, and hence is input 360 independent). 361 * `log_det_jacobian` implementations are merely broadcastable to the true 362 `log_det_jacobian` (because, again, the jacobian matrix is input 363 independent). Specifically, `log_det_jacobian` is implemented as the 364 log jacobian determinant for a single input. 365 366 ```python 367 class Identity(Bijector): 368 369 def __init__(self, validate_args=False, name="identity"): 370 super(Identity, self).__init__( 371 is_constant_jacobian=True, 372 validate_args=validate_args, 373 forward_min_event_ndims=0, 374 name=name) 375 376 def _forward(self, x): 377 return x 378 379 def _inverse(self, y): 380 return y 381 382 def _inverse_log_det_jacobian(self, y): 383 return -self._forward_log_det_jacobian(self._inverse(y)) 384 385 def _forward_log_det_jacobian(self, x): 386 # The full log jacobian determinant would be array_ops.zero_like(x). 387 # However, we circumvent materializing that, since the jacobian 388 # calculation is input independent, and we specify it for one input. 389 return constant_op.constant(0., x.dtype.base_dtype) 390 391 ``` 392 393 #### Subclass Requirements 394 395 - Subclasses typically implement: 396 397 - `_forward`, 398 - `_inverse`, 399 - `_inverse_log_det_jacobian`, 400 - `_forward_log_det_jacobian` (optional). 401 402 The `_forward_log_det_jacobian` is called when the bijector is inverted via 403 the `Invert` bijector. If undefined, a slightly less efficiently 404 calculation, `-1 * _inverse_log_det_jacobian`, is used. 405 406 If the bijector changes the shape of the input, you must also implement: 407 408 - _forward_event_shape_tensor, 409 - _forward_event_shape (optional), 410 - _inverse_event_shape_tensor, 411 - _inverse_event_shape (optional). 412 413 By default the event-shape is assumed unchanged from input. 414 415 - If the `Bijector`'s use is limited to `TransformedDistribution` (or friends 416 like `QuantizedDistribution`) then depending on your use, you may not need 417 to implement all of `_forward` and `_inverse` functions. 418 419 Examples: 420 421 1. Sampling (e.g., `sample`) only requires `_forward`. 422 2. Probability functions (e.g., `prob`, `cdf`, `survival`) only require 423 `_inverse` (and related). 424 3. Only calling probability functions on the output of `sample` means 425 `_inverse` can be implemented as a cache lookup. 426 427 See "Example Uses" [above] which shows how these functions are used to 428 transform a distribution. (Note: `_forward` could theoretically be 429 implemented as a cache lookup but this would require controlling the 430 underlying sample generation mechanism.) 431 432 #### Non Injective Transforms 433 434 **WARNING** Handing of non-injective transforms is subject to change. 435 436 Non injective maps `g` are supported, provided their domain `D` can be 437 partitioned into `k` disjoint subsets, `Union{D1, ..., Dk}`, such that, 438 ignoring sets of measure zero, the restriction of `g` to each subset is a 439 differentiable bijection onto `g(D)`. In particular, this implies that for 440 `y in g(D)`, the set inverse, i.e. `g^{-1}(y) = {x in D : g(x) = y}`, always 441 contains exactly `k` distinct points. 442 443 The property, `_is_injective` is set to `False` to indicate that the bijector 444 is not injective, yet satisfies the above condition. 445 446 The usual bijector API is modified in the case `_is_injective is False` (see 447 method docstrings for specifics). Here we show by example the `AbsoluteValue` 448 bijector. In this case, the domain `D = (-inf, inf)`, can be partitioned 449 into `D1 = (-inf, 0)`, `D2 = {0}`, and `D3 = (0, inf)`. Let `gi` be the 450 restriction of `g` to `Di`, then both `g1` and `g3` are bijections onto 451 `(0, inf)`, with `g1^{-1}(y) = -y`, and `g3^{-1}(y) = y`. We will use 452 `g1` and `g3` to define bijector methods over `D1` and `D3`. `D2 = {0}` is 453 an oddball in that `g2` is one to one, and the derivative is not well defined. 454 Fortunately, when considering transformations of probability densities 455 (e.g. in `TransformedDistribution`), sets of measure zero have no effect in 456 theory, and only a small effect in 32 or 64 bit precision. For that reason, 457 we define `inverse(0)` and `inverse_log_det_jacobian(0)` both as `[0, 0]`, 458 which is convenient and results in a left-semicontinuous pdf. 459 460 461 ```python 462 abs = tfp.distributions.bijectors.AbsoluteValue() 463 464 abs.forward(-1.) 465 ==> 1. 466 467 abs.forward(1.) 468 ==> 1. 469 470 abs.inverse(1.) 471 ==> (-1., 1.) 472 473 # The |dX/dY| is constant, == 1. So Log|dX/dY| == 0. 474 abs.inverse_log_det_jacobian(1., event_ndims=0) 475 ==> (0., 0.) 476 477 # Special case handling of 0. 478 abs.inverse(0.) 479 ==> (0., 0.) 480 481 abs.inverse_log_det_jacobian(0., event_ndims=0) 482 ==> (0., 0.) 483 ``` 484 485 """ 486 487 @abc.abstractmethod 488 def __init__(self, 489 graph_parents=None, 490 is_constant_jacobian=False, 491 validate_args=False, 492 dtype=None, 493 forward_min_event_ndims=None, 494 inverse_min_event_ndims=None, 495 name=None): 496 """Constructs Bijector. 497 498 A `Bijector` transforms random variables into new random variables. 499 500 Examples: 501 502 ```python 503 # Create the Y = g(X) = X transform. 504 identity = Identity() 505 506 # Create the Y = g(X) = exp(X) transform. 507 exp = Exp() 508 ``` 509 510 See `Bijector` subclass docstring for more details and specific examples. 511 512 Args: 513 graph_parents: Python list of graph prerequisites of this `Bijector`. 514 is_constant_jacobian: Python `bool` indicating that the Jacobian matrix is 515 not a function of the input. 516 validate_args: Python `bool`, default `False`. Whether to validate input 517 with asserts. If `validate_args` is `False`, and the inputs are invalid, 518 correct behavior is not guaranteed. 519 dtype: `tf.dtype` supported by this `Bijector`. `None` means dtype is not 520 enforced. 521 forward_min_event_ndims: Python `integer` indicating the minimum number of 522 dimensions `forward` operates on. 523 inverse_min_event_ndims: Python `integer` indicating the minimum number of 524 dimensions `inverse` operates on. Will be set to 525 `forward_min_event_ndims` by default, if no value is provided. 526 name: The name to give Ops created by the initializer. 527 528 Raises: 529 ValueError: If neither `forward_min_event_ndims` and 530 `inverse_min_event_ndims` are specified, or if either of them is 531 negative. 532 ValueError: If a member of `graph_parents` is not a `Tensor`. 533 """ 534 self._graph_parents = graph_parents or [] 535 536 if forward_min_event_ndims is None and inverse_min_event_ndims is None: 537 raise ValueError("Must specify at least one of `forward_min_event_ndims` " 538 "and `inverse_min_event_ndims`.") 539 elif inverse_min_event_ndims is None: 540 inverse_min_event_ndims = forward_min_event_ndims 541 elif forward_min_event_ndims is None: 542 forward_min_event_ndims = inverse_min_event_ndims 543 544 if not isinstance(forward_min_event_ndims, int): 545 raise TypeError("Expected forward_min_event_ndims to be of " 546 "type int, got {}".format( 547 type(forward_min_event_ndims).__name__)) 548 549 if not isinstance(inverse_min_event_ndims, int): 550 raise TypeError("Expected inverse_min_event_ndims to be of " 551 "type int, got {}".format( 552 type(inverse_min_event_ndims).__name__)) 553 554 if forward_min_event_ndims < 0: 555 raise ValueError("forward_min_event_ndims must be a non-negative " 556 "integer.") 557 if inverse_min_event_ndims < 0: 558 raise ValueError("inverse_min_event_ndims must be a non-negative " 559 "integer.") 560 561 self._forward_min_event_ndims = forward_min_event_ndims 562 self._inverse_min_event_ndims = inverse_min_event_ndims 563 self._is_constant_jacobian = is_constant_jacobian 564 self._constant_ildj_map = {} 565 self._validate_args = validate_args 566 self._dtype = dtype 567 # These dicts can only be accessed using _Mapping.x_key or _Mapping.y_key 568 self._from_y = {} 569 self._from_x = {} 570 if name: 571 self._name = name 572 else: 573 # We want the default convention to be snake_case rather than CamelCase 574 # since `Chain` uses bijector.name as the kwargs dictionary key. 575 def camel_to_snake(name): 576 s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name) 577 return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() 578 self._name = camel_to_snake(type(self).__name__.lstrip("_")) 579 580 for i, t in enumerate(self._graph_parents): 581 if t is None or not tensor_util.is_tf_type(t): 582 raise ValueError("Graph parent item %d is not a Tensor; %s." % (i, t)) 583 584 @property 585 def graph_parents(self): 586 """Returns this `Bijector`'s graph_parents as a Python list.""" 587 return self._graph_parents 588 589 @property 590 def forward_min_event_ndims(self): 591 """Returns the minimal number of dimensions bijector.forward operates on.""" 592 return self._forward_min_event_ndims 593 594 @property 595 def inverse_min_event_ndims(self): 596 """Returns the minimal number of dimensions bijector.inverse operates on.""" 597 return self._inverse_min_event_ndims 598 599 @property 600 def is_constant_jacobian(self): 601 """Returns true iff the Jacobian matrix is not a function of x. 602 603 Note: Jacobian matrix is either constant for both forward and inverse or 604 neither. 605 606 Returns: 607 is_constant_jacobian: Python `bool`. 608 """ 609 return self._is_constant_jacobian 610 611 @property 612 def _is_injective(self): 613 """Returns true iff the forward map `g` is injective (one-to-one function). 614 615 **WARNING** This hidden property and its behavior are subject to change. 616 617 Note: Non-injective maps `g` are supported, provided their domain `D` can 618 be partitioned into `k` disjoint subsets, `Union{D1, ..., Dk}`, such that, 619 ignoring sets of measure zero, the restriction of `g` to each subset is a 620 differentiable bijection onto `g(D)`. 621 622 Returns: 623 is_injective: Python `bool`. 624 """ 625 return True 626 627 @property 628 def validate_args(self): 629 """Returns True if Tensor arguments will be validated.""" 630 return self._validate_args 631 632 @property 633 def dtype(self): 634 """dtype of `Tensor`s transformable by this distribution.""" 635 return self._dtype 636 637 @property 638 def name(self): 639 """Returns the string name of this `Bijector`.""" 640 return self._name 641 642 def _forward_event_shape_tensor(self, input_shape): 643 """Subclass implementation for `forward_event_shape_tensor` function.""" 644 # By default, we assume event_shape is unchanged. 645 return input_shape 646 647 def forward_event_shape_tensor(self, 648 input_shape, 649 name="forward_event_shape_tensor"): 650 """Shape of a single sample from a single batch as an `int32` 1D `Tensor`. 651 652 Args: 653 input_shape: `Tensor`, `int32` vector indicating event-portion shape 654 passed into `forward` function. 655 name: name to give to the op 656 657 Returns: 658 forward_event_shape_tensor: `Tensor`, `int32` vector indicating 659 event-portion shape after applying `forward`. 660 """ 661 with self._name_scope(name, [input_shape]): 662 input_shape = ops.convert_to_tensor(input_shape, dtype=dtypes.int32, 663 name="input_shape") 664 return self._forward_event_shape_tensor(input_shape) 665 666 def _forward_event_shape(self, input_shape): 667 """Subclass implementation for `forward_event_shape` public function.""" 668 # By default, we assume event_shape is unchanged. 669 return input_shape 670 671 def forward_event_shape(self, input_shape): 672 """Shape of a single sample from a single batch as a `TensorShape`. 673 674 Same meaning as `forward_event_shape_tensor`. May be only partially defined. 675 676 Args: 677 input_shape: `TensorShape` indicating event-portion shape passed into 678 `forward` function. 679 680 Returns: 681 forward_event_shape_tensor: `TensorShape` indicating event-portion shape 682 after applying `forward`. Possibly unknown. 683 """ 684 return self._forward_event_shape(tensor_shape.TensorShape(input_shape)) 685 686 def _inverse_event_shape_tensor(self, output_shape): 687 """Subclass implementation for `inverse_event_shape_tensor` function.""" 688 # By default, we assume event_shape is unchanged. 689 return output_shape 690 691 def inverse_event_shape_tensor(self, 692 output_shape, 693 name="inverse_event_shape_tensor"): 694 """Shape of a single sample from a single batch as an `int32` 1D `Tensor`. 695 696 Args: 697 output_shape: `Tensor`, `int32` vector indicating event-portion shape 698 passed into `inverse` function. 699 name: name to give to the op 700 701 Returns: 702 inverse_event_shape_tensor: `Tensor`, `int32` vector indicating 703 event-portion shape after applying `inverse`. 704 """ 705 with self._name_scope(name, [output_shape]): 706 output_shape = ops.convert_to_tensor(output_shape, dtype=dtypes.int32, 707 name="output_shape") 708 return self._inverse_event_shape_tensor(output_shape) 709 710 def _inverse_event_shape(self, output_shape): 711 """Subclass implementation for `inverse_event_shape` public function.""" 712 # By default, we assume event_shape is unchanged. 713 return tensor_shape.TensorShape(output_shape) 714 715 def inverse_event_shape(self, output_shape): 716 """Shape of a single sample from a single batch as a `TensorShape`. 717 718 Same meaning as `inverse_event_shape_tensor`. May be only partially defined. 719 720 Args: 721 output_shape: `TensorShape` indicating event-portion shape passed into 722 `inverse` function. 723 724 Returns: 725 inverse_event_shape_tensor: `TensorShape` indicating event-portion shape 726 after applying `inverse`. Possibly unknown. 727 """ 728 return self._inverse_event_shape(output_shape) 729 730 def _forward(self, x): 731 """Subclass implementation for `forward` public function.""" 732 raise NotImplementedError("forward not implemented.") 733 734 def _call_forward(self, x, name, **kwargs): 735 with self._name_scope(name, [x]): 736 x = ops.convert_to_tensor(x, name="x") 737 self._maybe_assert_dtype(x) 738 if not self._is_injective: # No caching for non-injective 739 return self._forward(x, **kwargs) 740 mapping = self._lookup(x=x, kwargs=kwargs) 741 if mapping.y is not None: 742 return mapping.y 743 mapping = mapping.merge(y=self._forward(x, **kwargs)) 744 self._cache(mapping) 745 return mapping.y 746 747 def forward(self, x, name="forward"): 748 """Returns the forward `Bijector` evaluation, i.e., X = g(Y). 749 750 Args: 751 x: `Tensor`. The input to the "forward" evaluation. 752 name: The name to give this op. 753 754 Returns: 755 `Tensor`. 756 757 Raises: 758 TypeError: if `self.dtype` is specified and `x.dtype` is not 759 `self.dtype`. 760 NotImplementedError: if `_forward` is not implemented. 761 """ 762 return self._call_forward(x, name) 763 764 def _inverse(self, y): 765 """Subclass implementation for `inverse` public function.""" 766 raise NotImplementedError("inverse not implemented") 767 768 def _call_inverse(self, y, name, **kwargs): 769 with self._name_scope(name, [y]): 770 y = ops.convert_to_tensor(y, name="y") 771 self._maybe_assert_dtype(y) 772 if not self._is_injective: # No caching for non-injective 773 return self._inverse(y, **kwargs) 774 mapping = self._lookup(y=y, kwargs=kwargs) 775 if mapping.x is not None: 776 return mapping.x 777 mapping = mapping.merge(x=self._inverse(y, **kwargs)) 778 self._cache(mapping) 779 return mapping.x 780 781 def inverse(self, y, name="inverse"): 782 """Returns the inverse `Bijector` evaluation, i.e., X = g^{-1}(Y). 783 784 Args: 785 y: `Tensor`. The input to the "inverse" evaluation. 786 name: The name to give this op. 787 788 Returns: 789 `Tensor`, if this bijector is injective. 790 If not injective, returns the k-tuple containing the unique 791 `k` points `(x1, ..., xk)` such that `g(xi) = y`. 792 793 Raises: 794 TypeError: if `self.dtype` is specified and `y.dtype` is not 795 `self.dtype`. 796 NotImplementedError: if `_inverse` is not implemented. 797 """ 798 return self._call_inverse(y, name) 799 800 def _inverse_log_det_jacobian(self, y): 801 """Subclass implementation of `inverse_log_det_jacobian` public function. 802 803 In particular, this method differs from the public function, in that it 804 does not take `event_ndims`. Thus, this implements the minimal Jacobian 805 determinant calculation (i.e. over `inverse_min_event_ndims`). 806 807 Args: 808 y: `Tensor`. The input to the "inverse_log_det_jacobian" evaluation. 809 Returns: 810 inverse_log_det_jacobian: `Tensor`, if this bijector is injective. 811 If not injective, returns the k-tuple containing jacobians for the 812 unique `k` points `(x1, ..., xk)` such that `g(xi) = y`. 813 """ 814 raise NotImplementedError("inverse_log_det_jacobian not implemented.") 815 816 def _call_inverse_log_det_jacobian(self, y, event_ndims, name, **kwargs): 817 with self._name_scope(name, [y]): 818 if event_ndims in self._constant_ildj_map: 819 return self._constant_ildj_map[event_ndims] 820 y = ops.convert_to_tensor(y, name="y") 821 self._maybe_assert_dtype(y) 822 with ops.control_dependencies(self._check_valid_event_ndims( 823 min_event_ndims=self.inverse_min_event_ndims, 824 event_ndims=event_ndims)): 825 if not self._is_injective: # No caching for non-injective 826 try: 827 ildjs = self._inverse_log_det_jacobian(y, **kwargs) 828 return tuple(self._reduce_jacobian_det_over_event( 829 y, ildj, self.inverse_min_event_ndims, event_ndims) 830 for ildj in ildjs) 831 except NotImplementedError as original_exception: 832 try: 833 x = self._inverse(y, **kwargs) 834 fldjs = self._forward_log_det_jacobian(x, **kwargs) 835 return tuple(self._reduce_jacobian_det_over_event( 836 x, -fldj, self.forward_min_event_ndims, event_ndims) 837 for fldj in fldjs) 838 except NotImplementedError: 839 raise original_exception 840 841 mapping = self._lookup(y=y, kwargs=kwargs) 842 if mapping.ildj_map is not None and event_ndims in mapping.ildj_map: 843 return mapping.ildj_map[event_ndims] 844 try: 845 x = None # Not needed; leave cache as is. 846 ildj = self._inverse_log_det_jacobian(y, **kwargs) 847 ildj = self._reduce_jacobian_det_over_event( 848 y, ildj, self.inverse_min_event_ndims, event_ndims) 849 except NotImplementedError as original_exception: 850 try: 851 x = (mapping.x if mapping.x is not None 852 else self._inverse(y, **kwargs)) 853 ildj = -self._forward_log_det_jacobian(x, **kwargs) 854 ildj = self._reduce_jacobian_det_over_event( 855 x, ildj, self.forward_min_event_ndims, event_ndims) 856 except NotImplementedError: 857 raise original_exception 858 859 mapping = mapping.merge(x=x, ildj_map={event_ndims: ildj}) 860 self._cache(mapping) 861 if self.is_constant_jacobian: 862 self._constant_ildj_map[event_ndims] = ildj 863 return ildj 864 865 def inverse_log_det_jacobian( 866 self, y, event_ndims, name="inverse_log_det_jacobian"): 867 """Returns the (log o det o Jacobian o inverse)(y). 868 869 Mathematically, returns: `log(det(dX/dY))(Y)`. (Recall that: `X=g^{-1}(Y)`.) 870 871 Note that `forward_log_det_jacobian` is the negative of this function, 872 evaluated at `g^{-1}(y)`. 873 874 Args: 875 y: `Tensor`. The input to the "inverse" Jacobian determinant evaluation. 876 event_ndims: Number of dimensions in the probabilistic events being 877 transformed. Must be greater than or equal to 878 `self.inverse_min_event_ndims`. The result is summed over the final 879 dimensions to produce a scalar Jacobian determinant for each event, 880 i.e. it has shape `y.shape.ndims - event_ndims` dimensions. 881 name: The name to give this op. 882 883 Returns: 884 `Tensor`, if this bijector is injective. 885 If not injective, returns the tuple of local log det 886 Jacobians, `log(det(Dg_i^{-1}(y)))`, where `g_i` is the restriction 887 of `g` to the `ith` partition `Di`. 888 889 Raises: 890 TypeError: if `self.dtype` is specified and `y.dtype` is not 891 `self.dtype`. 892 NotImplementedError: if `_inverse_log_det_jacobian` is not implemented. 893 """ 894 return self._call_inverse_log_det_jacobian(y, event_ndims, name) 895 896 def _forward_log_det_jacobian(self, x): 897 """Subclass implementation of `forward_log_det_jacobian` public function. 898 899 In particular, this method differs from the public function, in that it 900 does not take `event_ndims`. Thus, this implements the minimal Jacobian 901 determinant calculation (i.e. over `forward_min_event_ndims`). 902 903 Args: 904 x: `Tensor`. The input to the "forward_log_det_jacobian" evaluation. 905 906 Returns: 907 forward_log_det_jacobian: `Tensor`, if this bijector is injective. 908 If not injective, returns the k-tuple containing jacobians for the 909 unique `k` points `(x1, ..., xk)` such that `g(xi) = y`. 910 """ 911 912 raise NotImplementedError( 913 "forward_log_det_jacobian not implemented.") 914 915 def _call_forward_log_det_jacobian(self, x, event_ndims, name, **kwargs): 916 if not self._is_injective: 917 raise NotImplementedError( 918 "forward_log_det_jacobian cannot be implemented for non-injective " 919 "transforms.") 920 with self._name_scope(name, [x]): 921 with ops.control_dependencies(self._check_valid_event_ndims( 922 min_event_ndims=self.forward_min_event_ndims, 923 event_ndims=event_ndims)): 924 if event_ndims in self._constant_ildj_map: 925 # Need "-1. *" to avoid invalid-unary-operand-type linter warning. 926 return -1. * self._constant_ildj_map[event_ndims] 927 x = ops.convert_to_tensor(x, name="x") 928 self._maybe_assert_dtype(x) 929 if not self._is_injective: # No caching for non-injective 930 try: 931 fldjs = self._forward_log_det_jacobian(x, **kwargs) # No caching. 932 return tuple(self._reduce_jacobian_det_over_event( 933 x, fldj, self.forward_min_event_ndims, event_ndims) 934 for fldj in fldjs) 935 except NotImplementedError as original_exception: 936 try: 937 y = self._forward(x, **kwargs) 938 ildjs = self._inverse_log_det_jacobian(y, **kwargs) 939 return tuple(self._reduce_jacobian_det_over_event( 940 y, -ildj, self.inverse_min_event_ndims, event_ndims) 941 for ildj in ildjs) 942 except NotImplementedError: 943 raise original_exception 944 mapping = self._lookup(x=x, kwargs=kwargs) 945 if mapping.ildj_map is not None and event_ndims in mapping.ildj_map: 946 return -mapping.ildj_map[event_ndims] 947 try: 948 y = None # Not needed; leave cache as is. 949 ildj = -self._forward_log_det_jacobian(x, **kwargs) 950 ildj = self._reduce_jacobian_det_over_event( 951 x, ildj, self.forward_min_event_ndims, event_ndims) 952 except NotImplementedError as original_exception: 953 try: 954 y = (mapping.y if mapping.y is not None 955 else self._forward(x, **kwargs)) 956 ildj = self._inverse_log_det_jacobian(y, **kwargs) 957 ildj = self._reduce_jacobian_det_over_event( 958 y, ildj, self.inverse_min_event_ndims, event_ndims) 959 except NotImplementedError: 960 raise original_exception 961 mapping = mapping.merge(y=y, ildj_map={event_ndims: ildj}) 962 self._cache(mapping) 963 if self.is_constant_jacobian: 964 self._constant_ildj_map[event_ndims] = ildj 965 return -ildj 966 967 def forward_log_det_jacobian( 968 self, x, event_ndims, name="forward_log_det_jacobian"): 969 """Returns both the forward_log_det_jacobian. 970 971 Args: 972 x: `Tensor`. The input to the "forward" Jacobian determinant evaluation. 973 event_ndims: Number of dimensions in the probabilistic events being 974 transformed. Must be greater than or equal to 975 `self.forward_min_event_ndims`. The result is summed over the final 976 dimensions to produce a scalar Jacobian determinant for each event, 977 i.e. it has shape `x.shape.ndims - event_ndims` dimensions. 978 name: The name to give this op. 979 980 Returns: 981 `Tensor`, if this bijector is injective. 982 If not injective this is not implemented. 983 984 Raises: 985 TypeError: if `self.dtype` is specified and `y.dtype` is not 986 `self.dtype`. 987 NotImplementedError: if neither `_forward_log_det_jacobian` 988 nor {`_inverse`, `_inverse_log_det_jacobian`} are implemented, or 989 this is a non-injective bijector. 990 """ 991 return self._call_forward_log_det_jacobian(x, event_ndims, name) 992 993 @contextlib.contextmanager 994 def _name_scope(self, name=None, values=None): 995 """Helper function to standardize op scope.""" 996 with ops.name_scope(self.name): 997 with ops.name_scope( 998 name, values=(values or []) + self.graph_parents) as scope: 999 yield scope 1000 1001 def _maybe_assert_dtype(self, x): 1002 """Helper to check dtype when self.dtype is known.""" 1003 if self.dtype is not None and self.dtype.base_dtype != x.dtype.base_dtype: 1004 raise TypeError("Input had dtype %s but expected %s." % 1005 (self.dtype, x.dtype)) 1006 1007 def _cache(self, mapping): 1008 """Helper which stores mapping info in forward/inverse dicts.""" 1009 # Merging from lookup is an added check that we're not overwriting anything 1010 # which is not None. 1011 mapping = mapping.merge(mapping=self._lookup( 1012 mapping.x, mapping.y, mapping.kwargs)) 1013 if mapping.x is None and mapping.y is None: 1014 raise ValueError("Caching expects at least one of (x,y) to be known, " 1015 "i.e., not None.") 1016 self._from_x[mapping.x_key] = mapping 1017 self._from_y[mapping.y_key] = mapping 1018 1019 def _lookup(self, x=None, y=None, kwargs=None): 1020 """Helper which retrieves mapping info from forward/inverse dicts.""" 1021 mapping = _Mapping(x=x, y=y, kwargs=kwargs) 1022 # Since _cache requires both x,y to be set, we only need to do one cache 1023 # lookup since the mapping is always in both or neither. 1024 if mapping.x is not None: 1025 return self._from_x.get(mapping.x_key, mapping) 1026 if mapping.y is not None: 1027 return self._from_y.get(mapping.y_key, mapping) 1028 return mapping 1029 1030 def _reduce_jacobian_det_over_event( 1031 self, y, ildj, min_event_ndims, event_ndims): 1032 """Reduce jacobian over event_ndims - min_event_ndims.""" 1033 # In this case, we need to tile the Jacobian over the event and reduce. 1034 y_rank = array_ops.rank(y) 1035 y_shape = array_ops.shape(y)[ 1036 y_rank - event_ndims : y_rank - min_event_ndims] 1037 1038 ones = array_ops.ones(y_shape, ildj.dtype) 1039 reduced_ildj = math_ops.reduce_sum( 1040 ones * ildj, 1041 axis=self._get_event_reduce_dims(min_event_ndims, event_ndims)) 1042 # The multiplication by ones can change the inferred static shape so we try 1043 # to recover as much as possible. 1044 event_ndims_ = self._maybe_get_static_event_ndims(event_ndims) 1045 if (event_ndims_ is not None and 1046 y.shape.ndims is not None and 1047 ildj.shape.ndims is not None): 1048 y_shape = y.shape[y.shape.ndims - event_ndims_ : 1049 y.shape.ndims - min_event_ndims] 1050 broadcast_shape = array_ops.broadcast_static_shape(ildj.shape, y_shape) 1051 reduced_ildj.set_shape( 1052 broadcast_shape[: broadcast_shape.ndims - ( 1053 event_ndims_ - min_event_ndims)]) 1054 1055 return reduced_ildj 1056 1057 def _get_event_reduce_dims(self, min_event_ndims, event_ndims): 1058 """Compute the reduction dimensions given event_ndims.""" 1059 event_ndims_ = self._maybe_get_static_event_ndims(event_ndims) 1060 1061 if event_ndims_ is not None: 1062 return [-index for index in range(1, event_ndims_ - min_event_ndims + 1)] 1063 else: 1064 reduce_ndims = event_ndims - min_event_ndims 1065 return math_ops.range(-reduce_ndims, 0) 1066 1067 def _check_valid_event_ndims(self, min_event_ndims, event_ndims): 1068 """Check whether event_ndims is at least min_event_ndims.""" 1069 event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") 1070 event_ndims_ = tensor_util.constant_value(event_ndims) 1071 assertions = [] 1072 1073 if not event_ndims.dtype.is_integer: 1074 raise ValueError("Expected integer dtype, got dtype {}".format( 1075 event_ndims.dtype)) 1076 1077 if event_ndims_ is not None: 1078 if event_ndims.shape.ndims != 0: 1079 raise ValueError("Expected scalar event_ndims, got shape {}".format( 1080 event_ndims.shape)) 1081 if min_event_ndims > event_ndims_: 1082 raise ValueError("event_ndims ({}) must be larger than " 1083 "min_event_ndims ({})".format( 1084 event_ndims_, min_event_ndims)) 1085 elif self.validate_args: 1086 assertions += [ 1087 check_ops.assert_greater_equal(event_ndims, min_event_ndims)] 1088 1089 if event_ndims.shape.is_fully_defined(): 1090 if event_ndims.shape.ndims != 0: 1091 raise ValueError("Expected scalar shape, got ndims {}".format( 1092 event_ndims.shape.ndims)) 1093 1094 elif self.validate_args: 1095 assertions += [ 1096 check_ops.assert_rank(event_ndims, 0, message="Expected scalar.")] 1097 return assertions 1098 1099 def _maybe_get_static_event_ndims(self, event_ndims): 1100 """Helper which returns tries to return an integer static value.""" 1101 event_ndims_ = distribution_util.maybe_get_static_value(event_ndims) 1102 1103 if isinstance(event_ndims_, (np.generic, np.ndarray)): 1104 if event_ndims_.dtype not in (np.int32, np.int64): 1105 raise ValueError("Expected integer dtype, got dtype {}".format( 1106 event_ndims_.dtype)) 1107 1108 if isinstance(event_ndims_, np.ndarray) and len(event_ndims_.shape): 1109 raise ValueError("Expected a scalar integer, got {}".format( 1110 event_ndims_)) 1111 event_ndims_ = int(event_ndims_) 1112 1113 return event_ndims_ 1114