1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""The Beta distribution class.""" 16 17import numpy as np 18 19from tensorflow.python.framework import constant_op 20from tensorflow.python.framework import dtypes 21from tensorflow.python.framework import ops 22from tensorflow.python.framework import tensor_shape 23from tensorflow.python.ops import array_ops 24from tensorflow.python.ops import check_ops 25from tensorflow.python.ops import control_flow_ops 26from tensorflow.python.ops import math_ops 27from tensorflow.python.ops import nn 28from tensorflow.python.ops import random_ops 29from tensorflow.python.ops.distributions import distribution 30from tensorflow.python.ops.distributions import kullback_leibler 31from tensorflow.python.ops.distributions import util as distribution_util 32from tensorflow.python.util import deprecation 33from tensorflow.python.util.tf_export import tf_export 34 35 36__all__ = [ 37 "Beta", 38 "BetaWithSoftplusConcentration", 39] 40 41 42_beta_sample_note = """Note: `x` must have dtype `self.dtype` and be in 43`[0, 1].` It must have a shape compatible with `self.batch_shape()`.""" 44 45 46@tf_export(v1=["distributions.Beta"]) 47class Beta(distribution.Distribution): 48 """Beta distribution. 49 50 The Beta distribution is defined over the `(0, 1)` interval using parameters 51 `concentration1` (aka "alpha") and `concentration0` (aka "beta"). 52 53 #### Mathematical Details 54 55 The probability density function (pdf) is, 56 57 ```none 58 pdf(x; alpha, beta) = x**(alpha - 1) (1 - x)**(beta - 1) / Z 59 Z = Gamma(alpha) Gamma(beta) / Gamma(alpha + beta) 60 ``` 61 62 where: 63 64 * `concentration1 = alpha`, 65 * `concentration0 = beta`, 66 * `Z` is the normalization constant, and, 67 * `Gamma` is the [gamma function]( 68 https://en.wikipedia.org/wiki/Gamma_function). 69 70 The concentration parameters represent mean total counts of a `1` or a `0`, 71 i.e., 72 73 ```none 74 concentration1 = alpha = mean * total_concentration 75 concentration0 = beta = (1. - mean) * total_concentration 76 ``` 77 78 where `mean` in `(0, 1)` and `total_concentration` is a positive real number 79 representing a mean `total_count = concentration1 + concentration0`. 80 81 Distribution parameters are automatically broadcast in all functions; see 82 examples for details. 83 84 Warning: The samples can be zero due to finite precision. 85 This happens more often when some of the concentrations are very small. 86 Make sure to round the samples to `np.finfo(dtype).tiny` before computing the 87 density. 88 89 Samples of this distribution are reparameterized (pathwise differentiable). 90 The derivatives are computed using the approach described in 91 (Figurnov et al., 2018). 92 93 #### Examples 94 95 ```python 96 import tensorflow_probability as tfp 97 tfd = tfp.distributions 98 99 # Create a batch of three Beta distributions. 100 alpha = [1, 2, 3] 101 beta = [1, 2, 3] 102 dist = tfd.Beta(alpha, beta) 103 104 dist.sample([4, 5]) # Shape [4, 5, 3] 105 106 # `x` has three batch entries, each with two samples. 107 x = [[.1, .4, .5], 108 [.2, .3, .5]] 109 # Calculate the probability of each pair of samples under the corresponding 110 # distribution in `dist`. 111 dist.prob(x) # Shape [2, 3] 112 ``` 113 114 ```python 115 # Create batch_shape=[2, 3] via parameter broadcast: 116 alpha = [[1.], [2]] # Shape [2, 1] 117 beta = [3., 4, 5] # Shape [3] 118 dist = tfd.Beta(alpha, beta) 119 120 # alpha broadcast as: [[1., 1, 1,], 121 # [2, 2, 2]] 122 # beta broadcast as: [[3., 4, 5], 123 # [3, 4, 5]] 124 # batch_Shape [2, 3] 125 dist.sample([4, 5]) # Shape [4, 5, 2, 3] 126 127 x = [.2, .3, .5] 128 # x will be broadcast as [[.2, .3, .5], 129 # [.2, .3, .5]], 130 # thus matching batch_shape [2, 3]. 131 dist.prob(x) # Shape [2, 3] 132 ``` 133 134 Compute the gradients of samples w.r.t. the parameters: 135 136 ```python 137 alpha = tf.constant(1.0) 138 beta = tf.constant(2.0) 139 dist = tfd.Beta(alpha, beta) 140 samples = dist.sample(5) # Shape [5] 141 loss = tf.reduce_mean(tf.square(samples)) # Arbitrary loss function 142 # Unbiased stochastic gradients of the loss function 143 grads = tf.gradients(loss, [alpha, beta]) 144 ``` 145 146 References: 147 Implicit Reparameterization Gradients: 148 [Figurnov et al., 2018] 149 (http://papers.nips.cc/paper/7326-implicit-reparameterization-gradients) 150 ([pdf] 151 (http://papers.nips.cc/paper/7326-implicit-reparameterization-gradients.pdf)) 152 """ 153 154 @deprecation.deprecated( 155 "2019-01-01", 156 "The TensorFlow Distributions library has moved to " 157 "TensorFlow Probability " 158 "(https://github.com/tensorflow/probability). You " 159 "should update all references to use `tfp.distributions` " 160 "instead of `tf.distributions`.", 161 warn_once=True) 162 def __init__(self, 163 concentration1=None, 164 concentration0=None, 165 validate_args=False, 166 allow_nan_stats=True, 167 name="Beta"): 168 """Initialize a batch of Beta distributions. 169 170 Args: 171 concentration1: Positive floating-point `Tensor` indicating mean 172 number of successes; aka "alpha". Implies `self.dtype` and 173 `self.batch_shape`, i.e., 174 `concentration1.shape = [N1, N2, ..., Nm] = self.batch_shape`. 175 concentration0: Positive floating-point `Tensor` indicating mean 176 number of failures; aka "beta". Otherwise has same semantics as 177 `concentration1`. 178 validate_args: Python `bool`, default `False`. When `True` distribution 179 parameters are checked for validity despite possibly degrading runtime 180 performance. When `False` invalid inputs may silently render incorrect 181 outputs. 182 allow_nan_stats: Python `bool`, default `True`. When `True`, statistics 183 (e.g., mean, mode, variance) use the value "`NaN`" to indicate the 184 result is undefined. When `False`, an exception is raised if one or 185 more of the statistic's batch members are undefined. 186 name: Python `str` name prefixed to Ops created by this class. 187 """ 188 parameters = dict(locals()) 189 with ops.name_scope(name, values=[concentration1, concentration0]) as name: 190 self._concentration1 = self._maybe_assert_valid_concentration( 191 ops.convert_to_tensor(concentration1, name="concentration1"), 192 validate_args) 193 self._concentration0 = self._maybe_assert_valid_concentration( 194 ops.convert_to_tensor(concentration0, name="concentration0"), 195 validate_args) 196 check_ops.assert_same_float_dtype([ 197 self._concentration1, self._concentration0]) 198 self._total_concentration = self._concentration1 + self._concentration0 199 super(Beta, self).__init__( 200 dtype=self._total_concentration.dtype, 201 validate_args=validate_args, 202 allow_nan_stats=allow_nan_stats, 203 reparameterization_type=distribution.FULLY_REPARAMETERIZED, 204 parameters=parameters, 205 graph_parents=[self._concentration1, 206 self._concentration0, 207 self._total_concentration], 208 name=name) 209 210 @staticmethod 211 def _param_shapes(sample_shape): 212 return dict(zip( 213 ["concentration1", "concentration0"], 214 [ops.convert_to_tensor(sample_shape, dtype=dtypes.int32)] * 2)) 215 216 @property 217 def concentration1(self): 218 """Concentration parameter associated with a `1` outcome.""" 219 return self._concentration1 220 221 @property 222 def concentration0(self): 223 """Concentration parameter associated with a `0` outcome.""" 224 return self._concentration0 225 226 @property 227 def total_concentration(self): 228 """Sum of concentration parameters.""" 229 return self._total_concentration 230 231 def _batch_shape_tensor(self): 232 return array_ops.shape(self.total_concentration) 233 234 def _batch_shape(self): 235 return self.total_concentration.get_shape() 236 237 def _event_shape_tensor(self): 238 return constant_op.constant([], dtype=dtypes.int32) 239 240 def _event_shape(self): 241 return tensor_shape.TensorShape([]) 242 243 def _sample_n(self, n, seed=None): 244 expanded_concentration1 = array_ops.ones_like( 245 self.total_concentration, dtype=self.dtype) * self.concentration1 246 expanded_concentration0 = array_ops.ones_like( 247 self.total_concentration, dtype=self.dtype) * self.concentration0 248 gamma1_sample = random_ops.random_gamma( 249 shape=[n], 250 alpha=expanded_concentration1, 251 dtype=self.dtype, 252 seed=seed) 253 gamma2_sample = random_ops.random_gamma( 254 shape=[n], 255 alpha=expanded_concentration0, 256 dtype=self.dtype, 257 seed=distribution_util.gen_new_seed(seed, "beta")) 258 beta_sample = gamma1_sample / (gamma1_sample + gamma2_sample) 259 return beta_sample 260 261 @distribution_util.AppendDocstring(_beta_sample_note) 262 def _log_prob(self, x): 263 return self._log_unnormalized_prob(x) - self._log_normalization() 264 265 @distribution_util.AppendDocstring(_beta_sample_note) 266 def _prob(self, x): 267 return math_ops.exp(self._log_prob(x)) 268 269 @distribution_util.AppendDocstring(_beta_sample_note) 270 def _log_cdf(self, x): 271 return math_ops.log(self._cdf(x)) 272 273 @distribution_util.AppendDocstring(_beta_sample_note) 274 def _cdf(self, x): 275 return math_ops.betainc(self.concentration1, self.concentration0, x) 276 277 def _log_unnormalized_prob(self, x): 278 x = self._maybe_assert_valid_sample(x) 279 return (math_ops.xlogy(self.concentration1 - 1., x) + 280 (self.concentration0 - 1.) * math_ops.log1p(-x)) # pylint: disable=invalid-unary-operand-type 281 282 def _log_normalization(self): 283 return (math_ops.lgamma(self.concentration1) 284 + math_ops.lgamma(self.concentration0) 285 - math_ops.lgamma(self.total_concentration)) 286 287 def _entropy(self): 288 return ( 289 self._log_normalization() 290 - (self.concentration1 - 1.) * math_ops.digamma(self.concentration1) 291 - (self.concentration0 - 1.) * math_ops.digamma(self.concentration0) 292 + ((self.total_concentration - 2.) * 293 math_ops.digamma(self.total_concentration))) 294 295 def _mean(self): 296 return self._concentration1 / self._total_concentration 297 298 def _variance(self): 299 return self._mean() * (1. - self._mean()) / (1. + self.total_concentration) 300 301 @distribution_util.AppendDocstring( 302 """Note: The mode is undefined when `concentration1 <= 1` or 303 `concentration0 <= 1`. If `self.allow_nan_stats` is `True`, `NaN` 304 is used for undefined modes. If `self.allow_nan_stats` is `False` an 305 exception is raised when one or more modes are undefined.""") 306 def _mode(self): 307 mode = (self.concentration1 - 1.) / (self.total_concentration - 2.) 308 if self.allow_nan_stats: 309 nan = array_ops.fill( 310 self.batch_shape_tensor(), 311 np.array(np.nan, dtype=self.dtype.as_numpy_dtype()), 312 name="nan") 313 is_defined = math_ops.logical_and(self.concentration1 > 1., 314 self.concentration0 > 1.) 315 return array_ops.where_v2(is_defined, mode, nan) 316 return control_flow_ops.with_dependencies([ 317 check_ops.assert_less( 318 array_ops.ones([], dtype=self.dtype), 319 self.concentration1, 320 message="Mode undefined for concentration1 <= 1."), 321 check_ops.assert_less( 322 array_ops.ones([], dtype=self.dtype), 323 self.concentration0, 324 message="Mode undefined for concentration0 <= 1.") 325 ], mode) 326 327 def _maybe_assert_valid_concentration(self, concentration, validate_args): 328 """Checks the validity of a concentration parameter.""" 329 if not validate_args: 330 return concentration 331 return control_flow_ops.with_dependencies([ 332 check_ops.assert_positive( 333 concentration, 334 message="Concentration parameter must be positive."), 335 ], concentration) 336 337 def _maybe_assert_valid_sample(self, x): 338 """Checks the validity of a sample.""" 339 if not self.validate_args: 340 return x 341 return control_flow_ops.with_dependencies([ 342 check_ops.assert_positive(x, message="sample must be positive"), 343 check_ops.assert_less( 344 x, 345 array_ops.ones([], self.dtype), 346 message="sample must be less than `1`."), 347 ], x) 348 349 350class BetaWithSoftplusConcentration(Beta): 351 """Beta with softplus transform of `concentration1` and `concentration0`.""" 352 353 @deprecation.deprecated( 354 "2019-01-01", 355 "Use `tfd.Beta(tf.nn.softplus(concentration1), " 356 "tf.nn.softplus(concentration2))` instead.", 357 warn_once=True) 358 def __init__(self, 359 concentration1, 360 concentration0, 361 validate_args=False, 362 allow_nan_stats=True, 363 name="BetaWithSoftplusConcentration"): 364 parameters = dict(locals()) 365 with ops.name_scope(name, values=[concentration1, 366 concentration0]) as name: 367 super(BetaWithSoftplusConcentration, self).__init__( 368 concentration1=nn.softplus(concentration1, 369 name="softplus_concentration1"), 370 concentration0=nn.softplus(concentration0, 371 name="softplus_concentration0"), 372 validate_args=validate_args, 373 allow_nan_stats=allow_nan_stats, 374 name=name) 375 self._parameters = parameters 376 377 378@kullback_leibler.RegisterKL(Beta, Beta) 379def _kl_beta_beta(d1, d2, name=None): 380 """Calculate the batchwise KL divergence KL(d1 || d2) with d1 and d2 Beta. 381 382 Args: 383 d1: instance of a Beta distribution object. 384 d2: instance of a Beta distribution object. 385 name: (optional) Name to use for created operations. 386 default is "kl_beta_beta". 387 388 Returns: 389 Batchwise KL(d1 || d2) 390 """ 391 def delta(fn, is_property=True): 392 fn1 = getattr(d1, fn) 393 fn2 = getattr(d2, fn) 394 return (fn2 - fn1) if is_property else (fn2() - fn1()) 395 with ops.name_scope(name, "kl_beta_beta", values=[ 396 d1.concentration1, 397 d1.concentration0, 398 d1.total_concentration, 399 d2.concentration1, 400 d2.concentration0, 401 d2.total_concentration, 402 ]): 403 return (delta("_log_normalization", is_property=False) 404 - math_ops.digamma(d1.concentration1) * delta("concentration1") 405 - math_ops.digamma(d1.concentration0) * delta("concentration0") 406 + (math_ops.digamma(d1.total_concentration) 407 * delta("total_concentration"))) 408