• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""The Beta distribution class."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.python.framework import constant_op
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import tensor_shape
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import check_ops
29from tensorflow.python.ops import control_flow_ops
30from tensorflow.python.ops import math_ops
31from tensorflow.python.ops import nn
32from tensorflow.python.ops import random_ops
33from tensorflow.python.ops.distributions import distribution
34from tensorflow.python.ops.distributions import kullback_leibler
35from tensorflow.python.ops.distributions import util as distribution_util
36from tensorflow.python.util import deprecation
37from tensorflow.python.util.tf_export import tf_export
38
39
40__all__ = [
41    "Beta",
42    "BetaWithSoftplusConcentration",
43]
44
45
46_beta_sample_note = """Note: `x` must have dtype `self.dtype` and be in
47`[0, 1].` It must have a shape compatible with `self.batch_shape()`."""
48
49
50@tf_export(v1=["distributions.Beta"])
51class Beta(distribution.Distribution):
52  """Beta distribution.
53
54  The Beta distribution is defined over the `(0, 1)` interval using parameters
55  `concentration1` (aka "alpha") and `concentration0` (aka "beta").
56
57  #### Mathematical Details
58
59  The probability density function (pdf) is,
60
61  ```none
62  pdf(x; alpha, beta) = x**(alpha - 1) (1 - x)**(beta - 1) / Z
63  Z = Gamma(alpha) Gamma(beta) / Gamma(alpha + beta)
64  ```
65
66  where:
67
68  * `concentration1 = alpha`,
69  * `concentration0 = beta`,
70  * `Z` is the normalization constant, and,
71  * `Gamma` is the [gamma function](
72    https://en.wikipedia.org/wiki/Gamma_function).
73
74  The concentration parameters represent mean total counts of a `1` or a `0`,
75  i.e.,
76
77  ```none
78  concentration1 = alpha = mean * total_concentration
79  concentration0 = beta  = (1. - mean) * total_concentration
80  ```
81
82  where `mean` in `(0, 1)` and `total_concentration` is a positive real number
83  representing a mean `total_count = concentration1 + concentration0`.
84
85  Distribution parameters are automatically broadcast in all functions; see
86  examples for details.
87
88  Warning: The samples can be zero due to finite precision.
89  This happens more often when some of the concentrations are very small.
90  Make sure to round the samples to `np.finfo(dtype).tiny` before computing the
91  density.
92
93  Samples of this distribution are reparameterized (pathwise differentiable).
94  The derivatives are computed using the approach described in
95  (Figurnov et al., 2018).
96
97  #### Examples
98
99  ```python
100  import tensorflow_probability as tfp
101  tfd = tfp.distributions
102
103  # Create a batch of three Beta distributions.
104  alpha = [1, 2, 3]
105  beta = [1, 2, 3]
106  dist = tfd.Beta(alpha, beta)
107
108  dist.sample([4, 5])  # Shape [4, 5, 3]
109
110  # `x` has three batch entries, each with two samples.
111  x = [[.1, .4, .5],
112       [.2, .3, .5]]
113  # Calculate the probability of each pair of samples under the corresponding
114  # distribution in `dist`.
115  dist.prob(x)         # Shape [2, 3]
116  ```
117
118  ```python
119  # Create batch_shape=[2, 3] via parameter broadcast:
120  alpha = [[1.], [2]]      # Shape [2, 1]
121  beta = [3., 4, 5]        # Shape [3]
122  dist = tfd.Beta(alpha, beta)
123
124  # alpha broadcast as: [[1., 1, 1,],
125  #                      [2, 2, 2]]
126  # beta broadcast as:  [[3., 4, 5],
127  #                      [3, 4, 5]]
128  # batch_Shape [2, 3]
129  dist.sample([4, 5])  # Shape [4, 5, 2, 3]
130
131  x = [.2, .3, .5]
132  # x will be broadcast as [[.2, .3, .5],
133  #                         [.2, .3, .5]],
134  # thus matching batch_shape [2, 3].
135  dist.prob(x)         # Shape [2, 3]
136  ```
137
138  Compute the gradients of samples w.r.t. the parameters:
139
140  ```python
141  alpha = tf.constant(1.0)
142  beta = tf.constant(2.0)
143  dist = tfd.Beta(alpha, beta)
144  samples = dist.sample(5)  # Shape [5]
145  loss = tf.reduce_mean(tf.square(samples))  # Arbitrary loss function
146  # Unbiased stochastic gradients of the loss function
147  grads = tf.gradients(loss, [alpha, beta])
148  ```
149
150  References:
151    Implicit Reparameterization Gradients:
152      [Figurnov et al., 2018]
153      (http://papers.nips.cc/paper/7326-implicit-reparameterization-gradients)
154      ([pdf]
155      (http://papers.nips.cc/paper/7326-implicit-reparameterization-gradients.pdf))
156  """
157
158  @deprecation.deprecated(
159      "2019-01-01",
160      "The TensorFlow Distributions library has moved to "
161      "TensorFlow Probability "
162      "(https://github.com/tensorflow/probability). You "
163      "should update all references to use `tfp.distributions` "
164      "instead of `tf.distributions`.",
165      warn_once=True)
166  def __init__(self,
167               concentration1=None,
168               concentration0=None,
169               validate_args=False,
170               allow_nan_stats=True,
171               name="Beta"):
172    """Initialize a batch of Beta distributions.
173
174    Args:
175      concentration1: Positive floating-point `Tensor` indicating mean
176        number of successes; aka "alpha". Implies `self.dtype` and
177        `self.batch_shape`, i.e.,
178        `concentration1.shape = [N1, N2, ..., Nm] = self.batch_shape`.
179      concentration0: Positive floating-point `Tensor` indicating mean
180        number of failures; aka "beta". Otherwise has same semantics as
181        `concentration1`.
182      validate_args: Python `bool`, default `False`. When `True` distribution
183        parameters are checked for validity despite possibly degrading runtime
184        performance. When `False` invalid inputs may silently render incorrect
185        outputs.
186      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
187        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
188        result is undefined. When `False`, an exception is raised if one or
189        more of the statistic's batch members are undefined.
190      name: Python `str` name prefixed to Ops created by this class.
191    """
192    parameters = dict(locals())
193    with ops.name_scope(name, values=[concentration1, concentration0]) as name:
194      self._concentration1 = self._maybe_assert_valid_concentration(
195          ops.convert_to_tensor(concentration1, name="concentration1"),
196          validate_args)
197      self._concentration0 = self._maybe_assert_valid_concentration(
198          ops.convert_to_tensor(concentration0, name="concentration0"),
199          validate_args)
200      check_ops.assert_same_float_dtype([
201          self._concentration1, self._concentration0])
202      self._total_concentration = self._concentration1 + self._concentration0
203    super(Beta, self).__init__(
204        dtype=self._total_concentration.dtype,
205        validate_args=validate_args,
206        allow_nan_stats=allow_nan_stats,
207        reparameterization_type=distribution.FULLY_REPARAMETERIZED,
208        parameters=parameters,
209        graph_parents=[self._concentration1,
210                       self._concentration0,
211                       self._total_concentration],
212        name=name)
213
214  @staticmethod
215  def _param_shapes(sample_shape):
216    return dict(zip(
217        ["concentration1", "concentration0"],
218        [ops.convert_to_tensor(sample_shape, dtype=dtypes.int32)] * 2))
219
220  @property
221  def concentration1(self):
222    """Concentration parameter associated with a `1` outcome."""
223    return self._concentration1
224
225  @property
226  def concentration0(self):
227    """Concentration parameter associated with a `0` outcome."""
228    return self._concentration0
229
230  @property
231  def total_concentration(self):
232    """Sum of concentration parameters."""
233    return self._total_concentration
234
235  def _batch_shape_tensor(self):
236    return array_ops.shape(self.total_concentration)
237
238  def _batch_shape(self):
239    return self.total_concentration.get_shape()
240
241  def _event_shape_tensor(self):
242    return constant_op.constant([], dtype=dtypes.int32)
243
244  def _event_shape(self):
245    return tensor_shape.TensorShape([])
246
247  def _sample_n(self, n, seed=None):
248    expanded_concentration1 = array_ops.ones_like(
249        self.total_concentration, dtype=self.dtype) * self.concentration1
250    expanded_concentration0 = array_ops.ones_like(
251        self.total_concentration, dtype=self.dtype) * self.concentration0
252    gamma1_sample = random_ops.random_gamma(
253        shape=[n],
254        alpha=expanded_concentration1,
255        dtype=self.dtype,
256        seed=seed)
257    gamma2_sample = random_ops.random_gamma(
258        shape=[n],
259        alpha=expanded_concentration0,
260        dtype=self.dtype,
261        seed=distribution_util.gen_new_seed(seed, "beta"))
262    beta_sample = gamma1_sample / (gamma1_sample + gamma2_sample)
263    return beta_sample
264
265  @distribution_util.AppendDocstring(_beta_sample_note)
266  def _log_prob(self, x):
267    return self._log_unnormalized_prob(x) - self._log_normalization()
268
269  @distribution_util.AppendDocstring(_beta_sample_note)
270  def _prob(self, x):
271    return math_ops.exp(self._log_prob(x))
272
273  @distribution_util.AppendDocstring(_beta_sample_note)
274  def _log_cdf(self, x):
275    return math_ops.log(self._cdf(x))
276
277  @distribution_util.AppendDocstring(_beta_sample_note)
278  def _cdf(self, x):
279    return math_ops.betainc(self.concentration1, self.concentration0, x)
280
281  def _log_unnormalized_prob(self, x):
282    x = self._maybe_assert_valid_sample(x)
283    return (math_ops.xlogy(self.concentration1 - 1., x) +
284            (self.concentration0 - 1.) * math_ops.log1p(-x))  # pylint: disable=invalid-unary-operand-type
285
286  def _log_normalization(self):
287    return (math_ops.lgamma(self.concentration1)
288            + math_ops.lgamma(self.concentration0)
289            - math_ops.lgamma(self.total_concentration))
290
291  def _entropy(self):
292    return (
293        self._log_normalization()
294        - (self.concentration1 - 1.) * math_ops.digamma(self.concentration1)
295        - (self.concentration0 - 1.) * math_ops.digamma(self.concentration0)
296        + ((self.total_concentration - 2.) *
297           math_ops.digamma(self.total_concentration)))
298
299  def _mean(self):
300    return self._concentration1 / self._total_concentration
301
302  def _variance(self):
303    return self._mean() * (1. - self._mean()) / (1. + self.total_concentration)
304
305  @distribution_util.AppendDocstring(
306      """Note: The mode is undefined when `concentration1 <= 1` or
307      `concentration0 <= 1`. If `self.allow_nan_stats` is `True`, `NaN`
308      is used for undefined modes. If `self.allow_nan_stats` is `False` an
309      exception is raised when one or more modes are undefined.""")
310  def _mode(self):
311    mode = (self.concentration1 - 1.) / (self.total_concentration - 2.)
312    if self.allow_nan_stats:
313      nan = array_ops.fill(
314          self.batch_shape_tensor(),
315          np.array(np.nan, dtype=self.dtype.as_numpy_dtype()),
316          name="nan")
317      is_defined = math_ops.logical_and(self.concentration1 > 1.,
318                                        self.concentration0 > 1.)
319      return array_ops.where_v2(is_defined, mode, nan)
320    return control_flow_ops.with_dependencies([
321        check_ops.assert_less(
322            array_ops.ones([], dtype=self.dtype),
323            self.concentration1,
324            message="Mode undefined for concentration1 <= 1."),
325        check_ops.assert_less(
326            array_ops.ones([], dtype=self.dtype),
327            self.concentration0,
328            message="Mode undefined for concentration0 <= 1.")
329    ], mode)
330
331  def _maybe_assert_valid_concentration(self, concentration, validate_args):
332    """Checks the validity of a concentration parameter."""
333    if not validate_args:
334      return concentration
335    return control_flow_ops.with_dependencies([
336        check_ops.assert_positive(
337            concentration,
338            message="Concentration parameter must be positive."),
339    ], concentration)
340
341  def _maybe_assert_valid_sample(self, x):
342    """Checks the validity of a sample."""
343    if not self.validate_args:
344      return x
345    return control_flow_ops.with_dependencies([
346        check_ops.assert_positive(x, message="sample must be positive"),
347        check_ops.assert_less(
348            x,
349            array_ops.ones([], self.dtype),
350            message="sample must be less than `1`."),
351    ], x)
352
353
354class BetaWithSoftplusConcentration(Beta):
355  """Beta with softplus transform of `concentration1` and `concentration0`."""
356
357  @deprecation.deprecated(
358      "2019-01-01",
359      "Use `tfd.Beta(tf.nn.softplus(concentration1), "
360      "tf.nn.softplus(concentration2))` instead.",
361      warn_once=True)
362  def __init__(self,
363               concentration1,
364               concentration0,
365               validate_args=False,
366               allow_nan_stats=True,
367               name="BetaWithSoftplusConcentration"):
368    parameters = dict(locals())
369    with ops.name_scope(name, values=[concentration1,
370                                      concentration0]) as name:
371      super(BetaWithSoftplusConcentration, self).__init__(
372          concentration1=nn.softplus(concentration1,
373                                     name="softplus_concentration1"),
374          concentration0=nn.softplus(concentration0,
375                                     name="softplus_concentration0"),
376          validate_args=validate_args,
377          allow_nan_stats=allow_nan_stats,
378          name=name)
379    self._parameters = parameters
380
381
382@kullback_leibler.RegisterKL(Beta, Beta)
383def _kl_beta_beta(d1, d2, name=None):
384  """Calculate the batchwise KL divergence KL(d1 || d2) with d1 and d2 Beta.
385
386  Args:
387    d1: instance of a Beta distribution object.
388    d2: instance of a Beta distribution object.
389    name: (optional) Name to use for created operations.
390      default is "kl_beta_beta".
391
392  Returns:
393    Batchwise KL(d1 || d2)
394  """
395  def delta(fn, is_property=True):
396    fn1 = getattr(d1, fn)
397    fn2 = getattr(d2, fn)
398    return (fn2 - fn1) if is_property else (fn2() - fn1())
399  with ops.name_scope(name, "kl_beta_beta", values=[
400      d1.concentration1,
401      d1.concentration0,
402      d1.total_concentration,
403      d2.concentration1,
404      d2.concentration0,
405      d2.total_concentration,
406  ]):
407    return (delta("_log_normalization", is_property=False)
408            - math_ops.digamma(d1.concentration1) * delta("concentration1")
409            - math_ops.digamma(d1.concentration0) * delta("concentration0")
410            + (math_ops.digamma(d1.total_concentration)
411               * delta("total_concentration")))
412