1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities for probability distributions.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import functools 22import hashlib 23 24import numpy as np 25 26from tensorflow.python.framework import constant_op 27from tensorflow.python.framework import dtypes 28from tensorflow.python.framework import ops 29from tensorflow.python.framework import tensor_shape 30from tensorflow.python.framework import tensor_util 31from tensorflow.python.ops import array_ops 32from tensorflow.python.ops import check_ops 33from tensorflow.python.ops import control_flow_ops 34from tensorflow.python.ops import linalg_ops 35from tensorflow.python.ops import math_ops 36from tensorflow.python.ops import nn 37from tensorflow.python.util import tf_inspect 38 39 40def assert_integer_form(x, 41 data=None, 42 summarize=None, 43 message=None, 44 int_dtype=None, 45 name="assert_integer_form"): 46 """Assert that x has integer components (or floats equal to integers). 47 48 Args: 49 x: Floating-point `Tensor` 50 data: The tensors to print out if the condition is `False`. Defaults to 51 error message and first few entries of `x` and `y`. 52 summarize: Print this many entries of each tensor. 53 message: A string to prefix to the default message. 54 int_dtype: A `tf.dtype` used to cast the float to. The default (`None`) 55 implies the smallest possible signed int will be used for casting. 56 name: A name for this operation (optional). 57 58 Returns: 59 Op raising `InvalidArgumentError` if `cast(x, int_dtype) != x`. 60 """ 61 with ops.name_scope(name, values=[x, data]): 62 x = ops.convert_to_tensor(x, name="x") 63 if x.dtype.is_integer: 64 return control_flow_ops.no_op() 65 message = message or "{} has non-integer components".format(x) 66 if int_dtype is None: 67 try: 68 int_dtype = { 69 dtypes.float16: dtypes.int16, 70 dtypes.float32: dtypes.int32, 71 dtypes.float64: dtypes.int64, 72 }[x.dtype.base_dtype] 73 except KeyError: 74 raise TypeError("Unrecognized type {}".format(x.dtype.name)) 75 return check_ops.assert_equal( 76 x, 77 math_ops.cast(math_ops.cast(x, int_dtype), x.dtype), 78 data=data, 79 summarize=summarize, 80 message=message, 81 name=name) 82 83 84def assert_symmetric(matrix): 85 matrix_t = array_ops.matrix_transpose(matrix) 86 return control_flow_ops.with_dependencies( 87 [check_ops.assert_equal(matrix, matrix_t)], matrix) 88 89 90def embed_check_nonnegative_integer_form( 91 x, name="embed_check_nonnegative_integer_form"): 92 """Assert x is a non-negative tensor, and optionally of integers.""" 93 with ops.name_scope(name, values=[x]): 94 x = ops.convert_to_tensor(x, name="x") 95 assertions = [ 96 check_ops.assert_non_negative( 97 x, message="'{}' must be non-negative.".format(x)), 98 ] 99 if not x.dtype.is_integer: 100 assertions += [ 101 assert_integer_form( 102 x, 103 message="'{}' cannot contain fractional components.".format(x)), 104 ] 105 return control_flow_ops.with_dependencies(assertions, x) 106 107 108def same_dynamic_shape(a, b): 109 """Returns whether a and b have the same dynamic shape. 110 111 Args: 112 a: `Tensor` 113 b: `Tensor` 114 115 Returns: 116 `bool` `Tensor` representing if both tensors have the same shape. 117 """ 118 a = ops.convert_to_tensor(a, name="a") 119 b = ops.convert_to_tensor(b, name="b") 120 121 # Here we can't just do math_ops.equal(a.shape, b.shape), since 122 # static shape inference may break the equality comparison between 123 # shape(a) and shape(b) in math_ops.equal. 124 def all_shapes_equal(): 125 return math_ops.reduce_all( 126 math_ops.equal( 127 array_ops.concat( 128 [array_ops.shape(a), array_ops.shape(b)], 0), 129 array_ops.concat( 130 [array_ops.shape(b), array_ops.shape(a)], 0))) 131 132 # One of the shapes isn't fully defined, so we need to use the dynamic 133 # shape. 134 return control_flow_ops.cond( 135 math_ops.equal(array_ops.rank(a), array_ops.rank(b)), 136 all_shapes_equal, lambda: constant_op.constant(False)) 137 138 139def maybe_get_static_value(x, dtype=None): 140 """Helper which tries to return a static value. 141 142 Given `x`, extract it's value statically, optionally casting to a specific 143 dtype. If this is not possible, None is returned. 144 145 Args: 146 x: `Tensor` for which to extract a value statically. 147 dtype: Optional dtype to cast to. 148 149 Returns: 150 Statically inferred value if possible, otherwise None. 151 """ 152 if x is None: 153 return x 154 try: 155 # This returns an np.ndarray. 156 x_ = tensor_util.constant_value(x) 157 except TypeError: 158 x_ = x 159 if x_ is None or dtype is None: 160 return x_ 161 return np.array(x_, dtype) 162 163 164def get_logits_and_probs(logits=None, 165 probs=None, 166 multidimensional=False, 167 validate_args=False, 168 name="get_logits_and_probs", 169 dtype=None): 170 """Converts logit to probabilities (or vice-versa), and returns both. 171 172 Args: 173 logits: Floating-point `Tensor` representing log-odds. 174 probs: Floating-point `Tensor` representing probabilities. 175 multidimensional: Python `bool`, default `False`. If `True`, represents 176 whether the last dimension of `logits` or `probs`, a `[N1, N2, ... k]` 177 dimensional tensor, representing the logit or probability of `shape[-1]` 178 classes. 179 validate_args: Python `bool`, default `False`. When `True`, either assert `0 180 <= probs <= 1` (if not `multidimensional`) or that the last dimension of 181 `probs` sums to one. 182 name: A name for this operation (optional). 183 dtype: `tf.DType` to prefer when converting args to `Tensor`s. 184 185 Returns: 186 logits, probs: Tuple of `Tensor`s. If `probs` has an entry that is `0` or 187 `1`, then the corresponding entry in the returned logit will be `-Inf` and 188 `Inf` respectively. 189 190 Raises: 191 ValueError: if neither `probs` nor `logits` were passed in, or both were. 192 """ 193 with ops.name_scope(name, values=[probs, logits]): 194 if (probs is None) == (logits is None): 195 raise ValueError("Must pass probs or logits, but not both.") 196 197 if probs is None: 198 logits = ops.convert_to_tensor(logits, name="logits", dtype=dtype) 199 if not logits.dtype.is_floating: 200 raise TypeError("logits must having floating type.") 201 # We can early return since we constructed probs and therefore know 202 # they're valid. 203 if multidimensional: 204 if validate_args: 205 logits = embed_check_categorical_event_shape(logits) 206 return logits, nn.softmax(logits, name="probs") 207 return logits, math_ops.sigmoid(logits, name="probs") 208 209 probs = ops.convert_to_tensor(probs, name="probs", dtype=dtype) 210 if not probs.dtype.is_floating: 211 raise TypeError("probs must having floating type.") 212 213 if validate_args: 214 with ops.name_scope("validate_probs"): 215 one = constant_op.constant(1., probs.dtype) 216 dependencies = [check_ops.assert_non_negative(probs)] 217 if multidimensional: 218 probs = embed_check_categorical_event_shape(probs) 219 dependencies += [ 220 check_ops.assert_near( 221 math_ops.reduce_sum(probs, -1), 222 one, 223 message="probs does not sum to 1.") 224 ] 225 else: 226 dependencies += [ 227 check_ops.assert_less_equal( 228 probs, one, message="probs has components greater than 1.") 229 ] 230 probs = control_flow_ops.with_dependencies(dependencies, probs) 231 232 with ops.name_scope("logits"): 233 if multidimensional: 234 # Here we don't compute the multidimensional case, in a manner 235 # consistent with respect to the unidimensional case. We do so 236 # following the TF convention. Typically, you might expect to see 237 # logits = log(probs) - log(probs[pivot]). A side-effect of 238 # being consistent with the TF approach is that the unidimensional case 239 # implicitly handles the second dimension but the multidimensional case 240 # explicitly keeps the pivot dimension. 241 return math_ops.log(probs), probs 242 return math_ops.log(probs) - math_ops.log1p(-1. * probs), probs 243 244 245def _is_known_unsigned_by_dtype(dt): 246 """Helper returning True if dtype is known to be unsigned.""" 247 return { 248 dtypes.bool: True, 249 dtypes.uint8: True, 250 dtypes.uint16: True, 251 }.get(dt.base_dtype, False) 252 253 254def _is_known_signed_by_dtype(dt): 255 """Helper returning True if dtype is known to be signed.""" 256 return { 257 dtypes.float16: True, 258 dtypes.float32: True, 259 dtypes.float64: True, 260 dtypes.int8: True, 261 dtypes.int16: True, 262 dtypes.int32: True, 263 dtypes.int64: True, 264 }.get(dt.base_dtype, False) 265 266 267def _is_known_dtype(dt): 268 """Helper returning True if dtype is known.""" 269 return _is_known_unsigned_by_dtype(dt) or _is_known_signed_by_dtype(dt) 270 271 272def _largest_integer_by_dtype(dt): 273 """Helper returning the largest integer exactly representable by dtype.""" 274 if not _is_known_dtype(dt): 275 raise TypeError("Unrecognized dtype: {}".format(dt.name)) 276 if dt.is_floating: 277 return int(2**(np.finfo(dt.as_numpy_dtype).nmant + 1)) 278 if dt.is_integer: 279 return np.iinfo(dt.as_numpy_dtype).max 280 if dt.base_dtype == dtypes.bool: 281 return int(1) 282 # We actually can't land here but keep the case for completeness. 283 raise TypeError("Unrecognized dtype: {}".format(dt.name)) 284 285 286def _smallest_integer_by_dtype(dt): 287 """Helper returning the smallest integer exactly representable by dtype.""" 288 if not _is_known_dtype(dt): 289 raise TypeError("Unrecognized dtype: {}".format(dt.name)) 290 if _is_known_unsigned_by_dtype(dt): 291 return 0 292 return -1 * _largest_integer_by_dtype(dt) 293 294 295def _is_integer_like_by_dtype(dt): 296 """Helper returning True if dtype.is_integer or is `bool`.""" 297 if not _is_known_dtype(dt): 298 raise TypeError("Unrecognized dtype: {}".format(dt.name)) 299 return dt.is_integer or dt.base_dtype == dtypes.bool 300 301 302def embed_check_categorical_event_shape( 303 categorical_param, name="embed_check_categorical_event_shape"): 304 """Embeds checks that categorical distributions don't have too many classes. 305 306 A categorical-type distribution is one which, e.g., returns the class label 307 rather than a one-hot encoding. E.g., `Categorical(probs)`. 308 309 Since distributions output samples in the same dtype as the parameters, we 310 must ensure that casting doesn't lose precision. That is, the 311 `parameter.dtype` implies a maximum number of classes. However, since shape is 312 `int32` and categorical variables are presumed to be indexes into a `Tensor`, 313 we must also ensure that the number of classes is no larger than the largest 314 possible `int32` index, i.e., `2**31-1`. 315 316 In other words the number of classes, `K`, must satisfy the following 317 condition: 318 319 ```python 320 K <= min( 321 int(2**31 - 1), # Largest float as an index. 322 { 323 dtypes.float16: int(2**11), # Largest int as a float16. 324 dtypes.float32: int(2**24), 325 dtypes.float64: int(2**53), 326 }.get(categorical_param.dtype.base_dtype, 0)) 327 ``` 328 329 Args: 330 categorical_param: Floating-point `Tensor` representing parameters of 331 distribution over categories. The rightmost shape is presumed to be the 332 number of categories. 333 name: A name for this operation (optional). 334 335 Returns: 336 categorical_param: Input `Tensor` with appropriate assertions embedded. 337 338 Raises: 339 TypeError: if `categorical_param` has an unknown `dtype`. 340 ValueError: if we can statically identify `categorical_param` as being too 341 large (for being closed under int32/float casting). 342 """ 343 with ops.name_scope(name, values=[categorical_param]): 344 x = ops.convert_to_tensor(categorical_param, name="categorical_param") 345 # The size must not exceed both of: 346 # - The largest possible int32 (since categorical values are presumed to be 347 # indexes into a Tensor). 348 # - The largest possible integer exactly representable under the given 349 # floating-point dtype (since we need to cast to/from). 350 # 351 # The chosen floating-point thresholds are 2**(1 + mantissa_bits). 352 # For more details, see: 353 # https://en.wikipedia.org/wiki/Floating-point_arithmetic#Internal_representation 354 x_dtype = x.dtype.base_dtype 355 max_event_size = ( 356 _largest_integer_by_dtype(x_dtype) if x_dtype.is_floating else 0) 357 if max_event_size == 0: 358 raise TypeError("Unable to validate size of unrecognized dtype " 359 "({}).".format(x_dtype.name)) 360 try: 361 x_shape_static = x.get_shape().with_rank_at_least(1) 362 except ValueError: 363 raise ValueError("A categorical-distribution parameter must have " 364 "at least 1 dimension.") 365 if tensor_shape.dimension_value(x_shape_static[-1]) is not None: 366 event_size = x_shape_static.dims[-1].value 367 if event_size < 2: 368 raise ValueError("A categorical-distribution parameter must have at " 369 "least 2 events.") 370 if event_size > max_event_size: 371 raise ValueError("Number of classes exceeds `dtype` precision, i.e., " 372 "{} implies shape ({}) cannot exceed {}.".format( 373 x_dtype.name, event_size, max_event_size)) 374 return x 375 else: 376 event_size = array_ops.shape(x, name="x_shape")[-1] 377 return control_flow_ops.with_dependencies([ 378 check_ops.assert_rank_at_least( 379 x, 380 1, 381 message=("A categorical-distribution parameter must have " 382 "at least 1 dimension.")), 383 check_ops.assert_greater_equal( 384 array_ops.shape(x)[-1], 385 2, 386 message=("A categorical-distribution parameter must have at " 387 "least 2 events.")), 388 check_ops.assert_less_equal( 389 event_size, 390 max_event_size, 391 message="Number of classes exceeds `dtype` precision, " 392 "i.e., {} dtype cannot exceed {} shape.".format( 393 x_dtype.name, max_event_size)), 394 ], x) 395 396 397def embed_check_integer_casting_closed(x, 398 target_dtype, 399 assert_nonnegative=True, 400 name="embed_check_casting_closed"): 401 """Ensures integers remain unaffected despite casting to/from int/float types. 402 403 Example integer-types: `uint8`, `int32`, `bool`. 404 Example floating-types: `float32`, `float64`. 405 406 The largest possible integer representable by an IEEE754 floating-point is 407 `2**(1 + mantissa_bits)` yet the largest possible integer as an int-type is 408 `2**(bits - 1) - 1`. This function ensures that a `Tensor` purporting to have 409 integer-form values can be cast to some other type without loss of precision. 410 411 The smallest representable integer is the negative of the largest 412 representable integer, except for types: `uint8`, `uint16`, `bool`. For these 413 types, the smallest representable integer is `0`. 414 415 Args: 416 x: `Tensor` representing integer-form values. 417 target_dtype: TF `dtype` under which `x` should have identical values. 418 assert_nonnegative: `bool` indicating `x` should contain nonnegative values. 419 name: A name for this operation (optional). 420 421 Returns: 422 x: Input `Tensor` with appropriate assertions embedded. 423 424 Raises: 425 TypeError: if `x` is neither integer- nor floating-type. 426 TypeError: if `target_dtype` is neither integer- nor floating-type. 427 TypeError: if neither `x` nor `target_dtype` are integer-type. 428 """ 429 430 with ops.name_scope(name, values=[x]): 431 x = ops.convert_to_tensor(x, name="x") 432 if (not _is_integer_like_by_dtype(x.dtype) and not x.dtype.is_floating): 433 raise TypeError("{}.dtype must be floating- or " 434 "integer-type.".format(x.dtype.name)) 435 if (not _is_integer_like_by_dtype(target_dtype) and 436 not target_dtype.is_floating): 437 raise TypeError("target_dtype ({}) must be floating- or " 438 "integer-type.".format(target_dtype.name)) 439 if (not _is_integer_like_by_dtype(x.dtype) and 440 not _is_integer_like_by_dtype(target_dtype)): 441 raise TypeError("At least one of {}.dtype ({}) and target_dtype ({}) " 442 "must be integer-type.".format(x, x.dtype.name, 443 target_dtype.name)) 444 445 assertions = [] 446 if assert_nonnegative: 447 assertions += [ 448 check_ops.assert_non_negative( 449 x, message="Elements must be non-negative."), 450 ] 451 452 if x.dtype.is_floating: 453 # Being here means _is_integer_like_by_dtype(target_dtype) = True. 454 # Since this check implies the magnitude check below, we need only it. 455 assertions += [ 456 assert_integer_form( 457 x, 458 int_dtype=target_dtype, 459 message="Elements must be {}-equivalent.".format( 460 target_dtype.name)), 461 ] 462 else: 463 if (_largest_integer_by_dtype(x.dtype) > 464 _largest_integer_by_dtype(target_dtype)): 465 # Cast may lose integer precision. 466 assertions += [ 467 check_ops.assert_less_equal( 468 x, 469 _largest_integer_by_dtype(target_dtype), 470 message=("Elements cannot exceed {}.".format( 471 _largest_integer_by_dtype(target_dtype)))), 472 ] 473 if (not assert_nonnegative and (_smallest_integer_by_dtype( 474 x.dtype) < _smallest_integer_by_dtype(target_dtype))): 475 assertions += [ 476 check_ops.assert_greater_equal( 477 x, 478 _smallest_integer_by_dtype(target_dtype), 479 message=("Elements cannot be smaller than {}.".format( 480 _smallest_integer_by_dtype(target_dtype)))), 481 ] 482 483 if not assertions: 484 return x 485 return control_flow_ops.with_dependencies(assertions, x) 486 487 488def log_combinations(n, counts, name="log_combinations"): 489 """Multinomial coefficient. 490 491 Given `n` and `counts`, where `counts` has last dimension `k`, we compute 492 the multinomial coefficient as: 493 494 ```n! / sum_i n_i!``` 495 496 where `i` runs over all `k` classes. 497 498 Args: 499 n: Floating-point `Tensor` broadcastable with `counts`. This represents `n` 500 outcomes. 501 counts: Floating-point `Tensor` broadcastable with `n`. This represents 502 counts in `k` classes, where `k` is the last dimension of the tensor. 503 name: A name for this operation (optional). 504 505 Returns: 506 `Tensor` representing the multinomial coefficient between `n` and `counts`. 507 """ 508 # First a bit about the number of ways counts could have come in: 509 # E.g. if counts = [1, 2], then this is 3 choose 2. 510 # In general, this is (sum counts)! / sum(counts!) 511 # The sum should be along the last dimension of counts. This is the 512 # "distribution" dimension. Here n a priori represents the sum of counts. 513 with ops.name_scope(name, values=[n, counts]): 514 n = ops.convert_to_tensor(n, name="n") 515 counts = ops.convert_to_tensor(counts, name="counts") 516 total_permutations = math_ops.lgamma(n + 1) 517 counts_factorial = math_ops.lgamma(counts + 1) 518 redundant_permutations = math_ops.reduce_sum(counts_factorial, axis=[-1]) 519 return total_permutations - redundant_permutations 520 521 522def matrix_diag_transform(matrix, transform=None, name=None): 523 """Transform diagonal of [batch-]matrix, leave rest of matrix unchanged. 524 525 Create a trainable covariance defined by a Cholesky factor: 526 527 ```python 528 # Transform network layer into 2 x 2 array. 529 matrix_values = tf.contrib.layers.fully_connected(activations, 4) 530 matrix = tf.reshape(matrix_values, (batch_size, 2, 2)) 531 532 # Make the diagonal positive. If the upper triangle was zero, this would be a 533 # valid Cholesky factor. 534 chol = matrix_diag_transform(matrix, transform=tf.nn.softplus) 535 536 # LinearOperatorLowerTriangular ignores the upper triangle. 537 operator = LinearOperatorLowerTriangular(chol) 538 ``` 539 540 Example of heteroskedastic 2-D linear regression. 541 542 ```python 543 tfd = tfp.distributions 544 545 # Get a trainable Cholesky factor. 546 matrix_values = tf.contrib.layers.fully_connected(activations, 4) 547 matrix = tf.reshape(matrix_values, (batch_size, 2, 2)) 548 chol = matrix_diag_transform(matrix, transform=tf.nn.softplus) 549 550 # Get a trainable mean. 551 mu = tf.contrib.layers.fully_connected(activations, 2) 552 553 # This is a fully trainable multivariate normal! 554 dist = tfd.MultivariateNormalTriL(mu, chol) 555 556 # Standard log loss. Minimizing this will "train" mu and chol, and then dist 557 # will be a distribution predicting labels as multivariate Gaussians. 558 loss = -1 * tf.reduce_mean(dist.log_prob(labels)) 559 ``` 560 561 Args: 562 matrix: Rank `R` `Tensor`, `R >= 2`, where the last two dimensions are 563 equal. 564 transform: Element-wise function mapping `Tensors` to `Tensors`. To be 565 applied to the diagonal of `matrix`. If `None`, `matrix` is returned 566 unchanged. Defaults to `None`. 567 name: A name to give created ops. Defaults to "matrix_diag_transform". 568 569 Returns: 570 A `Tensor` with same shape and `dtype` as `matrix`. 571 """ 572 with ops.name_scope(name, "matrix_diag_transform", [matrix]): 573 matrix = ops.convert_to_tensor(matrix, name="matrix") 574 if transform is None: 575 return matrix 576 # Replace the diag with transformed diag. 577 diag = array_ops.matrix_diag_part(matrix) 578 transformed_diag = transform(diag) 579 transformed_mat = array_ops.matrix_set_diag(matrix, transformed_diag) 580 581 return transformed_mat 582 583 584def rotate_transpose(x, shift, name="rotate_transpose"): 585 """Circularly moves dims left or right. 586 587 Effectively identical to: 588 589 ```python 590 numpy.transpose(x, numpy.roll(numpy.arange(len(x.shape)), shift)) 591 ``` 592 593 When `validate_args=False` additional graph-runtime checks are 594 performed. These checks entail moving data from to GPU to CPU. 595 596 Example: 597 598 ```python 599 x = tf.random.normal([1, 2, 3, 4]) # Tensor of shape [1, 2, 3, 4]. 600 rotate_transpose(x, -1).shape == [2, 3, 4, 1] 601 rotate_transpose(x, -2).shape == [3, 4, 1, 2] 602 rotate_transpose(x, 1).shape == [4, 1, 2, 3] 603 rotate_transpose(x, 2).shape == [3, 4, 1, 2] 604 rotate_transpose(x, 7).shape == rotate_transpose(x, 3).shape # [2, 3, 4, 1] 605 rotate_transpose(x, -7).shape == rotate_transpose(x, -3).shape # [4, 1, 2, 3] 606 ``` 607 608 Args: 609 x: `Tensor`. 610 shift: `Tensor`. Number of dimensions to transpose left (shift<0) or 611 transpose right (shift>0). 612 name: Python `str`. The name to give this op. 613 614 Returns: 615 rotated_x: Input `Tensor` with dimensions circularly rotated by shift. 616 617 Raises: 618 TypeError: if shift is not integer type. 619 """ 620 with ops.name_scope(name, values=[x, shift]): 621 x = ops.convert_to_tensor(x, name="x") 622 shift = ops.convert_to_tensor(shift, name="shift") 623 # We do not assign back to preserve constant-ness. 624 check_ops.assert_integer(shift) 625 shift_value_static = tensor_util.constant_value(shift) 626 ndims = x.get_shape().ndims 627 if ndims is not None and shift_value_static is not None: 628 if ndims < 2: 629 return x 630 shift_value_static = np.sign(shift_value_static) * ( 631 abs(shift_value_static) % ndims) 632 if shift_value_static == 0: 633 return x 634 perm = np.roll(np.arange(ndims), shift_value_static) 635 return array_ops.transpose(x, perm=perm) 636 else: 637 # Consider if we always had a positive shift, and some specified 638 # direction. 639 # When shifting left we want the new array: 640 # last(x, n-shift) + first(x, shift) 641 # and if shifting right then we want: 642 # last(x, shift) + first(x, n-shift) 643 # Observe that last(a) == slice(a, n) and first(a) == slice(0, a). 644 # Also, we can encode direction and shift as one: direction * shift. 645 # Combining these facts, we have: 646 # a = cond(shift<0, -shift, n-shift) 647 # last(x, n-a) + first(x, a) == x[a:n] + x[0:a] 648 # Finally, we transform shift by modulo length so it can be specified 649 # independently from the array upon which it operates (like python). 650 ndims = array_ops.rank(x) 651 shift = array_ops.where_v2( 652 math_ops.less(shift, 0), math_ops.mod(-shift, ndims), 653 ndims - math_ops.mod(shift, ndims)) 654 first = math_ops.range(0, shift) 655 last = math_ops.range(shift, ndims) 656 perm = array_ops.concat([last, first], 0) 657 return array_ops.transpose(x, perm=perm) 658 659 660def pick_vector(cond, true_vector, false_vector, name="pick_vector"): 661 """Picks possibly different length row `Tensor`s based on condition. 662 663 Value `Tensor`s should have exactly one dimension. 664 665 If `cond` is a python Boolean or `tf.constant` then either `true_vector` or 666 `false_vector` is immediately returned. I.e., no graph nodes are created and 667 no validation happens. 668 669 Args: 670 cond: `Tensor`. Must have `dtype=tf.bool` and be scalar. 671 true_vector: `Tensor` of one dimension. Returned when cond is `True`. 672 false_vector: `Tensor` of one dimension. Returned when cond is `False`. 673 name: Python `str`. The name to give this op. 674 Example: ```python pick_vector(tf.less(0, 5), tf.range(10, 12), tf.range(15, 675 18)) # [10, 11] pick_vector(tf.less(5, 0), tf.range(10, 12), tf.range(15, 676 18)) # [15, 16, 17] ``` 677 678 Returns: 679 true_or_false_vector: `Tensor`. 680 681 Raises: 682 TypeError: if `cond.dtype != tf.bool` 683 TypeError: if `cond` is not a constant and 684 `true_vector.dtype != false_vector.dtype` 685 """ 686 with ops.name_scope(name, values=(cond, true_vector, false_vector)): 687 cond = ops.convert_to_tensor(cond, name="cond") 688 if cond.dtype != dtypes.bool: 689 raise TypeError("%s.dtype=%s which is not %s" % 690 (cond, cond.dtype, dtypes.bool)) 691 cond_value_static = tensor_util.constant_value(cond) 692 if cond_value_static is not None: 693 return true_vector if cond_value_static else false_vector 694 true_vector = ops.convert_to_tensor(true_vector, name="true_vector") 695 false_vector = ops.convert_to_tensor(false_vector, name="false_vector") 696 if true_vector.dtype != false_vector.dtype: 697 raise TypeError( 698 "%s.dtype=%s does not match %s.dtype=%s" % 699 (true_vector, true_vector.dtype, false_vector, false_vector.dtype)) 700 n = array_ops.shape(true_vector)[0] 701 return array_ops.slice( 702 array_ops.concat([true_vector, false_vector], 0), 703 [array_ops.where_v2(cond, 0, n)], [array_ops.where(cond, n, -1)]) 704 705 706def prefer_static_broadcast_shape(shape1, 707 shape2, 708 name="prefer_static_broadcast_shape"): 709 """Convenience function which statically broadcasts shape when possible. 710 711 Args: 712 shape1: `1-D` integer `Tensor`. Already converted to tensor! 713 shape2: `1-D` integer `Tensor`. Already converted to tensor! 714 name: A string name to prepend to created ops. 715 716 Returns: 717 The broadcast shape, either as `TensorShape` (if broadcast can be done 718 statically), or as a `Tensor`. 719 """ 720 with ops.name_scope(name, values=[shape1, shape2]): 721 722 def make_shape_tensor(x): 723 return ops.convert_to_tensor(x, name="shape", dtype=dtypes.int32) 724 725 def get_tensor_shape(s): 726 if isinstance(s, tensor_shape.TensorShape): 727 return s 728 s_ = tensor_util.constant_value(make_shape_tensor(s)) 729 if s_ is not None: 730 return tensor_shape.TensorShape(s_) 731 return None 732 733 def get_shape_tensor(s): 734 if not isinstance(s, tensor_shape.TensorShape): 735 return make_shape_tensor(s) 736 if s.is_fully_defined(): 737 return make_shape_tensor(s.as_list()) 738 raise ValueError("Cannot broadcast from partially " 739 "defined `TensorShape`.") 740 741 shape1_ = get_tensor_shape(shape1) 742 shape2_ = get_tensor_shape(shape2) 743 if shape1_ is not None and shape2_ is not None: 744 return array_ops.broadcast_static_shape(shape1_, shape2_) 745 746 shape1_ = get_shape_tensor(shape1) 747 shape2_ = get_shape_tensor(shape2) 748 return array_ops.broadcast_dynamic_shape(shape1_, shape2_) 749 750 751def prefer_static_rank(x): 752 """Return static rank of tensor `x` if available, else `tf.rank(x)`. 753 754 Args: 755 x: `Tensor` (already converted). 756 757 Returns: 758 Numpy array (if static rank is obtainable), else `Tensor`. 759 """ 760 return prefer_static_value(array_ops.rank(x)) 761 762 763def prefer_static_shape(x): 764 """Return static shape of tensor `x` if available, else `tf.shape(x)`. 765 766 Args: 767 x: `Tensor` (already converted). 768 769 Returns: 770 Numpy array (if static shape is obtainable), else `Tensor`. 771 """ 772 return prefer_static_value(array_ops.shape(x)) 773 774 775def prefer_static_value(x): 776 """Return static value of tensor `x` if available, else `x`. 777 778 Args: 779 x: `Tensor` (already converted). 780 781 Returns: 782 Numpy array (if static value is obtainable), else `Tensor`. 783 """ 784 static_x = tensor_util.constant_value(x) 785 if static_x is not None: 786 return static_x 787 return x 788 789 790def gen_new_seed(seed, salt): 791 """Generate a new seed, from the given seed and salt.""" 792 if seed is None: 793 return None 794 string = (str(seed) + salt).encode("utf-8") 795 return int(hashlib.md5(string).hexdigest()[:8], 16) & 0x7FFFFFFF 796 797 798def fill_triangular(x, upper=False, name=None): 799 """Creates a (batch of) triangular matrix from a vector of inputs. 800 801 Created matrix can be lower- or upper-triangular. (It is more efficient to 802 create the matrix as upper or lower, rather than transpose.) 803 804 Triangular matrix elements are filled in a clockwise spiral. See example, 805 below. 806 807 If `x.get_shape()` is `[b1, b2, ..., bB, d]` then the output shape is 808 `[b1, b2, ..., bB, n, n]` where `n` is such that `d = n(n+1)/2`, i.e., 809 `n = int(np.sqrt(0.25 + 2. * m) - 0.5)`. 810 811 Example: 812 813 ```python 814 fill_triangular([1, 2, 3, 4, 5, 6]) 815 # ==> [[4, 0, 0], 816 # [6, 5, 0], 817 # [3, 2, 1]] 818 819 fill_triangular([1, 2, 3, 4, 5, 6], upper=True) 820 # ==> [[1, 2, 3], 821 # [0, 5, 6], 822 # [0, 0, 4]] 823 ``` 824 825 For comparison, a pure numpy version of this function can be found in 826 `util_test.py`, function `_fill_triangular`. 827 828 Args: 829 x: `Tensor` representing lower (or upper) triangular elements. 830 upper: Python `bool` representing whether output matrix should be upper 831 triangular (`True`) or lower triangular (`False`, default). 832 name: Python `str`. The name to give this op. 833 834 Returns: 835 tril: `Tensor` with lower (or upper) triangular elements filled from `x`. 836 837 Raises: 838 ValueError: if `x` cannot be mapped to a triangular matrix. 839 """ 840 841 with ops.name_scope(name, "fill_triangular", values=[x]): 842 x = ops.convert_to_tensor(x, name="x") 843 if tensor_shape.dimension_value( 844 x.shape.with_rank_at_least(1)[-1]) is not None: 845 # Formula derived by solving for n: m = n(n+1)/2. 846 m = np.int32(x.shape.dims[-1].value) 847 n = np.sqrt(0.25 + 2. * m) - 0.5 848 if n != np.floor(n): 849 raise ValueError("Input right-most shape ({}) does not " 850 "correspond to a triangular matrix.".format(m)) 851 n = np.int32(n) 852 static_final_shape = x.shape[:-1].concatenate([n, n]) 853 else: 854 m = array_ops.shape(x)[-1] 855 # For derivation, see above. Casting automatically lops off the 0.5, so we 856 # omit it. We don't validate n is an integer because this has 857 # graph-execution cost; an error will be thrown from the reshape, below. 858 n = math_ops.cast( 859 math_ops.sqrt(0.25 + math_ops.cast(2 * m, dtype=dtypes.float32)), 860 dtype=dtypes.int32) 861 static_final_shape = x.shape.with_rank_at_least(1)[:-1].concatenate( 862 [None, None]) 863 # We now concatenate the "tail" of `x` to `x` (and reverse one of them). 864 # 865 # We do this based on the insight that the input `x` provides `ceil(n/2)` 866 # rows of an `n x n` matrix, some of which will get zeroed out being on the 867 # wrong side of the diagonal. The first row will not get zeroed out at all, 868 # and we need `floor(n/2)` more rows, so the first is what we omit from 869 # `x_tail`. If we then stack those `ceil(n/2)` rows with the `floor(n/2)` 870 # rows provided by a reversed tail, it is exactly the other set of elements 871 # of the reversed tail which will be zeroed out for being on the wrong side 872 # of the diagonal further up/down the matrix. And, in doing-so, we've filled 873 # the triangular matrix in a clock-wise spiral pattern. Neat! 874 # 875 # Try it out in numpy: 876 # n = 3 877 # x = np.arange(n * (n + 1) / 2) 878 # m = x.shape[0] 879 # n = np.int32(np.sqrt(.25 + 2 * m) - .5) 880 # x_tail = x[(m - (n**2 - m)):] 881 # np.concatenate([x_tail, x[::-1]], 0).reshape(n, n) # lower 882 # # ==> array([[3, 4, 5], 883 # [5, 4, 3], 884 # [2, 1, 0]]) 885 # np.concatenate([x, x_tail[::-1]], 0).reshape(n, n) # upper 886 # # ==> array([[0, 1, 2], 887 # [3, 4, 5], 888 # [5, 4, 3]]) 889 # 890 # Note that we can't simply do `x[..., -(n**2 - m):]` because this doesn't 891 # correctly handle `m == n == 1`. Hence, we do nonnegative indexing. 892 # Furthermore observe that: 893 # m - (n**2 - m) 894 # = n**2 / 2 + n / 2 - (n**2 - n**2 / 2 + n / 2) 895 # = 2 (n**2 / 2 + n / 2) - n**2 896 # = n**2 + n - n**2 897 # = n 898 ndims = prefer_static_rank(x) 899 if upper: 900 x_list = [x, array_ops.reverse(x[..., n:], axis=[ndims - 1])] 901 else: 902 x_list = [x[..., n:], array_ops.reverse(x, axis=[ndims - 1])] 903 new_shape = ( 904 static_final_shape.as_list() if static_final_shape.is_fully_defined() 905 else array_ops.concat([array_ops.shape(x)[:-1], [n, n]], axis=0)) 906 x = array_ops.reshape(array_ops.concat(x_list, axis=-1), new_shape) 907 x = array_ops.matrix_band_part( 908 x, num_lower=(0 if upper else -1), num_upper=(-1 if upper else 0)) 909 x.set_shape(static_final_shape) 910 return x 911 912 913def fill_triangular_inverse(x, upper=False, name=None): 914 """Creates a vector from a (batch of) triangular matrix. 915 916 The vector is created from the lower-triangular or upper-triangular portion 917 depending on the value of the parameter `upper`. 918 919 If `x.shape` is `[b1, b2, ..., bB, n, n]` then the output shape is 920 `[b1, b2, ..., bB, d]` where `d = n (n + 1) / 2`. 921 922 Example: 923 924 ```python 925 fill_triangular_inverse( 926 [[4, 0, 0], 927 [6, 5, 0], 928 [3, 2, 1]]) 929 930 # ==> [1, 2, 3, 4, 5, 6] 931 932 fill_triangular_inverse( 933 [[1, 2, 3], 934 [0, 5, 6], 935 [0, 0, 4]], upper=True) 936 937 # ==> [1, 2, 3, 4, 5, 6] 938 ``` 939 940 Args: 941 x: `Tensor` representing lower (or upper) triangular elements. 942 upper: Python `bool` representing whether output matrix should be upper 943 triangular (`True`) or lower triangular (`False`, default). 944 name: Python `str`. The name to give this op. 945 946 Returns: 947 flat_tril: (Batch of) vector-shaped `Tensor` representing vectorized lower 948 (or upper) triangular elements from `x`. 949 """ 950 951 with ops.name_scope(name, "fill_triangular_inverse", values=[x]): 952 x = ops.convert_to_tensor(x, name="x") 953 if tensor_shape.dimension_value( 954 x.shape.with_rank_at_least(2)[-1]) is not None: 955 n = np.int32(x.shape.dims[-1].value) 956 m = np.int32((n * (n + 1)) // 2) 957 static_final_shape = x.shape[:-2].concatenate([m]) 958 else: 959 n = array_ops.shape(x)[-1] 960 m = (n * (n + 1)) // 2 961 static_final_shape = x.shape.with_rank_at_least(2)[:-2].concatenate( 962 [None]) 963 ndims = prefer_static_rank(x) 964 if upper: 965 initial_elements = x[..., 0, :] 966 triangular_portion = x[..., 1:, :] 967 else: 968 initial_elements = array_ops.reverse(x[..., -1, :], axis=[ndims - 2]) 969 triangular_portion = x[..., :-1, :] 970 rotated_triangular_portion = array_ops.reverse( 971 array_ops.reverse(triangular_portion, axis=[ndims - 1]), 972 axis=[ndims - 2]) 973 consolidated_matrix = triangular_portion + rotated_triangular_portion 974 end_sequence = array_ops.reshape( 975 consolidated_matrix, 976 array_ops.concat([array_ops.shape(x)[:-2], [n * (n - 1)]], axis=0)) 977 y = array_ops.concat([initial_elements, end_sequence[..., :m - n]], axis=-1) 978 y.set_shape(static_final_shape) 979 return y 980 981 982def tridiag(below=None, diag=None, above=None, name=None): 983 """Creates a matrix with values set above, below, and on the diagonal. 984 985 Example: 986 987 ```python 988 tridiag(below=[1., 2., 3.], 989 diag=[4., 5., 6., 7.], 990 above=[8., 9., 10.]) 991 # ==> array([[ 4., 8., 0., 0.], 992 # [ 1., 5., 9., 0.], 993 # [ 0., 2., 6., 10.], 994 # [ 0., 0., 3., 7.]], dtype=float32) 995 ``` 996 997 Warning: This Op is intended for convenience, not efficiency. 998 999 Args: 1000 below: `Tensor` of shape `[B1, ..., Bb, d-1]` corresponding to the below 1001 diagonal part. `None` is logically equivalent to `below = 0`. 1002 diag: `Tensor` of shape `[B1, ..., Bb, d]` corresponding to the diagonal 1003 part. `None` is logically equivalent to `diag = 0`. 1004 above: `Tensor` of shape `[B1, ..., Bb, d-1]` corresponding to the above 1005 diagonal part. `None` is logically equivalent to `above = 0`. 1006 name: Python `str`. The name to give this op. 1007 1008 Returns: 1009 tridiag: `Tensor` with values set above, below and on the diagonal. 1010 1011 Raises: 1012 ValueError: if all inputs are `None`. 1013 """ 1014 1015 def _pad(x): 1016 """Prepends and appends a zero to every vector in a batch of vectors.""" 1017 shape = array_ops.concat([array_ops.shape(x)[:-1], [1]], axis=0) 1018 z = array_ops.zeros(shape, dtype=x.dtype) 1019 return array_ops.concat([z, x, z], axis=-1) 1020 1021 def _add(*x): 1022 """Adds list of Tensors, ignoring `None`.""" 1023 s = None 1024 for y in x: 1025 if y is None: 1026 continue 1027 elif s is None: 1028 s = y 1029 else: 1030 s += y 1031 if s is None: 1032 raise ValueError("Must specify at least one of `below`, `diag`, `above`.") 1033 return s 1034 1035 with ops.name_scope(name, "tridiag", [below, diag, above]): 1036 if below is not None: 1037 below = ops.convert_to_tensor(below, name="below") 1038 below = array_ops.matrix_diag(_pad(below))[..., :-1, 1:] 1039 if diag is not None: 1040 diag = ops.convert_to_tensor(diag, name="diag") 1041 diag = array_ops.matrix_diag(diag) 1042 if above is not None: 1043 above = ops.convert_to_tensor(above, name="above") 1044 above = array_ops.matrix_diag(_pad(above))[..., 1:, :-1] 1045 # TODO(jvdillon): Consider using scatter_nd instead of creating three full 1046 # matrices. 1047 return _add(below, diag, above) 1048 1049 1050def reduce_weighted_logsumexp(logx, 1051 w=None, 1052 axis=None, 1053 keep_dims=False, 1054 return_sign=False, 1055 name=None): 1056 """Computes `log(abs(sum(weight * exp(elements across tensor dimensions))))`. 1057 1058 If all weights `w` are known to be positive, it is more efficient to directly 1059 use `reduce_logsumexp`, i.e., `tf.reduce_logsumexp(logx + tf.math.log(w))` is 1060 more 1061 efficient than `du.reduce_weighted_logsumexp(logx, w)`. 1062 1063 Reduces `input_tensor` along the dimensions given in `axis`. 1064 Unless `keep_dims` is true, the rank of the tensor is reduced by 1 for each 1065 entry in `axis`. If `keep_dims` is true, the reduced dimensions 1066 are retained with length 1. 1067 1068 If `axis` has no entries, all dimensions are reduced, and a 1069 tensor with a single element is returned. 1070 1071 This function is more numerically stable than log(sum(w * exp(input))). It 1072 avoids overflows caused by taking the exp of large inputs and underflows 1073 caused by taking the log of small inputs. 1074 1075 For example: 1076 1077 ```python 1078 x = tf.constant([[0., 0, 0], 1079 [0, 0, 0]]) 1080 1081 w = tf.constant([[-1., 1, 1], 1082 [1, 1, 1]]) 1083 1084 du.reduce_weighted_logsumexp(x, w) 1085 # ==> log(-1*1 + 1*1 + 1*1 + 1*1 + 1*1 + 1*1) = log(4) 1086 1087 du.reduce_weighted_logsumexp(x, w, axis=0) 1088 # ==> [log(-1+1), log(1+1), log(1+1)] 1089 1090 du.reduce_weighted_logsumexp(x, w, axis=1) 1091 # ==> [log(-1+1+1), log(1+1+1)] 1092 1093 du.reduce_weighted_logsumexp(x, w, axis=1, keep_dims=True) 1094 # ==> [[log(-1+1+1)], [log(1+1+1)]] 1095 1096 du.reduce_weighted_logsumexp(x, w, axis=[0, 1]) 1097 # ==> log(-1+5) 1098 ``` 1099 1100 Args: 1101 logx: The tensor to reduce. Should have numeric type. 1102 w: The weight tensor. Should have numeric type identical to `logx`. 1103 axis: The dimensions to reduce. If `None` (the default), reduces all 1104 dimensions. Must be in the range `[-rank(input_tensor), 1105 rank(input_tensor))`. 1106 keep_dims: If true, retains reduced dimensions with length 1. 1107 return_sign: If `True`, returns the sign of the result. 1108 name: A name for the operation (optional). 1109 1110 Returns: 1111 lswe: The `log(abs(sum(weight * exp(x))))` reduced tensor. 1112 sign: (Optional) The sign of `sum(weight * exp(x))`. 1113 """ 1114 with ops.name_scope(name, "reduce_weighted_logsumexp", [logx, w]): 1115 logx = ops.convert_to_tensor(logx, name="logx") 1116 if w is None: 1117 lswe = math_ops.reduce_logsumexp(logx, axis=axis, keepdims=keep_dims) 1118 if return_sign: 1119 sgn = array_ops.ones_like(lswe) 1120 return lswe, sgn 1121 return lswe 1122 w = ops.convert_to_tensor(w, dtype=logx.dtype, name="w") 1123 log_absw_x = logx + math_ops.log(math_ops.abs(w)) 1124 max_log_absw_x = math_ops.reduce_max(log_absw_x, axis=axis, keepdims=True) 1125 # If the largest element is `-inf` or `inf` then we don't bother subtracting 1126 # off the max. We do this because otherwise we'd get `inf - inf = NaN`. That 1127 # this is ok follows from the fact that we're actually free to subtract any 1128 # value we like, so long as we add it back after taking the `log(sum(...))`. 1129 max_log_absw_x = array_ops.where_v2( 1130 math_ops.is_inf(max_log_absw_x), array_ops.zeros_like(max_log_absw_x), 1131 max_log_absw_x) 1132 wx_over_max_absw_x = ( 1133 math_ops.sign(w) * math_ops.exp(log_absw_x - max_log_absw_x)) 1134 sum_wx_over_max_absw_x = math_ops.reduce_sum( 1135 wx_over_max_absw_x, axis=axis, keepdims=keep_dims) 1136 if not keep_dims: 1137 max_log_absw_x = array_ops.squeeze(max_log_absw_x, axis) 1138 sgn = math_ops.sign(sum_wx_over_max_absw_x) 1139 lswe = max_log_absw_x + math_ops.log(sgn * sum_wx_over_max_absw_x) 1140 if return_sign: 1141 return lswe, sgn 1142 return lswe 1143 1144 1145# TODO(jvdillon): Merge this test back into: 1146# tensorflow/python/ops/softplus_op_test.py 1147# once TF core is accepting new ops. 1148def softplus_inverse(x, name=None): 1149 """Computes the inverse softplus, i.e., x = softplus_inverse(softplus(x)). 1150 1151 Mathematically this op is equivalent to: 1152 1153 ```none 1154 softplus_inverse = log(exp(x) - 1.) 1155 ``` 1156 1157 Args: 1158 x: `Tensor`. Non-negative (not enforced), floating-point. 1159 name: A name for the operation (optional). 1160 1161 Returns: 1162 `Tensor`. Has the same type/shape as input `x`. 1163 """ 1164 with ops.name_scope(name, "softplus_inverse", values=[x]): 1165 x = ops.convert_to_tensor(x, name="x") 1166 # We begin by deriving a more numerically stable softplus_inverse: 1167 # x = softplus(y) = Log[1 + exp{y}], (which means x > 0). 1168 # ==> exp{x} = 1 + exp{y} (1) 1169 # ==> y = Log[exp{x} - 1] (2) 1170 # = Log[(exp{x} - 1) / exp{x}] + Log[exp{x}] 1171 # = Log[(1 - exp{-x}) / 1] + Log[exp{x}] 1172 # = Log[1 - exp{-x}] + x (3) 1173 # (2) is the "obvious" inverse, but (3) is more stable than (2) for large x. 1174 # For small x (e.g. x = 1e-10), (3) will become -inf since 1 - exp{-x} will 1175 # be zero. To fix this, we use 1 - exp{-x} approx x for small x > 0. 1176 # 1177 # In addition to the numerically stable derivation above, we clamp 1178 # small/large values to be congruent with the logic in: 1179 # tensorflow/core/kernels/softplus_op.h 1180 # 1181 # Finally, we set the input to one whenever the input is too large or too 1182 # small. This ensures that no unchosen codepath is +/- inf. This is 1183 # necessary to ensure the gradient doesn't get NaNs. Recall that the 1184 # gradient of `where` behaves like `pred*pred_true + (1-pred)*pred_false` 1185 # thus an `inf` in an unselected path results in `0*inf=nan`. We are careful 1186 # to overwrite `x` with ones only when we will never actually use this 1187 # value. Note that we use ones and not zeros since `log(expm1(0.)) = -inf`. 1188 threshold = np.log(np.finfo(x.dtype.as_numpy_dtype).eps) + 2. 1189 is_too_small = math_ops.less(x, np.exp(threshold)) 1190 is_too_large = math_ops.greater(x, -threshold) 1191 too_small_value = math_ops.log(x) 1192 too_large_value = x 1193 # This `where` will ultimately be a NOP because we won't select this 1194 # codepath whenever we used the surrogate `ones_like`. 1195 x = array_ops.where_v2( 1196 math_ops.logical_or(is_too_small, is_too_large), array_ops.ones_like(x), 1197 x) 1198 y = x + math_ops.log(-math_ops.expm1(-x)) # == log(expm1(x)) 1199 return array_ops.where_v2( 1200 is_too_small, too_small_value, 1201 array_ops.where_v2(is_too_large, too_large_value, y)) 1202 1203 1204# TODO(b/35290280): Add unit-tests. 1205def dimension_size(x, axis): 1206 """Returns the size of a specific dimension.""" 1207 # Since tf.gather isn't "constant-in, constant-out", we must first check the 1208 # static shape or fallback to dynamic shape. 1209 s = tensor_shape.dimension_value( 1210 x.shape.with_rank_at_least(np.abs(axis))[axis]) 1211 if s is not None: 1212 return s 1213 return array_ops.shape(x)[axis] 1214 1215 1216def process_quadrature_grid_and_probs(quadrature_grid_and_probs, 1217 dtype, 1218 validate_args, 1219 name=None): 1220 """Validates quadrature grid, probs or computes them as necessary. 1221 1222 Args: 1223 quadrature_grid_and_probs: Python pair of `float`-like `Tensor`s 1224 representing the sample points and the corresponding (possibly 1225 normalized) weight. When `None`, defaults to: 1226 `np.polynomial.hermite.hermgauss(deg=8)`. 1227 dtype: The expected `dtype` of `grid` and `probs`. 1228 validate_args: Python `bool`, default `False`. When `True` distribution 1229 parameters are checked for validity despite possibly degrading runtime 1230 performance. When `False` invalid inputs may silently render incorrect 1231 outputs. 1232 name: Python `str` name prefixed to Ops created by this class. 1233 1234 Returns: 1235 quadrature_grid_and_probs: Python pair of `float`-like `Tensor`s 1236 representing the sample points and the corresponding (possibly 1237 normalized) weight. 1238 1239 Raises: 1240 ValueError: if `quadrature_grid_and_probs is not None` and 1241 `len(quadrature_grid_and_probs[0]) != len(quadrature_grid_and_probs[1])` 1242 """ 1243 with ops.name_scope(name, "process_quadrature_grid_and_probs", 1244 [quadrature_grid_and_probs]): 1245 if quadrature_grid_and_probs is None: 1246 grid, probs = np.polynomial.hermite.hermgauss(deg=8) 1247 grid = grid.astype(dtype.as_numpy_dtype) 1248 probs = probs.astype(dtype.as_numpy_dtype) 1249 probs /= np.linalg.norm(probs, ord=1, keepdims=True) 1250 grid = ops.convert_to_tensor(grid, name="grid", dtype=dtype) 1251 probs = ops.convert_to_tensor(probs, name="probs", dtype=dtype) 1252 return grid, probs 1253 1254 grid, probs = tuple(quadrature_grid_and_probs) 1255 grid = ops.convert_to_tensor(grid, name="grid", dtype=dtype) 1256 probs = ops.convert_to_tensor(probs, name="unnormalized_probs", dtype=dtype) 1257 probs /= linalg_ops.norm(probs, ord=1, axis=-1, keepdims=True, name="probs") 1258 1259 def _static_event_size(x): 1260 """Returns the static size of a specific dimension or `None`.""" 1261 return tensor_shape.dimension_value(x.shape.with_rank_at_least(1)[-1]) 1262 1263 m, n = _static_event_size(probs), _static_event_size(grid) 1264 if m is not None and n is not None: 1265 if m != n: 1266 raise ValueError("`quadrature_grid_and_probs` must be a `tuple` of " 1267 "same-length zero-th-dimension `Tensor`s " 1268 "(saw lengths {}, {})".format(m, n)) 1269 elif validate_args: 1270 assertions = [ 1271 check_ops.assert_equal( 1272 dimension_size(probs, axis=-1), 1273 dimension_size(grid, axis=-1), 1274 message=("`quadrature_grid_and_probs` must be a `tuple` of " 1275 "same-length zero-th-dimension `Tensor`s")), 1276 ] 1277 with ops.control_dependencies(assertions): 1278 grid = array_ops.identity(grid) 1279 probs = array_ops.identity(probs) 1280 return grid, probs 1281 1282 1283def pad(x, axis, front=False, back=False, value=0, count=1, name=None): 1284 """Pads `value` to the front and/or back of a `Tensor` dim, `count` times. 1285 1286 Args: 1287 x: `Tensor` input. 1288 axis: Scalar `int`-like `Tensor` representing the single dimension to pad. 1289 (Negative indexing is supported.) 1290 front: Python `bool`; if `True` the beginning of the `axis` dimension is 1291 padded with `value`, `count` times. If `False` no front padding is made. 1292 back: Python `bool`; if `True` the end of the `axis` dimension is padded 1293 with `value`, `count` times. If `False` no end padding is made. 1294 value: Scalar `int`-like `Tensor` representing the actual value added to the 1295 front and/or back of the `axis` dimension of `x`. 1296 count: Scalar `int`-like `Tensor` representing number of elements added to 1297 the front and/or back of the `axis` dimension of `x`. E.g., if `front = 1298 back = True` then `2 * count` elements are added. 1299 name: Python `str` name prefixed to Ops created by this function. 1300 1301 Returns: 1302 pad: The padded version of input `x`. 1303 1304 Raises: 1305 ValueError: if both `front` and `back` are `False`. 1306 TypeError: if `count` is not `int`-like. 1307 """ 1308 with ops.name_scope(name, "pad", [x, value, count]): 1309 x = ops.convert_to_tensor(x, name="x") 1310 value = ops.convert_to_tensor(value, dtype=x.dtype, name="value") 1311 count = ops.convert_to_tensor(count, name="count") 1312 if not count.dtype.is_integer: 1313 raise TypeError("`count.dtype` (`{}`) must be `int`-like.".format( 1314 count.dtype.name)) 1315 if not front and not back: 1316 raise ValueError("At least one of `front`, `back` must be `True`.") 1317 ndims = ( 1318 x.shape.ndims if x.shape.ndims is not None else array_ops.rank( 1319 x, name="ndims")) 1320 axis = ops.convert_to_tensor(axis, name="axis") 1321 axis_ = tensor_util.constant_value(axis) 1322 if axis_ is not None: 1323 axis = axis_ 1324 if axis < 0: 1325 axis = ndims + axis 1326 count_ = tensor_util.constant_value(count) 1327 if axis_ >= 0 or x.shape.ndims is not None: 1328 head = x.shape[:axis] 1329 middle = tensor_shape.TensorShape(None if count_ is None else ( 1330 tensor_shape.dimension_at_index(x.shape, axis) + count_ * 1331 (front + back))) 1332 tail = x.shape[axis + 1:] 1333 final_shape = head.concatenate(middle.concatenate(tail)) 1334 else: 1335 final_shape = None 1336 else: 1337 axis = array_ops.where_v2(axis < 0, ndims + axis, axis) 1338 final_shape = None 1339 x = array_ops.pad( 1340 x, 1341 paddings=array_ops.one_hot( 1342 indices=array_ops.stack( 1343 [axis if front else -1, axis if back else -1]), 1344 depth=ndims, 1345 axis=0, 1346 on_value=count, 1347 dtype=dtypes.int32), 1348 constant_values=value) 1349 if final_shape is not None: 1350 x.set_shape(final_shape) 1351 return x 1352 1353 1354def parent_frame_arguments(): 1355 """Returns parent frame arguments. 1356 1357 When called inside a function, returns a dictionary with the caller's function 1358 arguments. These are positional arguments and keyword arguments (**kwargs), 1359 while variable arguments (*varargs) are excluded. 1360 1361 When called at global scope, this will return an empty dictionary, since there 1362 are no arguments. 1363 1364 WARNING: If caller function argument names are overloaded before invoking 1365 this method, then values will reflect the overloaded value. For this reason, 1366 we recommend calling `parent_frame_arguments` at the beginning of the 1367 function. 1368 """ 1369 # All arguments and the names used for *varargs, and **kwargs 1370 arg_names, variable_arg_name, keyword_arg_name, local_vars = ( 1371 tf_inspect._inspect.getargvalues( # pylint: disable=protected-access 1372 # Get the first frame of the caller of this method. 1373 tf_inspect._inspect.stack()[1][0])) # pylint: disable=protected-access 1374 1375 # Remove the *varargs, and flatten the **kwargs. Both are 1376 # nested lists. 1377 local_vars.pop(variable_arg_name, {}) 1378 keyword_args = local_vars.pop(keyword_arg_name, {}) 1379 1380 final_args = {} 1381 # Copy over arguments and their values. In general, local_vars 1382 # may contain more than just the arguments, since this method 1383 # can be called anywhere in a function. 1384 for arg_name in arg_names: 1385 final_args[arg_name] = local_vars.pop(arg_name) 1386 final_args.update(keyword_args) 1387 1388 return final_args 1389 1390 1391class AppendDocstring(object): 1392 """Helper class to promote private subclass docstring to public counterpart. 1393 1394 Example: 1395 1396 ```python 1397 class TransformedDistribution(Distribution): 1398 @distribution_util.AppendDocstring( 1399 additional_note="A special note!", 1400 kwargs_dict={"foo": "An extra arg."}) 1401 def _prob(self, y, foo=None): 1402 pass 1403 ``` 1404 1405 In this case, the `AppendDocstring` decorator appends the `additional_note` to 1406 the docstring of `prob` (not `_prob`) and adds a new `kwargs` 1407 section with each dictionary item as a bullet-point. 1408 1409 For a more detailed example, see `TransformedDistribution`. 1410 """ 1411 1412 def __init__(self, additional_note="", kwargs_dict=None): 1413 """Initializes the AppendDocstring object. 1414 1415 Args: 1416 additional_note: Python string added as additional docstring to public 1417 version of function. 1418 kwargs_dict: Python string/string dictionary representing specific kwargs 1419 expanded from the **kwargs input. 1420 1421 Raises: 1422 ValueError: if kwargs_dict.key contains whitespace. 1423 ValueError: if kwargs_dict.value contains newlines. 1424 """ 1425 self._additional_note = additional_note 1426 if kwargs_dict: 1427 bullets = [] 1428 for key in sorted(kwargs_dict.keys()): 1429 value = kwargs_dict[key] 1430 if any(x.isspace() for x in key): 1431 raise ValueError("Parameter name \"%s\" contains whitespace." % key) 1432 value = value.lstrip() 1433 if "\n" in value: 1434 raise ValueError( 1435 "Parameter description for \"%s\" contains newlines." % key) 1436 bullets.append("* `%s`: %s" % (key, value)) 1437 self._additional_note += ("\n\n##### `kwargs`:\n\n" + "\n".join(bullets)) 1438 1439 def __call__(self, fn): 1440 1441 @functools.wraps(fn) 1442 def _fn(*args, **kwargs): 1443 return fn(*args, **kwargs) 1444 1445 if _fn.__doc__ is None: 1446 _fn.__doc__ = self._additional_note 1447 else: 1448 _fn.__doc__ += "\n%s" % self._additional_note 1449 return _fn 1450