1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities for probability distributions.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.framework import dtypes 22from tensorflow.python.framework import ops 23from tensorflow.python.framework import smart_cond 24from tensorflow.python.framework import tensor_shape 25from tensorflow.python.framework import tensor_util 26from tensorflow.python.ops import array_ops 27from tensorflow.python.ops import check_ops 28from tensorflow.python.ops import control_flow_ops 29from tensorflow.python.ops import math_ops 30from tensorflow.python.ops.linalg import linalg 31from tensorflow.python.ops.distributions import distribution as distribution_lib 32 33# The following two lines are redundant, in a sense. The first enables 34# good coding practice *within* this file (`util.prefer_static_value` 35# rather than `prefer_static_value`). The second ensures that users 36# also get the core utils when they import this file. 37from tensorflow.python.ops.distributions import util 38from tensorflow.python.ops.distributions.util import * # pylint: disable=wildcard-import 39 40 41def _convert_to_tensor(x, name): 42 return None if x is None else ops.convert_to_tensor(x, name=name) 43 44 45def mixture_stddev(mixture_weight_vector, mean_vector, stddev_vector): 46 """Computes the standard deviation of a mixture distribution. 47 48 This function works regardless of the component distribution, so long as 49 each component's mean and standard deviation can be provided. 50 51 Args: 52 mixture_weight_vector: A 2D tensor with shape [batch_size, num_components] 53 mean_vector: A 2D tensor of mixture component means. Has shape 54 `[batch_size, num_components]`. 55 stddev_vector: A 2D tensor of mixture component standard deviations. Has 56 shape `[batch_size, num_components]`. 57 Returns: 58 A 1D tensor of shape `[batch_size]` representing the standard deviation of 59 the mixture distribution with given weights and component means and standard 60 deviations. 61 Raises: 62 ValueError: If the shapes of the input tensors are not as expected. 63 """ 64 mixture_weight_vector.shape.assert_has_rank(2) 65 if not mean_vector.shape.is_compatible_with(mixture_weight_vector.shape): 66 raise ValueError("Expecting means to have same shape as mixture weights.") 67 if not stddev_vector.shape.is_compatible_with(mixture_weight_vector.shape): 68 raise ValueError("Expecting stddevs to have same shape as mixture weights.") 69 70 # Reshape the distribution parameters for batched vectorized dot products. 71 pi_for_dot_prod = array_ops.expand_dims(mixture_weight_vector, axis=1) 72 mu_for_dot_prod = array_ops.expand_dims(mean_vector, axis=2) 73 sigma_for_dot_prod = array_ops.expand_dims(stddev_vector, axis=2) 74 75 # weighted average of component means under mixture distribution. 76 mean_wa = math_ops.matmul(pi_for_dot_prod, mu_for_dot_prod) 77 mean_wa = array_ops.reshape(mean_wa, (-1,)) 78 # weighted average of component variances under mixture distribution. 79 var_wa = math_ops.matmul(pi_for_dot_prod, 80 math_ops.square(sigma_for_dot_prod)) 81 var_wa = array_ops.reshape(var_wa, (-1,)) 82 # weighted average of component squared means under mixture distribution. 83 sq_mean_wa = math_ops.matmul(pi_for_dot_prod, 84 math_ops.square(mu_for_dot_prod)) 85 sq_mean_wa = array_ops.reshape(sq_mean_wa, (-1,)) 86 mixture_variance = var_wa + sq_mean_wa - math_ops.square(mean_wa) 87 return math_ops.sqrt(mixture_variance) 88 89 90def make_tril_scale( 91 loc=None, 92 scale_tril=None, 93 scale_diag=None, 94 scale_identity_multiplier=None, 95 shape_hint=None, 96 validate_args=False, 97 assert_positive=False, 98 name=None): 99 """Creates a LinOp representing a lower triangular matrix. 100 101 Args: 102 loc: Floating-point `Tensor`. This is used for inferring shape in the case 103 where only `scale_identity_multiplier` is set. 104 scale_tril: Floating-point `Tensor` representing the diagonal matrix. 105 `scale_diag` has shape [N1, N2, ... k, k], which represents a k x k 106 lower triangular matrix. 107 When `None` no `scale_tril` term is added to the LinOp. 108 The upper triangular elements above the diagonal are ignored. 109 scale_diag: Floating-point `Tensor` representing the diagonal matrix. 110 `scale_diag` has shape [N1, N2, ... k], which represents a k x k 111 diagonal matrix. 112 When `None` no diagonal term is added to the LinOp. 113 scale_identity_multiplier: floating point rank 0 `Tensor` representing a 114 scaling done to the identity matrix. 115 When `scale_identity_multiplier = scale_diag = scale_tril = None` then 116 `scale += IdentityMatrix`. Otherwise no scaled-identity-matrix is added 117 to `scale`. 118 shape_hint: scalar integer `Tensor` representing a hint at the dimension of 119 the identity matrix when only `scale_identity_multiplier` is set. 120 validate_args: Python `bool` indicating whether arguments should be 121 checked for correctness. 122 assert_positive: Python `bool` indicating whether LinOp should be checked 123 for being positive definite. 124 name: Python `str` name given to ops managed by this object. 125 126 Returns: 127 `LinearOperator` representing a lower triangular matrix. 128 129 Raises: 130 ValueError: If only `scale_identity_multiplier` is set and `loc` and 131 `shape_hint` are both None. 132 """ 133 134 def _maybe_attach_assertion(x): 135 if not validate_args: 136 return x 137 if assert_positive: 138 return control_flow_ops.with_dependencies([ 139 check_ops.assert_positive( 140 array_ops.matrix_diag_part(x), 141 message="diagonal part must be positive"), 142 ], x) 143 return control_flow_ops.with_dependencies([ 144 check_ops.assert_none_equal( 145 array_ops.matrix_diag_part(x), 146 array_ops.zeros([], x.dtype), 147 message="diagonal part must be non-zero"), 148 ], x) 149 150 with ops.name_scope(name, "make_tril_scale", 151 values=[loc, scale_diag, scale_identity_multiplier]): 152 153 loc = _convert_to_tensor(loc, name="loc") 154 scale_tril = _convert_to_tensor(scale_tril, name="scale_tril") 155 scale_diag = _convert_to_tensor(scale_diag, name="scale_diag") 156 scale_identity_multiplier = _convert_to_tensor( 157 scale_identity_multiplier, 158 name="scale_identity_multiplier") 159 160 if scale_tril is not None: 161 scale_tril = array_ops.matrix_band_part(scale_tril, -1, 0) # Zero out TriU. 162 tril_diag = array_ops.matrix_diag_part(scale_tril) 163 if scale_diag is not None: 164 tril_diag += scale_diag 165 if scale_identity_multiplier is not None: 166 tril_diag += scale_identity_multiplier[..., array_ops.newaxis] 167 168 scale_tril = array_ops.matrix_set_diag(scale_tril, tril_diag) 169 170 return linalg.LinearOperatorLowerTriangular( 171 tril=_maybe_attach_assertion(scale_tril), 172 is_non_singular=True, 173 is_self_adjoint=False, 174 is_positive_definite=assert_positive) 175 176 return make_diag_scale( 177 loc=loc, 178 scale_diag=scale_diag, 179 scale_identity_multiplier=scale_identity_multiplier, 180 shape_hint=shape_hint, 181 validate_args=validate_args, 182 assert_positive=assert_positive, 183 name=name) 184 185 186def make_diag_scale( 187 loc=None, 188 scale_diag=None, 189 scale_identity_multiplier=None, 190 shape_hint=None, 191 validate_args=False, 192 assert_positive=False, 193 name=None): 194 """Creates a LinOp representing a diagonal matrix. 195 196 Args: 197 loc: Floating-point `Tensor`. This is used for inferring shape in the case 198 where only `scale_identity_multiplier` is set. 199 scale_diag: Floating-point `Tensor` representing the diagonal matrix. 200 `scale_diag` has shape [N1, N2, ... k], which represents a k x k 201 diagonal matrix. 202 When `None` no diagonal term is added to the LinOp. 203 scale_identity_multiplier: floating point rank 0 `Tensor` representing a 204 scaling done to the identity matrix. 205 When `scale_identity_multiplier = scale_diag = scale_tril = None` then 206 `scale += IdentityMatrix`. Otherwise no scaled-identity-matrix is added 207 to `scale`. 208 shape_hint: scalar integer `Tensor` representing a hint at the dimension of 209 the identity matrix when only `scale_identity_multiplier` is set. 210 validate_args: Python `bool` indicating whether arguments should be 211 checked for correctness. 212 assert_positive: Python `bool` indicating whether LinOp should be checked 213 for being positive definite. 214 name: Python `str` name given to ops managed by this object. 215 216 Returns: 217 `LinearOperator` representing a lower triangular matrix. 218 219 Raises: 220 ValueError: If only `scale_identity_multiplier` is set and `loc` and 221 `shape_hint` are both None. 222 """ 223 224 def _maybe_attach_assertion(x): 225 if not validate_args: 226 return x 227 if assert_positive: 228 return control_flow_ops.with_dependencies([ 229 check_ops.assert_positive( 230 x, message="diagonal part must be positive"), 231 ], x) 232 return control_flow_ops.with_dependencies([ 233 check_ops.assert_none_equal( 234 x, 235 array_ops.zeros([], x.dtype), 236 message="diagonal part must be non-zero")], x) 237 238 with ops.name_scope(name, "make_diag_scale", 239 values=[loc, scale_diag, scale_identity_multiplier]): 240 loc = _convert_to_tensor(loc, name="loc") 241 scale_diag = _convert_to_tensor(scale_diag, name="scale_diag") 242 scale_identity_multiplier = _convert_to_tensor( 243 scale_identity_multiplier, 244 name="scale_identity_multiplier") 245 246 if scale_diag is not None: 247 if scale_identity_multiplier is not None: 248 scale_diag += scale_identity_multiplier[..., array_ops.newaxis] 249 return linalg.LinearOperatorDiag( 250 diag=_maybe_attach_assertion(scale_diag), 251 is_non_singular=True, 252 is_self_adjoint=True, 253 is_positive_definite=assert_positive) 254 255 if loc is None and shape_hint is None: 256 raise ValueError( 257 "Cannot infer `event_shape` unless `loc` or " 258 "`shape_hint` is specified.") 259 260 if shape_hint is None: 261 shape_hint = loc.shape[-1] 262 263 if scale_identity_multiplier is None: 264 return linalg.LinearOperatorIdentity( 265 num_rows=shape_hint, 266 dtype=loc.dtype.base_dtype, 267 is_self_adjoint=True, 268 is_positive_definite=True, 269 assert_proper_shapes=validate_args) 270 271 return linalg.LinearOperatorScaledIdentity( 272 num_rows=shape_hint, 273 multiplier=_maybe_attach_assertion(scale_identity_multiplier), 274 is_non_singular=True, 275 is_self_adjoint=True, 276 is_positive_definite=assert_positive, 277 assert_proper_shapes=validate_args) 278 279 280def shapes_from_loc_and_scale(loc, scale, name="shapes_from_loc_and_scale"): 281 """Infer distribution batch and event shapes from a location and scale. 282 283 Location and scale family distributions determine their batch/event shape by 284 broadcasting the `loc` and `scale` args. This helper does that broadcast, 285 statically if possible. 286 287 Batch shape broadcasts as per the normal rules. 288 We allow the `loc` event shape to broadcast up to that of `scale`. We do not 289 allow `scale`'s event shape to change. Therefore, the last dimension of `loc` 290 must either be size `1`, or the same as `scale.range_dimension`. 291 292 See `MultivariateNormalLinearOperator` for a usage example. 293 294 Args: 295 loc: `N-D` `Tensor` with `N >= 1` (already converted to tensor) or `None`. 296 If `None`, both batch and event shape are determined by `scale`. 297 scale: A `LinearOperator` instance. 298 name: A string name to prepend to created ops. 299 300 Returns: 301 batch_shape: `TensorShape` (if broadcast is done statically), or `Tensor`. 302 event_shape: `TensorShape` (if broadcast is done statically), or `Tensor`. 303 304 Raises: 305 ValueError: If the last dimension of `loc` is determined statically to be 306 different than the range of `scale`. 307 """ 308 with ops.name_scope(name, values=[loc] + scale.graph_parents): 309 # Get event shape. 310 event_size = scale.range_dimension_tensor() 311 event_size_const = tensor_util.constant_value(event_size) 312 if event_size_const is not None: 313 event_shape = event_size_const.reshape([1]) 314 else: 315 event_shape = event_size[array_ops.newaxis] 316 317 # Static check that event shapes match. 318 if loc is not None: 319 loc_event_size = tensor_shape.dimension_value(loc.get_shape()[-1]) 320 if loc_event_size is not None and event_size_const is not None: 321 if loc_event_size != 1 and loc_event_size != event_size_const: 322 raise ValueError( 323 "Event size of 'scale' (%d) could not be broadcast up to that of " 324 "'loc' (%d)." % (loc_event_size, event_size_const)) 325 326 # Get batch shape. 327 batch_shape = scale.batch_shape_tensor() 328 if loc is None: 329 batch_shape_const = tensor_util.constant_value(batch_shape) 330 batch_shape = ( 331 batch_shape_const if batch_shape_const is not None else batch_shape) 332 else: 333 loc_batch_shape = loc.get_shape().with_rank_at_least(1)[:-1] 334 if (loc.get_shape().ndims is None or 335 not loc_batch_shape.is_fully_defined()): 336 loc_batch_shape = array_ops.shape(loc)[:-1] 337 else: 338 loc_batch_shape = ops.convert_to_tensor(loc_batch_shape, 339 name="loc_batch_shape") 340 # This is defined in the core util module. 341 # pylint: disable=undefined-variable 342 batch_shape = prefer_static_broadcast_shape(batch_shape, loc_batch_shape) 343 # pylint: enable=undefined-variable 344 345 return batch_shape, event_shape 346 347 348def get_broadcast_shape(*tensors): 349 """Get broadcast shape as a Python list of integers (preferred) or `Tensor`. 350 351 Args: 352 *tensors: One or more `Tensor` objects (already converted!). 353 354 Returns: 355 broadcast shape: Python list (if shapes determined statically), otherwise 356 an `int32` `Tensor`. 357 """ 358 # Try static. 359 s_shape = tensors[0].shape 360 for t in tensors[1:]: 361 s_shape = array_ops.broadcast_static_shape(s_shape, t.shape) 362 if s_shape.is_fully_defined(): 363 return s_shape.as_list() 364 365 # Fallback on dynamic. 366 d_shape = array_ops.shape(tensors[0]) 367 for t in tensors[1:]: 368 d_shape = array_ops.broadcast_dynamic_shape(d_shape, array_ops.shape(t)) 369 return d_shape 370 371 372def is_diagonal_scale(scale): 373 """Returns `True` if `scale` is a `LinearOperator` that is known to be diag. 374 375 Args: 376 scale: `LinearOperator` instance. 377 378 Returns: 379 Python `bool`. 380 381 Raises: 382 TypeError: If `scale` is not a `LinearOperator`. 383 """ 384 if not isinstance(scale, linalg.LinearOperator): 385 raise TypeError("Expected argument 'scale' to be instance of LinearOperator" 386 ". Found: %s" % scale) 387 return (isinstance(scale, linalg.LinearOperatorIdentity) or 388 isinstance(scale, linalg.LinearOperatorScaledIdentity) or 389 isinstance(scale, linalg.LinearOperatorDiag)) 390 391 392def maybe_check_scalar_distribution( 393 distribution, expected_base_dtype, validate_args): 394 """Helper which checks validity of a scalar `distribution` init arg. 395 396 Valid here means: 397 398 * `distribution` has scalar batch and event shapes. 399 * `distribution` is `FULLY_REPARAMETERIZED` 400 * `distribution` has expected dtype. 401 402 Args: 403 distribution: `Distribution`-like object. 404 expected_base_dtype: `TensorFlow` `dtype`. 405 validate_args: Python `bool`. Whether to do additional checks: 406 (i) check that reparameterization_type is `FULLY_REPARAMETERIZED`. 407 (ii) add `tf.Assert` ops to the graph to enforce that distribution 408 is scalar in the event that this cannot be determined statically. 409 410 Returns: 411 List of `tf.Assert` ops to run to enforce validity checks that could not 412 be statically determined. Empty if `not validate_args`. 413 414 Raises: 415 ValueError: If validate_args and distribution is not FULLY_REPARAMETERIZED 416 ValueError: If distribution is statically determined to not have both 417 scalar batch and scalar event shapes. 418 """ 419 if distribution.dtype != expected_base_dtype: 420 raise TypeError("dtype mismatch; " 421 "distribution.dtype=\"{}\" is not \"{}\"".format( 422 distribution.dtype.name, expected_base_dtype.name)) 423 424 # Although `reparameterization_type` is a static property, we guard it by 425 # `validate_args`. This allows users to use a `distribution` which is not 426 # reparameterized itself. However, we tacitly assume that although the 427 # distribution is not reparameterized, it only depends on non-trainable 428 # variables. 429 if validate_args and (distribution.reparameterization_type 430 != distribution_lib.FULLY_REPARAMETERIZED): 431 raise ValueError("Base distribution should be reparameterized or be " 432 "a function of non-trainable variables; " 433 "distribution.reparameterization_type = \"{}\" " 434 "!= \"FULLY_REPARAMETERIZED\".".format( 435 distribution.reparameterization_type)) 436 with ops.name_scope(name="check_distribution"): 437 assertions = [] 438 def check_is_scalar(is_scalar, name): 439 is_scalar_ = static_value(is_scalar) 440 if is_scalar_ is not None: 441 if not is_scalar_: 442 raise ValueError("distribution must be scalar; " 443 "distribution.{}=False is not True".format(name)) 444 elif validate_args: 445 assertions.append(check_ops.assert_equal( 446 is_scalar, True, 447 message=("distribution must be scalar; " 448 "distribution.{}=False is not True".format(name)))) 449 check_is_scalar(distribution.is_scalar_event(), "is_scalar_event") 450 check_is_scalar(distribution.is_scalar_batch(), "is_scalar_batch") 451 return assertions 452 453 454def pad_mixture_dimensions(x, mixture_distribution, categorical_distribution, 455 event_ndims): 456 """Pad dimensions of event tensors for mixture distributions. 457 458 See `Mixture._sample_n` and `MixtureSameFamily._sample_n` for usage examples. 459 460 Args: 461 x: event tensor to pad. 462 mixture_distribution: Base distribution of the mixture. 463 categorical_distribution: `Categorical` distribution that mixes the base 464 distribution. 465 event_ndims: Integer specifying the number of event dimensions in the event 466 tensor. 467 468 Returns: 469 A padded version of `x` that can broadcast with `categorical_distribution`. 470 """ 471 with ops.name_scope("pad_mix_dims", values=[x]): 472 def _get_ndims(d): 473 if d.batch_shape.ndims is not None: 474 return d.batch_shape.ndims 475 return array_ops.shape(d.batch_shape_tensor())[0] 476 dist_batch_ndims = _get_ndims(mixture_distribution) 477 cat_batch_ndims = _get_ndims(categorical_distribution) 478 pad_ndims = array_ops.where( 479 categorical_distribution.is_scalar_batch(), 480 dist_batch_ndims, 481 dist_batch_ndims - cat_batch_ndims) 482 s = array_ops.shape(x) 483 x = array_ops.reshape(x, shape=array_ops.concat([ 484 s[:-1], 485 array_ops.ones([pad_ndims], dtype=dtypes.int32), 486 s[-1:], 487 array_ops.ones([event_ndims], dtype=dtypes.int32), 488 ], axis=0)) 489 return x 490 491 492def static_value(x): 493 """Returns the static value of a `Tensor` or `None`.""" 494 return tensor_util.constant_value(ops.convert_to_tensor(x)) 495 496 497def move_dimension(x, source_idx, dest_idx): 498 """Move a single tensor dimension within its shape. 499 500 This is a special case of `tf.transpose()`, which applies 501 arbitrary permutations to tensor dimensions. 502 503 Args: 504 x: Tensor of rank `ndims`. 505 source_idx: Integer index into `x.shape` (negative indexing is 506 supported). 507 dest_idx: Integer index into `x.shape` (negative indexing is 508 supported). 509 510 Returns: 511 x_perm: Tensor of rank `ndims`, in which the dimension at original 512 index `source_idx` has been moved to new index `dest_idx`, with 513 all other dimensions retained in their original order. 514 515 Example: 516 517 ```python 518 x = tf.placeholder(shape=[200, 30, 4, 1, 6]) 519 x_perm = _move_dimension(x, 1, 1) # no-op 520 x_perm = _move_dimension(x, 0, 3) # result shape [30, 4, 1, 200, 6] 521 x_perm = _move_dimension(x, 0, -2) # equivalent to previous 522 x_perm = _move_dimension(x, 4, 2) # result shape [200, 30, 6, 4, 1] 523 ``` 524 """ 525 ndims = util.prefer_static_rank(x) 526 if isinstance(source_idx, int): 527 dtype = dtypes.int32 528 else: 529 dtype = dtypes.as_dtype(source_idx.dtype) 530 531 # Handle negative indexing. Since ndims might be dynamic, this makes 532 # source_idx and dest_idx also possibly dynamic. 533 if source_idx < 0: 534 source_idx = ndims + source_idx 535 if dest_idx < 0: 536 dest_idx = ndims + dest_idx 537 538 # Construct the appropriate permutation of dimensions, depending 539 # whether the source is before or after the destination. 540 def move_left_permutation(): 541 return util.prefer_static_value( 542 array_ops.concat([ 543 math_ops.range(0, dest_idx, dtype=dtype), 544 [source_idx], 545 math_ops.range(dest_idx, source_idx, dtype=dtype), 546 math_ops.range(source_idx+1, ndims, dtype=dtype)], axis=0)) 547 548 def move_right_permutation(): 549 return util.prefer_static_value( 550 array_ops.concat([ 551 math_ops.range(0, source_idx, dtype=dtype), 552 math_ops.range(source_idx+1, dest_idx+1, dtype=dtype), 553 [source_idx], 554 math_ops.range(dest_idx+1, ndims, dtype=dtype)], axis=0)) 555 556 def x_permuted(): 557 return array_ops.transpose( 558 x, perm=smart_cond.smart_cond(source_idx < dest_idx, 559 move_right_permutation, 560 move_left_permutation)) 561 562 # One final conditional to handle the special case where source 563 # and destination indices are equal. 564 return smart_cond.smart_cond(math_ops.equal(source_idx, dest_idx), 565 lambda: x, 566 x_permuted) 567