1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""The DirichletMultinomial distribution class.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.framework import dtypes 22from tensorflow.python.framework import ops 23from tensorflow.python.ops import array_ops 24from tensorflow.python.ops import check_ops 25from tensorflow.python.ops import control_flow_ops 26from tensorflow.python.ops import math_ops 27from tensorflow.python.ops import random_ops 28from tensorflow.python.ops import special_math_ops 29from tensorflow.python.ops.distributions import distribution 30from tensorflow.python.ops.distributions import util as distribution_util 31from tensorflow.python.util import deprecation 32from tensorflow.python.util.tf_export import tf_export 33 34 35__all__ = [ 36 "DirichletMultinomial", 37] 38 39 40_dirichlet_multinomial_sample_note = """For each batch of counts, 41`value = [n_0, ..., n_{K-1}]`, `P[value]` is the probability that after 42sampling `self.total_count` draws from this Dirichlet-Multinomial distribution, 43the number of draws falling in class `j` is `n_j`. Since this definition is 44[exchangeable](https://en.wikipedia.org/wiki/Exchangeable_random_variables); 45different sequences have the same counts so the probability includes a 46combinatorial coefficient. 47 48Note: `value` must be a non-negative tensor with dtype `self.dtype`, have no 49fractional components, and such that 50`tf.reduce_sum(value, -1) = self.total_count`. Its shape must be broadcastable 51with `self.concentration` and `self.total_count`.""" 52 53 54@tf_export(v1=["distributions.DirichletMultinomial"]) 55class DirichletMultinomial(distribution.Distribution): 56 """Dirichlet-Multinomial compound distribution. 57 58 The Dirichlet-Multinomial distribution is parameterized by a (batch of) 59 length-`K` `concentration` vectors (`K > 1`) and a `total_count` number of 60 trials, i.e., the number of trials per draw from the DirichletMultinomial. It 61 is defined over a (batch of) length-`K` vector `counts` such that 62 `tf.reduce_sum(counts, -1) = total_count`. The Dirichlet-Multinomial is 63 identically the Beta-Binomial distribution when `K = 2`. 64 65 #### Mathematical Details 66 67 The Dirichlet-Multinomial is a distribution over `K`-class counts, i.e., a 68 length-`K` vector of non-negative integer `counts = n = [n_0, ..., n_{K-1}]`. 69 70 The probability mass function (pmf) is, 71 72 ```none 73 pmf(n; alpha, N) = Beta(alpha + n) / (prod_j n_j!) / Z 74 Z = Beta(alpha) / N! 75 ``` 76 77 where: 78 79 * `concentration = alpha = [alpha_0, ..., alpha_{K-1}]`, `alpha_j > 0`, 80 * `total_count = N`, `N` a positive integer, 81 * `N!` is `N` factorial, and, 82 * `Beta(x) = prod_j Gamma(x_j) / Gamma(sum_j x_j)` is the 83 [multivariate beta function]( 84 https://en.wikipedia.org/wiki/Beta_function#Multivariate_beta_function), 85 and, 86 * `Gamma` is the [gamma function]( 87 https://en.wikipedia.org/wiki/Gamma_function). 88 89 Dirichlet-Multinomial is a [compound distribution]( 90 https://en.wikipedia.org/wiki/Compound_probability_distribution), i.e., its 91 samples are generated as follows. 92 93 1. Choose class probabilities: 94 `probs = [p_0,...,p_{K-1}] ~ Dir(concentration)` 95 2. Draw integers: 96 `counts = [n_0,...,n_{K-1}] ~ Multinomial(total_count, probs)` 97 98 The last `concentration` dimension parametrizes a single Dirichlet-Multinomial 99 distribution. When calling distribution functions (e.g., `dist.prob(counts)`), 100 `concentration`, `total_count` and `counts` are broadcast to the same shape. 101 The last dimension of `counts` corresponds single Dirichlet-Multinomial 102 distributions. 103 104 Distribution parameters are automatically broadcast in all functions; see 105 examples for details. 106 107 #### Pitfalls 108 109 The number of classes, `K`, must not exceed: 110 - the largest integer representable by `self.dtype`, i.e., 111 `2**(mantissa_bits+1)` (IEE754), 112 - the maximum `Tensor` index, i.e., `2**31-1`. 113 114 In other words, 115 116 ```python 117 K <= min(2**31-1, { 118 tf.float16: 2**11, 119 tf.float32: 2**24, 120 tf.float64: 2**53 }[param.dtype]) 121 ``` 122 123 Note: This condition is validated only when `self.validate_args = True`. 124 125 #### Examples 126 127 ```python 128 alpha = [1., 2., 3.] 129 n = 2. 130 dist = DirichletMultinomial(n, alpha) 131 ``` 132 133 Creates a 3-class distribution, with the 3rd class is most likely to be 134 drawn. 135 The distribution functions can be evaluated on counts. 136 137 ```python 138 # counts same shape as alpha. 139 counts = [0., 0., 2.] 140 dist.prob(counts) # Shape [] 141 142 # alpha will be broadcast to [[1., 2., 3.], [1., 2., 3.]] to match counts. 143 counts = [[1., 1., 0.], [1., 0., 1.]] 144 dist.prob(counts) # Shape [2] 145 146 # alpha will be broadcast to shape [5, 7, 3] to match counts. 147 counts = [[...]] # Shape [5, 7, 3] 148 dist.prob(counts) # Shape [5, 7] 149 ``` 150 151 Creates a 2-batch of 3-class distributions. 152 153 ```python 154 alpha = [[1., 2., 3.], [4., 5., 6.]] # Shape [2, 3] 155 n = [3., 3.] 156 dist = DirichletMultinomial(n, alpha) 157 158 # counts will be broadcast to [[2., 1., 0.], [2., 1., 0.]] to match alpha. 159 counts = [2., 1., 0.] 160 dist.prob(counts) # Shape [2] 161 ``` 162 163 """ 164 165 # TODO(b/27419586) Change docstring for dtype of concentration once int 166 # allowed. 167 @deprecation.deprecated( 168 "2019-01-01", 169 "The TensorFlow Distributions library has moved to " 170 "TensorFlow Probability " 171 "(https://github.com/tensorflow/probability). You " 172 "should update all references to use `tfp.distributions` " 173 "instead of `tf.distributions`.", 174 warn_once=True) 175 def __init__(self, 176 total_count, 177 concentration, 178 validate_args=False, 179 allow_nan_stats=True, 180 name="DirichletMultinomial"): 181 """Initialize a batch of DirichletMultinomial distributions. 182 183 Args: 184 total_count: Non-negative floating point tensor, whose dtype is the same 185 as `concentration`. The shape is broadcastable to `[N1,..., Nm]` with 186 `m >= 0`. Defines this as a batch of `N1 x ... x Nm` different 187 Dirichlet multinomial distributions. Its components should be equal to 188 integer values. 189 concentration: Positive floating point tensor, whose dtype is the 190 same as `n` with shape broadcastable to `[N1,..., Nm, K]` `m >= 0`. 191 Defines this as a batch of `N1 x ... x Nm` different `K` class Dirichlet 192 multinomial distributions. 193 validate_args: Python `bool`, default `False`. When `True` distribution 194 parameters are checked for validity despite possibly degrading runtime 195 performance. When `False` invalid inputs may silently render incorrect 196 outputs. 197 allow_nan_stats: Python `bool`, default `True`. When `True`, statistics 198 (e.g., mean, mode, variance) use the value "`NaN`" to indicate the 199 result is undefined. When `False`, an exception is raised if one or 200 more of the statistic's batch members are undefined. 201 name: Python `str` name prefixed to Ops created by this class. 202 """ 203 parameters = dict(locals()) 204 with ops.name_scope(name, values=[total_count, concentration]) as name: 205 # Broadcasting works because: 206 # * The broadcasting convention is to prepend dimensions of size [1], and 207 # we use the last dimension for the distribution, whereas 208 # the batch dimensions are the leading dimensions, which forces the 209 # distribution dimension to be defined explicitly (i.e. it cannot be 210 # created automatically by prepending). This forces enough explicitness. 211 # * All calls involving `counts` eventually require a broadcast between 212 # `counts` and concentration. 213 self._total_count = ops.convert_to_tensor(total_count, name="total_count") 214 if validate_args: 215 self._total_count = ( 216 distribution_util.embed_check_nonnegative_integer_form( 217 self._total_count)) 218 self._concentration = self._maybe_assert_valid_concentration( 219 ops.convert_to_tensor(concentration, 220 name="concentration"), 221 validate_args) 222 self._total_concentration = math_ops.reduce_sum(self._concentration, -1) 223 super(DirichletMultinomial, self).__init__( 224 dtype=self._concentration.dtype, 225 validate_args=validate_args, 226 allow_nan_stats=allow_nan_stats, 227 reparameterization_type=distribution.NOT_REPARAMETERIZED, 228 parameters=parameters, 229 graph_parents=[self._total_count, 230 self._concentration], 231 name=name) 232 233 @property 234 def total_count(self): 235 """Number of trials used to construct a sample.""" 236 return self._total_count 237 238 @property 239 def concentration(self): 240 """Concentration parameter; expected prior counts for that coordinate.""" 241 return self._concentration 242 243 @property 244 def total_concentration(self): 245 """Sum of last dim of concentration parameter.""" 246 return self._total_concentration 247 248 def _batch_shape_tensor(self): 249 return array_ops.shape(self.total_concentration) 250 251 def _batch_shape(self): 252 return self.total_concentration.get_shape() 253 254 def _event_shape_tensor(self): 255 return array_ops.shape(self.concentration)[-1:] 256 257 def _event_shape(self): 258 # Event shape depends only on total_concentration, not "n". 259 return self.concentration.get_shape().with_rank_at_least(1)[-1:] 260 261 def _sample_n(self, n, seed=None): 262 n_draws = math_ops.cast(self.total_count, dtype=dtypes.int32) 263 k = self.event_shape_tensor()[0] 264 unnormalized_logits = array_ops.reshape( 265 math_ops.log(random_ops.random_gamma( 266 shape=[n], 267 alpha=self.concentration, 268 dtype=self.dtype, 269 seed=seed)), 270 shape=[-1, k]) 271 draws = random_ops.multinomial( 272 logits=unnormalized_logits, 273 num_samples=n_draws, 274 seed=distribution_util.gen_new_seed(seed, salt="dirichlet_multinomial")) 275 x = math_ops.reduce_sum(array_ops.one_hot(draws, depth=k), -2) 276 final_shape = array_ops.concat([[n], self.batch_shape_tensor(), [k]], 0) 277 x = array_ops.reshape(x, final_shape) 278 return math_ops.cast(x, self.dtype) 279 280 @distribution_util.AppendDocstring(_dirichlet_multinomial_sample_note) 281 def _log_prob(self, counts): 282 counts = self._maybe_assert_valid_sample(counts) 283 ordered_prob = ( 284 special_math_ops.lbeta(self.concentration + counts) 285 - special_math_ops.lbeta(self.concentration)) 286 return ordered_prob + distribution_util.log_combinations( 287 self.total_count, counts) 288 289 @distribution_util.AppendDocstring(_dirichlet_multinomial_sample_note) 290 def _prob(self, counts): 291 return math_ops.exp(self._log_prob(counts)) 292 293 def _mean(self): 294 return self.total_count * (self.concentration / 295 self.total_concentration[..., array_ops.newaxis]) 296 297 @distribution_util.AppendDocstring( 298 """The covariance for each batch member is defined as the following: 299 300 ```none 301 Var(X_j) = n * alpha_j / alpha_0 * (1 - alpha_j / alpha_0) * 302 (n + alpha_0) / (1 + alpha_0) 303 ``` 304 305 where `concentration = alpha` and 306 `total_concentration = alpha_0 = sum_j alpha_j`. 307 308 The covariance between elements in a batch is defined as: 309 310 ```none 311 Cov(X_i, X_j) = -n * alpha_i * alpha_j / alpha_0 ** 2 * 312 (n + alpha_0) / (1 + alpha_0) 313 ``` 314 """) 315 def _covariance(self): 316 x = self._variance_scale_term() * self._mean() 317 return array_ops.matrix_set_diag( 318 -math_ops.matmul(x[..., array_ops.newaxis], 319 x[..., array_ops.newaxis, :]), # outer prod 320 self._variance()) 321 322 def _variance(self): 323 scale = self._variance_scale_term() 324 x = scale * self._mean() 325 return x * (self.total_count * scale - x) 326 327 def _variance_scale_term(self): 328 """Helper to `_covariance` and `_variance` which computes a shared scale.""" 329 # We must take care to expand back the last dim whenever we use the 330 # total_concentration. 331 c0 = self.total_concentration[..., array_ops.newaxis] 332 return math_ops.sqrt((1. + c0 / self.total_count) / (1. + c0)) 333 334 def _maybe_assert_valid_concentration(self, concentration, validate_args): 335 """Checks the validity of the concentration parameter.""" 336 if not validate_args: 337 return concentration 338 concentration = distribution_util.embed_check_categorical_event_shape( 339 concentration) 340 return control_flow_ops.with_dependencies([ 341 check_ops.assert_positive( 342 concentration, 343 message="Concentration parameter must be positive."), 344 ], concentration) 345 346 def _maybe_assert_valid_sample(self, counts): 347 """Check counts for proper shape, values, then return tensor version.""" 348 if not self.validate_args: 349 return counts 350 counts = distribution_util.embed_check_nonnegative_integer_form(counts) 351 return control_flow_ops.with_dependencies([ 352 check_ops.assert_equal( 353 self.total_count, math_ops.reduce_sum(counts, -1), 354 message="counts last-dimension must sum to `self.total_count`"), 355 ], counts) 356