• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for probability distributions."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import functools
22import hashlib
23
24import numpy as np
25
26from tensorflow.python.framework import constant_op
27from tensorflow.python.framework import dtypes
28from tensorflow.python.framework import ops
29from tensorflow.python.framework import tensor_shape
30from tensorflow.python.framework import tensor_util
31from tensorflow.python.ops import array_ops
32from tensorflow.python.ops import check_ops
33from tensorflow.python.ops import control_flow_ops
34from tensorflow.python.ops import linalg_ops
35from tensorflow.python.ops import math_ops
36from tensorflow.python.ops import nn
37from tensorflow.python.util import tf_inspect
38
39
40def assert_integer_form(x,
41                        data=None,
42                        summarize=None,
43                        message=None,
44                        int_dtype=None,
45                        name="assert_integer_form"):
46  """Assert that x has integer components (or floats equal to integers).
47
48  Args:
49    x: Floating-point `Tensor`
50    data: The tensors to print out if the condition is `False`. Defaults to
51      error message and first few entries of `x` and `y`.
52    summarize: Print this many entries of each tensor.
53    message: A string to prefix to the default message.
54    int_dtype: A `tf.dtype` used to cast the float to. The default (`None`)
55      implies the smallest possible signed int will be used for casting.
56    name: A name for this operation (optional).
57
58  Returns:
59    Op raising `InvalidArgumentError` if `cast(x, int_dtype) != x`.
60  """
61  with ops.name_scope(name, values=[x, data]):
62    x = ops.convert_to_tensor(x, name="x")
63    if x.dtype.is_integer:
64      return control_flow_ops.no_op()
65    message = message or "{} has non-integer components".format(x)
66    if int_dtype is None:
67      try:
68        int_dtype = {
69            dtypes.float16: dtypes.int16,
70            dtypes.float32: dtypes.int32,
71            dtypes.float64: dtypes.int64,
72        }[x.dtype.base_dtype]
73      except KeyError:
74        raise TypeError("Unrecognized type {}".format(x.dtype.name))
75    return check_ops.assert_equal(
76        x,
77        math_ops.cast(math_ops.cast(x, int_dtype), x.dtype),
78        data=data,
79        summarize=summarize,
80        message=message,
81        name=name)
82
83
84def assert_symmetric(matrix):
85  matrix_t = array_ops.matrix_transpose(matrix)
86  return control_flow_ops.with_dependencies(
87      [check_ops.assert_equal(matrix, matrix_t)], matrix)
88
89
90def embed_check_nonnegative_integer_form(
91    x, name="embed_check_nonnegative_integer_form"):
92  """Assert x is a non-negative tensor, and optionally of integers."""
93  with ops.name_scope(name, values=[x]):
94    x = ops.convert_to_tensor(x, name="x")
95    assertions = [
96        check_ops.assert_non_negative(
97            x, message="'{}' must be non-negative.".format(x)),
98    ]
99    if not x.dtype.is_integer:
100      assertions += [
101          assert_integer_form(
102              x,
103              message="'{}' cannot contain fractional components.".format(x)),
104      ]
105    return control_flow_ops.with_dependencies(assertions, x)
106
107
108def same_dynamic_shape(a, b):
109  """Returns whether a and b have the same dynamic shape.
110
111  Args:
112    a: `Tensor`
113    b: `Tensor`
114
115  Returns:
116    `bool` `Tensor` representing if both tensors have the same shape.
117  """
118  a = ops.convert_to_tensor(a, name="a")
119  b = ops.convert_to_tensor(b, name="b")
120
121  # Here we can't just do math_ops.equal(a.shape, b.shape), since
122  # static shape inference may break the equality comparison between
123  # shape(a) and shape(b) in math_ops.equal.
124  def all_shapes_equal():
125    return math_ops.reduce_all(
126        math_ops.equal(
127            array_ops.concat(
128                [array_ops.shape(a), array_ops.shape(b)], 0),
129            array_ops.concat(
130                [array_ops.shape(b), array_ops.shape(a)], 0)))
131
132  # One of the shapes isn't fully defined, so we need to use the dynamic
133  # shape.
134  return control_flow_ops.cond(
135      math_ops.equal(array_ops.rank(a), array_ops.rank(b)),
136      all_shapes_equal, lambda: constant_op.constant(False))
137
138
139def maybe_get_static_value(x, dtype=None):
140  """Helper which tries to return a static value.
141
142  Given `x`, extract it's value statically, optionally casting to a specific
143  dtype. If this is not possible, None is returned.
144
145  Args:
146    x: `Tensor` for which to extract a value statically.
147    dtype: Optional dtype to cast to.
148
149  Returns:
150    Statically inferred value if possible, otherwise None.
151  """
152  if x is None:
153    return x
154  try:
155    # This returns an np.ndarray.
156    x_ = tensor_util.constant_value(x)
157  except TypeError:
158    x_ = x
159  if x_ is None or dtype is None:
160    return x_
161  return np.array(x_, dtype)
162
163
164def get_logits_and_probs(logits=None,
165                         probs=None,
166                         multidimensional=False,
167                         validate_args=False,
168                         name="get_logits_and_probs",
169                         dtype=None):
170  """Converts logit to probabilities (or vice-versa), and returns both.
171
172  Args:
173    logits: Floating-point `Tensor` representing log-odds.
174    probs: Floating-point `Tensor` representing probabilities.
175    multidimensional: Python `bool`, default `False`. If `True`, represents
176      whether the last dimension of `logits` or `probs`, a `[N1, N2, ...  k]`
177      dimensional tensor, representing the logit or probability of `shape[-1]`
178      classes.
179    validate_args: Python `bool`, default `False`. When `True`, either assert `0
180      <= probs <= 1` (if not `multidimensional`) or that the last dimension of
181      `probs` sums to one.
182    name: A name for this operation (optional).
183    dtype: `tf.DType` to prefer when converting args to `Tensor`s.
184
185  Returns:
186    logits, probs: Tuple of `Tensor`s. If `probs` has an entry that is `0` or
187      `1`, then the corresponding entry in the returned logit will be `-Inf` and
188      `Inf` respectively.
189
190  Raises:
191    ValueError: if neither `probs` nor `logits` were passed in, or both were.
192  """
193  with ops.name_scope(name, values=[probs, logits]):
194    if (probs is None) == (logits is None):
195      raise ValueError("Must pass probs or logits, but not both.")
196
197    if probs is None:
198      logits = ops.convert_to_tensor(logits, name="logits", dtype=dtype)
199      if not logits.dtype.is_floating:
200        raise TypeError("logits must having floating type.")
201      # We can early return since we constructed probs and therefore know
202      # they're valid.
203      if multidimensional:
204        if validate_args:
205          logits = embed_check_categorical_event_shape(logits)
206        return logits, nn.softmax(logits, name="probs")
207      return logits, math_ops.sigmoid(logits, name="probs")
208
209    probs = ops.convert_to_tensor(probs, name="probs", dtype=dtype)
210    if not probs.dtype.is_floating:
211      raise TypeError("probs must having floating type.")
212
213    if validate_args:
214      with ops.name_scope("validate_probs"):
215        one = constant_op.constant(1., probs.dtype)
216        dependencies = [check_ops.assert_non_negative(probs)]
217        if multidimensional:
218          probs = embed_check_categorical_event_shape(probs)
219          dependencies += [
220              check_ops.assert_near(
221                  math_ops.reduce_sum(probs, -1),
222                  one,
223                  message="probs does not sum to 1.")
224          ]
225        else:
226          dependencies += [
227              check_ops.assert_less_equal(
228                  probs, one, message="probs has components greater than 1.")
229          ]
230        probs = control_flow_ops.with_dependencies(dependencies, probs)
231
232    with ops.name_scope("logits"):
233      if multidimensional:
234        # Here we don't compute the multidimensional case, in a manner
235        # consistent with respect to the unidimensional case. We do so
236        # following the TF convention. Typically, you might expect to see
237        # logits = log(probs) - log(probs[pivot]). A side-effect of
238        # being consistent with the TF approach is that the unidimensional case
239        # implicitly handles the second dimension but the multidimensional case
240        # explicitly keeps the pivot dimension.
241        return math_ops.log(probs), probs
242      return math_ops.log(probs) - math_ops.log1p(-1. * probs), probs
243
244
245def _is_known_unsigned_by_dtype(dt):
246  """Helper returning True if dtype is known to be unsigned."""
247  return {
248      dtypes.bool: True,
249      dtypes.uint8: True,
250      dtypes.uint16: True,
251  }.get(dt.base_dtype, False)
252
253
254def _is_known_signed_by_dtype(dt):
255  """Helper returning True if dtype is known to be signed."""
256  return {
257      dtypes.float16: True,
258      dtypes.float32: True,
259      dtypes.float64: True,
260      dtypes.int8: True,
261      dtypes.int16: True,
262      dtypes.int32: True,
263      dtypes.int64: True,
264  }.get(dt.base_dtype, False)
265
266
267def _is_known_dtype(dt):
268  """Helper returning True if dtype is known."""
269  return _is_known_unsigned_by_dtype(dt) or _is_known_signed_by_dtype(dt)
270
271
272def _largest_integer_by_dtype(dt):
273  """Helper returning the largest integer exactly representable by dtype."""
274  if not _is_known_dtype(dt):
275    raise TypeError("Unrecognized dtype: {}".format(dt.name))
276  if dt.is_floating:
277    return int(2**(np.finfo(dt.as_numpy_dtype).nmant + 1))
278  if dt.is_integer:
279    return np.iinfo(dt.as_numpy_dtype).max
280  if dt.base_dtype == dtypes.bool:
281    return int(1)
282  # We actually can't land here but keep the case for completeness.
283  raise TypeError("Unrecognized dtype: {}".format(dt.name))
284
285
286def _smallest_integer_by_dtype(dt):
287  """Helper returning the smallest integer exactly representable by dtype."""
288  if not _is_known_dtype(dt):
289    raise TypeError("Unrecognized dtype: {}".format(dt.name))
290  if _is_known_unsigned_by_dtype(dt):
291    return 0
292  return -1 * _largest_integer_by_dtype(dt)
293
294
295def _is_integer_like_by_dtype(dt):
296  """Helper returning True if dtype.is_integer or is `bool`."""
297  if not _is_known_dtype(dt):
298    raise TypeError("Unrecognized dtype: {}".format(dt.name))
299  return dt.is_integer or dt.base_dtype == dtypes.bool
300
301
302def embed_check_categorical_event_shape(
303    categorical_param, name="embed_check_categorical_event_shape"):
304  """Embeds checks that categorical distributions don't have too many classes.
305
306  A categorical-type distribution is one which, e.g., returns the class label
307  rather than a one-hot encoding.  E.g., `Categorical(probs)`.
308
309  Since distributions output samples in the same dtype as the parameters, we
310  must ensure that casting doesn't lose precision. That is, the
311  `parameter.dtype` implies a maximum number of classes. However, since shape is
312  `int32` and categorical variables are presumed to be indexes into a `Tensor`,
313  we must also ensure that the number of classes is no larger than the largest
314  possible `int32` index, i.e., `2**31-1`.
315
316  In other words the number of classes, `K`, must satisfy the following
317  condition:
318
319  ```python
320  K <= min(
321      int(2**31 - 1),  # Largest float as an index.
322      {
323          dtypes.float16: int(2**11),   # Largest int as a float16.
324          dtypes.float32: int(2**24),
325          dtypes.float64: int(2**53),
326      }.get(categorical_param.dtype.base_dtype, 0))
327  ```
328
329  Args:
330    categorical_param: Floating-point `Tensor` representing parameters of
331      distribution over categories. The rightmost shape is presumed to be the
332      number of categories.
333    name: A name for this operation (optional).
334
335  Returns:
336    categorical_param: Input `Tensor` with appropriate assertions embedded.
337
338  Raises:
339    TypeError: if `categorical_param` has an unknown `dtype`.
340    ValueError: if we can statically identify `categorical_param` as being too
341      large (for being closed under int32/float casting).
342  """
343  with ops.name_scope(name, values=[categorical_param]):
344    x = ops.convert_to_tensor(categorical_param, name="categorical_param")
345    # The size must not exceed both of:
346    # - The largest possible int32 (since categorical values are presumed to be
347    #   indexes into a Tensor).
348    # - The largest possible integer exactly representable under the given
349    #   floating-point dtype (since we need to cast to/from).
350    #
351    # The chosen floating-point thresholds are 2**(1 + mantissa_bits).
352    # For more details, see:
353    # https://en.wikipedia.org/wiki/Floating-point_arithmetic#Internal_representation
354    x_dtype = x.dtype.base_dtype
355    max_event_size = (
356        _largest_integer_by_dtype(x_dtype) if x_dtype.is_floating else 0)
357    if max_event_size == 0:
358      raise TypeError("Unable to validate size of unrecognized dtype "
359                      "({}).".format(x_dtype.name))
360    try:
361      x_shape_static = x.get_shape().with_rank_at_least(1)
362    except ValueError:
363      raise ValueError("A categorical-distribution parameter must have "
364                       "at least 1 dimension.")
365    if tensor_shape.dimension_value(x_shape_static[-1]) is not None:
366      event_size = x_shape_static.dims[-1].value
367      if event_size < 2:
368        raise ValueError("A categorical-distribution parameter must have at "
369                         "least 2 events.")
370      if event_size > max_event_size:
371        raise ValueError("Number of classes exceeds `dtype` precision, i.e., "
372                         "{} implies shape ({}) cannot exceed {}.".format(
373                             x_dtype.name, event_size, max_event_size))
374      return x
375    else:
376      event_size = array_ops.shape(x, name="x_shape")[-1]
377      return control_flow_ops.with_dependencies([
378          check_ops.assert_rank_at_least(
379              x,
380              1,
381              message=("A categorical-distribution parameter must have "
382                       "at least 1 dimension.")),
383          check_ops.assert_greater_equal(
384              array_ops.shape(x)[-1],
385              2,
386              message=("A categorical-distribution parameter must have at "
387                       "least 2 events.")),
388          check_ops.assert_less_equal(
389              event_size,
390              max_event_size,
391              message="Number of classes exceeds `dtype` precision, "
392              "i.e., {} dtype cannot exceed {} shape.".format(
393                  x_dtype.name, max_event_size)),
394      ], x)
395
396
397def embed_check_integer_casting_closed(x,
398                                       target_dtype,
399                                       assert_nonnegative=True,
400                                       name="embed_check_casting_closed"):
401  """Ensures integers remain unaffected despite casting to/from int/float types.
402
403  Example integer-types: `uint8`, `int32`, `bool`.
404  Example floating-types: `float32`, `float64`.
405
406  The largest possible integer representable by an IEEE754 floating-point is
407  `2**(1 + mantissa_bits)` yet the largest possible integer as an int-type is
408  `2**(bits - 1) - 1`. This function ensures that a `Tensor` purporting to have
409  integer-form values can be cast to some other type without loss of precision.
410
411  The smallest representable integer is the negative of the largest
412  representable integer, except for types: `uint8`, `uint16`, `bool`. For these
413  types, the smallest representable integer is `0`.
414
415  Args:
416    x: `Tensor` representing integer-form values.
417    target_dtype: TF `dtype` under which `x` should have identical values.
418    assert_nonnegative: `bool` indicating `x` should contain nonnegative values.
419    name: A name for this operation (optional).
420
421  Returns:
422    x: Input `Tensor` with appropriate assertions embedded.
423
424  Raises:
425    TypeError: if `x` is neither integer- nor floating-type.
426    TypeError: if `target_dtype` is neither integer- nor floating-type.
427    TypeError: if neither `x` nor `target_dtype` are integer-type.
428  """
429
430  with ops.name_scope(name, values=[x]):
431    x = ops.convert_to_tensor(x, name="x")
432    if (not _is_integer_like_by_dtype(x.dtype) and not x.dtype.is_floating):
433      raise TypeError("{}.dtype must be floating- or "
434                      "integer-type.".format(x.dtype.name))
435    if (not _is_integer_like_by_dtype(target_dtype) and
436        not target_dtype.is_floating):
437      raise TypeError("target_dtype ({}) must be floating- or "
438                      "integer-type.".format(target_dtype.name))
439    if (not _is_integer_like_by_dtype(x.dtype) and
440        not _is_integer_like_by_dtype(target_dtype)):
441      raise TypeError("At least one of {}.dtype ({}) and target_dtype ({}) "
442                      "must be integer-type.".format(x, x.dtype.name,
443                                                     target_dtype.name))
444
445    assertions = []
446    if assert_nonnegative:
447      assertions += [
448          check_ops.assert_non_negative(
449              x, message="Elements must be non-negative."),
450      ]
451
452    if x.dtype.is_floating:
453      # Being here means _is_integer_like_by_dtype(target_dtype) = True.
454      # Since this check implies the magnitude check below, we need only it.
455      assertions += [
456          assert_integer_form(
457              x,
458              int_dtype=target_dtype,
459              message="Elements must be {}-equivalent.".format(
460                  target_dtype.name)),
461      ]
462    else:
463      if (_largest_integer_by_dtype(x.dtype) >
464          _largest_integer_by_dtype(target_dtype)):
465        # Cast may lose integer precision.
466        assertions += [
467            check_ops.assert_less_equal(
468                x,
469                _largest_integer_by_dtype(target_dtype),
470                message=("Elements cannot exceed {}.".format(
471                    _largest_integer_by_dtype(target_dtype)))),
472        ]
473      if (not assert_nonnegative and (_smallest_integer_by_dtype(
474          x.dtype) < _smallest_integer_by_dtype(target_dtype))):
475        assertions += [
476            check_ops.assert_greater_equal(
477                x,
478                _smallest_integer_by_dtype(target_dtype),
479                message=("Elements cannot be smaller than {}.".format(
480                    _smallest_integer_by_dtype(target_dtype)))),
481        ]
482
483    if not assertions:
484      return x
485    return control_flow_ops.with_dependencies(assertions, x)
486
487
488def log_combinations(n, counts, name="log_combinations"):
489  """Multinomial coefficient.
490
491  Given `n` and `counts`, where `counts` has last dimension `k`, we compute
492  the multinomial coefficient as:
493
494  ```n! / sum_i n_i!```
495
496  where `i` runs over all `k` classes.
497
498  Args:
499    n: Floating-point `Tensor` broadcastable with `counts`. This represents `n`
500      outcomes.
501    counts: Floating-point `Tensor` broadcastable with `n`. This represents
502      counts in `k` classes, where `k` is the last dimension of the tensor.
503    name: A name for this operation (optional).
504
505  Returns:
506    `Tensor` representing the multinomial coefficient between `n` and `counts`.
507  """
508  # First a bit about the number of ways counts could have come in:
509  # E.g. if counts = [1, 2], then this is 3 choose 2.
510  # In general, this is (sum counts)! / sum(counts!)
511  # The sum should be along the last dimension of counts. This is the
512  # "distribution" dimension. Here n a priori represents the sum of counts.
513  with ops.name_scope(name, values=[n, counts]):
514    n = ops.convert_to_tensor(n, name="n")
515    counts = ops.convert_to_tensor(counts, name="counts")
516    total_permutations = math_ops.lgamma(n + 1)
517    counts_factorial = math_ops.lgamma(counts + 1)
518    redundant_permutations = math_ops.reduce_sum(counts_factorial, axis=[-1])
519    return total_permutations - redundant_permutations
520
521
522def matrix_diag_transform(matrix, transform=None, name=None):
523  """Transform diagonal of [batch-]matrix, leave rest of matrix unchanged.
524
525  Create a trainable covariance defined by a Cholesky factor:
526
527  ```python
528  # Transform network layer into 2 x 2 array.
529  matrix_values = tf.contrib.layers.fully_connected(activations, 4)
530  matrix = tf.reshape(matrix_values, (batch_size, 2, 2))
531
532  # Make the diagonal positive. If the upper triangle was zero, this would be a
533  # valid Cholesky factor.
534  chol = matrix_diag_transform(matrix, transform=tf.nn.softplus)
535
536  # LinearOperatorLowerTriangular ignores the upper triangle.
537  operator = LinearOperatorLowerTriangular(chol)
538  ```
539
540  Example of heteroskedastic 2-D linear regression.
541
542  ```python
543  tfd = tfp.distributions
544
545  # Get a trainable Cholesky factor.
546  matrix_values = tf.contrib.layers.fully_connected(activations, 4)
547  matrix = tf.reshape(matrix_values, (batch_size, 2, 2))
548  chol = matrix_diag_transform(matrix, transform=tf.nn.softplus)
549
550  # Get a trainable mean.
551  mu = tf.contrib.layers.fully_connected(activations, 2)
552
553  # This is a fully trainable multivariate normal!
554  dist = tfd.MultivariateNormalTriL(mu, chol)
555
556  # Standard log loss. Minimizing this will "train" mu and chol, and then dist
557  # will be a distribution predicting labels as multivariate Gaussians.
558  loss = -1 * tf.reduce_mean(dist.log_prob(labels))
559  ```
560
561  Args:
562    matrix:  Rank `R` `Tensor`, `R >= 2`, where the last two dimensions are
563      equal.
564    transform:  Element-wise function mapping `Tensors` to `Tensors`. To be
565      applied to the diagonal of `matrix`. If `None`, `matrix` is returned
566      unchanged. Defaults to `None`.
567    name:  A name to give created ops. Defaults to "matrix_diag_transform".
568
569  Returns:
570    A `Tensor` with same shape and `dtype` as `matrix`.
571  """
572  with ops.name_scope(name, "matrix_diag_transform", [matrix]):
573    matrix = ops.convert_to_tensor(matrix, name="matrix")
574    if transform is None:
575      return matrix
576    # Replace the diag with transformed diag.
577    diag = array_ops.matrix_diag_part(matrix)
578    transformed_diag = transform(diag)
579    transformed_mat = array_ops.matrix_set_diag(matrix, transformed_diag)
580
581  return transformed_mat
582
583
584def rotate_transpose(x, shift, name="rotate_transpose"):
585  """Circularly moves dims left or right.
586
587  Effectively identical to:
588
589  ```python
590  numpy.transpose(x, numpy.roll(numpy.arange(len(x.shape)), shift))
591  ```
592
593  When `validate_args=False` additional graph-runtime checks are
594  performed. These checks entail moving data from to GPU to CPU.
595
596  Example:
597
598  ```python
599  x = tf.random.normal([1, 2, 3, 4])  # Tensor of shape [1, 2, 3, 4].
600  rotate_transpose(x, -1).shape == [2, 3, 4, 1]
601  rotate_transpose(x, -2).shape == [3, 4, 1, 2]
602  rotate_transpose(x,  1).shape == [4, 1, 2, 3]
603  rotate_transpose(x,  2).shape == [3, 4, 1, 2]
604  rotate_transpose(x,  7).shape == rotate_transpose(x, 3).shape  # [2, 3, 4, 1]
605  rotate_transpose(x, -7).shape == rotate_transpose(x, -3).shape  # [4, 1, 2, 3]
606  ```
607
608  Args:
609    x: `Tensor`.
610    shift: `Tensor`. Number of dimensions to transpose left (shift<0) or
611      transpose right (shift>0).
612    name: Python `str`. The name to give this op.
613
614  Returns:
615    rotated_x: Input `Tensor` with dimensions circularly rotated by shift.
616
617  Raises:
618    TypeError: if shift is not integer type.
619  """
620  with ops.name_scope(name, values=[x, shift]):
621    x = ops.convert_to_tensor(x, name="x")
622    shift = ops.convert_to_tensor(shift, name="shift")
623    # We do not assign back to preserve constant-ness.
624    check_ops.assert_integer(shift)
625    shift_value_static = tensor_util.constant_value(shift)
626    ndims = x.get_shape().ndims
627    if ndims is not None and shift_value_static is not None:
628      if ndims < 2:
629        return x
630      shift_value_static = np.sign(shift_value_static) * (
631          abs(shift_value_static) % ndims)
632      if shift_value_static == 0:
633        return x
634      perm = np.roll(np.arange(ndims), shift_value_static)
635      return array_ops.transpose(x, perm=perm)
636    else:
637      # Consider if we always had a positive shift, and some specified
638      # direction.
639      # When shifting left we want the new array:
640      #   last(x, n-shift) + first(x, shift)
641      # and if shifting right then we want:
642      #   last(x, shift) + first(x, n-shift)
643      # Observe that last(a) == slice(a, n) and first(a) == slice(0, a).
644      # Also, we can encode direction and shift as one: direction * shift.
645      # Combining these facts, we have:
646      #   a = cond(shift<0, -shift, n-shift)
647      #   last(x, n-a) + first(x, a) == x[a:n] + x[0:a]
648      # Finally, we transform shift by modulo length so it can be specified
649      # independently from the array upon which it operates (like python).
650      ndims = array_ops.rank(x)
651      shift = array_ops.where_v2(
652          math_ops.less(shift, 0),
653          math_ops.mod(-shift, ndims),  # pylint: disable=invalid-unary-operand-type
654          ndims - math_ops.mod(shift, ndims))
655      first = math_ops.range(0, shift)
656      last = math_ops.range(shift, ndims)
657      perm = array_ops.concat([last, first], 0)
658      return array_ops.transpose(x, perm=perm)
659
660
661def pick_vector(cond, true_vector, false_vector, name="pick_vector"):
662  """Picks possibly different length row `Tensor`s based on condition.
663
664  Value `Tensor`s should have exactly one dimension.
665
666  If `cond` is a python Boolean or `tf.constant` then either `true_vector` or
667  `false_vector` is immediately returned. I.e., no graph nodes are created and
668  no validation happens.
669
670  Args:
671    cond: `Tensor`. Must have `dtype=tf.bool` and be scalar.
672    true_vector: `Tensor` of one dimension. Returned when cond is `True`.
673    false_vector: `Tensor` of one dimension. Returned when cond is `False`.
674    name: Python `str`. The name to give this op.
675  Example:  ```python pick_vector(tf.less(0, 5), tf.range(10, 12), tf.range(15,
676    18))  # [10, 11] pick_vector(tf.less(5, 0), tf.range(10, 12), tf.range(15,
677    18))  # [15, 16, 17] ```
678
679  Returns:
680    true_or_false_vector: `Tensor`.
681
682  Raises:
683    TypeError: if `cond.dtype != tf.bool`
684    TypeError: if `cond` is not a constant and
685      `true_vector.dtype != false_vector.dtype`
686  """
687  with ops.name_scope(name, values=(cond, true_vector, false_vector)):
688    cond = ops.convert_to_tensor(cond, name="cond")
689    if cond.dtype != dtypes.bool:
690      raise TypeError("%s.dtype=%s which is not %s" %
691                      (cond, cond.dtype, dtypes.bool))
692    cond_value_static = tensor_util.constant_value(cond)
693    if cond_value_static is not None:
694      return true_vector if cond_value_static else false_vector
695    true_vector = ops.convert_to_tensor(true_vector, name="true_vector")
696    false_vector = ops.convert_to_tensor(false_vector, name="false_vector")
697    if true_vector.dtype != false_vector.dtype:
698      raise TypeError(
699          "%s.dtype=%s does not match %s.dtype=%s" %
700          (true_vector, true_vector.dtype, false_vector, false_vector.dtype))
701    n = array_ops.shape(true_vector)[0]
702    return array_ops.slice(
703        array_ops.concat([true_vector, false_vector], 0),
704        [array_ops.where_v2(cond, 0, n)], [array_ops.where(cond, n, -1)])
705
706
707def prefer_static_broadcast_shape(shape1,
708                                  shape2,
709                                  name="prefer_static_broadcast_shape"):
710  """Convenience function which statically broadcasts shape when possible.
711
712  Args:
713    shape1:  `1-D` integer `Tensor`.  Already converted to tensor!
714    shape2:  `1-D` integer `Tensor`.  Already converted to tensor!
715    name:  A string name to prepend to created ops.
716
717  Returns:
718    The broadcast shape, either as `TensorShape` (if broadcast can be done
719      statically), or as a `Tensor`.
720  """
721  with ops.name_scope(name, values=[shape1, shape2]):
722
723    def make_shape_tensor(x):
724      return ops.convert_to_tensor(x, name="shape", dtype=dtypes.int32)
725
726    def get_tensor_shape(s):
727      if isinstance(s, tensor_shape.TensorShape):
728        return s
729      s_ = tensor_util.constant_value(make_shape_tensor(s))
730      if s_ is not None:
731        return tensor_shape.TensorShape(s_)
732      return None
733
734    def get_shape_tensor(s):
735      if not isinstance(s, tensor_shape.TensorShape):
736        return make_shape_tensor(s)
737      if s.is_fully_defined():
738        return make_shape_tensor(s.as_list())
739      raise ValueError("Cannot broadcast from partially "
740                       "defined `TensorShape`.")
741
742    shape1_ = get_tensor_shape(shape1)
743    shape2_ = get_tensor_shape(shape2)
744    if shape1_ is not None and shape2_ is not None:
745      return array_ops.broadcast_static_shape(shape1_, shape2_)
746
747    shape1_ = get_shape_tensor(shape1)
748    shape2_ = get_shape_tensor(shape2)
749    return array_ops.broadcast_dynamic_shape(shape1_, shape2_)
750
751
752def prefer_static_rank(x):
753  """Return static rank of tensor `x` if available, else `tf.rank(x)`.
754
755  Args:
756    x: `Tensor` (already converted).
757
758  Returns:
759    Numpy array (if static rank is obtainable), else `Tensor`.
760  """
761  return prefer_static_value(array_ops.rank(x))
762
763
764def prefer_static_shape(x):
765  """Return static shape of tensor `x` if available, else `tf.shape(x)`.
766
767  Args:
768    x: `Tensor` (already converted).
769
770  Returns:
771    Numpy array (if static shape is obtainable), else `Tensor`.
772  """
773  return prefer_static_value(array_ops.shape(x))
774
775
776def prefer_static_value(x):
777  """Return static value of tensor `x` if available, else `x`.
778
779  Args:
780    x: `Tensor` (already converted).
781
782  Returns:
783    Numpy array (if static value is obtainable), else `Tensor`.
784  """
785  static_x = tensor_util.constant_value(x)
786  if static_x is not None:
787    return static_x
788  return x
789
790
791def gen_new_seed(seed, salt):
792  """Generate a new seed, from the given seed and salt."""
793  if seed is None:
794    return None
795  string = (str(seed) + salt).encode("utf-8")
796  return int(hashlib.md5(string).hexdigest()[:8], 16) & 0x7FFFFFFF
797
798
799def fill_triangular(x, upper=False, name=None):
800  """Creates a (batch of) triangular matrix from a vector of inputs.
801
802  Created matrix can be lower- or upper-triangular. (It is more efficient to
803  create the matrix as upper or lower, rather than transpose.)
804
805  Triangular matrix elements are filled in a clockwise spiral. See example,
806  below.
807
808  If `x.get_shape()` is `[b1, b2, ..., bB, d]` then the output shape is
809  `[b1, b2, ..., bB, n, n]` where `n` is such that `d = n(n+1)/2`, i.e.,
810  `n = int(np.sqrt(0.25 + 2. * m) - 0.5)`.
811
812  Example:
813
814  ```python
815  fill_triangular([1, 2, 3, 4, 5, 6])
816  # ==> [[4, 0, 0],
817  #      [6, 5, 0],
818  #      [3, 2, 1]]
819
820  fill_triangular([1, 2, 3, 4, 5, 6], upper=True)
821  # ==> [[1, 2, 3],
822  #      [0, 5, 6],
823  #      [0, 0, 4]]
824  ```
825
826  For comparison, a pure numpy version of this function can be found in
827  `util_test.py`, function `_fill_triangular`.
828
829  Args:
830    x: `Tensor` representing lower (or upper) triangular elements.
831    upper: Python `bool` representing whether output matrix should be upper
832      triangular (`True`) or lower triangular (`False`, default).
833    name: Python `str`. The name to give this op.
834
835  Returns:
836    tril: `Tensor` with lower (or upper) triangular elements filled from `x`.
837
838  Raises:
839    ValueError: if `x` cannot be mapped to a triangular matrix.
840  """
841
842  with ops.name_scope(name, "fill_triangular", values=[x]):
843    x = ops.convert_to_tensor(x, name="x")
844    if tensor_shape.dimension_value(
845        x.shape.with_rank_at_least(1)[-1]) is not None:
846      # Formula derived by solving for n: m = n(n+1)/2.
847      m = np.int32(x.shape.dims[-1].value)
848      n = np.sqrt(0.25 + 2. * m) - 0.5
849      if n != np.floor(n):
850        raise ValueError("Input right-most shape ({}) does not "
851                         "correspond to a triangular matrix.".format(m))
852      n = np.int32(n)
853      static_final_shape = x.shape[:-1].concatenate([n, n])
854    else:
855      m = array_ops.shape(x)[-1]
856      # For derivation, see above. Casting automatically lops off the 0.5, so we
857      # omit it.  We don't validate n is an integer because this has
858      # graph-execution cost; an error will be thrown from the reshape, below.
859      n = math_ops.cast(
860          math_ops.sqrt(0.25 + math_ops.cast(2 * m, dtype=dtypes.float32)),
861          dtype=dtypes.int32)
862      static_final_shape = x.shape.with_rank_at_least(1)[:-1].concatenate(
863          [None, None])
864    # We now concatenate the "tail" of `x` to `x` (and reverse one of them).
865    #
866    # We do this based on the insight that the input `x` provides `ceil(n/2)`
867    # rows of an `n x n` matrix, some of which will get zeroed out being on the
868    # wrong side of the diagonal. The first row will not get zeroed out at all,
869    # and we need `floor(n/2)` more rows, so the first is what we omit from
870    # `x_tail`. If we then stack those `ceil(n/2)` rows with the `floor(n/2)`
871    # rows provided by a reversed tail, it is exactly the other set of elements
872    # of the reversed tail which will be zeroed out for being on the wrong side
873    # of the diagonal further up/down the matrix. And, in doing-so, we've filled
874    # the triangular matrix in a clock-wise spiral pattern. Neat!
875    #
876    # Try it out in numpy:
877    #  n = 3
878    #  x = np.arange(n * (n + 1) / 2)
879    #  m = x.shape[0]
880    #  n = np.int32(np.sqrt(.25 + 2 * m) - .5)
881    #  x_tail = x[(m - (n**2 - m)):]
882    #  np.concatenate([x_tail, x[::-1]], 0).reshape(n, n)  # lower
883    #  # ==> array([[3, 4, 5],
884    #               [5, 4, 3],
885    #               [2, 1, 0]])
886    #  np.concatenate([x, x_tail[::-1]], 0).reshape(n, n)  # upper
887    #  # ==> array([[0, 1, 2],
888    #               [3, 4, 5],
889    #               [5, 4, 3]])
890    #
891    # Note that we can't simply do `x[..., -(n**2 - m):]` because this doesn't
892    # correctly handle `m == n == 1`. Hence, we do nonnegative indexing.
893    # Furthermore observe that:
894    #   m - (n**2 - m)
895    #   = n**2 / 2 + n / 2 - (n**2 - n**2 / 2 + n / 2)
896    #   = 2 (n**2 / 2 + n / 2) - n**2
897    #   = n**2 + n - n**2
898    #   = n
899    ndims = prefer_static_rank(x)
900    if upper:
901      x_list = [x, array_ops.reverse(x[..., n:], axis=[ndims - 1])]
902    else:
903      x_list = [x[..., n:], array_ops.reverse(x, axis=[ndims - 1])]
904    new_shape = (
905        static_final_shape.as_list() if static_final_shape.is_fully_defined()
906        else array_ops.concat([array_ops.shape(x)[:-1], [n, n]], axis=0))
907    x = array_ops.reshape(array_ops.concat(x_list, axis=-1), new_shape)
908    x = array_ops.matrix_band_part(
909        x, num_lower=(0 if upper else -1), num_upper=(-1 if upper else 0))
910    x.set_shape(static_final_shape)
911    return x
912
913
914def fill_triangular_inverse(x, upper=False, name=None):
915  """Creates a vector from a (batch of) triangular matrix.
916
917  The vector is created from the lower-triangular or upper-triangular portion
918  depending on the value of the parameter `upper`.
919
920  If `x.shape` is `[b1, b2, ..., bB, n, n]` then the output shape is
921  `[b1, b2, ..., bB, d]` where `d = n (n + 1) / 2`.
922
923  Example:
924
925  ```python
926  fill_triangular_inverse(
927    [[4, 0, 0],
928     [6, 5, 0],
929     [3, 2, 1]])
930
931  # ==> [1, 2, 3, 4, 5, 6]
932
933  fill_triangular_inverse(
934    [[1, 2, 3],
935     [0, 5, 6],
936     [0, 0, 4]], upper=True)
937
938  # ==> [1, 2, 3, 4, 5, 6]
939  ```
940
941  Args:
942    x: `Tensor` representing lower (or upper) triangular elements.
943    upper: Python `bool` representing whether output matrix should be upper
944      triangular (`True`) or lower triangular (`False`, default).
945    name: Python `str`. The name to give this op.
946
947  Returns:
948    flat_tril: (Batch of) vector-shaped `Tensor` representing vectorized lower
949      (or upper) triangular elements from `x`.
950  """
951
952  with ops.name_scope(name, "fill_triangular_inverse", values=[x]):
953    x = ops.convert_to_tensor(x, name="x")
954    if tensor_shape.dimension_value(
955        x.shape.with_rank_at_least(2)[-1]) is not None:
956      n = np.int32(x.shape.dims[-1].value)
957      m = np.int32((n * (n + 1)) // 2)
958      static_final_shape = x.shape[:-2].concatenate([m])
959    else:
960      n = array_ops.shape(x)[-1]
961      m = (n * (n + 1)) // 2
962      static_final_shape = x.shape.with_rank_at_least(2)[:-2].concatenate(
963          [None])
964    ndims = prefer_static_rank(x)
965    if upper:
966      initial_elements = x[..., 0, :]
967      triangular_portion = x[..., 1:, :]
968    else:
969      initial_elements = array_ops.reverse(x[..., -1, :], axis=[ndims - 2])
970      triangular_portion = x[..., :-1, :]
971    rotated_triangular_portion = array_ops.reverse(
972        array_ops.reverse(triangular_portion, axis=[ndims - 1]),
973        axis=[ndims - 2])
974    consolidated_matrix = triangular_portion + rotated_triangular_portion
975    end_sequence = array_ops.reshape(
976        consolidated_matrix,
977        array_ops.concat([array_ops.shape(x)[:-2], [n * (n - 1)]], axis=0))
978    y = array_ops.concat([initial_elements, end_sequence[..., :m - n]], axis=-1)
979    y.set_shape(static_final_shape)
980    return y
981
982
983def tridiag(below=None, diag=None, above=None, name=None):
984  """Creates a matrix with values set above, below, and on the diagonal.
985
986  Example:
987
988  ```python
989  tridiag(below=[1., 2., 3.],
990          diag=[4., 5., 6., 7.],
991          above=[8., 9., 10.])
992  # ==> array([[  4.,   8.,   0.,   0.],
993  #            [  1.,   5.,   9.,   0.],
994  #            [  0.,   2.,   6.,  10.],
995  #            [  0.,   0.,   3.,   7.]], dtype=float32)
996  ```
997
998  Warning: This Op is intended for convenience, not efficiency.
999
1000  Args:
1001    below: `Tensor` of shape `[B1, ..., Bb, d-1]` corresponding to the below
1002      diagonal part. `None` is logically equivalent to `below = 0`.
1003    diag: `Tensor` of shape `[B1, ..., Bb, d]` corresponding to the diagonal
1004      part.  `None` is logically equivalent to `diag = 0`.
1005    above: `Tensor` of shape `[B1, ..., Bb, d-1]` corresponding to the above
1006      diagonal part.  `None` is logically equivalent to `above = 0`.
1007    name: Python `str`. The name to give this op.
1008
1009  Returns:
1010    tridiag: `Tensor` with values set above, below and on the diagonal.
1011
1012  Raises:
1013    ValueError: if all inputs are `None`.
1014  """
1015
1016  def _pad(x):
1017    """Prepends and appends a zero to every vector in a batch of vectors."""
1018    shape = array_ops.concat([array_ops.shape(x)[:-1], [1]], axis=0)
1019    z = array_ops.zeros(shape, dtype=x.dtype)
1020    return array_ops.concat([z, x, z], axis=-1)
1021
1022  def _add(*x):
1023    """Adds list of Tensors, ignoring `None`."""
1024    s = None
1025    for y in x:
1026      if y is None:
1027        continue
1028      elif s is None:
1029        s = y
1030      else:
1031        s += y
1032    if s is None:
1033      raise ValueError("Must specify at least one of `below`, `diag`, `above`.")
1034    return s
1035
1036  with ops.name_scope(name, "tridiag", [below, diag, above]):
1037    if below is not None:
1038      below = ops.convert_to_tensor(below, name="below")
1039      below = array_ops.matrix_diag(_pad(below))[..., :-1, 1:]
1040    if diag is not None:
1041      diag = ops.convert_to_tensor(diag, name="diag")
1042      diag = array_ops.matrix_diag(diag)
1043    if above is not None:
1044      above = ops.convert_to_tensor(above, name="above")
1045      above = array_ops.matrix_diag(_pad(above))[..., 1:, :-1]
1046    # TODO(jvdillon): Consider using scatter_nd instead of creating three full
1047    # matrices.
1048    return _add(below, diag, above)
1049
1050
1051def reduce_weighted_logsumexp(logx,
1052                              w=None,
1053                              axis=None,
1054                              keep_dims=False,
1055                              return_sign=False,
1056                              name=None):
1057  """Computes `log(abs(sum(weight * exp(elements across tensor dimensions))))`.
1058
1059  If all weights `w` are known to be positive, it is more efficient to directly
1060  use `reduce_logsumexp`, i.e., `tf.reduce_logsumexp(logx + tf.math.log(w))` is
1061  more
1062  efficient than `du.reduce_weighted_logsumexp(logx, w)`.
1063
1064  Reduces `input_tensor` along the dimensions given in `axis`.
1065  Unless `keep_dims` is true, the rank of the tensor is reduced by 1 for each
1066  entry in `axis`. If `keep_dims` is true, the reduced dimensions
1067  are retained with length 1.
1068
1069  If `axis` has no entries, all dimensions are reduced, and a
1070  tensor with a single element is returned.
1071
1072  This function is more numerically stable than log(sum(w * exp(input))). It
1073  avoids overflows caused by taking the exp of large inputs and underflows
1074  caused by taking the log of small inputs.
1075
1076  For example:
1077
1078  ```python
1079  x = tf.constant([[0., 0, 0],
1080                   [0, 0, 0]])
1081
1082  w = tf.constant([[-1., 1, 1],
1083                   [1, 1, 1]])
1084
1085  du.reduce_weighted_logsumexp(x, w)
1086  # ==> log(-1*1 + 1*1 + 1*1 + 1*1 + 1*1 + 1*1) = log(4)
1087
1088  du.reduce_weighted_logsumexp(x, w, axis=0)
1089  # ==> [log(-1+1), log(1+1), log(1+1)]
1090
1091  du.reduce_weighted_logsumexp(x, w, axis=1)
1092  # ==> [log(-1+1+1), log(1+1+1)]
1093
1094  du.reduce_weighted_logsumexp(x, w, axis=1, keep_dims=True)
1095  # ==> [[log(-1+1+1)], [log(1+1+1)]]
1096
1097  du.reduce_weighted_logsumexp(x, w, axis=[0, 1])
1098  # ==> log(-1+5)
1099  ```
1100
1101  Args:
1102    logx: The tensor to reduce. Should have numeric type.
1103    w: The weight tensor. Should have numeric type identical to `logx`.
1104    axis: The dimensions to reduce. If `None` (the default), reduces all
1105      dimensions. Must be in the range `[-rank(input_tensor),
1106      rank(input_tensor))`.
1107    keep_dims: If true, retains reduced dimensions with length 1.
1108    return_sign: If `True`, returns the sign of the result.
1109    name: A name for the operation (optional).
1110
1111  Returns:
1112    lswe: The `log(abs(sum(weight * exp(x))))` reduced tensor.
1113    sign: (Optional) The sign of `sum(weight * exp(x))`.
1114  """
1115  with ops.name_scope(name, "reduce_weighted_logsumexp", [logx, w]):
1116    logx = ops.convert_to_tensor(logx, name="logx")
1117    if w is None:
1118      lswe = math_ops.reduce_logsumexp(logx, axis=axis, keepdims=keep_dims)
1119      if return_sign:
1120        sgn = array_ops.ones_like(lswe)
1121        return lswe, sgn
1122      return lswe
1123    w = ops.convert_to_tensor(w, dtype=logx.dtype, name="w")
1124    log_absw_x = logx + math_ops.log(math_ops.abs(w))
1125    max_log_absw_x = math_ops.reduce_max(log_absw_x, axis=axis, keepdims=True)
1126    # If the largest element is `-inf` or `inf` then we don't bother subtracting
1127    # off the max. We do this because otherwise we'd get `inf - inf = NaN`. That
1128    # this is ok follows from the fact that we're actually free to subtract any
1129    # value we like, so long as we add it back after taking the `log(sum(...))`.
1130    max_log_absw_x = array_ops.where_v2(
1131        math_ops.is_inf(max_log_absw_x), array_ops.zeros_like(max_log_absw_x),
1132        max_log_absw_x)
1133    wx_over_max_absw_x = (
1134        math_ops.sign(w) * math_ops.exp(log_absw_x - max_log_absw_x))
1135    sum_wx_over_max_absw_x = math_ops.reduce_sum(
1136        wx_over_max_absw_x, axis=axis, keepdims=keep_dims)
1137    if not keep_dims:
1138      max_log_absw_x = array_ops.squeeze(max_log_absw_x, axis)
1139    sgn = math_ops.sign(sum_wx_over_max_absw_x)
1140    lswe = max_log_absw_x + math_ops.log(sgn * sum_wx_over_max_absw_x)
1141    if return_sign:
1142      return lswe, sgn
1143    return lswe
1144
1145
1146# TODO(jvdillon): Merge this test back into:
1147# tensorflow/python/ops/softplus_op_test.py
1148# once TF core is accepting new ops.
1149def softplus_inverse(x, name=None):
1150  """Computes the inverse softplus, i.e., x = softplus_inverse(softplus(x)).
1151
1152  Mathematically this op is equivalent to:
1153
1154  ```none
1155  softplus_inverse = log(exp(x) - 1.)
1156  ```
1157
1158  Args:
1159    x: `Tensor`. Non-negative (not enforced), floating-point.
1160    name: A name for the operation (optional).
1161
1162  Returns:
1163    `Tensor`. Has the same type/shape as input `x`.
1164  """
1165  with ops.name_scope(name, "softplus_inverse", values=[x]):
1166    x = ops.convert_to_tensor(x, name="x")
1167    # We begin by deriving a more numerically stable softplus_inverse:
1168    # x = softplus(y) = Log[1 + exp{y}], (which means x > 0).
1169    # ==> exp{x} = 1 + exp{y}                                (1)
1170    # ==> y = Log[exp{x} - 1]                                (2)
1171    #       = Log[(exp{x} - 1) / exp{x}] + Log[exp{x}]
1172    #       = Log[(1 - exp{-x}) / 1] + Log[exp{x}]
1173    #       = Log[1 - exp{-x}] + x                           (3)
1174    # (2) is the "obvious" inverse, but (3) is more stable than (2) for large x.
1175    # For small x (e.g. x = 1e-10), (3) will become -inf since 1 - exp{-x} will
1176    # be zero. To fix this, we use 1 - exp{-x} approx x for small x > 0.
1177    #
1178    # In addition to the numerically stable derivation above, we clamp
1179    # small/large values to be congruent with the logic in:
1180    # tensorflow/core/kernels/softplus_op.h
1181    #
1182    # Finally, we set the input to one whenever the input is too large or too
1183    # small. This ensures that no unchosen codepath is +/- inf. This is
1184    # necessary to ensure the gradient doesn't get NaNs. Recall that the
1185    # gradient of `where` behaves like `pred*pred_true + (1-pred)*pred_false`
1186    # thus an `inf` in an unselected path results in `0*inf=nan`. We are careful
1187    # to overwrite `x` with ones only when we will never actually use this
1188    # value. Note that we use ones and not zeros since `log(expm1(0.)) = -inf`.
1189    threshold = np.log(np.finfo(x.dtype.as_numpy_dtype).eps) + 2.
1190    is_too_small = math_ops.less(x, np.exp(threshold))
1191    is_too_large = math_ops.greater(x, -threshold)
1192    too_small_value = math_ops.log(x)
1193    too_large_value = x
1194    # This `where` will ultimately be a NOP because we won't select this
1195    # codepath whenever we used the surrogate `ones_like`.
1196    x = array_ops.where_v2(
1197        math_ops.logical_or(is_too_small, is_too_large), array_ops.ones_like(x),
1198        x)
1199    y = x + math_ops.log(-math_ops.expm1(-x))  # == log(expm1(x))
1200    return array_ops.where_v2(
1201        is_too_small, too_small_value,
1202        array_ops.where_v2(is_too_large, too_large_value, y))
1203
1204
1205# TODO(b/35290280): Add unit-tests.
1206def dimension_size(x, axis):
1207  """Returns the size of a specific dimension."""
1208  # Since tf.gather isn't "constant-in, constant-out", we must first check the
1209  # static shape or fallback to dynamic shape.
1210  s = tensor_shape.dimension_value(
1211      x.shape.with_rank_at_least(np.abs(axis))[axis])
1212  if s is not None:
1213    return s
1214  return array_ops.shape(x)[axis]
1215
1216
1217def process_quadrature_grid_and_probs(quadrature_grid_and_probs,
1218                                      dtype,
1219                                      validate_args,
1220                                      name=None):
1221  """Validates quadrature grid, probs or computes them as necessary.
1222
1223  Args:
1224    quadrature_grid_and_probs: Python pair of `float`-like `Tensor`s
1225      representing the sample points and the corresponding (possibly
1226      normalized) weight.  When `None`, defaults to:
1227        `np.polynomial.hermite.hermgauss(deg=8)`.
1228    dtype: The expected `dtype` of `grid` and `probs`.
1229    validate_args: Python `bool`, default `False`. When `True` distribution
1230      parameters are checked for validity despite possibly degrading runtime
1231      performance. When `False` invalid inputs may silently render incorrect
1232      outputs.
1233    name: Python `str` name prefixed to Ops created by this class.
1234
1235  Returns:
1236     quadrature_grid_and_probs: Python pair of `float`-like `Tensor`s
1237      representing the sample points and the corresponding (possibly
1238      normalized) weight.
1239
1240  Raises:
1241    ValueError: if `quadrature_grid_and_probs is not None` and
1242      `len(quadrature_grid_and_probs[0]) != len(quadrature_grid_and_probs[1])`
1243  """
1244  with ops.name_scope(name, "process_quadrature_grid_and_probs",
1245                      [quadrature_grid_and_probs]):
1246    if quadrature_grid_and_probs is None:
1247      grid, probs = np.polynomial.hermite.hermgauss(deg=8)
1248      grid = grid.astype(dtype.as_numpy_dtype)
1249      probs = probs.astype(dtype.as_numpy_dtype)
1250      probs /= np.linalg.norm(probs, ord=1, keepdims=True)
1251      grid = ops.convert_to_tensor(grid, name="grid", dtype=dtype)
1252      probs = ops.convert_to_tensor(probs, name="probs", dtype=dtype)
1253      return grid, probs
1254
1255    grid, probs = tuple(quadrature_grid_and_probs)
1256    grid = ops.convert_to_tensor(grid, name="grid", dtype=dtype)
1257    probs = ops.convert_to_tensor(probs, name="unnormalized_probs", dtype=dtype)
1258    probs /= linalg_ops.norm(probs, ord=1, axis=-1, keepdims=True, name="probs")
1259
1260    def _static_event_size(x):
1261      """Returns the static size of a specific dimension or `None`."""
1262      return tensor_shape.dimension_value(x.shape.with_rank_at_least(1)[-1])
1263
1264    m, n = _static_event_size(probs), _static_event_size(grid)
1265    if m is not None and n is not None:
1266      if m != n:
1267        raise ValueError("`quadrature_grid_and_probs` must be a `tuple` of "
1268                         "same-length zero-th-dimension `Tensor`s "
1269                         "(saw lengths {}, {})".format(m, n))
1270    elif validate_args:
1271      assertions = [
1272          check_ops.assert_equal(
1273              dimension_size(probs, axis=-1),
1274              dimension_size(grid, axis=-1),
1275              message=("`quadrature_grid_and_probs` must be a `tuple` of "
1276                       "same-length zero-th-dimension `Tensor`s")),
1277      ]
1278      with ops.control_dependencies(assertions):
1279        grid = array_ops.identity(grid)
1280        probs = array_ops.identity(probs)
1281    return grid, probs
1282
1283
1284def pad(x, axis, front=False, back=False, value=0, count=1, name=None):
1285  """Pads `value` to the front and/or back of a `Tensor` dim, `count` times.
1286
1287  Args:
1288    x: `Tensor` input.
1289    axis: Scalar `int`-like `Tensor` representing the single dimension to pad.
1290      (Negative indexing is supported.)
1291    front: Python `bool`; if `True` the beginning of the `axis` dimension is
1292      padded with `value`, `count` times. If `False` no front padding is made.
1293    back: Python `bool`; if `True` the end of the `axis` dimension is padded
1294      with `value`, `count` times. If `False` no end padding is made.
1295    value: Scalar `int`-like `Tensor` representing the actual value added to the
1296      front and/or back of the `axis` dimension of `x`.
1297    count: Scalar `int`-like `Tensor` representing number of elements added to
1298      the front and/or back of the `axis` dimension of `x`. E.g., if `front =
1299      back = True` then `2 * count` elements are added.
1300    name: Python `str` name prefixed to Ops created by this function.
1301
1302  Returns:
1303    pad: The padded version of input `x`.
1304
1305  Raises:
1306    ValueError: if both `front` and `back` are `False`.
1307    TypeError: if `count` is not `int`-like.
1308  """
1309  with ops.name_scope(name, "pad", [x, value, count]):
1310    x = ops.convert_to_tensor(x, name="x")
1311    value = ops.convert_to_tensor(value, dtype=x.dtype, name="value")
1312    count = ops.convert_to_tensor(count, name="count")
1313    if not count.dtype.is_integer:
1314      raise TypeError("`count.dtype` (`{}`) must be `int`-like.".format(
1315          count.dtype.name))
1316    if not front and not back:
1317      raise ValueError("At least one of `front`, `back` must be `True`.")
1318    ndims = (
1319        x.shape.ndims if x.shape.ndims is not None else array_ops.rank(
1320            x, name="ndims"))
1321    axis = ops.convert_to_tensor(axis, name="axis")
1322    axis_ = tensor_util.constant_value(axis)
1323    if axis_ is not None:
1324      axis = axis_
1325      if axis < 0:
1326        axis = ndims + axis
1327      count_ = tensor_util.constant_value(count)
1328      if axis_ >= 0 or x.shape.ndims is not None:
1329        head = x.shape[:axis]
1330        middle = tensor_shape.TensorShape(None if count_ is None else (
1331            tensor_shape.dimension_at_index(x.shape, axis) + count_ *
1332            (front + back)))
1333        tail = x.shape[axis + 1:]
1334        final_shape = head.concatenate(middle.concatenate(tail))
1335      else:
1336        final_shape = None
1337    else:
1338      axis = array_ops.where_v2(axis < 0, ndims + axis, axis)
1339      final_shape = None
1340    x = array_ops.pad(
1341        x,
1342        paddings=array_ops.one_hot(
1343            indices=array_ops.stack(
1344                [axis if front else -1, axis if back else -1]),
1345            depth=ndims,
1346            axis=0,
1347            on_value=count,
1348            dtype=dtypes.int32),
1349        constant_values=value)
1350    if final_shape is not None:
1351      x.set_shape(final_shape)
1352    return x
1353
1354
1355def parent_frame_arguments():
1356  """Returns parent frame arguments.
1357
1358  When called inside a function, returns a dictionary with the caller's function
1359  arguments. These are positional arguments and keyword arguments (**kwargs),
1360  while variable arguments (*varargs) are excluded.
1361
1362  When called at global scope, this will return an empty dictionary, since there
1363  are no arguments.
1364
1365  WARNING: If caller function argument names are overloaded before invoking
1366  this method, then values will reflect the overloaded value. For this reason,
1367  we recommend calling `parent_frame_arguments` at the beginning of the
1368  function.
1369  """
1370  # All arguments and the names used for *varargs, and **kwargs
1371  arg_names, variable_arg_name, keyword_arg_name, local_vars = (
1372      tf_inspect._inspect.getargvalues(  # pylint: disable=protected-access
1373          # Get the first frame of the caller of this method.
1374          tf_inspect._inspect.stack()[1][0]))  # pylint: disable=protected-access
1375
1376  # Remove the *varargs, and flatten the **kwargs. Both are
1377  # nested lists.
1378  local_vars.pop(variable_arg_name, {})
1379  keyword_args = local_vars.pop(keyword_arg_name, {})
1380
1381  final_args = {}
1382  # Copy over arguments and their values. In general, local_vars
1383  # may contain more than just the arguments, since this method
1384  # can be called anywhere in a function.
1385  for arg_name in arg_names:
1386    final_args[arg_name] = local_vars.pop(arg_name)
1387  final_args.update(keyword_args)
1388
1389  return final_args
1390
1391
1392class AppendDocstring(object):
1393  """Helper class to promote private subclass docstring to public counterpart.
1394
1395  Example:
1396
1397  ```python
1398  class TransformedDistribution(Distribution):
1399    @distribution_util.AppendDocstring(
1400      additional_note="A special note!",
1401      kwargs_dict={"foo": "An extra arg."})
1402    def _prob(self, y, foo=None):
1403      pass
1404  ```
1405
1406  In this case, the `AppendDocstring` decorator appends the `additional_note` to
1407  the docstring of `prob` (not `_prob`) and adds a new `kwargs`
1408  section with each dictionary item as a bullet-point.
1409
1410  For a more detailed example, see `TransformedDistribution`.
1411  """
1412
1413  def __init__(self, additional_note="", kwargs_dict=None):
1414    """Initializes the AppendDocstring object.
1415
1416    Args:
1417      additional_note: Python string added as additional docstring to public
1418        version of function.
1419      kwargs_dict: Python string/string dictionary representing specific kwargs
1420        expanded from the **kwargs input.
1421
1422    Raises:
1423      ValueError: if kwargs_dict.key contains whitespace.
1424      ValueError: if kwargs_dict.value contains newlines.
1425    """
1426    self._additional_note = additional_note
1427    if kwargs_dict:
1428      bullets = []
1429      for key in sorted(kwargs_dict.keys()):
1430        value = kwargs_dict[key]
1431        if any(x.isspace() for x in key):
1432          raise ValueError("Parameter name \"%s\" contains whitespace." % key)
1433        value = value.lstrip()
1434        if "\n" in value:
1435          raise ValueError(
1436              "Parameter description for \"%s\" contains newlines." % key)
1437        bullets.append("*  `%s`: %s" % (key, value))
1438      self._additional_note += ("\n\n##### `kwargs`:\n\n" + "\n".join(bullets))
1439
1440  def __call__(self, fn):
1441
1442    @functools.wraps(fn)
1443    def _fn(*args, **kwargs):
1444      return fn(*args, **kwargs)
1445
1446    if _fn.__doc__ is None:
1447      _fn.__doc__ = self._additional_note
1448    else:
1449      _fn.__doc__ += "\n%s" % self._additional_note
1450    return _fn
1451