• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for probability distributions."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import functools
22import hashlib
23
24import numpy as np
25
26from tensorflow.python.framework import constant_op
27from tensorflow.python.framework import dtypes
28from tensorflow.python.framework import ops
29from tensorflow.python.framework import tensor_shape
30from tensorflow.python.framework import tensor_util
31from tensorflow.python.ops import array_ops
32from tensorflow.python.ops import check_ops
33from tensorflow.python.ops import control_flow_ops
34from tensorflow.python.ops import linalg_ops
35from tensorflow.python.ops import math_ops
36from tensorflow.python.ops import nn
37from tensorflow.python.util import tf_inspect
38
39
40def assert_integer_form(x,
41                        data=None,
42                        summarize=None,
43                        message=None,
44                        int_dtype=None,
45                        name="assert_integer_form"):
46  """Assert that x has integer components (or floats equal to integers).
47
48  Args:
49    x: Floating-point `Tensor`
50    data: The tensors to print out if the condition is `False`. Defaults to
51      error message and first few entries of `x` and `y`.
52    summarize: Print this many entries of each tensor.
53    message: A string to prefix to the default message.
54    int_dtype: A `tf.dtype` used to cast the float to. The default (`None`)
55      implies the smallest possible signed int will be used for casting.
56    name: A name for this operation (optional).
57
58  Returns:
59    Op raising `InvalidArgumentError` if `cast(x, int_dtype) != x`.
60  """
61  with ops.name_scope(name, values=[x, data]):
62    x = ops.convert_to_tensor(x, name="x")
63    if x.dtype.is_integer:
64      return control_flow_ops.no_op()
65    message = message or "{} has non-integer components".format(x)
66    if int_dtype is None:
67      try:
68        int_dtype = {
69            dtypes.float16: dtypes.int16,
70            dtypes.float32: dtypes.int32,
71            dtypes.float64: dtypes.int64,
72        }[x.dtype.base_dtype]
73      except KeyError:
74        raise TypeError("Unrecognized type {}".format(x.dtype.name))
75    return check_ops.assert_equal(
76        x,
77        math_ops.cast(math_ops.cast(x, int_dtype), x.dtype),
78        data=data,
79        summarize=summarize,
80        message=message,
81        name=name)
82
83
84def assert_symmetric(matrix):
85  matrix_t = array_ops.matrix_transpose(matrix)
86  return control_flow_ops.with_dependencies(
87      [check_ops.assert_equal(matrix, matrix_t)], matrix)
88
89
90def embed_check_nonnegative_integer_form(
91    x, name="embed_check_nonnegative_integer_form"):
92  """Assert x is a non-negative tensor, and optionally of integers."""
93  with ops.name_scope(name, values=[x]):
94    x = ops.convert_to_tensor(x, name="x")
95    assertions = [
96        check_ops.assert_non_negative(
97            x, message="'{}' must be non-negative.".format(x)),
98    ]
99    if not x.dtype.is_integer:
100      assertions += [
101          assert_integer_form(
102              x,
103              message="'{}' cannot contain fractional components.".format(x)),
104      ]
105    return control_flow_ops.with_dependencies(assertions, x)
106
107
108def same_dynamic_shape(a, b):
109  """Returns whether a and b have the same dynamic shape.
110
111  Args:
112    a: `Tensor`
113    b: `Tensor`
114
115  Returns:
116    `bool` `Tensor` representing if both tensors have the same shape.
117  """
118  a = ops.convert_to_tensor(a, name="a")
119  b = ops.convert_to_tensor(b, name="b")
120
121  # Here we can't just do math_ops.equal(a.shape, b.shape), since
122  # static shape inference may break the equality comparison between
123  # shape(a) and shape(b) in math_ops.equal.
124  def all_shapes_equal():
125    return math_ops.reduce_all(
126        math_ops.equal(
127            array_ops.concat(
128                [array_ops.shape(a), array_ops.shape(b)], 0),
129            array_ops.concat(
130                [array_ops.shape(b), array_ops.shape(a)], 0)))
131
132  # One of the shapes isn't fully defined, so we need to use the dynamic
133  # shape.
134  return control_flow_ops.cond(
135      math_ops.equal(array_ops.rank(a), array_ops.rank(b)),
136      all_shapes_equal, lambda: constant_op.constant(False))
137
138
139def maybe_get_static_value(x, dtype=None):
140  """Helper which tries to return a static value.
141
142  Given `x`, extract it's value statically, optionally casting to a specific
143  dtype. If this is not possible, None is returned.
144
145  Args:
146    x: `Tensor` for which to extract a value statically.
147    dtype: Optional dtype to cast to.
148
149  Returns:
150    Statically inferred value if possible, otherwise None.
151  """
152  if x is None:
153    return x
154  try:
155    # This returns an np.ndarray.
156    x_ = tensor_util.constant_value(x)
157  except TypeError:
158    x_ = x
159  if x_ is None or dtype is None:
160    return x_
161  return np.array(x_, dtype)
162
163
164def get_logits_and_probs(logits=None,
165                         probs=None,
166                         multidimensional=False,
167                         validate_args=False,
168                         name="get_logits_and_probs",
169                         dtype=None):
170  """Converts logit to probabilities (or vice-versa), and returns both.
171
172  Args:
173    logits: Floating-point `Tensor` representing log-odds.
174    probs: Floating-point `Tensor` representing probabilities.
175    multidimensional: Python `bool`, default `False`. If `True`, represents
176      whether the last dimension of `logits` or `probs`, a `[N1, N2, ...  k]`
177      dimensional tensor, representing the logit or probability of `shape[-1]`
178      classes.
179    validate_args: Python `bool`, default `False`. When `True`, either assert `0
180      <= probs <= 1` (if not `multidimensional`) or that the last dimension of
181      `probs` sums to one.
182    name: A name for this operation (optional).
183    dtype: `tf.DType` to prefer when converting args to `Tensor`s.
184
185  Returns:
186    logits, probs: Tuple of `Tensor`s. If `probs` has an entry that is `0` or
187      `1`, then the corresponding entry in the returned logit will be `-Inf` and
188      `Inf` respectively.
189
190  Raises:
191    ValueError: if neither `probs` nor `logits` were passed in, or both were.
192  """
193  with ops.name_scope(name, values=[probs, logits]):
194    if (probs is None) == (logits is None):
195      raise ValueError("Must pass probs or logits, but not both.")
196
197    if probs is None:
198      logits = ops.convert_to_tensor(logits, name="logits", dtype=dtype)
199      if not logits.dtype.is_floating:
200        raise TypeError("logits must having floating type.")
201      # We can early return since we constructed probs and therefore know
202      # they're valid.
203      if multidimensional:
204        if validate_args:
205          logits = embed_check_categorical_event_shape(logits)
206        return logits, nn.softmax(logits, name="probs")
207      return logits, math_ops.sigmoid(logits, name="probs")
208
209    probs = ops.convert_to_tensor(probs, name="probs", dtype=dtype)
210    if not probs.dtype.is_floating:
211      raise TypeError("probs must having floating type.")
212
213    if validate_args:
214      with ops.name_scope("validate_probs"):
215        one = constant_op.constant(1., probs.dtype)
216        dependencies = [check_ops.assert_non_negative(probs)]
217        if multidimensional:
218          probs = embed_check_categorical_event_shape(probs)
219          dependencies += [
220              check_ops.assert_near(
221                  math_ops.reduce_sum(probs, -1),
222                  one,
223                  message="probs does not sum to 1.")
224          ]
225        else:
226          dependencies += [
227              check_ops.assert_less_equal(
228                  probs, one, message="probs has components greater than 1.")
229          ]
230        probs = control_flow_ops.with_dependencies(dependencies, probs)
231
232    with ops.name_scope("logits"):
233      if multidimensional:
234        # Here we don't compute the multidimensional case, in a manner
235        # consistent with respect to the unidimensional case. We do so
236        # following the TF convention. Typically, you might expect to see
237        # logits = log(probs) - log(probs[pivot]). A side-effect of
238        # being consistent with the TF approach is that the unidimensional case
239        # implicitly handles the second dimension but the multidimensional case
240        # explicitly keeps the pivot dimension.
241        return math_ops.log(probs), probs
242      return math_ops.log(probs) - math_ops.log1p(-1. * probs), probs
243
244
245def _is_known_unsigned_by_dtype(dt):
246  """Helper returning True if dtype is known to be unsigned."""
247  return {
248      dtypes.bool: True,
249      dtypes.uint8: True,
250      dtypes.uint16: True,
251  }.get(dt.base_dtype, False)
252
253
254def _is_known_signed_by_dtype(dt):
255  """Helper returning True if dtype is known to be signed."""
256  return {
257      dtypes.float16: True,
258      dtypes.float32: True,
259      dtypes.float64: True,
260      dtypes.int8: True,
261      dtypes.int16: True,
262      dtypes.int32: True,
263      dtypes.int64: True,
264  }.get(dt.base_dtype, False)
265
266
267def _is_known_dtype(dt):
268  """Helper returning True if dtype is known."""
269  return _is_known_unsigned_by_dtype(dt) or _is_known_signed_by_dtype(dt)
270
271
272def _largest_integer_by_dtype(dt):
273  """Helper returning the largest integer exactly representable by dtype."""
274  if not _is_known_dtype(dt):
275    raise TypeError("Unrecognized dtype: {}".format(dt.name))
276  if dt.is_floating:
277    return int(2**(np.finfo(dt.as_numpy_dtype).nmant + 1))
278  if dt.is_integer:
279    return np.iinfo(dt.as_numpy_dtype).max
280  if dt.base_dtype == dtypes.bool:
281    return int(1)
282  # We actually can't land here but keep the case for completeness.
283  raise TypeError("Unrecognized dtype: {}".format(dt.name))
284
285
286def _smallest_integer_by_dtype(dt):
287  """Helper returning the smallest integer exactly representable by dtype."""
288  if not _is_known_dtype(dt):
289    raise TypeError("Unrecognized dtype: {}".format(dt.name))
290  if _is_known_unsigned_by_dtype(dt):
291    return 0
292  return -1 * _largest_integer_by_dtype(dt)
293
294
295def _is_integer_like_by_dtype(dt):
296  """Helper returning True if dtype.is_integer or is `bool`."""
297  if not _is_known_dtype(dt):
298    raise TypeError("Unrecognized dtype: {}".format(dt.name))
299  return dt.is_integer or dt.base_dtype == dtypes.bool
300
301
302def embed_check_categorical_event_shape(
303    categorical_param, name="embed_check_categorical_event_shape"):
304  """Embeds checks that categorical distributions don't have too many classes.
305
306  A categorical-type distribution is one which, e.g., returns the class label
307  rather than a one-hot encoding.  E.g., `Categorical(probs)`.
308
309  Since distributions output samples in the same dtype as the parameters, we
310  must ensure that casting doesn't lose precision. That is, the
311  `parameter.dtype` implies a maximum number of classes. However, since shape is
312  `int32` and categorical variables are presumed to be indexes into a `Tensor`,
313  we must also ensure that the number of classes is no larger than the largest
314  possible `int32` index, i.e., `2**31-1`.
315
316  In other words the number of classes, `K`, must satisfy the following
317  condition:
318
319  ```python
320  K <= min(
321      int(2**31 - 1),  # Largest float as an index.
322      {
323          dtypes.float16: int(2**11),   # Largest int as a float16.
324          dtypes.float32: int(2**24),
325          dtypes.float64: int(2**53),
326      }.get(categorical_param.dtype.base_dtype, 0))
327  ```
328
329  Args:
330    categorical_param: Floating-point `Tensor` representing parameters of
331      distribution over categories. The rightmost shape is presumed to be the
332      number of categories.
333    name: A name for this operation (optional).
334
335  Returns:
336    categorical_param: Input `Tensor` with appropriate assertions embedded.
337
338  Raises:
339    TypeError: if `categorical_param` has an unknown `dtype`.
340    ValueError: if we can statically identify `categorical_param` as being too
341      large (for being closed under int32/float casting).
342  """
343  with ops.name_scope(name, values=[categorical_param]):
344    x = ops.convert_to_tensor(categorical_param, name="categorical_param")
345    # The size must not exceed both of:
346    # - The largest possible int32 (since categorical values are presumed to be
347    #   indexes into a Tensor).
348    # - The largest possible integer exactly representable under the given
349    #   floating-point dtype (since we need to cast to/from).
350    #
351    # The chosen floating-point thresholds are 2**(1 + mantissa_bits).
352    # For more details, see:
353    # https://en.wikipedia.org/wiki/Floating-point_arithmetic#Internal_representation
354    x_dtype = x.dtype.base_dtype
355    max_event_size = (
356        _largest_integer_by_dtype(x_dtype) if x_dtype.is_floating else 0)
357    if max_event_size == 0:
358      raise TypeError("Unable to validate size of unrecognized dtype "
359                      "({}).".format(x_dtype.name))
360    try:
361      x_shape_static = x.get_shape().with_rank_at_least(1)
362    except ValueError:
363      raise ValueError("A categorical-distribution parameter must have "
364                       "at least 1 dimension.")
365    if tensor_shape.dimension_value(x_shape_static[-1]) is not None:
366      event_size = x_shape_static.dims[-1].value
367      if event_size < 2:
368        raise ValueError("A categorical-distribution parameter must have at "
369                         "least 2 events.")
370      if event_size > max_event_size:
371        raise ValueError("Number of classes exceeds `dtype` precision, i.e., "
372                         "{} implies shape ({}) cannot exceed {}.".format(
373                             x_dtype.name, event_size, max_event_size))
374      return x
375    else:
376      event_size = array_ops.shape(x, name="x_shape")[-1]
377      return control_flow_ops.with_dependencies([
378          check_ops.assert_rank_at_least(
379              x,
380              1,
381              message=("A categorical-distribution parameter must have "
382                       "at least 1 dimension.")),
383          check_ops.assert_greater_equal(
384              array_ops.shape(x)[-1],
385              2,
386              message=("A categorical-distribution parameter must have at "
387                       "least 2 events.")),
388          check_ops.assert_less_equal(
389              event_size,
390              max_event_size,
391              message="Number of classes exceeds `dtype` precision, "
392              "i.e., {} dtype cannot exceed {} shape.".format(
393                  x_dtype.name, max_event_size)),
394      ], x)
395
396
397def embed_check_integer_casting_closed(x,
398                                       target_dtype,
399                                       assert_nonnegative=True,
400                                       name="embed_check_casting_closed"):
401  """Ensures integers remain unaffected despite casting to/from int/float types.
402
403  Example integer-types: `uint8`, `int32`, `bool`.
404  Example floating-types: `float32`, `float64`.
405
406  The largest possible integer representable by an IEEE754 floating-point is
407  `2**(1 + mantissa_bits)` yet the largest possible integer as an int-type is
408  `2**(bits - 1) - 1`. This function ensures that a `Tensor` purporting to have
409  integer-form values can be cast to some other type without loss of precision.
410
411  The smallest representable integer is the negative of the largest
412  representable integer, except for types: `uint8`, `uint16`, `bool`. For these
413  types, the smallest representable integer is `0`.
414
415  Args:
416    x: `Tensor` representing integer-form values.
417    target_dtype: TF `dtype` under which `x` should have identical values.
418    assert_nonnegative: `bool` indicating `x` should contain nonnegative values.
419    name: A name for this operation (optional).
420
421  Returns:
422    x: Input `Tensor` with appropriate assertions embedded.
423
424  Raises:
425    TypeError: if `x` is neither integer- nor floating-type.
426    TypeError: if `target_dtype` is neither integer- nor floating-type.
427    TypeError: if neither `x` nor `target_dtype` are integer-type.
428  """
429
430  with ops.name_scope(name, values=[x]):
431    x = ops.convert_to_tensor(x, name="x")
432    if (not _is_integer_like_by_dtype(x.dtype) and not x.dtype.is_floating):
433      raise TypeError("{}.dtype must be floating- or "
434                      "integer-type.".format(x.dtype.name))
435    if (not _is_integer_like_by_dtype(target_dtype) and
436        not target_dtype.is_floating):
437      raise TypeError("target_dtype ({}) must be floating- or "
438                      "integer-type.".format(target_dtype.name))
439    if (not _is_integer_like_by_dtype(x.dtype) and
440        not _is_integer_like_by_dtype(target_dtype)):
441      raise TypeError("At least one of {}.dtype ({}) and target_dtype ({}) "
442                      "must be integer-type.".format(x, x.dtype.name,
443                                                     target_dtype.name))
444
445    assertions = []
446    if assert_nonnegative:
447      assertions += [
448          check_ops.assert_non_negative(
449              x, message="Elements must be non-negative."),
450      ]
451
452    if x.dtype.is_floating:
453      # Being here means _is_integer_like_by_dtype(target_dtype) = True.
454      # Since this check implies the magnitude check below, we need only it.
455      assertions += [
456          assert_integer_form(
457              x,
458              int_dtype=target_dtype,
459              message="Elements must be {}-equivalent.".format(
460                  target_dtype.name)),
461      ]
462    else:
463      if (_largest_integer_by_dtype(x.dtype) >
464          _largest_integer_by_dtype(target_dtype)):
465        # Cast may lose integer precision.
466        assertions += [
467            check_ops.assert_less_equal(
468                x,
469                _largest_integer_by_dtype(target_dtype),
470                message=("Elements cannot exceed {}.".format(
471                    _largest_integer_by_dtype(target_dtype)))),
472        ]
473      if (not assert_nonnegative and (_smallest_integer_by_dtype(
474          x.dtype) < _smallest_integer_by_dtype(target_dtype))):
475        assertions += [
476            check_ops.assert_greater_equal(
477                x,
478                _smallest_integer_by_dtype(target_dtype),
479                message=("Elements cannot be smaller than {}.".format(
480                    _smallest_integer_by_dtype(target_dtype)))),
481        ]
482
483    if not assertions:
484      return x
485    return control_flow_ops.with_dependencies(assertions, x)
486
487
488def log_combinations(n, counts, name="log_combinations"):
489  """Multinomial coefficient.
490
491  Given `n` and `counts`, where `counts` has last dimension `k`, we compute
492  the multinomial coefficient as:
493
494  ```n! / sum_i n_i!```
495
496  where `i` runs over all `k` classes.
497
498  Args:
499    n: Floating-point `Tensor` broadcastable with `counts`. This represents `n`
500      outcomes.
501    counts: Floating-point `Tensor` broadcastable with `n`. This represents
502      counts in `k` classes, where `k` is the last dimension of the tensor.
503    name: A name for this operation (optional).
504
505  Returns:
506    `Tensor` representing the multinomial coefficient between `n` and `counts`.
507  """
508  # First a bit about the number of ways counts could have come in:
509  # E.g. if counts = [1, 2], then this is 3 choose 2.
510  # In general, this is (sum counts)! / sum(counts!)
511  # The sum should be along the last dimension of counts. This is the
512  # "distribution" dimension. Here n a priori represents the sum of counts.
513  with ops.name_scope(name, values=[n, counts]):
514    n = ops.convert_to_tensor(n, name="n")
515    counts = ops.convert_to_tensor(counts, name="counts")
516    total_permutations = math_ops.lgamma(n + 1)
517    counts_factorial = math_ops.lgamma(counts + 1)
518    redundant_permutations = math_ops.reduce_sum(counts_factorial, axis=[-1])
519    return total_permutations - redundant_permutations
520
521
522def matrix_diag_transform(matrix, transform=None, name=None):
523  """Transform diagonal of [batch-]matrix, leave rest of matrix unchanged.
524
525  Create a trainable covariance defined by a Cholesky factor:
526
527  ```python
528  # Transform network layer into 2 x 2 array.
529  matrix_values = tf.contrib.layers.fully_connected(activations, 4)
530  matrix = tf.reshape(matrix_values, (batch_size, 2, 2))
531
532  # Make the diagonal positive. If the upper triangle was zero, this would be a
533  # valid Cholesky factor.
534  chol = matrix_diag_transform(matrix, transform=tf.nn.softplus)
535
536  # LinearOperatorLowerTriangular ignores the upper triangle.
537  operator = LinearOperatorLowerTriangular(chol)
538  ```
539
540  Example of heteroskedastic 2-D linear regression.
541
542  ```python
543  tfd = tfp.distributions
544
545  # Get a trainable Cholesky factor.
546  matrix_values = tf.contrib.layers.fully_connected(activations, 4)
547  matrix = tf.reshape(matrix_values, (batch_size, 2, 2))
548  chol = matrix_diag_transform(matrix, transform=tf.nn.softplus)
549
550  # Get a trainable mean.
551  mu = tf.contrib.layers.fully_connected(activations, 2)
552
553  # This is a fully trainable multivariate normal!
554  dist = tfd.MultivariateNormalTriL(mu, chol)
555
556  # Standard log loss. Minimizing this will "train" mu and chol, and then dist
557  # will be a distribution predicting labels as multivariate Gaussians.
558  loss = -1 * tf.reduce_mean(dist.log_prob(labels))
559  ```
560
561  Args:
562    matrix:  Rank `R` `Tensor`, `R >= 2`, where the last two dimensions are
563      equal.
564    transform:  Element-wise function mapping `Tensors` to `Tensors`. To be
565      applied to the diagonal of `matrix`. If `None`, `matrix` is returned
566      unchanged. Defaults to `None`.
567    name:  A name to give created ops. Defaults to "matrix_diag_transform".
568
569  Returns:
570    A `Tensor` with same shape and `dtype` as `matrix`.
571  """
572  with ops.name_scope(name, "matrix_diag_transform", [matrix]):
573    matrix = ops.convert_to_tensor(matrix, name="matrix")
574    if transform is None:
575      return matrix
576    # Replace the diag with transformed diag.
577    diag = array_ops.matrix_diag_part(matrix)
578    transformed_diag = transform(diag)
579    transformed_mat = array_ops.matrix_set_diag(matrix, transformed_diag)
580
581  return transformed_mat
582
583
584def rotate_transpose(x, shift, name="rotate_transpose"):
585  """Circularly moves dims left or right.
586
587  Effectively identical to:
588
589  ```python
590  numpy.transpose(x, numpy.roll(numpy.arange(len(x.shape)), shift))
591  ```
592
593  When `validate_args=False` additional graph-runtime checks are
594  performed. These checks entail moving data from to GPU to CPU.
595
596  Example:
597
598  ```python
599  x = tf.random.normal([1, 2, 3, 4])  # Tensor of shape [1, 2, 3, 4].
600  rotate_transpose(x, -1).shape == [2, 3, 4, 1]
601  rotate_transpose(x, -2).shape == [3, 4, 1, 2]
602  rotate_transpose(x,  1).shape == [4, 1, 2, 3]
603  rotate_transpose(x,  2).shape == [3, 4, 1, 2]
604  rotate_transpose(x,  7).shape == rotate_transpose(x, 3).shape  # [2, 3, 4, 1]
605  rotate_transpose(x, -7).shape == rotate_transpose(x, -3).shape  # [4, 1, 2, 3]
606  ```
607
608  Args:
609    x: `Tensor`.
610    shift: `Tensor`. Number of dimensions to transpose left (shift<0) or
611      transpose right (shift>0).
612    name: Python `str`. The name to give this op.
613
614  Returns:
615    rotated_x: Input `Tensor` with dimensions circularly rotated by shift.
616
617  Raises:
618    TypeError: if shift is not integer type.
619  """
620  with ops.name_scope(name, values=[x, shift]):
621    x = ops.convert_to_tensor(x, name="x")
622    shift = ops.convert_to_tensor(shift, name="shift")
623    # We do not assign back to preserve constant-ness.
624    check_ops.assert_integer(shift)
625    shift_value_static = tensor_util.constant_value(shift)
626    ndims = x.get_shape().ndims
627    if ndims is not None and shift_value_static is not None:
628      if ndims < 2:
629        return x
630      shift_value_static = np.sign(shift_value_static) * (
631          abs(shift_value_static) % ndims)
632      if shift_value_static == 0:
633        return x
634      perm = np.roll(np.arange(ndims), shift_value_static)
635      return array_ops.transpose(x, perm=perm)
636    else:
637      # Consider if we always had a positive shift, and some specified
638      # direction.
639      # When shifting left we want the new array:
640      #   last(x, n-shift) + first(x, shift)
641      # and if shifting right then we want:
642      #   last(x, shift) + first(x, n-shift)
643      # Observe that last(a) == slice(a, n) and first(a) == slice(0, a).
644      # Also, we can encode direction and shift as one: direction * shift.
645      # Combining these facts, we have:
646      #   a = cond(shift<0, -shift, n-shift)
647      #   last(x, n-a) + first(x, a) == x[a:n] + x[0:a]
648      # Finally, we transform shift by modulo length so it can be specified
649      # independently from the array upon which it operates (like python).
650      ndims = array_ops.rank(x)
651      shift = array_ops.where_v2(
652          math_ops.less(shift, 0), math_ops.mod(-shift, ndims),
653          ndims - math_ops.mod(shift, ndims))
654      first = math_ops.range(0, shift)
655      last = math_ops.range(shift, ndims)
656      perm = array_ops.concat([last, first], 0)
657      return array_ops.transpose(x, perm=perm)
658
659
660def pick_vector(cond, true_vector, false_vector, name="pick_vector"):
661  """Picks possibly different length row `Tensor`s based on condition.
662
663  Value `Tensor`s should have exactly one dimension.
664
665  If `cond` is a python Boolean or `tf.constant` then either `true_vector` or
666  `false_vector` is immediately returned. I.e., no graph nodes are created and
667  no validation happens.
668
669  Args:
670    cond: `Tensor`. Must have `dtype=tf.bool` and be scalar.
671    true_vector: `Tensor` of one dimension. Returned when cond is `True`.
672    false_vector: `Tensor` of one dimension. Returned when cond is `False`.
673    name: Python `str`. The name to give this op.
674  Example:  ```python pick_vector(tf.less(0, 5), tf.range(10, 12), tf.range(15,
675    18))  # [10, 11] pick_vector(tf.less(5, 0), tf.range(10, 12), tf.range(15,
676    18))  # [15, 16, 17] ```
677
678  Returns:
679    true_or_false_vector: `Tensor`.
680
681  Raises:
682    TypeError: if `cond.dtype != tf.bool`
683    TypeError: if `cond` is not a constant and
684      `true_vector.dtype != false_vector.dtype`
685  """
686  with ops.name_scope(name, values=(cond, true_vector, false_vector)):
687    cond = ops.convert_to_tensor(cond, name="cond")
688    if cond.dtype != dtypes.bool:
689      raise TypeError("%s.dtype=%s which is not %s" %
690                      (cond, cond.dtype, dtypes.bool))
691    cond_value_static = tensor_util.constant_value(cond)
692    if cond_value_static is not None:
693      return true_vector if cond_value_static else false_vector
694    true_vector = ops.convert_to_tensor(true_vector, name="true_vector")
695    false_vector = ops.convert_to_tensor(false_vector, name="false_vector")
696    if true_vector.dtype != false_vector.dtype:
697      raise TypeError(
698          "%s.dtype=%s does not match %s.dtype=%s" %
699          (true_vector, true_vector.dtype, false_vector, false_vector.dtype))
700    n = array_ops.shape(true_vector)[0]
701    return array_ops.slice(
702        array_ops.concat([true_vector, false_vector], 0),
703        [array_ops.where_v2(cond, 0, n)], [array_ops.where(cond, n, -1)])
704
705
706def prefer_static_broadcast_shape(shape1,
707                                  shape2,
708                                  name="prefer_static_broadcast_shape"):
709  """Convenience function which statically broadcasts shape when possible.
710
711  Args:
712    shape1:  `1-D` integer `Tensor`.  Already converted to tensor!
713    shape2:  `1-D` integer `Tensor`.  Already converted to tensor!
714    name:  A string name to prepend to created ops.
715
716  Returns:
717    The broadcast shape, either as `TensorShape` (if broadcast can be done
718      statically), or as a `Tensor`.
719  """
720  with ops.name_scope(name, values=[shape1, shape2]):
721
722    def make_shape_tensor(x):
723      return ops.convert_to_tensor(x, name="shape", dtype=dtypes.int32)
724
725    def get_tensor_shape(s):
726      if isinstance(s, tensor_shape.TensorShape):
727        return s
728      s_ = tensor_util.constant_value(make_shape_tensor(s))
729      if s_ is not None:
730        return tensor_shape.TensorShape(s_)
731      return None
732
733    def get_shape_tensor(s):
734      if not isinstance(s, tensor_shape.TensorShape):
735        return make_shape_tensor(s)
736      if s.is_fully_defined():
737        return make_shape_tensor(s.as_list())
738      raise ValueError("Cannot broadcast from partially "
739                       "defined `TensorShape`.")
740
741    shape1_ = get_tensor_shape(shape1)
742    shape2_ = get_tensor_shape(shape2)
743    if shape1_ is not None and shape2_ is not None:
744      return array_ops.broadcast_static_shape(shape1_, shape2_)
745
746    shape1_ = get_shape_tensor(shape1)
747    shape2_ = get_shape_tensor(shape2)
748    return array_ops.broadcast_dynamic_shape(shape1_, shape2_)
749
750
751def prefer_static_rank(x):
752  """Return static rank of tensor `x` if available, else `tf.rank(x)`.
753
754  Args:
755    x: `Tensor` (already converted).
756
757  Returns:
758    Numpy array (if static rank is obtainable), else `Tensor`.
759  """
760  return prefer_static_value(array_ops.rank(x))
761
762
763def prefer_static_shape(x):
764  """Return static shape of tensor `x` if available, else `tf.shape(x)`.
765
766  Args:
767    x: `Tensor` (already converted).
768
769  Returns:
770    Numpy array (if static shape is obtainable), else `Tensor`.
771  """
772  return prefer_static_value(array_ops.shape(x))
773
774
775def prefer_static_value(x):
776  """Return static value of tensor `x` if available, else `x`.
777
778  Args:
779    x: `Tensor` (already converted).
780
781  Returns:
782    Numpy array (if static value is obtainable), else `Tensor`.
783  """
784  static_x = tensor_util.constant_value(x)
785  if static_x is not None:
786    return static_x
787  return x
788
789
790def gen_new_seed(seed, salt):
791  """Generate a new seed, from the given seed and salt."""
792  if seed is None:
793    return None
794  string = (str(seed) + salt).encode("utf-8")
795  return int(hashlib.md5(string).hexdigest()[:8], 16) & 0x7FFFFFFF
796
797
798def fill_triangular(x, upper=False, name=None):
799  """Creates a (batch of) triangular matrix from a vector of inputs.
800
801  Created matrix can be lower- or upper-triangular. (It is more efficient to
802  create the matrix as upper or lower, rather than transpose.)
803
804  Triangular matrix elements are filled in a clockwise spiral. See example,
805  below.
806
807  If `x.get_shape()` is `[b1, b2, ..., bB, d]` then the output shape is
808  `[b1, b2, ..., bB, n, n]` where `n` is such that `d = n(n+1)/2`, i.e.,
809  `n = int(np.sqrt(0.25 + 2. * m) - 0.5)`.
810
811  Example:
812
813  ```python
814  fill_triangular([1, 2, 3, 4, 5, 6])
815  # ==> [[4, 0, 0],
816  #      [6, 5, 0],
817  #      [3, 2, 1]]
818
819  fill_triangular([1, 2, 3, 4, 5, 6], upper=True)
820  # ==> [[1, 2, 3],
821  #      [0, 5, 6],
822  #      [0, 0, 4]]
823  ```
824
825  For comparison, a pure numpy version of this function can be found in
826  `util_test.py`, function `_fill_triangular`.
827
828  Args:
829    x: `Tensor` representing lower (or upper) triangular elements.
830    upper: Python `bool` representing whether output matrix should be upper
831      triangular (`True`) or lower triangular (`False`, default).
832    name: Python `str`. The name to give this op.
833
834  Returns:
835    tril: `Tensor` with lower (or upper) triangular elements filled from `x`.
836
837  Raises:
838    ValueError: if `x` cannot be mapped to a triangular matrix.
839  """
840
841  with ops.name_scope(name, "fill_triangular", values=[x]):
842    x = ops.convert_to_tensor(x, name="x")
843    if tensor_shape.dimension_value(
844        x.shape.with_rank_at_least(1)[-1]) is not None:
845      # Formula derived by solving for n: m = n(n+1)/2.
846      m = np.int32(x.shape.dims[-1].value)
847      n = np.sqrt(0.25 + 2. * m) - 0.5
848      if n != np.floor(n):
849        raise ValueError("Input right-most shape ({}) does not "
850                         "correspond to a triangular matrix.".format(m))
851      n = np.int32(n)
852      static_final_shape = x.shape[:-1].concatenate([n, n])
853    else:
854      m = array_ops.shape(x)[-1]
855      # For derivation, see above. Casting automatically lops off the 0.5, so we
856      # omit it.  We don't validate n is an integer because this has
857      # graph-execution cost; an error will be thrown from the reshape, below.
858      n = math_ops.cast(
859          math_ops.sqrt(0.25 + math_ops.cast(2 * m, dtype=dtypes.float32)),
860          dtype=dtypes.int32)
861      static_final_shape = x.shape.with_rank_at_least(1)[:-1].concatenate(
862          [None, None])
863    # We now concatenate the "tail" of `x` to `x` (and reverse one of them).
864    #
865    # We do this based on the insight that the input `x` provides `ceil(n/2)`
866    # rows of an `n x n` matrix, some of which will get zeroed out being on the
867    # wrong side of the diagonal. The first row will not get zeroed out at all,
868    # and we need `floor(n/2)` more rows, so the first is what we omit from
869    # `x_tail`. If we then stack those `ceil(n/2)` rows with the `floor(n/2)`
870    # rows provided by a reversed tail, it is exactly the other set of elements
871    # of the reversed tail which will be zeroed out for being on the wrong side
872    # of the diagonal further up/down the matrix. And, in doing-so, we've filled
873    # the triangular matrix in a clock-wise spiral pattern. Neat!
874    #
875    # Try it out in numpy:
876    #  n = 3
877    #  x = np.arange(n * (n + 1) / 2)
878    #  m = x.shape[0]
879    #  n = np.int32(np.sqrt(.25 + 2 * m) - .5)
880    #  x_tail = x[(m - (n**2 - m)):]
881    #  np.concatenate([x_tail, x[::-1]], 0).reshape(n, n)  # lower
882    #  # ==> array([[3, 4, 5],
883    #               [5, 4, 3],
884    #               [2, 1, 0]])
885    #  np.concatenate([x, x_tail[::-1]], 0).reshape(n, n)  # upper
886    #  # ==> array([[0, 1, 2],
887    #               [3, 4, 5],
888    #               [5, 4, 3]])
889    #
890    # Note that we can't simply do `x[..., -(n**2 - m):]` because this doesn't
891    # correctly handle `m == n == 1`. Hence, we do nonnegative indexing.
892    # Furthermore observe that:
893    #   m - (n**2 - m)
894    #   = n**2 / 2 + n / 2 - (n**2 - n**2 / 2 + n / 2)
895    #   = 2 (n**2 / 2 + n / 2) - n**2
896    #   = n**2 + n - n**2
897    #   = n
898    ndims = prefer_static_rank(x)
899    if upper:
900      x_list = [x, array_ops.reverse(x[..., n:], axis=[ndims - 1])]
901    else:
902      x_list = [x[..., n:], array_ops.reverse(x, axis=[ndims - 1])]
903    new_shape = (
904        static_final_shape.as_list() if static_final_shape.is_fully_defined()
905        else array_ops.concat([array_ops.shape(x)[:-1], [n, n]], axis=0))
906    x = array_ops.reshape(array_ops.concat(x_list, axis=-1), new_shape)
907    x = array_ops.matrix_band_part(
908        x, num_lower=(0 if upper else -1), num_upper=(-1 if upper else 0))
909    x.set_shape(static_final_shape)
910    return x
911
912
913def fill_triangular_inverse(x, upper=False, name=None):
914  """Creates a vector from a (batch of) triangular matrix.
915
916  The vector is created from the lower-triangular or upper-triangular portion
917  depending on the value of the parameter `upper`.
918
919  If `x.shape` is `[b1, b2, ..., bB, n, n]` then the output shape is
920  `[b1, b2, ..., bB, d]` where `d = n (n + 1) / 2`.
921
922  Example:
923
924  ```python
925  fill_triangular_inverse(
926    [[4, 0, 0],
927     [6, 5, 0],
928     [3, 2, 1]])
929
930  # ==> [1, 2, 3, 4, 5, 6]
931
932  fill_triangular_inverse(
933    [[1, 2, 3],
934     [0, 5, 6],
935     [0, 0, 4]], upper=True)
936
937  # ==> [1, 2, 3, 4, 5, 6]
938  ```
939
940  Args:
941    x: `Tensor` representing lower (or upper) triangular elements.
942    upper: Python `bool` representing whether output matrix should be upper
943      triangular (`True`) or lower triangular (`False`, default).
944    name: Python `str`. The name to give this op.
945
946  Returns:
947    flat_tril: (Batch of) vector-shaped `Tensor` representing vectorized lower
948      (or upper) triangular elements from `x`.
949  """
950
951  with ops.name_scope(name, "fill_triangular_inverse", values=[x]):
952    x = ops.convert_to_tensor(x, name="x")
953    if tensor_shape.dimension_value(
954        x.shape.with_rank_at_least(2)[-1]) is not None:
955      n = np.int32(x.shape.dims[-1].value)
956      m = np.int32((n * (n + 1)) // 2)
957      static_final_shape = x.shape[:-2].concatenate([m])
958    else:
959      n = array_ops.shape(x)[-1]
960      m = (n * (n + 1)) // 2
961      static_final_shape = x.shape.with_rank_at_least(2)[:-2].concatenate(
962          [None])
963    ndims = prefer_static_rank(x)
964    if upper:
965      initial_elements = x[..., 0, :]
966      triangular_portion = x[..., 1:, :]
967    else:
968      initial_elements = array_ops.reverse(x[..., -1, :], axis=[ndims - 2])
969      triangular_portion = x[..., :-1, :]
970    rotated_triangular_portion = array_ops.reverse(
971        array_ops.reverse(triangular_portion, axis=[ndims - 1]),
972        axis=[ndims - 2])
973    consolidated_matrix = triangular_portion + rotated_triangular_portion
974    end_sequence = array_ops.reshape(
975        consolidated_matrix,
976        array_ops.concat([array_ops.shape(x)[:-2], [n * (n - 1)]], axis=0))
977    y = array_ops.concat([initial_elements, end_sequence[..., :m - n]], axis=-1)
978    y.set_shape(static_final_shape)
979    return y
980
981
982def tridiag(below=None, diag=None, above=None, name=None):
983  """Creates a matrix with values set above, below, and on the diagonal.
984
985  Example:
986
987  ```python
988  tridiag(below=[1., 2., 3.],
989          diag=[4., 5., 6., 7.],
990          above=[8., 9., 10.])
991  # ==> array([[  4.,   8.,   0.,   0.],
992  #            [  1.,   5.,   9.,   0.],
993  #            [  0.,   2.,   6.,  10.],
994  #            [  0.,   0.,   3.,   7.]], dtype=float32)
995  ```
996
997  Warning: This Op is intended for convenience, not efficiency.
998
999  Args:
1000    below: `Tensor` of shape `[B1, ..., Bb, d-1]` corresponding to the below
1001      diagonal part. `None` is logically equivalent to `below = 0`.
1002    diag: `Tensor` of shape `[B1, ..., Bb, d]` corresponding to the diagonal
1003      part.  `None` is logically equivalent to `diag = 0`.
1004    above: `Tensor` of shape `[B1, ..., Bb, d-1]` corresponding to the above
1005      diagonal part.  `None` is logically equivalent to `above = 0`.
1006    name: Python `str`. The name to give this op.
1007
1008  Returns:
1009    tridiag: `Tensor` with values set above, below and on the diagonal.
1010
1011  Raises:
1012    ValueError: if all inputs are `None`.
1013  """
1014
1015  def _pad(x):
1016    """Prepends and appends a zero to every vector in a batch of vectors."""
1017    shape = array_ops.concat([array_ops.shape(x)[:-1], [1]], axis=0)
1018    z = array_ops.zeros(shape, dtype=x.dtype)
1019    return array_ops.concat([z, x, z], axis=-1)
1020
1021  def _add(*x):
1022    """Adds list of Tensors, ignoring `None`."""
1023    s = None
1024    for y in x:
1025      if y is None:
1026        continue
1027      elif s is None:
1028        s = y
1029      else:
1030        s += y
1031    if s is None:
1032      raise ValueError("Must specify at least one of `below`, `diag`, `above`.")
1033    return s
1034
1035  with ops.name_scope(name, "tridiag", [below, diag, above]):
1036    if below is not None:
1037      below = ops.convert_to_tensor(below, name="below")
1038      below = array_ops.matrix_diag(_pad(below))[..., :-1, 1:]
1039    if diag is not None:
1040      diag = ops.convert_to_tensor(diag, name="diag")
1041      diag = array_ops.matrix_diag(diag)
1042    if above is not None:
1043      above = ops.convert_to_tensor(above, name="above")
1044      above = array_ops.matrix_diag(_pad(above))[..., 1:, :-1]
1045    # TODO(jvdillon): Consider using scatter_nd instead of creating three full
1046    # matrices.
1047    return _add(below, diag, above)
1048
1049
1050def reduce_weighted_logsumexp(logx,
1051                              w=None,
1052                              axis=None,
1053                              keep_dims=False,
1054                              return_sign=False,
1055                              name=None):
1056  """Computes `log(abs(sum(weight * exp(elements across tensor dimensions))))`.
1057
1058  If all weights `w` are known to be positive, it is more efficient to directly
1059  use `reduce_logsumexp`, i.e., `tf.reduce_logsumexp(logx + tf.math.log(w))` is
1060  more
1061  efficient than `du.reduce_weighted_logsumexp(logx, w)`.
1062
1063  Reduces `input_tensor` along the dimensions given in `axis`.
1064  Unless `keep_dims` is true, the rank of the tensor is reduced by 1 for each
1065  entry in `axis`. If `keep_dims` is true, the reduced dimensions
1066  are retained with length 1.
1067
1068  If `axis` has no entries, all dimensions are reduced, and a
1069  tensor with a single element is returned.
1070
1071  This function is more numerically stable than log(sum(w * exp(input))). It
1072  avoids overflows caused by taking the exp of large inputs and underflows
1073  caused by taking the log of small inputs.
1074
1075  For example:
1076
1077  ```python
1078  x = tf.constant([[0., 0, 0],
1079                   [0, 0, 0]])
1080
1081  w = tf.constant([[-1., 1, 1],
1082                   [1, 1, 1]])
1083
1084  du.reduce_weighted_logsumexp(x, w)
1085  # ==> log(-1*1 + 1*1 + 1*1 + 1*1 + 1*1 + 1*1) = log(4)
1086
1087  du.reduce_weighted_logsumexp(x, w, axis=0)
1088  # ==> [log(-1+1), log(1+1), log(1+1)]
1089
1090  du.reduce_weighted_logsumexp(x, w, axis=1)
1091  # ==> [log(-1+1+1), log(1+1+1)]
1092
1093  du.reduce_weighted_logsumexp(x, w, axis=1, keep_dims=True)
1094  # ==> [[log(-1+1+1)], [log(1+1+1)]]
1095
1096  du.reduce_weighted_logsumexp(x, w, axis=[0, 1])
1097  # ==> log(-1+5)
1098  ```
1099
1100  Args:
1101    logx: The tensor to reduce. Should have numeric type.
1102    w: The weight tensor. Should have numeric type identical to `logx`.
1103    axis: The dimensions to reduce. If `None` (the default), reduces all
1104      dimensions. Must be in the range `[-rank(input_tensor),
1105      rank(input_tensor))`.
1106    keep_dims: If true, retains reduced dimensions with length 1.
1107    return_sign: If `True`, returns the sign of the result.
1108    name: A name for the operation (optional).
1109
1110  Returns:
1111    lswe: The `log(abs(sum(weight * exp(x))))` reduced tensor.
1112    sign: (Optional) The sign of `sum(weight * exp(x))`.
1113  """
1114  with ops.name_scope(name, "reduce_weighted_logsumexp", [logx, w]):
1115    logx = ops.convert_to_tensor(logx, name="logx")
1116    if w is None:
1117      lswe = math_ops.reduce_logsumexp(logx, axis=axis, keepdims=keep_dims)
1118      if return_sign:
1119        sgn = array_ops.ones_like(lswe)
1120        return lswe, sgn
1121      return lswe
1122    w = ops.convert_to_tensor(w, dtype=logx.dtype, name="w")
1123    log_absw_x = logx + math_ops.log(math_ops.abs(w))
1124    max_log_absw_x = math_ops.reduce_max(log_absw_x, axis=axis, keepdims=True)
1125    # If the largest element is `-inf` or `inf` then we don't bother subtracting
1126    # off the max. We do this because otherwise we'd get `inf - inf = NaN`. That
1127    # this is ok follows from the fact that we're actually free to subtract any
1128    # value we like, so long as we add it back after taking the `log(sum(...))`.
1129    max_log_absw_x = array_ops.where_v2(
1130        math_ops.is_inf(max_log_absw_x), array_ops.zeros_like(max_log_absw_x),
1131        max_log_absw_x)
1132    wx_over_max_absw_x = (
1133        math_ops.sign(w) * math_ops.exp(log_absw_x - max_log_absw_x))
1134    sum_wx_over_max_absw_x = math_ops.reduce_sum(
1135        wx_over_max_absw_x, axis=axis, keepdims=keep_dims)
1136    if not keep_dims:
1137      max_log_absw_x = array_ops.squeeze(max_log_absw_x, axis)
1138    sgn = math_ops.sign(sum_wx_over_max_absw_x)
1139    lswe = max_log_absw_x + math_ops.log(sgn * sum_wx_over_max_absw_x)
1140    if return_sign:
1141      return lswe, sgn
1142    return lswe
1143
1144
1145# TODO(jvdillon): Merge this test back into:
1146# tensorflow/python/ops/softplus_op_test.py
1147# once TF core is accepting new ops.
1148def softplus_inverse(x, name=None):
1149  """Computes the inverse softplus, i.e., x = softplus_inverse(softplus(x)).
1150
1151  Mathematically this op is equivalent to:
1152
1153  ```none
1154  softplus_inverse = log(exp(x) - 1.)
1155  ```
1156
1157  Args:
1158    x: `Tensor`. Non-negative (not enforced), floating-point.
1159    name: A name for the operation (optional).
1160
1161  Returns:
1162    `Tensor`. Has the same type/shape as input `x`.
1163  """
1164  with ops.name_scope(name, "softplus_inverse", values=[x]):
1165    x = ops.convert_to_tensor(x, name="x")
1166    # We begin by deriving a more numerically stable softplus_inverse:
1167    # x = softplus(y) = Log[1 + exp{y}], (which means x > 0).
1168    # ==> exp{x} = 1 + exp{y}                                (1)
1169    # ==> y = Log[exp{x} - 1]                                (2)
1170    #       = Log[(exp{x} - 1) / exp{x}] + Log[exp{x}]
1171    #       = Log[(1 - exp{-x}) / 1] + Log[exp{x}]
1172    #       = Log[1 - exp{-x}] + x                           (3)
1173    # (2) is the "obvious" inverse, but (3) is more stable than (2) for large x.
1174    # For small x (e.g. x = 1e-10), (3) will become -inf since 1 - exp{-x} will
1175    # be zero. To fix this, we use 1 - exp{-x} approx x for small x > 0.
1176    #
1177    # In addition to the numerically stable derivation above, we clamp
1178    # small/large values to be congruent with the logic in:
1179    # tensorflow/core/kernels/softplus_op.h
1180    #
1181    # Finally, we set the input to one whenever the input is too large or too
1182    # small. This ensures that no unchosen codepath is +/- inf. This is
1183    # necessary to ensure the gradient doesn't get NaNs. Recall that the
1184    # gradient of `where` behaves like `pred*pred_true + (1-pred)*pred_false`
1185    # thus an `inf` in an unselected path results in `0*inf=nan`. We are careful
1186    # to overwrite `x` with ones only when we will never actually use this
1187    # value. Note that we use ones and not zeros since `log(expm1(0.)) = -inf`.
1188    threshold = np.log(np.finfo(x.dtype.as_numpy_dtype).eps) + 2.
1189    is_too_small = math_ops.less(x, np.exp(threshold))
1190    is_too_large = math_ops.greater(x, -threshold)
1191    too_small_value = math_ops.log(x)
1192    too_large_value = x
1193    # This `where` will ultimately be a NOP because we won't select this
1194    # codepath whenever we used the surrogate `ones_like`.
1195    x = array_ops.where_v2(
1196        math_ops.logical_or(is_too_small, is_too_large), array_ops.ones_like(x),
1197        x)
1198    y = x + math_ops.log(-math_ops.expm1(-x))  # == log(expm1(x))
1199    return array_ops.where_v2(
1200        is_too_small, too_small_value,
1201        array_ops.where_v2(is_too_large, too_large_value, y))
1202
1203
1204# TODO(b/35290280): Add unit-tests.
1205def dimension_size(x, axis):
1206  """Returns the size of a specific dimension."""
1207  # Since tf.gather isn't "constant-in, constant-out", we must first check the
1208  # static shape or fallback to dynamic shape.
1209  s = tensor_shape.dimension_value(
1210      x.shape.with_rank_at_least(np.abs(axis))[axis])
1211  if s is not None:
1212    return s
1213  return array_ops.shape(x)[axis]
1214
1215
1216def process_quadrature_grid_and_probs(quadrature_grid_and_probs,
1217                                      dtype,
1218                                      validate_args,
1219                                      name=None):
1220  """Validates quadrature grid, probs or computes them as necessary.
1221
1222  Args:
1223    quadrature_grid_and_probs: Python pair of `float`-like `Tensor`s
1224      representing the sample points and the corresponding (possibly
1225      normalized) weight.  When `None`, defaults to:
1226        `np.polynomial.hermite.hermgauss(deg=8)`.
1227    dtype: The expected `dtype` of `grid` and `probs`.
1228    validate_args: Python `bool`, default `False`. When `True` distribution
1229      parameters are checked for validity despite possibly degrading runtime
1230      performance. When `False` invalid inputs may silently render incorrect
1231      outputs.
1232    name: Python `str` name prefixed to Ops created by this class.
1233
1234  Returns:
1235     quadrature_grid_and_probs: Python pair of `float`-like `Tensor`s
1236      representing the sample points and the corresponding (possibly
1237      normalized) weight.
1238
1239  Raises:
1240    ValueError: if `quadrature_grid_and_probs is not None` and
1241      `len(quadrature_grid_and_probs[0]) != len(quadrature_grid_and_probs[1])`
1242  """
1243  with ops.name_scope(name, "process_quadrature_grid_and_probs",
1244                      [quadrature_grid_and_probs]):
1245    if quadrature_grid_and_probs is None:
1246      grid, probs = np.polynomial.hermite.hermgauss(deg=8)
1247      grid = grid.astype(dtype.as_numpy_dtype)
1248      probs = probs.astype(dtype.as_numpy_dtype)
1249      probs /= np.linalg.norm(probs, ord=1, keepdims=True)
1250      grid = ops.convert_to_tensor(grid, name="grid", dtype=dtype)
1251      probs = ops.convert_to_tensor(probs, name="probs", dtype=dtype)
1252      return grid, probs
1253
1254    grid, probs = tuple(quadrature_grid_and_probs)
1255    grid = ops.convert_to_tensor(grid, name="grid", dtype=dtype)
1256    probs = ops.convert_to_tensor(probs, name="unnormalized_probs", dtype=dtype)
1257    probs /= linalg_ops.norm(probs, ord=1, axis=-1, keepdims=True, name="probs")
1258
1259    def _static_event_size(x):
1260      """Returns the static size of a specific dimension or `None`."""
1261      return tensor_shape.dimension_value(x.shape.with_rank_at_least(1)[-1])
1262
1263    m, n = _static_event_size(probs), _static_event_size(grid)
1264    if m is not None and n is not None:
1265      if m != n:
1266        raise ValueError("`quadrature_grid_and_probs` must be a `tuple` of "
1267                         "same-length zero-th-dimension `Tensor`s "
1268                         "(saw lengths {}, {})".format(m, n))
1269    elif validate_args:
1270      assertions = [
1271          check_ops.assert_equal(
1272              dimension_size(probs, axis=-1),
1273              dimension_size(grid, axis=-1),
1274              message=("`quadrature_grid_and_probs` must be a `tuple` of "
1275                       "same-length zero-th-dimension `Tensor`s")),
1276      ]
1277      with ops.control_dependencies(assertions):
1278        grid = array_ops.identity(grid)
1279        probs = array_ops.identity(probs)
1280    return grid, probs
1281
1282
1283def pad(x, axis, front=False, back=False, value=0, count=1, name=None):
1284  """Pads `value` to the front and/or back of a `Tensor` dim, `count` times.
1285
1286  Args:
1287    x: `Tensor` input.
1288    axis: Scalar `int`-like `Tensor` representing the single dimension to pad.
1289      (Negative indexing is supported.)
1290    front: Python `bool`; if `True` the beginning of the `axis` dimension is
1291      padded with `value`, `count` times. If `False` no front padding is made.
1292    back: Python `bool`; if `True` the end of the `axis` dimension is padded
1293      with `value`, `count` times. If `False` no end padding is made.
1294    value: Scalar `int`-like `Tensor` representing the actual value added to the
1295      front and/or back of the `axis` dimension of `x`.
1296    count: Scalar `int`-like `Tensor` representing number of elements added to
1297      the front and/or back of the `axis` dimension of `x`. E.g., if `front =
1298      back = True` then `2 * count` elements are added.
1299    name: Python `str` name prefixed to Ops created by this function.
1300
1301  Returns:
1302    pad: The padded version of input `x`.
1303
1304  Raises:
1305    ValueError: if both `front` and `back` are `False`.
1306    TypeError: if `count` is not `int`-like.
1307  """
1308  with ops.name_scope(name, "pad", [x, value, count]):
1309    x = ops.convert_to_tensor(x, name="x")
1310    value = ops.convert_to_tensor(value, dtype=x.dtype, name="value")
1311    count = ops.convert_to_tensor(count, name="count")
1312    if not count.dtype.is_integer:
1313      raise TypeError("`count.dtype` (`{}`) must be `int`-like.".format(
1314          count.dtype.name))
1315    if not front and not back:
1316      raise ValueError("At least one of `front`, `back` must be `True`.")
1317    ndims = (
1318        x.shape.ndims if x.shape.ndims is not None else array_ops.rank(
1319            x, name="ndims"))
1320    axis = ops.convert_to_tensor(axis, name="axis")
1321    axis_ = tensor_util.constant_value(axis)
1322    if axis_ is not None:
1323      axis = axis_
1324      if axis < 0:
1325        axis = ndims + axis
1326      count_ = tensor_util.constant_value(count)
1327      if axis_ >= 0 or x.shape.ndims is not None:
1328        head = x.shape[:axis]
1329        middle = tensor_shape.TensorShape(None if count_ is None else (
1330            tensor_shape.dimension_at_index(x.shape, axis) + count_ *
1331            (front + back)))
1332        tail = x.shape[axis + 1:]
1333        final_shape = head.concatenate(middle.concatenate(tail))
1334      else:
1335        final_shape = None
1336    else:
1337      axis = array_ops.where_v2(axis < 0, ndims + axis, axis)
1338      final_shape = None
1339    x = array_ops.pad(
1340        x,
1341        paddings=array_ops.one_hot(
1342            indices=array_ops.stack(
1343                [axis if front else -1, axis if back else -1]),
1344            depth=ndims,
1345            axis=0,
1346            on_value=count,
1347            dtype=dtypes.int32),
1348        constant_values=value)
1349    if final_shape is not None:
1350      x.set_shape(final_shape)
1351    return x
1352
1353
1354def parent_frame_arguments():
1355  """Returns parent frame arguments.
1356
1357  When called inside a function, returns a dictionary with the caller's function
1358  arguments. These are positional arguments and keyword arguments (**kwargs),
1359  while variable arguments (*varargs) are excluded.
1360
1361  When called at global scope, this will return an empty dictionary, since there
1362  are no arguments.
1363
1364  WARNING: If caller function argument names are overloaded before invoking
1365  this method, then values will reflect the overloaded value. For this reason,
1366  we recommend calling `parent_frame_arguments` at the beginning of the
1367  function.
1368  """
1369  # All arguments and the names used for *varargs, and **kwargs
1370  arg_names, variable_arg_name, keyword_arg_name, local_vars = (
1371      tf_inspect._inspect.getargvalues(  # pylint: disable=protected-access
1372          # Get the first frame of the caller of this method.
1373          tf_inspect._inspect.stack()[1][0]))  # pylint: disable=protected-access
1374
1375  # Remove the *varargs, and flatten the **kwargs. Both are
1376  # nested lists.
1377  local_vars.pop(variable_arg_name, {})
1378  keyword_args = local_vars.pop(keyword_arg_name, {})
1379
1380  final_args = {}
1381  # Copy over arguments and their values. In general, local_vars
1382  # may contain more than just the arguments, since this method
1383  # can be called anywhere in a function.
1384  for arg_name in arg_names:
1385    final_args[arg_name] = local_vars.pop(arg_name)
1386  final_args.update(keyword_args)
1387
1388  return final_args
1389
1390
1391class AppendDocstring(object):
1392  """Helper class to promote private subclass docstring to public counterpart.
1393
1394  Example:
1395
1396  ```python
1397  class TransformedDistribution(Distribution):
1398    @distribution_util.AppendDocstring(
1399      additional_note="A special note!",
1400      kwargs_dict={"foo": "An extra arg."})
1401    def _prob(self, y, foo=None):
1402      pass
1403  ```
1404
1405  In this case, the `AppendDocstring` decorator appends the `additional_note` to
1406  the docstring of `prob` (not `_prob`) and adds a new `kwargs`
1407  section with each dictionary item as a bullet-point.
1408
1409  For a more detailed example, see `TransformedDistribution`.
1410  """
1411
1412  def __init__(self, additional_note="", kwargs_dict=None):
1413    """Initializes the AppendDocstring object.
1414
1415    Args:
1416      additional_note: Python string added as additional docstring to public
1417        version of function.
1418      kwargs_dict: Python string/string dictionary representing specific kwargs
1419        expanded from the **kwargs input.
1420
1421    Raises:
1422      ValueError: if kwargs_dict.key contains whitespace.
1423      ValueError: if kwargs_dict.value contains newlines.
1424    """
1425    self._additional_note = additional_note
1426    if kwargs_dict:
1427      bullets = []
1428      for key in sorted(kwargs_dict.keys()):
1429        value = kwargs_dict[key]
1430        if any(x.isspace() for x in key):
1431          raise ValueError("Parameter name \"%s\" contains whitespace." % key)
1432        value = value.lstrip()
1433        if "\n" in value:
1434          raise ValueError(
1435              "Parameter description for \"%s\" contains newlines." % key)
1436        bullets.append("*  `%s`: %s" % (key, value))
1437      self._additional_note += ("\n\n##### `kwargs`:\n\n" + "\n".join(bullets))
1438
1439  def __call__(self, fn):
1440
1441    @functools.wraps(fn)
1442    def _fn(*args, **kwargs):
1443      return fn(*args, **kwargs)
1444
1445    if _fn.__doc__ is None:
1446      _fn.__doc__ = self._additional_note
1447    else:
1448      _fn.__doc__ += "\n%s" % self._additional_note
1449    return _fn
1450