• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""The Dirichlet distribution class."""
16
17import numpy as np
18
19from tensorflow.python.framework import ops
20from tensorflow.python.ops import array_ops
21from tensorflow.python.ops import check_ops
22from tensorflow.python.ops import control_flow_ops
23from tensorflow.python.ops import math_ops
24from tensorflow.python.ops import random_ops
25from tensorflow.python.ops import special_math_ops
26from tensorflow.python.ops.distributions import distribution
27from tensorflow.python.ops.distributions import kullback_leibler
28from tensorflow.python.ops.distributions import util as distribution_util
29from tensorflow.python.util import deprecation
30from tensorflow.python.util.tf_export import tf_export
31
32
33__all__ = [
34    "Dirichlet",
35]
36
37
38_dirichlet_sample_note = """Note: `value` must be a non-negative tensor with
39dtype `self.dtype` and be in the `(self.event_shape() - 1)`-simplex, i.e.,
40`tf.reduce_sum(value, -1) = 1`. It must have a shape compatible with
41`self.batch_shape() + self.event_shape()`."""
42
43
44@tf_export(v1=["distributions.Dirichlet"])
45class Dirichlet(distribution.Distribution):
46  """Dirichlet distribution.
47
48  The Dirichlet distribution is defined over the
49  [`(k-1)`-simplex](https://en.wikipedia.org/wiki/Simplex) using a positive,
50  length-`k` vector `concentration` (`k > 1`). The Dirichlet is identically the
51  Beta distribution when `k = 2`.
52
53  #### Mathematical Details
54
55  The Dirichlet is a distribution over the open `(k-1)`-simplex, i.e.,
56
57  ```none
58  S^{k-1} = { (x_0, ..., x_{k-1}) in R^k : sum_j x_j = 1 and all_j x_j > 0 }.
59  ```
60
61  The probability density function (pdf) is,
62
63  ```none
64  pdf(x; alpha) = prod_j x_j**(alpha_j - 1) / Z
65  Z = prod_j Gamma(alpha_j) / Gamma(sum_j alpha_j)
66  ```
67
68  where:
69
70  * `x in S^{k-1}`, i.e., the `(k-1)`-simplex,
71  * `concentration = alpha = [alpha_0, ..., alpha_{k-1}]`, `alpha_j > 0`,
72  * `Z` is the normalization constant aka the [multivariate beta function](
73    https://en.wikipedia.org/wiki/Beta_function#Multivariate_beta_function),
74    and,
75  * `Gamma` is the [gamma function](
76    https://en.wikipedia.org/wiki/Gamma_function).
77
78  The `concentration` represents mean total counts of class occurrence, i.e.,
79
80  ```none
81  concentration = alpha = mean * total_concentration
82  ```
83
84  where `mean` in `S^{k-1}` and `total_concentration` is a positive real number
85  representing a mean total count.
86
87  Distribution parameters are automatically broadcast in all functions; see
88  examples for details.
89
90  Warning: Some components of the samples can be zero due to finite precision.
91  This happens more often when some of the concentrations are very small.
92  Make sure to round the samples to `np.finfo(dtype).tiny` before computing the
93  density.
94
95  Samples of this distribution are reparameterized (pathwise differentiable).
96  The derivatives are computed using the approach described in
97  (Figurnov et al., 2018).
98
99  #### Examples
100
101  ```python
102  import tensorflow_probability as tfp
103  tfd = tfp.distributions
104
105  # Create a single trivariate Dirichlet, with the 3rd class being three times
106  # more frequent than the first. I.e., batch_shape=[], event_shape=[3].
107  alpha = [1., 2, 3]
108  dist = tfd.Dirichlet(alpha)
109
110  dist.sample([4, 5])  # shape: [4, 5, 3]
111
112  # x has one sample, one batch, three classes:
113  x = [.2, .3, .5]   # shape: [3]
114  dist.prob(x)       # shape: []
115
116  # x has two samples from one batch:
117  x = [[.1, .4, .5],
118       [.2, .3, .5]]
119  dist.prob(x)         # shape: [2]
120
121  # alpha will be broadcast to shape [5, 7, 3] to match x.
122  x = [[...]]   # shape: [5, 7, 3]
123  dist.prob(x)  # shape: [5, 7]
124  ```
125
126  ```python
127  # Create batch_shape=[2], event_shape=[3]:
128  alpha = [[1., 2, 3],
129           [4, 5, 6]]   # shape: [2, 3]
130  dist = tfd.Dirichlet(alpha)
131
132  dist.sample([4, 5])  # shape: [4, 5, 2, 3]
133
134  x = [.2, .3, .5]
135  # x will be broadcast as [[.2, .3, .5],
136  #                         [.2, .3, .5]],
137  # thus matching batch_shape [2, 3].
138  dist.prob(x)         # shape: [2]
139  ```
140
141  Compute the gradients of samples w.r.t. the parameters:
142
143  ```python
144  alpha = tf.constant([1.0, 2.0, 3.0])
145  dist = tfd.Dirichlet(alpha)
146  samples = dist.sample(5)  # Shape [5, 3]
147  loss = tf.reduce_mean(tf.square(samples))  # Arbitrary loss function
148  # Unbiased stochastic gradients of the loss function
149  grads = tf.gradients(loss, alpha)
150  ```
151
152  References:
153    Implicit Reparameterization Gradients:
154      [Figurnov et al., 2018]
155      (http://papers.nips.cc/paper/7326-implicit-reparameterization-gradients)
156      ([pdf]
157      (http://papers.nips.cc/paper/7326-implicit-reparameterization-gradients.pdf))
158  """
159
160  @deprecation.deprecated(
161      "2019-01-01",
162      "The TensorFlow Distributions library has moved to "
163      "TensorFlow Probability "
164      "(https://github.com/tensorflow/probability). You "
165      "should update all references to use `tfp.distributions` "
166      "instead of `tf.distributions`.",
167      warn_once=True)
168  def __init__(self,
169               concentration,
170               validate_args=False,
171               allow_nan_stats=True,
172               name="Dirichlet"):
173    """Initialize a batch of Dirichlet distributions.
174
175    Args:
176      concentration: Positive floating-point `Tensor` indicating mean number
177        of class occurrences; aka "alpha". Implies `self.dtype`, and
178        `self.batch_shape`, `self.event_shape`, i.e., if
179        `concentration.shape = [N1, N2, ..., Nm, k]` then
180        `batch_shape = [N1, N2, ..., Nm]` and
181        `event_shape = [k]`.
182      validate_args: Python `bool`, default `False`. When `True` distribution
183        parameters are checked for validity despite possibly degrading runtime
184        performance. When `False` invalid inputs may silently render incorrect
185        outputs.
186      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
187        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
188        result is undefined. When `False`, an exception is raised if one or
189        more of the statistic's batch members are undefined.
190      name: Python `str` name prefixed to Ops created by this class.
191    """
192    parameters = dict(locals())
193    with ops.name_scope(name, values=[concentration]) as name:
194      self._concentration = self._maybe_assert_valid_concentration(
195          ops.convert_to_tensor(concentration, name="concentration"),
196          validate_args)
197      self._total_concentration = math_ops.reduce_sum(self._concentration, -1)
198    super(Dirichlet, self).__init__(
199        dtype=self._concentration.dtype,
200        validate_args=validate_args,
201        allow_nan_stats=allow_nan_stats,
202        reparameterization_type=distribution.FULLY_REPARAMETERIZED,
203        parameters=parameters,
204        graph_parents=[self._concentration,
205                       self._total_concentration],
206        name=name)
207
208  @property
209  def concentration(self):
210    """Concentration parameter; expected counts for that coordinate."""
211    return self._concentration
212
213  @property
214  def total_concentration(self):
215    """Sum of last dim of concentration parameter."""
216    return self._total_concentration
217
218  def _batch_shape_tensor(self):
219    return array_ops.shape(self.total_concentration)
220
221  def _batch_shape(self):
222    return self.total_concentration.get_shape()
223
224  def _event_shape_tensor(self):
225    return array_ops.shape(self.concentration)[-1:]
226
227  def _event_shape(self):
228    return self.concentration.get_shape().with_rank_at_least(1)[-1:]
229
230  def _sample_n(self, n, seed=None):
231    gamma_sample = random_ops.random_gamma(
232        shape=[n],
233        alpha=self.concentration,
234        dtype=self.dtype,
235        seed=seed)
236    return gamma_sample / math_ops.reduce_sum(gamma_sample, -1, keepdims=True)
237
238  @distribution_util.AppendDocstring(_dirichlet_sample_note)
239  def _log_prob(self, x):
240    return self._log_unnormalized_prob(x) - self._log_normalization()
241
242  @distribution_util.AppendDocstring(_dirichlet_sample_note)
243  def _prob(self, x):
244    return math_ops.exp(self._log_prob(x))
245
246  def _log_unnormalized_prob(self, x):
247    x = self._maybe_assert_valid_sample(x)
248    return math_ops.reduce_sum(math_ops.xlogy(self.concentration - 1., x), -1)
249
250  def _log_normalization(self):
251    return special_math_ops.lbeta(self.concentration)
252
253  def _entropy(self):
254    k = math_ops.cast(self.event_shape_tensor()[0], self.dtype)
255    return (
256        self._log_normalization()
257        + ((self.total_concentration - k)
258           * math_ops.digamma(self.total_concentration))
259        - math_ops.reduce_sum(
260            (self.concentration - 1.) * math_ops.digamma(self.concentration),
261            axis=-1))
262
263  def _mean(self):
264    return self.concentration / self.total_concentration[..., array_ops.newaxis]
265
266  def _covariance(self):
267    x = self._variance_scale_term() * self._mean()
268    # pylint: disable=invalid-unary-operand-type
269    return array_ops.matrix_set_diag(
270        -math_ops.matmul(
271            x[..., array_ops.newaxis],
272            x[..., array_ops.newaxis, :]),  # outer prod
273        self._variance())
274
275  def _variance(self):
276    scale = self._variance_scale_term()
277    x = scale * self._mean()
278    return x * (scale - x)
279
280  def _variance_scale_term(self):
281    """Helper to `_covariance` and `_variance` which computes a shared scale."""
282    return math_ops.rsqrt(1. + self.total_concentration[..., array_ops.newaxis])
283
284  @distribution_util.AppendDocstring(
285      """Note: The mode is undefined when any `concentration <= 1`. If
286      `self.allow_nan_stats` is `True`, `NaN` is used for undefined modes. If
287      `self.allow_nan_stats` is `False` an exception is raised when one or more
288      modes are undefined.""")
289  def _mode(self):
290    k = math_ops.cast(self.event_shape_tensor()[0], self.dtype)
291    mode = (self.concentration - 1.) / (
292        self.total_concentration[..., array_ops.newaxis] - k)
293    if self.allow_nan_stats:
294      nan = array_ops.fill(
295          array_ops.shape(mode),
296          np.array(np.nan, dtype=self.dtype.as_numpy_dtype()),
297          name="nan")
298      return array_ops.where_v2(
299          math_ops.reduce_all(self.concentration > 1., axis=-1), mode, nan)
300    return control_flow_ops.with_dependencies([
301        check_ops.assert_less(
302            array_ops.ones([], self.dtype),
303            self.concentration,
304            message="Mode undefined when any concentration <= 1"),
305    ], mode)
306
307  def _maybe_assert_valid_concentration(self, concentration, validate_args):
308    """Checks the validity of the concentration parameter."""
309    if not validate_args:
310      return concentration
311    return control_flow_ops.with_dependencies([
312        check_ops.assert_positive(
313            concentration,
314            message="Concentration parameter must be positive."),
315        check_ops.assert_rank_at_least(
316            concentration, 1,
317            message="Concentration parameter must have >=1 dimensions."),
318        check_ops.assert_less(
319            1, array_ops.shape(concentration)[-1],
320            message="Concentration parameter must have event_size >= 2."),
321    ], concentration)
322
323  def _maybe_assert_valid_sample(self, x):
324    """Checks the validity of a sample."""
325    if not self.validate_args:
326      return x
327    return control_flow_ops.with_dependencies([
328        check_ops.assert_positive(x, message="samples must be positive"),
329        check_ops.assert_near(
330            array_ops.ones([], dtype=self.dtype),
331            math_ops.reduce_sum(x, -1),
332            message="sample last-dimension must sum to `1`"),
333    ], x)
334
335
336@kullback_leibler.RegisterKL(Dirichlet, Dirichlet)
337def _kl_dirichlet_dirichlet(d1, d2, name=None):
338  """Batchwise KL divergence KL(d1 || d2) with d1 and d2 Dirichlet.
339
340  Args:
341    d1: instance of a Dirichlet distribution object.
342    d2: instance of a Dirichlet distribution object.
343    name: (optional) Name to use for created operations.
344      default is "kl_dirichlet_dirichlet".
345
346  Returns:
347    Batchwise KL(d1 || d2)
348  """
349  with ops.name_scope(name, "kl_dirichlet_dirichlet", values=[
350      d1.concentration, d2.concentration]):
351    # The KL between Dirichlet distributions can be derived as follows. We have
352    #
353    #   Dir(x; a) = 1 / B(a) * prod_i[x[i]^(a[i] - 1)]
354    #
355    # where B(a) is the multivariate Beta function:
356    #
357    #   B(a) = Gamma(a[1]) * ... * Gamma(a[n]) / Gamma(a[1] + ... + a[n])
358    #
359    # The KL is
360    #
361    #   KL(Dir(x; a), Dir(x; b)) = E_Dir(x; a){log(Dir(x; a) / Dir(x; b))}
362    #
363    # so we'll need to know the log density of the Dirichlet. This is
364    #
365    #   log(Dir(x; a)) = sum_i[(a[i] - 1) log(x[i])] - log B(a)
366    #
367    # The only term that matters for the expectations is the log(x[i]). To
368    # compute the expectation of this term over the Dirichlet density, we can
369    # use the following facts about the Dirichlet in exponential family form:
370    #   1. log(x[i]) is a sufficient statistic
371    #   2. expected sufficient statistics (of any exp family distribution) are
372    #      equal to derivatives of the log normalizer with respect to
373    #      corresponding natural parameters: E{T[i](x)} = dA/d(eta[i])
374    #
375    # To proceed, we can rewrite the Dirichlet density in exponential family
376    # form as follows:
377    #
378    #   Dir(x; a) = exp{eta(a) . T(x) - A(a)}
379    #
380    # where '.' is the dot product of vectors eta and T, and A is a scalar:
381    #
382    #   eta[i](a) = a[i] - 1
383    #     T[i](x) = log(x[i])
384    #        A(a) = log B(a)
385    #
386    # Now, we can use fact (2) above to write
387    #
388    #   E_Dir(x; a)[log(x[i])]
389    #       = dA(a) / da[i]
390    #       = d/da[i] log B(a)
391    #       = d/da[i] (sum_j lgamma(a[j])) - lgamma(sum_j a[j])
392    #       = digamma(a[i])) - digamma(sum_j a[j])
393    #
394    # Putting it all together, we have
395    #
396    # KL[Dir(x; a) || Dir(x; b)]
397    #     = E_Dir(x; a){log(Dir(x; a) / Dir(x; b)}
398    #     = E_Dir(x; a){sum_i[(a[i] - b[i]) log(x[i])} - (lbeta(a) - lbeta(b))
399    #     = sum_i[(a[i] - b[i]) * E_Dir(x; a){log(x[i])}] - lbeta(a) + lbeta(b)
400    #     = sum_i[(a[i] - b[i]) * (digamma(a[i]) - digamma(sum_j a[j]))]
401    #          - lbeta(a) + lbeta(b))
402
403    digamma_sum_d1 = math_ops.digamma(
404        math_ops.reduce_sum(d1.concentration, axis=-1, keepdims=True))
405    digamma_diff = math_ops.digamma(d1.concentration) - digamma_sum_d1
406    concentration_diff = d1.concentration - d2.concentration
407
408    return (math_ops.reduce_sum(concentration_diff * digamma_diff, axis=-1) -
409            special_math_ops.lbeta(d1.concentration) +
410            special_math_ops.lbeta(d2.concentration))
411