1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""A Transformed Distribution class.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import numpy as np 21 22from tensorflow.python.framework import constant_op 23from tensorflow.python.framework import dtypes 24from tensorflow.python.framework import ops 25from tensorflow.python.framework import tensor_util 26from tensorflow.python.ops import array_ops 27from tensorflow.python.ops import check_ops 28from tensorflow.python.ops import control_flow_ops 29from tensorflow.python.ops import math_ops 30from tensorflow.python.ops.distributions import distribution as distribution_lib 31from tensorflow.python.ops.distributions import identity_bijector 32from tensorflow.python.ops.distributions import util as distribution_util 33from tensorflow.python.util import deprecation 34 35__all__ = [ 36 "TransformedDistribution", 37] 38 39 40# The following helper functions attempt to statically perform a TF operation. 41# These functions make debugging easier since we can do more validation during 42# graph construction. 43 44 45def _static_value(x): 46 """Returns the static value of a `Tensor` or `None`.""" 47 return tensor_util.constant_value(ops.convert_to_tensor(x)) 48 49 50def _logical_and(*args): 51 """Convenience function which attempts to statically `reduce_all`.""" 52 args_ = [_static_value(x) for x in args] 53 if any(x is not None and not bool(x) for x in args_): 54 return constant_op.constant(False) 55 if all(x is not None and bool(x) for x in args_): 56 return constant_op.constant(True) 57 if len(args) == 2: 58 return math_ops.logical_and(*args) 59 return math_ops.reduce_all(args) 60 61 62def _logical_equal(x, y): 63 """Convenience function which attempts to statically compute `x == y`.""" 64 x_ = _static_value(x) 65 y_ = _static_value(y) 66 if x_ is None or y_ is None: 67 return math_ops.equal(x, y) 68 return constant_op.constant(np.array_equal(x_, y_)) 69 70 71def _logical_not(x): 72 """Convenience function which attempts to statically apply `logical_not`.""" 73 x_ = _static_value(x) 74 if x_ is None: 75 return math_ops.logical_not(x) 76 return constant_op.constant(np.logical_not(x_)) 77 78 79def _concat_vectors(*args): 80 """Convenience function which concatenates input vectors.""" 81 args_ = [_static_value(x) for x in args] 82 if any(x_ is None for x_ in args_): 83 return array_ops.concat(args, 0) 84 return constant_op.constant([x_ for vec_ in args_ for x_ in vec_]) 85 86 87def _pick_scalar_condition(pred, cond_true, cond_false): 88 """Convenience function which chooses the condition based on the predicate.""" 89 # Note: This function is only valid if all of pred, cond_true, and cond_false 90 # are scalars. This means its semantics are arguably more like tf.cond than 91 # tf.select even though we use tf.select to implement it. 92 pred_ = _static_value(pred) 93 if pred_ is None: 94 return array_ops.where(pred, cond_true, cond_false) 95 return cond_true if pred_ else cond_false 96 97 98def _ones_like(x): 99 """Convenience function attempts to statically construct `ones_like`.""" 100 # Should only be used for small vectors. 101 if x.get_shape().is_fully_defined(): 102 return array_ops.ones(x.get_shape().as_list(), dtype=x.dtype) 103 return array_ops.ones_like(x) 104 105 106def _ndims_from_shape(shape): 107 """Returns `Tensor`'s `rank` implied by a `Tensor` shape.""" 108 if shape.get_shape().ndims not in (None, 1): 109 raise ValueError("input is not a valid shape: not 1D") 110 if not shape.dtype.is_integer: 111 raise TypeError("input is not a valid shape: wrong dtype") 112 if shape.get_shape().is_fully_defined(): 113 return constant_op.constant(shape.get_shape().as_list()[0]) 114 return array_ops.shape(shape)[0] 115 116 117def _is_scalar_from_shape(shape): 118 """Returns `True` `Tensor` if `Tensor` shape implies a scalar.""" 119 return _logical_equal(_ndims_from_shape(shape), 0) 120 121 122class TransformedDistribution(distribution_lib.Distribution): 123 """A Transformed Distribution. 124 125 A `TransformedDistribution` models `p(y)` given a base distribution `p(x)`, 126 and a deterministic, invertible, differentiable transform, `Y = g(X)`. The 127 transform is typically an instance of the `Bijector` class and the base 128 distribution is typically an instance of the `Distribution` class. 129 130 A `Bijector` is expected to implement the following functions: 131 - `forward`, 132 - `inverse`, 133 - `inverse_log_det_jacobian`. 134 The semantics of these functions are outlined in the `Bijector` documentation. 135 136 We now describe how a `TransformedDistribution` alters the input/outputs of a 137 `Distribution` associated with a random variable (rv) `X`. 138 139 Write `cdf(Y=y)` for an absolutely continuous cumulative distribution function 140 of random variable `Y`; write the probability density function `pdf(Y=y) := 141 d^k / (dy_1,...,dy_k) cdf(Y=y)` for its derivative wrt to `Y` evaluated at 142 `y`. Assume that `Y = g(X)` where `g` is a deterministic diffeomorphism, 143 i.e., a non-random, continuous, differentiable, and invertible function. 144 Write the inverse of `g` as `X = g^{-1}(Y)` and `(J o g)(x)` for the Jacobian 145 of `g` evaluated at `x`. 146 147 A `TransformedDistribution` implements the following operations: 148 149 * `sample` 150 Mathematically: `Y = g(X)` 151 Programmatically: `bijector.forward(distribution.sample(...))` 152 153 * `log_prob` 154 Mathematically: `(log o pdf)(Y=y) = (log o pdf o g^{-1})(y) 155 + (log o abs o det o J o g^{-1})(y)` 156 Programmatically: `(distribution.log_prob(bijector.inverse(y)) 157 + bijector.inverse_log_det_jacobian(y))` 158 159 * `log_cdf` 160 Mathematically: `(log o cdf)(Y=y) = (log o cdf o g^{-1})(y)` 161 Programmatically: `distribution.log_cdf(bijector.inverse(x))` 162 163 * and similarly for: `cdf`, `prob`, `log_survival_function`, 164 `survival_function`. 165 166 A simple example constructing a Log-Normal distribution from a Normal 167 distribution: 168 169 ```python 170 ds = tfp.distributions 171 log_normal = ds.TransformedDistribution( 172 distribution=ds.Normal(loc=0., scale=1.), 173 bijector=ds.bijectors.Exp(), 174 name="LogNormalTransformedDistribution") 175 ``` 176 177 A `LogNormal` made from callables: 178 179 ```python 180 ds = tfp.distributions 181 log_normal = ds.TransformedDistribution( 182 distribution=ds.Normal(loc=0., scale=1.), 183 bijector=ds.bijectors.Inline( 184 forward_fn=tf.exp, 185 inverse_fn=tf.log, 186 inverse_log_det_jacobian_fn=( 187 lambda y: -tf.reduce_sum(tf.log(y), axis=-1)), 188 name="LogNormalTransformedDistribution") 189 ``` 190 191 Another example constructing a Normal from a StandardNormal: 192 193 ```python 194 ds = tfp.distributions 195 normal = ds.TransformedDistribution( 196 distribution=ds.Normal(loc=0., scale=1.), 197 bijector=ds.bijectors.Affine( 198 shift=-1., 199 scale_identity_multiplier=2.) 200 name="NormalTransformedDistribution") 201 ``` 202 203 A `TransformedDistribution`'s batch- and event-shape are implied by the base 204 distribution unless explicitly overridden by `batch_shape` or `event_shape` 205 arguments. Specifying an overriding `batch_shape` (`event_shape`) is 206 permitted only if the base distribution has scalar batch-shape (event-shape). 207 The bijector is applied to the distribution as if the distribution possessed 208 the overridden shape(s). The following example demonstrates how to construct a 209 multivariate Normal as a `TransformedDistribution`. 210 211 ```python 212 ds = tfp.distributions 213 # We will create two MVNs with batch_shape = event_shape = 2. 214 mean = [[-1., 0], # batch:0 215 [0., 1]] # batch:1 216 chol_cov = [[[1., 0], 217 [0, 1]], # batch:0 218 [[1, 0], 219 [2, 2]]] # batch:1 220 mvn1 = ds.TransformedDistribution( 221 distribution=ds.Normal(loc=0., scale=1.), 222 bijector=ds.bijectors.Affine(shift=mean, scale_tril=chol_cov), 223 batch_shape=[2], # Valid because base_distribution.batch_shape == []. 224 event_shape=[2]) # Valid because base_distribution.event_shape == []. 225 mvn2 = ds.MultivariateNormalTriL(loc=mean, scale_tril=chol_cov) 226 # mvn1.log_prob(x) == mvn2.log_prob(x) 227 ``` 228 229 """ 230 231 @deprecation.deprecated( 232 "2019-01-01", 233 "The TensorFlow Distributions library has moved to " 234 "TensorFlow Probability " 235 "(https://github.com/tensorflow/probability). You " 236 "should update all references to use `tfp.distributions` " 237 "instead of `tf.distributions`.", 238 warn_once=True) 239 def __init__(self, 240 distribution, 241 bijector=None, 242 batch_shape=None, 243 event_shape=None, 244 validate_args=False, 245 name=None): 246 """Construct a Transformed Distribution. 247 248 Args: 249 distribution: The base distribution instance to transform. Typically an 250 instance of `Distribution`. 251 bijector: The object responsible for calculating the transformation. 252 Typically an instance of `Bijector`. `None` means `Identity()`. 253 batch_shape: `integer` vector `Tensor` which overrides `distribution` 254 `batch_shape`; valid only if `distribution.is_scalar_batch()`. 255 event_shape: `integer` vector `Tensor` which overrides `distribution` 256 `event_shape`; valid only if `distribution.is_scalar_event()`. 257 validate_args: Python `bool`, default `False`. When `True` distribution 258 parameters are checked for validity despite possibly degrading runtime 259 performance. When `False` invalid inputs may silently render incorrect 260 outputs. 261 name: Python `str` name prefixed to Ops created by this class. Default: 262 `bijector.name + distribution.name`. 263 """ 264 parameters = dict(locals()) 265 name = name or (("" if bijector is None else bijector.name) + 266 distribution.name) 267 with ops.name_scope(name, values=[event_shape, batch_shape]) as name: 268 # For convenience we define some handy constants. 269 self._zero = constant_op.constant(0, dtype=dtypes.int32, name="zero") 270 self._empty = constant_op.constant([], dtype=dtypes.int32, name="empty") 271 272 if bijector is None: 273 bijector = identity_bijector.Identity(validate_args=validate_args) 274 275 # We will keep track of a static and dynamic version of 276 # self._is_{batch,event}_override. This way we can do more prior to graph 277 # execution, including possibly raising Python exceptions. 278 279 self._override_batch_shape = self._maybe_validate_shape_override( 280 batch_shape, distribution.is_scalar_batch(), validate_args, 281 "batch_shape") 282 self._is_batch_override = _logical_not(_logical_equal( 283 _ndims_from_shape(self._override_batch_shape), self._zero)) 284 self._is_maybe_batch_override = bool( 285 tensor_util.constant_value(self._override_batch_shape) is None or 286 tensor_util.constant_value(self._override_batch_shape).size != 0) 287 288 self._override_event_shape = self._maybe_validate_shape_override( 289 event_shape, distribution.is_scalar_event(), validate_args, 290 "event_shape") 291 self._is_event_override = _logical_not(_logical_equal( 292 _ndims_from_shape(self._override_event_shape), self._zero)) 293 self._is_maybe_event_override = bool( 294 tensor_util.constant_value(self._override_event_shape) is None or 295 tensor_util.constant_value(self._override_event_shape).size != 0) 296 297 # To convert a scalar distribution into a multivariate distribution we 298 # will draw dims from the sample dims, which are otherwise iid. This is 299 # easy to do except in the case that the base distribution has batch dims 300 # and we're overriding event shape. When that case happens the event dims 301 # will incorrectly be to the left of the batch dims. In this case we'll 302 # cyclically permute left the new dims. 303 self._needs_rotation = _logical_and( 304 self._is_event_override, 305 _logical_not(self._is_batch_override), 306 _logical_not(distribution.is_scalar_batch())) 307 override_event_ndims = _ndims_from_shape(self._override_event_shape) 308 self._rotate_ndims = _pick_scalar_condition( 309 self._needs_rotation, override_event_ndims, 0) 310 # We'll be reducing the head dims (if at all), i.e., this will be [] 311 # if we don't need to reduce. 312 self._reduce_event_indices = math_ops.range( 313 self._rotate_ndims - override_event_ndims, self._rotate_ndims) 314 315 self._distribution = distribution 316 self._bijector = bijector 317 super(TransformedDistribution, self).__init__( 318 dtype=self._distribution.dtype, 319 reparameterization_type=self._distribution.reparameterization_type, 320 validate_args=validate_args, 321 allow_nan_stats=self._distribution.allow_nan_stats, 322 parameters=parameters, 323 # We let TransformedDistribution access _graph_parents since this class 324 # is more like a baseclass than derived. 325 graph_parents=(distribution._graph_parents + # pylint: disable=protected-access 326 bijector.graph_parents), 327 name=name) 328 329 @property 330 def distribution(self): 331 """Base distribution, p(x).""" 332 return self._distribution 333 334 @property 335 def bijector(self): 336 """Function transforming x => y.""" 337 return self._bijector 338 339 def _event_shape_tensor(self): 340 return self.bijector.forward_event_shape_tensor( 341 distribution_util.pick_vector( 342 self._is_event_override, 343 self._override_event_shape, 344 self.distribution.event_shape_tensor())) 345 346 def _event_shape(self): 347 # If there's a chance that the event_shape has been overridden, we return 348 # what we statically know about the `event_shape_override`. This works 349 # because: `_is_maybe_event_override` means `static_override` is `None` or a 350 # non-empty list, i.e., we don't statically know the `event_shape` or we do. 351 # 352 # Since the `bijector` may change the `event_shape`, we then forward what we 353 # know to the bijector. This allows the `bijector` to have final say in the 354 # `event_shape`. 355 static_override = tensor_util.constant_value_as_shape( 356 self._override_event_shape) 357 return self.bijector.forward_event_shape( 358 static_override 359 if self._is_maybe_event_override 360 else self.distribution.event_shape) 361 362 def _batch_shape_tensor(self): 363 return distribution_util.pick_vector( 364 self._is_batch_override, 365 self._override_batch_shape, 366 self.distribution.batch_shape_tensor()) 367 368 def _batch_shape(self): 369 # If there's a chance that the batch_shape has been overridden, we return 370 # what we statically know about the `batch_shape_override`. This works 371 # because: `_is_maybe_batch_override` means `static_override` is `None` or a 372 # non-empty list, i.e., we don't statically know the `batch_shape` or we do. 373 # 374 # Notice that this implementation parallels the `_event_shape` except that 375 # the `bijector` doesn't get to alter the `batch_shape`. Recall that 376 # `batch_shape` is a property of a distribution while `event_shape` is 377 # shared between both the `distribution` instance and the `bijector`. 378 static_override = tensor_util.constant_value_as_shape( 379 self._override_batch_shape) 380 return (static_override 381 if self._is_maybe_batch_override 382 else self.distribution.batch_shape) 383 384 def _sample_n(self, n, seed=None): 385 sample_shape = _concat_vectors( 386 distribution_util.pick_vector(self._needs_rotation, self._empty, [n]), 387 self._override_batch_shape, 388 self._override_event_shape, 389 distribution_util.pick_vector(self._needs_rotation, [n], self._empty)) 390 x = self.distribution.sample(sample_shape=sample_shape, seed=seed) 391 x = self._maybe_rotate_dims(x) 392 # We'll apply the bijector in the `_call_sample_n` function. 393 return x 394 395 def _call_sample_n(self, sample_shape, seed, name, **kwargs): 396 # We override `_call_sample_n` rather than `_sample_n` so we can ensure that 397 # the result of `self.bijector.forward` is not modified (and thus caching 398 # works). 399 with self._name_scope(name, values=[sample_shape]): 400 sample_shape = ops.convert_to_tensor( 401 sample_shape, dtype=dtypes.int32, name="sample_shape") 402 sample_shape, n = self._expand_sample_shape_to_vector( 403 sample_shape, "sample_shape") 404 405 # First, generate samples. We will possibly generate extra samples in the 406 # event that we need to reinterpret the samples as part of the 407 # event_shape. 408 x = self._sample_n(n, seed, **kwargs) 409 410 # Next, we reshape `x` into its final form. We do this prior to the call 411 # to the bijector to ensure that the bijector caching works. 412 batch_event_shape = array_ops.shape(x)[1:] 413 final_shape = array_ops.concat([sample_shape, batch_event_shape], 0) 414 x = array_ops.reshape(x, final_shape) 415 416 # Finally, we apply the bijector's forward transformation. For caching to 417 # work, it is imperative that this is the last modification to the 418 # returned result. 419 y = self.bijector.forward(x, **kwargs) 420 y = self._set_sample_static_shape(y, sample_shape) 421 422 return y 423 424 def _log_prob(self, y): 425 # For caching to work, it is imperative that the bijector is the first to 426 # modify the input. 427 x = self.bijector.inverse(y) 428 event_ndims = self._maybe_get_static_event_ndims() 429 430 ildj = self.bijector.inverse_log_det_jacobian(y, event_ndims=event_ndims) 431 if self.bijector._is_injective: # pylint: disable=protected-access 432 return self._finish_log_prob_for_one_fiber(y, x, ildj, event_ndims) 433 434 lp_on_fibers = [ 435 self._finish_log_prob_for_one_fiber(y, x_i, ildj_i, event_ndims) 436 for x_i, ildj_i in zip(x, ildj)] 437 return math_ops.reduce_logsumexp(array_ops.stack(lp_on_fibers), axis=0) 438 439 def _finish_log_prob_for_one_fiber(self, y, x, ildj, event_ndims): 440 """Finish computation of log_prob on one element of the inverse image.""" 441 x = self._maybe_rotate_dims(x, rotate_right=True) 442 log_prob = self.distribution.log_prob(x) 443 if self._is_maybe_event_override: 444 log_prob = math_ops.reduce_sum(log_prob, self._reduce_event_indices) 445 log_prob += math_ops.cast(ildj, log_prob.dtype) 446 if self._is_maybe_event_override and isinstance(event_ndims, int): 447 log_prob.set_shape( 448 array_ops.broadcast_static_shape( 449 y.get_shape().with_rank_at_least(1)[:-event_ndims], 450 self.batch_shape)) 451 return log_prob 452 453 def _prob(self, y): 454 x = self.bijector.inverse(y) 455 event_ndims = self._maybe_get_static_event_ndims() 456 ildj = self.bijector.inverse_log_det_jacobian(y, event_ndims=event_ndims) 457 if self.bijector._is_injective: # pylint: disable=protected-access 458 return self._finish_prob_for_one_fiber(y, x, ildj, event_ndims) 459 460 prob_on_fibers = [ 461 self._finish_prob_for_one_fiber(y, x_i, ildj_i, event_ndims) 462 for x_i, ildj_i in zip(x, ildj)] 463 return sum(prob_on_fibers) 464 465 def _finish_prob_for_one_fiber(self, y, x, ildj, event_ndims): 466 """Finish computation of prob on one element of the inverse image.""" 467 x = self._maybe_rotate_dims(x, rotate_right=True) 468 prob = self.distribution.prob(x) 469 if self._is_maybe_event_override: 470 prob = math_ops.reduce_prod(prob, self._reduce_event_indices) 471 prob *= math_ops.exp(math_ops.cast(ildj, prob.dtype)) 472 if self._is_maybe_event_override and isinstance(event_ndims, int): 473 prob.set_shape( 474 array_ops.broadcast_static_shape( 475 y.get_shape().with_rank_at_least(1)[:-event_ndims], 476 self.batch_shape)) 477 return prob 478 479 def _log_cdf(self, y): 480 if self._is_maybe_event_override: 481 raise NotImplementedError("log_cdf is not implemented when overriding " 482 "event_shape") 483 if not self.bijector._is_injective: # pylint: disable=protected-access 484 raise NotImplementedError("log_cdf is not implemented when " 485 "bijector is not injective.") 486 x = self.bijector.inverse(y) 487 return self.distribution.log_cdf(x) 488 489 def _cdf(self, y): 490 if self._is_maybe_event_override: 491 raise NotImplementedError("cdf is not implemented when overriding " 492 "event_shape") 493 if not self.bijector._is_injective: # pylint: disable=protected-access 494 raise NotImplementedError("cdf is not implemented when " 495 "bijector is not injective.") 496 x = self.bijector.inverse(y) 497 return self.distribution.cdf(x) 498 499 def _log_survival_function(self, y): 500 if self._is_maybe_event_override: 501 raise NotImplementedError("log_survival_function is not implemented when " 502 "overriding event_shape") 503 if not self.bijector._is_injective: # pylint: disable=protected-access 504 raise NotImplementedError("log_survival_function is not implemented when " 505 "bijector is not injective.") 506 x = self.bijector.inverse(y) 507 return self.distribution.log_survival_function(x) 508 509 def _survival_function(self, y): 510 if self._is_maybe_event_override: 511 raise NotImplementedError("survival_function is not implemented when " 512 "overriding event_shape") 513 if not self.bijector._is_injective: # pylint: disable=protected-access 514 raise NotImplementedError("survival_function is not implemented when " 515 "bijector is not injective.") 516 x = self.bijector.inverse(y) 517 return self.distribution.survival_function(x) 518 519 def _quantile(self, value): 520 if self._is_maybe_event_override: 521 raise NotImplementedError("quantile is not implemented when overriding " 522 "event_shape") 523 if not self.bijector._is_injective: # pylint: disable=protected-access 524 raise NotImplementedError("quantile is not implemented when " 525 "bijector is not injective.") 526 # x_q is the "qth quantile" of X iff q = P[X <= x_q]. Now, since X = 527 # g^{-1}(Y), q = P[X <= x_q] = P[g^{-1}(Y) <= x_q] = P[Y <= g(x_q)], 528 # implies the qth quantile of Y is g(x_q). 529 inv_cdf = self.distribution.quantile(value) 530 return self.bijector.forward(inv_cdf) 531 532 def _entropy(self): 533 if not self.bijector.is_constant_jacobian: 534 raise NotImplementedError("entropy is not implemented") 535 if not self.bijector._is_injective: # pylint: disable=protected-access 536 raise NotImplementedError("entropy is not implemented when " 537 "bijector is not injective.") 538 # Suppose Y = g(X) where g is a diffeomorphism and X is a continuous rv. It 539 # can be shown that: 540 # H[Y] = H[X] + E_X[(log o abs o det o J o g)(X)]. 541 # If is_constant_jacobian then: 542 # E_X[(log o abs o det o J o g)(X)] = (log o abs o det o J o g)(c) 543 # where c can by anything. 544 entropy = self.distribution.entropy() 545 if self._is_maybe_event_override: 546 # H[X] = sum_i H[X_i] if X_i are mutually independent. 547 # This means that a reduce_sum is a simple rescaling. 548 entropy *= math_ops.cast(math_ops.reduce_prod(self._override_event_shape), 549 dtype=entropy.dtype.base_dtype) 550 if self._is_maybe_batch_override: 551 new_shape = array_ops.concat([ 552 _ones_like(self._override_batch_shape), 553 self.distribution.batch_shape_tensor() 554 ], 0) 555 entropy = array_ops.reshape(entropy, new_shape) 556 multiples = array_ops.concat([ 557 self._override_batch_shape, 558 _ones_like(self.distribution.batch_shape_tensor()) 559 ], 0) 560 entropy = array_ops.tile(entropy, multiples) 561 dummy = array_ops.zeros( 562 shape=array_ops.concat( 563 [self.batch_shape_tensor(), self.event_shape_tensor()], 564 0), 565 dtype=self.dtype) 566 event_ndims = (self.event_shape.ndims if self.event_shape.ndims is not None 567 else array_ops.size(self.event_shape_tensor())) 568 ildj = self.bijector.inverse_log_det_jacobian( 569 dummy, event_ndims=event_ndims) 570 571 entropy -= math_ops.cast(ildj, entropy.dtype) 572 entropy.set_shape(self.batch_shape) 573 return entropy 574 575 def _maybe_validate_shape_override(self, override_shape, base_is_scalar, 576 validate_args, name): 577 """Helper to __init__ which ensures override batch/event_shape are valid.""" 578 if override_shape is None: 579 override_shape = [] 580 581 override_shape = ops.convert_to_tensor(override_shape, dtype=dtypes.int32, 582 name=name) 583 584 if not override_shape.dtype.is_integer: 585 raise TypeError("shape override must be an integer") 586 587 override_is_scalar = _is_scalar_from_shape(override_shape) 588 if tensor_util.constant_value(override_is_scalar): 589 return self._empty 590 591 dynamic_assertions = [] 592 593 if override_shape.get_shape().ndims is not None: 594 if override_shape.get_shape().ndims != 1: 595 raise ValueError("shape override must be a vector") 596 elif validate_args: 597 dynamic_assertions += [check_ops.assert_rank( 598 override_shape, 1, 599 message="shape override must be a vector")] 600 601 if tensor_util.constant_value(override_shape) is not None: 602 if any(s <= 0 for s in tensor_util.constant_value(override_shape)): 603 raise ValueError("shape override must have positive elements") 604 elif validate_args: 605 dynamic_assertions += [check_ops.assert_positive( 606 override_shape, 607 message="shape override must have positive elements")] 608 609 is_both_nonscalar = _logical_and(_logical_not(base_is_scalar), 610 _logical_not(override_is_scalar)) 611 if tensor_util.constant_value(is_both_nonscalar) is not None: 612 if tensor_util.constant_value(is_both_nonscalar): 613 raise ValueError("base distribution not scalar") 614 elif validate_args: 615 dynamic_assertions += [check_ops.assert_equal( 616 is_both_nonscalar, False, 617 message="base distribution not scalar")] 618 619 if not dynamic_assertions: 620 return override_shape 621 return control_flow_ops.with_dependencies( 622 dynamic_assertions, override_shape) 623 624 def _maybe_rotate_dims(self, x, rotate_right=False): 625 """Helper which rolls left event_dims left or right event_dims right.""" 626 needs_rotation_const = tensor_util.constant_value(self._needs_rotation) 627 if needs_rotation_const is not None and not needs_rotation_const: 628 return x 629 ndims = array_ops.rank(x) 630 n = (ndims - self._rotate_ndims) if rotate_right else self._rotate_ndims 631 return array_ops.transpose( 632 x, _concat_vectors(math_ops.range(n, ndims), math_ops.range(0, n))) 633 634 def _maybe_get_static_event_ndims(self): 635 if self.event_shape.ndims is not None: 636 return self.event_shape.ndims 637 638 event_ndims = array_ops.size(self.event_shape_tensor()) 639 event_ndims_ = distribution_util.maybe_get_static_value(event_ndims) 640 641 if event_ndims_ is not None: 642 return event_ndims_ 643 644 return event_ndims 645