• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for probability distributions."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.framework import dtypes
22from tensorflow.python.framework import ops
23from tensorflow.python.framework import smart_cond
24from tensorflow.python.framework import tensor_shape
25from tensorflow.python.framework import tensor_util
26from tensorflow.python.ops import array_ops
27from tensorflow.python.ops import check_ops
28from tensorflow.python.ops import control_flow_ops
29from tensorflow.python.ops import math_ops
30from tensorflow.python.ops.linalg import linalg
31from tensorflow.python.ops.distributions import distribution as distribution_lib
32
33# The following two lines are redundant, in a sense. The first enables
34# good coding practice  *within* this file (`util.prefer_static_value`
35# rather than  `prefer_static_value`). The  second ensures  that users
36# also get the core utils when they import this file.
37from tensorflow.python.ops.distributions import util
38from tensorflow.python.ops.distributions.util import *  # pylint: disable=wildcard-import
39
40
41def _convert_to_tensor(x, name):
42  return None if x is None else ops.convert_to_tensor(x, name=name)
43
44
45def mixture_stddev(mixture_weight_vector, mean_vector, stddev_vector):
46  """Computes the standard deviation of a mixture distribution.
47
48  This function works regardless of the component distribution, so long as
49  each component's mean and standard deviation can be provided.
50
51  Args:
52    mixture_weight_vector: A 2D tensor with shape [batch_size, num_components]
53    mean_vector: A 2D tensor of mixture component means. Has shape
54      `[batch_size, num_components]`.
55    stddev_vector: A 2D tensor of mixture component standard deviations. Has
56      shape `[batch_size, num_components]`.
57  Returns:
58    A 1D tensor of shape `[batch_size]` representing the standard deviation of
59    the mixture distribution with given weights and component means and standard
60    deviations.
61  Raises:
62    ValueError: If the shapes of the input tensors are not as expected.
63  """
64  mixture_weight_vector.shape.assert_has_rank(2)
65  if not mean_vector.shape.is_compatible_with(mixture_weight_vector.shape):
66    raise ValueError("Expecting means to have same shape as mixture weights.")
67  if not stddev_vector.shape.is_compatible_with(mixture_weight_vector.shape):
68    raise ValueError("Expecting stddevs to have same shape as mixture weights.")
69
70  # Reshape the distribution parameters for batched vectorized dot products.
71  pi_for_dot_prod = array_ops.expand_dims(mixture_weight_vector, axis=1)
72  mu_for_dot_prod = array_ops.expand_dims(mean_vector, axis=2)
73  sigma_for_dot_prod = array_ops.expand_dims(stddev_vector, axis=2)
74
75  # weighted average of component means under mixture distribution.
76  mean_wa = math_ops.matmul(pi_for_dot_prod, mu_for_dot_prod)
77  mean_wa = array_ops.reshape(mean_wa, (-1,))
78  # weighted average of component variances under mixture distribution.
79  var_wa = math_ops.matmul(pi_for_dot_prod,
80                           math_ops.square(sigma_for_dot_prod))
81  var_wa = array_ops.reshape(var_wa, (-1,))
82  # weighted average of component squared means under mixture distribution.
83  sq_mean_wa = math_ops.matmul(pi_for_dot_prod,
84                               math_ops.square(mu_for_dot_prod))
85  sq_mean_wa = array_ops.reshape(sq_mean_wa, (-1,))
86  mixture_variance = var_wa + sq_mean_wa - math_ops.square(mean_wa)
87  return math_ops.sqrt(mixture_variance)
88
89
90def make_tril_scale(
91    loc=None,
92    scale_tril=None,
93    scale_diag=None,
94    scale_identity_multiplier=None,
95    shape_hint=None,
96    validate_args=False,
97    assert_positive=False,
98    name=None):
99  """Creates a LinOp representing a lower triangular matrix.
100
101  Args:
102    loc: Floating-point `Tensor`. This is used for inferring shape in the case
103      where only `scale_identity_multiplier` is set.
104    scale_tril: Floating-point `Tensor` representing the diagonal matrix.
105      `scale_diag` has shape [N1, N2, ...  k, k], which represents a k x k
106      lower triangular matrix.
107      When `None` no `scale_tril` term is added to the LinOp.
108      The upper triangular elements above the diagonal are ignored.
109    scale_diag: Floating-point `Tensor` representing the diagonal matrix.
110      `scale_diag` has shape [N1, N2, ...  k], which represents a k x k
111      diagonal matrix.
112      When `None` no diagonal term is added to the LinOp.
113    scale_identity_multiplier: floating point rank 0 `Tensor` representing a
114      scaling done to the identity matrix.
115      When `scale_identity_multiplier = scale_diag = scale_tril = None` then
116      `scale += IdentityMatrix`. Otherwise no scaled-identity-matrix is added
117      to `scale`.
118    shape_hint: scalar integer `Tensor` representing a hint at the dimension of
119      the identity matrix when only `scale_identity_multiplier` is set.
120    validate_args: Python `bool` indicating whether arguments should be
121      checked for correctness.
122    assert_positive: Python `bool` indicating whether LinOp should be checked
123      for being positive definite.
124    name: Python `str` name given to ops managed by this object.
125
126  Returns:
127    `LinearOperator` representing a lower triangular matrix.
128
129  Raises:
130    ValueError:  If only `scale_identity_multiplier` is set and `loc` and
131      `shape_hint` are both None.
132  """
133
134  def _maybe_attach_assertion(x):
135    if not validate_args:
136      return x
137    if assert_positive:
138      return control_flow_ops.with_dependencies([
139          check_ops.assert_positive(
140              array_ops.matrix_diag_part(x),
141              message="diagonal part must be positive"),
142      ], x)
143    return control_flow_ops.with_dependencies([
144        check_ops.assert_none_equal(
145            array_ops.matrix_diag_part(x),
146            array_ops.zeros([], x.dtype),
147            message="diagonal part must be non-zero"),
148    ], x)
149
150  with ops.name_scope(name, "make_tril_scale",
151                      values=[loc, scale_diag, scale_identity_multiplier]):
152
153    loc = _convert_to_tensor(loc, name="loc")
154    scale_tril = _convert_to_tensor(scale_tril, name="scale_tril")
155    scale_diag = _convert_to_tensor(scale_diag, name="scale_diag")
156    scale_identity_multiplier = _convert_to_tensor(
157        scale_identity_multiplier,
158        name="scale_identity_multiplier")
159
160  if scale_tril is not None:
161    scale_tril = array_ops.matrix_band_part(scale_tril, -1, 0)  # Zero out TriU.
162    tril_diag = array_ops.matrix_diag_part(scale_tril)
163    if scale_diag is not None:
164      tril_diag += scale_diag
165    if scale_identity_multiplier is not None:
166      tril_diag += scale_identity_multiplier[..., array_ops.newaxis]
167
168    scale_tril = array_ops.matrix_set_diag(scale_tril, tril_diag)
169
170    return linalg.LinearOperatorLowerTriangular(
171        tril=_maybe_attach_assertion(scale_tril),
172        is_non_singular=True,
173        is_self_adjoint=False,
174        is_positive_definite=assert_positive)
175
176  return make_diag_scale(
177      loc=loc,
178      scale_diag=scale_diag,
179      scale_identity_multiplier=scale_identity_multiplier,
180      shape_hint=shape_hint,
181      validate_args=validate_args,
182      assert_positive=assert_positive,
183      name=name)
184
185
186def make_diag_scale(
187    loc=None,
188    scale_diag=None,
189    scale_identity_multiplier=None,
190    shape_hint=None,
191    validate_args=False,
192    assert_positive=False,
193    name=None):
194  """Creates a LinOp representing a diagonal matrix.
195
196  Args:
197    loc: Floating-point `Tensor`. This is used for inferring shape in the case
198      where only `scale_identity_multiplier` is set.
199    scale_diag: Floating-point `Tensor` representing the diagonal matrix.
200      `scale_diag` has shape [N1, N2, ...  k], which represents a k x k
201      diagonal matrix.
202      When `None` no diagonal term is added to the LinOp.
203    scale_identity_multiplier: floating point rank 0 `Tensor` representing a
204      scaling done to the identity matrix.
205      When `scale_identity_multiplier = scale_diag = scale_tril = None` then
206      `scale += IdentityMatrix`. Otherwise no scaled-identity-matrix is added
207      to `scale`.
208    shape_hint: scalar integer `Tensor` representing a hint at the dimension of
209      the identity matrix when only `scale_identity_multiplier` is set.
210    validate_args: Python `bool` indicating whether arguments should be
211      checked for correctness.
212    assert_positive: Python `bool` indicating whether LinOp should be checked
213      for being positive definite.
214    name: Python `str` name given to ops managed by this object.
215
216  Returns:
217    `LinearOperator` representing a lower triangular matrix.
218
219  Raises:
220    ValueError:  If only `scale_identity_multiplier` is set and `loc` and
221      `shape_hint` are both None.
222  """
223
224  def _maybe_attach_assertion(x):
225    if not validate_args:
226      return x
227    if assert_positive:
228      return control_flow_ops.with_dependencies([
229          check_ops.assert_positive(
230              x, message="diagonal part must be positive"),
231      ], x)
232    return control_flow_ops.with_dependencies([
233        check_ops.assert_none_equal(
234            x,
235            array_ops.zeros([], x.dtype),
236            message="diagonal part must be non-zero")], x)
237
238  with ops.name_scope(name, "make_diag_scale",
239                      values=[loc, scale_diag, scale_identity_multiplier]):
240    loc = _convert_to_tensor(loc, name="loc")
241    scale_diag = _convert_to_tensor(scale_diag, name="scale_diag")
242    scale_identity_multiplier = _convert_to_tensor(
243        scale_identity_multiplier,
244        name="scale_identity_multiplier")
245
246    if scale_diag is not None:
247      if scale_identity_multiplier is not None:
248        scale_diag += scale_identity_multiplier[..., array_ops.newaxis]
249      return linalg.LinearOperatorDiag(
250          diag=_maybe_attach_assertion(scale_diag),
251          is_non_singular=True,
252          is_self_adjoint=True,
253          is_positive_definite=assert_positive)
254
255    if loc is None and shape_hint is None:
256      raise ValueError(
257          "Cannot infer `event_shape` unless `loc` or "
258          "`shape_hint` is specified.")
259
260    if shape_hint is None:
261      shape_hint = loc.shape[-1]
262
263    if scale_identity_multiplier is None:
264      return linalg.LinearOperatorIdentity(
265          num_rows=shape_hint,
266          dtype=loc.dtype.base_dtype,
267          is_self_adjoint=True,
268          is_positive_definite=True,
269          assert_proper_shapes=validate_args)
270
271    return linalg.LinearOperatorScaledIdentity(
272        num_rows=shape_hint,
273        multiplier=_maybe_attach_assertion(scale_identity_multiplier),
274        is_non_singular=True,
275        is_self_adjoint=True,
276        is_positive_definite=assert_positive,
277        assert_proper_shapes=validate_args)
278
279
280def shapes_from_loc_and_scale(loc, scale, name="shapes_from_loc_and_scale"):
281  """Infer distribution batch and event shapes from a location and scale.
282
283  Location and scale family distributions determine their batch/event shape by
284  broadcasting the `loc` and `scale` args.  This helper does that broadcast,
285  statically if possible.
286
287  Batch shape broadcasts as per the normal rules.
288  We allow the `loc` event shape to broadcast up to that of `scale`.  We do not
289  allow `scale`'s event shape to change.  Therefore, the last dimension of `loc`
290  must either be size `1`, or the same as `scale.range_dimension`.
291
292  See `MultivariateNormalLinearOperator` for a usage example.
293
294  Args:
295    loc:  `N-D` `Tensor` with `N >= 1` (already converted to tensor) or `None`.
296      If `None`, both batch and event shape are determined by `scale`.
297    scale:  A `LinearOperator` instance.
298    name:  A string name to prepend to created ops.
299
300  Returns:
301    batch_shape:  `TensorShape` (if broadcast is done statically), or `Tensor`.
302    event_shape:  `TensorShape` (if broadcast is done statically), or `Tensor`.
303
304  Raises:
305    ValueError:  If the last dimension of `loc` is determined statically to be
306      different than the range of `scale`.
307  """
308  with ops.name_scope(name, values=[loc] + scale.graph_parents):
309    # Get event shape.
310    event_size = scale.range_dimension_tensor()
311    event_size_const = tensor_util.constant_value(event_size)
312    if event_size_const is not None:
313      event_shape = event_size_const.reshape([1])
314    else:
315      event_shape = event_size[array_ops.newaxis]
316
317    # Static check that event shapes match.
318    if loc is not None:
319      loc_event_size = tensor_shape.dimension_value(loc.get_shape()[-1])
320      if loc_event_size is not None and event_size_const is not None:
321        if loc_event_size != 1 and loc_event_size != event_size_const:
322          raise ValueError(
323              "Event size of 'scale' (%d) could not be broadcast up to that of "
324              "'loc' (%d)." % (loc_event_size, event_size_const))
325
326    # Get batch shape.
327    batch_shape = scale.batch_shape_tensor()
328    if loc is None:
329      batch_shape_const = tensor_util.constant_value(batch_shape)
330      batch_shape = (
331          batch_shape_const if batch_shape_const is not None else batch_shape)
332    else:
333      loc_batch_shape = loc.get_shape().with_rank_at_least(1)[:-1]
334      if (loc.get_shape().ndims is None or
335          not loc_batch_shape.is_fully_defined()):
336        loc_batch_shape = array_ops.shape(loc)[:-1]
337      else:
338        loc_batch_shape = ops.convert_to_tensor(loc_batch_shape,
339                                                name="loc_batch_shape")
340      # This is defined in the core util module.
341      # pylint: disable=undefined-variable
342      batch_shape = prefer_static_broadcast_shape(batch_shape, loc_batch_shape)
343      # pylint: enable=undefined-variable
344
345  return batch_shape, event_shape
346
347
348def get_broadcast_shape(*tensors):
349  """Get broadcast shape as a Python list of integers (preferred) or `Tensor`.
350
351  Args:
352    *tensors:  One or more `Tensor` objects (already converted!).
353
354  Returns:
355    broadcast shape:  Python list (if shapes determined statically), otherwise
356      an `int32` `Tensor`.
357  """
358  # Try static.
359  s_shape = tensors[0].shape
360  for t in tensors[1:]:
361    s_shape = array_ops.broadcast_static_shape(s_shape, t.shape)
362  if s_shape.is_fully_defined():
363    return s_shape.as_list()
364
365  # Fallback on dynamic.
366  d_shape = array_ops.shape(tensors[0])
367  for t in tensors[1:]:
368    d_shape = array_ops.broadcast_dynamic_shape(d_shape, array_ops.shape(t))
369  return d_shape
370
371
372def is_diagonal_scale(scale):
373  """Returns `True` if `scale` is a `LinearOperator` that is known to be diag.
374
375  Args:
376    scale:  `LinearOperator` instance.
377
378  Returns:
379    Python `bool`.
380
381  Raises:
382    TypeError:  If `scale` is not a `LinearOperator`.
383  """
384  if not isinstance(scale, linalg.LinearOperator):
385    raise TypeError("Expected argument 'scale' to be instance of LinearOperator"
386                    ". Found: %s" % scale)
387  return (isinstance(scale, linalg.LinearOperatorIdentity) or
388          isinstance(scale, linalg.LinearOperatorScaledIdentity) or
389          isinstance(scale, linalg.LinearOperatorDiag))
390
391
392def maybe_check_scalar_distribution(
393    distribution, expected_base_dtype, validate_args):
394  """Helper which checks validity of a scalar `distribution` init arg.
395
396  Valid here means:
397
398  * `distribution` has scalar batch and event shapes.
399  * `distribution` is `FULLY_REPARAMETERIZED`
400  * `distribution` has expected dtype.
401
402  Args:
403    distribution:  `Distribution`-like object.
404    expected_base_dtype:  `TensorFlow` `dtype`.
405    validate_args:  Python `bool`.  Whether to do additional checks:
406      (i)  check that reparameterization_type is `FULLY_REPARAMETERIZED`.
407      (ii) add `tf.Assert` ops to the graph to enforce that distribution
408           is scalar in the event that this cannot be determined statically.
409
410  Returns:
411    List of `tf.Assert` ops to run to enforce validity checks that could not
412      be statically determined.  Empty if `not validate_args`.
413
414  Raises:
415    ValueError:  If validate_args and distribution is not FULLY_REPARAMETERIZED
416    ValueError:  If distribution is statically determined to not have both
417      scalar batch and scalar event shapes.
418  """
419  if distribution.dtype != expected_base_dtype:
420    raise TypeError("dtype mismatch; "
421                    "distribution.dtype=\"{}\" is not \"{}\"".format(
422                        distribution.dtype.name, expected_base_dtype.name))
423
424  # Although `reparameterization_type` is a static property, we guard it by
425  # `validate_args`. This allows users to use a `distribution` which is not
426  # reparameterized itself. However, we tacitly assume that although the
427  # distribution is not reparameterized, it only depends on non-trainable
428  # variables.
429  if validate_args and (distribution.reparameterization_type
430                        != distribution_lib.FULLY_REPARAMETERIZED):
431    raise ValueError("Base distribution should be reparameterized or be "
432                     "a function of non-trainable variables; "
433                     "distribution.reparameterization_type = \"{}\" "
434                     "!= \"FULLY_REPARAMETERIZED\".".format(
435                         distribution.reparameterization_type))
436  with ops.name_scope(name="check_distribution"):
437    assertions = []
438    def check_is_scalar(is_scalar, name):
439      is_scalar_ = static_value(is_scalar)
440      if is_scalar_ is not None:
441        if not is_scalar_:
442          raise ValueError("distribution must be scalar; "
443                           "distribution.{}=False is not True".format(name))
444      elif validate_args:
445        assertions.append(check_ops.assert_equal(
446            is_scalar, True,
447            message=("distribution must be scalar; "
448                     "distribution.{}=False is not True".format(name))))
449    check_is_scalar(distribution.is_scalar_event(), "is_scalar_event")
450    check_is_scalar(distribution.is_scalar_batch(), "is_scalar_batch")
451    return assertions
452
453
454def pad_mixture_dimensions(x, mixture_distribution, categorical_distribution,
455                           event_ndims):
456  """Pad dimensions of event tensors for mixture distributions.
457
458  See `Mixture._sample_n` and `MixtureSameFamily._sample_n` for usage examples.
459
460  Args:
461    x: event tensor to pad.
462    mixture_distribution: Base distribution of the mixture.
463    categorical_distribution: `Categorical` distribution that mixes the base
464      distribution.
465    event_ndims: Integer specifying the number of event dimensions in the event
466      tensor.
467
468  Returns:
469    A padded version of `x` that can broadcast with `categorical_distribution`.
470  """
471  with ops.name_scope("pad_mix_dims", values=[x]):
472    def _get_ndims(d):
473      if d.batch_shape.ndims is not None:
474        return d.batch_shape.ndims
475      return array_ops.shape(d.batch_shape_tensor())[0]
476    dist_batch_ndims = _get_ndims(mixture_distribution)
477    cat_batch_ndims = _get_ndims(categorical_distribution)
478    pad_ndims = array_ops.where(
479        categorical_distribution.is_scalar_batch(),
480        dist_batch_ndims,
481        dist_batch_ndims - cat_batch_ndims)
482    s = array_ops.shape(x)
483    x = array_ops.reshape(x, shape=array_ops.concat([
484        s[:-1],
485        array_ops.ones([pad_ndims], dtype=dtypes.int32),
486        s[-1:],
487        array_ops.ones([event_ndims], dtype=dtypes.int32),
488    ], axis=0))
489    return x
490
491
492def static_value(x):
493  """Returns the static value of a `Tensor` or `None`."""
494  return tensor_util.constant_value(ops.convert_to_tensor(x))
495
496
497def move_dimension(x, source_idx, dest_idx):
498  """Move a single tensor dimension within its shape.
499
500  This is a special case of `tf.transpose()`, which applies
501  arbitrary permutations to tensor dimensions.
502
503  Args:
504    x: Tensor of rank `ndims`.
505    source_idx: Integer index into `x.shape` (negative indexing is
506      supported).
507    dest_idx: Integer index into `x.shape` (negative indexing is
508      supported).
509
510  Returns:
511    x_perm: Tensor of rank `ndims`, in which the dimension at original
512     index `source_idx` has been moved to new index `dest_idx`, with
513     all other dimensions retained in their original order.
514
515  Example:
516
517  ```python
518  x = tf.placeholder(shape=[200, 30, 4, 1, 6])
519  x_perm = _move_dimension(x, 1, 1) # no-op
520  x_perm = _move_dimension(x, 0, 3) # result shape [30, 4, 1, 200, 6]
521  x_perm = _move_dimension(x, 0, -2) # equivalent to previous
522  x_perm = _move_dimension(x, 4, 2) # result shape [200, 30, 6, 4, 1]
523  ```
524  """
525  ndims = util.prefer_static_rank(x)
526  if isinstance(source_idx, int):
527    dtype = dtypes.int32
528  else:
529    dtype = dtypes.as_dtype(source_idx.dtype)
530
531  # Handle negative indexing. Since ndims might be dynamic, this makes
532  # source_idx and dest_idx also possibly dynamic.
533  if source_idx < 0:
534    source_idx = ndims + source_idx
535  if dest_idx < 0:
536    dest_idx = ndims + dest_idx
537
538  # Construct the appropriate permutation of dimensions, depending
539  # whether the source is before or after the destination.
540  def move_left_permutation():
541    return util.prefer_static_value(
542        array_ops.concat([
543            math_ops.range(0, dest_idx, dtype=dtype),
544            [source_idx],
545            math_ops.range(dest_idx, source_idx, dtype=dtype),
546            math_ops.range(source_idx+1, ndims, dtype=dtype)], axis=0))
547
548  def move_right_permutation():
549    return util.prefer_static_value(
550        array_ops.concat([
551            math_ops.range(0, source_idx, dtype=dtype),
552            math_ops.range(source_idx+1, dest_idx+1, dtype=dtype),
553            [source_idx],
554            math_ops.range(dest_idx+1, ndims, dtype=dtype)], axis=0))
555
556  def x_permuted():
557    return array_ops.transpose(
558        x, perm=smart_cond.smart_cond(source_idx < dest_idx,
559                                      move_right_permutation,
560                                      move_left_permutation))
561
562  # One final conditional to handle the special case where source
563  # and destination indices are equal.
564  return smart_cond.smart_cond(math_ops.equal(source_idx, dest_idx),
565                               lambda: x,
566                               x_permuted)
567