• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""The Beta distribution class."""
16
17import numpy as np
18
19from tensorflow.python.framework import constant_op
20from tensorflow.python.framework import dtypes
21from tensorflow.python.framework import ops
22from tensorflow.python.framework import tensor_shape
23from tensorflow.python.ops import array_ops
24from tensorflow.python.ops import check_ops
25from tensorflow.python.ops import control_flow_ops
26from tensorflow.python.ops import math_ops
27from tensorflow.python.ops import nn
28from tensorflow.python.ops import random_ops
29from tensorflow.python.ops.distributions import distribution
30from tensorflow.python.ops.distributions import kullback_leibler
31from tensorflow.python.ops.distributions import util as distribution_util
32from tensorflow.python.util import deprecation
33from tensorflow.python.util.tf_export import tf_export
34
35
36__all__ = [
37    "Beta",
38    "BetaWithSoftplusConcentration",
39]
40
41
42_beta_sample_note = """Note: `x` must have dtype `self.dtype` and be in
43`[0, 1].` It must have a shape compatible with `self.batch_shape()`."""
44
45
46@tf_export(v1=["distributions.Beta"])
47class Beta(distribution.Distribution):
48  """Beta distribution.
49
50  The Beta distribution is defined over the `(0, 1)` interval using parameters
51  `concentration1` (aka "alpha") and `concentration0` (aka "beta").
52
53  #### Mathematical Details
54
55  The probability density function (pdf) is,
56
57  ```none
58  pdf(x; alpha, beta) = x**(alpha - 1) (1 - x)**(beta - 1) / Z
59  Z = Gamma(alpha) Gamma(beta) / Gamma(alpha + beta)
60  ```
61
62  where:
63
64  * `concentration1 = alpha`,
65  * `concentration0 = beta`,
66  * `Z` is the normalization constant, and,
67  * `Gamma` is the [gamma function](
68    https://en.wikipedia.org/wiki/Gamma_function).
69
70  The concentration parameters represent mean total counts of a `1` or a `0`,
71  i.e.,
72
73  ```none
74  concentration1 = alpha = mean * total_concentration
75  concentration0 = beta  = (1. - mean) * total_concentration
76  ```
77
78  where `mean` in `(0, 1)` and `total_concentration` is a positive real number
79  representing a mean `total_count = concentration1 + concentration0`.
80
81  Distribution parameters are automatically broadcast in all functions; see
82  examples for details.
83
84  Warning: The samples can be zero due to finite precision.
85  This happens more often when some of the concentrations are very small.
86  Make sure to round the samples to `np.finfo(dtype).tiny` before computing the
87  density.
88
89  Samples of this distribution are reparameterized (pathwise differentiable).
90  The derivatives are computed using the approach described in
91  (Figurnov et al., 2018).
92
93  #### Examples
94
95  ```python
96  import tensorflow_probability as tfp
97  tfd = tfp.distributions
98
99  # Create a batch of three Beta distributions.
100  alpha = [1, 2, 3]
101  beta = [1, 2, 3]
102  dist = tfd.Beta(alpha, beta)
103
104  dist.sample([4, 5])  # Shape [4, 5, 3]
105
106  # `x` has three batch entries, each with two samples.
107  x = [[.1, .4, .5],
108       [.2, .3, .5]]
109  # Calculate the probability of each pair of samples under the corresponding
110  # distribution in `dist`.
111  dist.prob(x)         # Shape [2, 3]
112  ```
113
114  ```python
115  # Create batch_shape=[2, 3] via parameter broadcast:
116  alpha = [[1.], [2]]      # Shape [2, 1]
117  beta = [3., 4, 5]        # Shape [3]
118  dist = tfd.Beta(alpha, beta)
119
120  # alpha broadcast as: [[1., 1, 1,],
121  #                      [2, 2, 2]]
122  # beta broadcast as:  [[3., 4, 5],
123  #                      [3, 4, 5]]
124  # batch_Shape [2, 3]
125  dist.sample([4, 5])  # Shape [4, 5, 2, 3]
126
127  x = [.2, .3, .5]
128  # x will be broadcast as [[.2, .3, .5],
129  #                         [.2, .3, .5]],
130  # thus matching batch_shape [2, 3].
131  dist.prob(x)         # Shape [2, 3]
132  ```
133
134  Compute the gradients of samples w.r.t. the parameters:
135
136  ```python
137  alpha = tf.constant(1.0)
138  beta = tf.constant(2.0)
139  dist = tfd.Beta(alpha, beta)
140  samples = dist.sample(5)  # Shape [5]
141  loss = tf.reduce_mean(tf.square(samples))  # Arbitrary loss function
142  # Unbiased stochastic gradients of the loss function
143  grads = tf.gradients(loss, [alpha, beta])
144  ```
145
146  References:
147    Implicit Reparameterization Gradients:
148      [Figurnov et al., 2018]
149      (http://papers.nips.cc/paper/7326-implicit-reparameterization-gradients)
150      ([pdf]
151      (http://papers.nips.cc/paper/7326-implicit-reparameterization-gradients.pdf))
152  """
153
154  @deprecation.deprecated(
155      "2019-01-01",
156      "The TensorFlow Distributions library has moved to "
157      "TensorFlow Probability "
158      "(https://github.com/tensorflow/probability). You "
159      "should update all references to use `tfp.distributions` "
160      "instead of `tf.distributions`.",
161      warn_once=True)
162  def __init__(self,
163               concentration1=None,
164               concentration0=None,
165               validate_args=False,
166               allow_nan_stats=True,
167               name="Beta"):
168    """Initialize a batch of Beta distributions.
169
170    Args:
171      concentration1: Positive floating-point `Tensor` indicating mean
172        number of successes; aka "alpha". Implies `self.dtype` and
173        `self.batch_shape`, i.e.,
174        `concentration1.shape = [N1, N2, ..., Nm] = self.batch_shape`.
175      concentration0: Positive floating-point `Tensor` indicating mean
176        number of failures; aka "beta". Otherwise has same semantics as
177        `concentration1`.
178      validate_args: Python `bool`, default `False`. When `True` distribution
179        parameters are checked for validity despite possibly degrading runtime
180        performance. When `False` invalid inputs may silently render incorrect
181        outputs.
182      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
183        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
184        result is undefined. When `False`, an exception is raised if one or
185        more of the statistic's batch members are undefined.
186      name: Python `str` name prefixed to Ops created by this class.
187    """
188    parameters = dict(locals())
189    with ops.name_scope(name, values=[concentration1, concentration0]) as name:
190      self._concentration1 = self._maybe_assert_valid_concentration(
191          ops.convert_to_tensor(concentration1, name="concentration1"),
192          validate_args)
193      self._concentration0 = self._maybe_assert_valid_concentration(
194          ops.convert_to_tensor(concentration0, name="concentration0"),
195          validate_args)
196      check_ops.assert_same_float_dtype([
197          self._concentration1, self._concentration0])
198      self._total_concentration = self._concentration1 + self._concentration0
199    super(Beta, self).__init__(
200        dtype=self._total_concentration.dtype,
201        validate_args=validate_args,
202        allow_nan_stats=allow_nan_stats,
203        reparameterization_type=distribution.FULLY_REPARAMETERIZED,
204        parameters=parameters,
205        graph_parents=[self._concentration1,
206                       self._concentration0,
207                       self._total_concentration],
208        name=name)
209
210  @staticmethod
211  def _param_shapes(sample_shape):
212    return dict(zip(
213        ["concentration1", "concentration0"],
214        [ops.convert_to_tensor(sample_shape, dtype=dtypes.int32)] * 2))
215
216  @property
217  def concentration1(self):
218    """Concentration parameter associated with a `1` outcome."""
219    return self._concentration1
220
221  @property
222  def concentration0(self):
223    """Concentration parameter associated with a `0` outcome."""
224    return self._concentration0
225
226  @property
227  def total_concentration(self):
228    """Sum of concentration parameters."""
229    return self._total_concentration
230
231  def _batch_shape_tensor(self):
232    return array_ops.shape(self.total_concentration)
233
234  def _batch_shape(self):
235    return self.total_concentration.get_shape()
236
237  def _event_shape_tensor(self):
238    return constant_op.constant([], dtype=dtypes.int32)
239
240  def _event_shape(self):
241    return tensor_shape.TensorShape([])
242
243  def _sample_n(self, n, seed=None):
244    expanded_concentration1 = array_ops.ones_like(
245        self.total_concentration, dtype=self.dtype) * self.concentration1
246    expanded_concentration0 = array_ops.ones_like(
247        self.total_concentration, dtype=self.dtype) * self.concentration0
248    gamma1_sample = random_ops.random_gamma(
249        shape=[n],
250        alpha=expanded_concentration1,
251        dtype=self.dtype,
252        seed=seed)
253    gamma2_sample = random_ops.random_gamma(
254        shape=[n],
255        alpha=expanded_concentration0,
256        dtype=self.dtype,
257        seed=distribution_util.gen_new_seed(seed, "beta"))
258    beta_sample = gamma1_sample / (gamma1_sample + gamma2_sample)
259    return beta_sample
260
261  @distribution_util.AppendDocstring(_beta_sample_note)
262  def _log_prob(self, x):
263    return self._log_unnormalized_prob(x) - self._log_normalization()
264
265  @distribution_util.AppendDocstring(_beta_sample_note)
266  def _prob(self, x):
267    return math_ops.exp(self._log_prob(x))
268
269  @distribution_util.AppendDocstring(_beta_sample_note)
270  def _log_cdf(self, x):
271    return math_ops.log(self._cdf(x))
272
273  @distribution_util.AppendDocstring(_beta_sample_note)
274  def _cdf(self, x):
275    return math_ops.betainc(self.concentration1, self.concentration0, x)
276
277  def _log_unnormalized_prob(self, x):
278    x = self._maybe_assert_valid_sample(x)
279    return (math_ops.xlogy(self.concentration1 - 1., x) +
280            (self.concentration0 - 1.) * math_ops.log1p(-x))  # pylint: disable=invalid-unary-operand-type
281
282  def _log_normalization(self):
283    return (math_ops.lgamma(self.concentration1)
284            + math_ops.lgamma(self.concentration0)
285            - math_ops.lgamma(self.total_concentration))
286
287  def _entropy(self):
288    return (
289        self._log_normalization()
290        - (self.concentration1 - 1.) * math_ops.digamma(self.concentration1)
291        - (self.concentration0 - 1.) * math_ops.digamma(self.concentration0)
292        + ((self.total_concentration - 2.) *
293           math_ops.digamma(self.total_concentration)))
294
295  def _mean(self):
296    return self._concentration1 / self._total_concentration
297
298  def _variance(self):
299    return self._mean() * (1. - self._mean()) / (1. + self.total_concentration)
300
301  @distribution_util.AppendDocstring(
302      """Note: The mode is undefined when `concentration1 <= 1` or
303      `concentration0 <= 1`. If `self.allow_nan_stats` is `True`, `NaN`
304      is used for undefined modes. If `self.allow_nan_stats` is `False` an
305      exception is raised when one or more modes are undefined.""")
306  def _mode(self):
307    mode = (self.concentration1 - 1.) / (self.total_concentration - 2.)
308    if self.allow_nan_stats:
309      nan = array_ops.fill(
310          self.batch_shape_tensor(),
311          np.array(np.nan, dtype=self.dtype.as_numpy_dtype()),
312          name="nan")
313      is_defined = math_ops.logical_and(self.concentration1 > 1.,
314                                        self.concentration0 > 1.)
315      return array_ops.where_v2(is_defined, mode, nan)
316    return control_flow_ops.with_dependencies([
317        check_ops.assert_less(
318            array_ops.ones([], dtype=self.dtype),
319            self.concentration1,
320            message="Mode undefined for concentration1 <= 1."),
321        check_ops.assert_less(
322            array_ops.ones([], dtype=self.dtype),
323            self.concentration0,
324            message="Mode undefined for concentration0 <= 1.")
325    ], mode)
326
327  def _maybe_assert_valid_concentration(self, concentration, validate_args):
328    """Checks the validity of a concentration parameter."""
329    if not validate_args:
330      return concentration
331    return control_flow_ops.with_dependencies([
332        check_ops.assert_positive(
333            concentration,
334            message="Concentration parameter must be positive."),
335    ], concentration)
336
337  def _maybe_assert_valid_sample(self, x):
338    """Checks the validity of a sample."""
339    if not self.validate_args:
340      return x
341    return control_flow_ops.with_dependencies([
342        check_ops.assert_positive(x, message="sample must be positive"),
343        check_ops.assert_less(
344            x,
345            array_ops.ones([], self.dtype),
346            message="sample must be less than `1`."),
347    ], x)
348
349
350class BetaWithSoftplusConcentration(Beta):
351  """Beta with softplus transform of `concentration1` and `concentration0`."""
352
353  @deprecation.deprecated(
354      "2019-01-01",
355      "Use `tfd.Beta(tf.nn.softplus(concentration1), "
356      "tf.nn.softplus(concentration2))` instead.",
357      warn_once=True)
358  def __init__(self,
359               concentration1,
360               concentration0,
361               validate_args=False,
362               allow_nan_stats=True,
363               name="BetaWithSoftplusConcentration"):
364    parameters = dict(locals())
365    with ops.name_scope(name, values=[concentration1,
366                                      concentration0]) as name:
367      super(BetaWithSoftplusConcentration, self).__init__(
368          concentration1=nn.softplus(concentration1,
369                                     name="softplus_concentration1"),
370          concentration0=nn.softplus(concentration0,
371                                     name="softplus_concentration0"),
372          validate_args=validate_args,
373          allow_nan_stats=allow_nan_stats,
374          name=name)
375    self._parameters = parameters
376
377
378@kullback_leibler.RegisterKL(Beta, Beta)
379def _kl_beta_beta(d1, d2, name=None):
380  """Calculate the batchwise KL divergence KL(d1 || d2) with d1 and d2 Beta.
381
382  Args:
383    d1: instance of a Beta distribution object.
384    d2: instance of a Beta distribution object.
385    name: (optional) Name to use for created operations.
386      default is "kl_beta_beta".
387
388  Returns:
389    Batchwise KL(d1 || d2)
390  """
391  def delta(fn, is_property=True):
392    fn1 = getattr(d1, fn)
393    fn2 = getattr(d2, fn)
394    return (fn2 - fn1) if is_property else (fn2() - fn1())
395  with ops.name_scope(name, "kl_beta_beta", values=[
396      d1.concentration1,
397      d1.concentration0,
398      d1.total_concentration,
399      d2.concentration1,
400      d2.concentration0,
401      d2.total_concentration,
402  ]):
403    return (delta("_log_normalization", is_property=False)
404            - math_ops.digamma(d1.concentration1) * delta("concentration1")
405            - math_ops.digamma(d1.concentration0) * delta("concentration0")
406            + (math_ops.digamma(d1.total_concentration)
407               * delta("total_concentration")))
408