• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A Transformed Distribution class."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import numpy as np
21
22from tensorflow.python.framework import constant_op
23from tensorflow.python.framework import dtypes
24from tensorflow.python.framework import ops
25from tensorflow.python.framework import tensor_util
26from tensorflow.python.ops import array_ops
27from tensorflow.python.ops import check_ops
28from tensorflow.python.ops import control_flow_ops
29from tensorflow.python.ops import math_ops
30from tensorflow.python.ops.distributions import distribution as distribution_lib
31from tensorflow.python.ops.distributions import identity_bijector
32from tensorflow.python.ops.distributions import util as distribution_util
33from tensorflow.python.util import deprecation
34
35__all__ = [
36    "TransformedDistribution",
37]
38
39
40# The following helper functions attempt to statically perform a TF operation.
41# These functions make debugging easier since we can do more validation during
42# graph construction.
43
44
45def _static_value(x):
46  """Returns the static value of a `Tensor` or `None`."""
47  return tensor_util.constant_value(ops.convert_to_tensor(x))
48
49
50def _logical_and(*args):
51  """Convenience function which attempts to statically `reduce_all`."""
52  args_ = [_static_value(x) for x in args]
53  if any(x is not None and not bool(x) for x in args_):
54    return constant_op.constant(False)
55  if all(x is not None and bool(x) for x in args_):
56    return constant_op.constant(True)
57  if len(args) == 2:
58    return math_ops.logical_and(*args)
59  return math_ops.reduce_all(args)
60
61
62def _logical_equal(x, y):
63  """Convenience function which attempts to statically compute `x == y`."""
64  x_ = _static_value(x)
65  y_ = _static_value(y)
66  if x_ is None or y_ is None:
67    return math_ops.equal(x, y)
68  return constant_op.constant(np.array_equal(x_, y_))
69
70
71def _logical_not(x):
72  """Convenience function which attempts to statically apply `logical_not`."""
73  x_ = _static_value(x)
74  if x_ is None:
75    return math_ops.logical_not(x)
76  return constant_op.constant(np.logical_not(x_))
77
78
79def _concat_vectors(*args):
80  """Convenience function which concatenates input vectors."""
81  args_ = [_static_value(x) for x in args]
82  if any(x_ is None for x_ in args_):
83    return array_ops.concat(args, 0)
84  return constant_op.constant([x_ for vec_ in args_ for x_ in vec_])
85
86
87def _pick_scalar_condition(pred, cond_true, cond_false):
88  """Convenience function which chooses the condition based on the predicate."""
89  # Note: This function is only valid if all of pred, cond_true, and cond_false
90  # are scalars. This means its semantics are arguably more like tf.cond than
91  # tf.select even though we use tf.select to implement it.
92  pred_ = _static_value(pred)
93  if pred_ is None:
94    return array_ops.where(pred, cond_true, cond_false)
95  return cond_true if pred_ else cond_false
96
97
98def _ones_like(x):
99  """Convenience function attempts to statically construct `ones_like`."""
100  # Should only be used for small vectors.
101  if x.get_shape().is_fully_defined():
102    return array_ops.ones(x.get_shape().as_list(), dtype=x.dtype)
103  return array_ops.ones_like(x)
104
105
106def _ndims_from_shape(shape):
107  """Returns `Tensor`'s `rank` implied by a `Tensor` shape."""
108  if shape.get_shape().ndims not in (None, 1):
109    raise ValueError("input is not a valid shape: not 1D")
110  if not shape.dtype.is_integer:
111    raise TypeError("input is not a valid shape: wrong dtype")
112  if shape.get_shape().is_fully_defined():
113    return constant_op.constant(shape.get_shape().as_list()[0])
114  return array_ops.shape(shape)[0]
115
116
117def _is_scalar_from_shape(shape):
118  """Returns `True` `Tensor` if `Tensor` shape implies a scalar."""
119  return _logical_equal(_ndims_from_shape(shape), 0)
120
121
122class TransformedDistribution(distribution_lib.Distribution):
123  """A Transformed Distribution.
124
125  A `TransformedDistribution` models `p(y)` given a base distribution `p(x)`,
126  and a deterministic, invertible, differentiable transform, `Y = g(X)`. The
127  transform is typically an instance of the `Bijector` class and the base
128  distribution is typically an instance of the `Distribution` class.
129
130  A `Bijector` is expected to implement the following functions:
131  - `forward`,
132  - `inverse`,
133  - `inverse_log_det_jacobian`.
134  The semantics of these functions are outlined in the `Bijector` documentation.
135
136  We now describe how a `TransformedDistribution` alters the input/outputs of a
137  `Distribution` associated with a random variable (rv) `X`.
138
139  Write `cdf(Y=y)` for an absolutely continuous cumulative distribution function
140  of random variable `Y`; write the probability density function `pdf(Y=y) :=
141  d^k / (dy_1,...,dy_k) cdf(Y=y)` for its derivative wrt to `Y` evaluated at
142  `y`. Assume that `Y = g(X)` where `g` is a deterministic diffeomorphism,
143  i.e., a non-random, continuous, differentiable, and invertible function.
144  Write the inverse of `g` as `X = g^{-1}(Y)` and `(J o g)(x)` for the Jacobian
145  of `g` evaluated at `x`.
146
147  A `TransformedDistribution` implements the following operations:
148
149    * `sample`
150      Mathematically:   `Y = g(X)`
151      Programmatically: `bijector.forward(distribution.sample(...))`
152
153    * `log_prob`
154      Mathematically:   `(log o pdf)(Y=y) = (log o pdf o g^{-1})(y)
155                         + (log o abs o det o J o g^{-1})(y)`
156      Programmatically: `(distribution.log_prob(bijector.inverse(y))
157                         + bijector.inverse_log_det_jacobian(y))`
158
159    * `log_cdf`
160      Mathematically:   `(log o cdf)(Y=y) = (log o cdf o g^{-1})(y)`
161      Programmatically: `distribution.log_cdf(bijector.inverse(x))`
162
163    * and similarly for: `cdf`, `prob`, `log_survival_function`,
164     `survival_function`.
165
166  A simple example constructing a Log-Normal distribution from a Normal
167  distribution:
168
169  ```python
170  ds = tfp.distributions
171  log_normal = ds.TransformedDistribution(
172    distribution=ds.Normal(loc=0., scale=1.),
173    bijector=ds.bijectors.Exp(),
174    name="LogNormalTransformedDistribution")
175  ```
176
177  A `LogNormal` made from callables:
178
179  ```python
180  ds = tfp.distributions
181  log_normal = ds.TransformedDistribution(
182    distribution=ds.Normal(loc=0., scale=1.),
183    bijector=ds.bijectors.Inline(
184      forward_fn=tf.exp,
185      inverse_fn=tf.log,
186      inverse_log_det_jacobian_fn=(
187        lambda y: -tf.reduce_sum(tf.log(y), axis=-1)),
188    name="LogNormalTransformedDistribution")
189  ```
190
191  Another example constructing a Normal from a StandardNormal:
192
193  ```python
194  ds = tfp.distributions
195  normal = ds.TransformedDistribution(
196    distribution=ds.Normal(loc=0., scale=1.),
197    bijector=ds.bijectors.Affine(
198      shift=-1.,
199      scale_identity_multiplier=2.)
200    name="NormalTransformedDistribution")
201  ```
202
203  A `TransformedDistribution`'s batch- and event-shape are implied by the base
204  distribution unless explicitly overridden by `batch_shape` or `event_shape`
205  arguments. Specifying an overriding `batch_shape` (`event_shape`) is
206  permitted only if the base distribution has scalar batch-shape (event-shape).
207  The bijector is applied to the distribution as if the distribution possessed
208  the overridden shape(s). The following example demonstrates how to construct a
209  multivariate Normal as a `TransformedDistribution`.
210
211  ```python
212  ds = tfp.distributions
213  # We will create two MVNs with batch_shape = event_shape = 2.
214  mean = [[-1., 0],      # batch:0
215          [0., 1]]       # batch:1
216  chol_cov = [[[1., 0],
217               [0, 1]],  # batch:0
218              [[1, 0],
219               [2, 2]]]  # batch:1
220  mvn1 = ds.TransformedDistribution(
221      distribution=ds.Normal(loc=0., scale=1.),
222      bijector=ds.bijectors.Affine(shift=mean, scale_tril=chol_cov),
223      batch_shape=[2],  # Valid because base_distribution.batch_shape == [].
224      event_shape=[2])  # Valid because base_distribution.event_shape == [].
225  mvn2 = ds.MultivariateNormalTriL(loc=mean, scale_tril=chol_cov)
226  # mvn1.log_prob(x) == mvn2.log_prob(x)
227  ```
228
229  """
230
231  @deprecation.deprecated(
232      "2019-01-01",
233      "The TensorFlow Distributions library has moved to "
234      "TensorFlow Probability "
235      "(https://github.com/tensorflow/probability). You "
236      "should update all references to use `tfp.distributions` "
237      "instead of `tf.distributions`.",
238      warn_once=True)
239  def __init__(self,
240               distribution,
241               bijector=None,
242               batch_shape=None,
243               event_shape=None,
244               validate_args=False,
245               name=None):
246    """Construct a Transformed Distribution.
247
248    Args:
249      distribution: The base distribution instance to transform. Typically an
250        instance of `Distribution`.
251      bijector: The object responsible for calculating the transformation.
252        Typically an instance of `Bijector`. `None` means `Identity()`.
253      batch_shape: `integer` vector `Tensor` which overrides `distribution`
254        `batch_shape`; valid only if `distribution.is_scalar_batch()`.
255      event_shape: `integer` vector `Tensor` which overrides `distribution`
256        `event_shape`; valid only if `distribution.is_scalar_event()`.
257      validate_args: Python `bool`, default `False`. When `True` distribution
258        parameters are checked for validity despite possibly degrading runtime
259        performance. When `False` invalid inputs may silently render incorrect
260        outputs.
261      name: Python `str` name prefixed to Ops created by this class. Default:
262        `bijector.name + distribution.name`.
263    """
264    parameters = dict(locals())
265    name = name or (("" if bijector is None else bijector.name) +
266                    distribution.name)
267    with ops.name_scope(name, values=[event_shape, batch_shape]) as name:
268      # For convenience we define some handy constants.
269      self._zero = constant_op.constant(0, dtype=dtypes.int32, name="zero")
270      self._empty = constant_op.constant([], dtype=dtypes.int32, name="empty")
271
272      if bijector is None:
273        bijector = identity_bijector.Identity(validate_args=validate_args)
274
275      # We will keep track of a static and dynamic version of
276      # self._is_{batch,event}_override. This way we can do more prior to graph
277      # execution, including possibly raising Python exceptions.
278
279      self._override_batch_shape = self._maybe_validate_shape_override(
280          batch_shape, distribution.is_scalar_batch(), validate_args,
281          "batch_shape")
282      self._is_batch_override = _logical_not(_logical_equal(
283          _ndims_from_shape(self._override_batch_shape), self._zero))
284      self._is_maybe_batch_override = bool(
285          tensor_util.constant_value(self._override_batch_shape) is None or
286          tensor_util.constant_value(self._override_batch_shape).size != 0)
287
288      self._override_event_shape = self._maybe_validate_shape_override(
289          event_shape, distribution.is_scalar_event(), validate_args,
290          "event_shape")
291      self._is_event_override = _logical_not(_logical_equal(
292          _ndims_from_shape(self._override_event_shape), self._zero))
293      self._is_maybe_event_override = bool(
294          tensor_util.constant_value(self._override_event_shape) is None or
295          tensor_util.constant_value(self._override_event_shape).size != 0)
296
297      # To convert a scalar distribution into a multivariate distribution we
298      # will draw dims from the sample dims, which are otherwise iid. This is
299      # easy to do except in the case that the base distribution has batch dims
300      # and we're overriding event shape. When that case happens the event dims
301      # will incorrectly be to the left of the batch dims. In this case we'll
302      # cyclically permute left the new dims.
303      self._needs_rotation = _logical_and(
304          self._is_event_override,
305          _logical_not(self._is_batch_override),
306          _logical_not(distribution.is_scalar_batch()))
307      override_event_ndims = _ndims_from_shape(self._override_event_shape)
308      self._rotate_ndims = _pick_scalar_condition(
309          self._needs_rotation, override_event_ndims, 0)
310      # We'll be reducing the head dims (if at all), i.e., this will be []
311      # if we don't need to reduce.
312      self._reduce_event_indices = math_ops.range(
313          self._rotate_ndims - override_event_ndims, self._rotate_ndims)
314
315    self._distribution = distribution
316    self._bijector = bijector
317    super(TransformedDistribution, self).__init__(
318        dtype=self._distribution.dtype,
319        reparameterization_type=self._distribution.reparameterization_type,
320        validate_args=validate_args,
321        allow_nan_stats=self._distribution.allow_nan_stats,
322        parameters=parameters,
323        # We let TransformedDistribution access _graph_parents since this class
324        # is more like a baseclass than derived.
325        graph_parents=(distribution._graph_parents +  # pylint: disable=protected-access
326                       bijector.graph_parents),
327        name=name)
328
329  @property
330  def distribution(self):
331    """Base distribution, p(x)."""
332    return self._distribution
333
334  @property
335  def bijector(self):
336    """Function transforming x => y."""
337    return self._bijector
338
339  def _event_shape_tensor(self):
340    return self.bijector.forward_event_shape_tensor(
341        distribution_util.pick_vector(
342            self._is_event_override,
343            self._override_event_shape,
344            self.distribution.event_shape_tensor()))
345
346  def _event_shape(self):
347    # If there's a chance that the event_shape has been overridden, we return
348    # what we statically know about the `event_shape_override`. This works
349    # because: `_is_maybe_event_override` means `static_override` is `None` or a
350    # non-empty list, i.e., we don't statically know the `event_shape` or we do.
351    #
352    # Since the `bijector` may change the `event_shape`, we then forward what we
353    # know to the bijector. This allows the `bijector` to have final say in the
354    # `event_shape`.
355    static_override = tensor_util.constant_value_as_shape(
356        self._override_event_shape)
357    return self.bijector.forward_event_shape(
358        static_override
359        if self._is_maybe_event_override
360        else self.distribution.event_shape)
361
362  def _batch_shape_tensor(self):
363    return distribution_util.pick_vector(
364        self._is_batch_override,
365        self._override_batch_shape,
366        self.distribution.batch_shape_tensor())
367
368  def _batch_shape(self):
369    # If there's a chance that the batch_shape has been overridden, we return
370    # what we statically know about the `batch_shape_override`. This works
371    # because: `_is_maybe_batch_override` means `static_override` is `None` or a
372    # non-empty list, i.e., we don't statically know the `batch_shape` or we do.
373    #
374    # Notice that this implementation parallels the `_event_shape` except that
375    # the `bijector` doesn't get to alter the `batch_shape`. Recall that
376    # `batch_shape` is a property of a distribution while `event_shape` is
377    # shared between both the `distribution` instance and the `bijector`.
378    static_override = tensor_util.constant_value_as_shape(
379        self._override_batch_shape)
380    return (static_override
381            if self._is_maybe_batch_override
382            else self.distribution.batch_shape)
383
384  def _sample_n(self, n, seed=None):
385    sample_shape = _concat_vectors(
386        distribution_util.pick_vector(self._needs_rotation, self._empty, [n]),
387        self._override_batch_shape,
388        self._override_event_shape,
389        distribution_util.pick_vector(self._needs_rotation, [n], self._empty))
390    x = self.distribution.sample(sample_shape=sample_shape, seed=seed)
391    x = self._maybe_rotate_dims(x)
392    # We'll apply the bijector in the `_call_sample_n` function.
393    return x
394
395  def _call_sample_n(self, sample_shape, seed, name, **kwargs):
396    # We override `_call_sample_n` rather than `_sample_n` so we can ensure that
397    # the result of `self.bijector.forward` is not modified (and thus caching
398    # works).
399    with self._name_scope(name, values=[sample_shape]):
400      sample_shape = ops.convert_to_tensor(
401          sample_shape, dtype=dtypes.int32, name="sample_shape")
402      sample_shape, n = self._expand_sample_shape_to_vector(
403          sample_shape, "sample_shape")
404
405      # First, generate samples. We will possibly generate extra samples in the
406      # event that we need to reinterpret the samples as part of the
407      # event_shape.
408      x = self._sample_n(n, seed, **kwargs)
409
410      # Next, we reshape `x` into its final form. We do this prior to the call
411      # to the bijector to ensure that the bijector caching works.
412      batch_event_shape = array_ops.shape(x)[1:]
413      final_shape = array_ops.concat([sample_shape, batch_event_shape], 0)
414      x = array_ops.reshape(x, final_shape)
415
416      # Finally, we apply the bijector's forward transformation. For caching to
417      # work, it is imperative that this is the last modification to the
418      # returned result.
419      y = self.bijector.forward(x, **kwargs)
420      y = self._set_sample_static_shape(y, sample_shape)
421
422      return y
423
424  def _log_prob(self, y):
425    # For caching to work, it is imperative that the bijector is the first to
426    # modify the input.
427    x = self.bijector.inverse(y)
428    event_ndims = self._maybe_get_static_event_ndims()
429
430    ildj = self.bijector.inverse_log_det_jacobian(y, event_ndims=event_ndims)
431    if self.bijector._is_injective:  # pylint: disable=protected-access
432      return self._finish_log_prob_for_one_fiber(y, x, ildj, event_ndims)
433
434    lp_on_fibers = [
435        self._finish_log_prob_for_one_fiber(y, x_i, ildj_i, event_ndims)
436        for x_i, ildj_i in zip(x, ildj)]
437    return math_ops.reduce_logsumexp(array_ops.stack(lp_on_fibers), axis=0)
438
439  def _finish_log_prob_for_one_fiber(self, y, x, ildj, event_ndims):
440    """Finish computation of log_prob on one element of the inverse image."""
441    x = self._maybe_rotate_dims(x, rotate_right=True)
442    log_prob = self.distribution.log_prob(x)
443    if self._is_maybe_event_override:
444      log_prob = math_ops.reduce_sum(log_prob, self._reduce_event_indices)
445    log_prob += math_ops.cast(ildj, log_prob.dtype)
446    if self._is_maybe_event_override and isinstance(event_ndims, int):
447      log_prob.set_shape(
448          array_ops.broadcast_static_shape(
449              y.get_shape().with_rank_at_least(1)[:-event_ndims],
450              self.batch_shape))
451    return log_prob
452
453  def _prob(self, y):
454    x = self.bijector.inverse(y)
455    event_ndims = self._maybe_get_static_event_ndims()
456    ildj = self.bijector.inverse_log_det_jacobian(y, event_ndims=event_ndims)
457    if self.bijector._is_injective:  # pylint: disable=protected-access
458      return self._finish_prob_for_one_fiber(y, x, ildj, event_ndims)
459
460    prob_on_fibers = [
461        self._finish_prob_for_one_fiber(y, x_i, ildj_i, event_ndims)
462        for x_i, ildj_i in zip(x, ildj)]
463    return sum(prob_on_fibers)
464
465  def _finish_prob_for_one_fiber(self, y, x, ildj, event_ndims):
466    """Finish computation of prob on one element of the inverse image."""
467    x = self._maybe_rotate_dims(x, rotate_right=True)
468    prob = self.distribution.prob(x)
469    if self._is_maybe_event_override:
470      prob = math_ops.reduce_prod(prob, self._reduce_event_indices)
471    prob *= math_ops.exp(math_ops.cast(ildj, prob.dtype))
472    if self._is_maybe_event_override and isinstance(event_ndims, int):
473      prob.set_shape(
474          array_ops.broadcast_static_shape(
475              y.get_shape().with_rank_at_least(1)[:-event_ndims],
476              self.batch_shape))
477    return prob
478
479  def _log_cdf(self, y):
480    if self._is_maybe_event_override:
481      raise NotImplementedError("log_cdf is not implemented when overriding "
482                                "event_shape")
483    if not self.bijector._is_injective:  # pylint: disable=protected-access
484      raise NotImplementedError("log_cdf is not implemented when "
485                                "bijector is not injective.")
486    x = self.bijector.inverse(y)
487    return self.distribution.log_cdf(x)
488
489  def _cdf(self, y):
490    if self._is_maybe_event_override:
491      raise NotImplementedError("cdf is not implemented when overriding "
492                                "event_shape")
493    if not self.bijector._is_injective:  # pylint: disable=protected-access
494      raise NotImplementedError("cdf is not implemented when "
495                                "bijector is not injective.")
496    x = self.bijector.inverse(y)
497    return self.distribution.cdf(x)
498
499  def _log_survival_function(self, y):
500    if self._is_maybe_event_override:
501      raise NotImplementedError("log_survival_function is not implemented when "
502                                "overriding event_shape")
503    if not self.bijector._is_injective:  # pylint: disable=protected-access
504      raise NotImplementedError("log_survival_function is not implemented when "
505                                "bijector is not injective.")
506    x = self.bijector.inverse(y)
507    return self.distribution.log_survival_function(x)
508
509  def _survival_function(self, y):
510    if self._is_maybe_event_override:
511      raise NotImplementedError("survival_function is not implemented when "
512                                "overriding event_shape")
513    if not self.bijector._is_injective:  # pylint: disable=protected-access
514      raise NotImplementedError("survival_function is not implemented when "
515                                "bijector is not injective.")
516    x = self.bijector.inverse(y)
517    return self.distribution.survival_function(x)
518
519  def _quantile(self, value):
520    if self._is_maybe_event_override:
521      raise NotImplementedError("quantile is not implemented when overriding "
522                                "event_shape")
523    if not self.bijector._is_injective:  # pylint: disable=protected-access
524      raise NotImplementedError("quantile is not implemented when "
525                                "bijector is not injective.")
526    # x_q is the "qth quantile" of X iff q = P[X <= x_q].  Now, since X =
527    # g^{-1}(Y), q = P[X <= x_q] = P[g^{-1}(Y) <= x_q] = P[Y <= g(x_q)],
528    # implies the qth quantile of Y is g(x_q).
529    inv_cdf = self.distribution.quantile(value)
530    return self.bijector.forward(inv_cdf)
531
532  def _entropy(self):
533    if not self.bijector.is_constant_jacobian:
534      raise NotImplementedError("entropy is not implemented")
535    if not self.bijector._is_injective:  # pylint: disable=protected-access
536      raise NotImplementedError("entropy is not implemented when "
537                                "bijector is not injective.")
538    # Suppose Y = g(X) where g is a diffeomorphism and X is a continuous rv. It
539    # can be shown that:
540    #   H[Y] = H[X] + E_X[(log o abs o det o J o g)(X)].
541    # If is_constant_jacobian then:
542    #   E_X[(log o abs o det o J o g)(X)] = (log o abs o det o J o g)(c)
543    # where c can by anything.
544    entropy = self.distribution.entropy()
545    if self._is_maybe_event_override:
546      # H[X] = sum_i H[X_i] if X_i are mutually independent.
547      # This means that a reduce_sum is a simple rescaling.
548      entropy *= math_ops.cast(math_ops.reduce_prod(self._override_event_shape),
549                               dtype=entropy.dtype.base_dtype)
550    if self._is_maybe_batch_override:
551      new_shape = array_ops.concat([
552          _ones_like(self._override_batch_shape),
553          self.distribution.batch_shape_tensor()
554      ], 0)
555      entropy = array_ops.reshape(entropy, new_shape)
556      multiples = array_ops.concat([
557          self._override_batch_shape,
558          _ones_like(self.distribution.batch_shape_tensor())
559      ], 0)
560      entropy = array_ops.tile(entropy, multiples)
561    dummy = array_ops.zeros(
562        shape=array_ops.concat(
563            [self.batch_shape_tensor(), self.event_shape_tensor()],
564            0),
565        dtype=self.dtype)
566    event_ndims = (self.event_shape.ndims if self.event_shape.ndims is not None
567                   else array_ops.size(self.event_shape_tensor()))
568    ildj = self.bijector.inverse_log_det_jacobian(
569        dummy, event_ndims=event_ndims)
570
571    entropy -= math_ops.cast(ildj, entropy.dtype)
572    entropy.set_shape(self.batch_shape)
573    return entropy
574
575  def _maybe_validate_shape_override(self, override_shape, base_is_scalar,
576                                     validate_args, name):
577    """Helper to __init__ which ensures override batch/event_shape are valid."""
578    if override_shape is None:
579      override_shape = []
580
581    override_shape = ops.convert_to_tensor(override_shape, dtype=dtypes.int32,
582                                           name=name)
583
584    if not override_shape.dtype.is_integer:
585      raise TypeError("shape override must be an integer")
586
587    override_is_scalar = _is_scalar_from_shape(override_shape)
588    if tensor_util.constant_value(override_is_scalar):
589      return self._empty
590
591    dynamic_assertions = []
592
593    if override_shape.get_shape().ndims is not None:
594      if override_shape.get_shape().ndims != 1:
595        raise ValueError("shape override must be a vector")
596    elif validate_args:
597      dynamic_assertions += [check_ops.assert_rank(
598          override_shape, 1,
599          message="shape override must be a vector")]
600
601    if tensor_util.constant_value(override_shape) is not None:
602      if any(s <= 0 for s in tensor_util.constant_value(override_shape)):
603        raise ValueError("shape override must have positive elements")
604    elif validate_args:
605      dynamic_assertions += [check_ops.assert_positive(
606          override_shape,
607          message="shape override must have positive elements")]
608
609    is_both_nonscalar = _logical_and(_logical_not(base_is_scalar),
610                                     _logical_not(override_is_scalar))
611    if tensor_util.constant_value(is_both_nonscalar) is not None:
612      if tensor_util.constant_value(is_both_nonscalar):
613        raise ValueError("base distribution not scalar")
614    elif validate_args:
615      dynamic_assertions += [check_ops.assert_equal(
616          is_both_nonscalar, False,
617          message="base distribution not scalar")]
618
619    if not dynamic_assertions:
620      return override_shape
621    return control_flow_ops.with_dependencies(
622        dynamic_assertions, override_shape)
623
624  def _maybe_rotate_dims(self, x, rotate_right=False):
625    """Helper which rolls left event_dims left or right event_dims right."""
626    needs_rotation_const = tensor_util.constant_value(self._needs_rotation)
627    if needs_rotation_const is not None and not needs_rotation_const:
628      return x
629    ndims = array_ops.rank(x)
630    n = (ndims - self._rotate_ndims) if rotate_right else self._rotate_ndims
631    return array_ops.transpose(
632        x, _concat_vectors(math_ops.range(n, ndims), math_ops.range(0, n)))
633
634  def _maybe_get_static_event_ndims(self):
635    if self.event_shape.ndims is not None:
636      return self.event_shape.ndims
637
638    event_ndims = array_ops.size(self.event_shape_tensor())
639    event_ndims_ = distribution_util.maybe_get_static_value(event_ndims)
640
641    if event_ndims_ is not None:
642      return event_ndims_
643
644    return event_ndims
645