1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Bijector base.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import abc 22import collections 23import contextlib 24import re 25 26import numpy as np 27import six 28 29from tensorflow.python.framework import dtypes 30from tensorflow.python.framework import ops 31from tensorflow.python.framework import tensor_shape 32from tensorflow.python.framework import tensor_util 33from tensorflow.python.ops import array_ops 34from tensorflow.python.ops import check_ops 35from tensorflow.python.ops import math_ops 36from tensorflow.python.ops.distributions import util as distribution_util 37 38 39__all__ = [ 40 "Bijector", 41] 42 43 44class _Mapping(collections.namedtuple( 45 "_Mapping", ["x", "y", "ildj_map", "kwargs"])): 46 """Helper class to make it easier to manage caching in `Bijector`.""" 47 48 def __new__(cls, x=None, y=None, ildj_map=None, kwargs=None): 49 """Custom __new__ so namedtuple items have defaults. 50 51 Args: 52 x: `Tensor`. Forward. 53 y: `Tensor`. Inverse. 54 ildj_map: `Dictionary`. This is a mapping from event_ndims to a `Tensor` 55 representing the inverse log det jacobian. 56 kwargs: Python dictionary. Extra args supplied to 57 forward/inverse/etc functions. 58 59 Returns: 60 mapping: New instance of _Mapping. 61 """ 62 return super(_Mapping, cls).__new__(cls, x, y, ildj_map, kwargs) 63 64 @property 65 def x_key(self): 66 """Returns key used for caching Y=g(X).""" 67 return (self.x,) + self._deep_tuple(tuple(sorted(self.kwargs.items()))) 68 69 @property 70 def y_key(self): 71 """Returns key used for caching X=g^{-1}(Y).""" 72 return (self.y,) + self._deep_tuple(tuple(sorted(self.kwargs.items()))) 73 74 def merge(self, x=None, y=None, ildj_map=None, kwargs=None, mapping=None): 75 """Returns new _Mapping with args merged with self. 76 77 Args: 78 x: `Tensor`. Forward. 79 y: `Tensor`. Inverse. 80 ildj_map: `Dictionary`. This is a mapping from event_ndims to a `Tensor` 81 representing the inverse log det jacobian. 82 kwargs: Python dictionary. Extra args supplied to 83 forward/inverse/etc functions. 84 mapping: Instance of _Mapping to merge. Can only be specified if no other 85 arg is specified. 86 87 Returns: 88 mapping: New instance of `_Mapping` which has inputs merged with self. 89 90 Raises: 91 ValueError: if mapping and any other arg is not `None`. 92 """ 93 if mapping is None: 94 mapping = _Mapping(x=x, y=y, ildj_map=ildj_map, kwargs=kwargs) 95 elif any(arg is not None for arg in [x, y, ildj_map, kwargs]): 96 raise ValueError("Cannot simultaneously specify mapping and individual " 97 "arguments.") 98 99 return _Mapping( 100 x=self._merge(self.x, mapping.x), 101 y=self._merge(self.y, mapping.y), 102 ildj_map=self._merge_dicts(self.ildj_map, mapping.ildj_map), 103 kwargs=self._merge(self.kwargs, mapping.kwargs)) 104 105 def _merge_dicts(self, old=None, new=None): 106 """Helper to merge two dictionaries.""" 107 old = dict() if old is None else old 108 new = dict() if new is None else new 109 for k, v in six.iteritems(new): 110 val = old.get(k, None) 111 if val is not None and val != v: 112 raise ValueError("Found different value for existing key " 113 "(key:{} old_value:{} new_value:{}".format( 114 k, old[k], v)) 115 old[k] = v 116 return old 117 118 def _merge(self, old, new): 119 """Helper to merge which handles merging one value.""" 120 if old is None: 121 return new 122 elif new is not None and old != new: 123 raise ValueError("Incompatible values: %s != %s" % (old, new)) 124 return old 125 126 def _deep_tuple(self, x): 127 """Converts lists of lists to tuples of tuples.""" 128 return (tuple(map(self._deep_tuple, x)) 129 if isinstance(x, (list, tuple)) else x) 130 131 132@six.add_metaclass(abc.ABCMeta) 133class Bijector(object): 134 r"""Interface for transformations of a `Distribution` sample. 135 136 Bijectors can be used to represent any differentiable and injective 137 (one to one) function defined on an open subset of `R^n`. Some non-injective 138 transformations are also supported (see "Non Injective Transforms" below). 139 140 #### Mathematical Details 141 142 A `Bijector` implements a [smooth covering map]( 143 https://en.wikipedia.org/wiki/Local_diffeomorphism), i.e., a local 144 diffeomorphism such that every point in the target has a neighborhood evenly 145 covered by a map ([see also]( 146 https://en.wikipedia.org/wiki/Covering_space#Covering_of_a_manifold)). 147 A `Bijector` is used by `TransformedDistribution` but can be generally used 148 for transforming a `Distribution` generated `Tensor`. A `Bijector` is 149 characterized by three operations: 150 151 1. Forward 152 153 Useful for turning one random outcome into another random outcome from a 154 different distribution. 155 156 2. Inverse 157 158 Useful for "reversing" a transformation to compute one probability in 159 terms of another. 160 161 3. `log_det_jacobian(x)` 162 163 "The log of the absolute value of the determinant of the matrix of all 164 first-order partial derivatives of the inverse function." 165 166 Useful for inverting a transformation to compute one probability in terms 167 of another. Geometrically, the Jacobian determinant is the volume of the 168 transformation and is used to scale the probability. 169 170 We take the absolute value of the determinant before log to avoid NaN 171 values. Geometrically, a negative determinant corresponds to an 172 orientation-reversing transformation. It is ok for us to discard the sign 173 of the determinant because we only integrate everywhere-nonnegative 174 functions (probability densities) and the correct orientation is always the 175 one that produces a nonnegative integrand. 176 177 By convention, transformations of random variables are named in terms of the 178 forward transformation. The forward transformation creates samples, the 179 inverse is useful for computing probabilities. 180 181 #### Example Uses 182 183 - Basic properties: 184 185 ```python 186 x = ... # A tensor. 187 # Evaluate forward transformation. 188 fwd_x = my_bijector.forward(x) 189 x == my_bijector.inverse(fwd_x) 190 x != my_bijector.forward(fwd_x) # Not equal because x != g(g(x)). 191 ``` 192 193 - Computing a log-likelihood: 194 195 ```python 196 def transformed_log_prob(bijector, log_prob, x): 197 return (bijector.inverse_log_det_jacobian(x, event_ndims=0) + 198 log_prob(bijector.inverse(x))) 199 ``` 200 201 - Transforming a random outcome: 202 203 ```python 204 def transformed_sample(bijector, x): 205 return bijector.forward(x) 206 ``` 207 208 #### Example Bijectors 209 210 - "Exponential" 211 212 ```none 213 Y = g(X) = exp(X) 214 X ~ Normal(0, 1) # Univariate. 215 ``` 216 217 Implies: 218 219 ```none 220 g^{-1}(Y) = log(Y) 221 |Jacobian(g^{-1})(y)| = 1 / y 222 Y ~ LogNormal(0, 1), i.e., 223 prob(Y=y) = |Jacobian(g^{-1})(y)| * prob(X=g^{-1}(y)) 224 = (1 / y) Normal(log(y); 0, 1) 225 ``` 226 227 Here is an example of how one might implement the `Exp` bijector: 228 229 ```python 230 class Exp(Bijector): 231 232 def __init__(self, validate_args=False, name="exp"): 233 super(Exp, self).__init__( 234 validate_args=validate_args, 235 forward_min_event_ndims=0, 236 name=name) 237 238 def _forward(self, x): 239 return math_ops.exp(x) 240 241 def _inverse(self, y): 242 return math_ops.log(y) 243 244 def _inverse_log_det_jacobian(self, y): 245 return -self._forward_log_det_jacobian(self._inverse(y)) 246 247 def _forward_log_det_jacobian(self, x): 248 # Notice that we needn't do any reducing, even when`event_ndims > 0`. 249 # The base Bijector class will handle reducing for us; it knows how 250 # to do so because we called `super` `__init__` with 251 # `forward_min_event_ndims = 0`. 252 return x 253 ``` 254 255 - "Affine" 256 257 ```none 258 Y = g(X) = sqrtSigma * X + mu 259 X ~ MultivariateNormal(0, I_d) 260 ``` 261 262 Implies: 263 264 ```none 265 g^{-1}(Y) = inv(sqrtSigma) * (Y - mu) 266 |Jacobian(g^{-1})(y)| = det(inv(sqrtSigma)) 267 Y ~ MultivariateNormal(mu, sqrtSigma) , i.e., 268 prob(Y=y) = |Jacobian(g^{-1})(y)| * prob(X=g^{-1}(y)) 269 = det(sqrtSigma)^(-d) * 270 MultivariateNormal(inv(sqrtSigma) * (y - mu); 0, I_d) 271 ``` 272 273 #### Min_event_ndims and Naming 274 275 Bijectors are named for the dimensionality of data they act on (i.e. without 276 broadcasting). We can think of bijectors having an intrinsic `min_event_ndims` 277 , which is the minimum number of dimensions for the bijector act on. For 278 instance, a Cholesky decomposition requires a matrix, and hence 279 `min_event_ndims=2`. 280 281 Some examples: 282 283 `AffineScalar: min_event_ndims=0` 284 `Affine: min_event_ndims=1` 285 `Cholesky: min_event_ndims=2` 286 `Exp: min_event_ndims=0` 287 `Sigmoid: min_event_ndims=0` 288 `SoftmaxCentered: min_event_ndims=1` 289 290 Note the difference between `Affine` and `AffineScalar`. `AffineScalar` 291 operates on scalar events, whereas `Affine` operates on vector-valued events. 292 293 More generally, there is a `forward_min_event_ndims` and an 294 `inverse_min_event_ndims`. In most cases, these will be the same. 295 However, for some shape changing bijectors, these will be different 296 (e.g. a bijector which pads an extra dimension at the end, might have 297 `forward_min_event_ndims=0` and `inverse_min_event_ndims=1`. 298 299 300 #### Jacobian Determinant 301 302 The Jacobian determinant is a reduction over `event_ndims - min_event_ndims` 303 (`forward_min_event_ndims` for `forward_log_det_jacobian` and 304 `inverse_min_event_ndims` for `inverse_log_det_jacobian`). 305 To see this, consider the `Exp` `Bijector` applied to a `Tensor` which has 306 sample, batch, and event (S, B, E) shape semantics. Suppose the `Tensor`'s 307 partitioned-shape is `(S=[4], B=[2], E=[3, 3])`. The shape of the `Tensor` 308 returned by `forward` and `inverse` is unchanged, i.e., `[4, 2, 3, 3]`. 309 However the shape returned by `inverse_log_det_jacobian` is `[4, 2]` because 310 the Jacobian determinant is a reduction over the event dimensions. 311 312 Another example is the `Affine` `Bijector`. Because `min_event_ndims = 1`, the 313 Jacobian determinant reduction is over `event_ndims - 1`. 314 315 It is sometimes useful to implement the inverse Jacobian determinant as the 316 negative forward Jacobian determinant. For example, 317 318 ```python 319 def _inverse_log_det_jacobian(self, y): 320 return -self._forward_log_det_jac(self._inverse(y)) # Note negation. 321 ``` 322 323 The correctness of this approach can be seen from the following claim. 324 325 - Claim: 326 327 Assume `Y = g(X)` is a bijection whose derivative exists and is nonzero 328 for its domain, i.e., `dY/dX = d/dX g(X) != 0`. Then: 329 330 ```none 331 (log o det o jacobian o g^{-1})(Y) = -(log o det o jacobian o g)(X) 332 ``` 333 334 - Proof: 335 336 From the bijective, nonzero differentiability of `g`, the 337 [inverse function theorem]( 338 https://en.wikipedia.org/wiki/Inverse_function_theorem) 339 implies `g^{-1}` is differentiable in the image of `g`. 340 Applying the chain rule to `y = g(x) = g(g^{-1}(y))` yields 341 `I = g'(g^{-1}(y))*g^{-1}'(y)`. 342 The same theorem also implies `g^{-1}'` is non-singular therefore: 343 `inv[ g'(g^{-1}(y)) ] = g^{-1}'(y)`. 344 The claim follows from [properties of determinant]( 345 https://en.wikipedia.org/wiki/Determinant#Multiplicativity_and_matrix_groups). 346 347 Generally its preferable to directly implement the inverse Jacobian 348 determinant. This should have superior numerical stability and will often 349 share subgraphs with the `_inverse` implementation. 350 351 #### Is_constant_jacobian 352 353 Certain bijectors will have constant jacobian matrices. For instance, the 354 `Affine` bijector encodes multiplication by a matrix plus a shift, with 355 jacobian matrix, the same aforementioned matrix. 356 357 `is_constant_jacobian` encodes the fact that the jacobian matrix is constant. 358 The semantics of this argument are the following: 359 360 * Repeated calls to "log_det_jacobian" functions with the same 361 `event_ndims` (but not necessarily same input), will return the first 362 computed jacobian (because the matrix is constant, and hence is input 363 independent). 364 * `log_det_jacobian` implementations are merely broadcastable to the true 365 `log_det_jacobian` (because, again, the jacobian matrix is input 366 independent). Specifically, `log_det_jacobian` is implemented as the 367 log jacobian determinant for a single input. 368 369 ```python 370 class Identity(Bijector): 371 372 def __init__(self, validate_args=False, name="identity"): 373 super(Identity, self).__init__( 374 is_constant_jacobian=True, 375 validate_args=validate_args, 376 forward_min_event_ndims=0, 377 name=name) 378 379 def _forward(self, x): 380 return x 381 382 def _inverse(self, y): 383 return y 384 385 def _inverse_log_det_jacobian(self, y): 386 return -self._forward_log_det_jacobian(self._inverse(y)) 387 388 def _forward_log_det_jacobian(self, x): 389 # The full log jacobian determinant would be array_ops.zero_like(x). 390 # However, we circumvent materializing that, since the jacobian 391 # calculation is input independent, and we specify it for one input. 392 return constant_op.constant(0., x.dtype.base_dtype) 393 394 ``` 395 396 #### Subclass Requirements 397 398 - Subclasses typically implement: 399 400 - `_forward`, 401 - `_inverse`, 402 - `_inverse_log_det_jacobian`, 403 - `_forward_log_det_jacobian` (optional). 404 405 The `_forward_log_det_jacobian` is called when the bijector is inverted via 406 the `Invert` bijector. If undefined, a slightly less efficiently 407 calculation, `-1 * _inverse_log_det_jacobian`, is used. 408 409 If the bijector changes the shape of the input, you must also implement: 410 411 - _forward_event_shape_tensor, 412 - _forward_event_shape (optional), 413 - _inverse_event_shape_tensor, 414 - _inverse_event_shape (optional). 415 416 By default the event-shape is assumed unchanged from input. 417 418 - If the `Bijector`'s use is limited to `TransformedDistribution` (or friends 419 like `QuantizedDistribution`) then depending on your use, you may not need 420 to implement all of `_forward` and `_inverse` functions. 421 422 Examples: 423 424 1. Sampling (e.g., `sample`) only requires `_forward`. 425 2. Probability functions (e.g., `prob`, `cdf`, `survival`) only require 426 `_inverse` (and related). 427 3. Only calling probability functions on the output of `sample` means 428 `_inverse` can be implemented as a cache lookup. 429 430 See "Example Uses" [above] which shows how these functions are used to 431 transform a distribution. (Note: `_forward` could theoretically be 432 implemented as a cache lookup but this would require controlling the 433 underlying sample generation mechanism.) 434 435 #### Non Injective Transforms 436 437 **WARNING** Handing of non-injective transforms is subject to change. 438 439 Non injective maps `g` are supported, provided their domain `D` can be 440 partitioned into `k` disjoint subsets, `Union{D1, ..., Dk}`, such that, 441 ignoring sets of measure zero, the restriction of `g` to each subset is a 442 differentiable bijection onto `g(D)`. In particular, this imples that for 443 `y in g(D)`, the set inverse, i.e. `g^{-1}(y) = {x in D : g(x) = y}`, always 444 contains exactly `k` distinct points. 445 446 The property, `_is_injective` is set to `False` to indicate that the bijector 447 is not injective, yet satisfies the above condition. 448 449 The usual bijector API is modified in the case `_is_injective is False` (see 450 method docstrings for specifics). Here we show by example the `AbsoluteValue` 451 bijector. In this case, the domain `D = (-inf, inf)`, can be partitioned 452 into `D1 = (-inf, 0)`, `D2 = {0}`, and `D3 = (0, inf)`. Let `gi` be the 453 restriction of `g` to `Di`, then both `g1` and `g3` are bijections onto 454 `(0, inf)`, with `g1^{-1}(y) = -y`, and `g3^{-1}(y) = y`. We will use 455 `g1` and `g3` to define bijector methods over `D1` and `D3`. `D2 = {0}` is 456 an oddball in that `g2` is one to one, and the derivative is not well defined. 457 Fortunately, when considering transformations of probability densities 458 (e.g. in `TransformedDistribution`), sets of measure zero have no effect in 459 theory, and only a small effect in 32 or 64 bit precision. For that reason, 460 we define `inverse(0)` and `inverse_log_det_jacobian(0)` both as `[0, 0]`, 461 which is convenient and results in a left-semicontinuous pdf. 462 463 464 ```python 465 abs = tfp.distributions.bijectors.AbsoluteValue() 466 467 abs.forward(-1.) 468 ==> 1. 469 470 abs.forward(1.) 471 ==> 1. 472 473 abs.inverse(1.) 474 ==> (-1., 1.) 475 476 # The |dX/dY| is constant, == 1. So Log|dX/dY| == 0. 477 abs.inverse_log_det_jacobian(1., event_ndims=0) 478 ==> (0., 0.) 479 480 # Special case handling of 0. 481 abs.inverse(0.) 482 ==> (0., 0.) 483 484 abs.inverse_log_det_jacobian(0., event_ndims=0) 485 ==> (0., 0.) 486 ``` 487 488 """ 489 490 @abc.abstractmethod 491 def __init__(self, 492 graph_parents=None, 493 is_constant_jacobian=False, 494 validate_args=False, 495 dtype=None, 496 forward_min_event_ndims=None, 497 inverse_min_event_ndims=None, 498 name=None): 499 """Constructs Bijector. 500 501 A `Bijector` transforms random variables into new random variables. 502 503 Examples: 504 505 ```python 506 # Create the Y = g(X) = X transform. 507 identity = Identity() 508 509 # Create the Y = g(X) = exp(X) transform. 510 exp = Exp() 511 ``` 512 513 See `Bijector` subclass docstring for more details and specific examples. 514 515 Args: 516 graph_parents: Python list of graph prerequisites of this `Bijector`. 517 is_constant_jacobian: Python `bool` indicating that the Jacobian matrix is 518 not a function of the input. 519 validate_args: Python `bool`, default `False`. Whether to validate input 520 with asserts. If `validate_args` is `False`, and the inputs are invalid, 521 correct behavior is not guaranteed. 522 dtype: `tf.dtype` supported by this `Bijector`. `None` means dtype is not 523 enforced. 524 forward_min_event_ndims: Python `integer` indicating the minimum number of 525 dimensions `forward` operates on. 526 inverse_min_event_ndims: Python `integer` indicating the minimum number of 527 dimensions `inverse` operates on. Will be set to 528 `forward_min_event_ndims` by default, if no value is provided. 529 name: The name to give Ops created by the initializer. 530 531 Raises: 532 ValueError: If neither `forward_min_event_ndims` and 533 `inverse_min_event_ndims` are specified, or if either of them is 534 negative. 535 ValueError: If a member of `graph_parents` is not a `Tensor`. 536 """ 537 self._graph_parents = graph_parents or [] 538 539 if forward_min_event_ndims is None and inverse_min_event_ndims is None: 540 raise ValueError("Must specify at least one of `forward_min_event_ndims` " 541 "and `inverse_min_event_ndims`.") 542 elif inverse_min_event_ndims is None: 543 inverse_min_event_ndims = forward_min_event_ndims 544 elif forward_min_event_ndims is None: 545 forward_min_event_ndims = inverse_min_event_ndims 546 547 if not isinstance(forward_min_event_ndims, int): 548 raise TypeError("Expected forward_min_event_ndims to be of " 549 "type int, got {}".format( 550 type(forward_min_event_ndims).__name__)) 551 552 if not isinstance(inverse_min_event_ndims, int): 553 raise TypeError("Expected inverse_min_event_ndims to be of " 554 "type int, got {}".format( 555 type(inverse_min_event_ndims).__name__)) 556 557 if forward_min_event_ndims < 0: 558 raise ValueError("forward_min_event_ndims must be a non-negative " 559 "integer.") 560 if inverse_min_event_ndims < 0: 561 raise ValueError("inverse_min_event_ndims must be a non-negative " 562 "integer.") 563 564 self._forward_min_event_ndims = forward_min_event_ndims 565 self._inverse_min_event_ndims = inverse_min_event_ndims 566 self._is_constant_jacobian = is_constant_jacobian 567 self._constant_ildj_map = {} 568 self._validate_args = validate_args 569 self._dtype = dtype 570 self._from_y = {} 571 self._from_x = {} 572 if name: 573 self._name = name 574 else: 575 # We want the default convention to be snake_case rather than CamelCase 576 # since `Chain` uses bijector.name as the kwargs dictionary key. 577 def camel_to_snake(name): 578 s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name) 579 return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() 580 self._name = camel_to_snake(type(self).__name__.lstrip("_")) 581 582 for i, t in enumerate(self._graph_parents): 583 if t is None or not tensor_util.is_tensor(t): 584 raise ValueError("Graph parent item %d is not a Tensor; %s." % (i, t)) 585 586 @property 587 def graph_parents(self): 588 """Returns this `Bijector`'s graph_parents as a Python list.""" 589 return self._graph_parents 590 591 @property 592 def forward_min_event_ndims(self): 593 """Returns the minimal number of dimensions bijector.forward operates on.""" 594 return self._forward_min_event_ndims 595 596 @property 597 def inverse_min_event_ndims(self): 598 """Returns the minimal number of dimensions bijector.inverse operates on.""" 599 return self._inverse_min_event_ndims 600 601 @property 602 def is_constant_jacobian(self): 603 """Returns true iff the Jacobian matrix is not a function of x. 604 605 Note: Jacobian matrix is either constant for both forward and inverse or 606 neither. 607 608 Returns: 609 is_constant_jacobian: Python `bool`. 610 """ 611 return self._is_constant_jacobian 612 613 @property 614 def _is_injective(self): 615 """Returns true iff the forward map `g` is injective (one-to-one function). 616 617 **WARNING** This hidden property and its behavior are subject to change. 618 619 Note: Non-injective maps `g` are supported, provided their domain `D` can 620 be partitioned into `k` disjoint subsets, `Union{D1, ..., Dk}`, such that, 621 ignoring sets of measure zero, the restriction of `g` to each subset is a 622 differentiable bijection onto `g(D)`. 623 624 Returns: 625 is_injective: Python `bool`. 626 """ 627 return True 628 629 @property 630 def validate_args(self): 631 """Returns True if Tensor arguments will be validated.""" 632 return self._validate_args 633 634 @property 635 def dtype(self): 636 """dtype of `Tensor`s transformable by this distribution.""" 637 return self._dtype 638 639 @property 640 def name(self): 641 """Returns the string name of this `Bijector`.""" 642 return self._name 643 644 def _forward_event_shape_tensor(self, input_shape): 645 """Subclass implementation for `forward_event_shape_tensor` function.""" 646 # By default, we assume event_shape is unchanged. 647 return input_shape 648 649 def forward_event_shape_tensor(self, 650 input_shape, 651 name="forward_event_shape_tensor"): 652 """Shape of a single sample from a single batch as an `int32` 1D `Tensor`. 653 654 Args: 655 input_shape: `Tensor`, `int32` vector indicating event-portion shape 656 passed into `forward` function. 657 name: name to give to the op 658 659 Returns: 660 forward_event_shape_tensor: `Tensor`, `int32` vector indicating 661 event-portion shape after applying `forward`. 662 """ 663 with self._name_scope(name, [input_shape]): 664 input_shape = ops.convert_to_tensor(input_shape, dtype=dtypes.int32, 665 name="input_shape") 666 return self._forward_event_shape_tensor(input_shape) 667 668 def _forward_event_shape(self, input_shape): 669 """Subclass implementation for `forward_event_shape` public function.""" 670 # By default, we assume event_shape is unchanged. 671 return input_shape 672 673 def forward_event_shape(self, input_shape): 674 """Shape of a single sample from a single batch as a `TensorShape`. 675 676 Same meaning as `forward_event_shape_tensor`. May be only partially defined. 677 678 Args: 679 input_shape: `TensorShape` indicating event-portion shape passed into 680 `forward` function. 681 682 Returns: 683 forward_event_shape_tensor: `TensorShape` indicating event-portion shape 684 after applying `forward`. Possibly unknown. 685 """ 686 return self._forward_event_shape(tensor_shape.TensorShape(input_shape)) 687 688 def _inverse_event_shape_tensor(self, output_shape): 689 """Subclass implementation for `inverse_event_shape_tensor` function.""" 690 # By default, we assume event_shape is unchanged. 691 return output_shape 692 693 def inverse_event_shape_tensor(self, 694 output_shape, 695 name="inverse_event_shape_tensor"): 696 """Shape of a single sample from a single batch as an `int32` 1D `Tensor`. 697 698 Args: 699 output_shape: `Tensor`, `int32` vector indicating event-portion shape 700 passed into `inverse` function. 701 name: name to give to the op 702 703 Returns: 704 inverse_event_shape_tensor: `Tensor`, `int32` vector indicating 705 event-portion shape after applying `inverse`. 706 """ 707 with self._name_scope(name, [output_shape]): 708 output_shape = ops.convert_to_tensor(output_shape, dtype=dtypes.int32, 709 name="output_shape") 710 return self._inverse_event_shape_tensor(output_shape) 711 712 def _inverse_event_shape(self, output_shape): 713 """Subclass implementation for `inverse_event_shape` public function.""" 714 # By default, we assume event_shape is unchanged. 715 return tensor_shape.TensorShape(output_shape) 716 717 def inverse_event_shape(self, output_shape): 718 """Shape of a single sample from a single batch as a `TensorShape`. 719 720 Same meaning as `inverse_event_shape_tensor`. May be only partially defined. 721 722 Args: 723 output_shape: `TensorShape` indicating event-portion shape passed into 724 `inverse` function. 725 726 Returns: 727 inverse_event_shape_tensor: `TensorShape` indicating event-portion shape 728 after applying `inverse`. Possibly unknown. 729 """ 730 return self._inverse_event_shape(output_shape) 731 732 def _forward(self, x): 733 """Subclass implementation for `forward` public function.""" 734 raise NotImplementedError("forward not implemented.") 735 736 def _call_forward(self, x, name, **kwargs): 737 with self._name_scope(name, [x]): 738 x = ops.convert_to_tensor(x, name="x") 739 self._maybe_assert_dtype(x) 740 if not self._is_injective: # No caching for non-injective 741 return self._forward(x, **kwargs) 742 mapping = self._lookup(x=x, kwargs=kwargs) 743 if mapping.y is not None: 744 return mapping.y 745 mapping = mapping.merge(y=self._forward(x, **kwargs)) 746 self._cache(mapping) 747 return mapping.y 748 749 def forward(self, x, name="forward"): 750 """Returns the forward `Bijector` evaluation, i.e., X = g(Y). 751 752 Args: 753 x: `Tensor`. The input to the "forward" evaluation. 754 name: The name to give this op. 755 756 Returns: 757 `Tensor`. 758 759 Raises: 760 TypeError: if `self.dtype` is specified and `x.dtype` is not 761 `self.dtype`. 762 NotImplementedError: if `_forward` is not implemented. 763 """ 764 return self._call_forward(x, name) 765 766 def _inverse(self, y): 767 """Subclass implementation for `inverse` public function.""" 768 raise NotImplementedError("inverse not implemented") 769 770 def _call_inverse(self, y, name, **kwargs): 771 with self._name_scope(name, [y]): 772 y = ops.convert_to_tensor(y, name="y") 773 self._maybe_assert_dtype(y) 774 if not self._is_injective: # No caching for non-injective 775 return self._inverse(y, **kwargs) 776 mapping = self._lookup(y=y, kwargs=kwargs) 777 if mapping.x is not None: 778 return mapping.x 779 mapping = mapping.merge(x=self._inverse(y, **kwargs)) 780 self._cache(mapping) 781 return mapping.x 782 783 def inverse(self, y, name="inverse"): 784 """Returns the inverse `Bijector` evaluation, i.e., X = g^{-1}(Y). 785 786 Args: 787 y: `Tensor`. The input to the "inverse" evaluation. 788 name: The name to give this op. 789 790 Returns: 791 `Tensor`, if this bijector is injective. 792 If not injective, returns the k-tuple containing the unique 793 `k` points `(x1, ..., xk)` such that `g(xi) = y`. 794 795 Raises: 796 TypeError: if `self.dtype` is specified and `y.dtype` is not 797 `self.dtype`. 798 NotImplementedError: if `_inverse` is not implemented. 799 """ 800 return self._call_inverse(y, name) 801 802 def _inverse_log_det_jacobian(self, y): 803 """Subclass implementation of `inverse_log_det_jacobian` public function. 804 805 In particular, this method differs from the public function, in that it 806 does not take `event_ndims`. Thus, this implements the minimal Jacobian 807 determinant calculation (i.e. over `inverse_min_event_ndims`). 808 809 Args: 810 y: `Tensor`. The input to the "inverse_log_det_jacobian" evaluation. 811 Returns: 812 inverse_log_det_jacobian: `Tensor`, if this bijector is injective. 813 If not injective, returns the k-tuple containing jacobians for the 814 unique `k` points `(x1, ..., xk)` such that `g(xi) = y`. 815 """ 816 raise NotImplementedError("inverse_log_det_jacobian not implemented.") 817 818 def _call_inverse_log_det_jacobian(self, y, event_ndims, name, **kwargs): 819 with self._name_scope(name, [y]): 820 if event_ndims in self._constant_ildj_map: 821 return self._constant_ildj_map[event_ndims] 822 y = ops.convert_to_tensor(y, name="y") 823 self._maybe_assert_dtype(y) 824 with ops.control_dependencies(self._check_valid_event_ndims( 825 min_event_ndims=self.inverse_min_event_ndims, 826 event_ndims=event_ndims)): 827 if not self._is_injective: # No caching for non-injective 828 try: 829 ildjs = self._inverse_log_det_jacobian(y, **kwargs) 830 return tuple(self._reduce_jacobian_det_over_event( 831 y, ildj, self.inverse_min_event_ndims, event_ndims) 832 for ildj in ildjs) 833 except NotImplementedError as original_exception: 834 try: 835 x = self._inverse(y, **kwargs) 836 fldjs = self._forward_log_det_jacobian(x, **kwargs) 837 return tuple(self._reduce_jacobian_det_over_event( 838 x, -fldj, self.forward_min_event_ndims, event_ndims) 839 for fldj in fldjs) 840 except NotImplementedError: 841 raise original_exception 842 843 mapping = self._lookup(y=y, kwargs=kwargs) 844 if mapping.ildj_map is not None and event_ndims in mapping.ildj_map: 845 return mapping.ildj_map[event_ndims] 846 try: 847 x = None # Not needed; leave cache as is. 848 ildj = self._inverse_log_det_jacobian(y, **kwargs) 849 ildj = self._reduce_jacobian_det_over_event( 850 y, ildj, self.inverse_min_event_ndims, event_ndims) 851 except NotImplementedError as original_exception: 852 try: 853 x = (mapping.x if mapping.x is not None 854 else self._inverse(y, **kwargs)) 855 ildj = -self._forward_log_det_jacobian(x, **kwargs) 856 ildj = self._reduce_jacobian_det_over_event( 857 x, ildj, self.forward_min_event_ndims, event_ndims) 858 except NotImplementedError: 859 raise original_exception 860 861 mapping = mapping.merge(x=x, ildj_map={event_ndims: ildj}) 862 self._cache(mapping) 863 if self.is_constant_jacobian: 864 self._constant_ildj_map[event_ndims] = ildj 865 return ildj 866 867 def inverse_log_det_jacobian( 868 self, y, event_ndims, name="inverse_log_det_jacobian"): 869 """Returns the (log o det o Jacobian o inverse)(y). 870 871 Mathematically, returns: `log(det(dX/dY))(Y)`. (Recall that: `X=g^{-1}(Y)`.) 872 873 Note that `forward_log_det_jacobian` is the negative of this function, 874 evaluated at `g^{-1}(y)`. 875 876 Args: 877 y: `Tensor`. The input to the "inverse" Jacobian determinant evaluation. 878 event_ndims: Number of dimensions in the probabilistic events being 879 transformed. Must be greater than or equal to 880 `self.inverse_min_event_ndims`. The result is summed over the final 881 dimensions to produce a scalar Jacobian determinant for each event, 882 i.e. it has shape `y.shape.ndims - event_ndims` dimensions. 883 name: The name to give this op. 884 885 Returns: 886 `Tensor`, if this bijector is injective. 887 If not injective, returns the tuple of local log det 888 Jacobians, `log(det(Dg_i^{-1}(y)))`, where `g_i` is the restriction 889 of `g` to the `ith` partition `Di`. 890 891 Raises: 892 TypeError: if `self.dtype` is specified and `y.dtype` is not 893 `self.dtype`. 894 NotImplementedError: if `_inverse_log_det_jacobian` is not implemented. 895 """ 896 return self._call_inverse_log_det_jacobian(y, event_ndims, name) 897 898 def _forward_log_det_jacobian(self, x): 899 """Subclass implementation of `forward_log_det_jacobian` public function. 900 901 In particular, this method differs from the public function, in that it 902 does not take `event_ndims`. Thus, this implements the minimal Jacobian 903 determinant calculation (i.e. over `forward_min_event_ndims`). 904 905 Args: 906 x: `Tensor`. The input to the "forward_log_det_jacobian" evaluation. 907 908 Returns: 909 forward_log_det_jacobian: `Tensor`, if this bijector is injective. 910 If not injective, returns the k-tuple containing jacobians for the 911 unique `k` points `(x1, ..., xk)` such that `g(xi) = y`. 912 """ 913 914 raise NotImplementedError( 915 "forward_log_det_jacobian not implemented.") 916 917 def _call_forward_log_det_jacobian(self, x, event_ndims, name, **kwargs): 918 if not self._is_injective: 919 raise NotImplementedError( 920 "forward_log_det_jacobian cannot be implemented for non-injective " 921 "transforms.") 922 with self._name_scope(name, [x]): 923 with ops.control_dependencies(self._check_valid_event_ndims( 924 min_event_ndims=self.forward_min_event_ndims, 925 event_ndims=event_ndims)): 926 if event_ndims in self._constant_ildj_map: 927 # Need "-1. *" to avoid invalid-unary-operand-type linter warning. 928 return -1. * self._constant_ildj_map[event_ndims] 929 x = ops.convert_to_tensor(x, name="x") 930 self._maybe_assert_dtype(x) 931 if not self._is_injective: # No caching for non-injective 932 try: 933 fldjs = self._forward_log_det_jacobian(x, **kwargs) # No caching. 934 return tuple(self._reduce_jacobian_det_over_event( 935 x, fldj, self.forward_min_event_ndims, event_ndims) 936 for fldj in fldjs) 937 except NotImplementedError as original_exception: 938 try: 939 y = self._forward(x, **kwargs) 940 ildjs = self._inverse_log_det_jacobian(y, **kwargs) 941 return tuple(self._reduce_jacobian_det_over_event( 942 y, -ildj, self.inverse_min_event_ndims, event_ndims) 943 for ildj in ildjs) 944 except NotImplementedError: 945 raise original_exception 946 mapping = self._lookup(x=x, kwargs=kwargs) 947 if mapping.ildj_map is not None and event_ndims in mapping.ildj_map: 948 return -mapping.ildj_map[event_ndims] 949 try: 950 y = None # Not needed; leave cache as is. 951 ildj = -self._forward_log_det_jacobian(x, **kwargs) 952 ildj = self._reduce_jacobian_det_over_event( 953 x, ildj, self.forward_min_event_ndims, event_ndims) 954 except NotImplementedError as original_exception: 955 try: 956 y = (mapping.y if mapping.y is not None 957 else self._forward(x, **kwargs)) 958 ildj = self._inverse_log_det_jacobian(y, **kwargs) 959 ildj = self._reduce_jacobian_det_over_event( 960 y, ildj, self.inverse_min_event_ndims, event_ndims) 961 except NotImplementedError: 962 raise original_exception 963 mapping = mapping.merge(y=y, ildj_map={event_ndims: ildj}) 964 self._cache(mapping) 965 if self.is_constant_jacobian: 966 self._constant_ildj_map[event_ndims] = ildj 967 return -ildj 968 969 def forward_log_det_jacobian( 970 self, x, event_ndims, name="forward_log_det_jacobian"): 971 """Returns both the forward_log_det_jacobian. 972 973 Args: 974 x: `Tensor`. The input to the "forward" Jacobian determinant evaluation. 975 event_ndims: Number of dimensions in the probabilistic events being 976 transformed. Must be greater than or equal to 977 `self.forward_min_event_ndims`. The result is summed over the final 978 dimensions to produce a scalar Jacobian determinant for each event, 979 i.e. it has shape `x.shape.ndims - event_ndims` dimensions. 980 name: The name to give this op. 981 982 Returns: 983 `Tensor`, if this bijector is injective. 984 If not injective this is not implemented. 985 986 Raises: 987 TypeError: if `self.dtype` is specified and `y.dtype` is not 988 `self.dtype`. 989 NotImplementedError: if neither `_forward_log_det_jacobian` 990 nor {`_inverse`, `_inverse_log_det_jacobian`} are implemented, or 991 this is a non-injective bijector. 992 """ 993 return self._call_forward_log_det_jacobian(x, event_ndims, name) 994 995 @contextlib.contextmanager 996 def _name_scope(self, name=None, values=None): 997 """Helper function to standardize op scope.""" 998 with ops.name_scope(self.name): 999 with ops.name_scope( 1000 name, values=(values or []) + self.graph_parents) as scope: 1001 yield scope 1002 1003 def _maybe_assert_dtype(self, x): 1004 """Helper to check dtype when self.dtype is known.""" 1005 if self.dtype is not None and self.dtype.base_dtype != x.dtype.base_dtype: 1006 raise TypeError("Input had dtype %s but expected %s." % 1007 (self.dtype, x.dtype)) 1008 1009 def _cache(self, mapping): 1010 """Helper which stores mapping info in forward/inverse dicts.""" 1011 # Merging from lookup is an added check that we're not overwriting anything 1012 # which is not None. 1013 mapping = mapping.merge(mapping=self._lookup( 1014 mapping.x, mapping.y, mapping.kwargs)) 1015 if mapping.x is None and mapping.y is None: 1016 raise ValueError("Caching expects at least one of (x,y) to be known, " 1017 "i.e., not None.") 1018 self._from_x[mapping.x_key] = mapping 1019 self._from_y[mapping.y_key] = mapping 1020 1021 def _lookup(self, x=None, y=None, kwargs=None): 1022 """Helper which retrieves mapping info from forward/inverse dicts.""" 1023 mapping = _Mapping(x=x, y=y, kwargs=kwargs) 1024 # Since _cache requires both x,y to be set, we only need to do one cache 1025 # lookup since the mapping is always in both or neither. 1026 if mapping.x is not None: 1027 return self._from_x.get(mapping.x_key, mapping) 1028 if mapping.y is not None: 1029 return self._from_y.get(mapping.y_key, mapping) 1030 return mapping 1031 1032 def _reduce_jacobian_det_over_event( 1033 self, y, ildj, min_event_ndims, event_ndims): 1034 """Reduce jacobian over event_ndims - min_event_ndims.""" 1035 # In this case, we need to tile the Jacobian over the event and reduce. 1036 y_rank = array_ops.rank(y) 1037 y_shape = array_ops.shape(y)[ 1038 y_rank - event_ndims : y_rank - min_event_ndims] 1039 1040 ones = array_ops.ones(y_shape, ildj.dtype) 1041 reduced_ildj = math_ops.reduce_sum( 1042 ones * ildj, 1043 axis=self._get_event_reduce_dims(min_event_ndims, event_ndims)) 1044 # The multiplication by ones can change the inferred static shape so we try 1045 # to recover as much as possible. 1046 event_ndims_ = self._maybe_get_static_event_ndims(event_ndims) 1047 if (event_ndims_ is not None and 1048 y.shape.ndims is not None and 1049 ildj.shape.ndims is not None): 1050 y_shape = y.shape[y.shape.ndims - event_ndims_ : 1051 y.shape.ndims - min_event_ndims] 1052 broadcast_shape = array_ops.broadcast_static_shape(ildj.shape, y_shape) 1053 reduced_ildj.set_shape( 1054 broadcast_shape[: broadcast_shape.ndims - ( 1055 event_ndims_ - min_event_ndims)]) 1056 1057 return reduced_ildj 1058 1059 def _get_event_reduce_dims(self, min_event_ndims, event_ndims): 1060 """Compute the reduction dimensions given event_ndims.""" 1061 event_ndims_ = self._maybe_get_static_event_ndims(event_ndims) 1062 1063 if event_ndims_ is not None: 1064 return [-index for index in range(1, event_ndims_ - min_event_ndims + 1)] 1065 else: 1066 reduce_ndims = event_ndims - min_event_ndims 1067 return math_ops.range(-reduce_ndims, 0) 1068 1069 def _check_valid_event_ndims(self, min_event_ndims, event_ndims): 1070 """Check whether event_ndims is atleast min_event_ndims.""" 1071 event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") 1072 event_ndims_ = tensor_util.constant_value(event_ndims) 1073 assertions = [] 1074 1075 if not event_ndims.dtype.is_integer: 1076 raise ValueError("Expected integer dtype, got dtype {}".format( 1077 event_ndims.dtype)) 1078 1079 if event_ndims_ is not None: 1080 if event_ndims.shape.ndims != 0: 1081 raise ValueError("Expected scalar event_ndims, got shape {}".format( 1082 event_ndims.shape)) 1083 if min_event_ndims > event_ndims_: 1084 raise ValueError("event_ndims ({}) must be larger than " 1085 "min_event_ndims ({})".format( 1086 event_ndims_, min_event_ndims)) 1087 elif self.validate_args: 1088 assertions += [ 1089 check_ops.assert_greater_equal(event_ndims, min_event_ndims)] 1090 1091 if event_ndims.shape.is_fully_defined(): 1092 if event_ndims.shape.ndims != 0: 1093 raise ValueError("Expected scalar shape, got ndims {}".format( 1094 event_ndims.shape.ndims)) 1095 1096 elif self.validate_args: 1097 assertions += [ 1098 check_ops.assert_rank(event_ndims, 0, message="Expected scalar.")] 1099 return assertions 1100 1101 def _maybe_get_static_event_ndims(self, event_ndims): 1102 """Helper which returns tries to return an integer static value.""" 1103 event_ndims_ = distribution_util.maybe_get_static_value(event_ndims) 1104 1105 if isinstance(event_ndims_, (np.generic, np.ndarray)): 1106 if event_ndims_.dtype not in (np.int32, np.int64): 1107 raise ValueError("Expected integer dtype, got dtype {}".format( 1108 event_ndims_.dtype)) 1109 1110 if isinstance(event_ndims_, np.ndarray) and len(event_ndims_.shape): 1111 raise ValueError("Expected a scalar integer, got {}".format( 1112 event_ndims_)) 1113 event_ndims_ = int(event_ndims_) 1114 1115 return event_ndims_ 1116