• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""The Multinomial distribution class."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.framework import dtypes
22from tensorflow.python.framework import ops
23from tensorflow.python.ops import array_ops
24from tensorflow.python.ops import check_ops
25from tensorflow.python.ops import control_flow_ops
26from tensorflow.python.ops import map_fn
27from tensorflow.python.ops import math_ops
28from tensorflow.python.ops import nn_ops
29from tensorflow.python.ops import random_ops
30from tensorflow.python.ops.distributions import distribution
31from tensorflow.python.ops.distributions import util as distribution_util
32from tensorflow.python.util import deprecation
33from tensorflow.python.util.tf_export import tf_export
34
35
36__all__ = [
37    "Multinomial",
38]
39
40
41_multinomial_sample_note = """For each batch of counts, `value = [n_0, ...
42,n_{k-1}]`, `P[value]` is the probability that after sampling `self.total_count`
43draws from this Multinomial distribution, the number of draws falling in class
44`j` is `n_j`. Since this definition is [exchangeable](
45https://en.wikipedia.org/wiki/Exchangeable_random_variables); different
46sequences have the same counts so the probability includes a combinatorial
47coefficient.
48
49Note: `value` must be a non-negative tensor with dtype `self.dtype`, have no
50fractional components, and such that
51`tf.reduce_sum(value, -1) = self.total_count`. Its shape must be broadcastable
52with `self.probs` and `self.total_count`."""
53
54
55@tf_export(v1=["distributions.Multinomial"])
56class Multinomial(distribution.Distribution):
57  """Multinomial distribution.
58
59  This Multinomial distribution is parameterized by `probs`, a (batch of)
60  length-`K` `prob` (probability) vectors (`K > 1`) such that
61  `tf.reduce_sum(probs, -1) = 1`, and a `total_count` number of trials, i.e.,
62  the number of trials per draw from the Multinomial. It is defined over a
63  (batch of) length-`K` vector `counts` such that
64  `tf.reduce_sum(counts, -1) = total_count`. The Multinomial is identically the
65  Binomial distribution when `K = 2`.
66
67  #### Mathematical Details
68
69  The Multinomial is a distribution over `K`-class counts, i.e., a length-`K`
70  vector of non-negative integer `counts = n = [n_0, ..., n_{K-1}]`.
71
72  The probability mass function (pmf) is,
73
74  ```none
75  pmf(n; pi, N) = prod_j (pi_j)**n_j / Z
76  Z = (prod_j n_j!) / N!
77  ```
78
79  where:
80  * `probs = pi = [pi_0, ..., pi_{K-1}]`, `pi_j > 0`, `sum_j pi_j = 1`,
81  * `total_count = N`, `N` a positive integer,
82  * `Z` is the normalization constant, and,
83  * `N!` denotes `N` factorial.
84
85  Distribution parameters are automatically broadcast in all functions; see
86  examples for details.
87
88  #### Pitfalls
89
90  The number of classes, `K`, must not exceed:
91  - the largest integer representable by `self.dtype`, i.e.,
92    `2**(mantissa_bits+1)` (IEE754),
93  - the maximum `Tensor` index, i.e., `2**31-1`.
94
95  In other words,
96
97  ```python
98  K <= min(2**31-1, {
99    tf.float16: 2**11,
100    tf.float32: 2**24,
101    tf.float64: 2**53 }[param.dtype])
102  ```
103
104  Note: This condition is validated only when `self.validate_args = True`.
105
106  #### Examples
107
108  Create a 3-class distribution, with the 3rd class is most likely to be drawn,
109  using logits.
110
111  ```python
112  logits = [-50., -43, 0]
113  dist = Multinomial(total_count=4., logits=logits)
114  ```
115
116  Create a 3-class distribution, with the 3rd class is most likely to be drawn.
117
118  ```python
119  p = [.2, .3, .5]
120  dist = Multinomial(total_count=4., probs=p)
121  ```
122
123  The distribution functions can be evaluated on counts.
124
125  ```python
126  # counts same shape as p.
127  counts = [1., 0, 3]
128  dist.prob(counts)  # Shape []
129
130  # p will be broadcast to [[.2, .3, .5], [.2, .3, .5]] to match counts.
131  counts = [[1., 2, 1], [2, 2, 0]]
132  dist.prob(counts)  # Shape [2]
133
134  # p will be broadcast to shape [5, 7, 3] to match counts.
135  counts = [[...]]  # Shape [5, 7, 3]
136  dist.prob(counts)  # Shape [5, 7]
137  ```
138
139  Create a 2-batch of 3-class distributions.
140
141  ```python
142  p = [[.1, .2, .7], [.3, .3, .4]]  # Shape [2, 3]
143  dist = Multinomial(total_count=[4., 5], probs=p)
144
145  counts = [[2., 1, 1], [3, 1, 1]]
146  dist.prob(counts)  # Shape [2]
147
148  dist.sample(5) # Shape [5, 2, 3]
149  ```
150  """
151
152  @deprecation.deprecated(
153      "2019-01-01",
154      "The TensorFlow Distributions library has moved to "
155      "TensorFlow Probability "
156      "(https://github.com/tensorflow/probability). You "
157      "should update all references to use `tfp.distributions` "
158      "instead of `tf.distributions`.",
159      warn_once=True)
160  def __init__(self,
161               total_count,
162               logits=None,
163               probs=None,
164               validate_args=False,
165               allow_nan_stats=True,
166               name="Multinomial"):
167    """Initialize a batch of Multinomial distributions.
168
169    Args:
170      total_count: Non-negative floating point tensor with shape broadcastable
171        to `[N1,..., Nm]` with `m >= 0`. Defines this as a batch of
172        `N1 x ... x Nm` different Multinomial distributions. Its components
173        should be equal to integer values.
174      logits: Floating point tensor representing unnormalized log-probabilities
175        of a positive event with shape broadcastable to
176        `[N1,..., Nm, K]` `m >= 0`, and the same dtype as `total_count`. Defines
177        this as a batch of `N1 x ... x Nm` different `K` class Multinomial
178        distributions. Only one of `logits` or `probs` should be passed in.
179      probs: Positive floating point tensor with shape broadcastable to
180        `[N1,..., Nm, K]` `m >= 0` and same dtype as `total_count`. Defines
181        this as a batch of `N1 x ... x Nm` different `K` class Multinomial
182        distributions. `probs`'s components in the last portion of its shape
183        should sum to `1`. Only one of `logits` or `probs` should be passed in.
184      validate_args: Python `bool`, default `False`. When `True` distribution
185        parameters are checked for validity despite possibly degrading runtime
186        performance. When `False` invalid inputs may silently render incorrect
187        outputs.
188      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
189        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
190        result is undefined. When `False`, an exception is raised if one or
191        more of the statistic's batch members are undefined.
192      name: Python `str` name prefixed to Ops created by this class.
193    """
194    parameters = dict(locals())
195    with ops.name_scope(name, values=[total_count, logits, probs]) as name:
196      self._total_count = ops.convert_to_tensor(total_count, name="total_count")
197      if validate_args:
198        self._total_count = (
199            distribution_util.embed_check_nonnegative_integer_form(
200                self._total_count))
201      self._logits, self._probs = distribution_util.get_logits_and_probs(
202          logits=logits,
203          probs=probs,
204          multidimensional=True,
205          validate_args=validate_args,
206          name=name)
207      self._mean_val = self._total_count[..., array_ops.newaxis] * self._probs
208    super(Multinomial, self).__init__(
209        dtype=self._probs.dtype,
210        reparameterization_type=distribution.NOT_REPARAMETERIZED,
211        validate_args=validate_args,
212        allow_nan_stats=allow_nan_stats,
213        parameters=parameters,
214        graph_parents=[self._total_count,
215                       self._logits,
216                       self._probs],
217        name=name)
218
219  @property
220  def total_count(self):
221    """Number of trials used to construct a sample."""
222    return self._total_count
223
224  @property
225  def logits(self):
226    """Vector of coordinatewise logits."""
227    return self._logits
228
229  @property
230  def probs(self):
231    """Probability of drawing a `1` in that coordinate."""
232    return self._probs
233
234  def _batch_shape_tensor(self):
235    return array_ops.shape(self._mean_val)[:-1]
236
237  def _batch_shape(self):
238    return self._mean_val.get_shape().with_rank_at_least(1)[:-1]
239
240  def _event_shape_tensor(self):
241    return array_ops.shape(self._mean_val)[-1:]
242
243  def _event_shape(self):
244    return self._mean_val.get_shape().with_rank_at_least(1)[-1:]
245
246  def _sample_n(self, n, seed=None):
247    n_draws = math_ops.cast(self.total_count, dtype=dtypes.int32)
248    k = self.event_shape_tensor()[0]
249
250    # broadcast the total_count and logits to same shape
251    n_draws = array_ops.ones_like(
252        self.logits[..., 0], dtype=n_draws.dtype) * n_draws
253    logits = array_ops.ones_like(
254        n_draws[..., array_ops.newaxis], dtype=self.logits.dtype) * self.logits
255
256    # flatten the total_count and logits
257    flat_logits = array_ops.reshape(logits, [-1, k])  # [B1B2...Bm, k]
258    flat_ndraws = n * array_ops.reshape(n_draws, [-1])  # [B1B2...Bm]
259
260    # computes each total_count and logits situation by map_fn
261    def _sample_single(args):
262      logits, n_draw = args[0], args[1]  # [K], []
263      x = random_ops.multinomial(logits[array_ops.newaxis, ...], n_draw,
264                                 seed)  # [1, n*n_draw]
265      x = array_ops.reshape(x, shape=[n, -1])  # [n, n_draw]
266      x = math_ops.reduce_sum(array_ops.one_hot(x, depth=k), axis=-2)  # [n, k]
267      return x
268
269    x = map_fn.map_fn(
270        _sample_single, [flat_logits, flat_ndraws],
271        dtype=self.dtype)  # [B1B2...Bm, n, k]
272
273    # reshape the results to proper shape
274    x = array_ops.transpose(x, perm=[1, 0, 2])
275    final_shape = array_ops.concat([[n], self.batch_shape_tensor(), [k]], 0)
276    x = array_ops.reshape(x, final_shape)  # [n, B1, B2,..., Bm, k]
277    return x
278
279  @distribution_util.AppendDocstring(_multinomial_sample_note)
280  def _log_prob(self, counts):
281    return self._log_unnormalized_prob(counts) - self._log_normalization(counts)
282
283  def _log_unnormalized_prob(self, counts):
284    counts = self._maybe_assert_valid_sample(counts)
285    return math_ops.reduce_sum(counts * nn_ops.log_softmax(self.logits), -1)
286
287  def _log_normalization(self, counts):
288    counts = self._maybe_assert_valid_sample(counts)
289    return -distribution_util.log_combinations(self.total_count, counts)
290
291  def _mean(self):
292    return array_ops.identity(self._mean_val)
293
294  def _covariance(self):
295    p = self.probs * array_ops.ones_like(
296        self.total_count)[..., array_ops.newaxis]
297    return array_ops.matrix_set_diag(
298        -math_ops.matmul(self._mean_val[..., array_ops.newaxis],
299                         p[..., array_ops.newaxis, :]),  # outer product
300        self._variance())
301
302  def _variance(self):
303    p = self.probs * array_ops.ones_like(
304        self.total_count)[..., array_ops.newaxis]
305    return self._mean_val - self._mean_val * p
306
307  def _maybe_assert_valid_sample(self, counts):
308    """Check counts for proper shape, values, then return tensor version."""
309    if not self.validate_args:
310      return counts
311    counts = distribution_util.embed_check_nonnegative_integer_form(counts)
312    return control_flow_ops.with_dependencies([
313        check_ops.assert_equal(
314            self.total_count, math_ops.reduce_sum(counts, -1),
315            message="counts must sum to `self.total_count`"),
316    ], counts)
317