1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities for probability distributions.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import functools 22import hashlib 23import numpy as np 24 25from tensorflow.python.framework import constant_op 26from tensorflow.python.framework import dtypes 27from tensorflow.python.framework import ops 28from tensorflow.python.framework import tensor_shape 29from tensorflow.python.framework import tensor_util 30from tensorflow.python.ops import array_ops 31from tensorflow.python.ops import check_ops 32from tensorflow.python.ops import control_flow_ops 33from tensorflow.python.ops import linalg_ops 34from tensorflow.python.ops import math_ops 35from tensorflow.python.ops import nn 36from tensorflow.python.util import tf_inspect 37 38 39def assert_integer_form( 40 x, data=None, summarize=None, message=None, 41 int_dtype=None, name="assert_integer_form"): 42 """Assert that x has integer components (or floats equal to integers). 43 44 Args: 45 x: Floating-point `Tensor` 46 data: The tensors to print out if the condition is `False`. Defaults to 47 error message and first few entries of `x` and `y`. 48 summarize: Print this many entries of each tensor. 49 message: A string to prefix to the default message. 50 int_dtype: A `tf.dtype` used to cast the float to. The default (`None`) 51 implies the smallest possible signed int will be used for casting. 52 name: A name for this operation (optional). 53 54 Returns: 55 Op raising `InvalidArgumentError` if `cast(x, int_dtype) != x`. 56 """ 57 with ops.name_scope(name, values=[x, data]): 58 x = ops.convert_to_tensor(x, name="x") 59 if x.dtype.is_integer: 60 return control_flow_ops.no_op() 61 message = message or "{} has non-integer components".format(x) 62 if int_dtype is None: 63 try: 64 int_dtype = { 65 dtypes.float16: dtypes.int16, 66 dtypes.float32: dtypes.int32, 67 dtypes.float64: dtypes.int64, 68 }[x.dtype.base_dtype] 69 except KeyError: 70 raise TypeError("Unrecognized type {}".format(x.dtype.name)) 71 return check_ops.assert_equal( 72 x, math_ops.cast(math_ops.cast(x, int_dtype), x.dtype), 73 data=data, summarize=summarize, message=message, name=name) 74 75 76def assert_symmetric(matrix): 77 matrix_t = array_ops.matrix_transpose(matrix) 78 return control_flow_ops.with_dependencies( 79 [check_ops.assert_equal(matrix, matrix_t)], matrix) 80 81 82def embed_check_nonnegative_integer_form( 83 x, name="embed_check_nonnegative_integer_form"): 84 """Assert x is a non-negative tensor, and optionally of integers.""" 85 with ops.name_scope(name, values=[x]): 86 x = ops.convert_to_tensor(x, name="x") 87 assertions = [ 88 check_ops.assert_non_negative( 89 x, message="'{}' must be non-negative.".format(x)), 90 ] 91 if not x.dtype.is_integer: 92 assertions += [ 93 assert_integer_form( 94 x, message="'{}' cannot contain fractional components.".format( 95 x)), 96 ] 97 return control_flow_ops.with_dependencies(assertions, x) 98 99 100def same_dynamic_shape(a, b): 101 """Returns whether a and b have the same dynamic shape. 102 103 Args: 104 a: `Tensor` 105 b: `Tensor` 106 107 Returns: 108 `bool` `Tensor` representing if both tensors have the same shape. 109 """ 110 a = ops.convert_to_tensor(a, name="a") 111 b = ops.convert_to_tensor(b, name="b") 112 113 # Here we can't just do math_ops.equal(a.shape, b.shape), since 114 # static shape inference may break the equality comparison between 115 # shape(a) and shape(b) in math_ops.equal. 116 def all_shapes_equal(): 117 return math_ops.reduce_all(math_ops.equal( 118 array_ops.concat([array_ops.shape(a), array_ops.shape(b)], 0), 119 array_ops.concat([array_ops.shape(b), array_ops.shape(a)], 0))) 120 121 # One of the shapes isn't fully defined, so we need to use the dynamic 122 # shape. 123 return control_flow_ops.cond( 124 math_ops.equal(array_ops.rank(a), array_ops.rank(b)), 125 all_shapes_equal, 126 lambda: constant_op.constant(False)) 127 128 129def maybe_get_static_value(x, dtype=None): 130 """Helper which tries to return a static value. 131 132 Given `x`, extract it's value statically, optionally casting to a specific 133 dtype. If this is not possible, None is returned. 134 135 Args: 136 x: `Tensor` for which to extract a value statically. 137 dtype: Optional dtype to cast to. 138 139 Returns: 140 Statically inferred value if possible, otherwise None. 141 """ 142 if x is None: 143 return x 144 try: 145 # This returns an np.ndarray. 146 x_ = tensor_util.constant_value(x) 147 except TypeError: 148 x_ = x 149 if x_ is None or dtype is None: 150 return x_ 151 return np.array(x_, dtype) 152 153 154def get_logits_and_probs(logits=None, 155 probs=None, 156 multidimensional=False, 157 validate_args=False, 158 name="get_logits_and_probs", 159 dtype=None): 160 """Converts logit to probabilities (or vice-versa), and returns both. 161 162 Args: 163 logits: Floating-point `Tensor` representing log-odds. 164 probs: Floating-point `Tensor` representing probabilities. 165 multidimensional: Python `bool`, default `False`. 166 If `True`, represents whether the last dimension of `logits` or `probs`, 167 a `[N1, N2, ... k]` dimensional tensor, representing the 168 logit or probability of `shape[-1]` classes. 169 validate_args: Python `bool`, default `False`. When `True`, either assert 170 `0 <= probs <= 1` (if not `multidimensional`) or that the last dimension 171 of `probs` sums to one. 172 name: A name for this operation (optional). 173 dtype: `tf.DType` to prefer when converting args to `Tensor`s. 174 175 Returns: 176 logits, probs: Tuple of `Tensor`s. If `probs` has an entry that is `0` or 177 `1`, then the corresponding entry in the returned logit will be `-Inf` and 178 `Inf` respectively. 179 180 Raises: 181 ValueError: if neither `probs` nor `logits` were passed in, or both were. 182 """ 183 with ops.name_scope(name, values=[probs, logits]): 184 if (probs is None) == (logits is None): 185 raise ValueError("Must pass probs or logits, but not both.") 186 187 if probs is None: 188 logits = ops.convert_to_tensor(logits, name="logits", dtype=dtype) 189 if not logits.dtype.is_floating: 190 raise TypeError("logits must having floating type.") 191 # We can early return since we constructed probs and therefore know 192 # they're valid. 193 if multidimensional: 194 if validate_args: 195 logits = embed_check_categorical_event_shape(logits) 196 return logits, nn.softmax(logits, name="probs") 197 return logits, math_ops.sigmoid(logits, name="probs") 198 199 probs = ops.convert_to_tensor(probs, name="probs", dtype=dtype) 200 if not probs.dtype.is_floating: 201 raise TypeError("probs must having floating type.") 202 203 if validate_args: 204 with ops.name_scope("validate_probs"): 205 one = constant_op.constant(1., probs.dtype) 206 dependencies = [check_ops.assert_non_negative(probs)] 207 if multidimensional: 208 probs = embed_check_categorical_event_shape(probs) 209 dependencies += [ 210 check_ops.assert_near( 211 math_ops.reduce_sum(probs, -1), 212 one, 213 message="probs does not sum to 1.") 214 ] 215 else: 216 dependencies += [check_ops.assert_less_equal( 217 probs, one, message="probs has components greater than 1.")] 218 probs = control_flow_ops.with_dependencies(dependencies, probs) 219 220 with ops.name_scope("logits"): 221 if multidimensional: 222 # Here we don't compute the multidimensional case, in a manner 223 # consistent with respect to the unidimensional case. We do so 224 # following the TF convention. Typically, you might expect to see 225 # logits = log(probs) - log(probs[pivot]). A side-effect of 226 # being consistent with the TF approach is that the unidimensional case 227 # implicitly handles the second dimension but the multidimensional case 228 # explicitly keeps the pivot dimension. 229 return math_ops.log(probs), probs 230 return math_ops.log(probs) - math_ops.log1p(-1. * probs), probs 231 232 233def _is_known_unsigned_by_dtype(dt): 234 """Helper returning True if dtype is known to be unsigned.""" 235 return { 236 dtypes.bool: True, 237 dtypes.uint8: True, 238 dtypes.uint16: True, 239 }.get(dt.base_dtype, False) 240 241 242def _is_known_signed_by_dtype(dt): 243 """Helper returning True if dtype is known to be signed.""" 244 return { 245 dtypes.float16: True, 246 dtypes.float32: True, 247 dtypes.float64: True, 248 dtypes.int8: True, 249 dtypes.int16: True, 250 dtypes.int32: True, 251 dtypes.int64: True, 252 }.get(dt.base_dtype, False) 253 254 255def _is_known_dtype(dt): 256 """Helper returning True if dtype is known.""" 257 return _is_known_unsigned_by_dtype(dt) or _is_known_signed_by_dtype(dt) 258 259 260def _largest_integer_by_dtype(dt): 261 """Helper returning the largest integer exactly representable by dtype.""" 262 if not _is_known_dtype(dt): 263 raise TypeError("Unrecognized dtype: {}".format(dt.name)) 264 if dt.is_floating: 265 return int(2**(np.finfo(dt.as_numpy_dtype).nmant + 1)) 266 if dt.is_integer: 267 return np.iinfo(dt.as_numpy_dtype).max 268 if dt.base_dtype == dtypes.bool: 269 return int(1) 270 # We actually can't land here but keep the case for completeness. 271 raise TypeError("Unrecognized dtype: {}".format(dt.name)) 272 273 274def _smallest_integer_by_dtype(dt): 275 """Helper returning the smallest integer exactly representable by dtype.""" 276 if not _is_known_dtype(dt): 277 raise TypeError("Unrecognized dtype: {}".format(dt.name)) 278 if _is_known_unsigned_by_dtype(dt): 279 return 0 280 return -1 * _largest_integer_by_dtype(dt) 281 282 283def _is_integer_like_by_dtype(dt): 284 """Helper returning True if dtype.is_integer or is `bool`.""" 285 if not _is_known_dtype(dt): 286 raise TypeError("Unrecognized dtype: {}".format(dt.name)) 287 return dt.is_integer or dt.base_dtype == dtypes.bool 288 289 290def embed_check_categorical_event_shape( 291 categorical_param, 292 name="embed_check_categorical_event_shape"): 293 """Embeds checks that categorical distributions don't have too many classes. 294 295 A categorical-type distribution is one which, e.g., returns the class label 296 rather than a one-hot encoding. E.g., `Categorical(probs)`. 297 298 Since distributions output samples in the same dtype as the parameters, we 299 must ensure that casting doesn't lose precision. That is, the 300 `parameter.dtype` implies a maximum number of classes. However, since shape is 301 `int32` and categorical variables are presumed to be indexes into a `Tensor`, 302 we must also ensure that the number of classes is no larger than the largest 303 possible `int32` index, i.e., `2**31-1`. 304 305 In other words the number of classes, `K`, must satisfy the following 306 condition: 307 308 ```python 309 K <= min( 310 int(2**31 - 1), # Largest float as an index. 311 { 312 dtypes.float16: int(2**11), # Largest int as a float16. 313 dtypes.float32: int(2**24), 314 dtypes.float64: int(2**53), 315 }.get(categorical_param.dtype.base_dtype, 0)) 316 ``` 317 318 Args: 319 categorical_param: Floating-point `Tensor` representing parameters of 320 distribution over categories. The rightmost shape is presumed to be the 321 number of categories. 322 name: A name for this operation (optional). 323 324 Returns: 325 categorical_param: Input `Tensor` with appropriate assertions embedded. 326 327 Raises: 328 TypeError: if `categorical_param` has an unknown `dtype`. 329 ValueError: if we can statically identify `categorical_param` as being too 330 large (for being closed under int32/float casting). 331 """ 332 with ops.name_scope(name, values=[categorical_param]): 333 x = ops.convert_to_tensor(categorical_param, name="categorical_param") 334 # The size must not exceed both of: 335 # - The largest possible int32 (since categorical values are presumed to be 336 # indexes into a Tensor). 337 # - The largest possible integer exactly representable under the given 338 # floating-point dtype (since we need to cast to/from). 339 # 340 # The chosen floating-point thresholds are 2**(1 + mantissa_bits). 341 # For more details, see: 342 # https://en.wikipedia.org/wiki/Floating-point_arithmetic#Internal_representation 343 x_dtype = x.dtype.base_dtype 344 max_event_size = (_largest_integer_by_dtype(x_dtype) 345 if x_dtype.is_floating else 0) 346 if max_event_size == 0: 347 raise TypeError("Unable to validate size of unrecognized dtype " 348 "({}).".format(x_dtype.name)) 349 try: 350 x_shape_static = x.get_shape().with_rank_at_least(1) 351 except ValueError: 352 raise ValueError("A categorical-distribution parameter must have " 353 "at least 1 dimension.") 354 if tensor_shape.dimension_value(x_shape_static[-1]) is not None: 355 event_size = x_shape_static.dims[-1].value 356 if event_size < 2: 357 raise ValueError("A categorical-distribution parameter must have at " 358 "least 2 events.") 359 if event_size > max_event_size: 360 raise ValueError( 361 "Number of classes exceeds `dtype` precision, i.e., " 362 "{} implies shape ({}) cannot exceed {}.".format( 363 x_dtype.name, event_size, max_event_size)) 364 return x 365 else: 366 event_size = array_ops.shape(x, name="x_shape")[-1] 367 return control_flow_ops.with_dependencies([ 368 check_ops.assert_rank_at_least( 369 x, 1, message=("A categorical-distribution parameter must have " 370 "at least 1 dimension.")), 371 check_ops.assert_greater_equal( 372 array_ops.shape(x)[-1], 2, 373 message=("A categorical-distribution parameter must have at " 374 "least 2 events.")), 375 check_ops.assert_less_equal( 376 event_size, max_event_size, 377 message="Number of classes exceeds `dtype` precision, " 378 "i.e., {} dtype cannot exceed {} shape.".format( 379 x_dtype.name, max_event_size)), 380 ], x) 381 382 383def embed_check_integer_casting_closed( 384 x, 385 target_dtype, 386 assert_nonnegative=True, 387 name="embed_check_casting_closed"): 388 """Ensures integers remain unaffected despite casting to/from int/float types. 389 390 Example integer-types: `uint8`, `int32`, `bool`. 391 Example floating-types: `float32`, `float64`. 392 393 The largest possible integer representable by an IEEE754 floating-point is 394 `2**(1 + mantissa_bits)` yet the largest possible integer as an int-type is 395 `2**(bits - 1) - 1`. This function ensures that a `Tensor` purporting to have 396 integer-form values can be cast to some other type without loss of precision. 397 398 The smallest representable integer is the negative of the largest 399 representable integer, except for types: `uint8`, `uint16`, `bool`. For these 400 types, the smallest representable integer is `0`. 401 402 Args: 403 x: `Tensor` representing integer-form values. 404 target_dtype: TF `dtype` under which `x` should have identical values. 405 assert_nonnegative: `bool` indicating `x` should contain nonnegative values. 406 name: A name for this operation (optional). 407 408 Returns: 409 x: Input `Tensor` with appropriate assertions embedded. 410 411 Raises: 412 TypeError: if `x` is neither integer- nor floating-type. 413 TypeError: if `target_dtype` is neither integer- nor floating-type. 414 TypeError: if neither `x` nor `target_dtype` are integer-type. 415 """ 416 417 with ops.name_scope(name, values=[x]): 418 x = ops.convert_to_tensor(x, name="x") 419 if (not _is_integer_like_by_dtype(x.dtype) 420 and not x.dtype.is_floating): 421 raise TypeError("{}.dtype must be floating- or " 422 "integer-type.".format(x.dtype.name)) 423 if (not _is_integer_like_by_dtype(target_dtype) 424 and not target_dtype.is_floating): 425 raise TypeError("target_dtype ({}) must be floating- or " 426 "integer-type.".format(target_dtype.name)) 427 if (not _is_integer_like_by_dtype(x.dtype) 428 and not _is_integer_like_by_dtype(target_dtype)): 429 raise TypeError("At least one of {}.dtype ({}) and target_dtype ({}) " 430 "must be integer-type.".format( 431 x, x.dtype.name, target_dtype.name)) 432 433 assertions = [] 434 if assert_nonnegative: 435 assertions += [ 436 check_ops.assert_non_negative( 437 x, message="Elements must be non-negative."), 438 ] 439 440 if x.dtype.is_floating: 441 # Being here means _is_integer_like_by_dtype(target_dtype) = True. 442 # Since this check implies the magnitude check below, we need only it. 443 assertions += [ 444 assert_integer_form( 445 x, int_dtype=target_dtype, 446 message="Elements must be {}-equivalent.".format( 447 target_dtype.name)), 448 ] 449 else: 450 if (_largest_integer_by_dtype(x.dtype) 451 > _largest_integer_by_dtype(target_dtype)): 452 # Cast may lose integer precision. 453 assertions += [ 454 check_ops.assert_less_equal( 455 x, _largest_integer_by_dtype(target_dtype), 456 message=("Elements cannot exceed {}.".format( 457 _largest_integer_by_dtype(target_dtype)))), 458 ] 459 if (not assert_nonnegative and 460 (_smallest_integer_by_dtype(x.dtype) 461 < _smallest_integer_by_dtype(target_dtype))): 462 assertions += [ 463 check_ops.assert_greater_equal( 464 x, _smallest_integer_by_dtype(target_dtype), 465 message=("Elements cannot be smaller than {}.".format( 466 _smallest_integer_by_dtype(target_dtype)))), 467 ] 468 469 if not assertions: 470 return x 471 return control_flow_ops.with_dependencies(assertions, x) 472 473 474def log_combinations(n, counts, name="log_combinations"): 475 """Multinomial coefficient. 476 477 Given `n` and `counts`, where `counts` has last dimension `k`, we compute 478 the multinomial coefficient as: 479 480 ```n! / sum_i n_i!``` 481 482 where `i` runs over all `k` classes. 483 484 Args: 485 n: Floating-point `Tensor` broadcastable with `counts`. This represents `n` 486 outcomes. 487 counts: Floating-point `Tensor` broadcastable with `n`. This represents 488 counts in `k` classes, where `k` is the last dimension of the tensor. 489 name: A name for this operation (optional). 490 491 Returns: 492 `Tensor` representing the multinomial coefficient between `n` and `counts`. 493 """ 494 # First a bit about the number of ways counts could have come in: 495 # E.g. if counts = [1, 2], then this is 3 choose 2. 496 # In general, this is (sum counts)! / sum(counts!) 497 # The sum should be along the last dimension of counts. This is the 498 # "distribution" dimension. Here n a priori represents the sum of counts. 499 with ops.name_scope(name, values=[n, counts]): 500 n = ops.convert_to_tensor(n, name="n") 501 counts = ops.convert_to_tensor(counts, name="counts") 502 total_permutations = math_ops.lgamma(n + 1) 503 counts_factorial = math_ops.lgamma(counts + 1) 504 redundant_permutations = math_ops.reduce_sum(counts_factorial, axis=[-1]) 505 return total_permutations - redundant_permutations 506 507 508def matrix_diag_transform(matrix, transform=None, name=None): 509 """Transform diagonal of [batch-]matrix, leave rest of matrix unchanged. 510 511 Create a trainable covariance defined by a Cholesky factor: 512 513 ```python 514 # Transform network layer into 2 x 2 array. 515 matrix_values = tf.contrib.layers.fully_connected(activations, 4) 516 matrix = tf.reshape(matrix_values, (batch_size, 2, 2)) 517 518 # Make the diagonal positive. If the upper triangle was zero, this would be a 519 # valid Cholesky factor. 520 chol = matrix_diag_transform(matrix, transform=tf.nn.softplus) 521 522 # LinearOperatorLowerTriangular ignores the upper triangle. 523 operator = LinearOperatorLowerTriangular(chol) 524 ``` 525 526 Example of heteroskedastic 2-D linear regression. 527 528 ```python 529 tfd = tfp.distributions 530 531 # Get a trainable Cholesky factor. 532 matrix_values = tf.contrib.layers.fully_connected(activations, 4) 533 matrix = tf.reshape(matrix_values, (batch_size, 2, 2)) 534 chol = matrix_diag_transform(matrix, transform=tf.nn.softplus) 535 536 # Get a trainable mean. 537 mu = tf.contrib.layers.fully_connected(activations, 2) 538 539 # This is a fully trainable multivariate normal! 540 dist = tfd.MultivariateNormalTriL(mu, chol) 541 542 # Standard log loss. Minimizing this will "train" mu and chol, and then dist 543 # will be a distribution predicting labels as multivariate Gaussians. 544 loss = -1 * tf.reduce_mean(dist.log_prob(labels)) 545 ``` 546 547 Args: 548 matrix: Rank `R` `Tensor`, `R >= 2`, where the last two dimensions are 549 equal. 550 transform: Element-wise function mapping `Tensors` to `Tensors`. To 551 be applied to the diagonal of `matrix`. If `None`, `matrix` is returned 552 unchanged. Defaults to `None`. 553 name: A name to give created ops. 554 Defaults to "matrix_diag_transform". 555 556 Returns: 557 A `Tensor` with same shape and `dtype` as `matrix`. 558 """ 559 with ops.name_scope(name, "matrix_diag_transform", [matrix]): 560 matrix = ops.convert_to_tensor(matrix, name="matrix") 561 if transform is None: 562 return matrix 563 # Replace the diag with transformed diag. 564 diag = array_ops.matrix_diag_part(matrix) 565 transformed_diag = transform(diag) 566 transformed_mat = array_ops.matrix_set_diag(matrix, transformed_diag) 567 568 return transformed_mat 569 570 571def rotate_transpose(x, shift, name="rotate_transpose"): 572 """Circularly moves dims left or right. 573 574 Effectively identical to: 575 576 ```python 577 numpy.transpose(x, numpy.roll(numpy.arange(len(x.shape)), shift)) 578 ``` 579 580 When `validate_args=False` additional graph-runtime checks are 581 performed. These checks entail moving data from to GPU to CPU. 582 583 Example: 584 585 ```python 586 x = tf.random_normal([1, 2, 3, 4]) # Tensor of shape [1, 2, 3, 4]. 587 rotate_transpose(x, -1).shape == [2, 3, 4, 1] 588 rotate_transpose(x, -2).shape == [3, 4, 1, 2] 589 rotate_transpose(x, 1).shape == [4, 1, 2, 3] 590 rotate_transpose(x, 2).shape == [3, 4, 1, 2] 591 rotate_transpose(x, 7).shape == rotate_transpose(x, 3).shape # [2, 3, 4, 1] 592 rotate_transpose(x, -7).shape == rotate_transpose(x, -3).shape # [4, 1, 2, 3] 593 ``` 594 595 Args: 596 x: `Tensor`. 597 shift: `Tensor`. Number of dimensions to transpose left (shift<0) or 598 transpose right (shift>0). 599 name: Python `str`. The name to give this op. 600 601 Returns: 602 rotated_x: Input `Tensor` with dimensions circularly rotated by shift. 603 604 Raises: 605 TypeError: if shift is not integer type. 606 """ 607 with ops.name_scope(name, values=[x, shift]): 608 x = ops.convert_to_tensor(x, name="x") 609 shift = ops.convert_to_tensor(shift, name="shift") 610 # We do not assign back to preserve constant-ness. 611 check_ops.assert_integer(shift) 612 shift_value_static = tensor_util.constant_value(shift) 613 ndims = x.get_shape().ndims 614 if ndims is not None and shift_value_static is not None: 615 if ndims < 2: return x 616 shift_value_static = np.sign(shift_value_static) * ( 617 abs(shift_value_static) % ndims) 618 if shift_value_static == 0: return x 619 perm = np.roll(np.arange(ndims), shift_value_static) 620 return array_ops.transpose(x, perm=perm) 621 else: 622 # Consider if we always had a positive shift, and some specified 623 # direction. 624 # When shifting left we want the new array: 625 # last(x, n-shift) + first(x, shift) 626 # and if shifting right then we want: 627 # last(x, shift) + first(x, n-shift) 628 # Observe that last(a) == slice(a, n) and first(a) == slice(0, a). 629 # Also, we can encode direction and shift as one: direction * shift. 630 # Combining these facts, we have: 631 # a = cond(shift<0, -shift, n-shift) 632 # last(x, n-a) + first(x, a) == x[a:n] + x[0:a] 633 # Finally, we transform shift by modulo length so it can be specified 634 # independently from the array upon which it operates (like python). 635 ndims = array_ops.rank(x) 636 shift = array_ops.where(math_ops.less(shift, 0), 637 math_ops.mod(-shift, ndims), 638 ndims - math_ops.mod(shift, ndims)) 639 first = math_ops.range(0, shift) 640 last = math_ops.range(shift, ndims) 641 perm = array_ops.concat([last, first], 0) 642 return array_ops.transpose(x, perm=perm) 643 644 645def pick_vector(cond, 646 true_vector, 647 false_vector, 648 name="pick_vector"): 649 """Picks possibly different length row `Tensor`s based on condition. 650 651 Value `Tensor`s should have exactly one dimension. 652 653 If `cond` is a python Boolean or `tf.constant` then either `true_vector` or 654 `false_vector` is immediately returned. I.e., no graph nodes are created and 655 no validation happens. 656 657 Args: 658 cond: `Tensor`. Must have `dtype=tf.bool` and be scalar. 659 true_vector: `Tensor` of one dimension. Returned when cond is `True`. 660 false_vector: `Tensor` of one dimension. Returned when cond is `False`. 661 name: Python `str`. The name to give this op. 662 663 Example: 664 665 ```python 666 pick_vector(tf.less(0, 5), tf.range(10, 12), tf.range(15, 18)) # [10, 11] 667 pick_vector(tf.less(5, 0), tf.range(10, 12), tf.range(15, 18)) # [15, 16, 17] 668 ``` 669 670 Returns: 671 true_or_false_vector: `Tensor`. 672 673 Raises: 674 TypeError: if `cond.dtype != tf.bool` 675 TypeError: if `cond` is not a constant and 676 `true_vector.dtype != false_vector.dtype` 677 """ 678 with ops.name_scope(name, values=(cond, true_vector, false_vector)): 679 cond = ops.convert_to_tensor(cond, name="cond") 680 if cond.dtype != dtypes.bool: 681 raise TypeError("%s.dtype=%s which is not %s" % 682 (cond, cond.dtype, dtypes.bool)) 683 cond_value_static = tensor_util.constant_value(cond) 684 if cond_value_static is not None: 685 return true_vector if cond_value_static else false_vector 686 true_vector = ops.convert_to_tensor(true_vector, name="true_vector") 687 false_vector = ops.convert_to_tensor(false_vector, name="false_vector") 688 if true_vector.dtype != false_vector.dtype: 689 raise TypeError( 690 "%s.dtype=%s does not match %s.dtype=%s" 691 % (true_vector, true_vector.dtype, 692 false_vector, false_vector.dtype)) 693 n = array_ops.shape(true_vector)[0] 694 return array_ops.slice( 695 array_ops.concat([true_vector, false_vector], 0), 696 [array_ops.where(cond, 0, n)], [array_ops.where(cond, n, -1)]) 697 698 699def prefer_static_broadcast_shape( 700 shape1, shape2, name="prefer_static_broadcast_shape"): 701 """Convenience function which statically broadcasts shape when possible. 702 703 Args: 704 shape1: `1-D` integer `Tensor`. Already converted to tensor! 705 shape2: `1-D` integer `Tensor`. Already converted to tensor! 706 name: A string name to prepend to created ops. 707 708 Returns: 709 The broadcast shape, either as `TensorShape` (if broadcast can be done 710 statically), or as a `Tensor`. 711 """ 712 with ops.name_scope(name, values=[shape1, shape2]): 713 def make_shape_tensor(x): 714 return ops.convert_to_tensor(x, name="shape", dtype=dtypes.int32) 715 716 def get_tensor_shape(s): 717 if isinstance(s, tensor_shape.TensorShape): 718 return s 719 s_ = tensor_util.constant_value(make_shape_tensor(s)) 720 if s_ is not None: 721 return tensor_shape.TensorShape(s_) 722 return None 723 724 def get_shape_tensor(s): 725 if not isinstance(s, tensor_shape.TensorShape): 726 return make_shape_tensor(s) 727 if s.is_fully_defined(): 728 return make_shape_tensor(s.as_list()) 729 raise ValueError("Cannot broadcast from partially " 730 "defined `TensorShape`.") 731 732 shape1_ = get_tensor_shape(shape1) 733 shape2_ = get_tensor_shape(shape2) 734 if shape1_ is not None and shape2_ is not None: 735 return array_ops.broadcast_static_shape(shape1_, shape2_) 736 737 shape1_ = get_shape_tensor(shape1) 738 shape2_ = get_shape_tensor(shape2) 739 return array_ops.broadcast_dynamic_shape(shape1_, shape2_) 740 741 742def prefer_static_rank(x): 743 """Return static rank of tensor `x` if available, else `tf.rank(x)`. 744 745 Args: 746 x: `Tensor` (already converted). 747 748 Returns: 749 Numpy array (if static rank is obtainable), else `Tensor`. 750 """ 751 return prefer_static_value(array_ops.rank(x)) 752 753 754def prefer_static_shape(x): 755 """Return static shape of tensor `x` if available, else `tf.shape(x)`. 756 757 Args: 758 x: `Tensor` (already converted). 759 760 Returns: 761 Numpy array (if static shape is obtainable), else `Tensor`. 762 """ 763 return prefer_static_value(array_ops.shape(x)) 764 765 766def prefer_static_value(x): 767 """Return static value of tensor `x` if available, else `x`. 768 769 Args: 770 x: `Tensor` (already converted). 771 772 Returns: 773 Numpy array (if static value is obtainable), else `Tensor`. 774 """ 775 static_x = tensor_util.constant_value(x) 776 if static_x is not None: 777 return static_x 778 return x 779 780 781def gen_new_seed(seed, salt): 782 """Generate a new seed, from the given seed and salt.""" 783 if seed is None: 784 return None 785 string = (str(seed) + salt).encode("utf-8") 786 return int(hashlib.md5(string).hexdigest()[:8], 16) & 0x7FFFFFFF 787 788 789def fill_triangular(x, upper=False, name=None): 790 """Creates a (batch of) triangular matrix from a vector of inputs. 791 792 Created matrix can be lower- or upper-triangular. (It is more efficient to 793 create the matrix as upper or lower, rather than transpose.) 794 795 Triangular matrix elements are filled in a clockwise spiral. See example, 796 below. 797 798 If `x.get_shape()` is `[b1, b2, ..., bB, d]` then the output shape is 799 `[b1, b2, ..., bB, n, n]` where `n` is such that `d = n(n+1)/2`, i.e., 800 `n = int(np.sqrt(0.25 + 2. * m) - 0.5)`. 801 802 Example: 803 804 ```python 805 fill_triangular([1, 2, 3, 4, 5, 6]) 806 # ==> [[4, 0, 0], 807 # [6, 5, 0], 808 # [3, 2, 1]] 809 810 fill_triangular([1, 2, 3, 4, 5, 6], upper=True) 811 # ==> [[1, 2, 3], 812 # [0, 5, 6], 813 # [0, 0, 4]] 814 ``` 815 816 For comparison, a pure numpy version of this function can be found in 817 `util_test.py`, function `_fill_triangular`. 818 819 Args: 820 x: `Tensor` representing lower (or upper) triangular elements. 821 upper: Python `bool` representing whether output matrix should be upper 822 triangular (`True`) or lower triangular (`False`, default). 823 name: Python `str`. The name to give this op. 824 825 Returns: 826 tril: `Tensor` with lower (or upper) triangular elements filled from `x`. 827 828 Raises: 829 ValueError: if `x` cannot be mapped to a triangular matrix. 830 """ 831 832 with ops.name_scope(name, "fill_triangular", values=[x]): 833 x = ops.convert_to_tensor(x, name="x") 834 if tensor_shape.dimension_value( 835 x.shape.with_rank_at_least(1)[-1]) is not None: 836 # Formula derived by solving for n: m = n(n+1)/2. 837 m = np.int32(x.shape.dims[-1].value) 838 n = np.sqrt(0.25 + 2. * m) - 0.5 839 if n != np.floor(n): 840 raise ValueError("Input right-most shape ({}) does not " 841 "correspond to a triangular matrix.".format(m)) 842 n = np.int32(n) 843 static_final_shape = x.shape[:-1].concatenate([n, n]) 844 else: 845 m = array_ops.shape(x)[-1] 846 # For derivation, see above. Casting automatically lops off the 0.5, so we 847 # omit it. We don't validate n is an integer because this has 848 # graph-execution cost; an error will be thrown from the reshape, below. 849 n = math_ops.cast( 850 math_ops.sqrt(0.25 + math_ops.cast(2 * m, dtype=dtypes.float32)), 851 dtype=dtypes.int32) 852 static_final_shape = x.shape.with_rank_at_least(1)[:-1].concatenate( 853 [None, None]) 854 # We now concatenate the "tail" of `x` to `x` (and reverse one of them). 855 # 856 # We do this based on the insight that the input `x` provides `ceil(n/2)` 857 # rows of an `n x n` matrix, some of which will get zeroed out being on the 858 # wrong side of the diagonal. The first row will not get zeroed out at all, 859 # and we need `floor(n/2)` more rows, so the first is what we omit from 860 # `x_tail`. If we then stack those `ceil(n/2)` rows with the `floor(n/2)` 861 # rows provided by a reversed tail, it is exactly the other set of elements 862 # of the reversed tail which will be zeroed out for being on the wrong side 863 # of the diagonal further up/down the matrix. And, in doing-so, we've filled 864 # the triangular matrix in a clock-wise spiral pattern. Neat! 865 # 866 # Try it out in numpy: 867 # n = 3 868 # x = np.arange(n * (n + 1) / 2) 869 # m = x.shape[0] 870 # n = np.int32(np.sqrt(.25 + 2 * m) - .5) 871 # x_tail = x[(m - (n**2 - m)):] 872 # np.concatenate([x_tail, x[::-1]], 0).reshape(n, n) # lower 873 # # ==> array([[3, 4, 5], 874 # [5, 4, 3], 875 # [2, 1, 0]]) 876 # np.concatenate([x, x_tail[::-1]], 0).reshape(n, n) # upper 877 # # ==> array([[0, 1, 2], 878 # [3, 4, 5], 879 # [5, 4, 3]]) 880 # 881 # Note that we can't simply do `x[..., -(n**2 - m):]` because this doesn't 882 # correctly handle `m == n == 1`. Hence, we do nonnegative indexing. 883 # Furthermore observe that: 884 # m - (n**2 - m) 885 # = n**2 / 2 + n / 2 - (n**2 - n**2 / 2 + n / 2) 886 # = 2 (n**2 / 2 + n / 2) - n**2 887 # = n**2 + n - n**2 888 # = n 889 ndims = prefer_static_rank(x) 890 if upper: 891 x_list = [x, array_ops.reverse(x[..., n:], axis=[ndims - 1])] 892 else: 893 x_list = [x[..., n:], array_ops.reverse(x, axis=[ndims - 1])] 894 new_shape = ( 895 static_final_shape.as_list() 896 if static_final_shape.is_fully_defined() 897 else array_ops.concat([array_ops.shape(x)[:-1], [n, n]], axis=0)) 898 x = array_ops.reshape(array_ops.concat(x_list, axis=-1), new_shape) 899 x = array_ops.matrix_band_part( 900 x, 901 num_lower=(0 if upper else -1), 902 num_upper=(-1 if upper else 0)) 903 x.set_shape(static_final_shape) 904 return x 905 906 907def fill_triangular_inverse(x, upper=False, name=None): 908 """Creates a vector from a (batch of) triangular matrix. 909 910 The vector is created from the lower-triangular or upper-triangular portion 911 depending on the value of the parameter `upper`. 912 913 If `x.shape` is `[b1, b2, ..., bB, n, n]` then the output shape is 914 `[b1, b2, ..., bB, d]` where `d = n (n + 1) / 2`. 915 916 Example: 917 918 ```python 919 fill_triangular_inverse( 920 [[4, 0, 0], 921 [6, 5, 0], 922 [3, 2, 1]]) 923 924 # ==> [1, 2, 3, 4, 5, 6] 925 926 fill_triangular_inverse( 927 [[1, 2, 3], 928 [0, 5, 6], 929 [0, 0, 4]], upper=True) 930 931 # ==> [1, 2, 3, 4, 5, 6] 932 ``` 933 934 Args: 935 x: `Tensor` representing lower (or upper) triangular elements. 936 upper: Python `bool` representing whether output matrix should be upper 937 triangular (`True`) or lower triangular (`False`, default). 938 name: Python `str`. The name to give this op. 939 940 Returns: 941 flat_tril: (Batch of) vector-shaped `Tensor` representing vectorized lower 942 (or upper) triangular elements from `x`. 943 """ 944 945 with ops.name_scope(name, "fill_triangular_inverse", values=[x]): 946 x = ops.convert_to_tensor(x, name="x") 947 if tensor_shape.dimension_value( 948 x.shape.with_rank_at_least(2)[-1]) is not None: 949 n = np.int32(x.shape.dims[-1].value) 950 m = np.int32((n * (n + 1)) // 2) 951 static_final_shape = x.shape[:-2].concatenate([m]) 952 else: 953 n = array_ops.shape(x)[-1] 954 m = (n * (n + 1)) // 2 955 static_final_shape = x.shape.with_rank_at_least(2)[:-2].concatenate( 956 [None]) 957 ndims = prefer_static_rank(x) 958 if upper: 959 initial_elements = x[..., 0, :] 960 triangular_portion = x[..., 1:, :] 961 else: 962 initial_elements = array_ops.reverse(x[..., -1, :], axis=[ndims - 2]) 963 triangular_portion = x[..., :-1, :] 964 rotated_triangular_portion = array_ops.reverse( 965 array_ops.reverse(triangular_portion, axis=[ndims - 1]), 966 axis=[ndims - 2]) 967 consolidated_matrix = triangular_portion + rotated_triangular_portion 968 end_sequence = array_ops.reshape( 969 consolidated_matrix, 970 array_ops.concat([array_ops.shape(x)[:-2], [n * (n - 1)]], axis=0)) 971 y = array_ops.concat([initial_elements, end_sequence[..., :m - n]], axis=-1) 972 y.set_shape(static_final_shape) 973 return y 974 975 976def tridiag(below=None, diag=None, above=None, name=None): 977 """Creates a matrix with values set above, below, and on the diagonal. 978 979 Example: 980 981 ```python 982 tridiag(below=[1., 2., 3.], 983 diag=[4., 5., 6., 7.], 984 above=[8., 9., 10.]) 985 # ==> array([[ 4., 8., 0., 0.], 986 # [ 1., 5., 9., 0.], 987 # [ 0., 2., 6., 10.], 988 # [ 0., 0., 3., 7.]], dtype=float32) 989 ``` 990 991 Warning: This Op is intended for convenience, not efficiency. 992 993 Args: 994 below: `Tensor` of shape `[B1, ..., Bb, d-1]` corresponding to the below 995 diagonal part. `None` is logically equivalent to `below = 0`. 996 diag: `Tensor` of shape `[B1, ..., Bb, d]` corresponding to the diagonal 997 part. `None` is logically equivalent to `diag = 0`. 998 above: `Tensor` of shape `[B1, ..., Bb, d-1]` corresponding to the above 999 diagonal part. `None` is logically equivalent to `above = 0`. 1000 name: Python `str`. The name to give this op. 1001 1002 Returns: 1003 tridiag: `Tensor` with values set above, below and on the diagonal. 1004 1005 Raises: 1006 ValueError: if all inputs are `None`. 1007 """ 1008 1009 def _pad(x): 1010 """Prepends and appends a zero to every vector in a batch of vectors.""" 1011 shape = array_ops.concat([array_ops.shape(x)[:-1], [1]], axis=0) 1012 z = array_ops.zeros(shape, dtype=x.dtype) 1013 return array_ops.concat([z, x, z], axis=-1) 1014 1015 def _add(*x): 1016 """Adds list of Tensors, ignoring `None`.""" 1017 s = None 1018 for y in x: 1019 if y is None: 1020 continue 1021 elif s is None: 1022 s = y 1023 else: 1024 s += y 1025 if s is None: 1026 raise ValueError("Must specify at least one of `below`, `diag`, `above`.") 1027 return s 1028 1029 with ops.name_scope(name, "tridiag", [below, diag, above]): 1030 if below is not None: 1031 below = ops.convert_to_tensor(below, name="below") 1032 below = array_ops.matrix_diag(_pad(below))[..., :-1, 1:] 1033 if diag is not None: 1034 diag = ops.convert_to_tensor(diag, name="diag") 1035 diag = array_ops.matrix_diag(diag) 1036 if above is not None: 1037 above = ops.convert_to_tensor(above, name="above") 1038 above = array_ops.matrix_diag(_pad(above))[..., 1:, :-1] 1039 # TODO(jvdillon): Consider using scatter_nd instead of creating three full 1040 # matrices. 1041 return _add(below, diag, above) 1042 1043 1044def reduce_weighted_logsumexp( 1045 logx, 1046 w=None, 1047 axis=None, 1048 keep_dims=False, 1049 return_sign=False, 1050 name=None): 1051 """Computes `log(abs(sum(weight * exp(elements across tensor dimensions))))`. 1052 1053 If all weights `w` are known to be positive, it is more efficient to directly 1054 use `reduce_logsumexp`, i.e., `tf.reduce_logsumexp(logx + tf.log(w))` is more 1055 efficient than `du.reduce_weighted_logsumexp(logx, w)`. 1056 1057 Reduces `input_tensor` along the dimensions given in `axis`. 1058 Unless `keep_dims` is true, the rank of the tensor is reduced by 1 for each 1059 entry in `axis`. If `keep_dims` is true, the reduced dimensions 1060 are retained with length 1. 1061 1062 If `axis` has no entries, all dimensions are reduced, and a 1063 tensor with a single element is returned. 1064 1065 This function is more numerically stable than log(sum(w * exp(input))). It 1066 avoids overflows caused by taking the exp of large inputs and underflows 1067 caused by taking the log of small inputs. 1068 1069 For example: 1070 1071 ```python 1072 x = tf.constant([[0., 0, 0], 1073 [0, 0, 0]]) 1074 1075 w = tf.constant([[-1., 1, 1], 1076 [1, 1, 1]]) 1077 1078 du.reduce_weighted_logsumexp(x, w) 1079 # ==> log(-1*1 + 1*1 + 1*1 + 1*1 + 1*1 + 1*1) = log(4) 1080 1081 du.reduce_weighted_logsumexp(x, w, axis=0) 1082 # ==> [log(-1+1), log(1+1), log(1+1)] 1083 1084 du.reduce_weighted_logsumexp(x, w, axis=1) 1085 # ==> [log(-1+1+1), log(1+1+1)] 1086 1087 du.reduce_weighted_logsumexp(x, w, axis=1, keep_dims=True) 1088 # ==> [[log(-1+1+1)], [log(1+1+1)]] 1089 1090 du.reduce_weighted_logsumexp(x, w, axis=[0, 1]) 1091 # ==> log(-1+5) 1092 ``` 1093 1094 Args: 1095 logx: The tensor to reduce. Should have numeric type. 1096 w: The weight tensor. Should have numeric type identical to `logx`. 1097 axis: The dimensions to reduce. If `None` (the default), 1098 reduces all dimensions. Must be in the range 1099 `[-rank(input_tensor), rank(input_tensor))`. 1100 keep_dims: If true, retains reduced dimensions with length 1. 1101 return_sign: If `True`, returns the sign of the result. 1102 name: A name for the operation (optional). 1103 1104 Returns: 1105 lswe: The `log(abs(sum(weight * exp(x))))` reduced tensor. 1106 sign: (Optional) The sign of `sum(weight * exp(x))`. 1107 """ 1108 with ops.name_scope(name, "reduce_weighted_logsumexp", [logx, w]): 1109 logx = ops.convert_to_tensor(logx, name="logx") 1110 if w is None: 1111 lswe = math_ops.reduce_logsumexp(logx, axis=axis, keepdims=keep_dims) 1112 if return_sign: 1113 sgn = array_ops.ones_like(lswe) 1114 return lswe, sgn 1115 return lswe 1116 w = ops.convert_to_tensor(w, dtype=logx.dtype, name="w") 1117 log_absw_x = logx + math_ops.log(math_ops.abs(w)) 1118 max_log_absw_x = math_ops.reduce_max(log_absw_x, axis=axis, keepdims=True) 1119 # If the largest element is `-inf` or `inf` then we don't bother subtracting 1120 # off the max. We do this because otherwise we'd get `inf - inf = NaN`. That 1121 # this is ok follows from the fact that we're actually free to subtract any 1122 # value we like, so long as we add it back after taking the `log(sum(...))`. 1123 max_log_absw_x = array_ops.where( 1124 math_ops.is_inf(max_log_absw_x), 1125 array_ops.zeros_like(max_log_absw_x), 1126 max_log_absw_x) 1127 wx_over_max_absw_x = ( 1128 math_ops.sign(w) * math_ops.exp(log_absw_x - max_log_absw_x)) 1129 sum_wx_over_max_absw_x = math_ops.reduce_sum( 1130 wx_over_max_absw_x, axis=axis, keepdims=keep_dims) 1131 if not keep_dims: 1132 max_log_absw_x = array_ops.squeeze(max_log_absw_x, axis) 1133 sgn = math_ops.sign(sum_wx_over_max_absw_x) 1134 lswe = max_log_absw_x + math_ops.log(sgn * sum_wx_over_max_absw_x) 1135 if return_sign: 1136 return lswe, sgn 1137 return lswe 1138 1139 1140# TODO(jvdillon): Merge this test back into: 1141# tensorflow/python/ops/softplus_op_test.py 1142# once TF core is accepting new ops. 1143def softplus_inverse(x, name=None): 1144 """Computes the inverse softplus, i.e., x = softplus_inverse(softplus(x)). 1145 1146 Mathematically this op is equivalent to: 1147 1148 ```none 1149 softplus_inverse = log(exp(x) - 1.) 1150 ``` 1151 1152 Args: 1153 x: `Tensor`. Non-negative (not enforced), floating-point. 1154 name: A name for the operation (optional). 1155 1156 Returns: 1157 `Tensor`. Has the same type/shape as input `x`. 1158 """ 1159 with ops.name_scope(name, "softplus_inverse", values=[x]): 1160 x = ops.convert_to_tensor(x, name="x") 1161 # We begin by deriving a more numerically stable softplus_inverse: 1162 # x = softplus(y) = Log[1 + exp{y}], (which means x > 0). 1163 # ==> exp{x} = 1 + exp{y} (1) 1164 # ==> y = Log[exp{x} - 1] (2) 1165 # = Log[(exp{x} - 1) / exp{x}] + Log[exp{x}] 1166 # = Log[(1 - exp{-x}) / 1] + Log[exp{x}] 1167 # = Log[1 - exp{-x}] + x (3) 1168 # (2) is the "obvious" inverse, but (3) is more stable than (2) for large x. 1169 # For small x (e.g. x = 1e-10), (3) will become -inf since 1 - exp{-x} will 1170 # be zero. To fix this, we use 1 - exp{-x} approx x for small x > 0. 1171 # 1172 # In addition to the numerically stable derivation above, we clamp 1173 # small/large values to be congruent with the logic in: 1174 # tensorflow/core/kernels/softplus_op.h 1175 # 1176 # Finally, we set the input to one whenever the input is too large or too 1177 # small. This ensures that no unchosen codepath is +/- inf. This is 1178 # necessary to ensure the gradient doesn't get NaNs. Recall that the 1179 # gradient of `where` behaves like `pred*pred_true + (1-pred)*pred_false` 1180 # thus an `inf` in an unselected path results in `0*inf=nan`. We are careful 1181 # to overwrite `x` with ones only when we will never actually use this 1182 # value. Note that we use ones and not zeros since `log(expm1(0.)) = -inf`. 1183 threshold = np.log(np.finfo(x.dtype.as_numpy_dtype).eps) + 2. 1184 is_too_small = math_ops.less(x, np.exp(threshold)) 1185 is_too_large = math_ops.greater(x, -threshold) 1186 too_small_value = math_ops.log(x) 1187 too_large_value = x 1188 # This `where` will ultimately be a NOP because we won't select this 1189 # codepath whenever we used the surrogate `ones_like`. 1190 x = array_ops.where(math_ops.logical_or(is_too_small, is_too_large), 1191 array_ops.ones_like(x), x) 1192 y = x + math_ops.log(-math_ops.expm1(-x)) # == log(expm1(x)) 1193 return array_ops.where(is_too_small, too_small_value, 1194 array_ops.where(is_too_large, too_large_value, y)) 1195 1196 1197# TODO(b/35290280): Add unit-tests. 1198def dimension_size(x, axis): 1199 """Returns the size of a specific dimension.""" 1200 # Since tf.gather isn't "constant-in, constant-out", we must first check the 1201 # static shape or fallback to dynamic shape. 1202 s = tensor_shape.dimension_value( 1203 x.shape.with_rank_at_least(np.abs(axis))[axis]) 1204 if s is not None: 1205 return s 1206 return array_ops.shape(x)[axis] 1207 1208 1209def process_quadrature_grid_and_probs( 1210 quadrature_grid_and_probs, dtype, validate_args, name=None): 1211 """Validates quadrature grid, probs or computes them as necessary. 1212 1213 Args: 1214 quadrature_grid_and_probs: Python pair of `float`-like `Tensor`s 1215 representing the sample points and the corresponding (possibly 1216 normalized) weight. When `None`, defaults to: 1217 `np.polynomial.hermite.hermgauss(deg=8)`. 1218 dtype: The expected `dtype` of `grid` and `probs`. 1219 validate_args: Python `bool`, default `False`. When `True` distribution 1220 parameters are checked for validity despite possibly degrading runtime 1221 performance. When `False` invalid inputs may silently render incorrect 1222 outputs. 1223 name: Python `str` name prefixed to Ops created by this class. 1224 1225 Returns: 1226 quadrature_grid_and_probs: Python pair of `float`-like `Tensor`s 1227 representing the sample points and the corresponding (possibly 1228 normalized) weight. 1229 1230 Raises: 1231 ValueError: if `quadrature_grid_and_probs is not None` and 1232 `len(quadrature_grid_and_probs[0]) != len(quadrature_grid_and_probs[1])` 1233 """ 1234 with ops.name_scope(name, "process_quadrature_grid_and_probs", 1235 [quadrature_grid_and_probs]): 1236 if quadrature_grid_and_probs is None: 1237 grid, probs = np.polynomial.hermite.hermgauss(deg=8) 1238 grid = grid.astype(dtype.as_numpy_dtype) 1239 probs = probs.astype(dtype.as_numpy_dtype) 1240 probs /= np.linalg.norm(probs, ord=1, keepdims=True) 1241 grid = ops.convert_to_tensor(grid, name="grid", dtype=dtype) 1242 probs = ops.convert_to_tensor(probs, name="probs", dtype=dtype) 1243 return grid, probs 1244 1245 grid, probs = tuple(quadrature_grid_and_probs) 1246 grid = ops.convert_to_tensor(grid, name="grid", dtype=dtype) 1247 probs = ops.convert_to_tensor(probs, name="unnormalized_probs", 1248 dtype=dtype) 1249 probs /= linalg_ops.norm(probs, ord=1, axis=-1, keepdims=True, name="probs") 1250 1251 def _static_event_size(x): 1252 """Returns the static size of a specific dimension or `None`.""" 1253 return tensor_shape.dimension_value(x.shape.with_rank_at_least(1)[-1]) 1254 1255 m, n = _static_event_size(probs), _static_event_size(grid) 1256 if m is not None and n is not None: 1257 if m != n: 1258 raise ValueError("`quadrature_grid_and_probs` must be a `tuple` of " 1259 "same-length zero-th-dimension `Tensor`s " 1260 "(saw lengths {}, {})".format(m, n)) 1261 elif validate_args: 1262 assertions = [ 1263 check_ops.assert_equal( 1264 dimension_size(probs, axis=-1), 1265 dimension_size(grid, axis=-1), 1266 message=("`quadrature_grid_and_probs` must be a `tuple` of " 1267 "same-length zero-th-dimension `Tensor`s")), 1268 ] 1269 with ops.control_dependencies(assertions): 1270 grid = array_ops.identity(grid) 1271 probs = array_ops.identity(probs) 1272 return grid, probs 1273 1274 1275def pad(x, axis, front=False, back=False, value=0, count=1, name=None): 1276 """Pads `value` to the front and/or back of a `Tensor` dim, `count` times. 1277 1278 Args: 1279 x: `Tensor` input. 1280 axis: Scalar `int`-like `Tensor` representing the single dimension to pad. 1281 (Negative indexing is supported.) 1282 front: Python `bool`; if `True` the beginning of the `axis` dimension is 1283 padded with `value`, `count` times. If `False` no front padding is made. 1284 back: Python `bool`; if `True` the end of the `axis` dimension is 1285 padded with `value`, `count` times. If `False` no end padding is made. 1286 value: Scalar `int`-like `Tensor` representing the actual value added to the 1287 front and/or back of the `axis` dimension of `x`. 1288 count: Scalar `int`-like `Tensor` representing number of elements added to 1289 the front and/or back of the `axis` dimension of `x`. E.g., if 1290 `front = back = True` then `2 * count` elements are added. 1291 name: Python `str` name prefixed to Ops created by this function. 1292 1293 Returns: 1294 pad: The padded version of input `x`. 1295 1296 Raises: 1297 ValueError: if both `front` and `back` are `False`. 1298 TypeError: if `count` is not `int`-like. 1299 """ 1300 with ops.name_scope(name, "pad", [x, value, count]): 1301 x = ops.convert_to_tensor(x, name="x") 1302 value = ops.convert_to_tensor(value, dtype=x.dtype, name="value") 1303 count = ops.convert_to_tensor(count, name="count") 1304 if not count.dtype.is_integer: 1305 raise TypeError("`count.dtype` (`{}`) must be `int`-like.".format( 1306 count.dtype.name)) 1307 if not front and not back: 1308 raise ValueError("At least one of `front`, `back` must be `True`.") 1309 ndims = (x.shape.ndims if x.shape.ndims is not None 1310 else array_ops.rank(x, name="ndims")) 1311 axis = ops.convert_to_tensor(axis, name="axis") 1312 axis_ = tensor_util.constant_value(axis) 1313 if axis_ is not None: 1314 axis = axis_ 1315 if axis < 0: 1316 axis = ndims + axis 1317 count_ = tensor_util.constant_value(count) 1318 if axis_ >= 0 or x.shape.ndims is not None: 1319 head = x.shape[:axis] 1320 middle = tensor_shape.TensorShape( 1321 None if count_ is None 1322 else (tensor_shape.dimension_at_index( 1323 x.shape, axis) + count_ * (front + back))) 1324 tail = x.shape[axis+1:] 1325 final_shape = head.concatenate(middle.concatenate(tail)) 1326 else: 1327 final_shape = None 1328 else: 1329 axis = array_ops.where(axis < 0, ndims + axis, axis) 1330 final_shape = None 1331 x = array_ops.pad( 1332 x, 1333 paddings=array_ops.one_hot( 1334 indices=array_ops.stack([axis if front else -1, 1335 axis if back else -1]), 1336 depth=ndims, 1337 axis=0, 1338 on_value=count, 1339 dtype=dtypes.int32), 1340 constant_values=value) 1341 if final_shape is not None: 1342 x.set_shape(final_shape) 1343 return x 1344 1345 1346def parent_frame_arguments(): 1347 """Returns parent frame arguments. 1348 1349 When called inside a function, returns a dictionary with the caller's function 1350 arguments. These are positional arguments and keyword arguments (**kwargs), 1351 while variable arguments (*varargs) are excluded. 1352 1353 When called at global scope, this will return an empty dictionary, since there 1354 are no arguments. 1355 1356 WARNING: If caller function argument names are overloaded before invoking 1357 this method, then values will reflect the overloaded value. For this reason, 1358 we recommend calling `parent_frame_arguments` at the beginning of the 1359 function. 1360 """ 1361 # All arguments and the names used for *varargs, and **kwargs 1362 arg_names, variable_arg_name, keyword_arg_name, local_vars = ( 1363 tf_inspect._inspect.getargvalues( # pylint: disable=protected-access 1364 # Get the first frame of the caller of this method. 1365 tf_inspect._inspect.stack()[1][0])) # pylint: disable=protected-access 1366 1367 # Remove the *varargs, and flatten the **kwargs. Both are 1368 # nested lists. 1369 local_vars.pop(variable_arg_name, {}) 1370 keyword_args = local_vars.pop(keyword_arg_name, {}) 1371 1372 final_args = {} 1373 # Copy over arguments and their values. In general, local_vars 1374 # may contain more than just the arguments, since this method 1375 # can be called anywhere in a function. 1376 for arg_name in arg_names: 1377 final_args[arg_name] = local_vars.pop(arg_name) 1378 final_args.update(keyword_args) 1379 1380 return final_args 1381 1382 1383class AppendDocstring(object): 1384 """Helper class to promote private subclass docstring to public counterpart. 1385 1386 Example: 1387 1388 ```python 1389 class TransformedDistribution(Distribution): 1390 @distribution_util.AppendDocstring( 1391 additional_note="A special note!", 1392 kwargs_dict={"foo": "An extra arg."}) 1393 def _prob(self, y, foo=None): 1394 pass 1395 ``` 1396 1397 In this case, the `AppendDocstring` decorator appends the `additional_note` to 1398 the docstring of `prob` (not `_prob`) and adds a new `kwargs` 1399 section with each dictionary item as a bullet-point. 1400 1401 For a more detailed example, see `TransformedDistribution`. 1402 """ 1403 1404 def __init__(self, additional_note="", kwargs_dict=None): 1405 """Initializes the AppendDocstring object. 1406 1407 Args: 1408 additional_note: Python string added as additional docstring to public 1409 version of function. 1410 kwargs_dict: Python string/string dictionary representing 1411 specific kwargs expanded from the **kwargs input. 1412 1413 Raises: 1414 ValueError: if kwargs_dict.key contains whitespace. 1415 ValueError: if kwargs_dict.value contains newlines. 1416 """ 1417 self._additional_note = additional_note 1418 if kwargs_dict: 1419 bullets = [] 1420 for key in sorted(kwargs_dict.keys()): 1421 value = kwargs_dict[key] 1422 if any(x.isspace() for x in key): 1423 raise ValueError( 1424 "Parameter name \"%s\" contains whitespace." % key) 1425 value = value.lstrip() 1426 if "\n" in value: 1427 raise ValueError( 1428 "Parameter description for \"%s\" contains newlines." % key) 1429 bullets.append("* `%s`: %s" % (key, value)) 1430 self._additional_note += ("\n\n##### `kwargs`:\n\n" + 1431 "\n".join(bullets)) 1432 1433 def __call__(self, fn): 1434 @functools.wraps(fn) 1435 def _fn(*args, **kwargs): 1436 return fn(*args, **kwargs) 1437 if _fn.__doc__ is None: 1438 _fn.__doc__ = self._additional_note 1439 else: 1440 _fn.__doc__ += "\n%s" % self._additional_note 1441 return _fn 1442