• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for probability distributions."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import functools
22import hashlib
23import numpy as np
24
25from tensorflow.python.framework import constant_op
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import ops
28from tensorflow.python.framework import tensor_shape
29from tensorflow.python.framework import tensor_util
30from tensorflow.python.ops import array_ops
31from tensorflow.python.ops import check_ops
32from tensorflow.python.ops import control_flow_ops
33from tensorflow.python.ops import linalg_ops
34from tensorflow.python.ops import math_ops
35from tensorflow.python.ops import nn
36from tensorflow.python.util import tf_inspect
37
38
39def assert_integer_form(
40    x, data=None, summarize=None, message=None,
41    int_dtype=None, name="assert_integer_form"):
42  """Assert that x has integer components (or floats equal to integers).
43
44  Args:
45    x: Floating-point `Tensor`
46    data: The tensors to print out if the condition is `False`. Defaults to
47      error message and first few entries of `x` and `y`.
48    summarize: Print this many entries of each tensor.
49    message: A string to prefix to the default message.
50    int_dtype: A `tf.dtype` used to cast the float to. The default (`None`)
51      implies the smallest possible signed int will be used for casting.
52    name: A name for this operation (optional).
53
54  Returns:
55    Op raising `InvalidArgumentError` if `cast(x, int_dtype) != x`.
56  """
57  with ops.name_scope(name, values=[x, data]):
58    x = ops.convert_to_tensor(x, name="x")
59    if x.dtype.is_integer:
60      return control_flow_ops.no_op()
61    message = message or "{} has non-integer components".format(x)
62    if int_dtype is None:
63      try:
64        int_dtype = {
65            dtypes.float16: dtypes.int16,
66            dtypes.float32: dtypes.int32,
67            dtypes.float64: dtypes.int64,
68        }[x.dtype.base_dtype]
69      except KeyError:
70        raise TypeError("Unrecognized type {}".format(x.dtype.name))
71    return check_ops.assert_equal(
72        x, math_ops.cast(math_ops.cast(x, int_dtype), x.dtype),
73        data=data, summarize=summarize, message=message, name=name)
74
75
76def assert_symmetric(matrix):
77  matrix_t = array_ops.matrix_transpose(matrix)
78  return control_flow_ops.with_dependencies(
79      [check_ops.assert_equal(matrix, matrix_t)], matrix)
80
81
82def embed_check_nonnegative_integer_form(
83    x, name="embed_check_nonnegative_integer_form"):
84  """Assert x is a non-negative tensor, and optionally of integers."""
85  with ops.name_scope(name, values=[x]):
86    x = ops.convert_to_tensor(x, name="x")
87    assertions = [
88        check_ops.assert_non_negative(
89            x, message="'{}' must be non-negative.".format(x)),
90    ]
91    if not x.dtype.is_integer:
92      assertions += [
93          assert_integer_form(
94              x, message="'{}' cannot contain fractional components.".format(
95                  x)),
96      ]
97    return control_flow_ops.with_dependencies(assertions, x)
98
99
100def same_dynamic_shape(a, b):
101  """Returns whether a and b have the same dynamic shape.
102
103  Args:
104    a: `Tensor`
105    b: `Tensor`
106
107  Returns:
108    `bool` `Tensor` representing if both tensors have the same shape.
109  """
110  a = ops.convert_to_tensor(a, name="a")
111  b = ops.convert_to_tensor(b, name="b")
112
113  # Here we can't just do math_ops.equal(a.shape, b.shape), since
114  # static shape inference may break the equality comparison between
115  # shape(a) and shape(b) in math_ops.equal.
116  def all_shapes_equal():
117    return math_ops.reduce_all(math_ops.equal(
118        array_ops.concat([array_ops.shape(a), array_ops.shape(b)], 0),
119        array_ops.concat([array_ops.shape(b), array_ops.shape(a)], 0)))
120
121  # One of the shapes isn't fully defined, so we need to use the dynamic
122  # shape.
123  return control_flow_ops.cond(
124      math_ops.equal(array_ops.rank(a), array_ops.rank(b)),
125      all_shapes_equal,
126      lambda: constant_op.constant(False))
127
128
129def maybe_get_static_value(x, dtype=None):
130  """Helper which tries to return a static value.
131
132  Given `x`, extract it's value statically, optionally casting to a specific
133  dtype. If this is not possible, None is returned.
134
135  Args:
136    x: `Tensor` for which to extract a value statically.
137    dtype: Optional dtype to cast to.
138
139  Returns:
140    Statically inferred value if possible, otherwise None.
141  """
142  if x is None:
143    return x
144  try:
145    # This returns an np.ndarray.
146    x_ = tensor_util.constant_value(x)
147  except TypeError:
148    x_ = x
149  if x_ is None or dtype is None:
150    return x_
151  return np.array(x_, dtype)
152
153
154def get_logits_and_probs(logits=None,
155                         probs=None,
156                         multidimensional=False,
157                         validate_args=False,
158                         name="get_logits_and_probs",
159                         dtype=None):
160  """Converts logit to probabilities (or vice-versa), and returns both.
161
162  Args:
163    logits: Floating-point `Tensor` representing log-odds.
164    probs: Floating-point `Tensor` representing probabilities.
165    multidimensional: Python `bool`, default `False`.
166      If `True`, represents whether the last dimension of `logits` or `probs`,
167      a `[N1, N2, ...  k]` dimensional tensor, representing the
168      logit or probability of `shape[-1]` classes.
169    validate_args: Python `bool`, default `False`. When `True`, either assert
170      `0 <= probs <= 1` (if not `multidimensional`) or that the last dimension
171      of `probs` sums to one.
172    name: A name for this operation (optional).
173    dtype: `tf.DType` to prefer when converting args to `Tensor`s.
174
175  Returns:
176    logits, probs: Tuple of `Tensor`s. If `probs` has an entry that is `0` or
177      `1`, then the corresponding entry in the returned logit will be `-Inf` and
178      `Inf` respectively.
179
180  Raises:
181    ValueError: if neither `probs` nor `logits` were passed in, or both were.
182  """
183  with ops.name_scope(name, values=[probs, logits]):
184    if (probs is None) == (logits is None):
185      raise ValueError("Must pass probs or logits, but not both.")
186
187    if probs is None:
188      logits = ops.convert_to_tensor(logits, name="logits", dtype=dtype)
189      if not logits.dtype.is_floating:
190        raise TypeError("logits must having floating type.")
191      # We can early return since we constructed probs and therefore know
192      # they're valid.
193      if multidimensional:
194        if validate_args:
195          logits = embed_check_categorical_event_shape(logits)
196        return logits, nn.softmax(logits, name="probs")
197      return logits, math_ops.sigmoid(logits, name="probs")
198
199    probs = ops.convert_to_tensor(probs, name="probs", dtype=dtype)
200    if not probs.dtype.is_floating:
201      raise TypeError("probs must having floating type.")
202
203    if validate_args:
204      with ops.name_scope("validate_probs"):
205        one = constant_op.constant(1., probs.dtype)
206        dependencies = [check_ops.assert_non_negative(probs)]
207        if multidimensional:
208          probs = embed_check_categorical_event_shape(probs)
209          dependencies += [
210              check_ops.assert_near(
211                  math_ops.reduce_sum(probs, -1),
212                  one,
213                  message="probs does not sum to 1.")
214          ]
215        else:
216          dependencies += [check_ops.assert_less_equal(
217              probs, one, message="probs has components greater than 1.")]
218        probs = control_flow_ops.with_dependencies(dependencies, probs)
219
220    with ops.name_scope("logits"):
221      if multidimensional:
222        # Here we don't compute the multidimensional case, in a manner
223        # consistent with respect to the unidimensional case. We do so
224        # following the TF convention. Typically, you might expect to see
225        # logits = log(probs) - log(probs[pivot]). A side-effect of
226        # being consistent with the TF approach is that the unidimensional case
227        # implicitly handles the second dimension but the multidimensional case
228        # explicitly keeps the pivot dimension.
229        return math_ops.log(probs), probs
230      return math_ops.log(probs) - math_ops.log1p(-1. * probs), probs
231
232
233def _is_known_unsigned_by_dtype(dt):
234  """Helper returning True if dtype is known to be unsigned."""
235  return {
236      dtypes.bool: True,
237      dtypes.uint8: True,
238      dtypes.uint16: True,
239  }.get(dt.base_dtype, False)
240
241
242def _is_known_signed_by_dtype(dt):
243  """Helper returning True if dtype is known to be signed."""
244  return {
245      dtypes.float16: True,
246      dtypes.float32: True,
247      dtypes.float64: True,
248      dtypes.int8: True,
249      dtypes.int16: True,
250      dtypes.int32: True,
251      dtypes.int64: True,
252  }.get(dt.base_dtype, False)
253
254
255def _is_known_dtype(dt):
256  """Helper returning True if dtype is known."""
257  return _is_known_unsigned_by_dtype(dt) or _is_known_signed_by_dtype(dt)
258
259
260def _largest_integer_by_dtype(dt):
261  """Helper returning the largest integer exactly representable by dtype."""
262  if not _is_known_dtype(dt):
263    raise TypeError("Unrecognized dtype: {}".format(dt.name))
264  if dt.is_floating:
265    return int(2**(np.finfo(dt.as_numpy_dtype).nmant + 1))
266  if dt.is_integer:
267    return np.iinfo(dt.as_numpy_dtype).max
268  if dt.base_dtype == dtypes.bool:
269    return int(1)
270  # We actually can't land here but keep the case for completeness.
271  raise TypeError("Unrecognized dtype: {}".format(dt.name))
272
273
274def _smallest_integer_by_dtype(dt):
275  """Helper returning the smallest integer exactly representable by dtype."""
276  if not _is_known_dtype(dt):
277    raise TypeError("Unrecognized dtype: {}".format(dt.name))
278  if _is_known_unsigned_by_dtype(dt):
279    return 0
280  return -1 * _largest_integer_by_dtype(dt)
281
282
283def _is_integer_like_by_dtype(dt):
284  """Helper returning True if dtype.is_integer or is `bool`."""
285  if not _is_known_dtype(dt):
286    raise TypeError("Unrecognized dtype: {}".format(dt.name))
287  return dt.is_integer or dt.base_dtype == dtypes.bool
288
289
290def embed_check_categorical_event_shape(
291    categorical_param,
292    name="embed_check_categorical_event_shape"):
293  """Embeds checks that categorical distributions don't have too many classes.
294
295  A categorical-type distribution is one which, e.g., returns the class label
296  rather than a one-hot encoding.  E.g., `Categorical(probs)`.
297
298  Since distributions output samples in the same dtype as the parameters, we
299  must ensure that casting doesn't lose precision. That is, the
300  `parameter.dtype` implies a maximum number of classes. However, since shape is
301  `int32` and categorical variables are presumed to be indexes into a `Tensor`,
302  we must also ensure that the number of classes is no larger than the largest
303  possible `int32` index, i.e., `2**31-1`.
304
305  In other words the number of classes, `K`, must satisfy the following
306  condition:
307
308  ```python
309  K <= min(
310      int(2**31 - 1),  # Largest float as an index.
311      {
312          dtypes.float16: int(2**11),   # Largest int as a float16.
313          dtypes.float32: int(2**24),
314          dtypes.float64: int(2**53),
315      }.get(categorical_param.dtype.base_dtype, 0))
316  ```
317
318  Args:
319    categorical_param: Floating-point `Tensor` representing parameters of
320      distribution over categories. The rightmost shape is presumed to be the
321      number of categories.
322    name: A name for this operation (optional).
323
324  Returns:
325    categorical_param: Input `Tensor` with appropriate assertions embedded.
326
327  Raises:
328    TypeError: if `categorical_param` has an unknown `dtype`.
329    ValueError: if we can statically identify `categorical_param` as being too
330      large (for being closed under int32/float casting).
331  """
332  with ops.name_scope(name, values=[categorical_param]):
333    x = ops.convert_to_tensor(categorical_param, name="categorical_param")
334    # The size must not exceed both of:
335    # - The largest possible int32 (since categorical values are presumed to be
336    #   indexes into a Tensor).
337    # - The largest possible integer exactly representable under the given
338    #   floating-point dtype (since we need to cast to/from).
339    #
340    # The chosen floating-point thresholds are 2**(1 + mantissa_bits).
341    # For more details, see:
342    # https://en.wikipedia.org/wiki/Floating-point_arithmetic#Internal_representation
343    x_dtype = x.dtype.base_dtype
344    max_event_size = (_largest_integer_by_dtype(x_dtype)
345                      if x_dtype.is_floating else 0)
346    if max_event_size == 0:
347      raise TypeError("Unable to validate size of unrecognized dtype "
348                      "({}).".format(x_dtype.name))
349    try:
350      x_shape_static = x.get_shape().with_rank_at_least(1)
351    except ValueError:
352      raise ValueError("A categorical-distribution parameter must have "
353                       "at least 1 dimension.")
354    if tensor_shape.dimension_value(x_shape_static[-1]) is not None:
355      event_size = x_shape_static.dims[-1].value
356      if event_size < 2:
357        raise ValueError("A categorical-distribution parameter must have at "
358                         "least 2 events.")
359      if event_size > max_event_size:
360        raise ValueError(
361            "Number of classes exceeds `dtype` precision, i.e., "
362            "{} implies shape ({}) cannot exceed {}.".format(
363                x_dtype.name, event_size, max_event_size))
364      return x
365    else:
366      event_size = array_ops.shape(x, name="x_shape")[-1]
367      return control_flow_ops.with_dependencies([
368          check_ops.assert_rank_at_least(
369              x, 1, message=("A categorical-distribution parameter must have "
370                             "at least 1 dimension.")),
371          check_ops.assert_greater_equal(
372              array_ops.shape(x)[-1], 2,
373              message=("A categorical-distribution parameter must have at "
374                       "least 2 events.")),
375          check_ops.assert_less_equal(
376              event_size, max_event_size,
377              message="Number of classes exceeds `dtype` precision, "
378                      "i.e., {} dtype cannot exceed {} shape.".format(
379                          x_dtype.name, max_event_size)),
380      ], x)
381
382
383def embed_check_integer_casting_closed(
384    x,
385    target_dtype,
386    assert_nonnegative=True,
387    name="embed_check_casting_closed"):
388  """Ensures integers remain unaffected despite casting to/from int/float types.
389
390  Example integer-types: `uint8`, `int32`, `bool`.
391  Example floating-types: `float32`, `float64`.
392
393  The largest possible integer representable by an IEEE754 floating-point is
394  `2**(1 + mantissa_bits)` yet the largest possible integer as an int-type is
395  `2**(bits - 1) - 1`. This function ensures that a `Tensor` purporting to have
396  integer-form values can be cast to some other type without loss of precision.
397
398  The smallest representable integer is the negative of the largest
399  representable integer, except for types: `uint8`, `uint16`, `bool`. For these
400  types, the smallest representable integer is `0`.
401
402  Args:
403    x: `Tensor` representing integer-form values.
404    target_dtype: TF `dtype` under which `x` should have identical values.
405    assert_nonnegative: `bool` indicating `x` should contain nonnegative values.
406    name: A name for this operation (optional).
407
408  Returns:
409    x: Input `Tensor` with appropriate assertions embedded.
410
411  Raises:
412    TypeError: if `x` is neither integer- nor floating-type.
413    TypeError: if `target_dtype` is neither integer- nor floating-type.
414    TypeError: if neither `x` nor `target_dtype` are integer-type.
415  """
416
417  with ops.name_scope(name, values=[x]):
418    x = ops.convert_to_tensor(x, name="x")
419    if (not _is_integer_like_by_dtype(x.dtype)
420        and not x.dtype.is_floating):
421      raise TypeError("{}.dtype must be floating- or "
422                      "integer-type.".format(x.dtype.name))
423    if (not _is_integer_like_by_dtype(target_dtype)
424        and not target_dtype.is_floating):
425      raise TypeError("target_dtype ({}) must be floating- or "
426                      "integer-type.".format(target_dtype.name))
427    if (not _is_integer_like_by_dtype(x.dtype)
428        and not _is_integer_like_by_dtype(target_dtype)):
429      raise TypeError("At least one of {}.dtype ({}) and target_dtype ({}) "
430                      "must be integer-type.".format(
431                          x, x.dtype.name, target_dtype.name))
432
433    assertions = []
434    if assert_nonnegative:
435      assertions += [
436          check_ops.assert_non_negative(
437              x, message="Elements must be non-negative."),
438      ]
439
440    if x.dtype.is_floating:
441      # Being here means _is_integer_like_by_dtype(target_dtype) = True.
442      # Since this check implies the magnitude check below, we need only it.
443      assertions += [
444          assert_integer_form(
445              x, int_dtype=target_dtype,
446              message="Elements must be {}-equivalent.".format(
447                  target_dtype.name)),
448      ]
449    else:
450      if (_largest_integer_by_dtype(x.dtype)
451          > _largest_integer_by_dtype(target_dtype)):
452        # Cast may lose integer precision.
453        assertions += [
454            check_ops.assert_less_equal(
455                x, _largest_integer_by_dtype(target_dtype),
456                message=("Elements cannot exceed {}.".format(
457                    _largest_integer_by_dtype(target_dtype)))),
458        ]
459      if (not assert_nonnegative and
460          (_smallest_integer_by_dtype(x.dtype)
461           < _smallest_integer_by_dtype(target_dtype))):
462        assertions += [
463            check_ops.assert_greater_equal(
464                x, _smallest_integer_by_dtype(target_dtype),
465                message=("Elements cannot be smaller than {}.".format(
466                    _smallest_integer_by_dtype(target_dtype)))),
467        ]
468
469    if not assertions:
470      return x
471    return control_flow_ops.with_dependencies(assertions, x)
472
473
474def log_combinations(n, counts, name="log_combinations"):
475  """Multinomial coefficient.
476
477  Given `n` and `counts`, where `counts` has last dimension `k`, we compute
478  the multinomial coefficient as:
479
480  ```n! / sum_i n_i!```
481
482  where `i` runs over all `k` classes.
483
484  Args:
485    n: Floating-point `Tensor` broadcastable with `counts`. This represents `n`
486      outcomes.
487    counts: Floating-point `Tensor` broadcastable with `n`. This represents
488      counts in `k` classes, where `k` is the last dimension of the tensor.
489    name: A name for this operation (optional).
490
491  Returns:
492    `Tensor` representing the multinomial coefficient between `n` and `counts`.
493  """
494  # First a bit about the number of ways counts could have come in:
495  # E.g. if counts = [1, 2], then this is 3 choose 2.
496  # In general, this is (sum counts)! / sum(counts!)
497  # The sum should be along the last dimension of counts. This is the
498  # "distribution" dimension. Here n a priori represents the sum of counts.
499  with ops.name_scope(name, values=[n, counts]):
500    n = ops.convert_to_tensor(n, name="n")
501    counts = ops.convert_to_tensor(counts, name="counts")
502    total_permutations = math_ops.lgamma(n + 1)
503    counts_factorial = math_ops.lgamma(counts + 1)
504    redundant_permutations = math_ops.reduce_sum(counts_factorial, axis=[-1])
505    return total_permutations - redundant_permutations
506
507
508def matrix_diag_transform(matrix, transform=None, name=None):
509  """Transform diagonal of [batch-]matrix, leave rest of matrix unchanged.
510
511  Create a trainable covariance defined by a Cholesky factor:
512
513  ```python
514  # Transform network layer into 2 x 2 array.
515  matrix_values = tf.contrib.layers.fully_connected(activations, 4)
516  matrix = tf.reshape(matrix_values, (batch_size, 2, 2))
517
518  # Make the diagonal positive. If the upper triangle was zero, this would be a
519  # valid Cholesky factor.
520  chol = matrix_diag_transform(matrix, transform=tf.nn.softplus)
521
522  # LinearOperatorLowerTriangular ignores the upper triangle.
523  operator = LinearOperatorLowerTriangular(chol)
524  ```
525
526  Example of heteroskedastic 2-D linear regression.
527
528  ```python
529  tfd = tfp.distributions
530
531  # Get a trainable Cholesky factor.
532  matrix_values = tf.contrib.layers.fully_connected(activations, 4)
533  matrix = tf.reshape(matrix_values, (batch_size, 2, 2))
534  chol = matrix_diag_transform(matrix, transform=tf.nn.softplus)
535
536  # Get a trainable mean.
537  mu = tf.contrib.layers.fully_connected(activations, 2)
538
539  # This is a fully trainable multivariate normal!
540  dist = tfd.MultivariateNormalTriL(mu, chol)
541
542  # Standard log loss. Minimizing this will "train" mu and chol, and then dist
543  # will be a distribution predicting labels as multivariate Gaussians.
544  loss = -1 * tf.reduce_mean(dist.log_prob(labels))
545  ```
546
547  Args:
548    matrix:  Rank `R` `Tensor`, `R >= 2`, where the last two dimensions are
549      equal.
550    transform:  Element-wise function mapping `Tensors` to `Tensors`. To
551      be applied to the diagonal of `matrix`. If `None`, `matrix` is returned
552      unchanged. Defaults to `None`.
553    name:  A name to give created ops.
554      Defaults to "matrix_diag_transform".
555
556  Returns:
557    A `Tensor` with same shape and `dtype` as `matrix`.
558  """
559  with ops.name_scope(name, "matrix_diag_transform", [matrix]):
560    matrix = ops.convert_to_tensor(matrix, name="matrix")
561    if transform is None:
562      return matrix
563    # Replace the diag with transformed diag.
564    diag = array_ops.matrix_diag_part(matrix)
565    transformed_diag = transform(diag)
566    transformed_mat = array_ops.matrix_set_diag(matrix, transformed_diag)
567
568  return transformed_mat
569
570
571def rotate_transpose(x, shift, name="rotate_transpose"):
572  """Circularly moves dims left or right.
573
574  Effectively identical to:
575
576  ```python
577  numpy.transpose(x, numpy.roll(numpy.arange(len(x.shape)), shift))
578  ```
579
580  When `validate_args=False` additional graph-runtime checks are
581  performed. These checks entail moving data from to GPU to CPU.
582
583  Example:
584
585  ```python
586  x = tf.random_normal([1, 2, 3, 4])  # Tensor of shape [1, 2, 3, 4].
587  rotate_transpose(x, -1).shape == [2, 3, 4, 1]
588  rotate_transpose(x, -2).shape == [3, 4, 1, 2]
589  rotate_transpose(x,  1).shape == [4, 1, 2, 3]
590  rotate_transpose(x,  2).shape == [3, 4, 1, 2]
591  rotate_transpose(x,  7).shape == rotate_transpose(x, 3).shape  # [2, 3, 4, 1]
592  rotate_transpose(x, -7).shape == rotate_transpose(x, -3).shape  # [4, 1, 2, 3]
593  ```
594
595  Args:
596    x: `Tensor`.
597    shift: `Tensor`. Number of dimensions to transpose left (shift<0) or
598      transpose right (shift>0).
599    name: Python `str`. The name to give this op.
600
601  Returns:
602    rotated_x: Input `Tensor` with dimensions circularly rotated by shift.
603
604  Raises:
605    TypeError: if shift is not integer type.
606  """
607  with ops.name_scope(name, values=[x, shift]):
608    x = ops.convert_to_tensor(x, name="x")
609    shift = ops.convert_to_tensor(shift, name="shift")
610    # We do not assign back to preserve constant-ness.
611    check_ops.assert_integer(shift)
612    shift_value_static = tensor_util.constant_value(shift)
613    ndims = x.get_shape().ndims
614    if ndims is not None and shift_value_static is not None:
615      if ndims < 2: return x
616      shift_value_static = np.sign(shift_value_static) * (
617          abs(shift_value_static) % ndims)
618      if shift_value_static == 0: return x
619      perm = np.roll(np.arange(ndims), shift_value_static)
620      return array_ops.transpose(x, perm=perm)
621    else:
622      # Consider if we always had a positive shift, and some specified
623      # direction.
624      # When shifting left we want the new array:
625      #   last(x, n-shift) + first(x, shift)
626      # and if shifting right then we want:
627      #   last(x, shift) + first(x, n-shift)
628      # Observe that last(a) == slice(a, n) and first(a) == slice(0, a).
629      # Also, we can encode direction and shift as one: direction * shift.
630      # Combining these facts, we have:
631      #   a = cond(shift<0, -shift, n-shift)
632      #   last(x, n-a) + first(x, a) == x[a:n] + x[0:a]
633      # Finally, we transform shift by modulo length so it can be specified
634      # independently from the array upon which it operates (like python).
635      ndims = array_ops.rank(x)
636      shift = array_ops.where(math_ops.less(shift, 0),
637                              math_ops.mod(-shift, ndims),
638                              ndims - math_ops.mod(shift, ndims))
639      first = math_ops.range(0, shift)
640      last = math_ops.range(shift, ndims)
641      perm = array_ops.concat([last, first], 0)
642      return array_ops.transpose(x, perm=perm)
643
644
645def pick_vector(cond,
646                true_vector,
647                false_vector,
648                name="pick_vector"):
649  """Picks possibly different length row `Tensor`s based on condition.
650
651  Value `Tensor`s should have exactly one dimension.
652
653  If `cond` is a python Boolean or `tf.constant` then either `true_vector` or
654  `false_vector` is immediately returned. I.e., no graph nodes are created and
655  no validation happens.
656
657  Args:
658    cond: `Tensor`. Must have `dtype=tf.bool` and be scalar.
659    true_vector: `Tensor` of one dimension. Returned when cond is `True`.
660    false_vector: `Tensor` of one dimension. Returned when cond is `False`.
661    name: Python `str`. The name to give this op.
662
663  Example:
664
665  ```python
666  pick_vector(tf.less(0, 5), tf.range(10, 12), tf.range(15, 18))  # [10, 11]
667  pick_vector(tf.less(5, 0), tf.range(10, 12), tf.range(15, 18))  # [15, 16, 17]
668  ```
669
670  Returns:
671    true_or_false_vector: `Tensor`.
672
673  Raises:
674    TypeError: if `cond.dtype != tf.bool`
675    TypeError: if `cond` is not a constant and
676      `true_vector.dtype != false_vector.dtype`
677  """
678  with ops.name_scope(name, values=(cond, true_vector, false_vector)):
679    cond = ops.convert_to_tensor(cond, name="cond")
680    if cond.dtype != dtypes.bool:
681      raise TypeError("%s.dtype=%s which is not %s" %
682                      (cond, cond.dtype, dtypes.bool))
683    cond_value_static = tensor_util.constant_value(cond)
684    if cond_value_static is not None:
685      return true_vector if cond_value_static else false_vector
686    true_vector = ops.convert_to_tensor(true_vector, name="true_vector")
687    false_vector = ops.convert_to_tensor(false_vector, name="false_vector")
688    if true_vector.dtype != false_vector.dtype:
689      raise TypeError(
690          "%s.dtype=%s does not match %s.dtype=%s"
691          % (true_vector, true_vector.dtype,
692             false_vector, false_vector.dtype))
693    n = array_ops.shape(true_vector)[0]
694    return array_ops.slice(
695        array_ops.concat([true_vector, false_vector], 0),
696        [array_ops.where(cond, 0, n)], [array_ops.where(cond, n, -1)])
697
698
699def prefer_static_broadcast_shape(
700    shape1, shape2, name="prefer_static_broadcast_shape"):
701  """Convenience function which statically broadcasts shape when possible.
702
703  Args:
704    shape1:  `1-D` integer `Tensor`.  Already converted to tensor!
705    shape2:  `1-D` integer `Tensor`.  Already converted to tensor!
706    name:  A string name to prepend to created ops.
707
708  Returns:
709    The broadcast shape, either as `TensorShape` (if broadcast can be done
710      statically), or as a `Tensor`.
711  """
712  with ops.name_scope(name, values=[shape1, shape2]):
713    def make_shape_tensor(x):
714      return ops.convert_to_tensor(x, name="shape", dtype=dtypes.int32)
715
716    def get_tensor_shape(s):
717      if isinstance(s, tensor_shape.TensorShape):
718        return s
719      s_ = tensor_util.constant_value(make_shape_tensor(s))
720      if s_ is not None:
721        return tensor_shape.TensorShape(s_)
722      return None
723
724    def get_shape_tensor(s):
725      if not isinstance(s, tensor_shape.TensorShape):
726        return make_shape_tensor(s)
727      if s.is_fully_defined():
728        return make_shape_tensor(s.as_list())
729      raise ValueError("Cannot broadcast from partially "
730                       "defined `TensorShape`.")
731
732    shape1_ = get_tensor_shape(shape1)
733    shape2_ = get_tensor_shape(shape2)
734    if shape1_ is not None and shape2_ is not None:
735      return array_ops.broadcast_static_shape(shape1_, shape2_)
736
737    shape1_ = get_shape_tensor(shape1)
738    shape2_ = get_shape_tensor(shape2)
739    return array_ops.broadcast_dynamic_shape(shape1_, shape2_)
740
741
742def prefer_static_rank(x):
743  """Return static rank of tensor `x` if available, else `tf.rank(x)`.
744
745  Args:
746    x: `Tensor` (already converted).
747
748  Returns:
749    Numpy array (if static rank is obtainable), else `Tensor`.
750  """
751  return prefer_static_value(array_ops.rank(x))
752
753
754def prefer_static_shape(x):
755  """Return static shape of tensor `x` if available, else `tf.shape(x)`.
756
757  Args:
758    x: `Tensor` (already converted).
759
760  Returns:
761    Numpy array (if static shape is obtainable), else `Tensor`.
762  """
763  return prefer_static_value(array_ops.shape(x))
764
765
766def prefer_static_value(x):
767  """Return static value of tensor `x` if available, else `x`.
768
769  Args:
770    x: `Tensor` (already converted).
771
772  Returns:
773    Numpy array (if static value is obtainable), else `Tensor`.
774  """
775  static_x = tensor_util.constant_value(x)
776  if static_x is not None:
777    return static_x
778  return x
779
780
781def gen_new_seed(seed, salt):
782  """Generate a new seed, from the given seed and salt."""
783  if seed is None:
784    return None
785  string = (str(seed) + salt).encode("utf-8")
786  return int(hashlib.md5(string).hexdigest()[:8], 16) & 0x7FFFFFFF
787
788
789def fill_triangular(x, upper=False, name=None):
790  """Creates a (batch of) triangular matrix from a vector of inputs.
791
792  Created matrix can be lower- or upper-triangular. (It is more efficient to
793  create the matrix as upper or lower, rather than transpose.)
794
795  Triangular matrix elements are filled in a clockwise spiral. See example,
796  below.
797
798  If `x.get_shape()` is `[b1, b2, ..., bB, d]` then the output shape is
799  `[b1, b2, ..., bB, n, n]` where `n` is such that `d = n(n+1)/2`, i.e.,
800  `n = int(np.sqrt(0.25 + 2. * m) - 0.5)`.
801
802  Example:
803
804  ```python
805  fill_triangular([1, 2, 3, 4, 5, 6])
806  # ==> [[4, 0, 0],
807  #      [6, 5, 0],
808  #      [3, 2, 1]]
809
810  fill_triangular([1, 2, 3, 4, 5, 6], upper=True)
811  # ==> [[1, 2, 3],
812  #      [0, 5, 6],
813  #      [0, 0, 4]]
814  ```
815
816  For comparison, a pure numpy version of this function can be found in
817  `util_test.py`, function `_fill_triangular`.
818
819  Args:
820    x: `Tensor` representing lower (or upper) triangular elements.
821    upper: Python `bool` representing whether output matrix should be upper
822      triangular (`True`) or lower triangular (`False`, default).
823    name: Python `str`. The name to give this op.
824
825  Returns:
826    tril: `Tensor` with lower (or upper) triangular elements filled from `x`.
827
828  Raises:
829    ValueError: if `x` cannot be mapped to a triangular matrix.
830  """
831
832  with ops.name_scope(name, "fill_triangular", values=[x]):
833    x = ops.convert_to_tensor(x, name="x")
834    if tensor_shape.dimension_value(
835        x.shape.with_rank_at_least(1)[-1]) is not None:
836      # Formula derived by solving for n: m = n(n+1)/2.
837      m = np.int32(x.shape.dims[-1].value)
838      n = np.sqrt(0.25 + 2. * m) - 0.5
839      if n != np.floor(n):
840        raise ValueError("Input right-most shape ({}) does not "
841                         "correspond to a triangular matrix.".format(m))
842      n = np.int32(n)
843      static_final_shape = x.shape[:-1].concatenate([n, n])
844    else:
845      m = array_ops.shape(x)[-1]
846      # For derivation, see above. Casting automatically lops off the 0.5, so we
847      # omit it.  We don't validate n is an integer because this has
848      # graph-execution cost; an error will be thrown from the reshape, below.
849      n = math_ops.cast(
850          math_ops.sqrt(0.25 + math_ops.cast(2 * m, dtype=dtypes.float32)),
851          dtype=dtypes.int32)
852      static_final_shape = x.shape.with_rank_at_least(1)[:-1].concatenate(
853          [None, None])
854    # We now concatenate the "tail" of `x` to `x` (and reverse one of them).
855    #
856    # We do this based on the insight that the input `x` provides `ceil(n/2)`
857    # rows of an `n x n` matrix, some of which will get zeroed out being on the
858    # wrong side of the diagonal. The first row will not get zeroed out at all,
859    # and we need `floor(n/2)` more rows, so the first is what we omit from
860    # `x_tail`. If we then stack those `ceil(n/2)` rows with the `floor(n/2)`
861    # rows provided by a reversed tail, it is exactly the other set of elements
862    # of the reversed tail which will be zeroed out for being on the wrong side
863    # of the diagonal further up/down the matrix. And, in doing-so, we've filled
864    # the triangular matrix in a clock-wise spiral pattern. Neat!
865    #
866    # Try it out in numpy:
867    #  n = 3
868    #  x = np.arange(n * (n + 1) / 2)
869    #  m = x.shape[0]
870    #  n = np.int32(np.sqrt(.25 + 2 * m) - .5)
871    #  x_tail = x[(m - (n**2 - m)):]
872    #  np.concatenate([x_tail, x[::-1]], 0).reshape(n, n)  # lower
873    #  # ==> array([[3, 4, 5],
874    #               [5, 4, 3],
875    #               [2, 1, 0]])
876    #  np.concatenate([x, x_tail[::-1]], 0).reshape(n, n)  # upper
877    #  # ==> array([[0, 1, 2],
878    #               [3, 4, 5],
879    #               [5, 4, 3]])
880    #
881    # Note that we can't simply do `x[..., -(n**2 - m):]` because this doesn't
882    # correctly handle `m == n == 1`. Hence, we do nonnegative indexing.
883    # Furthermore observe that:
884    #   m - (n**2 - m)
885    #   = n**2 / 2 + n / 2 - (n**2 - n**2 / 2 + n / 2)
886    #   = 2 (n**2 / 2 + n / 2) - n**2
887    #   = n**2 + n - n**2
888    #   = n
889    ndims = prefer_static_rank(x)
890    if upper:
891      x_list = [x, array_ops.reverse(x[..., n:], axis=[ndims - 1])]
892    else:
893      x_list = [x[..., n:], array_ops.reverse(x, axis=[ndims - 1])]
894    new_shape = (
895        static_final_shape.as_list()
896        if static_final_shape.is_fully_defined()
897        else array_ops.concat([array_ops.shape(x)[:-1], [n, n]], axis=0))
898    x = array_ops.reshape(array_ops.concat(x_list, axis=-1), new_shape)
899    x = array_ops.matrix_band_part(
900        x,
901        num_lower=(0 if upper else -1),
902        num_upper=(-1 if upper else 0))
903    x.set_shape(static_final_shape)
904    return x
905
906
907def fill_triangular_inverse(x, upper=False, name=None):
908  """Creates a vector from a (batch of) triangular matrix.
909
910  The vector is created from the lower-triangular or upper-triangular portion
911  depending on the value of the parameter `upper`.
912
913  If `x.shape` is `[b1, b2, ..., bB, n, n]` then the output shape is
914  `[b1, b2, ..., bB, d]` where `d = n (n + 1) / 2`.
915
916  Example:
917
918  ```python
919  fill_triangular_inverse(
920    [[4, 0, 0],
921     [6, 5, 0],
922     [3, 2, 1]])
923
924  # ==> [1, 2, 3, 4, 5, 6]
925
926  fill_triangular_inverse(
927    [[1, 2, 3],
928     [0, 5, 6],
929     [0, 0, 4]], upper=True)
930
931  # ==> [1, 2, 3, 4, 5, 6]
932  ```
933
934  Args:
935    x: `Tensor` representing lower (or upper) triangular elements.
936    upper: Python `bool` representing whether output matrix should be upper
937      triangular (`True`) or lower triangular (`False`, default).
938    name: Python `str`. The name to give this op.
939
940  Returns:
941    flat_tril: (Batch of) vector-shaped `Tensor` representing vectorized lower
942      (or upper) triangular elements from `x`.
943  """
944
945  with ops.name_scope(name, "fill_triangular_inverse", values=[x]):
946    x = ops.convert_to_tensor(x, name="x")
947    if tensor_shape.dimension_value(
948        x.shape.with_rank_at_least(2)[-1]) is not None:
949      n = np.int32(x.shape.dims[-1].value)
950      m = np.int32((n * (n + 1)) // 2)
951      static_final_shape = x.shape[:-2].concatenate([m])
952    else:
953      n = array_ops.shape(x)[-1]
954      m = (n * (n + 1)) // 2
955      static_final_shape = x.shape.with_rank_at_least(2)[:-2].concatenate(
956          [None])
957    ndims = prefer_static_rank(x)
958    if upper:
959      initial_elements = x[..., 0, :]
960      triangular_portion = x[..., 1:, :]
961    else:
962      initial_elements = array_ops.reverse(x[..., -1, :], axis=[ndims - 2])
963      triangular_portion = x[..., :-1, :]
964    rotated_triangular_portion = array_ops.reverse(
965        array_ops.reverse(triangular_portion, axis=[ndims - 1]),
966        axis=[ndims - 2])
967    consolidated_matrix = triangular_portion + rotated_triangular_portion
968    end_sequence = array_ops.reshape(
969        consolidated_matrix,
970        array_ops.concat([array_ops.shape(x)[:-2], [n * (n - 1)]], axis=0))
971    y = array_ops.concat([initial_elements, end_sequence[..., :m - n]], axis=-1)
972    y.set_shape(static_final_shape)
973    return y
974
975
976def tridiag(below=None, diag=None, above=None, name=None):
977  """Creates a matrix with values set above, below, and on the diagonal.
978
979  Example:
980
981  ```python
982  tridiag(below=[1., 2., 3.],
983          diag=[4., 5., 6., 7.],
984          above=[8., 9., 10.])
985  # ==> array([[  4.,   8.,   0.,   0.],
986  #            [  1.,   5.,   9.,   0.],
987  #            [  0.,   2.,   6.,  10.],
988  #            [  0.,   0.,   3.,   7.]], dtype=float32)
989  ```
990
991  Warning: This Op is intended for convenience, not efficiency.
992
993  Args:
994    below: `Tensor` of shape `[B1, ..., Bb, d-1]` corresponding to the below
995      diagonal part. `None` is logically equivalent to `below = 0`.
996    diag: `Tensor` of shape `[B1, ..., Bb, d]` corresponding to the diagonal
997      part.  `None` is logically equivalent to `diag = 0`.
998    above: `Tensor` of shape `[B1, ..., Bb, d-1]` corresponding to the above
999      diagonal part.  `None` is logically equivalent to `above = 0`.
1000    name: Python `str`. The name to give this op.
1001
1002  Returns:
1003    tridiag: `Tensor` with values set above, below and on the diagonal.
1004
1005  Raises:
1006    ValueError: if all inputs are `None`.
1007  """
1008
1009  def _pad(x):
1010    """Prepends and appends a zero to every vector in a batch of vectors."""
1011    shape = array_ops.concat([array_ops.shape(x)[:-1], [1]], axis=0)
1012    z = array_ops.zeros(shape, dtype=x.dtype)
1013    return array_ops.concat([z, x, z], axis=-1)
1014
1015  def _add(*x):
1016    """Adds list of Tensors, ignoring `None`."""
1017    s = None
1018    for y in x:
1019      if y is None:
1020        continue
1021      elif s is None:
1022        s = y
1023      else:
1024        s += y
1025    if s is None:
1026      raise ValueError("Must specify at least one of `below`, `diag`, `above`.")
1027    return s
1028
1029  with ops.name_scope(name, "tridiag", [below, diag, above]):
1030    if below is not None:
1031      below = ops.convert_to_tensor(below, name="below")
1032      below = array_ops.matrix_diag(_pad(below))[..., :-1, 1:]
1033    if diag is not None:
1034      diag = ops.convert_to_tensor(diag, name="diag")
1035      diag = array_ops.matrix_diag(diag)
1036    if above is not None:
1037      above = ops.convert_to_tensor(above, name="above")
1038      above = array_ops.matrix_diag(_pad(above))[..., 1:, :-1]
1039    # TODO(jvdillon): Consider using scatter_nd instead of creating three full
1040    # matrices.
1041    return _add(below, diag, above)
1042
1043
1044def reduce_weighted_logsumexp(
1045    logx,
1046    w=None,
1047    axis=None,
1048    keep_dims=False,
1049    return_sign=False,
1050    name=None):
1051  """Computes `log(abs(sum(weight * exp(elements across tensor dimensions))))`.
1052
1053  If all weights `w` are known to be positive, it is more efficient to directly
1054  use `reduce_logsumexp`, i.e., `tf.reduce_logsumexp(logx + tf.log(w))` is more
1055  efficient than `du.reduce_weighted_logsumexp(logx, w)`.
1056
1057  Reduces `input_tensor` along the dimensions given in `axis`.
1058  Unless `keep_dims` is true, the rank of the tensor is reduced by 1 for each
1059  entry in `axis`. If `keep_dims` is true, the reduced dimensions
1060  are retained with length 1.
1061
1062  If `axis` has no entries, all dimensions are reduced, and a
1063  tensor with a single element is returned.
1064
1065  This function is more numerically stable than log(sum(w * exp(input))). It
1066  avoids overflows caused by taking the exp of large inputs and underflows
1067  caused by taking the log of small inputs.
1068
1069  For example:
1070
1071  ```python
1072  x = tf.constant([[0., 0, 0],
1073                   [0, 0, 0]])
1074
1075  w = tf.constant([[-1., 1, 1],
1076                   [1, 1, 1]])
1077
1078  du.reduce_weighted_logsumexp(x, w)
1079  # ==> log(-1*1 + 1*1 + 1*1 + 1*1 + 1*1 + 1*1) = log(4)
1080
1081  du.reduce_weighted_logsumexp(x, w, axis=0)
1082  # ==> [log(-1+1), log(1+1), log(1+1)]
1083
1084  du.reduce_weighted_logsumexp(x, w, axis=1)
1085  # ==> [log(-1+1+1), log(1+1+1)]
1086
1087  du.reduce_weighted_logsumexp(x, w, axis=1, keep_dims=True)
1088  # ==> [[log(-1+1+1)], [log(1+1+1)]]
1089
1090  du.reduce_weighted_logsumexp(x, w, axis=[0, 1])
1091  # ==> log(-1+5)
1092  ```
1093
1094  Args:
1095    logx: The tensor to reduce. Should have numeric type.
1096    w: The weight tensor. Should have numeric type identical to `logx`.
1097    axis: The dimensions to reduce. If `None` (the default),
1098      reduces all dimensions. Must be in the range
1099      `[-rank(input_tensor), rank(input_tensor))`.
1100    keep_dims: If true, retains reduced dimensions with length 1.
1101    return_sign: If `True`, returns the sign of the result.
1102    name: A name for the operation (optional).
1103
1104  Returns:
1105    lswe: The `log(abs(sum(weight * exp(x))))` reduced tensor.
1106    sign: (Optional) The sign of `sum(weight * exp(x))`.
1107  """
1108  with ops.name_scope(name, "reduce_weighted_logsumexp", [logx, w]):
1109    logx = ops.convert_to_tensor(logx, name="logx")
1110    if w is None:
1111      lswe = math_ops.reduce_logsumexp(logx, axis=axis, keepdims=keep_dims)
1112      if return_sign:
1113        sgn = array_ops.ones_like(lswe)
1114        return lswe, sgn
1115      return lswe
1116    w = ops.convert_to_tensor(w, dtype=logx.dtype, name="w")
1117    log_absw_x = logx + math_ops.log(math_ops.abs(w))
1118    max_log_absw_x = math_ops.reduce_max(log_absw_x, axis=axis, keepdims=True)
1119    # If the largest element is `-inf` or `inf` then we don't bother subtracting
1120    # off the max. We do this because otherwise we'd get `inf - inf = NaN`. That
1121    # this is ok follows from the fact that we're actually free to subtract any
1122    # value we like, so long as we add it back after taking the `log(sum(...))`.
1123    max_log_absw_x = array_ops.where(
1124        math_ops.is_inf(max_log_absw_x),
1125        array_ops.zeros_like(max_log_absw_x),
1126        max_log_absw_x)
1127    wx_over_max_absw_x = (
1128        math_ops.sign(w) * math_ops.exp(log_absw_x - max_log_absw_x))
1129    sum_wx_over_max_absw_x = math_ops.reduce_sum(
1130        wx_over_max_absw_x, axis=axis, keepdims=keep_dims)
1131    if not keep_dims:
1132      max_log_absw_x = array_ops.squeeze(max_log_absw_x, axis)
1133    sgn = math_ops.sign(sum_wx_over_max_absw_x)
1134    lswe = max_log_absw_x + math_ops.log(sgn * sum_wx_over_max_absw_x)
1135    if return_sign:
1136      return lswe, sgn
1137    return lswe
1138
1139
1140# TODO(jvdillon): Merge this test back into:
1141# tensorflow/python/ops/softplus_op_test.py
1142# once TF core is accepting new ops.
1143def softplus_inverse(x, name=None):
1144  """Computes the inverse softplus, i.e., x = softplus_inverse(softplus(x)).
1145
1146  Mathematically this op is equivalent to:
1147
1148  ```none
1149  softplus_inverse = log(exp(x) - 1.)
1150  ```
1151
1152  Args:
1153    x: `Tensor`. Non-negative (not enforced), floating-point.
1154    name: A name for the operation (optional).
1155
1156  Returns:
1157    `Tensor`. Has the same type/shape as input `x`.
1158  """
1159  with ops.name_scope(name, "softplus_inverse", values=[x]):
1160    x = ops.convert_to_tensor(x, name="x")
1161    # We begin by deriving a more numerically stable softplus_inverse:
1162    # x = softplus(y) = Log[1 + exp{y}], (which means x > 0).
1163    # ==> exp{x} = 1 + exp{y}                                (1)
1164    # ==> y = Log[exp{x} - 1]                                (2)
1165    #       = Log[(exp{x} - 1) / exp{x}] + Log[exp{x}]
1166    #       = Log[(1 - exp{-x}) / 1] + Log[exp{x}]
1167    #       = Log[1 - exp{-x}] + x                           (3)
1168    # (2) is the "obvious" inverse, but (3) is more stable than (2) for large x.
1169    # For small x (e.g. x = 1e-10), (3) will become -inf since 1 - exp{-x} will
1170    # be zero. To fix this, we use 1 - exp{-x} approx x for small x > 0.
1171    #
1172    # In addition to the numerically stable derivation above, we clamp
1173    # small/large values to be congruent with the logic in:
1174    # tensorflow/core/kernels/softplus_op.h
1175    #
1176    # Finally, we set the input to one whenever the input is too large or too
1177    # small. This ensures that no unchosen codepath is +/- inf. This is
1178    # necessary to ensure the gradient doesn't get NaNs. Recall that the
1179    # gradient of `where` behaves like `pred*pred_true + (1-pred)*pred_false`
1180    # thus an `inf` in an unselected path results in `0*inf=nan`. We are careful
1181    # to overwrite `x` with ones only when we will never actually use this
1182    # value. Note that we use ones and not zeros since `log(expm1(0.)) = -inf`.
1183    threshold = np.log(np.finfo(x.dtype.as_numpy_dtype).eps) + 2.
1184    is_too_small = math_ops.less(x, np.exp(threshold))
1185    is_too_large = math_ops.greater(x, -threshold)
1186    too_small_value = math_ops.log(x)
1187    too_large_value = x
1188    # This `where` will ultimately be a NOP because we won't select this
1189    # codepath whenever we used the surrogate `ones_like`.
1190    x = array_ops.where(math_ops.logical_or(is_too_small, is_too_large),
1191                        array_ops.ones_like(x), x)
1192    y = x + math_ops.log(-math_ops.expm1(-x))  # == log(expm1(x))
1193    return array_ops.where(is_too_small, too_small_value,
1194                           array_ops.where(is_too_large, too_large_value, y))
1195
1196
1197# TODO(b/35290280): Add unit-tests.
1198def dimension_size(x, axis):
1199  """Returns the size of a specific dimension."""
1200  # Since tf.gather isn't "constant-in, constant-out", we must first check the
1201  # static shape or fallback to dynamic shape.
1202  s = tensor_shape.dimension_value(
1203      x.shape.with_rank_at_least(np.abs(axis))[axis])
1204  if s is not None:
1205    return s
1206  return array_ops.shape(x)[axis]
1207
1208
1209def process_quadrature_grid_and_probs(
1210    quadrature_grid_and_probs, dtype, validate_args, name=None):
1211  """Validates quadrature grid, probs or computes them as necessary.
1212
1213  Args:
1214    quadrature_grid_and_probs: Python pair of `float`-like `Tensor`s
1215      representing the sample points and the corresponding (possibly
1216      normalized) weight.  When `None`, defaults to:
1217      `np.polynomial.hermite.hermgauss(deg=8)`.
1218    dtype: The expected `dtype` of `grid` and `probs`.
1219    validate_args: Python `bool`, default `False`. When `True` distribution
1220      parameters are checked for validity despite possibly degrading runtime
1221      performance. When `False` invalid inputs may silently render incorrect
1222      outputs.
1223    name: Python `str` name prefixed to Ops created by this class.
1224
1225  Returns:
1226     quadrature_grid_and_probs: Python pair of `float`-like `Tensor`s
1227      representing the sample points and the corresponding (possibly
1228      normalized) weight.
1229
1230  Raises:
1231    ValueError: if `quadrature_grid_and_probs is not None` and
1232      `len(quadrature_grid_and_probs[0]) != len(quadrature_grid_and_probs[1])`
1233  """
1234  with ops.name_scope(name, "process_quadrature_grid_and_probs",
1235                      [quadrature_grid_and_probs]):
1236    if quadrature_grid_and_probs is None:
1237      grid, probs = np.polynomial.hermite.hermgauss(deg=8)
1238      grid = grid.astype(dtype.as_numpy_dtype)
1239      probs = probs.astype(dtype.as_numpy_dtype)
1240      probs /= np.linalg.norm(probs, ord=1, keepdims=True)
1241      grid = ops.convert_to_tensor(grid, name="grid", dtype=dtype)
1242      probs = ops.convert_to_tensor(probs, name="probs", dtype=dtype)
1243      return grid, probs
1244
1245    grid, probs = tuple(quadrature_grid_and_probs)
1246    grid = ops.convert_to_tensor(grid, name="grid", dtype=dtype)
1247    probs = ops.convert_to_tensor(probs, name="unnormalized_probs",
1248                                  dtype=dtype)
1249    probs /= linalg_ops.norm(probs, ord=1, axis=-1, keepdims=True, name="probs")
1250
1251    def _static_event_size(x):
1252      """Returns the static size of a specific dimension or `None`."""
1253      return tensor_shape.dimension_value(x.shape.with_rank_at_least(1)[-1])
1254
1255    m, n = _static_event_size(probs), _static_event_size(grid)
1256    if m is not None and n is not None:
1257      if m != n:
1258        raise ValueError("`quadrature_grid_and_probs` must be a `tuple` of "
1259                         "same-length zero-th-dimension `Tensor`s "
1260                         "(saw lengths {}, {})".format(m, n))
1261    elif validate_args:
1262      assertions = [
1263          check_ops.assert_equal(
1264              dimension_size(probs, axis=-1),
1265              dimension_size(grid, axis=-1),
1266              message=("`quadrature_grid_and_probs` must be a `tuple` of "
1267                       "same-length zero-th-dimension `Tensor`s")),
1268      ]
1269      with ops.control_dependencies(assertions):
1270        grid = array_ops.identity(grid)
1271        probs = array_ops.identity(probs)
1272    return grid, probs
1273
1274
1275def pad(x, axis, front=False, back=False, value=0, count=1, name=None):
1276  """Pads `value` to the front and/or back of a `Tensor` dim, `count` times.
1277
1278  Args:
1279    x: `Tensor` input.
1280    axis: Scalar `int`-like `Tensor` representing the single dimension to pad.
1281      (Negative indexing is supported.)
1282    front: Python `bool`; if `True` the beginning of the `axis` dimension is
1283      padded with `value`, `count` times. If `False` no front padding is made.
1284    back: Python `bool`; if `True` the end of the `axis` dimension is
1285      padded with `value`, `count` times. If `False` no end padding is made.
1286    value: Scalar `int`-like `Tensor` representing the actual value added to the
1287      front and/or back of the `axis` dimension of `x`.
1288    count: Scalar `int`-like `Tensor` representing number of elements added to
1289      the front and/or back of the `axis` dimension of `x`. E.g., if
1290      `front = back = True` then `2 * count` elements are added.
1291    name: Python `str` name prefixed to Ops created by this function.
1292
1293  Returns:
1294    pad: The padded version of input `x`.
1295
1296  Raises:
1297    ValueError: if both `front` and `back` are `False`.
1298    TypeError: if `count` is not `int`-like.
1299  """
1300  with ops.name_scope(name, "pad", [x, value, count]):
1301    x = ops.convert_to_tensor(x, name="x")
1302    value = ops.convert_to_tensor(value, dtype=x.dtype, name="value")
1303    count = ops.convert_to_tensor(count, name="count")
1304    if not count.dtype.is_integer:
1305      raise TypeError("`count.dtype` (`{}`) must be `int`-like.".format(
1306          count.dtype.name))
1307    if not front and not back:
1308      raise ValueError("At least one of `front`, `back` must be `True`.")
1309    ndims = (x.shape.ndims if x.shape.ndims is not None
1310             else array_ops.rank(x, name="ndims"))
1311    axis = ops.convert_to_tensor(axis, name="axis")
1312    axis_ = tensor_util.constant_value(axis)
1313    if axis_ is not None:
1314      axis = axis_
1315      if axis < 0:
1316        axis = ndims + axis
1317      count_ = tensor_util.constant_value(count)
1318      if axis_ >= 0 or x.shape.ndims is not None:
1319        head = x.shape[:axis]
1320        middle = tensor_shape.TensorShape(
1321            None if count_ is None
1322            else (tensor_shape.dimension_at_index(
1323                x.shape, axis) + count_ * (front + back)))
1324        tail = x.shape[axis+1:]
1325        final_shape = head.concatenate(middle.concatenate(tail))
1326      else:
1327        final_shape = None
1328    else:
1329      axis = array_ops.where(axis < 0, ndims + axis, axis)
1330      final_shape = None
1331    x = array_ops.pad(
1332        x,
1333        paddings=array_ops.one_hot(
1334            indices=array_ops.stack([axis if front else -1,
1335                                     axis if back else -1]),
1336            depth=ndims,
1337            axis=0,
1338            on_value=count,
1339            dtype=dtypes.int32),
1340        constant_values=value)
1341    if final_shape is not None:
1342      x.set_shape(final_shape)
1343    return x
1344
1345
1346def parent_frame_arguments():
1347  """Returns parent frame arguments.
1348
1349  When called inside a function, returns a dictionary with the caller's function
1350  arguments. These are positional arguments and keyword arguments (**kwargs),
1351  while variable arguments (*varargs) are excluded.
1352
1353  When called at global scope, this will return an empty dictionary, since there
1354  are no arguments.
1355
1356  WARNING: If caller function argument names are overloaded before invoking
1357  this method, then values will reflect the overloaded value. For this reason,
1358  we recommend calling `parent_frame_arguments` at the beginning of the
1359  function.
1360  """
1361  # All arguments and the names used for *varargs, and **kwargs
1362  arg_names, variable_arg_name, keyword_arg_name, local_vars = (
1363      tf_inspect._inspect.getargvalues(  # pylint: disable=protected-access
1364          # Get the first frame of the caller of this method.
1365          tf_inspect._inspect.stack()[1][0]))  # pylint: disable=protected-access
1366
1367  # Remove the *varargs, and flatten the **kwargs. Both are
1368  # nested lists.
1369  local_vars.pop(variable_arg_name, {})
1370  keyword_args = local_vars.pop(keyword_arg_name, {})
1371
1372  final_args = {}
1373  # Copy over arguments and their values. In general, local_vars
1374  # may contain more than just the arguments, since this method
1375  # can be called anywhere in a function.
1376  for arg_name in arg_names:
1377    final_args[arg_name] = local_vars.pop(arg_name)
1378  final_args.update(keyword_args)
1379
1380  return final_args
1381
1382
1383class AppendDocstring(object):
1384  """Helper class to promote private subclass docstring to public counterpart.
1385
1386  Example:
1387
1388  ```python
1389  class TransformedDistribution(Distribution):
1390    @distribution_util.AppendDocstring(
1391      additional_note="A special note!",
1392      kwargs_dict={"foo": "An extra arg."})
1393    def _prob(self, y, foo=None):
1394      pass
1395  ```
1396
1397  In this case, the `AppendDocstring` decorator appends the `additional_note` to
1398  the docstring of `prob` (not `_prob`) and adds a new `kwargs`
1399  section with each dictionary item as a bullet-point.
1400
1401  For a more detailed example, see `TransformedDistribution`.
1402  """
1403
1404  def __init__(self, additional_note="", kwargs_dict=None):
1405    """Initializes the AppendDocstring object.
1406
1407    Args:
1408      additional_note: Python string added as additional docstring to public
1409        version of function.
1410      kwargs_dict: Python string/string dictionary representing
1411        specific kwargs expanded from the **kwargs input.
1412
1413    Raises:
1414      ValueError: if kwargs_dict.key contains whitespace.
1415      ValueError: if kwargs_dict.value contains newlines.
1416    """
1417    self._additional_note = additional_note
1418    if kwargs_dict:
1419      bullets = []
1420      for key in sorted(kwargs_dict.keys()):
1421        value = kwargs_dict[key]
1422        if any(x.isspace() for x in key):
1423          raise ValueError(
1424              "Parameter name \"%s\" contains whitespace." % key)
1425        value = value.lstrip()
1426        if "\n" in value:
1427          raise ValueError(
1428              "Parameter description for \"%s\" contains newlines." % key)
1429        bullets.append("*  `%s`: %s" % (key, value))
1430      self._additional_note += ("\n\n##### `kwargs`:\n\n" +
1431                                "\n".join(bullets))
1432
1433  def __call__(self, fn):
1434    @functools.wraps(fn)
1435    def _fn(*args, **kwargs):
1436      return fn(*args, **kwargs)
1437    if _fn.__doc__ is None:
1438      _fn.__doc__ = self._additional_note
1439    else:
1440      _fn.__doc__ += "\n%s" % self._additional_note
1441    return _fn
1442