• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Bijector base."""
16
17import abc
18import collections
19import contextlib
20import re
21
22import numpy as np
23
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import tensor_shape
27from tensorflow.python.framework import tensor_util
28from tensorflow.python.ops import array_ops
29from tensorflow.python.ops import check_ops
30from tensorflow.python.ops import math_ops
31from tensorflow.python.ops.distributions import util as distribution_util
32from tensorflow.python.util import object_identity
33
34
35__all__ = [
36    "Bijector",
37]
38
39
40class _Mapping(collections.namedtuple(
41    "_Mapping", ["x", "y", "ildj_map", "kwargs"])):
42  """Helper class to make it easier to manage caching in `Bijector`."""
43
44  def __new__(cls, x=None, y=None, ildj_map=None, kwargs=None):
45    """Custom __new__ so namedtuple items have defaults.
46
47    Args:
48      x: `Tensor`. Forward.
49      y: `Tensor`. Inverse.
50      ildj_map: `Dictionary`. This is a mapping from event_ndims to a `Tensor`
51        representing the inverse log det jacobian.
52      kwargs: Python dictionary. Extra args supplied to
53        forward/inverse/etc functions.
54
55    Returns:
56      mapping: New instance of _Mapping.
57    """
58    return super(_Mapping, cls).__new__(cls, x, y, ildj_map, kwargs)
59
60  @property
61  def x_key(self):
62    """Returns key used for caching Y=g(X)."""
63    return ((object_identity.Reference(self.x),) +
64            self._deep_tuple(tuple(sorted(self.kwargs.items()))))
65
66  @property
67  def y_key(self):
68    """Returns key used for caching X=g^{-1}(Y)."""
69    return ((object_identity.Reference(self.y),) +
70            self._deep_tuple(tuple(sorted(self.kwargs.items()))))
71
72  def merge(self, x=None, y=None, ildj_map=None, kwargs=None, mapping=None):
73    """Returns new _Mapping with args merged with self.
74
75    Args:
76      x: `Tensor`. Forward.
77      y: `Tensor`. Inverse.
78      ildj_map: `Dictionary`. This is a mapping from event_ndims to a `Tensor`
79        representing the inverse log det jacobian.
80      kwargs: Python dictionary. Extra args supplied to
81        forward/inverse/etc functions.
82      mapping: Instance of _Mapping to merge. Can only be specified if no other
83        arg is specified.
84
85    Returns:
86      mapping: New instance of `_Mapping` which has inputs merged with self.
87
88    Raises:
89      ValueError: if mapping and any other arg is not `None`.
90    """
91    if mapping is None:
92      mapping = _Mapping(x=x, y=y, ildj_map=ildj_map, kwargs=kwargs)
93    elif any(arg is not None for arg in [x, y, ildj_map, kwargs]):
94      raise ValueError("Cannot simultaneously specify mapping and individual "
95                       "arguments.")
96
97    return _Mapping(
98        x=self._merge(self.x, mapping.x),
99        y=self._merge(self.y, mapping.y),
100        ildj_map=self._merge_dicts(self.ildj_map, mapping.ildj_map),
101        kwargs=self._merge(self.kwargs, mapping.kwargs))
102
103  def _merge_dicts(self, old=None, new=None):
104    """Helper to merge two dictionaries."""
105    old = {} if old is None else old
106    new = {} if new is None else new
107    for k, v in new.items():
108      val = old.get(k, None)
109      if val is not None and val is not v:
110        raise ValueError("Found different value for existing key "
111                         "(key:{} old_value:{} new_value:{}".format(
112                             k, old[k], v))
113      old[k] = v
114    return old
115
116  def _merge(self, old, new):
117    """Helper to merge which handles merging one value."""
118    if old is None:
119      return new
120    elif new is not None and old is not new:
121      raise ValueError("Incompatible values: %s != %s" % (old, new))
122    return old
123
124  def _deep_tuple(self, x):
125    """Converts lists of lists to tuples of tuples."""
126    return (tuple(map(self._deep_tuple, x))
127            if isinstance(x, (list, tuple)) else x)
128
129
130class Bijector(metaclass=abc.ABCMeta):
131  r"""Interface for transformations of a `Distribution` sample.
132
133  Bijectors can be used to represent any differentiable and injective
134  (one to one) function defined on an open subset of `R^n`.  Some non-injective
135  transformations are also supported (see "Non Injective Transforms" below).
136
137  #### Mathematical Details
138
139  A `Bijector` implements a [smooth covering map](
140  https://en.wikipedia.org/wiki/Local_diffeomorphism), i.e., a local
141  diffeomorphism such that every point in the target has a neighborhood evenly
142  covered by a map ([see also](
143  https://en.wikipedia.org/wiki/Covering_space#Covering_of_a_manifold)).
144  A `Bijector` is used by `TransformedDistribution` but can be generally used
145  for transforming a `Distribution` generated `Tensor`. A `Bijector` is
146  characterized by three operations:
147
148  1. Forward
149
150     Useful for turning one random outcome into another random outcome from a
151     different distribution.
152
153  2. Inverse
154
155     Useful for "reversing" a transformation to compute one probability in
156     terms of another.
157
158  3. `log_det_jacobian(x)`
159
160     "The log of the absolute value of the determinant of the matrix of all
161     first-order partial derivatives of the inverse function."
162
163     Useful for inverting a transformation to compute one probability in terms
164     of another. Geometrically, the Jacobian determinant is the volume of the
165     transformation and is used to scale the probability.
166
167     We take the absolute value of the determinant before log to avoid NaN
168     values.  Geometrically, a negative determinant corresponds to an
169     orientation-reversing transformation.  It is ok for us to discard the sign
170     of the determinant because we only integrate everywhere-nonnegative
171     functions (probability densities) and the correct orientation is always the
172     one that produces a nonnegative integrand.
173
174  By convention, transformations of random variables are named in terms of the
175  forward transformation. The forward transformation creates samples, the
176  inverse is useful for computing probabilities.
177
178  #### Example Uses
179
180  - Basic properties:
181
182  ```python
183  x = ...  # A tensor.
184  # Evaluate forward transformation.
185  fwd_x = my_bijector.forward(x)
186  x == my_bijector.inverse(fwd_x)
187  x != my_bijector.forward(fwd_x)  # Not equal because x != g(g(x)).
188  ```
189
190  - Computing a log-likelihood:
191
192  ```python
193  def transformed_log_prob(bijector, log_prob, x):
194    return (bijector.inverse_log_det_jacobian(x, event_ndims=0) +
195            log_prob(bijector.inverse(x)))
196  ```
197
198  - Transforming a random outcome:
199
200  ```python
201  def transformed_sample(bijector, x):
202    return bijector.forward(x)
203  ```
204
205  #### Example Bijectors
206
207  - "Exponential"
208
209    ```none
210    Y = g(X) = exp(X)
211    X ~ Normal(0, 1)  # Univariate.
212    ```
213
214    Implies:
215
216    ```none
217      g^{-1}(Y) = log(Y)
218      |Jacobian(g^{-1})(y)| = 1 / y
219      Y ~ LogNormal(0, 1), i.e.,
220      prob(Y=y) = |Jacobian(g^{-1})(y)| * prob(X=g^{-1}(y))
221                = (1 / y) Normal(log(y); 0, 1)
222    ```
223
224    Here is an example of how one might implement the `Exp` bijector:
225
226    ```python
227      class Exp(Bijector):
228
229        def __init__(self, validate_args=False, name="exp"):
230          super(Exp, self).__init__(
231              validate_args=validate_args,
232              forward_min_event_ndims=0,
233              name=name)
234
235        def _forward(self, x):
236          return math_ops.exp(x)
237
238        def _inverse(self, y):
239          return math_ops.log(y)
240
241        def _inverse_log_det_jacobian(self, y):
242          return -self._forward_log_det_jacobian(self._inverse(y))
243
244        def _forward_log_det_jacobian(self, x):
245          # Notice that we needn't do any reducing, even when`event_ndims > 0`.
246          # The base Bijector class will handle reducing for us; it knows how
247          # to do so because we called `super` `__init__` with
248          # `forward_min_event_ndims = 0`.
249          return x
250      ```
251
252  - "Affine"
253
254    ```none
255    Y = g(X) = sqrtSigma * X + mu
256    X ~ MultivariateNormal(0, I_d)
257    ```
258
259    Implies:
260
261    ```none
262      g^{-1}(Y) = inv(sqrtSigma) * (Y - mu)
263      |Jacobian(g^{-1})(y)| = det(inv(sqrtSigma))
264      Y ~ MultivariateNormal(mu, sqrtSigma) , i.e.,
265      prob(Y=y) = |Jacobian(g^{-1})(y)| * prob(X=g^{-1}(y))
266                = det(sqrtSigma)^(-d) *
267                  MultivariateNormal(inv(sqrtSigma) * (y - mu); 0, I_d)
268      ```
269
270  #### Min_event_ndims and Naming
271
272  Bijectors are named for the dimensionality of data they act on (i.e. without
273  broadcasting). We can think of bijectors having an intrinsic `min_event_ndims`
274  , which is the minimum number of dimensions for the bijector act on. For
275  instance, a Cholesky decomposition requires a matrix, and hence
276  `min_event_ndims=2`.
277
278  Some examples:
279
280  `AffineScalar:  min_event_ndims=0`
281  `Affine:  min_event_ndims=1`
282  `Cholesky:  min_event_ndims=2`
283  `Exp:  min_event_ndims=0`
284  `Sigmoid:  min_event_ndims=0`
285  `SoftmaxCentered:  min_event_ndims=1`
286
287  Note the difference between `Affine` and `AffineScalar`. `AffineScalar`
288  operates on scalar events, whereas `Affine` operates on vector-valued events.
289
290  More generally, there is a `forward_min_event_ndims` and an
291  `inverse_min_event_ndims`. In most cases, these will be the same.
292  However, for some shape changing bijectors, these will be different
293  (e.g. a bijector which pads an extra dimension at the end, might have
294  `forward_min_event_ndims=0` and `inverse_min_event_ndims=1`.
295
296
297  #### Jacobian Determinant
298
299  The Jacobian determinant is a reduction over `event_ndims - min_event_ndims`
300  (`forward_min_event_ndims` for `forward_log_det_jacobian` and
301  `inverse_min_event_ndims` for `inverse_log_det_jacobian`).
302  To see this, consider the `Exp` `Bijector` applied to a `Tensor` which has
303  sample, batch, and event (S, B, E) shape semantics. Suppose the `Tensor`'s
304  partitioned-shape is `(S=[4], B=[2], E=[3, 3])`. The shape of the `Tensor`
305  returned by `forward` and `inverse` is unchanged, i.e., `[4, 2, 3, 3]`.
306  However the shape returned by `inverse_log_det_jacobian` is `[4, 2]` because
307  the Jacobian determinant is a reduction over the event dimensions.
308
309  Another example is the `Affine` `Bijector`. Because `min_event_ndims = 1`, the
310  Jacobian determinant reduction is over `event_ndims - 1`.
311
312  It is sometimes useful to implement the inverse Jacobian determinant as the
313  negative forward Jacobian determinant. For example,
314
315  ```python
316  def _inverse_log_det_jacobian(self, y):
317     return -self._forward_log_det_jac(self._inverse(y))  # Note negation.
318  ```
319
320  The correctness of this approach can be seen from the following claim.
321
322  - Claim:
323
324      Assume `Y = g(X)` is a bijection whose derivative exists and is nonzero
325      for its domain, i.e., `dY/dX = d/dX g(X) != 0`. Then:
326
327      ```none
328      (log o det o jacobian o g^{-1})(Y) = -(log o det o jacobian o g)(X)
329      ```
330
331  - Proof:
332
333      From the bijective, nonzero differentiability of `g`, the
334      [inverse function theorem](
335          https://en.wikipedia.org/wiki/Inverse_function_theorem)
336      implies `g^{-1}` is differentiable in the image of `g`.
337      Applying the chain rule to `y = g(x) = g(g^{-1}(y))` yields
338      `I = g'(g^{-1}(y))*g^{-1}'(y)`.
339      The same theorem also implies `g^{-1}'` is non-singular therefore:
340      `inv[ g'(g^{-1}(y)) ] = g^{-1}'(y)`.
341      The claim follows from [properties of determinant](
342  https://en.wikipedia.org/wiki/Determinant#Multiplicativity_and_matrix_groups).
343
344  Generally its preferable to directly implement the inverse Jacobian
345  determinant.  This should have superior numerical stability and will often
346  share subgraphs with the `_inverse` implementation.
347
348  #### Is_constant_jacobian
349
350  Certain bijectors will have constant jacobian matrices. For instance, the
351  `Affine` bijector encodes multiplication by a matrix plus a shift, with
352  jacobian matrix, the same aforementioned matrix.
353
354  `is_constant_jacobian` encodes the fact that the jacobian matrix is constant.
355  The semantics of this argument are the following:
356
357    * Repeated calls to "log_det_jacobian" functions with the same
358      `event_ndims` (but not necessarily same input), will return the first
359      computed jacobian (because the matrix is constant, and hence is input
360      independent).
361    * `log_det_jacobian` implementations are merely broadcastable to the true
362      `log_det_jacobian` (because, again, the jacobian matrix is input
363      independent). Specifically, `log_det_jacobian` is implemented as the
364      log jacobian determinant for a single input.
365
366      ```python
367      class Identity(Bijector):
368
369        def __init__(self, validate_args=False, name="identity"):
370          super(Identity, self).__init__(
371              is_constant_jacobian=True,
372              validate_args=validate_args,
373              forward_min_event_ndims=0,
374              name=name)
375
376        def _forward(self, x):
377          return x
378
379        def _inverse(self, y):
380          return y
381
382        def _inverse_log_det_jacobian(self, y):
383          return -self._forward_log_det_jacobian(self._inverse(y))
384
385        def _forward_log_det_jacobian(self, x):
386          # The full log jacobian determinant would be array_ops.zero_like(x).
387          # However, we circumvent materializing that, since the jacobian
388          # calculation is input independent, and we specify it for one input.
389          return constant_op.constant(0., x.dtype.base_dtype)
390
391      ```
392
393  #### Subclass Requirements
394
395  - Subclasses typically implement:
396
397      - `_forward`,
398      - `_inverse`,
399      - `_inverse_log_det_jacobian`,
400      - `_forward_log_det_jacobian` (optional).
401
402    The `_forward_log_det_jacobian` is called when the bijector is inverted via
403    the `Invert` bijector. If undefined, a slightly less efficiently
404    calculation, `-1 * _inverse_log_det_jacobian`, is used.
405
406    If the bijector changes the shape of the input, you must also implement:
407
408      - _forward_event_shape_tensor,
409      - _forward_event_shape (optional),
410      - _inverse_event_shape_tensor,
411      - _inverse_event_shape (optional).
412
413    By default the event-shape is assumed unchanged from input.
414
415  - If the `Bijector`'s use is limited to `TransformedDistribution` (or friends
416    like `QuantizedDistribution`) then depending on your use, you may not need
417    to implement all of `_forward` and `_inverse` functions.
418
419    Examples:
420
421      1. Sampling (e.g., `sample`) only requires `_forward`.
422      2. Probability functions (e.g., `prob`, `cdf`, `survival`) only require
423         `_inverse` (and related).
424      3. Only calling probability functions on the output of `sample` means
425        `_inverse` can be implemented as a cache lookup.
426
427    See "Example Uses" [above] which shows how these functions are used to
428    transform a distribution. (Note: `_forward` could theoretically be
429    implemented as a cache lookup but this would require controlling the
430    underlying sample generation mechanism.)
431
432  #### Non Injective Transforms
433
434  **WARNING** Handing of non-injective transforms is subject to change.
435
436  Non injective maps `g` are supported, provided their domain `D` can be
437  partitioned into `k` disjoint subsets, `Union{D1, ..., Dk}`, such that,
438  ignoring sets of measure zero, the restriction of `g` to each subset is a
439  differentiable bijection onto `g(D)`.  In particular, this implies that for
440  `y in g(D)`, the set inverse, i.e. `g^{-1}(y) = {x in D : g(x) = y}`, always
441  contains exactly `k` distinct points.
442
443  The property, `_is_injective` is set to `False` to indicate that the bijector
444  is not injective, yet satisfies the above condition.
445
446  The usual bijector API is modified in the case `_is_injective is False` (see
447  method docstrings for specifics).  Here we show by example the `AbsoluteValue`
448  bijector.  In this case, the domain `D = (-inf, inf)`, can be partitioned
449  into `D1 = (-inf, 0)`, `D2 = {0}`, and `D3 = (0, inf)`.  Let `gi` be the
450  restriction of `g` to `Di`, then both `g1` and `g3` are bijections onto
451  `(0, inf)`, with `g1^{-1}(y) = -y`, and `g3^{-1}(y) = y`.  We will use
452  `g1` and `g3` to define bijector methods over `D1` and `D3`.  `D2 = {0}` is
453  an oddball in that `g2` is one to one, and the derivative is not well defined.
454  Fortunately, when considering transformations of probability densities
455  (e.g. in `TransformedDistribution`), sets of measure zero have no effect in
456  theory, and only a small effect in 32 or 64 bit precision.  For that reason,
457  we define `inverse(0)` and `inverse_log_det_jacobian(0)` both as `[0, 0]`,
458  which is convenient and results in a left-semicontinuous pdf.
459
460
461  ```python
462  abs = tfp.distributions.bijectors.AbsoluteValue()
463
464  abs.forward(-1.)
465  ==> 1.
466
467  abs.forward(1.)
468  ==> 1.
469
470  abs.inverse(1.)
471  ==> (-1., 1.)
472
473  # The |dX/dY| is constant, == 1.  So Log|dX/dY| == 0.
474  abs.inverse_log_det_jacobian(1., event_ndims=0)
475  ==> (0., 0.)
476
477  # Special case handling of 0.
478  abs.inverse(0.)
479  ==> (0., 0.)
480
481  abs.inverse_log_det_jacobian(0., event_ndims=0)
482  ==> (0., 0.)
483  ```
484
485  """
486
487  @abc.abstractmethod
488  def __init__(self,
489               graph_parents=None,
490               is_constant_jacobian=False,
491               validate_args=False,
492               dtype=None,
493               forward_min_event_ndims=None,
494               inverse_min_event_ndims=None,
495               name=None):
496    """Constructs Bijector.
497
498    A `Bijector` transforms random variables into new random variables.
499
500    Examples:
501
502    ```python
503    # Create the Y = g(X) = X transform.
504    identity = Identity()
505
506    # Create the Y = g(X) = exp(X) transform.
507    exp = Exp()
508    ```
509
510    See `Bijector` subclass docstring for more details and specific examples.
511
512    Args:
513      graph_parents: Python list of graph prerequisites of this `Bijector`.
514      is_constant_jacobian: Python `bool` indicating that the Jacobian matrix is
515        not a function of the input.
516      validate_args: Python `bool`, default `False`. Whether to validate input
517        with asserts. If `validate_args` is `False`, and the inputs are invalid,
518        correct behavior is not guaranteed.
519      dtype: `tf.dtype` supported by this `Bijector`. `None` means dtype is not
520        enforced.
521      forward_min_event_ndims: Python `integer` indicating the minimum number of
522        dimensions `forward` operates on.
523      inverse_min_event_ndims: Python `integer` indicating the minimum number of
524        dimensions `inverse` operates on. Will be set to
525        `forward_min_event_ndims` by default, if no value is provided.
526      name: The name to give Ops created by the initializer.
527
528    Raises:
529      ValueError:  If neither `forward_min_event_ndims` and
530        `inverse_min_event_ndims` are specified, or if either of them is
531        negative.
532      ValueError:  If a member of `graph_parents` is not a `Tensor`.
533    """
534    self._graph_parents = graph_parents or []
535
536    if forward_min_event_ndims is None and inverse_min_event_ndims is None:
537      raise ValueError("Must specify at least one of `forward_min_event_ndims` "
538                       "and `inverse_min_event_ndims`.")
539    elif inverse_min_event_ndims is None:
540      inverse_min_event_ndims = forward_min_event_ndims
541    elif forward_min_event_ndims is None:
542      forward_min_event_ndims = inverse_min_event_ndims
543
544    if not isinstance(forward_min_event_ndims, int):
545      raise TypeError("Expected forward_min_event_ndims to be of "
546                      "type int, got {}".format(
547                          type(forward_min_event_ndims).__name__))
548
549    if not isinstance(inverse_min_event_ndims, int):
550      raise TypeError("Expected inverse_min_event_ndims to be of "
551                      "type int, got {}".format(
552                          type(inverse_min_event_ndims).__name__))
553
554    if forward_min_event_ndims < 0:
555      raise ValueError("forward_min_event_ndims must be a non-negative "
556                       "integer.")
557    if inverse_min_event_ndims < 0:
558      raise ValueError("inverse_min_event_ndims must be a non-negative "
559                       "integer.")
560
561    self._forward_min_event_ndims = forward_min_event_ndims
562    self._inverse_min_event_ndims = inverse_min_event_ndims
563    self._is_constant_jacobian = is_constant_jacobian
564    self._constant_ildj_map = {}
565    self._validate_args = validate_args
566    self._dtype = dtype
567    # These dicts can only be accessed using _Mapping.x_key or _Mapping.y_key
568    self._from_y = {}
569    self._from_x = {}
570    if name:
571      self._name = name
572    else:
573      # We want the default convention to be snake_case rather than CamelCase
574      # since `Chain` uses bijector.name as the kwargs dictionary key.
575      def camel_to_snake(name):
576        s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
577        return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
578      self._name = camel_to_snake(type(self).__name__.lstrip("_"))
579
580    for i, t in enumerate(self._graph_parents):
581      if t is None or not tensor_util.is_tf_type(t):
582        raise ValueError("Graph parent item %d is not a Tensor; %s." % (i, t))
583
584  @property
585  def graph_parents(self):
586    """Returns this `Bijector`'s graph_parents as a Python list."""
587    return self._graph_parents
588
589  @property
590  def forward_min_event_ndims(self):
591    """Returns the minimal number of dimensions bijector.forward operates on."""
592    return self._forward_min_event_ndims
593
594  @property
595  def inverse_min_event_ndims(self):
596    """Returns the minimal number of dimensions bijector.inverse operates on."""
597    return self._inverse_min_event_ndims
598
599  @property
600  def is_constant_jacobian(self):
601    """Returns true iff the Jacobian matrix is not a function of x.
602
603    Note: Jacobian matrix is either constant for both forward and inverse or
604    neither.
605
606    Returns:
607      is_constant_jacobian: Python `bool`.
608    """
609    return self._is_constant_jacobian
610
611  @property
612  def _is_injective(self):
613    """Returns true iff the forward map `g` is injective (one-to-one function).
614
615    **WARNING** This hidden property and its behavior are subject to change.
616
617    Note:  Non-injective maps `g` are supported, provided their domain `D` can
618    be partitioned into `k` disjoint subsets, `Union{D1, ..., Dk}`, such that,
619    ignoring sets of measure zero, the restriction of `g` to each subset is a
620    differentiable bijection onto `g(D)`.
621
622    Returns:
623      is_injective: Python `bool`.
624    """
625    return True
626
627  @property
628  def validate_args(self):
629    """Returns True if Tensor arguments will be validated."""
630    return self._validate_args
631
632  @property
633  def dtype(self):
634    """dtype of `Tensor`s transformable by this distribution."""
635    return self._dtype
636
637  @property
638  def name(self):
639    """Returns the string name of this `Bijector`."""
640    return self._name
641
642  def _forward_event_shape_tensor(self, input_shape):
643    """Subclass implementation for `forward_event_shape_tensor` function."""
644    # By default, we assume event_shape is unchanged.
645    return input_shape
646
647  def forward_event_shape_tensor(self,
648                                 input_shape,
649                                 name="forward_event_shape_tensor"):
650    """Shape of a single sample from a single batch as an `int32` 1D `Tensor`.
651
652    Args:
653      input_shape: `Tensor`, `int32` vector indicating event-portion shape
654        passed into `forward` function.
655      name: name to give to the op
656
657    Returns:
658      forward_event_shape_tensor: `Tensor`, `int32` vector indicating
659        event-portion shape after applying `forward`.
660    """
661    with self._name_scope(name, [input_shape]):
662      input_shape = ops.convert_to_tensor(input_shape, dtype=dtypes.int32,
663                                          name="input_shape")
664      return self._forward_event_shape_tensor(input_shape)
665
666  def _forward_event_shape(self, input_shape):
667    """Subclass implementation for `forward_event_shape` public function."""
668    # By default, we assume event_shape is unchanged.
669    return input_shape
670
671  def forward_event_shape(self, input_shape):
672    """Shape of a single sample from a single batch as a `TensorShape`.
673
674    Same meaning as `forward_event_shape_tensor`. May be only partially defined.
675
676    Args:
677      input_shape: `TensorShape` indicating event-portion shape passed into
678        `forward` function.
679
680    Returns:
681      forward_event_shape_tensor: `TensorShape` indicating event-portion shape
682        after applying `forward`. Possibly unknown.
683    """
684    return self._forward_event_shape(tensor_shape.TensorShape(input_shape))
685
686  def _inverse_event_shape_tensor(self, output_shape):
687    """Subclass implementation for `inverse_event_shape_tensor` function."""
688    # By default, we assume event_shape is unchanged.
689    return output_shape
690
691  def inverse_event_shape_tensor(self,
692                                 output_shape,
693                                 name="inverse_event_shape_tensor"):
694    """Shape of a single sample from a single batch as an `int32` 1D `Tensor`.
695
696    Args:
697      output_shape: `Tensor`, `int32` vector indicating event-portion shape
698        passed into `inverse` function.
699      name: name to give to the op
700
701    Returns:
702      inverse_event_shape_tensor: `Tensor`, `int32` vector indicating
703        event-portion shape after applying `inverse`.
704    """
705    with self._name_scope(name, [output_shape]):
706      output_shape = ops.convert_to_tensor(output_shape, dtype=dtypes.int32,
707                                           name="output_shape")
708      return self._inverse_event_shape_tensor(output_shape)
709
710  def _inverse_event_shape(self, output_shape):
711    """Subclass implementation for `inverse_event_shape` public function."""
712    # By default, we assume event_shape is unchanged.
713    return tensor_shape.TensorShape(output_shape)
714
715  def inverse_event_shape(self, output_shape):
716    """Shape of a single sample from a single batch as a `TensorShape`.
717
718    Same meaning as `inverse_event_shape_tensor`. May be only partially defined.
719
720    Args:
721      output_shape: `TensorShape` indicating event-portion shape passed into
722        `inverse` function.
723
724    Returns:
725      inverse_event_shape_tensor: `TensorShape` indicating event-portion shape
726        after applying `inverse`. Possibly unknown.
727    """
728    return self._inverse_event_shape(output_shape)
729
730  def _forward(self, x):
731    """Subclass implementation for `forward` public function."""
732    raise NotImplementedError("forward not implemented.")
733
734  def _call_forward(self, x, name, **kwargs):
735    with self._name_scope(name, [x]):
736      x = ops.convert_to_tensor(x, name="x")
737      self._maybe_assert_dtype(x)
738      if not self._is_injective:  # No caching for non-injective
739        return self._forward(x, **kwargs)
740      mapping = self._lookup(x=x, kwargs=kwargs)
741      if mapping.y is not None:
742        return mapping.y
743      mapping = mapping.merge(y=self._forward(x, **kwargs))
744      self._cache(mapping)
745      return mapping.y
746
747  def forward(self, x, name="forward"):
748    """Returns the forward `Bijector` evaluation, i.e., X = g(Y).
749
750    Args:
751      x: `Tensor`. The input to the "forward" evaluation.
752      name: The name to give this op.
753
754    Returns:
755      `Tensor`.
756
757    Raises:
758      TypeError: if `self.dtype` is specified and `x.dtype` is not
759        `self.dtype`.
760      NotImplementedError: if `_forward` is not implemented.
761    """
762    return self._call_forward(x, name)
763
764  def _inverse(self, y):
765    """Subclass implementation for `inverse` public function."""
766    raise NotImplementedError("inverse not implemented")
767
768  def _call_inverse(self, y, name, **kwargs):
769    with self._name_scope(name, [y]):
770      y = ops.convert_to_tensor(y, name="y")
771      self._maybe_assert_dtype(y)
772      if not self._is_injective:  # No caching for non-injective
773        return self._inverse(y, **kwargs)
774      mapping = self._lookup(y=y, kwargs=kwargs)
775      if mapping.x is not None:
776        return mapping.x
777      mapping = mapping.merge(x=self._inverse(y, **kwargs))
778      self._cache(mapping)
779      return mapping.x
780
781  def inverse(self, y, name="inverse"):
782    """Returns the inverse `Bijector` evaluation, i.e., X = g^{-1}(Y).
783
784    Args:
785      y: `Tensor`. The input to the "inverse" evaluation.
786      name: The name to give this op.
787
788    Returns:
789      `Tensor`, if this bijector is injective.
790        If not injective, returns the k-tuple containing the unique
791        `k` points `(x1, ..., xk)` such that `g(xi) = y`.
792
793    Raises:
794      TypeError: if `self.dtype` is specified and `y.dtype` is not
795        `self.dtype`.
796      NotImplementedError: if `_inverse` is not implemented.
797    """
798    return self._call_inverse(y, name)
799
800  def _inverse_log_det_jacobian(self, y):
801    """Subclass implementation of `inverse_log_det_jacobian` public function.
802
803    In particular, this method differs from the public function, in that it
804    does not take `event_ndims`. Thus, this implements the minimal Jacobian
805    determinant calculation (i.e. over `inverse_min_event_ndims`).
806
807    Args:
808      y: `Tensor`. The input to the "inverse_log_det_jacobian" evaluation.
809    Returns:
810      inverse_log_det_jacobian: `Tensor`, if this bijector is injective.
811        If not injective, returns the k-tuple containing jacobians for the
812        unique `k` points `(x1, ..., xk)` such that `g(xi) = y`.
813    """
814    raise NotImplementedError("inverse_log_det_jacobian not implemented.")
815
816  def _call_inverse_log_det_jacobian(self, y, event_ndims, name, **kwargs):
817    with self._name_scope(name, [y]):
818      if event_ndims in self._constant_ildj_map:
819        return self._constant_ildj_map[event_ndims]
820      y = ops.convert_to_tensor(y, name="y")
821      self._maybe_assert_dtype(y)
822      with ops.control_dependencies(self._check_valid_event_ndims(
823          min_event_ndims=self.inverse_min_event_ndims,
824          event_ndims=event_ndims)):
825        if not self._is_injective:  # No caching for non-injective
826          try:
827            ildjs = self._inverse_log_det_jacobian(y, **kwargs)
828            return tuple(self._reduce_jacobian_det_over_event(
829                y, ildj, self.inverse_min_event_ndims, event_ndims)
830                         for ildj in ildjs)
831          except NotImplementedError as original_exception:
832            try:
833              x = self._inverse(y, **kwargs)
834              fldjs = self._forward_log_det_jacobian(x, **kwargs)
835              return tuple(self._reduce_jacobian_det_over_event(
836                  x, -fldj, self.forward_min_event_ndims, event_ndims)
837                           for fldj in fldjs)
838            except NotImplementedError:
839              raise original_exception
840
841        mapping = self._lookup(y=y, kwargs=kwargs)
842        if mapping.ildj_map is not None and event_ndims in mapping.ildj_map:
843          return mapping.ildj_map[event_ndims]
844        try:
845          x = None  # Not needed; leave cache as is.
846          ildj = self._inverse_log_det_jacobian(y, **kwargs)
847          ildj = self._reduce_jacobian_det_over_event(
848              y, ildj, self.inverse_min_event_ndims, event_ndims)
849        except NotImplementedError as original_exception:
850          try:
851            x = (mapping.x if mapping.x is not None
852                 else self._inverse(y, **kwargs))
853            ildj = -self._forward_log_det_jacobian(x, **kwargs)
854            ildj = self._reduce_jacobian_det_over_event(
855                x, ildj, self.forward_min_event_ndims, event_ndims)
856          except NotImplementedError:
857            raise original_exception
858
859        mapping = mapping.merge(x=x, ildj_map={event_ndims: ildj})
860        self._cache(mapping)
861        if self.is_constant_jacobian:
862          self._constant_ildj_map[event_ndims] = ildj
863        return ildj
864
865  def inverse_log_det_jacobian(
866      self, y, event_ndims, name="inverse_log_det_jacobian"):
867    """Returns the (log o det o Jacobian o inverse)(y).
868
869    Mathematically, returns: `log(det(dX/dY))(Y)`. (Recall that: `X=g^{-1}(Y)`.)
870
871    Note that `forward_log_det_jacobian` is the negative of this function,
872    evaluated at `g^{-1}(y)`.
873
874    Args:
875      y: `Tensor`. The input to the "inverse" Jacobian determinant evaluation.
876      event_ndims: Number of dimensions in the probabilistic events being
877        transformed. Must be greater than or equal to
878        `self.inverse_min_event_ndims`. The result is summed over the final
879        dimensions to produce a scalar Jacobian determinant for each event,
880        i.e. it has shape `y.shape.ndims - event_ndims` dimensions.
881      name: The name to give this op.
882
883    Returns:
884      `Tensor`, if this bijector is injective.
885        If not injective, returns the tuple of local log det
886        Jacobians, `log(det(Dg_i^{-1}(y)))`, where `g_i` is the restriction
887        of `g` to the `ith` partition `Di`.
888
889    Raises:
890      TypeError: if `self.dtype` is specified and `y.dtype` is not
891        `self.dtype`.
892      NotImplementedError: if `_inverse_log_det_jacobian` is not implemented.
893    """
894    return self._call_inverse_log_det_jacobian(y, event_ndims, name)
895
896  def _forward_log_det_jacobian(self, x):
897    """Subclass implementation of `forward_log_det_jacobian` public function.
898
899    In particular, this method differs from the public function, in that it
900    does not take `event_ndims`. Thus, this implements the minimal Jacobian
901    determinant calculation (i.e. over `forward_min_event_ndims`).
902
903    Args:
904      x: `Tensor`. The input to the "forward_log_det_jacobian" evaluation.
905
906    Returns:
907      forward_log_det_jacobian: `Tensor`, if this bijector is injective.
908        If not injective, returns the k-tuple containing jacobians for the
909        unique `k` points `(x1, ..., xk)` such that `g(xi) = y`.
910    """
911
912    raise NotImplementedError(
913        "forward_log_det_jacobian not implemented.")
914
915  def _call_forward_log_det_jacobian(self, x, event_ndims, name, **kwargs):
916    if not self._is_injective:
917      raise NotImplementedError(
918          "forward_log_det_jacobian cannot be implemented for non-injective "
919          "transforms.")
920    with self._name_scope(name, [x]):
921      with ops.control_dependencies(self._check_valid_event_ndims(
922          min_event_ndims=self.forward_min_event_ndims,
923          event_ndims=event_ndims)):
924        if event_ndims in self._constant_ildj_map:
925          # Need "-1. *" to avoid invalid-unary-operand-type linter warning.
926          return -1. * self._constant_ildj_map[event_ndims]
927        x = ops.convert_to_tensor(x, name="x")
928        self._maybe_assert_dtype(x)
929        if not self._is_injective:  # No caching for non-injective
930          try:
931            fldjs = self._forward_log_det_jacobian(x, **kwargs)  # No caching.
932            return tuple(self._reduce_jacobian_det_over_event(
933                x, fldj, self.forward_min_event_ndims, event_ndims)
934                         for fldj in fldjs)
935          except NotImplementedError as original_exception:
936            try:
937              y = self._forward(x, **kwargs)
938              ildjs = self._inverse_log_det_jacobian(y, **kwargs)
939              return tuple(self._reduce_jacobian_det_over_event(
940                  y, -ildj, self.inverse_min_event_ndims, event_ndims)
941                           for ildj in ildjs)
942            except NotImplementedError:
943              raise original_exception
944        mapping = self._lookup(x=x, kwargs=kwargs)
945        if mapping.ildj_map is not None and event_ndims in mapping.ildj_map:
946          return -mapping.ildj_map[event_ndims]
947        try:
948          y = None  # Not needed; leave cache as is.
949          ildj = -self._forward_log_det_jacobian(x, **kwargs)
950          ildj = self._reduce_jacobian_det_over_event(
951              x, ildj, self.forward_min_event_ndims, event_ndims)
952        except NotImplementedError as original_exception:
953          try:
954            y = (mapping.y if mapping.y is not None
955                 else self._forward(x, **kwargs))
956            ildj = self._inverse_log_det_jacobian(y, **kwargs)
957            ildj = self._reduce_jacobian_det_over_event(
958                y, ildj, self.inverse_min_event_ndims, event_ndims)
959          except NotImplementedError:
960            raise original_exception
961        mapping = mapping.merge(y=y, ildj_map={event_ndims: ildj})
962        self._cache(mapping)
963        if self.is_constant_jacobian:
964          self._constant_ildj_map[event_ndims] = ildj
965        return -ildj
966
967  def forward_log_det_jacobian(
968      self, x, event_ndims, name="forward_log_det_jacobian"):
969    """Returns both the forward_log_det_jacobian.
970
971    Args:
972      x: `Tensor`. The input to the "forward" Jacobian determinant evaluation.
973      event_ndims: Number of dimensions in the probabilistic events being
974        transformed. Must be greater than or equal to
975        `self.forward_min_event_ndims`. The result is summed over the final
976        dimensions to produce a scalar Jacobian determinant for each event,
977        i.e. it has shape `x.shape.ndims - event_ndims` dimensions.
978      name: The name to give this op.
979
980    Returns:
981      `Tensor`, if this bijector is injective.
982        If not injective this is not implemented.
983
984    Raises:
985      TypeError: if `self.dtype` is specified and `y.dtype` is not
986        `self.dtype`.
987      NotImplementedError: if neither `_forward_log_det_jacobian`
988        nor {`_inverse`, `_inverse_log_det_jacobian`} are implemented, or
989        this is a non-injective bijector.
990    """
991    return self._call_forward_log_det_jacobian(x, event_ndims, name)
992
993  @contextlib.contextmanager
994  def _name_scope(self, name=None, values=None):
995    """Helper function to standardize op scope."""
996    with ops.name_scope(self.name):
997      with ops.name_scope(
998          name, values=(values or []) + self.graph_parents) as scope:
999        yield scope
1000
1001  def _maybe_assert_dtype(self, x):
1002    """Helper to check dtype when self.dtype is known."""
1003    if self.dtype is not None and self.dtype.base_dtype != x.dtype.base_dtype:
1004      raise TypeError("Input had dtype %s but expected %s." %
1005                      (self.dtype, x.dtype))
1006
1007  def _cache(self, mapping):
1008    """Helper which stores mapping info in forward/inverse dicts."""
1009    # Merging from lookup is an added check that we're not overwriting anything
1010    # which is not None.
1011    mapping = mapping.merge(mapping=self._lookup(
1012        mapping.x, mapping.y, mapping.kwargs))
1013    if mapping.x is None and mapping.y is None:
1014      raise ValueError("Caching expects at least one of (x,y) to be known, "
1015                       "i.e., not None.")
1016    self._from_x[mapping.x_key] = mapping
1017    self._from_y[mapping.y_key] = mapping
1018
1019  def _lookup(self, x=None, y=None, kwargs=None):
1020    """Helper which retrieves mapping info from forward/inverse dicts."""
1021    mapping = _Mapping(x=x, y=y, kwargs=kwargs)
1022    # Since _cache requires both x,y to be set, we only need to do one cache
1023    # lookup since the mapping is always in both or neither.
1024    if mapping.x is not None:
1025      return self._from_x.get(mapping.x_key, mapping)
1026    if mapping.y is not None:
1027      return self._from_y.get(mapping.y_key, mapping)
1028    return mapping
1029
1030  def _reduce_jacobian_det_over_event(
1031      self, y, ildj, min_event_ndims, event_ndims):
1032    """Reduce jacobian over event_ndims - min_event_ndims."""
1033    # In this case, we need to tile the Jacobian over the event and reduce.
1034    y_rank = array_ops.rank(y)
1035    y_shape = array_ops.shape(y)[
1036        y_rank - event_ndims : y_rank - min_event_ndims]
1037
1038    ones = array_ops.ones(y_shape, ildj.dtype)
1039    reduced_ildj = math_ops.reduce_sum(
1040        ones * ildj,
1041        axis=self._get_event_reduce_dims(min_event_ndims, event_ndims))
1042    # The multiplication by ones can change the inferred static shape so we try
1043    # to recover as much as possible.
1044    event_ndims_ = self._maybe_get_static_event_ndims(event_ndims)
1045    if (event_ndims_ is not None and
1046        y.shape.ndims is not None and
1047        ildj.shape.ndims is not None):
1048      y_shape = y.shape[y.shape.ndims - event_ndims_ :
1049                        y.shape.ndims - min_event_ndims]
1050      broadcast_shape = array_ops.broadcast_static_shape(ildj.shape, y_shape)
1051      reduced_ildj.set_shape(
1052          broadcast_shape[: broadcast_shape.ndims - (
1053              event_ndims_ - min_event_ndims)])
1054
1055    return reduced_ildj
1056
1057  def _get_event_reduce_dims(self, min_event_ndims, event_ndims):
1058    """Compute the reduction dimensions given event_ndims."""
1059    event_ndims_ = self._maybe_get_static_event_ndims(event_ndims)
1060
1061    if event_ndims_ is not None:
1062      return [-index for index in range(1, event_ndims_ - min_event_ndims + 1)]
1063    else:
1064      reduce_ndims = event_ndims - min_event_ndims
1065      return math_ops.range(-reduce_ndims, 0)
1066
1067  def _check_valid_event_ndims(self, min_event_ndims, event_ndims):
1068    """Check whether event_ndims is at least min_event_ndims."""
1069    event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims")
1070    event_ndims_ = tensor_util.constant_value(event_ndims)
1071    assertions = []
1072
1073    if not event_ndims.dtype.is_integer:
1074      raise ValueError("Expected integer dtype, got dtype {}".format(
1075          event_ndims.dtype))
1076
1077    if event_ndims_ is not None:
1078      if event_ndims.shape.ndims != 0:
1079        raise ValueError("Expected scalar event_ndims, got shape {}".format(
1080            event_ndims.shape))
1081      if min_event_ndims > event_ndims_:
1082        raise ValueError("event_ndims ({}) must be larger than "
1083                         "min_event_ndims ({})".format(
1084                             event_ndims_, min_event_ndims))
1085    elif self.validate_args:
1086      assertions += [
1087          check_ops.assert_greater_equal(event_ndims, min_event_ndims)]
1088
1089    if event_ndims.shape.is_fully_defined():
1090      if event_ndims.shape.ndims != 0:
1091        raise ValueError("Expected scalar shape, got ndims {}".format(
1092            event_ndims.shape.ndims))
1093
1094    elif self.validate_args:
1095      assertions += [
1096          check_ops.assert_rank(event_ndims, 0, message="Expected scalar.")]
1097    return assertions
1098
1099  def _maybe_get_static_event_ndims(self, event_ndims):
1100    """Helper which returns tries to return an integer static value."""
1101    event_ndims_ = distribution_util.maybe_get_static_value(event_ndims)
1102
1103    if isinstance(event_ndims_, (np.generic, np.ndarray)):
1104      if event_ndims_.dtype not in (np.int32, np.int64):
1105        raise ValueError("Expected integer dtype, got dtype {}".format(
1106            event_ndims_.dtype))
1107
1108      if isinstance(event_ndims_, np.ndarray) and len(event_ndims_.shape):
1109        raise ValueError("Expected a scalar integer, got {}".format(
1110            event_ndims_))
1111      event_ndims_ = int(event_ndims_)
1112
1113    return event_ndims_
1114