• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""The Exponential distribution class."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.python.framework import dtypes
24from tensorflow.python.framework import ops
25from tensorflow.python.ops import array_ops
26from tensorflow.python.ops import math_ops
27from tensorflow.python.ops import nn
28from tensorflow.python.ops import random_ops
29from tensorflow.python.ops.distributions import gamma
30from tensorflow.python.util import deprecation
31from tensorflow.python.util.tf_export import tf_export
32
33
34__all__ = [
35    "Exponential",
36    "ExponentialWithSoftplusRate",
37]
38
39
40@tf_export(v1=["distributions.Exponential"])
41class Exponential(gamma.Gamma):
42  """Exponential distribution.
43
44  The Exponential distribution is parameterized by an event `rate` parameter.
45
46  #### Mathematical Details
47
48  The probability density function (pdf) is,
49
50  ```none
51  pdf(x; lambda, x > 0) = exp(-lambda x) / Z
52  Z = 1 / lambda
53  ```
54
55  where `rate = lambda` and `Z` is the normalizaing constant.
56
57  The Exponential distribution is a special case of the Gamma distribution,
58  i.e.,
59
60  ```python
61  Exponential(rate) = Gamma(concentration=1., rate)
62  ```
63
64  The Exponential distribution uses a `rate` parameter, or "inverse scale",
65  which can be intuited as,
66
67  ```none
68  X ~ Exponential(rate=1)
69  Y = X / rate
70  ```
71
72  """
73
74  @deprecation.deprecated(
75      "2019-01-01",
76      "The TensorFlow Distributions library has moved to "
77      "TensorFlow Probability "
78      "(https://github.com/tensorflow/probability). You "
79      "should update all references to use `tfp.distributions` "
80      "instead of `tf.distributions`.",
81      warn_once=True)
82  def __init__(self,
83               rate,
84               validate_args=False,
85               allow_nan_stats=True,
86               name="Exponential"):
87    """Construct Exponential distribution with parameter `rate`.
88
89    Args:
90      rate: Floating point tensor, equivalent to `1 / mean`. Must contain only
91        positive values.
92      validate_args: Python `bool`, default `False`. When `True` distribution
93        parameters are checked for validity despite possibly degrading runtime
94        performance. When `False` invalid inputs may silently render incorrect
95        outputs.
96      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
97        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
98        result is undefined. When `False`, an exception is raised if one or
99        more of the statistic's batch members are undefined.
100      name: Python `str` name prefixed to Ops created by this class.
101    """
102    parameters = dict(locals())
103    # Even though all statistics of are defined for valid inputs, this is not
104    # true in the parent class "Gamma."  Therefore, passing
105    # allow_nan_stats=True
106    # through to the parent class results in unnecessary asserts.
107    with ops.name_scope(name, values=[rate]) as name:
108      self._rate = ops.convert_to_tensor(rate, name="rate")
109    super(Exponential, self).__init__(
110        concentration=array_ops.ones([], dtype=self._rate.dtype),
111        rate=self._rate,
112        allow_nan_stats=allow_nan_stats,
113        validate_args=validate_args,
114        name=name)
115    self._parameters = parameters
116    self._graph_parents += [self._rate]
117
118  @staticmethod
119  def _param_shapes(sample_shape):
120    return {"rate": ops.convert_to_tensor(sample_shape, dtype=dtypes.int32)}
121
122  @property
123  def rate(self):
124    return self._rate
125
126  def _log_survival_function(self, value):
127    return self._log_prob(value) - math_ops.log(self._rate)
128
129  def _sample_n(self, n, seed=None):
130    shape = array_ops.concat([[n], array_ops.shape(self._rate)], 0)
131    # Uniform variates must be sampled from the open-interval `(0, 1)` rather
132    # than `[0, 1)`. To do so, we use `np.finfo(self.dtype.as_numpy_dtype).tiny`
133    # because it is the smallest, positive, "normal" number. A "normal" number
134    # is such that the mantissa has an implicit leading 1. Normal, positive
135    # numbers x, y have the reasonable property that, `x + y >= max(x, y)`. In
136    # this case, a subnormal number (i.e., np.nextafter) can cause us to sample
137    # 0.
138    sampled = random_ops.random_uniform(
139        shape,
140        minval=np.finfo(self.dtype.as_numpy_dtype).tiny,
141        maxval=1.,
142        seed=seed,
143        dtype=self.dtype)
144    return -math_ops.log(sampled) / self._rate
145
146
147class ExponentialWithSoftplusRate(Exponential):
148  """Exponential with softplus transform on `rate`."""
149
150  @deprecation.deprecated(
151      "2019-01-01",
152      "Use `tfd.Exponential(tf.nn.softplus(rate)).",
153      warn_once=True)
154  def __init__(self,
155               rate,
156               validate_args=False,
157               allow_nan_stats=True,
158               name="ExponentialWithSoftplusRate"):
159    parameters = dict(locals())
160    with ops.name_scope(name, values=[rate]) as name:
161      super(ExponentialWithSoftplusRate, self).__init__(
162          rate=nn.softplus(rate, name="softplus_rate"),
163          validate_args=validate_args,
164          allow_nan_stats=allow_nan_stats,
165          name=name)
166    self._parameters = parameters
167