1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities for probability distributions.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import functools 22import hashlib 23 24import numpy as np 25 26from tensorflow.python.framework import constant_op 27from tensorflow.python.framework import dtypes 28from tensorflow.python.framework import ops 29from tensorflow.python.framework import tensor_shape 30from tensorflow.python.framework import tensor_util 31from tensorflow.python.ops import array_ops 32from tensorflow.python.ops import check_ops 33from tensorflow.python.ops import control_flow_ops 34from tensorflow.python.ops import linalg_ops 35from tensorflow.python.ops import math_ops 36from tensorflow.python.ops import nn 37from tensorflow.python.util import tf_inspect 38 39 40def assert_integer_form(x, 41 data=None, 42 summarize=None, 43 message=None, 44 int_dtype=None, 45 name="assert_integer_form"): 46 """Assert that x has integer components (or floats equal to integers). 47 48 Args: 49 x: Floating-point `Tensor` 50 data: The tensors to print out if the condition is `False`. Defaults to 51 error message and first few entries of `x` and `y`. 52 summarize: Print this many entries of each tensor. 53 message: A string to prefix to the default message. 54 int_dtype: A `tf.dtype` used to cast the float to. The default (`None`) 55 implies the smallest possible signed int will be used for casting. 56 name: A name for this operation (optional). 57 58 Returns: 59 Op raising `InvalidArgumentError` if `cast(x, int_dtype) != x`. 60 """ 61 with ops.name_scope(name, values=[x, data]): 62 x = ops.convert_to_tensor(x, name="x") 63 if x.dtype.is_integer: 64 return control_flow_ops.no_op() 65 message = message or "{} has non-integer components".format(x) 66 if int_dtype is None: 67 try: 68 int_dtype = { 69 dtypes.float16: dtypes.int16, 70 dtypes.float32: dtypes.int32, 71 dtypes.float64: dtypes.int64, 72 }[x.dtype.base_dtype] 73 except KeyError: 74 raise TypeError("Unrecognized type {}".format(x.dtype.name)) 75 return check_ops.assert_equal( 76 x, 77 math_ops.cast(math_ops.cast(x, int_dtype), x.dtype), 78 data=data, 79 summarize=summarize, 80 message=message, 81 name=name) 82 83 84def assert_symmetric(matrix): 85 matrix_t = array_ops.matrix_transpose(matrix) 86 return control_flow_ops.with_dependencies( 87 [check_ops.assert_equal(matrix, matrix_t)], matrix) 88 89 90def embed_check_nonnegative_integer_form( 91 x, name="embed_check_nonnegative_integer_form"): 92 """Assert x is a non-negative tensor, and optionally of integers.""" 93 with ops.name_scope(name, values=[x]): 94 x = ops.convert_to_tensor(x, name="x") 95 assertions = [ 96 check_ops.assert_non_negative( 97 x, message="'{}' must be non-negative.".format(x)), 98 ] 99 if not x.dtype.is_integer: 100 assertions += [ 101 assert_integer_form( 102 x, 103 message="'{}' cannot contain fractional components.".format(x)), 104 ] 105 return control_flow_ops.with_dependencies(assertions, x) 106 107 108def same_dynamic_shape(a, b): 109 """Returns whether a and b have the same dynamic shape. 110 111 Args: 112 a: `Tensor` 113 b: `Tensor` 114 115 Returns: 116 `bool` `Tensor` representing if both tensors have the same shape. 117 """ 118 a = ops.convert_to_tensor(a, name="a") 119 b = ops.convert_to_tensor(b, name="b") 120 121 # Here we can't just do math_ops.equal(a.shape, b.shape), since 122 # static shape inference may break the equality comparison between 123 # shape(a) and shape(b) in math_ops.equal. 124 def all_shapes_equal(): 125 return math_ops.reduce_all( 126 math_ops.equal( 127 array_ops.concat( 128 [array_ops.shape(a), array_ops.shape(b)], 0), 129 array_ops.concat( 130 [array_ops.shape(b), array_ops.shape(a)], 0))) 131 132 # One of the shapes isn't fully defined, so we need to use the dynamic 133 # shape. 134 return control_flow_ops.cond( 135 math_ops.equal(array_ops.rank(a), array_ops.rank(b)), 136 all_shapes_equal, lambda: constant_op.constant(False)) 137 138 139def maybe_get_static_value(x, dtype=None): 140 """Helper which tries to return a static value. 141 142 Given `x`, extract it's value statically, optionally casting to a specific 143 dtype. If this is not possible, None is returned. 144 145 Args: 146 x: `Tensor` for which to extract a value statically. 147 dtype: Optional dtype to cast to. 148 149 Returns: 150 Statically inferred value if possible, otherwise None. 151 """ 152 if x is None: 153 return x 154 try: 155 # This returns an np.ndarray. 156 x_ = tensor_util.constant_value(x) 157 except TypeError: 158 x_ = x 159 if x_ is None or dtype is None: 160 return x_ 161 return np.array(x_, dtype) 162 163 164def get_logits_and_probs(logits=None, 165 probs=None, 166 multidimensional=False, 167 validate_args=False, 168 name="get_logits_and_probs", 169 dtype=None): 170 """Converts logit to probabilities (or vice-versa), and returns both. 171 172 Args: 173 logits: Floating-point `Tensor` representing log-odds. 174 probs: Floating-point `Tensor` representing probabilities. 175 multidimensional: Python `bool`, default `False`. If `True`, represents 176 whether the last dimension of `logits` or `probs`, a `[N1, N2, ... k]` 177 dimensional tensor, representing the logit or probability of `shape[-1]` 178 classes. 179 validate_args: Python `bool`, default `False`. When `True`, either assert `0 180 <= probs <= 1` (if not `multidimensional`) or that the last dimension of 181 `probs` sums to one. 182 name: A name for this operation (optional). 183 dtype: `tf.DType` to prefer when converting args to `Tensor`s. 184 185 Returns: 186 logits, probs: Tuple of `Tensor`s. If `probs` has an entry that is `0` or 187 `1`, then the corresponding entry in the returned logit will be `-Inf` and 188 `Inf` respectively. 189 190 Raises: 191 ValueError: if neither `probs` nor `logits` were passed in, or both were. 192 """ 193 with ops.name_scope(name, values=[probs, logits]): 194 if (probs is None) == (logits is None): 195 raise ValueError("Must pass probs or logits, but not both.") 196 197 if probs is None: 198 logits = ops.convert_to_tensor(logits, name="logits", dtype=dtype) 199 if not logits.dtype.is_floating: 200 raise TypeError("logits must having floating type.") 201 # We can early return since we constructed probs and therefore know 202 # they're valid. 203 if multidimensional: 204 if validate_args: 205 logits = embed_check_categorical_event_shape(logits) 206 return logits, nn.softmax(logits, name="probs") 207 return logits, math_ops.sigmoid(logits, name="probs") 208 209 probs = ops.convert_to_tensor(probs, name="probs", dtype=dtype) 210 if not probs.dtype.is_floating: 211 raise TypeError("probs must having floating type.") 212 213 if validate_args: 214 with ops.name_scope("validate_probs"): 215 one = constant_op.constant(1., probs.dtype) 216 dependencies = [check_ops.assert_non_negative(probs)] 217 if multidimensional: 218 probs = embed_check_categorical_event_shape(probs) 219 dependencies += [ 220 check_ops.assert_near( 221 math_ops.reduce_sum(probs, -1), 222 one, 223 message="probs does not sum to 1.") 224 ] 225 else: 226 dependencies += [ 227 check_ops.assert_less_equal( 228 probs, one, message="probs has components greater than 1.") 229 ] 230 probs = control_flow_ops.with_dependencies(dependencies, probs) 231 232 with ops.name_scope("logits"): 233 if multidimensional: 234 # Here we don't compute the multidimensional case, in a manner 235 # consistent with respect to the unidimensional case. We do so 236 # following the TF convention. Typically, you might expect to see 237 # logits = log(probs) - log(probs[pivot]). A side-effect of 238 # being consistent with the TF approach is that the unidimensional case 239 # implicitly handles the second dimension but the multidimensional case 240 # explicitly keeps the pivot dimension. 241 return math_ops.log(probs), probs 242 return math_ops.log(probs) - math_ops.log1p(-1. * probs), probs 243 244 245def _is_known_unsigned_by_dtype(dt): 246 """Helper returning True if dtype is known to be unsigned.""" 247 return { 248 dtypes.bool: True, 249 dtypes.uint8: True, 250 dtypes.uint16: True, 251 }.get(dt.base_dtype, False) 252 253 254def _is_known_signed_by_dtype(dt): 255 """Helper returning True if dtype is known to be signed.""" 256 return { 257 dtypes.float16: True, 258 dtypes.float32: True, 259 dtypes.float64: True, 260 dtypes.int8: True, 261 dtypes.int16: True, 262 dtypes.int32: True, 263 dtypes.int64: True, 264 }.get(dt.base_dtype, False) 265 266 267def _is_known_dtype(dt): 268 """Helper returning True if dtype is known.""" 269 return _is_known_unsigned_by_dtype(dt) or _is_known_signed_by_dtype(dt) 270 271 272def _largest_integer_by_dtype(dt): 273 """Helper returning the largest integer exactly representable by dtype.""" 274 if not _is_known_dtype(dt): 275 raise TypeError("Unrecognized dtype: {}".format(dt.name)) 276 if dt.is_floating: 277 return int(2**(np.finfo(dt.as_numpy_dtype).nmant + 1)) 278 if dt.is_integer: 279 return np.iinfo(dt.as_numpy_dtype).max 280 if dt.base_dtype == dtypes.bool: 281 return int(1) 282 # We actually can't land here but keep the case for completeness. 283 raise TypeError("Unrecognized dtype: {}".format(dt.name)) 284 285 286def _smallest_integer_by_dtype(dt): 287 """Helper returning the smallest integer exactly representable by dtype.""" 288 if not _is_known_dtype(dt): 289 raise TypeError("Unrecognized dtype: {}".format(dt.name)) 290 if _is_known_unsigned_by_dtype(dt): 291 return 0 292 return -1 * _largest_integer_by_dtype(dt) 293 294 295def _is_integer_like_by_dtype(dt): 296 """Helper returning True if dtype.is_integer or is `bool`.""" 297 if not _is_known_dtype(dt): 298 raise TypeError("Unrecognized dtype: {}".format(dt.name)) 299 return dt.is_integer or dt.base_dtype == dtypes.bool 300 301 302def embed_check_categorical_event_shape( 303 categorical_param, name="embed_check_categorical_event_shape"): 304 """Embeds checks that categorical distributions don't have too many classes. 305 306 A categorical-type distribution is one which, e.g., returns the class label 307 rather than a one-hot encoding. E.g., `Categorical(probs)`. 308 309 Since distributions output samples in the same dtype as the parameters, we 310 must ensure that casting doesn't lose precision. That is, the 311 `parameter.dtype` implies a maximum number of classes. However, since shape is 312 `int32` and categorical variables are presumed to be indexes into a `Tensor`, 313 we must also ensure that the number of classes is no larger than the largest 314 possible `int32` index, i.e., `2**31-1`. 315 316 In other words the number of classes, `K`, must satisfy the following 317 condition: 318 319 ```python 320 K <= min( 321 int(2**31 - 1), # Largest float as an index. 322 { 323 dtypes.float16: int(2**11), # Largest int as a float16. 324 dtypes.float32: int(2**24), 325 dtypes.float64: int(2**53), 326 }.get(categorical_param.dtype.base_dtype, 0)) 327 ``` 328 329 Args: 330 categorical_param: Floating-point `Tensor` representing parameters of 331 distribution over categories. The rightmost shape is presumed to be the 332 number of categories. 333 name: A name for this operation (optional). 334 335 Returns: 336 categorical_param: Input `Tensor` with appropriate assertions embedded. 337 338 Raises: 339 TypeError: if `categorical_param` has an unknown `dtype`. 340 ValueError: if we can statically identify `categorical_param` as being too 341 large (for being closed under int32/float casting). 342 """ 343 with ops.name_scope(name, values=[categorical_param]): 344 x = ops.convert_to_tensor(categorical_param, name="categorical_param") 345 # The size must not exceed both of: 346 # - The largest possible int32 (since categorical values are presumed to be 347 # indexes into a Tensor). 348 # - The largest possible integer exactly representable under the given 349 # floating-point dtype (since we need to cast to/from). 350 # 351 # The chosen floating-point thresholds are 2**(1 + mantissa_bits). 352 # For more details, see: 353 # https://en.wikipedia.org/wiki/Floating-point_arithmetic#Internal_representation 354 x_dtype = x.dtype.base_dtype 355 max_event_size = ( 356 _largest_integer_by_dtype(x_dtype) if x_dtype.is_floating else 0) 357 if max_event_size == 0: 358 raise TypeError("Unable to validate size of unrecognized dtype " 359 "({}).".format(x_dtype.name)) 360 try: 361 x_shape_static = x.get_shape().with_rank_at_least(1) 362 except ValueError: 363 raise ValueError("A categorical-distribution parameter must have " 364 "at least 1 dimension.") 365 if tensor_shape.dimension_value(x_shape_static[-1]) is not None: 366 event_size = x_shape_static.dims[-1].value 367 if event_size < 2: 368 raise ValueError("A categorical-distribution parameter must have at " 369 "least 2 events.") 370 if event_size > max_event_size: 371 raise ValueError("Number of classes exceeds `dtype` precision, i.e., " 372 "{} implies shape ({}) cannot exceed {}.".format( 373 x_dtype.name, event_size, max_event_size)) 374 return x 375 else: 376 event_size = array_ops.shape(x, name="x_shape")[-1] 377 return control_flow_ops.with_dependencies([ 378 check_ops.assert_rank_at_least( 379 x, 380 1, 381 message=("A categorical-distribution parameter must have " 382 "at least 1 dimension.")), 383 check_ops.assert_greater_equal( 384 array_ops.shape(x)[-1], 385 2, 386 message=("A categorical-distribution parameter must have at " 387 "least 2 events.")), 388 check_ops.assert_less_equal( 389 event_size, 390 max_event_size, 391 message="Number of classes exceeds `dtype` precision, " 392 "i.e., {} dtype cannot exceed {} shape.".format( 393 x_dtype.name, max_event_size)), 394 ], x) 395 396 397def embed_check_integer_casting_closed(x, 398 target_dtype, 399 assert_nonnegative=True, 400 name="embed_check_casting_closed"): 401 """Ensures integers remain unaffected despite casting to/from int/float types. 402 403 Example integer-types: `uint8`, `int32`, `bool`. 404 Example floating-types: `float32`, `float64`. 405 406 The largest possible integer representable by an IEEE754 floating-point is 407 `2**(1 + mantissa_bits)` yet the largest possible integer as an int-type is 408 `2**(bits - 1) - 1`. This function ensures that a `Tensor` purporting to have 409 integer-form values can be cast to some other type without loss of precision. 410 411 The smallest representable integer is the negative of the largest 412 representable integer, except for types: `uint8`, `uint16`, `bool`. For these 413 types, the smallest representable integer is `0`. 414 415 Args: 416 x: `Tensor` representing integer-form values. 417 target_dtype: TF `dtype` under which `x` should have identical values. 418 assert_nonnegative: `bool` indicating `x` should contain nonnegative values. 419 name: A name for this operation (optional). 420 421 Returns: 422 x: Input `Tensor` with appropriate assertions embedded. 423 424 Raises: 425 TypeError: if `x` is neither integer- nor floating-type. 426 TypeError: if `target_dtype` is neither integer- nor floating-type. 427 TypeError: if neither `x` nor `target_dtype` are integer-type. 428 """ 429 430 with ops.name_scope(name, values=[x]): 431 x = ops.convert_to_tensor(x, name="x") 432 if (not _is_integer_like_by_dtype(x.dtype) and not x.dtype.is_floating): 433 raise TypeError("{}.dtype must be floating- or " 434 "integer-type.".format(x.dtype.name)) 435 if (not _is_integer_like_by_dtype(target_dtype) and 436 not target_dtype.is_floating): 437 raise TypeError("target_dtype ({}) must be floating- or " 438 "integer-type.".format(target_dtype.name)) 439 if (not _is_integer_like_by_dtype(x.dtype) and 440 not _is_integer_like_by_dtype(target_dtype)): 441 raise TypeError("At least one of {}.dtype ({}) and target_dtype ({}) " 442 "must be integer-type.".format(x, x.dtype.name, 443 target_dtype.name)) 444 445 assertions = [] 446 if assert_nonnegative: 447 assertions += [ 448 check_ops.assert_non_negative( 449 x, message="Elements must be non-negative."), 450 ] 451 452 if x.dtype.is_floating: 453 # Being here means _is_integer_like_by_dtype(target_dtype) = True. 454 # Since this check implies the magnitude check below, we need only it. 455 assertions += [ 456 assert_integer_form( 457 x, 458 int_dtype=target_dtype, 459 message="Elements must be {}-equivalent.".format( 460 target_dtype.name)), 461 ] 462 else: 463 if (_largest_integer_by_dtype(x.dtype) > 464 _largest_integer_by_dtype(target_dtype)): 465 # Cast may lose integer precision. 466 assertions += [ 467 check_ops.assert_less_equal( 468 x, 469 _largest_integer_by_dtype(target_dtype), 470 message=("Elements cannot exceed {}.".format( 471 _largest_integer_by_dtype(target_dtype)))), 472 ] 473 if (not assert_nonnegative and (_smallest_integer_by_dtype( 474 x.dtype) < _smallest_integer_by_dtype(target_dtype))): 475 assertions += [ 476 check_ops.assert_greater_equal( 477 x, 478 _smallest_integer_by_dtype(target_dtype), 479 message=("Elements cannot be smaller than {}.".format( 480 _smallest_integer_by_dtype(target_dtype)))), 481 ] 482 483 if not assertions: 484 return x 485 return control_flow_ops.with_dependencies(assertions, x) 486 487 488def log_combinations(n, counts, name="log_combinations"): 489 """Multinomial coefficient. 490 491 Given `n` and `counts`, where `counts` has last dimension `k`, we compute 492 the multinomial coefficient as: 493 494 ```n! / sum_i n_i!``` 495 496 where `i` runs over all `k` classes. 497 498 Args: 499 n: Floating-point `Tensor` broadcastable with `counts`. This represents `n` 500 outcomes. 501 counts: Floating-point `Tensor` broadcastable with `n`. This represents 502 counts in `k` classes, where `k` is the last dimension of the tensor. 503 name: A name for this operation (optional). 504 505 Returns: 506 `Tensor` representing the multinomial coefficient between `n` and `counts`. 507 """ 508 # First a bit about the number of ways counts could have come in: 509 # E.g. if counts = [1, 2], then this is 3 choose 2. 510 # In general, this is (sum counts)! / sum(counts!) 511 # The sum should be along the last dimension of counts. This is the 512 # "distribution" dimension. Here n a priori represents the sum of counts. 513 with ops.name_scope(name, values=[n, counts]): 514 n = ops.convert_to_tensor(n, name="n") 515 counts = ops.convert_to_tensor(counts, name="counts") 516 total_permutations = math_ops.lgamma(n + 1) 517 counts_factorial = math_ops.lgamma(counts + 1) 518 redundant_permutations = math_ops.reduce_sum(counts_factorial, axis=[-1]) 519 return total_permutations - redundant_permutations 520 521 522def matrix_diag_transform(matrix, transform=None, name=None): 523 """Transform diagonal of [batch-]matrix, leave rest of matrix unchanged. 524 525 Create a trainable covariance defined by a Cholesky factor: 526 527 ```python 528 # Transform network layer into 2 x 2 array. 529 matrix_values = tf.contrib.layers.fully_connected(activations, 4) 530 matrix = tf.reshape(matrix_values, (batch_size, 2, 2)) 531 532 # Make the diagonal positive. If the upper triangle was zero, this would be a 533 # valid Cholesky factor. 534 chol = matrix_diag_transform(matrix, transform=tf.nn.softplus) 535 536 # LinearOperatorLowerTriangular ignores the upper triangle. 537 operator = LinearOperatorLowerTriangular(chol) 538 ``` 539 540 Example of heteroskedastic 2-D linear regression. 541 542 ```python 543 tfd = tfp.distributions 544 545 # Get a trainable Cholesky factor. 546 matrix_values = tf.contrib.layers.fully_connected(activations, 4) 547 matrix = tf.reshape(matrix_values, (batch_size, 2, 2)) 548 chol = matrix_diag_transform(matrix, transform=tf.nn.softplus) 549 550 # Get a trainable mean. 551 mu = tf.contrib.layers.fully_connected(activations, 2) 552 553 # This is a fully trainable multivariate normal! 554 dist = tfd.MultivariateNormalTriL(mu, chol) 555 556 # Standard log loss. Minimizing this will "train" mu and chol, and then dist 557 # will be a distribution predicting labels as multivariate Gaussians. 558 loss = -1 * tf.reduce_mean(dist.log_prob(labels)) 559 ``` 560 561 Args: 562 matrix: Rank `R` `Tensor`, `R >= 2`, where the last two dimensions are 563 equal. 564 transform: Element-wise function mapping `Tensors` to `Tensors`. To be 565 applied to the diagonal of `matrix`. If `None`, `matrix` is returned 566 unchanged. Defaults to `None`. 567 name: A name to give created ops. Defaults to "matrix_diag_transform". 568 569 Returns: 570 A `Tensor` with same shape and `dtype` as `matrix`. 571 """ 572 with ops.name_scope(name, "matrix_diag_transform", [matrix]): 573 matrix = ops.convert_to_tensor(matrix, name="matrix") 574 if transform is None: 575 return matrix 576 # Replace the diag with transformed diag. 577 diag = array_ops.matrix_diag_part(matrix) 578 transformed_diag = transform(diag) 579 transformed_mat = array_ops.matrix_set_diag(matrix, transformed_diag) 580 581 return transformed_mat 582 583 584def rotate_transpose(x, shift, name="rotate_transpose"): 585 """Circularly moves dims left or right. 586 587 Effectively identical to: 588 589 ```python 590 numpy.transpose(x, numpy.roll(numpy.arange(len(x.shape)), shift)) 591 ``` 592 593 When `validate_args=False` additional graph-runtime checks are 594 performed. These checks entail moving data from to GPU to CPU. 595 596 Example: 597 598 ```python 599 x = tf.random.normal([1, 2, 3, 4]) # Tensor of shape [1, 2, 3, 4]. 600 rotate_transpose(x, -1).shape == [2, 3, 4, 1] 601 rotate_transpose(x, -2).shape == [3, 4, 1, 2] 602 rotate_transpose(x, 1).shape == [4, 1, 2, 3] 603 rotate_transpose(x, 2).shape == [3, 4, 1, 2] 604 rotate_transpose(x, 7).shape == rotate_transpose(x, 3).shape # [2, 3, 4, 1] 605 rotate_transpose(x, -7).shape == rotate_transpose(x, -3).shape # [4, 1, 2, 3] 606 ``` 607 608 Args: 609 x: `Tensor`. 610 shift: `Tensor`. Number of dimensions to transpose left (shift<0) or 611 transpose right (shift>0). 612 name: Python `str`. The name to give this op. 613 614 Returns: 615 rotated_x: Input `Tensor` with dimensions circularly rotated by shift. 616 617 Raises: 618 TypeError: if shift is not integer type. 619 """ 620 with ops.name_scope(name, values=[x, shift]): 621 x = ops.convert_to_tensor(x, name="x") 622 shift = ops.convert_to_tensor(shift, name="shift") 623 # We do not assign back to preserve constant-ness. 624 check_ops.assert_integer(shift) 625 shift_value_static = tensor_util.constant_value(shift) 626 ndims = x.get_shape().ndims 627 if ndims is not None and shift_value_static is not None: 628 if ndims < 2: 629 return x 630 shift_value_static = np.sign(shift_value_static) * ( 631 abs(shift_value_static) % ndims) 632 if shift_value_static == 0: 633 return x 634 perm = np.roll(np.arange(ndims), shift_value_static) 635 return array_ops.transpose(x, perm=perm) 636 else: 637 # Consider if we always had a positive shift, and some specified 638 # direction. 639 # When shifting left we want the new array: 640 # last(x, n-shift) + first(x, shift) 641 # and if shifting right then we want: 642 # last(x, shift) + first(x, n-shift) 643 # Observe that last(a) == slice(a, n) and first(a) == slice(0, a). 644 # Also, we can encode direction and shift as one: direction * shift. 645 # Combining these facts, we have: 646 # a = cond(shift<0, -shift, n-shift) 647 # last(x, n-a) + first(x, a) == x[a:n] + x[0:a] 648 # Finally, we transform shift by modulo length so it can be specified 649 # independently from the array upon which it operates (like python). 650 ndims = array_ops.rank(x) 651 shift = array_ops.where_v2( 652 math_ops.less(shift, 0), 653 math_ops.mod(-shift, ndims), # pylint: disable=invalid-unary-operand-type 654 ndims - math_ops.mod(shift, ndims)) 655 first = math_ops.range(0, shift) 656 last = math_ops.range(shift, ndims) 657 perm = array_ops.concat([last, first], 0) 658 return array_ops.transpose(x, perm=perm) 659 660 661def pick_vector(cond, true_vector, false_vector, name="pick_vector"): 662 """Picks possibly different length row `Tensor`s based on condition. 663 664 Value `Tensor`s should have exactly one dimension. 665 666 If `cond` is a python Boolean or `tf.constant` then either `true_vector` or 667 `false_vector` is immediately returned. I.e., no graph nodes are created and 668 no validation happens. 669 670 Args: 671 cond: `Tensor`. Must have `dtype=tf.bool` and be scalar. 672 true_vector: `Tensor` of one dimension. Returned when cond is `True`. 673 false_vector: `Tensor` of one dimension. Returned when cond is `False`. 674 name: Python `str`. The name to give this op. 675 Example: ```python pick_vector(tf.less(0, 5), tf.range(10, 12), tf.range(15, 676 18)) # [10, 11] pick_vector(tf.less(5, 0), tf.range(10, 12), tf.range(15, 677 18)) # [15, 16, 17] ``` 678 679 Returns: 680 true_or_false_vector: `Tensor`. 681 682 Raises: 683 TypeError: if `cond.dtype != tf.bool` 684 TypeError: if `cond` is not a constant and 685 `true_vector.dtype != false_vector.dtype` 686 """ 687 with ops.name_scope(name, values=(cond, true_vector, false_vector)): 688 cond = ops.convert_to_tensor(cond, name="cond") 689 if cond.dtype != dtypes.bool: 690 raise TypeError("%s.dtype=%s which is not %s" % 691 (cond, cond.dtype, dtypes.bool)) 692 cond_value_static = tensor_util.constant_value(cond) 693 if cond_value_static is not None: 694 return true_vector if cond_value_static else false_vector 695 true_vector = ops.convert_to_tensor(true_vector, name="true_vector") 696 false_vector = ops.convert_to_tensor(false_vector, name="false_vector") 697 if true_vector.dtype != false_vector.dtype: 698 raise TypeError( 699 "%s.dtype=%s does not match %s.dtype=%s" % 700 (true_vector, true_vector.dtype, false_vector, false_vector.dtype)) 701 n = array_ops.shape(true_vector)[0] 702 return array_ops.slice( 703 array_ops.concat([true_vector, false_vector], 0), 704 [array_ops.where_v2(cond, 0, n)], [array_ops.where(cond, n, -1)]) 705 706 707def prefer_static_broadcast_shape(shape1, 708 shape2, 709 name="prefer_static_broadcast_shape"): 710 """Convenience function which statically broadcasts shape when possible. 711 712 Args: 713 shape1: `1-D` integer `Tensor`. Already converted to tensor! 714 shape2: `1-D` integer `Tensor`. Already converted to tensor! 715 name: A string name to prepend to created ops. 716 717 Returns: 718 The broadcast shape, either as `TensorShape` (if broadcast can be done 719 statically), or as a `Tensor`. 720 """ 721 with ops.name_scope(name, values=[shape1, shape2]): 722 723 def make_shape_tensor(x): 724 return ops.convert_to_tensor(x, name="shape", dtype=dtypes.int32) 725 726 def get_tensor_shape(s): 727 if isinstance(s, tensor_shape.TensorShape): 728 return s 729 s_ = tensor_util.constant_value(make_shape_tensor(s)) 730 if s_ is not None: 731 return tensor_shape.TensorShape(s_) 732 return None 733 734 def get_shape_tensor(s): 735 if not isinstance(s, tensor_shape.TensorShape): 736 return make_shape_tensor(s) 737 if s.is_fully_defined(): 738 return make_shape_tensor(s.as_list()) 739 raise ValueError("Cannot broadcast from partially " 740 "defined `TensorShape`.") 741 742 shape1_ = get_tensor_shape(shape1) 743 shape2_ = get_tensor_shape(shape2) 744 if shape1_ is not None and shape2_ is not None: 745 return array_ops.broadcast_static_shape(shape1_, shape2_) 746 747 shape1_ = get_shape_tensor(shape1) 748 shape2_ = get_shape_tensor(shape2) 749 return array_ops.broadcast_dynamic_shape(shape1_, shape2_) 750 751 752def prefer_static_rank(x): 753 """Return static rank of tensor `x` if available, else `tf.rank(x)`. 754 755 Args: 756 x: `Tensor` (already converted). 757 758 Returns: 759 Numpy array (if static rank is obtainable), else `Tensor`. 760 """ 761 return prefer_static_value(array_ops.rank(x)) 762 763 764def prefer_static_shape(x): 765 """Return static shape of tensor `x` if available, else `tf.shape(x)`. 766 767 Args: 768 x: `Tensor` (already converted). 769 770 Returns: 771 Numpy array (if static shape is obtainable), else `Tensor`. 772 """ 773 return prefer_static_value(array_ops.shape(x)) 774 775 776def prefer_static_value(x): 777 """Return static value of tensor `x` if available, else `x`. 778 779 Args: 780 x: `Tensor` (already converted). 781 782 Returns: 783 Numpy array (if static value is obtainable), else `Tensor`. 784 """ 785 static_x = tensor_util.constant_value(x) 786 if static_x is not None: 787 return static_x 788 return x 789 790 791def gen_new_seed(seed, salt): 792 """Generate a new seed, from the given seed and salt.""" 793 if seed is None: 794 return None 795 string = (str(seed) + salt).encode("utf-8") 796 return int(hashlib.md5(string).hexdigest()[:8], 16) & 0x7FFFFFFF 797 798 799def fill_triangular(x, upper=False, name=None): 800 """Creates a (batch of) triangular matrix from a vector of inputs. 801 802 Created matrix can be lower- or upper-triangular. (It is more efficient to 803 create the matrix as upper or lower, rather than transpose.) 804 805 Triangular matrix elements are filled in a clockwise spiral. See example, 806 below. 807 808 If `x.get_shape()` is `[b1, b2, ..., bB, d]` then the output shape is 809 `[b1, b2, ..., bB, n, n]` where `n` is such that `d = n(n+1)/2`, i.e., 810 `n = int(np.sqrt(0.25 + 2. * m) - 0.5)`. 811 812 Example: 813 814 ```python 815 fill_triangular([1, 2, 3, 4, 5, 6]) 816 # ==> [[4, 0, 0], 817 # [6, 5, 0], 818 # [3, 2, 1]] 819 820 fill_triangular([1, 2, 3, 4, 5, 6], upper=True) 821 # ==> [[1, 2, 3], 822 # [0, 5, 6], 823 # [0, 0, 4]] 824 ``` 825 826 For comparison, a pure numpy version of this function can be found in 827 `util_test.py`, function `_fill_triangular`. 828 829 Args: 830 x: `Tensor` representing lower (or upper) triangular elements. 831 upper: Python `bool` representing whether output matrix should be upper 832 triangular (`True`) or lower triangular (`False`, default). 833 name: Python `str`. The name to give this op. 834 835 Returns: 836 tril: `Tensor` with lower (or upper) triangular elements filled from `x`. 837 838 Raises: 839 ValueError: if `x` cannot be mapped to a triangular matrix. 840 """ 841 842 with ops.name_scope(name, "fill_triangular", values=[x]): 843 x = ops.convert_to_tensor(x, name="x") 844 if tensor_shape.dimension_value( 845 x.shape.with_rank_at_least(1)[-1]) is not None: 846 # Formula derived by solving for n: m = n(n+1)/2. 847 m = np.int32(x.shape.dims[-1].value) 848 n = np.sqrt(0.25 + 2. * m) - 0.5 849 if n != np.floor(n): 850 raise ValueError("Input right-most shape ({}) does not " 851 "correspond to a triangular matrix.".format(m)) 852 n = np.int32(n) 853 static_final_shape = x.shape[:-1].concatenate([n, n]) 854 else: 855 m = array_ops.shape(x)[-1] 856 # For derivation, see above. Casting automatically lops off the 0.5, so we 857 # omit it. We don't validate n is an integer because this has 858 # graph-execution cost; an error will be thrown from the reshape, below. 859 n = math_ops.cast( 860 math_ops.sqrt(0.25 + math_ops.cast(2 * m, dtype=dtypes.float32)), 861 dtype=dtypes.int32) 862 static_final_shape = x.shape.with_rank_at_least(1)[:-1].concatenate( 863 [None, None]) 864 # We now concatenate the "tail" of `x` to `x` (and reverse one of them). 865 # 866 # We do this based on the insight that the input `x` provides `ceil(n/2)` 867 # rows of an `n x n` matrix, some of which will get zeroed out being on the 868 # wrong side of the diagonal. The first row will not get zeroed out at all, 869 # and we need `floor(n/2)` more rows, so the first is what we omit from 870 # `x_tail`. If we then stack those `ceil(n/2)` rows with the `floor(n/2)` 871 # rows provided by a reversed tail, it is exactly the other set of elements 872 # of the reversed tail which will be zeroed out for being on the wrong side 873 # of the diagonal further up/down the matrix. And, in doing-so, we've filled 874 # the triangular matrix in a clock-wise spiral pattern. Neat! 875 # 876 # Try it out in numpy: 877 # n = 3 878 # x = np.arange(n * (n + 1) / 2) 879 # m = x.shape[0] 880 # n = np.int32(np.sqrt(.25 + 2 * m) - .5) 881 # x_tail = x[(m - (n**2 - m)):] 882 # np.concatenate([x_tail, x[::-1]], 0).reshape(n, n) # lower 883 # # ==> array([[3, 4, 5], 884 # [5, 4, 3], 885 # [2, 1, 0]]) 886 # np.concatenate([x, x_tail[::-1]], 0).reshape(n, n) # upper 887 # # ==> array([[0, 1, 2], 888 # [3, 4, 5], 889 # [5, 4, 3]]) 890 # 891 # Note that we can't simply do `x[..., -(n**2 - m):]` because this doesn't 892 # correctly handle `m == n == 1`. Hence, we do nonnegative indexing. 893 # Furthermore observe that: 894 # m - (n**2 - m) 895 # = n**2 / 2 + n / 2 - (n**2 - n**2 / 2 + n / 2) 896 # = 2 (n**2 / 2 + n / 2) - n**2 897 # = n**2 + n - n**2 898 # = n 899 ndims = prefer_static_rank(x) 900 if upper: 901 x_list = [x, array_ops.reverse(x[..., n:], axis=[ndims - 1])] 902 else: 903 x_list = [x[..., n:], array_ops.reverse(x, axis=[ndims - 1])] 904 new_shape = ( 905 static_final_shape.as_list() if static_final_shape.is_fully_defined() 906 else array_ops.concat([array_ops.shape(x)[:-1], [n, n]], axis=0)) 907 x = array_ops.reshape(array_ops.concat(x_list, axis=-1), new_shape) 908 x = array_ops.matrix_band_part( 909 x, num_lower=(0 if upper else -1), num_upper=(-1 if upper else 0)) 910 x.set_shape(static_final_shape) 911 return x 912 913 914def fill_triangular_inverse(x, upper=False, name=None): 915 """Creates a vector from a (batch of) triangular matrix. 916 917 The vector is created from the lower-triangular or upper-triangular portion 918 depending on the value of the parameter `upper`. 919 920 If `x.shape` is `[b1, b2, ..., bB, n, n]` then the output shape is 921 `[b1, b2, ..., bB, d]` where `d = n (n + 1) / 2`. 922 923 Example: 924 925 ```python 926 fill_triangular_inverse( 927 [[4, 0, 0], 928 [6, 5, 0], 929 [3, 2, 1]]) 930 931 # ==> [1, 2, 3, 4, 5, 6] 932 933 fill_triangular_inverse( 934 [[1, 2, 3], 935 [0, 5, 6], 936 [0, 0, 4]], upper=True) 937 938 # ==> [1, 2, 3, 4, 5, 6] 939 ``` 940 941 Args: 942 x: `Tensor` representing lower (or upper) triangular elements. 943 upper: Python `bool` representing whether output matrix should be upper 944 triangular (`True`) or lower triangular (`False`, default). 945 name: Python `str`. The name to give this op. 946 947 Returns: 948 flat_tril: (Batch of) vector-shaped `Tensor` representing vectorized lower 949 (or upper) triangular elements from `x`. 950 """ 951 952 with ops.name_scope(name, "fill_triangular_inverse", values=[x]): 953 x = ops.convert_to_tensor(x, name="x") 954 if tensor_shape.dimension_value( 955 x.shape.with_rank_at_least(2)[-1]) is not None: 956 n = np.int32(x.shape.dims[-1].value) 957 m = np.int32((n * (n + 1)) // 2) 958 static_final_shape = x.shape[:-2].concatenate([m]) 959 else: 960 n = array_ops.shape(x)[-1] 961 m = (n * (n + 1)) // 2 962 static_final_shape = x.shape.with_rank_at_least(2)[:-2].concatenate( 963 [None]) 964 ndims = prefer_static_rank(x) 965 if upper: 966 initial_elements = x[..., 0, :] 967 triangular_portion = x[..., 1:, :] 968 else: 969 initial_elements = array_ops.reverse(x[..., -1, :], axis=[ndims - 2]) 970 triangular_portion = x[..., :-1, :] 971 rotated_triangular_portion = array_ops.reverse( 972 array_ops.reverse(triangular_portion, axis=[ndims - 1]), 973 axis=[ndims - 2]) 974 consolidated_matrix = triangular_portion + rotated_triangular_portion 975 end_sequence = array_ops.reshape( 976 consolidated_matrix, 977 array_ops.concat([array_ops.shape(x)[:-2], [n * (n - 1)]], axis=0)) 978 y = array_ops.concat([initial_elements, end_sequence[..., :m - n]], axis=-1) 979 y.set_shape(static_final_shape) 980 return y 981 982 983def tridiag(below=None, diag=None, above=None, name=None): 984 """Creates a matrix with values set above, below, and on the diagonal. 985 986 Example: 987 988 ```python 989 tridiag(below=[1., 2., 3.], 990 diag=[4., 5., 6., 7.], 991 above=[8., 9., 10.]) 992 # ==> array([[ 4., 8., 0., 0.], 993 # [ 1., 5., 9., 0.], 994 # [ 0., 2., 6., 10.], 995 # [ 0., 0., 3., 7.]], dtype=float32) 996 ``` 997 998 Warning: This Op is intended for convenience, not efficiency. 999 1000 Args: 1001 below: `Tensor` of shape `[B1, ..., Bb, d-1]` corresponding to the below 1002 diagonal part. `None` is logically equivalent to `below = 0`. 1003 diag: `Tensor` of shape `[B1, ..., Bb, d]` corresponding to the diagonal 1004 part. `None` is logically equivalent to `diag = 0`. 1005 above: `Tensor` of shape `[B1, ..., Bb, d-1]` corresponding to the above 1006 diagonal part. `None` is logically equivalent to `above = 0`. 1007 name: Python `str`. The name to give this op. 1008 1009 Returns: 1010 tridiag: `Tensor` with values set above, below and on the diagonal. 1011 1012 Raises: 1013 ValueError: if all inputs are `None`. 1014 """ 1015 1016 def _pad(x): 1017 """Prepends and appends a zero to every vector in a batch of vectors.""" 1018 shape = array_ops.concat([array_ops.shape(x)[:-1], [1]], axis=0) 1019 z = array_ops.zeros(shape, dtype=x.dtype) 1020 return array_ops.concat([z, x, z], axis=-1) 1021 1022 def _add(*x): 1023 """Adds list of Tensors, ignoring `None`.""" 1024 s = None 1025 for y in x: 1026 if y is None: 1027 continue 1028 elif s is None: 1029 s = y 1030 else: 1031 s += y 1032 if s is None: 1033 raise ValueError("Must specify at least one of `below`, `diag`, `above`.") 1034 return s 1035 1036 with ops.name_scope(name, "tridiag", [below, diag, above]): 1037 if below is not None: 1038 below = ops.convert_to_tensor(below, name="below") 1039 below = array_ops.matrix_diag(_pad(below))[..., :-1, 1:] 1040 if diag is not None: 1041 diag = ops.convert_to_tensor(diag, name="diag") 1042 diag = array_ops.matrix_diag(diag) 1043 if above is not None: 1044 above = ops.convert_to_tensor(above, name="above") 1045 above = array_ops.matrix_diag(_pad(above))[..., 1:, :-1] 1046 # TODO(jvdillon): Consider using scatter_nd instead of creating three full 1047 # matrices. 1048 return _add(below, diag, above) 1049 1050 1051def reduce_weighted_logsumexp(logx, 1052 w=None, 1053 axis=None, 1054 keep_dims=False, 1055 return_sign=False, 1056 name=None): 1057 """Computes `log(abs(sum(weight * exp(elements across tensor dimensions))))`. 1058 1059 If all weights `w` are known to be positive, it is more efficient to directly 1060 use `reduce_logsumexp`, i.e., `tf.reduce_logsumexp(logx + tf.math.log(w))` is 1061 more 1062 efficient than `du.reduce_weighted_logsumexp(logx, w)`. 1063 1064 Reduces `input_tensor` along the dimensions given in `axis`. 1065 Unless `keep_dims` is true, the rank of the tensor is reduced by 1 for each 1066 entry in `axis`. If `keep_dims` is true, the reduced dimensions 1067 are retained with length 1. 1068 1069 If `axis` has no entries, all dimensions are reduced, and a 1070 tensor with a single element is returned. 1071 1072 This function is more numerically stable than log(sum(w * exp(input))). It 1073 avoids overflows caused by taking the exp of large inputs and underflows 1074 caused by taking the log of small inputs. 1075 1076 For example: 1077 1078 ```python 1079 x = tf.constant([[0., 0, 0], 1080 [0, 0, 0]]) 1081 1082 w = tf.constant([[-1., 1, 1], 1083 [1, 1, 1]]) 1084 1085 du.reduce_weighted_logsumexp(x, w) 1086 # ==> log(-1*1 + 1*1 + 1*1 + 1*1 + 1*1 + 1*1) = log(4) 1087 1088 du.reduce_weighted_logsumexp(x, w, axis=0) 1089 # ==> [log(-1+1), log(1+1), log(1+1)] 1090 1091 du.reduce_weighted_logsumexp(x, w, axis=1) 1092 # ==> [log(-1+1+1), log(1+1+1)] 1093 1094 du.reduce_weighted_logsumexp(x, w, axis=1, keep_dims=True) 1095 # ==> [[log(-1+1+1)], [log(1+1+1)]] 1096 1097 du.reduce_weighted_logsumexp(x, w, axis=[0, 1]) 1098 # ==> log(-1+5) 1099 ``` 1100 1101 Args: 1102 logx: The tensor to reduce. Should have numeric type. 1103 w: The weight tensor. Should have numeric type identical to `logx`. 1104 axis: The dimensions to reduce. If `None` (the default), reduces all 1105 dimensions. Must be in the range `[-rank(input_tensor), 1106 rank(input_tensor))`. 1107 keep_dims: If true, retains reduced dimensions with length 1. 1108 return_sign: If `True`, returns the sign of the result. 1109 name: A name for the operation (optional). 1110 1111 Returns: 1112 lswe: The `log(abs(sum(weight * exp(x))))` reduced tensor. 1113 sign: (Optional) The sign of `sum(weight * exp(x))`. 1114 """ 1115 with ops.name_scope(name, "reduce_weighted_logsumexp", [logx, w]): 1116 logx = ops.convert_to_tensor(logx, name="logx") 1117 if w is None: 1118 lswe = math_ops.reduce_logsumexp(logx, axis=axis, keepdims=keep_dims) 1119 if return_sign: 1120 sgn = array_ops.ones_like(lswe) 1121 return lswe, sgn 1122 return lswe 1123 w = ops.convert_to_tensor(w, dtype=logx.dtype, name="w") 1124 log_absw_x = logx + math_ops.log(math_ops.abs(w)) 1125 max_log_absw_x = math_ops.reduce_max(log_absw_x, axis=axis, keepdims=True) 1126 # If the largest element is `-inf` or `inf` then we don't bother subtracting 1127 # off the max. We do this because otherwise we'd get `inf - inf = NaN`. That 1128 # this is ok follows from the fact that we're actually free to subtract any 1129 # value we like, so long as we add it back after taking the `log(sum(...))`. 1130 max_log_absw_x = array_ops.where_v2( 1131 math_ops.is_inf(max_log_absw_x), array_ops.zeros_like(max_log_absw_x), 1132 max_log_absw_x) 1133 wx_over_max_absw_x = ( 1134 math_ops.sign(w) * math_ops.exp(log_absw_x - max_log_absw_x)) 1135 sum_wx_over_max_absw_x = math_ops.reduce_sum( 1136 wx_over_max_absw_x, axis=axis, keepdims=keep_dims) 1137 if not keep_dims: 1138 max_log_absw_x = array_ops.squeeze(max_log_absw_x, axis) 1139 sgn = math_ops.sign(sum_wx_over_max_absw_x) 1140 lswe = max_log_absw_x + math_ops.log(sgn * sum_wx_over_max_absw_x) 1141 if return_sign: 1142 return lswe, sgn 1143 return lswe 1144 1145 1146# TODO(jvdillon): Merge this test back into: 1147# tensorflow/python/ops/softplus_op_test.py 1148# once TF core is accepting new ops. 1149def softplus_inverse(x, name=None): 1150 """Computes the inverse softplus, i.e., x = softplus_inverse(softplus(x)). 1151 1152 Mathematically this op is equivalent to: 1153 1154 ```none 1155 softplus_inverse = log(exp(x) - 1.) 1156 ``` 1157 1158 Args: 1159 x: `Tensor`. Non-negative (not enforced), floating-point. 1160 name: A name for the operation (optional). 1161 1162 Returns: 1163 `Tensor`. Has the same type/shape as input `x`. 1164 """ 1165 with ops.name_scope(name, "softplus_inverse", values=[x]): 1166 x = ops.convert_to_tensor(x, name="x") 1167 # We begin by deriving a more numerically stable softplus_inverse: 1168 # x = softplus(y) = Log[1 + exp{y}], (which means x > 0). 1169 # ==> exp{x} = 1 + exp{y} (1) 1170 # ==> y = Log[exp{x} - 1] (2) 1171 # = Log[(exp{x} - 1) / exp{x}] + Log[exp{x}] 1172 # = Log[(1 - exp{-x}) / 1] + Log[exp{x}] 1173 # = Log[1 - exp{-x}] + x (3) 1174 # (2) is the "obvious" inverse, but (3) is more stable than (2) for large x. 1175 # For small x (e.g. x = 1e-10), (3) will become -inf since 1 - exp{-x} will 1176 # be zero. To fix this, we use 1 - exp{-x} approx x for small x > 0. 1177 # 1178 # In addition to the numerically stable derivation above, we clamp 1179 # small/large values to be congruent with the logic in: 1180 # tensorflow/core/kernels/softplus_op.h 1181 # 1182 # Finally, we set the input to one whenever the input is too large or too 1183 # small. This ensures that no unchosen codepath is +/- inf. This is 1184 # necessary to ensure the gradient doesn't get NaNs. Recall that the 1185 # gradient of `where` behaves like `pred*pred_true + (1-pred)*pred_false` 1186 # thus an `inf` in an unselected path results in `0*inf=nan`. We are careful 1187 # to overwrite `x` with ones only when we will never actually use this 1188 # value. Note that we use ones and not zeros since `log(expm1(0.)) = -inf`. 1189 threshold = np.log(np.finfo(x.dtype.as_numpy_dtype).eps) + 2. 1190 is_too_small = math_ops.less(x, np.exp(threshold)) 1191 is_too_large = math_ops.greater(x, -threshold) 1192 too_small_value = math_ops.log(x) 1193 too_large_value = x 1194 # This `where` will ultimately be a NOP because we won't select this 1195 # codepath whenever we used the surrogate `ones_like`. 1196 x = array_ops.where_v2( 1197 math_ops.logical_or(is_too_small, is_too_large), array_ops.ones_like(x), 1198 x) 1199 y = x + math_ops.log(-math_ops.expm1(-x)) # == log(expm1(x)) 1200 return array_ops.where_v2( 1201 is_too_small, too_small_value, 1202 array_ops.where_v2(is_too_large, too_large_value, y)) 1203 1204 1205# TODO(b/35290280): Add unit-tests. 1206def dimension_size(x, axis): 1207 """Returns the size of a specific dimension.""" 1208 # Since tf.gather isn't "constant-in, constant-out", we must first check the 1209 # static shape or fallback to dynamic shape. 1210 s = tensor_shape.dimension_value( 1211 x.shape.with_rank_at_least(np.abs(axis))[axis]) 1212 if s is not None: 1213 return s 1214 return array_ops.shape(x)[axis] 1215 1216 1217def process_quadrature_grid_and_probs(quadrature_grid_and_probs, 1218 dtype, 1219 validate_args, 1220 name=None): 1221 """Validates quadrature grid, probs or computes them as necessary. 1222 1223 Args: 1224 quadrature_grid_and_probs: Python pair of `float`-like `Tensor`s 1225 representing the sample points and the corresponding (possibly 1226 normalized) weight. When `None`, defaults to: 1227 `np.polynomial.hermite.hermgauss(deg=8)`. 1228 dtype: The expected `dtype` of `grid` and `probs`. 1229 validate_args: Python `bool`, default `False`. When `True` distribution 1230 parameters are checked for validity despite possibly degrading runtime 1231 performance. When `False` invalid inputs may silently render incorrect 1232 outputs. 1233 name: Python `str` name prefixed to Ops created by this class. 1234 1235 Returns: 1236 quadrature_grid_and_probs: Python pair of `float`-like `Tensor`s 1237 representing the sample points and the corresponding (possibly 1238 normalized) weight. 1239 1240 Raises: 1241 ValueError: if `quadrature_grid_and_probs is not None` and 1242 `len(quadrature_grid_and_probs[0]) != len(quadrature_grid_and_probs[1])` 1243 """ 1244 with ops.name_scope(name, "process_quadrature_grid_and_probs", 1245 [quadrature_grid_and_probs]): 1246 if quadrature_grid_and_probs is None: 1247 grid, probs = np.polynomial.hermite.hermgauss(deg=8) 1248 grid = grid.astype(dtype.as_numpy_dtype) 1249 probs = probs.astype(dtype.as_numpy_dtype) 1250 probs /= np.linalg.norm(probs, ord=1, keepdims=True) 1251 grid = ops.convert_to_tensor(grid, name="grid", dtype=dtype) 1252 probs = ops.convert_to_tensor(probs, name="probs", dtype=dtype) 1253 return grid, probs 1254 1255 grid, probs = tuple(quadrature_grid_and_probs) 1256 grid = ops.convert_to_tensor(grid, name="grid", dtype=dtype) 1257 probs = ops.convert_to_tensor(probs, name="unnormalized_probs", dtype=dtype) 1258 probs /= linalg_ops.norm(probs, ord=1, axis=-1, keepdims=True, name="probs") 1259 1260 def _static_event_size(x): 1261 """Returns the static size of a specific dimension or `None`.""" 1262 return tensor_shape.dimension_value(x.shape.with_rank_at_least(1)[-1]) 1263 1264 m, n = _static_event_size(probs), _static_event_size(grid) 1265 if m is not None and n is not None: 1266 if m != n: 1267 raise ValueError("`quadrature_grid_and_probs` must be a `tuple` of " 1268 "same-length zero-th-dimension `Tensor`s " 1269 "(saw lengths {}, {})".format(m, n)) 1270 elif validate_args: 1271 assertions = [ 1272 check_ops.assert_equal( 1273 dimension_size(probs, axis=-1), 1274 dimension_size(grid, axis=-1), 1275 message=("`quadrature_grid_and_probs` must be a `tuple` of " 1276 "same-length zero-th-dimension `Tensor`s")), 1277 ] 1278 with ops.control_dependencies(assertions): 1279 grid = array_ops.identity(grid) 1280 probs = array_ops.identity(probs) 1281 return grid, probs 1282 1283 1284def pad(x, axis, front=False, back=False, value=0, count=1, name=None): 1285 """Pads `value` to the front and/or back of a `Tensor` dim, `count` times. 1286 1287 Args: 1288 x: `Tensor` input. 1289 axis: Scalar `int`-like `Tensor` representing the single dimension to pad. 1290 (Negative indexing is supported.) 1291 front: Python `bool`; if `True` the beginning of the `axis` dimension is 1292 padded with `value`, `count` times. If `False` no front padding is made. 1293 back: Python `bool`; if `True` the end of the `axis` dimension is padded 1294 with `value`, `count` times. If `False` no end padding is made. 1295 value: Scalar `int`-like `Tensor` representing the actual value added to the 1296 front and/or back of the `axis` dimension of `x`. 1297 count: Scalar `int`-like `Tensor` representing number of elements added to 1298 the front and/or back of the `axis` dimension of `x`. E.g., if `front = 1299 back = True` then `2 * count` elements are added. 1300 name: Python `str` name prefixed to Ops created by this function. 1301 1302 Returns: 1303 pad: The padded version of input `x`. 1304 1305 Raises: 1306 ValueError: if both `front` and `back` are `False`. 1307 TypeError: if `count` is not `int`-like. 1308 """ 1309 with ops.name_scope(name, "pad", [x, value, count]): 1310 x = ops.convert_to_tensor(x, name="x") 1311 value = ops.convert_to_tensor(value, dtype=x.dtype, name="value") 1312 count = ops.convert_to_tensor(count, name="count") 1313 if not count.dtype.is_integer: 1314 raise TypeError("`count.dtype` (`{}`) must be `int`-like.".format( 1315 count.dtype.name)) 1316 if not front and not back: 1317 raise ValueError("At least one of `front`, `back` must be `True`.") 1318 ndims = ( 1319 x.shape.ndims if x.shape.ndims is not None else array_ops.rank( 1320 x, name="ndims")) 1321 axis = ops.convert_to_tensor(axis, name="axis") 1322 axis_ = tensor_util.constant_value(axis) 1323 if axis_ is not None: 1324 axis = axis_ 1325 if axis < 0: 1326 axis = ndims + axis 1327 count_ = tensor_util.constant_value(count) 1328 if axis_ >= 0 or x.shape.ndims is not None: 1329 head = x.shape[:axis] 1330 middle = tensor_shape.TensorShape(None if count_ is None else ( 1331 tensor_shape.dimension_at_index(x.shape, axis) + count_ * 1332 (front + back))) 1333 tail = x.shape[axis + 1:] 1334 final_shape = head.concatenate(middle.concatenate(tail)) 1335 else: 1336 final_shape = None 1337 else: 1338 axis = array_ops.where_v2(axis < 0, ndims + axis, axis) 1339 final_shape = None 1340 x = array_ops.pad( 1341 x, 1342 paddings=array_ops.one_hot( 1343 indices=array_ops.stack( 1344 [axis if front else -1, axis if back else -1]), 1345 depth=ndims, 1346 axis=0, 1347 on_value=count, 1348 dtype=dtypes.int32), 1349 constant_values=value) 1350 if final_shape is not None: 1351 x.set_shape(final_shape) 1352 return x 1353 1354 1355def parent_frame_arguments(): 1356 """Returns parent frame arguments. 1357 1358 When called inside a function, returns a dictionary with the caller's function 1359 arguments. These are positional arguments and keyword arguments (**kwargs), 1360 while variable arguments (*varargs) are excluded. 1361 1362 When called at global scope, this will return an empty dictionary, since there 1363 are no arguments. 1364 1365 WARNING: If caller function argument names are overloaded before invoking 1366 this method, then values will reflect the overloaded value. For this reason, 1367 we recommend calling `parent_frame_arguments` at the beginning of the 1368 function. 1369 """ 1370 # All arguments and the names used for *varargs, and **kwargs 1371 arg_names, variable_arg_name, keyword_arg_name, local_vars = ( 1372 tf_inspect._inspect.getargvalues( # pylint: disable=protected-access 1373 # Get the first frame of the caller of this method. 1374 tf_inspect._inspect.stack()[1][0])) # pylint: disable=protected-access 1375 1376 # Remove the *varargs, and flatten the **kwargs. Both are 1377 # nested lists. 1378 local_vars.pop(variable_arg_name, {}) 1379 keyword_args = local_vars.pop(keyword_arg_name, {}) 1380 1381 final_args = {} 1382 # Copy over arguments and their values. In general, local_vars 1383 # may contain more than just the arguments, since this method 1384 # can be called anywhere in a function. 1385 for arg_name in arg_names: 1386 final_args[arg_name] = local_vars.pop(arg_name) 1387 final_args.update(keyword_args) 1388 1389 return final_args 1390 1391 1392class AppendDocstring(object): 1393 """Helper class to promote private subclass docstring to public counterpart. 1394 1395 Example: 1396 1397 ```python 1398 class TransformedDistribution(Distribution): 1399 @distribution_util.AppendDocstring( 1400 additional_note="A special note!", 1401 kwargs_dict={"foo": "An extra arg."}) 1402 def _prob(self, y, foo=None): 1403 pass 1404 ``` 1405 1406 In this case, the `AppendDocstring` decorator appends the `additional_note` to 1407 the docstring of `prob` (not `_prob`) and adds a new `kwargs` 1408 section with each dictionary item as a bullet-point. 1409 1410 For a more detailed example, see `TransformedDistribution`. 1411 """ 1412 1413 def __init__(self, additional_note="", kwargs_dict=None): 1414 """Initializes the AppendDocstring object. 1415 1416 Args: 1417 additional_note: Python string added as additional docstring to public 1418 version of function. 1419 kwargs_dict: Python string/string dictionary representing specific kwargs 1420 expanded from the **kwargs input. 1421 1422 Raises: 1423 ValueError: if kwargs_dict.key contains whitespace. 1424 ValueError: if kwargs_dict.value contains newlines. 1425 """ 1426 self._additional_note = additional_note 1427 if kwargs_dict: 1428 bullets = [] 1429 for key in sorted(kwargs_dict.keys()): 1430 value = kwargs_dict[key] 1431 if any(x.isspace() for x in key): 1432 raise ValueError("Parameter name \"%s\" contains whitespace." % key) 1433 value = value.lstrip() 1434 if "\n" in value: 1435 raise ValueError( 1436 "Parameter description for \"%s\" contains newlines." % key) 1437 bullets.append("* `%s`: %s" % (key, value)) 1438 self._additional_note += ("\n\n##### `kwargs`:\n\n" + "\n".join(bullets)) 1439 1440 def __call__(self, fn): 1441 1442 @functools.wraps(fn) 1443 def _fn(*args, **kwargs): 1444 return fn(*args, **kwargs) 1445 1446 if _fn.__doc__ is None: 1447 _fn.__doc__ = self._additional_note 1448 else: 1449 _fn.__doc__ += "\n%s" % self._additional_note 1450 return _fn 1451