1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""The Beta distribution class.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.python.framework import constant_op 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import ops 26from tensorflow.python.framework import tensor_shape 27from tensorflow.python.ops import array_ops 28from tensorflow.python.ops import check_ops 29from tensorflow.python.ops import control_flow_ops 30from tensorflow.python.ops import math_ops 31from tensorflow.python.ops import nn 32from tensorflow.python.ops import random_ops 33from tensorflow.python.ops.distributions import distribution 34from tensorflow.python.ops.distributions import kullback_leibler 35from tensorflow.python.ops.distributions import util as distribution_util 36from tensorflow.python.util import deprecation 37from tensorflow.python.util.tf_export import tf_export 38 39 40__all__ = [ 41 "Beta", 42 "BetaWithSoftplusConcentration", 43] 44 45 46_beta_sample_note = """Note: `x` must have dtype `self.dtype` and be in 47`[0, 1].` It must have a shape compatible with `self.batch_shape()`.""" 48 49 50@tf_export(v1=["distributions.Beta"]) 51class Beta(distribution.Distribution): 52 """Beta distribution. 53 54 The Beta distribution is defined over the `(0, 1)` interval using parameters 55 `concentration1` (aka "alpha") and `concentration0` (aka "beta"). 56 57 #### Mathematical Details 58 59 The probability density function (pdf) is, 60 61 ```none 62 pdf(x; alpha, beta) = x**(alpha - 1) (1 - x)**(beta - 1) / Z 63 Z = Gamma(alpha) Gamma(beta) / Gamma(alpha + beta) 64 ``` 65 66 where: 67 68 * `concentration1 = alpha`, 69 * `concentration0 = beta`, 70 * `Z` is the normalization constant, and, 71 * `Gamma` is the [gamma function]( 72 https://en.wikipedia.org/wiki/Gamma_function). 73 74 The concentration parameters represent mean total counts of a `1` or a `0`, 75 i.e., 76 77 ```none 78 concentration1 = alpha = mean * total_concentration 79 concentration0 = beta = (1. - mean) * total_concentration 80 ``` 81 82 where `mean` in `(0, 1)` and `total_concentration` is a positive real number 83 representing a mean `total_count = concentration1 + concentration0`. 84 85 Distribution parameters are automatically broadcast in all functions; see 86 examples for details. 87 88 Warning: The samples can be zero due to finite precision. 89 This happens more often when some of the concentrations are very small. 90 Make sure to round the samples to `np.finfo(dtype).tiny` before computing the 91 density. 92 93 Samples of this distribution are reparameterized (pathwise differentiable). 94 The derivatives are computed using the approach described in 95 (Figurnov et al., 2018). 96 97 #### Examples 98 99 ```python 100 import tensorflow_probability as tfp 101 tfd = tfp.distributions 102 103 # Create a batch of three Beta distributions. 104 alpha = [1, 2, 3] 105 beta = [1, 2, 3] 106 dist = tfd.Beta(alpha, beta) 107 108 dist.sample([4, 5]) # Shape [4, 5, 3] 109 110 # `x` has three batch entries, each with two samples. 111 x = [[.1, .4, .5], 112 [.2, .3, .5]] 113 # Calculate the probability of each pair of samples under the corresponding 114 # distribution in `dist`. 115 dist.prob(x) # Shape [2, 3] 116 ``` 117 118 ```python 119 # Create batch_shape=[2, 3] via parameter broadcast: 120 alpha = [[1.], [2]] # Shape [2, 1] 121 beta = [3., 4, 5] # Shape [3] 122 dist = tfd.Beta(alpha, beta) 123 124 # alpha broadcast as: [[1., 1, 1,], 125 # [2, 2, 2]] 126 # beta broadcast as: [[3., 4, 5], 127 # [3, 4, 5]] 128 # batch_Shape [2, 3] 129 dist.sample([4, 5]) # Shape [4, 5, 2, 3] 130 131 x = [.2, .3, .5] 132 # x will be broadcast as [[.2, .3, .5], 133 # [.2, .3, .5]], 134 # thus matching batch_shape [2, 3]. 135 dist.prob(x) # Shape [2, 3] 136 ``` 137 138 Compute the gradients of samples w.r.t. the parameters: 139 140 ```python 141 alpha = tf.constant(1.0) 142 beta = tf.constant(2.0) 143 dist = tfd.Beta(alpha, beta) 144 samples = dist.sample(5) # Shape [5] 145 loss = tf.reduce_mean(tf.square(samples)) # Arbitrary loss function 146 # Unbiased stochastic gradients of the loss function 147 grads = tf.gradients(loss, [alpha, beta]) 148 ``` 149 150 References: 151 Implicit Reparameterization Gradients: 152 [Figurnov et al., 2018] 153 (http://papers.nips.cc/paper/7326-implicit-reparameterization-gradients) 154 ([pdf] 155 (http://papers.nips.cc/paper/7326-implicit-reparameterization-gradients.pdf)) 156 """ 157 158 @deprecation.deprecated( 159 "2019-01-01", 160 "The TensorFlow Distributions library has moved to " 161 "TensorFlow Probability " 162 "(https://github.com/tensorflow/probability). You " 163 "should update all references to use `tfp.distributions` " 164 "instead of `tf.distributions`.", 165 warn_once=True) 166 def __init__(self, 167 concentration1=None, 168 concentration0=None, 169 validate_args=False, 170 allow_nan_stats=True, 171 name="Beta"): 172 """Initialize a batch of Beta distributions. 173 174 Args: 175 concentration1: Positive floating-point `Tensor` indicating mean 176 number of successes; aka "alpha". Implies `self.dtype` and 177 `self.batch_shape`, i.e., 178 `concentration1.shape = [N1, N2, ..., Nm] = self.batch_shape`. 179 concentration0: Positive floating-point `Tensor` indicating mean 180 number of failures; aka "beta". Otherwise has same semantics as 181 `concentration1`. 182 validate_args: Python `bool`, default `False`. When `True` distribution 183 parameters are checked for validity despite possibly degrading runtime 184 performance. When `False` invalid inputs may silently render incorrect 185 outputs. 186 allow_nan_stats: Python `bool`, default `True`. When `True`, statistics 187 (e.g., mean, mode, variance) use the value "`NaN`" to indicate the 188 result is undefined. When `False`, an exception is raised if one or 189 more of the statistic's batch members are undefined. 190 name: Python `str` name prefixed to Ops created by this class. 191 """ 192 parameters = dict(locals()) 193 with ops.name_scope(name, values=[concentration1, concentration0]) as name: 194 self._concentration1 = self._maybe_assert_valid_concentration( 195 ops.convert_to_tensor(concentration1, name="concentration1"), 196 validate_args) 197 self._concentration0 = self._maybe_assert_valid_concentration( 198 ops.convert_to_tensor(concentration0, name="concentration0"), 199 validate_args) 200 check_ops.assert_same_float_dtype([ 201 self._concentration1, self._concentration0]) 202 self._total_concentration = self._concentration1 + self._concentration0 203 super(Beta, self).__init__( 204 dtype=self._total_concentration.dtype, 205 validate_args=validate_args, 206 allow_nan_stats=allow_nan_stats, 207 reparameterization_type=distribution.FULLY_REPARAMETERIZED, 208 parameters=parameters, 209 graph_parents=[self._concentration1, 210 self._concentration0, 211 self._total_concentration], 212 name=name) 213 214 @staticmethod 215 def _param_shapes(sample_shape): 216 return dict(zip( 217 ["concentration1", "concentration0"], 218 [ops.convert_to_tensor(sample_shape, dtype=dtypes.int32)] * 2)) 219 220 @property 221 def concentration1(self): 222 """Concentration parameter associated with a `1` outcome.""" 223 return self._concentration1 224 225 @property 226 def concentration0(self): 227 """Concentration parameter associated with a `0` outcome.""" 228 return self._concentration0 229 230 @property 231 def total_concentration(self): 232 """Sum of concentration parameters.""" 233 return self._total_concentration 234 235 def _batch_shape_tensor(self): 236 return array_ops.shape(self.total_concentration) 237 238 def _batch_shape(self): 239 return self.total_concentration.get_shape() 240 241 def _event_shape_tensor(self): 242 return constant_op.constant([], dtype=dtypes.int32) 243 244 def _event_shape(self): 245 return tensor_shape.TensorShape([]) 246 247 def _sample_n(self, n, seed=None): 248 expanded_concentration1 = array_ops.ones_like( 249 self.total_concentration, dtype=self.dtype) * self.concentration1 250 expanded_concentration0 = array_ops.ones_like( 251 self.total_concentration, dtype=self.dtype) * self.concentration0 252 gamma1_sample = random_ops.random_gamma( 253 shape=[n], 254 alpha=expanded_concentration1, 255 dtype=self.dtype, 256 seed=seed) 257 gamma2_sample = random_ops.random_gamma( 258 shape=[n], 259 alpha=expanded_concentration0, 260 dtype=self.dtype, 261 seed=distribution_util.gen_new_seed(seed, "beta")) 262 beta_sample = gamma1_sample / (gamma1_sample + gamma2_sample) 263 return beta_sample 264 265 @distribution_util.AppendDocstring(_beta_sample_note) 266 def _log_prob(self, x): 267 return self._log_unnormalized_prob(x) - self._log_normalization() 268 269 @distribution_util.AppendDocstring(_beta_sample_note) 270 def _prob(self, x): 271 return math_ops.exp(self._log_prob(x)) 272 273 @distribution_util.AppendDocstring(_beta_sample_note) 274 def _log_cdf(self, x): 275 return math_ops.log(self._cdf(x)) 276 277 @distribution_util.AppendDocstring(_beta_sample_note) 278 def _cdf(self, x): 279 return math_ops.betainc(self.concentration1, self.concentration0, x) 280 281 def _log_unnormalized_prob(self, x): 282 x = self._maybe_assert_valid_sample(x) 283 return (math_ops.xlogy(self.concentration1 - 1., x) + 284 (self.concentration0 - 1.) * math_ops.log1p(-x)) # pylint: disable=invalid-unary-operand-type 285 286 def _log_normalization(self): 287 return (math_ops.lgamma(self.concentration1) 288 + math_ops.lgamma(self.concentration0) 289 - math_ops.lgamma(self.total_concentration)) 290 291 def _entropy(self): 292 return ( 293 self._log_normalization() 294 - (self.concentration1 - 1.) * math_ops.digamma(self.concentration1) 295 - (self.concentration0 - 1.) * math_ops.digamma(self.concentration0) 296 + ((self.total_concentration - 2.) * 297 math_ops.digamma(self.total_concentration))) 298 299 def _mean(self): 300 return self._concentration1 / self._total_concentration 301 302 def _variance(self): 303 return self._mean() * (1. - self._mean()) / (1. + self.total_concentration) 304 305 @distribution_util.AppendDocstring( 306 """Note: The mode is undefined when `concentration1 <= 1` or 307 `concentration0 <= 1`. If `self.allow_nan_stats` is `True`, `NaN` 308 is used for undefined modes. If `self.allow_nan_stats` is `False` an 309 exception is raised when one or more modes are undefined.""") 310 def _mode(self): 311 mode = (self.concentration1 - 1.) / (self.total_concentration - 2.) 312 if self.allow_nan_stats: 313 nan = array_ops.fill( 314 self.batch_shape_tensor(), 315 np.array(np.nan, dtype=self.dtype.as_numpy_dtype()), 316 name="nan") 317 is_defined = math_ops.logical_and(self.concentration1 > 1., 318 self.concentration0 > 1.) 319 return array_ops.where_v2(is_defined, mode, nan) 320 return control_flow_ops.with_dependencies([ 321 check_ops.assert_less( 322 array_ops.ones([], dtype=self.dtype), 323 self.concentration1, 324 message="Mode undefined for concentration1 <= 1."), 325 check_ops.assert_less( 326 array_ops.ones([], dtype=self.dtype), 327 self.concentration0, 328 message="Mode undefined for concentration0 <= 1.") 329 ], mode) 330 331 def _maybe_assert_valid_concentration(self, concentration, validate_args): 332 """Checks the validity of a concentration parameter.""" 333 if not validate_args: 334 return concentration 335 return control_flow_ops.with_dependencies([ 336 check_ops.assert_positive( 337 concentration, 338 message="Concentration parameter must be positive."), 339 ], concentration) 340 341 def _maybe_assert_valid_sample(self, x): 342 """Checks the validity of a sample.""" 343 if not self.validate_args: 344 return x 345 return control_flow_ops.with_dependencies([ 346 check_ops.assert_positive(x, message="sample must be positive"), 347 check_ops.assert_less( 348 x, 349 array_ops.ones([], self.dtype), 350 message="sample must be less than `1`."), 351 ], x) 352 353 354class BetaWithSoftplusConcentration(Beta): 355 """Beta with softplus transform of `concentration1` and `concentration0`.""" 356 357 @deprecation.deprecated( 358 "2019-01-01", 359 "Use `tfd.Beta(tf.nn.softplus(concentration1), " 360 "tf.nn.softplus(concentration2))` instead.", 361 warn_once=True) 362 def __init__(self, 363 concentration1, 364 concentration0, 365 validate_args=False, 366 allow_nan_stats=True, 367 name="BetaWithSoftplusConcentration"): 368 parameters = dict(locals()) 369 with ops.name_scope(name, values=[concentration1, 370 concentration0]) as name: 371 super(BetaWithSoftplusConcentration, self).__init__( 372 concentration1=nn.softplus(concentration1, 373 name="softplus_concentration1"), 374 concentration0=nn.softplus(concentration0, 375 name="softplus_concentration0"), 376 validate_args=validate_args, 377 allow_nan_stats=allow_nan_stats, 378 name=name) 379 self._parameters = parameters 380 381 382@kullback_leibler.RegisterKL(Beta, Beta) 383def _kl_beta_beta(d1, d2, name=None): 384 """Calculate the batchwise KL divergence KL(d1 || d2) with d1 and d2 Beta. 385 386 Args: 387 d1: instance of a Beta distribution object. 388 d2: instance of a Beta distribution object. 389 name: (optional) Name to use for created operations. 390 default is "kl_beta_beta". 391 392 Returns: 393 Batchwise KL(d1 || d2) 394 """ 395 def delta(fn, is_property=True): 396 fn1 = getattr(d1, fn) 397 fn2 = getattr(d2, fn) 398 return (fn2 - fn1) if is_property else (fn2() - fn1()) 399 with ops.name_scope(name, "kl_beta_beta", values=[ 400 d1.concentration1, 401 d1.concentration0, 402 d1.total_concentration, 403 d2.concentration1, 404 d2.concentration0, 405 d2.total_concentration, 406 ]): 407 return (delta("_log_normalization", is_property=False) 408 - math_ops.digamma(d1.concentration1) * delta("concentration1") 409 - math_ops.digamma(d1.concentration0) * delta("concentration0") 410 + (math_ops.digamma(d1.total_concentration) 411 * delta("total_concentration"))) 412