• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Base classes for probability distributions."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import abc
22import contextlib
23import types
24
25import numpy as np
26import six
27
28from tensorflow.python.eager import context
29from tensorflow.python.framework import dtypes
30from tensorflow.python.framework import ops
31from tensorflow.python.framework import tensor_shape
32from tensorflow.python.framework import tensor_util
33from tensorflow.python.ops import array_ops
34from tensorflow.python.ops import math_ops
35from tensorflow.python.ops.distributions import kullback_leibler
36from tensorflow.python.ops.distributions import util
37from tensorflow.python.util import deprecation
38from tensorflow.python.util import tf_inspect
39from tensorflow.python.util.tf_export import tf_export
40
41
42__all__ = [
43    "ReparameterizationType",
44    "FULLY_REPARAMETERIZED",
45    "NOT_REPARAMETERIZED",
46    "Distribution",
47]
48
49_DISTRIBUTION_PUBLIC_METHOD_WRAPPERS = [
50    "batch_shape",
51    "batch_shape_tensor",
52    "cdf",
53    "covariance",
54    "cross_entropy",
55    "entropy",
56    "event_shape",
57    "event_shape_tensor",
58    "kl_divergence",
59    "log_cdf",
60    "log_prob",
61    "log_survival_function",
62    "mean",
63    "mode",
64    "prob",
65    "sample",
66    "stddev",
67    "survival_function",
68    "variance",
69]
70
71
72@six.add_metaclass(abc.ABCMeta)
73class _BaseDistribution(object):
74  """Abstract base class needed for resolving subclass hierarchy."""
75  pass
76
77
78def _copy_fn(fn):
79  """Create a deep copy of fn.
80
81  Args:
82    fn: a callable
83
84  Returns:
85    A `FunctionType`: a deep copy of fn.
86
87  Raises:
88    TypeError: if `fn` is not a callable.
89  """
90  if not callable(fn):
91    raise TypeError("fn is not callable: %s" % fn)
92  # The blessed way to copy a function. copy.deepcopy fails to create a
93  # non-reference copy. Since:
94  #   types.FunctionType == type(lambda: None),
95  # and the docstring for the function type states:
96  #
97  #   function(code, globals[, name[, argdefs[, closure]]])
98  #
99  #   Create a function object from a code object and a dictionary.
100  #   ...
101  #
102  # Here we can use this to create a new function with the old function's
103  # code, globals, closure, etc.
104  return types.FunctionType(
105      code=fn.__code__, globals=fn.__globals__,
106      name=fn.__name__, argdefs=fn.__defaults__,
107      closure=fn.__closure__)
108
109
110def _update_docstring(old_str, append_str):
111  """Update old_str by inserting append_str just before the "Args:" section."""
112  old_str = old_str or ""
113  old_str_lines = old_str.split("\n")
114
115  # Step 0: Prepend spaces to all lines of append_str. This is
116  # necessary for correct markdown generation.
117  append_str = "\n".join("    %s" % line for line in append_str.split("\n"))
118
119  # Step 1: Find mention of "Args":
120  has_args_ix = [
121      ix for ix, line in enumerate(old_str_lines)
122      if line.strip().lower() == "args:"]
123  if has_args_ix:
124    final_args_ix = has_args_ix[-1]
125    return ("\n".join(old_str_lines[:final_args_ix])
126            + "\n\n" + append_str + "\n\n"
127            + "\n".join(old_str_lines[final_args_ix:]))
128  else:
129    return old_str + "\n\n" + append_str
130
131
132def _convert_to_tensor(value, name=None, preferred_dtype=None):
133  """Converts to tensor avoiding an eager bug that loses float precision."""
134  # TODO(b/116672045): Remove this function.
135  if (context.executing_eagerly() and preferred_dtype is not None and
136      (preferred_dtype.is_integer or preferred_dtype.is_bool)):
137    v = ops.convert_to_tensor(value, name=name)
138    if v.dtype.is_floating:
139      return v
140  return ops.convert_to_tensor(
141      value, name=name, preferred_dtype=preferred_dtype)
142
143
144class _DistributionMeta(abc.ABCMeta):
145
146  def __new__(mcs, classname, baseclasses, attrs):
147    """Control the creation of subclasses of the Distribution class.
148
149    The main purpose of this method is to properly propagate docstrings
150    from private Distribution methods, like `_log_prob`, into their
151    public wrappers as inherited by the Distribution base class
152    (e.g. `log_prob`).
153
154    Args:
155      classname: The name of the subclass being created.
156      baseclasses: A tuple of parent classes.
157      attrs: A dict mapping new attributes to their values.
158
159    Returns:
160      The class object.
161
162    Raises:
163      TypeError: If `Distribution` is not a subclass of `BaseDistribution`, or
164        the new class is derived via multiple inheritance and the first
165        parent class is not a subclass of `BaseDistribution`.
166      AttributeError:  If `Distribution` does not implement e.g. `log_prob`.
167      ValueError:  If a `Distribution` public method lacks a docstring.
168    """
169    if not baseclasses:  # Nothing to be done for Distribution
170      raise TypeError("Expected non-empty baseclass. Does Distribution "
171                      "not subclass _BaseDistribution?")
172    which_base = [
173        base for base in baseclasses
174        if base == _BaseDistribution or issubclass(base, Distribution)]
175    base = which_base[0]
176    if base == _BaseDistribution:  # Nothing to be done for Distribution
177      return abc.ABCMeta.__new__(mcs, classname, baseclasses, attrs)
178    if not issubclass(base, Distribution):
179      raise TypeError("First parent class declared for %s must be "
180                      "Distribution, but saw '%s'" % (classname, base.__name__))
181    for attr in _DISTRIBUTION_PUBLIC_METHOD_WRAPPERS:
182      special_attr = "_%s" % attr
183      class_attr_value = attrs.get(attr, None)
184      if attr in attrs:
185        # The method is being overridden, do not update its docstring
186        continue
187      base_attr_value = getattr(base, attr, None)
188      if not base_attr_value:
189        raise AttributeError(
190            "Internal error: expected base class '%s' to implement method '%s'"
191            % (base.__name__, attr))
192      class_special_attr_value = attrs.get(special_attr, None)
193      if class_special_attr_value is None:
194        # No _special method available, no need to update the docstring.
195        continue
196      class_special_attr_docstring = tf_inspect.getdoc(class_special_attr_value)
197      if not class_special_attr_docstring:
198        # No docstring to append.
199        continue
200      class_attr_value = _copy_fn(base_attr_value)
201      class_attr_docstring = tf_inspect.getdoc(base_attr_value)
202      if class_attr_docstring is None:
203        raise ValueError(
204            "Expected base class fn to contain a docstring: %s.%s"
205            % (base.__name__, attr))
206      class_attr_value.__doc__ = _update_docstring(
207          class_attr_value.__doc__,
208          ("Additional documentation from `%s`:\n\n%s"
209           % (classname, class_special_attr_docstring)))
210      attrs[attr] = class_attr_value
211
212    return abc.ABCMeta.__new__(mcs, classname, baseclasses, attrs)
213
214
215@tf_export(v1=["distributions.ReparameterizationType"])
216class ReparameterizationType(object):
217  """Instances of this class represent how sampling is reparameterized.
218
219  Two static instances exist in the distributions library, signifying
220  one of two possible properties for samples from a distribution:
221
222  `FULLY_REPARAMETERIZED`: Samples from the distribution are fully
223    reparameterized, and straight-through gradients are supported.
224
225  `NOT_REPARAMETERIZED`: Samples from the distribution are not fully
226    reparameterized, and straight-through gradients are either partially
227    unsupported or are not supported at all. In this case, for purposes of
228    e.g. RL or variational inference, it is generally safest to wrap the
229    sample results in a `stop_gradients` call and use policy
230    gradients / surrogate loss instead.
231  """
232
233  @deprecation.deprecated(
234      "2019-01-01",
235      "The TensorFlow Distributions library has moved to "
236      "TensorFlow Probability "
237      "(https://github.com/tensorflow/probability). You "
238      "should update all references to use `tfp.distributions` "
239      "instead of `tf.distributions`.",
240      warn_once=True)
241  def __init__(self, rep_type):
242    self._rep_type = rep_type
243
244  def __repr__(self):
245    return "<Reparameteriation Type: %s>" % self._rep_type
246
247  def __eq__(self, other):
248    """Determine if this `ReparameterizationType` is equal to another.
249
250    Since RepaparameterizationType instances are constant static global
251    instances, equality checks if two instances' id() values are equal.
252
253    Args:
254      other: Object to compare against.
255
256    Returns:
257      `self is other`.
258    """
259    return self is other
260
261
262# Fully reparameterized distribution: samples from a fully
263# reparameterized distribution support straight-through gradients with
264# respect to all parameters.
265FULLY_REPARAMETERIZED = ReparameterizationType("FULLY_REPARAMETERIZED")
266tf_export(v1=["distributions.FULLY_REPARAMETERIZED"]).export_constant(
267    __name__, "FULLY_REPARAMETERIZED")
268
269
270# Not reparameterized distribution: samples from a non-
271# reparameterized distribution do not support straight-through gradients for
272# at least some of the parameters.
273NOT_REPARAMETERIZED = ReparameterizationType("NOT_REPARAMETERIZED")
274tf_export(v1=["distributions.NOT_REPARAMETERIZED"]).export_constant(
275    __name__, "NOT_REPARAMETERIZED")
276
277
278@six.add_metaclass(_DistributionMeta)
279@tf_export(v1=["distributions.Distribution"])
280class Distribution(_BaseDistribution):
281  """A generic probability distribution base class.
282
283  `Distribution` is a base class for constructing and organizing properties
284  (e.g., mean, variance) of random variables (e.g, Bernoulli, Gaussian).
285
286  #### Subclassing
287
288  Subclasses are expected to implement a leading-underscore version of the
289  same-named function. The argument signature should be identical except for
290  the omission of `name="..."`. For example, to enable `log_prob(value,
291  name="log_prob")` a subclass should implement `_log_prob(value)`.
292
293  Subclasses can append to public-level docstrings by providing
294  docstrings for their method specializations. For example:
295
296  ```python
297  @util.AppendDocstring("Some other details.")
298  def _log_prob(self, value):
299    ...
300  ```
301
302  would add the string "Some other details." to the `log_prob` function
303  docstring. This is implemented as a simple decorator to avoid python
304  linter complaining about missing Args/Returns/Raises sections in the
305  partial docstrings.
306
307  #### Broadcasting, batching, and shapes
308
309  All distributions support batches of independent distributions of that type.
310  The batch shape is determined by broadcasting together the parameters.
311
312  The shape of arguments to `__init__`, `cdf`, `log_cdf`, `prob`, and
313  `log_prob` reflect this broadcasting, as does the return value of `sample` and
314  `sample_n`.
315
316  `sample_n_shape = [n] + batch_shape + event_shape`, where `sample_n_shape` is
317  the shape of the `Tensor` returned from `sample_n`, `n` is the number of
318  samples, `batch_shape` defines how many independent distributions there are,
319  and `event_shape` defines the shape of samples from each of those independent
320  distributions. Samples are independent along the `batch_shape` dimensions, but
321  not necessarily so along the `event_shape` dimensions (depending on the
322  particulars of the underlying distribution).
323
324  Using the `Uniform` distribution as an example:
325
326  ```python
327  minval = 3.0
328  maxval = [[4.0, 6.0],
329            [10.0, 12.0]]
330
331  # Broadcasting:
332  # This instance represents 4 Uniform distributions. Each has a lower bound at
333  # 3.0 as the `minval` parameter was broadcasted to match `maxval`'s shape.
334  u = Uniform(minval, maxval)
335
336  # `event_shape` is `TensorShape([])`.
337  event_shape = u.event_shape
338  # `event_shape_t` is a `Tensor` which will evaluate to [].
339  event_shape_t = u.event_shape_tensor()
340
341  # Sampling returns a sample per distribution. `samples` has shape
342  # [5, 2, 2], which is [n] + batch_shape + event_shape, where n=5,
343  # batch_shape=[2, 2], and event_shape=[].
344  samples = u.sample_n(5)
345
346  # The broadcasting holds across methods. Here we use `cdf` as an example. The
347  # same holds for `log_cdf` and the likelihood functions.
348
349  # `cum_prob` has shape [2, 2] as the `value` argument was broadcasted to the
350  # shape of the `Uniform` instance.
351  cum_prob_broadcast = u.cdf(4.0)
352
353  # `cum_prob`'s shape is [2, 2], one per distribution. No broadcasting
354  # occurred.
355  cum_prob_per_dist = u.cdf([[4.0, 5.0],
356                             [6.0, 7.0]])
357
358  # INVALID as the `value` argument is not broadcastable to the distribution's
359  # shape.
360  cum_prob_invalid = u.cdf([4.0, 5.0, 6.0])
361  ```
362
363  #### Shapes
364
365  There are three important concepts associated with TensorFlow Distributions
366  shapes:
367  - Event shape describes the shape of a single draw from the distribution;
368    it may be dependent across dimensions. For scalar distributions, the event
369    shape is `[]`. For a 5-dimensional MultivariateNormal, the event shape is
370    `[5]`.
371  - Batch shape describes independent, not identically distributed draws, aka a
372    "collection" or "bunch" of distributions.
373  - Sample shape describes independent, identically distributed draws of batches
374    from the distribution family.
375
376  The event shape and the batch shape are properties of a Distribution object,
377  whereas the sample shape is associated with a specific call to `sample` or
378  `log_prob`.
379
380  For detailed usage examples of TensorFlow Distributions shapes, see
381  [this tutorial](
382  https://github.com/tensorflow/probability/blob/master/tensorflow_probability/examples/jupyter_notebooks/Understanding_TensorFlow_Distributions_Shapes.ipynb)
383
384  #### Parameter values leading to undefined statistics or distributions.
385
386  Some distributions do not have well-defined statistics for all initialization
387  parameter values. For example, the beta distribution is parameterized by
388  positive real numbers `concentration1` and `concentration0`, and does not have
389  well-defined mode if `concentration1 < 1` or `concentration0 < 1`.
390
391  The user is given the option of raising an exception or returning `NaN`.
392
393  ```python
394  a = tf.exp(tf.matmul(logits, weights_a))
395  b = tf.exp(tf.matmul(logits, weights_b))
396
397  # Will raise exception if ANY batch member has a < 1 or b < 1.
398  dist = distributions.beta(a, b, allow_nan_stats=False)
399  mode = dist.mode().eval()
400
401  # Will return NaN for batch members with either a < 1 or b < 1.
402  dist = distributions.beta(a, b, allow_nan_stats=True)  # Default behavior
403  mode = dist.mode().eval()
404  ```
405
406  In all cases, an exception is raised if *invalid* parameters are passed, e.g.
407
408  ```python
409  # Will raise an exception if any Op is run.
410  negative_a = -1.0 * a  # beta distribution by definition has a > 0.
411  dist = distributions.beta(negative_a, b, allow_nan_stats=True)
412  dist.mean().eval()
413  ```
414
415  """
416
417  @deprecation.deprecated(
418      "2019-01-01",
419      "The TensorFlow Distributions library has moved to "
420      "TensorFlow Probability "
421      "(https://github.com/tensorflow/probability). You "
422      "should update all references to use `tfp.distributions` "
423      "instead of `tf.distributions`.",
424      warn_once=True)
425  def __init__(self,
426               dtype,
427               reparameterization_type,
428               validate_args,
429               allow_nan_stats,
430               parameters=None,
431               graph_parents=None,
432               name=None):
433    """Constructs the `Distribution`.
434
435    **This is a private method for subclass use.**
436
437    Args:
438      dtype: The type of the event samples. `None` implies no type-enforcement.
439      reparameterization_type: Instance of `ReparameterizationType`.
440        If `distributions.FULLY_REPARAMETERIZED`, this
441        `Distribution` can be reparameterized in terms of some standard
442        distribution with a function whose Jacobian is constant for the support
443        of the standard distribution. If `distributions.NOT_REPARAMETERIZED`,
444        then no such reparameterization is available.
445      validate_args: Python `bool`, default `False`. When `True` distribution
446        parameters are checked for validity despite possibly degrading runtime
447        performance. When `False` invalid inputs may silently render incorrect
448        outputs.
449      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
450        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
451        result is undefined. When `False`, an exception is raised if one or
452        more of the statistic's batch members are undefined.
453      parameters: Python `dict` of parameters used to instantiate this
454        `Distribution`.
455      graph_parents: Python `list` of graph prerequisites of this
456        `Distribution`.
457      name: Python `str` name prefixed to Ops created by this class. Default:
458        subclass name.
459
460    Raises:
461      ValueError: if any member of graph_parents is `None` or not a `Tensor`.
462    """
463    graph_parents = [] if graph_parents is None else graph_parents
464    for i, t in enumerate(graph_parents):
465      if t is None or not tensor_util.is_tensor(t):
466        raise ValueError("Graph parent item %d is not a Tensor; %s." % (i, t))
467    if not name or name[-1] != "/":  # `name` is not a name scope
468      non_unique_name = name or type(self).__name__
469      with ops.name_scope(non_unique_name) as name:
470        pass
471    self._dtype = dtype
472    self._reparameterization_type = reparameterization_type
473    self._allow_nan_stats = allow_nan_stats
474    self._validate_args = validate_args
475    self._parameters = parameters or {}
476    self._graph_parents = graph_parents
477    self._name = name
478
479  @property
480  def _parameters(self):
481    return self._parameter_dict
482
483  @_parameters.setter
484  def _parameters(self, value):
485    """Intercept assignments to self._parameters to avoid reference cycles.
486
487    Parameters are often created using locals(), so we need to clean out any
488    references to `self` before assigning it to an attribute.
489
490    Args:
491      value: A dictionary of parameters to assign to the `_parameters` property.
492    """
493    if "self" in value:
494      del value["self"]
495    self._parameter_dict = value
496
497  @classmethod
498  def param_shapes(cls, sample_shape, name="DistributionParamShapes"):
499    """Shapes of parameters given the desired shape of a call to `sample()`.
500
501    This is a class method that describes what key/value arguments are required
502    to instantiate the given `Distribution` so that a particular shape is
503    returned for that instance's call to `sample()`.
504
505    Subclasses should override class method `_param_shapes`.
506
507    Args:
508      sample_shape: `Tensor` or python list/tuple. Desired shape of a call to
509        `sample()`.
510      name: name to prepend ops with.
511
512    Returns:
513      `dict` of parameter name to `Tensor` shapes.
514    """
515    with ops.name_scope(name, values=[sample_shape]):
516      return cls._param_shapes(sample_shape)
517
518  @classmethod
519  def param_static_shapes(cls, sample_shape):
520    """param_shapes with static (i.e. `TensorShape`) shapes.
521
522    This is a class method that describes what key/value arguments are required
523    to instantiate the given `Distribution` so that a particular shape is
524    returned for that instance's call to `sample()`. Assumes that the sample's
525    shape is known statically.
526
527    Subclasses should override class method `_param_shapes` to return
528    constant-valued tensors when constant values are fed.
529
530    Args:
531      sample_shape: `TensorShape` or python list/tuple. Desired shape of a call
532        to `sample()`.
533
534    Returns:
535      `dict` of parameter name to `TensorShape`.
536
537    Raises:
538      ValueError: if `sample_shape` is a `TensorShape` and is not fully defined.
539    """
540    if isinstance(sample_shape, tensor_shape.TensorShape):
541      if not sample_shape.is_fully_defined():
542        raise ValueError("TensorShape sample_shape must be fully defined")
543      sample_shape = sample_shape.as_list()
544
545    params = cls.param_shapes(sample_shape)
546
547    static_params = {}
548    for name, shape in params.items():
549      static_shape = tensor_util.constant_value(shape)
550      if static_shape is None:
551        raise ValueError(
552            "sample_shape must be a fully-defined TensorShape or list/tuple")
553      static_params[name] = tensor_shape.TensorShape(static_shape)
554
555    return static_params
556
557  @staticmethod
558  def _param_shapes(sample_shape):
559    raise NotImplementedError("_param_shapes not implemented")
560
561  @property
562  def name(self):
563    """Name prepended to all ops created by this `Distribution`."""
564    return self._name
565
566  @property
567  def dtype(self):
568    """The `DType` of `Tensor`s handled by this `Distribution`."""
569    return self._dtype
570
571  @property
572  def parameters(self):
573    """Dictionary of parameters used to instantiate this `Distribution`."""
574    # Remove "self", "__class__", or other special variables. These can appear
575    # if the subclass used:
576    # `parameters = dict(locals())`.
577    return {k: v for k, v in self._parameters.items()
578            if not k.startswith("__") and k != "self"}
579
580  @property
581  def reparameterization_type(self):
582    """Describes how samples from the distribution are reparameterized.
583
584    Currently this is one of the static instances
585    `distributions.FULLY_REPARAMETERIZED`
586    or `distributions.NOT_REPARAMETERIZED`.
587
588    Returns:
589      An instance of `ReparameterizationType`.
590    """
591    return self._reparameterization_type
592
593  @property
594  def allow_nan_stats(self):
595    """Python `bool` describing behavior when a stat is undefined.
596
597    Stats return +/- infinity when it makes sense. E.g., the variance of a
598    Cauchy distribution is infinity. However, sometimes the statistic is
599    undefined, e.g., if a distribution's pdf does not achieve a maximum within
600    the support of the distribution, the mode is undefined. If the mean is
601    undefined, then by definition the variance is undefined. E.g. the mean for
602    Student's T for df = 1 is undefined (no clear way to say it is either + or -
603    infinity), so the variance = E[(X - mean)**2] is also undefined.
604
605    Returns:
606      allow_nan_stats: Python `bool`.
607    """
608    return self._allow_nan_stats
609
610  @property
611  def validate_args(self):
612    """Python `bool` indicating possibly expensive checks are enabled."""
613    return self._validate_args
614
615  def copy(self, **override_parameters_kwargs):
616    """Creates a deep copy of the distribution.
617
618    Note: the copy distribution may continue to depend on the original
619    initialization arguments.
620
621    Args:
622      **override_parameters_kwargs: String/value dictionary of initialization
623        arguments to override with new values.
624
625    Returns:
626      distribution: A new instance of `type(self)` initialized from the union
627        of self.parameters and override_parameters_kwargs, i.e.,
628        `dict(self.parameters, **override_parameters_kwargs)`.
629    """
630    parameters = dict(self.parameters, **override_parameters_kwargs)
631    return type(self)(**parameters)
632
633  def _batch_shape_tensor(self):
634    raise NotImplementedError(
635        "batch_shape_tensor is not implemented: {}".format(type(self).__name__))
636
637  def batch_shape_tensor(self, name="batch_shape_tensor"):
638    """Shape of a single sample from a single event index as a 1-D `Tensor`.
639
640    The batch dimensions are indexes into independent, non-identical
641    parameterizations of this distribution.
642
643    Args:
644      name: name to give to the op
645
646    Returns:
647      batch_shape: `Tensor`.
648    """
649    with self._name_scope(name):
650      if self.batch_shape.is_fully_defined():
651        return ops.convert_to_tensor(self.batch_shape.as_list(),
652                                     dtype=dtypes.int32,
653                                     name="batch_shape")
654      return self._batch_shape_tensor()
655
656  def _batch_shape(self):
657    return tensor_shape.TensorShape(None)
658
659  @property
660  def batch_shape(self):
661    """Shape of a single sample from a single event index as a `TensorShape`.
662
663    May be partially defined or unknown.
664
665    The batch dimensions are indexes into independent, non-identical
666    parameterizations of this distribution.
667
668    Returns:
669      batch_shape: `TensorShape`, possibly unknown.
670    """
671    return tensor_shape.as_shape(self._batch_shape())
672
673  def _event_shape_tensor(self):
674    raise NotImplementedError(
675        "event_shape_tensor is not implemented: {}".format(type(self).__name__))
676
677  def event_shape_tensor(self, name="event_shape_tensor"):
678    """Shape of a single sample from a single batch as a 1-D int32 `Tensor`.
679
680    Args:
681      name: name to give to the op
682
683    Returns:
684      event_shape: `Tensor`.
685    """
686    with self._name_scope(name):
687      if self.event_shape.is_fully_defined():
688        return ops.convert_to_tensor(self.event_shape.as_list(),
689                                     dtype=dtypes.int32,
690                                     name="event_shape")
691      return self._event_shape_tensor()
692
693  def _event_shape(self):
694    return tensor_shape.TensorShape(None)
695
696  @property
697  def event_shape(self):
698    """Shape of a single sample from a single batch as a `TensorShape`.
699
700    May be partially defined or unknown.
701
702    Returns:
703      event_shape: `TensorShape`, possibly unknown.
704    """
705    return tensor_shape.as_shape(self._event_shape())
706
707  def is_scalar_event(self, name="is_scalar_event"):
708    """Indicates that `event_shape == []`.
709
710    Args:
711      name: Python `str` prepended to names of ops created by this function.
712
713    Returns:
714      is_scalar_event: `bool` scalar `Tensor`.
715    """
716    with self._name_scope(name):
717      return ops.convert_to_tensor(
718          self._is_scalar_helper(self.event_shape, self.event_shape_tensor),
719          name="is_scalar_event")
720
721  def is_scalar_batch(self, name="is_scalar_batch"):
722    """Indicates that `batch_shape == []`.
723
724    Args:
725      name: Python `str` prepended to names of ops created by this function.
726
727    Returns:
728      is_scalar_batch: `bool` scalar `Tensor`.
729    """
730    with self._name_scope(name):
731      return ops.convert_to_tensor(
732          self._is_scalar_helper(self.batch_shape, self.batch_shape_tensor),
733          name="is_scalar_batch")
734
735  def _sample_n(self, n, seed=None):
736    raise NotImplementedError("sample_n is not implemented: {}".format(
737        type(self).__name__))
738
739  def _call_sample_n(self, sample_shape, seed, name, **kwargs):
740    with self._name_scope(name, values=[sample_shape]):
741      sample_shape = ops.convert_to_tensor(
742          sample_shape, dtype=dtypes.int32, name="sample_shape")
743      sample_shape, n = self._expand_sample_shape_to_vector(
744          sample_shape, "sample_shape")
745      samples = self._sample_n(n, seed, **kwargs)
746      batch_event_shape = array_ops.shape(samples)[1:]
747      final_shape = array_ops.concat([sample_shape, batch_event_shape], 0)
748      samples = array_ops.reshape(samples, final_shape)
749      samples = self._set_sample_static_shape(samples, sample_shape)
750      return samples
751
752  def sample(self, sample_shape=(), seed=None, name="sample"):
753    """Generate samples of the specified shape.
754
755    Note that a call to `sample()` without arguments will generate a single
756    sample.
757
758    Args:
759      sample_shape: 0D or 1D `int32` `Tensor`. Shape of the generated samples.
760      seed: Python integer seed for RNG
761      name: name to give to the op.
762
763    Returns:
764      samples: a `Tensor` with prepended dimensions `sample_shape`.
765    """
766    return self._call_sample_n(sample_shape, seed, name)
767
768  def _log_prob(self, value):
769    raise NotImplementedError("log_prob is not implemented: {}".format(
770        type(self).__name__))
771
772  def _call_log_prob(self, value, name, **kwargs):
773    with self._name_scope(name, values=[value]):
774      value = _convert_to_tensor(
775          value, name="value", preferred_dtype=self.dtype)
776      try:
777        return self._log_prob(value, **kwargs)
778      except NotImplementedError as original_exception:
779        try:
780          return math_ops.log(self._prob(value, **kwargs))
781        except NotImplementedError:
782          raise original_exception
783
784  def log_prob(self, value, name="log_prob"):
785    """Log probability density/mass function.
786
787    Args:
788      value: `float` or `double` `Tensor`.
789      name: Python `str` prepended to names of ops created by this function.
790
791    Returns:
792      log_prob: a `Tensor` of shape `sample_shape(x) + self.batch_shape` with
793        values of type `self.dtype`.
794    """
795    return self._call_log_prob(value, name)
796
797  def _prob(self, value):
798    raise NotImplementedError("prob is not implemented: {}".format(
799        type(self).__name__))
800
801  def _call_prob(self, value, name, **kwargs):
802    with self._name_scope(name, values=[value]):
803      value = _convert_to_tensor(
804          value, name="value", preferred_dtype=self.dtype)
805      try:
806        return self._prob(value, **kwargs)
807      except NotImplementedError as original_exception:
808        try:
809          return math_ops.exp(self._log_prob(value, **kwargs))
810        except NotImplementedError:
811          raise original_exception
812
813  def prob(self, value, name="prob"):
814    """Probability density/mass function.
815
816    Args:
817      value: `float` or `double` `Tensor`.
818      name: Python `str` prepended to names of ops created by this function.
819
820    Returns:
821      prob: a `Tensor` of shape `sample_shape(x) + self.batch_shape` with
822        values of type `self.dtype`.
823    """
824    return self._call_prob(value, name)
825
826  def _log_cdf(self, value):
827    raise NotImplementedError("log_cdf is not implemented: {}".format(
828        type(self).__name__))
829
830  def _call_log_cdf(self, value, name, **kwargs):
831    with self._name_scope(name, values=[value]):
832      value = _convert_to_tensor(
833          value, name="value", preferred_dtype=self.dtype)
834      try:
835        return self._log_cdf(value, **kwargs)
836      except NotImplementedError as original_exception:
837        try:
838          return math_ops.log(self._cdf(value, **kwargs))
839        except NotImplementedError:
840          raise original_exception
841
842  def log_cdf(self, value, name="log_cdf"):
843    """Log cumulative distribution function.
844
845    Given random variable `X`, the cumulative distribution function `cdf` is:
846
847    ```none
848    log_cdf(x) := Log[ P[X <= x] ]
849    ```
850
851    Often, a numerical approximation can be used for `log_cdf(x)` that yields
852    a more accurate answer than simply taking the logarithm of the `cdf` when
853    `x << -1`.
854
855    Args:
856      value: `float` or `double` `Tensor`.
857      name: Python `str` prepended to names of ops created by this function.
858
859    Returns:
860      logcdf: a `Tensor` of shape `sample_shape(x) + self.batch_shape` with
861        values of type `self.dtype`.
862    """
863    return self._call_log_cdf(value, name)
864
865  def _cdf(self, value):
866    raise NotImplementedError("cdf is not implemented: {}".format(
867        type(self).__name__))
868
869  def _call_cdf(self, value, name, **kwargs):
870    with self._name_scope(name, values=[value]):
871      value = _convert_to_tensor(
872          value, name="value", preferred_dtype=self.dtype)
873      try:
874        return self._cdf(value, **kwargs)
875      except NotImplementedError as original_exception:
876        try:
877          return math_ops.exp(self._log_cdf(value, **kwargs))
878        except NotImplementedError:
879          raise original_exception
880
881  def cdf(self, value, name="cdf"):
882    """Cumulative distribution function.
883
884    Given random variable `X`, the cumulative distribution function `cdf` is:
885
886    ```none
887    cdf(x) := P[X <= x]
888    ```
889
890    Args:
891      value: `float` or `double` `Tensor`.
892      name: Python `str` prepended to names of ops created by this function.
893
894    Returns:
895      cdf: a `Tensor` of shape `sample_shape(x) + self.batch_shape` with
896        values of type `self.dtype`.
897    """
898    return self._call_cdf(value, name)
899
900  def _log_survival_function(self, value):
901    raise NotImplementedError(
902        "log_survival_function is not implemented: {}".format(
903            type(self).__name__))
904
905  def _call_log_survival_function(self, value, name, **kwargs):
906    with self._name_scope(name, values=[value]):
907      value = _convert_to_tensor(
908          value, name="value", preferred_dtype=self.dtype)
909      try:
910        return self._log_survival_function(value, **kwargs)
911      except NotImplementedError as original_exception:
912        try:
913          return math_ops.log1p(-self.cdf(value, **kwargs))
914        except NotImplementedError:
915          raise original_exception
916
917  def log_survival_function(self, value, name="log_survival_function"):
918    """Log survival function.
919
920    Given random variable `X`, the survival function is defined:
921
922    ```none
923    log_survival_function(x) = Log[ P[X > x] ]
924                             = Log[ 1 - P[X <= x] ]
925                             = Log[ 1 - cdf(x) ]
926    ```
927
928    Typically, different numerical approximations can be used for the log
929    survival function, which are more accurate than `1 - cdf(x)` when `x >> 1`.
930
931    Args:
932      value: `float` or `double` `Tensor`.
933      name: Python `str` prepended to names of ops created by this function.
934
935    Returns:
936      `Tensor` of shape `sample_shape(x) + self.batch_shape` with values of type
937        `self.dtype`.
938    """
939    return self._call_log_survival_function(value, name)
940
941  def _survival_function(self, value):
942    raise NotImplementedError("survival_function is not implemented: {}".format(
943        type(self).__name__))
944
945  def _call_survival_function(self, value, name, **kwargs):
946    with self._name_scope(name, values=[value]):
947      value = _convert_to_tensor(
948          value, name="value", preferred_dtype=self.dtype)
949      try:
950        return self._survival_function(value, **kwargs)
951      except NotImplementedError as original_exception:
952        try:
953          return 1. - self.cdf(value, **kwargs)
954        except NotImplementedError:
955          raise original_exception
956
957  def survival_function(self, value, name="survival_function"):
958    """Survival function.
959
960    Given random variable `X`, the survival function is defined:
961
962    ```none
963    survival_function(x) = P[X > x]
964                         = 1 - P[X <= x]
965                         = 1 - cdf(x).
966    ```
967
968    Args:
969      value: `float` or `double` `Tensor`.
970      name: Python `str` prepended to names of ops created by this function.
971
972    Returns:
973      `Tensor` of shape `sample_shape(x) + self.batch_shape` with values of type
974        `self.dtype`.
975    """
976    return self._call_survival_function(value, name)
977
978  def _entropy(self):
979    raise NotImplementedError("entropy is not implemented: {}".format(
980        type(self).__name__))
981
982  def entropy(self, name="entropy"):
983    """Shannon entropy in nats."""
984    with self._name_scope(name):
985      return self._entropy()
986
987  def _mean(self):
988    raise NotImplementedError("mean is not implemented: {}".format(
989        type(self).__name__))
990
991  def mean(self, name="mean"):
992    """Mean."""
993    with self._name_scope(name):
994      return self._mean()
995
996  def _quantile(self, value):
997    raise NotImplementedError("quantile is not implemented: {}".format(
998        type(self).__name__))
999
1000  def _call_quantile(self, value, name, **kwargs):
1001    with self._name_scope(name, values=[value]):
1002      value = _convert_to_tensor(
1003          value, name="value", preferred_dtype=self.dtype)
1004      return self._quantile(value, **kwargs)
1005
1006  def quantile(self, value, name="quantile"):
1007    """Quantile function. Aka "inverse cdf" or "percent point function".
1008
1009    Given random variable `X` and `p in [0, 1]`, the `quantile` is:
1010
1011    ```none
1012    quantile(p) := x such that P[X <= x] == p
1013    ```
1014
1015    Args:
1016      value: `float` or `double` `Tensor`.
1017      name: Python `str` prepended to names of ops created by this function.
1018
1019    Returns:
1020      quantile: a `Tensor` of shape `sample_shape(x) + self.batch_shape` with
1021        values of type `self.dtype`.
1022    """
1023    return self._call_quantile(value, name)
1024
1025  def _variance(self):
1026    raise NotImplementedError("variance is not implemented: {}".format(
1027        type(self).__name__))
1028
1029  def variance(self, name="variance"):
1030    """Variance.
1031
1032    Variance is defined as,
1033
1034    ```none
1035    Var = E[(X - E[X])**2]
1036    ```
1037
1038    where `X` is the random variable associated with this distribution, `E`
1039    denotes expectation, and `Var.shape = batch_shape + event_shape`.
1040
1041    Args:
1042      name: Python `str` prepended to names of ops created by this function.
1043
1044    Returns:
1045      variance: Floating-point `Tensor` with shape identical to
1046        `batch_shape + event_shape`, i.e., the same shape as `self.mean()`.
1047    """
1048    with self._name_scope(name):
1049      try:
1050        return self._variance()
1051      except NotImplementedError as original_exception:
1052        try:
1053          return math_ops.square(self._stddev())
1054        except NotImplementedError:
1055          raise original_exception
1056
1057  def _stddev(self):
1058    raise NotImplementedError("stddev is not implemented: {}".format(
1059        type(self).__name__))
1060
1061  def stddev(self, name="stddev"):
1062    """Standard deviation.
1063
1064    Standard deviation is defined as,
1065
1066    ```none
1067    stddev = E[(X - E[X])**2]**0.5
1068    ```
1069
1070    where `X` is the random variable associated with this distribution, `E`
1071    denotes expectation, and `stddev.shape = batch_shape + event_shape`.
1072
1073    Args:
1074      name: Python `str` prepended to names of ops created by this function.
1075
1076    Returns:
1077      stddev: Floating-point `Tensor` with shape identical to
1078        `batch_shape + event_shape`, i.e., the same shape as `self.mean()`.
1079    """
1080
1081    with self._name_scope(name):
1082      try:
1083        return self._stddev()
1084      except NotImplementedError as original_exception:
1085        try:
1086          return math_ops.sqrt(self._variance())
1087        except NotImplementedError:
1088          raise original_exception
1089
1090  def _covariance(self):
1091    raise NotImplementedError("covariance is not implemented: {}".format(
1092        type(self).__name__))
1093
1094  def covariance(self, name="covariance"):
1095    """Covariance.
1096
1097    Covariance is (possibly) defined only for non-scalar-event distributions.
1098
1099    For example, for a length-`k`, vector-valued distribution, it is calculated
1100    as,
1101
1102    ```none
1103    Cov[i, j] = Covariance(X_i, X_j) = E[(X_i - E[X_i]) (X_j - E[X_j])]
1104    ```
1105
1106    where `Cov` is a (batch of) `k x k` matrix, `0 <= (i, j) < k`, and `E`
1107    denotes expectation.
1108
1109    Alternatively, for non-vector, multivariate distributions (e.g.,
1110    matrix-valued, Wishart), `Covariance` shall return a (batch of) matrices
1111    under some vectorization of the events, i.e.,
1112
1113    ```none
1114    Cov[i, j] = Covariance(Vec(X)_i, Vec(X)_j) = [as above]
1115    ```
1116
1117    where `Cov` is a (batch of) `k' x k'` matrices,
1118    `0 <= (i, j) < k' = reduce_prod(event_shape)`, and `Vec` is some function
1119    mapping indices of this distribution's event dimensions to indices of a
1120    length-`k'` vector.
1121
1122    Args:
1123      name: Python `str` prepended to names of ops created by this function.
1124
1125    Returns:
1126      covariance: Floating-point `Tensor` with shape `[B1, ..., Bn, k', k']`
1127        where the first `n` dimensions are batch coordinates and
1128        `k' = reduce_prod(self.event_shape)`.
1129    """
1130    with self._name_scope(name):
1131      return self._covariance()
1132
1133  def _mode(self):
1134    raise NotImplementedError("mode is not implemented: {}".format(
1135        type(self).__name__))
1136
1137  def mode(self, name="mode"):
1138    """Mode."""
1139    with self._name_scope(name):
1140      return self._mode()
1141
1142  def _cross_entropy(self, other):
1143    return kullback_leibler.cross_entropy(
1144        self, other, allow_nan_stats=self.allow_nan_stats)
1145
1146  def cross_entropy(self, other, name="cross_entropy"):
1147    """Computes the (Shannon) cross entropy.
1148
1149    Denote this distribution (`self`) by `P` and the `other` distribution by
1150    `Q`. Assuming `P, Q` are absolutely continuous with respect to
1151    one another and permit densities `p(x) dr(x)` and `q(x) dr(x)`, (Shanon)
1152    cross entropy is defined as:
1153
1154    ```none
1155    H[P, Q] = E_p[-log q(X)] = -int_F p(x) log q(x) dr(x)
1156    ```
1157
1158    where `F` denotes the support of the random variable `X ~ P`.
1159
1160    Args:
1161      other: `tfp.distributions.Distribution` instance.
1162      name: Python `str` prepended to names of ops created by this function.
1163
1164    Returns:
1165      cross_entropy: `self.dtype` `Tensor` with shape `[B1, ..., Bn]`
1166        representing `n` different calculations of (Shanon) cross entropy.
1167    """
1168    with self._name_scope(name):
1169      return self._cross_entropy(other)
1170
1171  def _kl_divergence(self, other):
1172    return kullback_leibler.kl_divergence(
1173        self, other, allow_nan_stats=self.allow_nan_stats)
1174
1175  def kl_divergence(self, other, name="kl_divergence"):
1176    """Computes the Kullback--Leibler divergence.
1177
1178    Denote this distribution (`self`) by `p` and the `other` distribution by
1179    `q`. Assuming `p, q` are absolutely continuous with respect to reference
1180    measure `r`, the KL divergence is defined as:
1181
1182    ```none
1183    KL[p, q] = E_p[log(p(X)/q(X))]
1184             = -int_F p(x) log q(x) dr(x) + int_F p(x) log p(x) dr(x)
1185             = H[p, q] - H[p]
1186    ```
1187
1188    where `F` denotes the support of the random variable `X ~ p`, `H[., .]`
1189    denotes (Shanon) cross entropy, and `H[.]` denotes (Shanon) entropy.
1190
1191    Args:
1192      other: `tfp.distributions.Distribution` instance.
1193      name: Python `str` prepended to names of ops created by this function.
1194
1195    Returns:
1196      kl_divergence: `self.dtype` `Tensor` with shape `[B1, ..., Bn]`
1197        representing `n` different calculations of the Kullback-Leibler
1198        divergence.
1199    """
1200    with self._name_scope(name):
1201      return self._kl_divergence(other)
1202
1203  def __str__(self):
1204    return ("tfp.distributions.{type_name}("
1205            "\"{self_name}\""
1206            "{maybe_batch_shape}"
1207            "{maybe_event_shape}"
1208            ", dtype={dtype})".format(
1209                type_name=type(self).__name__,
1210                self_name=self.name,
1211                maybe_batch_shape=(", batch_shape={}".format(self.batch_shape)
1212                                   if self.batch_shape.ndims is not None
1213                                   else ""),
1214                maybe_event_shape=(", event_shape={}".format(self.event_shape)
1215                                   if self.event_shape.ndims is not None
1216                                   else ""),
1217                dtype=self.dtype.name))
1218
1219  def __repr__(self):
1220    return ("<tfp.distributions.{type_name} "
1221            "'{self_name}'"
1222            " batch_shape={batch_shape}"
1223            " event_shape={event_shape}"
1224            " dtype={dtype}>".format(
1225                type_name=type(self).__name__,
1226                self_name=self.name,
1227                batch_shape=self.batch_shape,
1228                event_shape=self.event_shape,
1229                dtype=self.dtype.name))
1230
1231  @contextlib.contextmanager
1232  def _name_scope(self, name=None, values=None):
1233    """Helper function to standardize op scope."""
1234    with ops.name_scope(self.name):
1235      with ops.name_scope(name, values=(
1236          ([] if values is None else values) + self._graph_parents)) as scope:
1237        yield scope
1238
1239  def _expand_sample_shape_to_vector(self, x, name):
1240    """Helper to `sample` which ensures input is 1D."""
1241    x_static_val = tensor_util.constant_value(x)
1242    if x_static_val is None:
1243      prod = math_ops.reduce_prod(x)
1244    else:
1245      prod = np.prod(x_static_val, dtype=x.dtype.as_numpy_dtype())
1246
1247    ndims = x.get_shape().ndims  # != sample_ndims
1248    if ndims is None:
1249      # Maybe expand_dims.
1250      ndims = array_ops.rank(x)
1251      expanded_shape = util.pick_vector(
1252          math_ops.equal(ndims, 0),
1253          np.array([1], dtype=np.int32), array_ops.shape(x))
1254      x = array_ops.reshape(x, expanded_shape)
1255    elif ndims == 0:
1256      # Definitely expand_dims.
1257      if x_static_val is not None:
1258        x = ops.convert_to_tensor(
1259            np.array([x_static_val], dtype=x.dtype.as_numpy_dtype()),
1260            name=name)
1261      else:
1262        x = array_ops.reshape(x, [1])
1263    elif ndims != 1:
1264      raise ValueError("Input is neither scalar nor vector.")
1265
1266    return x, prod
1267
1268  def _set_sample_static_shape(self, x, sample_shape):
1269    """Helper to `sample`; sets static shape info."""
1270    # Set shape hints.
1271    sample_shape = tensor_shape.TensorShape(
1272        tensor_util.constant_value(sample_shape))
1273
1274    ndims = x.get_shape().ndims
1275    sample_ndims = sample_shape.ndims
1276    batch_ndims = self.batch_shape.ndims
1277    event_ndims = self.event_shape.ndims
1278
1279    # Infer rank(x).
1280    if (ndims is None and
1281        sample_ndims is not None and
1282        batch_ndims is not None and
1283        event_ndims is not None):
1284      ndims = sample_ndims + batch_ndims + event_ndims
1285      x.set_shape([None] * ndims)
1286
1287    # Infer sample shape.
1288    if ndims is not None and sample_ndims is not None:
1289      shape = sample_shape.concatenate([None]*(ndims - sample_ndims))
1290      x.set_shape(x.get_shape().merge_with(shape))
1291
1292    # Infer event shape.
1293    if ndims is not None and event_ndims is not None:
1294      shape = tensor_shape.TensorShape(
1295          [None]*(ndims - event_ndims)).concatenate(self.event_shape)
1296      x.set_shape(x.get_shape().merge_with(shape))
1297
1298    # Infer batch shape.
1299    if batch_ndims is not None:
1300      if ndims is not None:
1301        if sample_ndims is None and event_ndims is not None:
1302          sample_ndims = ndims - batch_ndims - event_ndims
1303        elif event_ndims is None and sample_ndims is not None:
1304          event_ndims = ndims - batch_ndims - sample_ndims
1305      if sample_ndims is not None and event_ndims is not None:
1306        shape = tensor_shape.TensorShape([None]*sample_ndims).concatenate(
1307            self.batch_shape).concatenate([None]*event_ndims)
1308        x.set_shape(x.get_shape().merge_with(shape))
1309
1310    return x
1311
1312  def _is_scalar_helper(self, static_shape, dynamic_shape_fn):
1313    """Implementation for `is_scalar_batch` and `is_scalar_event`."""
1314    if static_shape.ndims is not None:
1315      return static_shape.ndims == 0
1316    shape = dynamic_shape_fn()
1317    if (shape.get_shape().ndims is not None and
1318        shape.get_shape().dims[0].value is not None):
1319      # If the static_shape is correctly written then we should never execute
1320      # this branch. We keep it just in case there's some unimagined corner
1321      # case.
1322      return shape.get_shape().as_list() == [0]
1323    return math_ops.equal(array_ops.shape(shape)[0], 0)
1324