• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""The DirichletMultinomial distribution class."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.framework import dtypes
22from tensorflow.python.framework import ops
23from tensorflow.python.ops import array_ops
24from tensorflow.python.ops import check_ops
25from tensorflow.python.ops import control_flow_ops
26from tensorflow.python.ops import math_ops
27from tensorflow.python.ops import random_ops
28from tensorflow.python.ops import special_math_ops
29from tensorflow.python.ops.distributions import distribution
30from tensorflow.python.ops.distributions import util as distribution_util
31from tensorflow.python.util import deprecation
32from tensorflow.python.util.tf_export import tf_export
33
34
35__all__ = [
36    "DirichletMultinomial",
37]
38
39
40_dirichlet_multinomial_sample_note = """For each batch of counts,
41`value = [n_0, ..., n_{K-1}]`, `P[value]` is the probability that after
42sampling `self.total_count` draws from this Dirichlet-Multinomial distribution,
43the number of draws falling in class `j` is `n_j`. Since this definition is
44[exchangeable](https://en.wikipedia.org/wiki/Exchangeable_random_variables);
45different sequences have the same counts so the probability includes a
46combinatorial coefficient.
47
48Note: `value` must be a non-negative tensor with dtype `self.dtype`, have no
49fractional components, and such that
50`tf.reduce_sum(value, -1) = self.total_count`. Its shape must be broadcastable
51with `self.concentration` and `self.total_count`."""
52
53
54@tf_export(v1=["distributions.DirichletMultinomial"])
55class DirichletMultinomial(distribution.Distribution):
56  """Dirichlet-Multinomial compound distribution.
57
58  The Dirichlet-Multinomial distribution is parameterized by a (batch of)
59  length-`K` `concentration` vectors (`K > 1`) and a `total_count` number of
60  trials, i.e., the number of trials per draw from the DirichletMultinomial. It
61  is defined over a (batch of) length-`K` vector `counts` such that
62  `tf.reduce_sum(counts, -1) = total_count`. The Dirichlet-Multinomial is
63  identically the Beta-Binomial distribution when `K = 2`.
64
65  #### Mathematical Details
66
67  The Dirichlet-Multinomial is a distribution over `K`-class counts, i.e., a
68  length-`K` vector of non-negative integer `counts = n = [n_0, ..., n_{K-1}]`.
69
70  The probability mass function (pmf) is,
71
72  ```none
73  pmf(n; alpha, N) = Beta(alpha + n) / (prod_j n_j!) / Z
74  Z = Beta(alpha) / N!
75  ```
76
77  where:
78
79  * `concentration = alpha = [alpha_0, ..., alpha_{K-1}]`, `alpha_j > 0`,
80  * `total_count = N`, `N` a positive integer,
81  * `N!` is `N` factorial, and,
82  * `Beta(x) = prod_j Gamma(x_j) / Gamma(sum_j x_j)` is the
83    [multivariate beta function](
84    https://en.wikipedia.org/wiki/Beta_function#Multivariate_beta_function),
85    and,
86  * `Gamma` is the [gamma function](
87    https://en.wikipedia.org/wiki/Gamma_function).
88
89  Dirichlet-Multinomial is a [compound distribution](
90  https://en.wikipedia.org/wiki/Compound_probability_distribution), i.e., its
91  samples are generated as follows.
92
93    1. Choose class probabilities:
94       `probs = [p_0,...,p_{K-1}] ~ Dir(concentration)`
95    2. Draw integers:
96       `counts = [n_0,...,n_{K-1}] ~ Multinomial(total_count, probs)`
97
98  The last `concentration` dimension parametrizes a single Dirichlet-Multinomial
99  distribution. When calling distribution functions (e.g., `dist.prob(counts)`),
100  `concentration`, `total_count` and `counts` are broadcast to the same shape.
101  The last dimension of `counts` corresponds single Dirichlet-Multinomial
102  distributions.
103
104  Distribution parameters are automatically broadcast in all functions; see
105  examples for details.
106
107  #### Pitfalls
108
109  The number of classes, `K`, must not exceed:
110  - the largest integer representable by `self.dtype`, i.e.,
111    `2**(mantissa_bits+1)` (IEE754),
112  - the maximum `Tensor` index, i.e., `2**31-1`.
113
114  In other words,
115
116  ```python
117  K <= min(2**31-1, {
118    tf.float16: 2**11,
119    tf.float32: 2**24,
120    tf.float64: 2**53 }[param.dtype])
121  ```
122
123  Note: This condition is validated only when `self.validate_args = True`.
124
125  #### Examples
126
127  ```python
128  alpha = [1., 2., 3.]
129  n = 2.
130  dist = DirichletMultinomial(n, alpha)
131  ```
132
133  Creates a 3-class distribution, with the 3rd class is most likely to be
134  drawn.
135  The distribution functions can be evaluated on counts.
136
137  ```python
138  # counts same shape as alpha.
139  counts = [0., 0., 2.]
140  dist.prob(counts)  # Shape []
141
142  # alpha will be broadcast to [[1., 2., 3.], [1., 2., 3.]] to match counts.
143  counts = [[1., 1., 0.], [1., 0., 1.]]
144  dist.prob(counts)  # Shape [2]
145
146  # alpha will be broadcast to shape [5, 7, 3] to match counts.
147  counts = [[...]]  # Shape [5, 7, 3]
148  dist.prob(counts)  # Shape [5, 7]
149  ```
150
151  Creates a 2-batch of 3-class distributions.
152
153  ```python
154  alpha = [[1., 2., 3.], [4., 5., 6.]]  # Shape [2, 3]
155  n = [3., 3.]
156  dist = DirichletMultinomial(n, alpha)
157
158  # counts will be broadcast to [[2., 1., 0.], [2., 1., 0.]] to match alpha.
159  counts = [2., 1., 0.]
160  dist.prob(counts)  # Shape [2]
161  ```
162
163  """
164
165  # TODO(b/27419586) Change docstring for dtype of concentration once int
166  # allowed.
167  @deprecation.deprecated(
168      "2019-01-01",
169      "The TensorFlow Distributions library has moved to "
170      "TensorFlow Probability "
171      "(https://github.com/tensorflow/probability). You "
172      "should update all references to use `tfp.distributions` "
173      "instead of `tf.distributions`.",
174      warn_once=True)
175  def __init__(self,
176               total_count,
177               concentration,
178               validate_args=False,
179               allow_nan_stats=True,
180               name="DirichletMultinomial"):
181    """Initialize a batch of DirichletMultinomial distributions.
182
183    Args:
184      total_count:  Non-negative floating point tensor, whose dtype is the same
185        as `concentration`. The shape is broadcastable to `[N1,..., Nm]` with
186        `m >= 0`. Defines this as a batch of `N1 x ... x Nm` different
187        Dirichlet multinomial distributions. Its components should be equal to
188        integer values.
189      concentration: Positive floating point tensor, whose dtype is the
190        same as `n` with shape broadcastable to `[N1,..., Nm, K]` `m >= 0`.
191        Defines this as a batch of `N1 x ... x Nm` different `K` class Dirichlet
192        multinomial distributions.
193      validate_args: Python `bool`, default `False`. When `True` distribution
194        parameters are checked for validity despite possibly degrading runtime
195        performance. When `False` invalid inputs may silently render incorrect
196        outputs.
197      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
198        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
199        result is undefined. When `False`, an exception is raised if one or
200        more of the statistic's batch members are undefined.
201      name: Python `str` name prefixed to Ops created by this class.
202    """
203    parameters = dict(locals())
204    with ops.name_scope(name, values=[total_count, concentration]) as name:
205      # Broadcasting works because:
206      # * The broadcasting convention is to prepend dimensions of size [1], and
207      #   we use the last dimension for the distribution, whereas
208      #   the batch dimensions are the leading dimensions, which forces the
209      #   distribution dimension to be defined explicitly (i.e. it cannot be
210      #   created automatically by prepending). This forces enough explicitness.
211      # * All calls involving `counts` eventually require a broadcast between
212      #  `counts` and concentration.
213      self._total_count = ops.convert_to_tensor(total_count, name="total_count")
214      if validate_args:
215        self._total_count = (
216            distribution_util.embed_check_nonnegative_integer_form(
217                self._total_count))
218      self._concentration = self._maybe_assert_valid_concentration(
219          ops.convert_to_tensor(concentration,
220                                name="concentration"),
221          validate_args)
222      self._total_concentration = math_ops.reduce_sum(self._concentration, -1)
223    super(DirichletMultinomial, self).__init__(
224        dtype=self._concentration.dtype,
225        validate_args=validate_args,
226        allow_nan_stats=allow_nan_stats,
227        reparameterization_type=distribution.NOT_REPARAMETERIZED,
228        parameters=parameters,
229        graph_parents=[self._total_count,
230                       self._concentration],
231        name=name)
232
233  @property
234  def total_count(self):
235    """Number of trials used to construct a sample."""
236    return self._total_count
237
238  @property
239  def concentration(self):
240    """Concentration parameter; expected prior counts for that coordinate."""
241    return self._concentration
242
243  @property
244  def total_concentration(self):
245    """Sum of last dim of concentration parameter."""
246    return self._total_concentration
247
248  def _batch_shape_tensor(self):
249    return array_ops.shape(self.total_concentration)
250
251  def _batch_shape(self):
252    return self.total_concentration.get_shape()
253
254  def _event_shape_tensor(self):
255    return array_ops.shape(self.concentration)[-1:]
256
257  def _event_shape(self):
258    # Event shape depends only on total_concentration, not "n".
259    return self.concentration.get_shape().with_rank_at_least(1)[-1:]
260
261  def _sample_n(self, n, seed=None):
262    n_draws = math_ops.cast(self.total_count, dtype=dtypes.int32)
263    k = self.event_shape_tensor()[0]
264    unnormalized_logits = array_ops.reshape(
265        math_ops.log(random_ops.random_gamma(
266            shape=[n],
267            alpha=self.concentration,
268            dtype=self.dtype,
269            seed=seed)),
270        shape=[-1, k])
271    draws = random_ops.multinomial(
272        logits=unnormalized_logits,
273        num_samples=n_draws,
274        seed=distribution_util.gen_new_seed(seed, salt="dirichlet_multinomial"))
275    x = math_ops.reduce_sum(array_ops.one_hot(draws, depth=k), -2)
276    final_shape = array_ops.concat([[n], self.batch_shape_tensor(), [k]], 0)
277    x = array_ops.reshape(x, final_shape)
278    return math_ops.cast(x, self.dtype)
279
280  @distribution_util.AppendDocstring(_dirichlet_multinomial_sample_note)
281  def _log_prob(self, counts):
282    counts = self._maybe_assert_valid_sample(counts)
283    ordered_prob = (
284        special_math_ops.lbeta(self.concentration + counts)
285        - special_math_ops.lbeta(self.concentration))
286    return ordered_prob + distribution_util.log_combinations(
287        self.total_count, counts)
288
289  @distribution_util.AppendDocstring(_dirichlet_multinomial_sample_note)
290  def _prob(self, counts):
291    return math_ops.exp(self._log_prob(counts))
292
293  def _mean(self):
294    return self.total_count * (self.concentration /
295                               self.total_concentration[..., array_ops.newaxis])
296
297  @distribution_util.AppendDocstring(
298      """The covariance for each batch member is defined as the following:
299
300      ```none
301      Var(X_j) = n * alpha_j / alpha_0 * (1 - alpha_j / alpha_0) *
302      (n + alpha_0) / (1 + alpha_0)
303      ```
304
305      where `concentration = alpha` and
306      `total_concentration = alpha_0 = sum_j alpha_j`.
307
308      The covariance between elements in a batch is defined as:
309
310      ```none
311      Cov(X_i, X_j) = -n * alpha_i * alpha_j / alpha_0 ** 2 *
312      (n + alpha_0) / (1 + alpha_0)
313      ```
314      """)
315  def _covariance(self):
316    x = self._variance_scale_term() * self._mean()
317    return array_ops.matrix_set_diag(
318        -math_ops.matmul(x[..., array_ops.newaxis],
319                         x[..., array_ops.newaxis, :]),  # outer prod
320        self._variance())
321
322  def _variance(self):
323    scale = self._variance_scale_term()
324    x = scale * self._mean()
325    return x * (self.total_count * scale - x)
326
327  def _variance_scale_term(self):
328    """Helper to `_covariance` and `_variance` which computes a shared scale."""
329    # We must take care to expand back the last dim whenever we use the
330    # total_concentration.
331    c0 = self.total_concentration[..., array_ops.newaxis]
332    return math_ops.sqrt((1. + c0 / self.total_count) / (1. + c0))
333
334  def _maybe_assert_valid_concentration(self, concentration, validate_args):
335    """Checks the validity of the concentration parameter."""
336    if not validate_args:
337      return concentration
338    concentration = distribution_util.embed_check_categorical_event_shape(
339        concentration)
340    return control_flow_ops.with_dependencies([
341        check_ops.assert_positive(
342            concentration,
343            message="Concentration parameter must be positive."),
344    ], concentration)
345
346  def _maybe_assert_valid_sample(self, counts):
347    """Check counts for proper shape, values, then return tensor version."""
348    if not self.validate_args:
349      return counts
350    counts = distribution_util.embed_check_nonnegative_integer_form(counts)
351    return control_flow_ops.with_dependencies([
352        check_ops.assert_equal(
353            self.total_count, math_ops.reduce_sum(counts, -1),
354            message="counts last-dimension must sum to `self.total_count`"),
355    ], counts)
356