• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Base classes for probability distributions."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import abc
22import contextlib
23import types
24
25import numpy as np
26import six
27
28from tensorflow.python.framework import dtypes
29from tensorflow.python.framework import ops
30from tensorflow.python.framework import tensor_shape
31from tensorflow.python.framework import tensor_util
32from tensorflow.python.ops import array_ops
33from tensorflow.python.ops import math_ops
34from tensorflow.python.ops.distributions import kullback_leibler
35from tensorflow.python.ops.distributions import util
36from tensorflow.python.util import tf_inspect
37from tensorflow.python.util.tf_export import tf_export
38
39
40__all__ = [
41    "ReparameterizationType",
42    "FULLY_REPARAMETERIZED",
43    "NOT_REPARAMETERIZED",
44    "Distribution",
45]
46
47_DISTRIBUTION_PUBLIC_METHOD_WRAPPERS = [
48    "batch_shape",
49    "batch_shape_tensor",
50    "cdf",
51    "covariance",
52    "cross_entropy",
53    "entropy",
54    "event_shape",
55    "event_shape_tensor",
56    "kl_divergence",
57    "log_cdf",
58    "log_prob",
59    "log_survival_function",
60    "mean",
61    "mode",
62    "prob",
63    "sample",
64    "stddev",
65    "survival_function",
66    "variance",
67]
68
69
70@six.add_metaclass(abc.ABCMeta)
71class _BaseDistribution(object):
72  """Abstract base class needed for resolving subclass hierarchy."""
73  pass
74
75
76def _copy_fn(fn):
77  """Create a deep copy of fn.
78
79  Args:
80    fn: a callable
81
82  Returns:
83    A `FunctionType`: a deep copy of fn.
84
85  Raises:
86    TypeError: if `fn` is not a callable.
87  """
88  if not callable(fn):
89    raise TypeError("fn is not callable: %s" % fn)
90  # The blessed way to copy a function. copy.deepcopy fails to create a
91  # non-reference copy. Since:
92  #   types.FunctionType == type(lambda: None),
93  # and the docstring for the function type states:
94  #
95  #   function(code, globals[, name[, argdefs[, closure]]])
96  #
97  #   Create a function object from a code object and a dictionary.
98  #   ...
99  #
100  # Here we can use this to create a new function with the old function's
101  # code, globals, closure, etc.
102  return types.FunctionType(
103      code=fn.__code__, globals=fn.__globals__,
104      name=fn.__name__, argdefs=fn.__defaults__,
105      closure=fn.__closure__)
106
107
108def _update_docstring(old_str, append_str):
109  """Update old_str by inserting append_str just before the "Args:" section."""
110  old_str = old_str or ""
111  old_str_lines = old_str.split("\n")
112
113  # Step 0: Prepend spaces to all lines of append_str. This is
114  # necessary for correct markdown generation.
115  append_str = "\n".join("    %s" % line for line in append_str.split("\n"))
116
117  # Step 1: Find mention of "Args":
118  has_args_ix = [
119      ix for ix, line in enumerate(old_str_lines)
120      if line.strip().lower() == "args:"]
121  if has_args_ix:
122    final_args_ix = has_args_ix[-1]
123    return ("\n".join(old_str_lines[:final_args_ix])
124            + "\n\n" + append_str + "\n\n"
125            + "\n".join(old_str_lines[final_args_ix:]))
126  else:
127    return old_str + "\n\n" + append_str
128
129
130class _DistributionMeta(abc.ABCMeta):
131
132  def __new__(mcs, classname, baseclasses, attrs):
133    """Control the creation of subclasses of the Distribution class.
134
135    The main purpose of this method is to properly propagate docstrings
136    from private Distribution methods, like `_log_prob`, into their
137    public wrappers as inherited by the Distribution base class
138    (e.g. `log_prob`).
139
140    Args:
141      classname: The name of the subclass being created.
142      baseclasses: A tuple of parent classes.
143      attrs: A dict mapping new attributes to their values.
144
145    Returns:
146      The class object.
147
148    Raises:
149      TypeError: If `Distribution` is not a subclass of `BaseDistribution`, or
150        the new class is derived via multiple inheritance and the first
151        parent class is not a subclass of `BaseDistribution`.
152      AttributeError:  If `Distribution` does not implement e.g. `log_prob`.
153      ValueError:  If a `Distribution` public method lacks a docstring.
154    """
155    if not baseclasses:  # Nothing to be done for Distribution
156      raise TypeError("Expected non-empty baseclass. Does Distribution "
157                      "not subclass _BaseDistribution?")
158    which_base = [
159        base for base in baseclasses
160        if base == _BaseDistribution or issubclass(base, Distribution)]
161    base = which_base[0]
162    if base == _BaseDistribution:  # Nothing to be done for Distribution
163      return abc.ABCMeta.__new__(mcs, classname, baseclasses, attrs)
164    if not issubclass(base, Distribution):
165      raise TypeError("First parent class declared for %s must be "
166                      "Distribution, but saw '%s'" % (classname, base.__name__))
167    for attr in _DISTRIBUTION_PUBLIC_METHOD_WRAPPERS:
168      special_attr = "_%s" % attr
169      class_attr_value = attrs.get(attr, None)
170      if attr in attrs:
171        # The method is being overridden, do not update its docstring
172        continue
173      base_attr_value = getattr(base, attr, None)
174      if not base_attr_value:
175        raise AttributeError(
176            "Internal error: expected base class '%s' to implement method '%s'"
177            % (base.__name__, attr))
178      class_special_attr_value = attrs.get(special_attr, None)
179      if class_special_attr_value is None:
180        # No _special method available, no need to update the docstring.
181        continue
182      class_special_attr_docstring = tf_inspect.getdoc(class_special_attr_value)
183      if not class_special_attr_docstring:
184        # No docstring to append.
185        continue
186      class_attr_value = _copy_fn(base_attr_value)
187      class_attr_docstring = tf_inspect.getdoc(base_attr_value)
188      if class_attr_docstring is None:
189        raise ValueError(
190            "Expected base class fn to contain a docstring: %s.%s"
191            % (base.__name__, attr))
192      class_attr_value.__doc__ = _update_docstring(
193          class_attr_value.__doc__,
194          ("Additional documentation from `%s`:\n\n%s"
195           % (classname, class_special_attr_docstring)))
196      attrs[attr] = class_attr_value
197
198    return abc.ABCMeta.__new__(mcs, classname, baseclasses, attrs)
199
200
201@tf_export("distributions.ReparameterizationType")
202class ReparameterizationType(object):
203  """Instances of this class represent how sampling is reparameterized.
204
205  Two static instances exist in the distributions library, signifying
206  one of two possible properties for samples from a distribution:
207
208  `FULLY_REPARAMETERIZED`: Samples from the distribution are fully
209    reparameterized, and straight-through gradients are supported.
210
211  `NOT_REPARAMETERIZED`: Samples from the distribution are not fully
212    reparameterized, and straight-through gradients are either partially
213    unsupported or are not supported at all. In this case, for purposes of
214    e.g. RL or variational inference, it is generally safest to wrap the
215    sample results in a `stop_gradients` call and instead use policy
216    gradients / surrogate loss instead.
217  """
218
219  def __init__(self, rep_type):
220    self._rep_type = rep_type
221
222  def __repr__(self):
223    return "<Reparameteriation Type: %s>" % self._rep_type
224
225  def __eq__(self, other):
226    """Determine if this `ReparameterizationType` is equal to another.
227
228    Since RepaparameterizationType instances are constant static global
229    instances, equality checks if two instances' id() values are equal.
230
231    Args:
232      other: Object to compare against.
233
234    Returns:
235      `self is other`.
236    """
237    return self is other
238
239
240# Fully reparameterized distribution: samples from a fully
241# reparameterized distribution support straight-through gradients with
242# respect to all parameters.
243FULLY_REPARAMETERIZED = ReparameterizationType("FULLY_REPARAMETERIZED")
244tf_export("distributions.FULLY_REPARAMETERIZED").export_constant(
245    __name__, "FULLY_REPARAMETERIZED")
246
247
248# Not reparameterized distribution: samples from a non-
249# reparameterized distribution do not support straight-through gradients for
250# at least some of the parameters.
251NOT_REPARAMETERIZED = ReparameterizationType("NOT_REPARAMETERIZED")
252tf_export("distributions.NOT_REPARAMETERIZED").export_constant(
253    __name__, "NOT_REPARAMETERIZED")
254
255
256@six.add_metaclass(_DistributionMeta)
257@tf_export("distributions.Distribution")
258class Distribution(_BaseDistribution):
259  """A generic probability distribution base class.
260
261  `Distribution` is a base class for constructing and organizing properties
262  (e.g., mean, variance) of random variables (e.g, Bernoulli, Gaussian).
263
264  #### Subclassing
265
266  Subclasses are expected to implement a leading-underscore version of the
267  same-named function. The argument signature should be identical except for
268  the omission of `name="..."`. For example, to enable `log_prob(value,
269  name="log_prob")` a subclass should implement `_log_prob(value)`.
270
271  Subclasses can append to public-level docstrings by providing
272  docstrings for their method specializations. For example:
273
274  ```python
275  @util.AppendDocstring("Some other details.")
276  def _log_prob(self, value):
277    ...
278  ```
279
280  would add the string "Some other details." to the `log_prob` function
281  docstring. This is implemented as a simple decorator to avoid python
282  linter complaining about missing Args/Returns/Raises sections in the
283  partial docstrings.
284
285  #### Broadcasting, batching, and shapes
286
287  All distributions support batches of independent distributions of that type.
288  The batch shape is determined by broadcasting together the parameters.
289
290  The shape of arguments to `__init__`, `cdf`, `log_cdf`, `prob`, and
291  `log_prob` reflect this broadcasting, as does the return value of `sample` and
292  `sample_n`.
293
294  `sample_n_shape = [n] + batch_shape + event_shape`, where `sample_n_shape` is
295  the shape of the `Tensor` returned from `sample_n`, `n` is the number of
296  samples, `batch_shape` defines how many independent distributions there are,
297  and `event_shape` defines the shape of samples from each of those independent
298  distributions. Samples are independent along the `batch_shape` dimensions, but
299  not necessarily so along the `event_shape` dimensions (depending on the
300  particulars of the underlying distribution).
301
302  Using the `Uniform` distribution as an example:
303
304  ```python
305  minval = 3.0
306  maxval = [[4.0, 6.0],
307            [10.0, 12.0]]
308
309  # Broadcasting:
310  # This instance represents 4 Uniform distributions. Each has a lower bound at
311  # 3.0 as the `minval` parameter was broadcasted to match `maxval`'s shape.
312  u = Uniform(minval, maxval)
313
314  # `event_shape` is `TensorShape([])`.
315  event_shape = u.event_shape
316  # `event_shape_t` is a `Tensor` which will evaluate to [].
317  event_shape_t = u.event_shape_tensor()
318
319  # Sampling returns a sample per distribution. `samples` has shape
320  # [5, 2, 2], which is [n] + batch_shape + event_shape, where n=5,
321  # batch_shape=[2, 2], and event_shape=[].
322  samples = u.sample_n(5)
323
324  # The broadcasting holds across methods. Here we use `cdf` as an example. The
325  # same holds for `log_cdf` and the likelihood functions.
326
327  # `cum_prob` has shape [2, 2] as the `value` argument was broadcasted to the
328  # shape of the `Uniform` instance.
329  cum_prob_broadcast = u.cdf(4.0)
330
331  # `cum_prob`'s shape is [2, 2], one per distribution. No broadcasting
332  # occurred.
333  cum_prob_per_dist = u.cdf([[4.0, 5.0],
334                             [6.0, 7.0]])
335
336  # INVALID as the `value` argument is not broadcastable to the distribution's
337  # shape.
338  cum_prob_invalid = u.cdf([4.0, 5.0, 6.0])
339  ```
340
341  #### Parameter values leading to undefined statistics or distributions.
342
343  Some distributions do not have well-defined statistics for all initialization
344  parameter values. For example, the beta distribution is parameterized by
345  positive real numbers `concentration1` and `concentration0`, and does not have
346  well-defined mode if `concentration1 < 1` or `concentration0 < 1`.
347
348  The user is given the option of raising an exception or returning `NaN`.
349
350  ```python
351  a = tf.exp(tf.matmul(logits, weights_a))
352  b = tf.exp(tf.matmul(logits, weights_b))
353
354  # Will raise exception if ANY batch member has a < 1 or b < 1.
355  dist = distributions.beta(a, b, allow_nan_stats=False)
356  mode = dist.mode().eval()
357
358  # Will return NaN for batch members with either a < 1 or b < 1.
359  dist = distributions.beta(a, b, allow_nan_stats=True)  # Default behavior
360  mode = dist.mode().eval()
361  ```
362
363  In all cases, an exception is raised if *invalid* parameters are passed, e.g.
364
365  ```python
366  # Will raise an exception if any Op is run.
367  negative_a = -1.0 * a  # beta distribution by definition has a > 0.
368  dist = distributions.beta(negative_a, b, allow_nan_stats=True)
369  dist.mean().eval()
370  ```
371
372  """
373
374  def __init__(self,
375               dtype,
376               reparameterization_type,
377               validate_args,
378               allow_nan_stats,
379               parameters=None,
380               graph_parents=None,
381               name=None):
382    """Constructs the `Distribution`.
383
384    **This is a private method for subclass use.**
385
386    Args:
387      dtype: The type of the event samples. `None` implies no type-enforcement.
388      reparameterization_type: Instance of `ReparameterizationType`.
389        If `distributions.FULLY_REPARAMETERIZED`, this
390        `Distribution` can be reparameterized in terms of some standard
391        distribution with a function whose Jacobian is constant for the support
392        of the standard distribution. If `distributions.NOT_REPARAMETERIZED`,
393        then no such reparameterization is available.
394      validate_args: Python `bool`, default `False`. When `True` distribution
395        parameters are checked for validity despite possibly degrading runtime
396        performance. When `False` invalid inputs may silently render incorrect
397        outputs.
398      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
399        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
400        result is undefined. When `False`, an exception is raised if one or
401        more of the statistic's batch members are undefined.
402      parameters: Python `dict` of parameters used to instantiate this
403        `Distribution`.
404      graph_parents: Python `list` of graph prerequisites of this
405        `Distribution`.
406      name: Python `str` name prefixed to Ops created by this class. Default:
407        subclass name.
408
409    Raises:
410      ValueError: if any member of graph_parents is `None` or not a `Tensor`.
411    """
412    graph_parents = [] if graph_parents is None else graph_parents
413    for i, t in enumerate(graph_parents):
414      if t is None or not tensor_util.is_tensor(t):
415        raise ValueError("Graph parent item %d is not a Tensor; %s." % (i, t))
416    self._dtype = dtype
417    self._reparameterization_type = reparameterization_type
418    self._allow_nan_stats = allow_nan_stats
419    self._validate_args = validate_args
420    self._parameters = parameters or {}
421    self._graph_parents = graph_parents
422    self._name = name or type(self).__name__
423
424  @classmethod
425  def param_shapes(cls, sample_shape, name="DistributionParamShapes"):
426    """Shapes of parameters given the desired shape of a call to `sample()`.
427
428    This is a class method that describes what key/value arguments are required
429    to instantiate the given `Distribution` so that a particular shape is
430    returned for that instance's call to `sample()`.
431
432    Subclasses should override class method `_param_shapes`.
433
434    Args:
435      sample_shape: `Tensor` or python list/tuple. Desired shape of a call to
436        `sample()`.
437      name: name to prepend ops with.
438
439    Returns:
440      `dict` of parameter name to `Tensor` shapes.
441    """
442    with ops.name_scope(name, values=[sample_shape]):
443      return cls._param_shapes(sample_shape)
444
445  @classmethod
446  def param_static_shapes(cls, sample_shape):
447    """param_shapes with static (i.e. `TensorShape`) shapes.
448
449    This is a class method that describes what key/value arguments are required
450    to instantiate the given `Distribution` so that a particular shape is
451    returned for that instance's call to `sample()`. Assumes that the sample's
452    shape is known statically.
453
454    Subclasses should override class method `_param_shapes` to return
455    constant-valued tensors when constant values are fed.
456
457    Args:
458      sample_shape: `TensorShape` or python list/tuple. Desired shape of a call
459        to `sample()`.
460
461    Returns:
462      `dict` of parameter name to `TensorShape`.
463
464    Raises:
465      ValueError: if `sample_shape` is a `TensorShape` and is not fully defined.
466    """
467    if isinstance(sample_shape, tensor_shape.TensorShape):
468      if not sample_shape.is_fully_defined():
469        raise ValueError("TensorShape sample_shape must be fully defined")
470      sample_shape = sample_shape.as_list()
471
472    params = cls.param_shapes(sample_shape)
473
474    static_params = {}
475    for name, shape in params.items():
476      static_shape = tensor_util.constant_value(shape)
477      if static_shape is None:
478        raise ValueError(
479            "sample_shape must be a fully-defined TensorShape or list/tuple")
480      static_params[name] = tensor_shape.TensorShape(static_shape)
481
482    return static_params
483
484  @staticmethod
485  def _param_shapes(sample_shape):
486    raise NotImplementedError("_param_shapes not implemented")
487
488  @property
489  def name(self):
490    """Name prepended to all ops created by this `Distribution`."""
491    return self._name
492
493  @property
494  def dtype(self):
495    """The `DType` of `Tensor`s handled by this `Distribution`."""
496    return self._dtype
497
498  @property
499  def parameters(self):
500    """Dictionary of parameters used to instantiate this `Distribution`."""
501    # Remove "self", "__class__", or other special variables. These can appear
502    # if the subclass used `parameters = locals()`.
503    return dict((k, v) for k, v in self._parameters.items()
504                if not k.startswith("__") and k != "self")
505
506  @property
507  def reparameterization_type(self):
508    """Describes how samples from the distribution are reparameterized.
509
510    Currently this is one of the static instances
511    `distributions.FULLY_REPARAMETERIZED`
512    or `distributions.NOT_REPARAMETERIZED`.
513
514    Returns:
515      An instance of `ReparameterizationType`.
516    """
517    return self._reparameterization_type
518
519  @property
520  def allow_nan_stats(self):
521    """Python `bool` describing behavior when a stat is undefined.
522
523    Stats return +/- infinity when it makes sense. E.g., the variance of a
524    Cauchy distribution is infinity. However, sometimes the statistic is
525    undefined, e.g., if a distribution's pdf does not achieve a maximum within
526    the support of the distribution, the mode is undefined. If the mean is
527    undefined, then by definition the variance is undefined. E.g. the mean for
528    Student's T for df = 1 is undefined (no clear way to say it is either + or -
529    infinity), so the variance = E[(X - mean)**2] is also undefined.
530
531    Returns:
532      allow_nan_stats: Python `bool`.
533    """
534    return self._allow_nan_stats
535
536  @property
537  def validate_args(self):
538    """Python `bool` indicating possibly expensive checks are enabled."""
539    return self._validate_args
540
541  def copy(self, **override_parameters_kwargs):
542    """Creates a deep copy of the distribution.
543
544    Note: the copy distribution may continue to depend on the original
545    initialization arguments.
546
547    Args:
548      **override_parameters_kwargs: String/value dictionary of initialization
549        arguments to override with new values.
550
551    Returns:
552      distribution: A new instance of `type(self)` initialized from the union
553        of self.parameters and override_parameters_kwargs, i.e.,
554        `dict(self.parameters, **override_parameters_kwargs)`.
555    """
556    parameters = dict(self.parameters, **override_parameters_kwargs)
557    return type(self)(**parameters)
558
559  def _batch_shape_tensor(self):
560    raise NotImplementedError("batch_shape_tensor is not implemented")
561
562  def batch_shape_tensor(self, name="batch_shape_tensor"):
563    """Shape of a single sample from a single event index as a 1-D `Tensor`.
564
565    The batch dimensions are indexes into independent, non-identical
566    parameterizations of this distribution.
567
568    Args:
569      name: name to give to the op
570
571    Returns:
572      batch_shape: `Tensor`.
573    """
574    with self._name_scope(name):
575      if self.batch_shape.is_fully_defined():
576        return ops.convert_to_tensor(self.batch_shape.as_list(),
577                                     dtype=dtypes.int32,
578                                     name="batch_shape")
579      return self._batch_shape_tensor()
580
581  def _batch_shape(self):
582    return tensor_shape.TensorShape(None)
583
584  @property
585  def batch_shape(self):
586    """Shape of a single sample from a single event index as a `TensorShape`.
587
588    May be partially defined or unknown.
589
590    The batch dimensions are indexes into independent, non-identical
591    parameterizations of this distribution.
592
593    Returns:
594      batch_shape: `TensorShape`, possibly unknown.
595    """
596    return self._batch_shape()
597
598  def _event_shape_tensor(self):
599    raise NotImplementedError("event_shape_tensor is not implemented")
600
601  def event_shape_tensor(self, name="event_shape_tensor"):
602    """Shape of a single sample from a single batch as a 1-D int32 `Tensor`.
603
604    Args:
605      name: name to give to the op
606
607    Returns:
608      event_shape: `Tensor`.
609    """
610    with self._name_scope(name):
611      if self.event_shape.is_fully_defined():
612        return ops.convert_to_tensor(self.event_shape.as_list(),
613                                     dtype=dtypes.int32,
614                                     name="event_shape")
615      return self._event_shape_tensor()
616
617  def _event_shape(self):
618    return tensor_shape.TensorShape(None)
619
620  @property
621  def event_shape(self):
622    """Shape of a single sample from a single batch as a `TensorShape`.
623
624    May be partially defined or unknown.
625
626    Returns:
627      event_shape: `TensorShape`, possibly unknown.
628    """
629    return self._event_shape()
630
631  def is_scalar_event(self, name="is_scalar_event"):
632    """Indicates that `event_shape == []`.
633
634    Args:
635      name: Python `str` prepended to names of ops created by this function.
636
637    Returns:
638      is_scalar_event: `bool` scalar `Tensor`.
639    """
640    with self._name_scope(name):
641      return ops.convert_to_tensor(
642          self._is_scalar_helper(self.event_shape, self.event_shape_tensor),
643          name="is_scalar_event")
644
645  def is_scalar_batch(self, name="is_scalar_batch"):
646    """Indicates that `batch_shape == []`.
647
648    Args:
649      name: Python `str` prepended to names of ops created by this function.
650
651    Returns:
652      is_scalar_batch: `bool` scalar `Tensor`.
653    """
654    with self._name_scope(name):
655      return ops.convert_to_tensor(
656          self._is_scalar_helper(self.batch_shape, self.batch_shape_tensor),
657          name="is_scalar_batch")
658
659  def _sample_n(self, n, seed=None):
660    raise NotImplementedError("sample_n is not implemented")
661
662  def _call_sample_n(self, sample_shape, seed, name, **kwargs):
663    with self._name_scope(name, values=[sample_shape]):
664      sample_shape = ops.convert_to_tensor(
665          sample_shape, dtype=dtypes.int32, name="sample_shape")
666      sample_shape, n = self._expand_sample_shape_to_vector(
667          sample_shape, "sample_shape")
668      samples = self._sample_n(n, seed, **kwargs)
669      batch_event_shape = array_ops.shape(samples)[1:]
670      final_shape = array_ops.concat([sample_shape, batch_event_shape], 0)
671      samples = array_ops.reshape(samples, final_shape)
672      samples = self._set_sample_static_shape(samples, sample_shape)
673      return samples
674
675  def sample(self, sample_shape=(), seed=None, name="sample"):
676    """Generate samples of the specified shape.
677
678    Note that a call to `sample()` without arguments will generate a single
679    sample.
680
681    Args:
682      sample_shape: 0D or 1D `int32` `Tensor`. Shape of the generated samples.
683      seed: Python integer seed for RNG
684      name: name to give to the op.
685
686    Returns:
687      samples: a `Tensor` with prepended dimensions `sample_shape`.
688    """
689    return self._call_sample_n(sample_shape, seed, name)
690
691  def _log_prob(self, value):
692    raise NotImplementedError("log_prob is not implemented")
693
694  def _call_log_prob(self, value, name, **kwargs):
695    with self._name_scope(name, values=[value]):
696      value = ops.convert_to_tensor(value, name="value")
697      try:
698        return self._log_prob(value, **kwargs)
699      except NotImplementedError as original_exception:
700        try:
701          return math_ops.log(self._prob(value, **kwargs))
702        except NotImplementedError:
703          raise original_exception
704
705  def log_prob(self, value, name="log_prob"):
706    """Log probability density/mass function.
707
708    Args:
709      value: `float` or `double` `Tensor`.
710      name: Python `str` prepended to names of ops created by this function.
711
712    Returns:
713      log_prob: a `Tensor` of shape `sample_shape(x) + self.batch_shape` with
714        values of type `self.dtype`.
715    """
716    return self._call_log_prob(value, name)
717
718  def _prob(self, value):
719    raise NotImplementedError("prob is not implemented")
720
721  def _call_prob(self, value, name, **kwargs):
722    with self._name_scope(name, values=[value]):
723      value = ops.convert_to_tensor(value, name="value")
724      try:
725        return self._prob(value, **kwargs)
726      except NotImplementedError as original_exception:
727        try:
728          return math_ops.exp(self._log_prob(value, **kwargs))
729        except NotImplementedError:
730          raise original_exception
731
732  def prob(self, value, name="prob"):
733    """Probability density/mass function.
734
735    Args:
736      value: `float` or `double` `Tensor`.
737      name: Python `str` prepended to names of ops created by this function.
738
739    Returns:
740      prob: a `Tensor` of shape `sample_shape(x) + self.batch_shape` with
741        values of type `self.dtype`.
742    """
743    return self._call_prob(value, name)
744
745  def _log_cdf(self, value):
746    raise NotImplementedError("log_cdf is not implemented")
747
748  def _call_log_cdf(self, value, name, **kwargs):
749    with self._name_scope(name, values=[value]):
750      value = ops.convert_to_tensor(value, name="value")
751      try:
752        return self._log_cdf(value, **kwargs)
753      except NotImplementedError as original_exception:
754        try:
755          return math_ops.log(self._cdf(value, **kwargs))
756        except NotImplementedError:
757          raise original_exception
758
759  def log_cdf(self, value, name="log_cdf"):
760    """Log cumulative distribution function.
761
762    Given random variable `X`, the cumulative distribution function `cdf` is:
763
764    ```none
765    log_cdf(x) := Log[ P[X <= x] ]
766    ```
767
768    Often, a numerical approximation can be used for `log_cdf(x)` that yields
769    a more accurate answer than simply taking the logarithm of the `cdf` when
770    `x << -1`.
771
772    Args:
773      value: `float` or `double` `Tensor`.
774      name: Python `str` prepended to names of ops created by this function.
775
776    Returns:
777      logcdf: a `Tensor` of shape `sample_shape(x) + self.batch_shape` with
778        values of type `self.dtype`.
779    """
780    return self._call_log_cdf(value, name)
781
782  def _cdf(self, value):
783    raise NotImplementedError("cdf is not implemented")
784
785  def _call_cdf(self, value, name, **kwargs):
786    with self._name_scope(name, values=[value]):
787      value = ops.convert_to_tensor(value, name="value")
788      try:
789        return self._cdf(value, **kwargs)
790      except NotImplementedError as original_exception:
791        try:
792          return math_ops.exp(self._log_cdf(value, **kwargs))
793        except NotImplementedError:
794          raise original_exception
795
796  def cdf(self, value, name="cdf"):
797    """Cumulative distribution function.
798
799    Given random variable `X`, the cumulative distribution function `cdf` is:
800
801    ```none
802    cdf(x) := P[X <= x]
803    ```
804
805    Args:
806      value: `float` or `double` `Tensor`.
807      name: Python `str` prepended to names of ops created by this function.
808
809    Returns:
810      cdf: a `Tensor` of shape `sample_shape(x) + self.batch_shape` with
811        values of type `self.dtype`.
812    """
813    return self._call_cdf(value, name)
814
815  def _log_survival_function(self, value):
816    raise NotImplementedError("log_survival_function is not implemented")
817
818  def _call_log_survival_function(self, value, name, **kwargs):
819    with self._name_scope(name, values=[value]):
820      value = ops.convert_to_tensor(value, name="value")
821      try:
822        return self._log_survival_function(value, **kwargs)
823      except NotImplementedError as original_exception:
824        try:
825          return math_ops.log1p(-self.cdf(value, **kwargs))
826        except NotImplementedError:
827          raise original_exception
828
829  def log_survival_function(self, value, name="log_survival_function"):
830    """Log survival function.
831
832    Given random variable `X`, the survival function is defined:
833
834    ```none
835    log_survival_function(x) = Log[ P[X > x] ]
836                             = Log[ 1 - P[X <= x] ]
837                             = Log[ 1 - cdf(x) ]
838    ```
839
840    Typically, different numerical approximations can be used for the log
841    survival function, which are more accurate than `1 - cdf(x)` when `x >> 1`.
842
843    Args:
844      value: `float` or `double` `Tensor`.
845      name: Python `str` prepended to names of ops created by this function.
846
847    Returns:
848      `Tensor` of shape `sample_shape(x) + self.batch_shape` with values of type
849        `self.dtype`.
850    """
851    return self._call_log_survival_function(value, name)
852
853  def _survival_function(self, value):
854    raise NotImplementedError("survival_function is not implemented")
855
856  def _call_survival_function(self, value, name, **kwargs):
857    with self._name_scope(name, values=[value]):
858      value = ops.convert_to_tensor(value, name="value")
859      try:
860        return self._survival_function(value, **kwargs)
861      except NotImplementedError as original_exception:
862        try:
863          return 1. - self.cdf(value, **kwargs)
864        except NotImplementedError:
865          raise original_exception
866
867  def survival_function(self, value, name="survival_function"):
868    """Survival function.
869
870    Given random variable `X`, the survival function is defined:
871
872    ```none
873    survival_function(x) = P[X > x]
874                         = 1 - P[X <= x]
875                         = 1 - cdf(x).
876    ```
877
878    Args:
879      value: `float` or `double` `Tensor`.
880      name: Python `str` prepended to names of ops created by this function.
881
882    Returns:
883      `Tensor` of shape `sample_shape(x) + self.batch_shape` with values of type
884        `self.dtype`.
885    """
886    return self._call_survival_function(value, name)
887
888  def _entropy(self):
889    raise NotImplementedError("entropy is not implemented")
890
891  def entropy(self, name="entropy"):
892    """Shannon entropy in nats."""
893    with self._name_scope(name):
894      return self._entropy()
895
896  def _mean(self):
897    raise NotImplementedError("mean is not implemented")
898
899  def mean(self, name="mean"):
900    """Mean."""
901    with self._name_scope(name):
902      return self._mean()
903
904  def _quantile(self, value):
905    raise NotImplementedError("quantile is not implemented")
906
907  def _call_quantile(self, value, name, **kwargs):
908    with self._name_scope(name, values=[value]):
909      value = ops.convert_to_tensor(value, name="value")
910      try:
911        return self._quantile(value, **kwargs)
912      except NotImplementedError as original_exception:
913        raise original_exception
914
915  def quantile(self, value, name="quantile"):
916    """Quantile function. Aka "inverse cdf" or "percent point function".
917
918    Given random variable `X` and `p in [0, 1]`, the `quantile` is:
919
920    ```none
921    quantile(p) := x such that P[X <= x] == p
922    ```
923
924    Args:
925      value: `float` or `double` `Tensor`.
926      name: Python `str` prepended to names of ops created by this function.
927
928    Returns:
929      quantile: a `Tensor` of shape `sample_shape(x) + self.batch_shape` with
930        values of type `self.dtype`.
931    """
932    return self._call_quantile(value, name)
933
934  def _variance(self):
935    raise NotImplementedError("variance is not implemented")
936
937  def variance(self, name="variance"):
938    """Variance.
939
940    Variance is defined as,
941
942    ```none
943    Var = E[(X - E[X])**2]
944    ```
945
946    where `X` is the random variable associated with this distribution, `E`
947    denotes expectation, and `Var.shape = batch_shape + event_shape`.
948
949    Args:
950      name: Python `str` prepended to names of ops created by this function.
951
952    Returns:
953      variance: Floating-point `Tensor` with shape identical to
954        `batch_shape + event_shape`, i.e., the same shape as `self.mean()`.
955    """
956    with self._name_scope(name):
957      try:
958        return self._variance()
959      except NotImplementedError as original_exception:
960        try:
961          return math_ops.square(self._stddev())
962        except NotImplementedError:
963          raise original_exception
964
965  def _stddev(self):
966    raise NotImplementedError("stddev is not implemented")
967
968  def stddev(self, name="stddev"):
969    """Standard deviation.
970
971    Standard deviation is defined as,
972
973    ```none
974    stddev = E[(X - E[X])**2]**0.5
975    ```
976
977    where `X` is the random variable associated with this distribution, `E`
978    denotes expectation, and `stddev.shape = batch_shape + event_shape`.
979
980    Args:
981      name: Python `str` prepended to names of ops created by this function.
982
983    Returns:
984      stddev: Floating-point `Tensor` with shape identical to
985        `batch_shape + event_shape`, i.e., the same shape as `self.mean()`.
986    """
987
988    with self._name_scope(name):
989      try:
990        return self._stddev()
991      except NotImplementedError as original_exception:
992        try:
993          return math_ops.sqrt(self._variance())
994        except NotImplementedError:
995          raise original_exception
996
997  def _covariance(self):
998    raise NotImplementedError("covariance is not implemented")
999
1000  def covariance(self, name="covariance"):
1001    """Covariance.
1002
1003    Covariance is (possibly) defined only for non-scalar-event distributions.
1004
1005    For example, for a length-`k`, vector-valued distribution, it is calculated
1006    as,
1007
1008    ```none
1009    Cov[i, j] = Covariance(X_i, X_j) = E[(X_i - E[X_i]) (X_j - E[X_j])]
1010    ```
1011
1012    where `Cov` is a (batch of) `k x k` matrix, `0 <= (i, j) < k`, and `E`
1013    denotes expectation.
1014
1015    Alternatively, for non-vector, multivariate distributions (e.g.,
1016    matrix-valued, Wishart), `Covariance` shall return a (batch of) matrices
1017    under some vectorization of the events, i.e.,
1018
1019    ```none
1020    Cov[i, j] = Covariance(Vec(X)_i, Vec(X)_j) = [as above]
1021    ```
1022
1023    where `Cov` is a (batch of) `k' x k'` matrices,
1024    `0 <= (i, j) < k' = reduce_prod(event_shape)`, and `Vec` is some function
1025    mapping indices of this distribution's event dimensions to indices of a
1026    length-`k'` vector.
1027
1028    Args:
1029      name: Python `str` prepended to names of ops created by this function.
1030
1031    Returns:
1032      covariance: Floating-point `Tensor` with shape `[B1, ..., Bn, k', k']`
1033        where the first `n` dimensions are batch coordinates and
1034        `k' = reduce_prod(self.event_shape)`.
1035    """
1036    with self._name_scope(name):
1037      return self._covariance()
1038
1039  def _mode(self):
1040    raise NotImplementedError("mode is not implemented")
1041
1042  def mode(self, name="mode"):
1043    """Mode."""
1044    with self._name_scope(name):
1045      return self._mode()
1046
1047  def _cross_entropy(self, other):
1048    return kullback_leibler.cross_entropy(
1049        self, other, allow_nan_stats=self.allow_nan_stats)
1050
1051  def cross_entropy(self, other, name="cross_entropy"):
1052    """Computes the (Shannon) cross entropy.
1053
1054    Denote this distribution (`self`) by `P` and the `other` distribution by
1055    `Q`. Assuming `P, Q` are absolutely continuous with respect to
1056    one another and permit densities `p(x) dr(x)` and `q(x) dr(x)`, (Shanon)
1057    cross entropy is defined as:
1058
1059    ```none
1060    H[P, Q] = E_p[-log q(X)] = -int_F p(x) log q(x) dr(x)
1061    ```
1062
1063    where `F` denotes the support of the random variable `X ~ P`.
1064
1065    Args:
1066      other: `tf.distributions.Distribution` instance.
1067      name: Python `str` prepended to names of ops created by this function.
1068
1069    Returns:
1070      cross_entropy: `self.dtype` `Tensor` with shape `[B1, ..., Bn]`
1071        representing `n` different calculations of (Shanon) cross entropy.
1072    """
1073    with self._name_scope(name):
1074      return self._cross_entropy(other)
1075
1076  def _kl_divergence(self, other):
1077    return kullback_leibler.kl_divergence(
1078        self, other, allow_nan_stats=self.allow_nan_stats)
1079
1080  def kl_divergence(self, other, name="kl_divergence"):
1081    """Computes the Kullback--Leibler divergence.
1082
1083    Denote this distribution (`self`) by `p` and the `other` distribution by
1084    `q`. Assuming `p, q` are absolutely continuous with respect to reference
1085    measure `r`, the KL divergence is defined as:
1086
1087    ```none
1088    KL[p, q] = E_p[log(p(X)/q(X))]
1089             = -int_F p(x) log q(x) dr(x) + int_F p(x) log p(x) dr(x)
1090             = H[p, q] - H[p]
1091    ```
1092
1093    where `F` denotes the support of the random variable `X ~ p`, `H[., .]`
1094    denotes (Shanon) cross entropy, and `H[.]` denotes (Shanon) entropy.
1095
1096    Args:
1097      other: `tf.distributions.Distribution` instance.
1098      name: Python `str` prepended to names of ops created by this function.
1099
1100    Returns:
1101      kl_divergence: `self.dtype` `Tensor` with shape `[B1, ..., Bn]`
1102        representing `n` different calculations of the Kullback-Leibler
1103        divergence.
1104    """
1105    with self._name_scope(name):
1106      return self._kl_divergence(other)
1107
1108  @contextlib.contextmanager
1109  def _name_scope(self, name=None, values=None):
1110    """Helper function to standardize op scope."""
1111    with ops.name_scope(self.name):
1112      with ops.name_scope(name, values=(
1113          ([] if values is None else values) + self._graph_parents)) as scope:
1114        yield scope
1115
1116  def _expand_sample_shape_to_vector(self, x, name):
1117    """Helper to `sample` which ensures input is 1D."""
1118    x_static_val = tensor_util.constant_value(x)
1119    if x_static_val is None:
1120      prod = math_ops.reduce_prod(x)
1121    else:
1122      prod = np.prod(x_static_val, dtype=x.dtype.as_numpy_dtype())
1123
1124    ndims = x.get_shape().ndims  # != sample_ndims
1125    if ndims is None:
1126      # Maybe expand_dims.
1127      ndims = array_ops.rank(x)
1128      expanded_shape = util.pick_vector(
1129          math_ops.equal(ndims, 0),
1130          np.array([1], dtype=np.int32), array_ops.shape(x))
1131      x = array_ops.reshape(x, expanded_shape)
1132    elif ndims == 0:
1133      # Definitely expand_dims.
1134      if x_static_val is not None:
1135        x = ops.convert_to_tensor(
1136            np.array([x_static_val], dtype=x.dtype.as_numpy_dtype()),
1137            name=name)
1138      else:
1139        x = array_ops.reshape(x, [1])
1140    elif ndims != 1:
1141      raise ValueError("Input is neither scalar nor vector.")
1142
1143    return x, prod
1144
1145  def _set_sample_static_shape(self, x, sample_shape):
1146    """Helper to `sample`; sets static shape info."""
1147    # Set shape hints.
1148    sample_shape = tensor_shape.TensorShape(
1149        tensor_util.constant_value(sample_shape))
1150
1151    ndims = x.get_shape().ndims
1152    sample_ndims = sample_shape.ndims
1153    batch_ndims = self.batch_shape.ndims
1154    event_ndims = self.event_shape.ndims
1155
1156    # Infer rank(x).
1157    if (ndims is None and
1158        sample_ndims is not None and
1159        batch_ndims is not None and
1160        event_ndims is not None):
1161      ndims = sample_ndims + batch_ndims + event_ndims
1162      x.set_shape([None] * ndims)
1163
1164    # Infer sample shape.
1165    if ndims is not None and sample_ndims is not None:
1166      shape = sample_shape.concatenate([None]*(ndims - sample_ndims))
1167      x.set_shape(x.get_shape().merge_with(shape))
1168
1169    # Infer event shape.
1170    if ndims is not None and event_ndims is not None:
1171      shape = tensor_shape.TensorShape(
1172          [None]*(ndims - event_ndims)).concatenate(self.event_shape)
1173      x.set_shape(x.get_shape().merge_with(shape))
1174
1175    # Infer batch shape.
1176    if batch_ndims is not None:
1177      if ndims is not None:
1178        if sample_ndims is None and event_ndims is not None:
1179          sample_ndims = ndims - batch_ndims - event_ndims
1180        elif event_ndims is None and sample_ndims is not None:
1181          event_ndims = ndims - batch_ndims - sample_ndims
1182      if sample_ndims is not None and event_ndims is not None:
1183        shape = tensor_shape.TensorShape([None]*sample_ndims).concatenate(
1184            self.batch_shape).concatenate([None]*event_ndims)
1185        x.set_shape(x.get_shape().merge_with(shape))
1186
1187    return x
1188
1189  def _is_scalar_helper(self, static_shape, dynamic_shape_fn):
1190    """Implementation for `is_scalar_batch` and `is_scalar_event`."""
1191    if static_shape.ndims is not None:
1192      return static_shape.ndims == 0
1193    shape = dynamic_shape_fn()
1194    if (shape.get_shape().ndims is not None and
1195        shape.get_shape()[0].value is not None):
1196      # If the static_shape is correctly written then we should never execute
1197      # this branch. We keep it just in case there's some unimagined corner
1198      # case.
1199      return shape.get_shape().as_list() == [0]
1200    return math_ops.equal(array_ops.shape(shape)[0], 0)
1201