1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""The Dirichlet distribution class.""" 16 17import numpy as np 18 19from tensorflow.python.framework import ops 20from tensorflow.python.ops import array_ops 21from tensorflow.python.ops import check_ops 22from tensorflow.python.ops import control_flow_ops 23from tensorflow.python.ops import math_ops 24from tensorflow.python.ops import random_ops 25from tensorflow.python.ops import special_math_ops 26from tensorflow.python.ops.distributions import distribution 27from tensorflow.python.ops.distributions import kullback_leibler 28from tensorflow.python.ops.distributions import util as distribution_util 29from tensorflow.python.util import deprecation 30from tensorflow.python.util.tf_export import tf_export 31 32 33__all__ = [ 34 "Dirichlet", 35] 36 37 38_dirichlet_sample_note = """Note: `value` must be a non-negative tensor with 39dtype `self.dtype` and be in the `(self.event_shape() - 1)`-simplex, i.e., 40`tf.reduce_sum(value, -1) = 1`. It must have a shape compatible with 41`self.batch_shape() + self.event_shape()`.""" 42 43 44@tf_export(v1=["distributions.Dirichlet"]) 45class Dirichlet(distribution.Distribution): 46 """Dirichlet distribution. 47 48 The Dirichlet distribution is defined over the 49 [`(k-1)`-simplex](https://en.wikipedia.org/wiki/Simplex) using a positive, 50 length-`k` vector `concentration` (`k > 1`). The Dirichlet is identically the 51 Beta distribution when `k = 2`. 52 53 #### Mathematical Details 54 55 The Dirichlet is a distribution over the open `(k-1)`-simplex, i.e., 56 57 ```none 58 S^{k-1} = { (x_0, ..., x_{k-1}) in R^k : sum_j x_j = 1 and all_j x_j > 0 }. 59 ``` 60 61 The probability density function (pdf) is, 62 63 ```none 64 pdf(x; alpha) = prod_j x_j**(alpha_j - 1) / Z 65 Z = prod_j Gamma(alpha_j) / Gamma(sum_j alpha_j) 66 ``` 67 68 where: 69 70 * `x in S^{k-1}`, i.e., the `(k-1)`-simplex, 71 * `concentration = alpha = [alpha_0, ..., alpha_{k-1}]`, `alpha_j > 0`, 72 * `Z` is the normalization constant aka the [multivariate beta function]( 73 https://en.wikipedia.org/wiki/Beta_function#Multivariate_beta_function), 74 and, 75 * `Gamma` is the [gamma function]( 76 https://en.wikipedia.org/wiki/Gamma_function). 77 78 The `concentration` represents mean total counts of class occurrence, i.e., 79 80 ```none 81 concentration = alpha = mean * total_concentration 82 ``` 83 84 where `mean` in `S^{k-1}` and `total_concentration` is a positive real number 85 representing a mean total count. 86 87 Distribution parameters are automatically broadcast in all functions; see 88 examples for details. 89 90 Warning: Some components of the samples can be zero due to finite precision. 91 This happens more often when some of the concentrations are very small. 92 Make sure to round the samples to `np.finfo(dtype).tiny` before computing the 93 density. 94 95 Samples of this distribution are reparameterized (pathwise differentiable). 96 The derivatives are computed using the approach described in 97 (Figurnov et al., 2018). 98 99 #### Examples 100 101 ```python 102 import tensorflow_probability as tfp 103 tfd = tfp.distributions 104 105 # Create a single trivariate Dirichlet, with the 3rd class being three times 106 # more frequent than the first. I.e., batch_shape=[], event_shape=[3]. 107 alpha = [1., 2, 3] 108 dist = tfd.Dirichlet(alpha) 109 110 dist.sample([4, 5]) # shape: [4, 5, 3] 111 112 # x has one sample, one batch, three classes: 113 x = [.2, .3, .5] # shape: [3] 114 dist.prob(x) # shape: [] 115 116 # x has two samples from one batch: 117 x = [[.1, .4, .5], 118 [.2, .3, .5]] 119 dist.prob(x) # shape: [2] 120 121 # alpha will be broadcast to shape [5, 7, 3] to match x. 122 x = [[...]] # shape: [5, 7, 3] 123 dist.prob(x) # shape: [5, 7] 124 ``` 125 126 ```python 127 # Create batch_shape=[2], event_shape=[3]: 128 alpha = [[1., 2, 3], 129 [4, 5, 6]] # shape: [2, 3] 130 dist = tfd.Dirichlet(alpha) 131 132 dist.sample([4, 5]) # shape: [4, 5, 2, 3] 133 134 x = [.2, .3, .5] 135 # x will be broadcast as [[.2, .3, .5], 136 # [.2, .3, .5]], 137 # thus matching batch_shape [2, 3]. 138 dist.prob(x) # shape: [2] 139 ``` 140 141 Compute the gradients of samples w.r.t. the parameters: 142 143 ```python 144 alpha = tf.constant([1.0, 2.0, 3.0]) 145 dist = tfd.Dirichlet(alpha) 146 samples = dist.sample(5) # Shape [5, 3] 147 loss = tf.reduce_mean(tf.square(samples)) # Arbitrary loss function 148 # Unbiased stochastic gradients of the loss function 149 grads = tf.gradients(loss, alpha) 150 ``` 151 152 References: 153 Implicit Reparameterization Gradients: 154 [Figurnov et al., 2018] 155 (http://papers.nips.cc/paper/7326-implicit-reparameterization-gradients) 156 ([pdf] 157 (http://papers.nips.cc/paper/7326-implicit-reparameterization-gradients.pdf)) 158 """ 159 160 @deprecation.deprecated( 161 "2019-01-01", 162 "The TensorFlow Distributions library has moved to " 163 "TensorFlow Probability " 164 "(https://github.com/tensorflow/probability). You " 165 "should update all references to use `tfp.distributions` " 166 "instead of `tf.distributions`.", 167 warn_once=True) 168 def __init__(self, 169 concentration, 170 validate_args=False, 171 allow_nan_stats=True, 172 name="Dirichlet"): 173 """Initialize a batch of Dirichlet distributions. 174 175 Args: 176 concentration: Positive floating-point `Tensor` indicating mean number 177 of class occurrences; aka "alpha". Implies `self.dtype`, and 178 `self.batch_shape`, `self.event_shape`, i.e., if 179 `concentration.shape = [N1, N2, ..., Nm, k]` then 180 `batch_shape = [N1, N2, ..., Nm]` and 181 `event_shape = [k]`. 182 validate_args: Python `bool`, default `False`. When `True` distribution 183 parameters are checked for validity despite possibly degrading runtime 184 performance. When `False` invalid inputs may silently render incorrect 185 outputs. 186 allow_nan_stats: Python `bool`, default `True`. When `True`, statistics 187 (e.g., mean, mode, variance) use the value "`NaN`" to indicate the 188 result is undefined. When `False`, an exception is raised if one or 189 more of the statistic's batch members are undefined. 190 name: Python `str` name prefixed to Ops created by this class. 191 """ 192 parameters = dict(locals()) 193 with ops.name_scope(name, values=[concentration]) as name: 194 self._concentration = self._maybe_assert_valid_concentration( 195 ops.convert_to_tensor(concentration, name="concentration"), 196 validate_args) 197 self._total_concentration = math_ops.reduce_sum(self._concentration, -1) 198 super(Dirichlet, self).__init__( 199 dtype=self._concentration.dtype, 200 validate_args=validate_args, 201 allow_nan_stats=allow_nan_stats, 202 reparameterization_type=distribution.FULLY_REPARAMETERIZED, 203 parameters=parameters, 204 graph_parents=[self._concentration, 205 self._total_concentration], 206 name=name) 207 208 @property 209 def concentration(self): 210 """Concentration parameter; expected counts for that coordinate.""" 211 return self._concentration 212 213 @property 214 def total_concentration(self): 215 """Sum of last dim of concentration parameter.""" 216 return self._total_concentration 217 218 def _batch_shape_tensor(self): 219 return array_ops.shape(self.total_concentration) 220 221 def _batch_shape(self): 222 return self.total_concentration.get_shape() 223 224 def _event_shape_tensor(self): 225 return array_ops.shape(self.concentration)[-1:] 226 227 def _event_shape(self): 228 return self.concentration.get_shape().with_rank_at_least(1)[-1:] 229 230 def _sample_n(self, n, seed=None): 231 gamma_sample = random_ops.random_gamma( 232 shape=[n], 233 alpha=self.concentration, 234 dtype=self.dtype, 235 seed=seed) 236 return gamma_sample / math_ops.reduce_sum(gamma_sample, -1, keepdims=True) 237 238 @distribution_util.AppendDocstring(_dirichlet_sample_note) 239 def _log_prob(self, x): 240 return self._log_unnormalized_prob(x) - self._log_normalization() 241 242 @distribution_util.AppendDocstring(_dirichlet_sample_note) 243 def _prob(self, x): 244 return math_ops.exp(self._log_prob(x)) 245 246 def _log_unnormalized_prob(self, x): 247 x = self._maybe_assert_valid_sample(x) 248 return math_ops.reduce_sum(math_ops.xlogy(self.concentration - 1., x), -1) 249 250 def _log_normalization(self): 251 return special_math_ops.lbeta(self.concentration) 252 253 def _entropy(self): 254 k = math_ops.cast(self.event_shape_tensor()[0], self.dtype) 255 return ( 256 self._log_normalization() 257 + ((self.total_concentration - k) 258 * math_ops.digamma(self.total_concentration)) 259 - math_ops.reduce_sum( 260 (self.concentration - 1.) * math_ops.digamma(self.concentration), 261 axis=-1)) 262 263 def _mean(self): 264 return self.concentration / self.total_concentration[..., array_ops.newaxis] 265 266 def _covariance(self): 267 x = self._variance_scale_term() * self._mean() 268 # pylint: disable=invalid-unary-operand-type 269 return array_ops.matrix_set_diag( 270 -math_ops.matmul( 271 x[..., array_ops.newaxis], 272 x[..., array_ops.newaxis, :]), # outer prod 273 self._variance()) 274 275 def _variance(self): 276 scale = self._variance_scale_term() 277 x = scale * self._mean() 278 return x * (scale - x) 279 280 def _variance_scale_term(self): 281 """Helper to `_covariance` and `_variance` which computes a shared scale.""" 282 return math_ops.rsqrt(1. + self.total_concentration[..., array_ops.newaxis]) 283 284 @distribution_util.AppendDocstring( 285 """Note: The mode is undefined when any `concentration <= 1`. If 286 `self.allow_nan_stats` is `True`, `NaN` is used for undefined modes. If 287 `self.allow_nan_stats` is `False` an exception is raised when one or more 288 modes are undefined.""") 289 def _mode(self): 290 k = math_ops.cast(self.event_shape_tensor()[0], self.dtype) 291 mode = (self.concentration - 1.) / ( 292 self.total_concentration[..., array_ops.newaxis] - k) 293 if self.allow_nan_stats: 294 nan = array_ops.fill( 295 array_ops.shape(mode), 296 np.array(np.nan, dtype=self.dtype.as_numpy_dtype()), 297 name="nan") 298 return array_ops.where_v2( 299 math_ops.reduce_all(self.concentration > 1., axis=-1), mode, nan) 300 return control_flow_ops.with_dependencies([ 301 check_ops.assert_less( 302 array_ops.ones([], self.dtype), 303 self.concentration, 304 message="Mode undefined when any concentration <= 1"), 305 ], mode) 306 307 def _maybe_assert_valid_concentration(self, concentration, validate_args): 308 """Checks the validity of the concentration parameter.""" 309 if not validate_args: 310 return concentration 311 return control_flow_ops.with_dependencies([ 312 check_ops.assert_positive( 313 concentration, 314 message="Concentration parameter must be positive."), 315 check_ops.assert_rank_at_least( 316 concentration, 1, 317 message="Concentration parameter must have >=1 dimensions."), 318 check_ops.assert_less( 319 1, array_ops.shape(concentration)[-1], 320 message="Concentration parameter must have event_size >= 2."), 321 ], concentration) 322 323 def _maybe_assert_valid_sample(self, x): 324 """Checks the validity of a sample.""" 325 if not self.validate_args: 326 return x 327 return control_flow_ops.with_dependencies([ 328 check_ops.assert_positive(x, message="samples must be positive"), 329 check_ops.assert_near( 330 array_ops.ones([], dtype=self.dtype), 331 math_ops.reduce_sum(x, -1), 332 message="sample last-dimension must sum to `1`"), 333 ], x) 334 335 336@kullback_leibler.RegisterKL(Dirichlet, Dirichlet) 337def _kl_dirichlet_dirichlet(d1, d2, name=None): 338 """Batchwise KL divergence KL(d1 || d2) with d1 and d2 Dirichlet. 339 340 Args: 341 d1: instance of a Dirichlet distribution object. 342 d2: instance of a Dirichlet distribution object. 343 name: (optional) Name to use for created operations. 344 default is "kl_dirichlet_dirichlet". 345 346 Returns: 347 Batchwise KL(d1 || d2) 348 """ 349 with ops.name_scope(name, "kl_dirichlet_dirichlet", values=[ 350 d1.concentration, d2.concentration]): 351 # The KL between Dirichlet distributions can be derived as follows. We have 352 # 353 # Dir(x; a) = 1 / B(a) * prod_i[x[i]^(a[i] - 1)] 354 # 355 # where B(a) is the multivariate Beta function: 356 # 357 # B(a) = Gamma(a[1]) * ... * Gamma(a[n]) / Gamma(a[1] + ... + a[n]) 358 # 359 # The KL is 360 # 361 # KL(Dir(x; a), Dir(x; b)) = E_Dir(x; a){log(Dir(x; a) / Dir(x; b))} 362 # 363 # so we'll need to know the log density of the Dirichlet. This is 364 # 365 # log(Dir(x; a)) = sum_i[(a[i] - 1) log(x[i])] - log B(a) 366 # 367 # The only term that matters for the expectations is the log(x[i]). To 368 # compute the expectation of this term over the Dirichlet density, we can 369 # use the following facts about the Dirichlet in exponential family form: 370 # 1. log(x[i]) is a sufficient statistic 371 # 2. expected sufficient statistics (of any exp family distribution) are 372 # equal to derivatives of the log normalizer with respect to 373 # corresponding natural parameters: E{T[i](x)} = dA/d(eta[i]) 374 # 375 # To proceed, we can rewrite the Dirichlet density in exponential family 376 # form as follows: 377 # 378 # Dir(x; a) = exp{eta(a) . T(x) - A(a)} 379 # 380 # where '.' is the dot product of vectors eta and T, and A is a scalar: 381 # 382 # eta[i](a) = a[i] - 1 383 # T[i](x) = log(x[i]) 384 # A(a) = log B(a) 385 # 386 # Now, we can use fact (2) above to write 387 # 388 # E_Dir(x; a)[log(x[i])] 389 # = dA(a) / da[i] 390 # = d/da[i] log B(a) 391 # = d/da[i] (sum_j lgamma(a[j])) - lgamma(sum_j a[j]) 392 # = digamma(a[i])) - digamma(sum_j a[j]) 393 # 394 # Putting it all together, we have 395 # 396 # KL[Dir(x; a) || Dir(x; b)] 397 # = E_Dir(x; a){log(Dir(x; a) / Dir(x; b)} 398 # = E_Dir(x; a){sum_i[(a[i] - b[i]) log(x[i])} - (lbeta(a) - lbeta(b)) 399 # = sum_i[(a[i] - b[i]) * E_Dir(x; a){log(x[i])}] - lbeta(a) + lbeta(b) 400 # = sum_i[(a[i] - b[i]) * (digamma(a[i]) - digamma(sum_j a[j]))] 401 # - lbeta(a) + lbeta(b)) 402 403 digamma_sum_d1 = math_ops.digamma( 404 math_ops.reduce_sum(d1.concentration, axis=-1, keepdims=True)) 405 digamma_diff = math_ops.digamma(d1.concentration) - digamma_sum_d1 406 concentration_diff = d1.concentration - d2.concentration 407 408 return (math_ops.reduce_sum(concentration_diff * digamma_diff, axis=-1) - 409 special_math_ops.lbeta(d1.concentration) + 410 special_math_ops.lbeta(d2.concentration)) 411