• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for Categorical distribution."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from absl.testing import parameterized
22import numpy as np
23
24from tensorflow.python.eager import backprop
25from tensorflow.python.framework import constant_op
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import tensor_util
28from tensorflow.python.framework import test_util
29from tensorflow.python.ops import array_ops
30from tensorflow.python.ops import gradients_impl
31from tensorflow.python.ops import math_ops
32from tensorflow.python.ops import nn_ops
33from tensorflow.python.ops import random_ops
34from tensorflow.python.ops.distributions import categorical
35from tensorflow.python.ops.distributions import kullback_leibler
36from tensorflow.python.ops.distributions import normal
37from tensorflow.python.platform import test
38
39
40def make_categorical(batch_shape, num_classes, dtype=dtypes.int32):
41  logits = random_ops.random_uniform(
42      list(batch_shape) + [num_classes], -10, 10, dtype=dtypes.float32) - 50.
43  return categorical.Categorical(logits, dtype=dtype)
44
45
46class CategoricalTest(test.TestCase, parameterized.TestCase):
47
48  @test_util.run_deprecated_v1
49  def testP(self):
50    p = [0.2, 0.8]
51    dist = categorical.Categorical(probs=p)
52    with self.cached_session():
53      self.assertAllClose(p, dist.probs.eval())
54      self.assertAllEqual([2], dist.logits.get_shape())
55
56  @test_util.run_deprecated_v1
57  def testLogits(self):
58    p = np.array([0.2, 0.8], dtype=np.float32)
59    logits = np.log(p) - 50.
60    dist = categorical.Categorical(logits=logits)
61    with self.cached_session():
62      self.assertAllEqual([2], dist.probs.get_shape())
63      self.assertAllEqual([2], dist.logits.get_shape())
64      self.assertAllClose(dist.probs.eval(), p)
65      self.assertAllClose(dist.logits.eval(), logits)
66
67  @test_util.run_deprecated_v1
68  def testShapes(self):
69    with self.cached_session():
70      for batch_shape in ([], [1], [2, 3, 4]):
71        dist = make_categorical(batch_shape, 10)
72        self.assertAllEqual(batch_shape, dist.batch_shape)
73        self.assertAllEqual(batch_shape, dist.batch_shape_tensor().eval())
74        self.assertAllEqual([], dist.event_shape)
75        self.assertAllEqual([], dist.event_shape_tensor().eval())
76        self.assertEqual(10, dist.event_size.eval())
77        # event_size is available as a constant because the shape is
78        # known at graph build time.
79        self.assertEqual(10, tensor_util.constant_value(dist.event_size))
80
81      for batch_shape in ([], [1], [2, 3, 4]):
82        dist = make_categorical(
83            batch_shape, constant_op.constant(
84                10, dtype=dtypes.int32))
85        self.assertAllEqual(len(batch_shape), dist.batch_shape.ndims)
86        self.assertAllEqual(batch_shape, dist.batch_shape_tensor().eval())
87        self.assertAllEqual([], dist.event_shape)
88        self.assertAllEqual([], dist.event_shape_tensor().eval())
89        self.assertEqual(10, dist.event_size.eval())
90
91  def testDtype(self):
92    dist = make_categorical([], 5, dtype=dtypes.int32)
93    self.assertEqual(dist.dtype, dtypes.int32)
94    self.assertEqual(dist.dtype, dist.sample(5).dtype)
95    self.assertEqual(dist.dtype, dist.mode().dtype)
96    dist = make_categorical([], 5, dtype=dtypes.int64)
97    self.assertEqual(dist.dtype, dtypes.int64)
98    self.assertEqual(dist.dtype, dist.sample(5).dtype)
99    self.assertEqual(dist.dtype, dist.mode().dtype)
100    self.assertEqual(dist.probs.dtype, dtypes.float32)
101    self.assertEqual(dist.logits.dtype, dtypes.float32)
102    self.assertEqual(dist.logits.dtype, dist.entropy().dtype)
103    self.assertEqual(
104        dist.logits.dtype, dist.prob(np.array(
105            0, dtype=np.int64)).dtype)
106    self.assertEqual(
107        dist.logits.dtype, dist.log_prob(np.array(
108            0, dtype=np.int64)).dtype)
109    for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
110      dist = make_categorical([], 5, dtype=dtype)
111      self.assertEqual(dist.dtype, dtype)
112      self.assertEqual(dist.dtype, dist.sample(5).dtype)
113
114  @test_util.run_deprecated_v1
115  def testUnknownShape(self):
116    with self.cached_session():
117      logits = array_ops.placeholder(dtype=dtypes.float32)
118      dist = categorical.Categorical(logits)
119      sample = dist.sample()
120      # Will sample class 1.
121      sample_value = sample.eval(feed_dict={logits: [-1000.0, 1000.0]})
122      self.assertEqual(1, sample_value)
123
124      # Batch entry 0 will sample class 1, batch entry 1 will sample class 0.
125      sample_value_batch = sample.eval(
126          feed_dict={logits: [[-1000.0, 1000.0], [1000.0, -1000.0]]})
127      self.assertAllEqual([1, 0], sample_value_batch)
128
129  @test_util.run_deprecated_v1
130  def testPMFWithBatch(self):
131    histograms = [[0.2, 0.8], [0.6, 0.4]]
132    dist = categorical.Categorical(math_ops.log(histograms) - 50.)
133    with self.cached_session():
134      self.assertAllClose(dist.prob([0, 1]).eval(), [0.2, 0.4])
135
136  @test_util.run_deprecated_v1
137  def testPMFNoBatch(self):
138    histograms = [0.2, 0.8]
139    dist = categorical.Categorical(math_ops.log(histograms) - 50.)
140    with self.cached_session():
141      self.assertAllClose(dist.prob(0).eval(), 0.2)
142
143  @test_util.run_deprecated_v1
144  def testCDFWithDynamicEventShapeKnownNdims(self):
145    """Test that dynamically-sized events with unknown shape work."""
146    batch_size = 2
147    histograms = array_ops.placeholder(dtype=dtypes.float32,
148                                       shape=(batch_size, None))
149    event = array_ops.placeholder(dtype=dtypes.float32, shape=(batch_size,))
150    dist = categorical.Categorical(probs=histograms)
151    cdf_op = dist.cdf(event)
152
153    # Feed values into the placeholder with different shapes
154    # three classes.
155    event_feed_one = [0, 1]
156    histograms_feed_one = [[0.5, 0.3, 0.2], [1.0, 0.0, 0.0]]
157    expected_cdf_one = [0.0, 1.0]
158    feed_dict_one = {
159        histograms: histograms_feed_one,
160        event: event_feed_one
161    }
162
163    # six classes.
164    event_feed_two = [2, 5]
165    histograms_feed_two = [[0.9, 0.0, 0.0, 0.0, 0.0, 0.1],
166                           [0.15, 0.2, 0.05, 0.35, 0.13, 0.12]]
167    expected_cdf_two = [0.9, 0.88]
168    feed_dict_two = {
169        histograms: histograms_feed_two,
170        event: event_feed_two
171    }
172
173    with self.cached_session() as sess:
174      actual_cdf_one = sess.run(cdf_op, feed_dict=feed_dict_one)
175      actual_cdf_two = sess.run(cdf_op, feed_dict=feed_dict_two)
176
177    self.assertAllClose(actual_cdf_one, expected_cdf_one)
178    self.assertAllClose(actual_cdf_two, expected_cdf_two)
179
180  @parameterized.named_parameters(
181      ("test1", [0, 1], [[0.5, 0.3, 0.2], [1.0, 0.0, 0.0]], [0.0, 1.0]),
182      ("test2", [2, 5], [[0.9, 0.0, 0.0, 0.0, 0.0, 0.1],
183                         [0.15, 0.2, 0.05, 0.35, 0.13, 0.12]], [0.9, 0.88]))
184  def testCDFWithDynamicEventShapeUnknownNdims(
185      self, events, histograms, expected_cdf):
186    """Test that dynamically-sized events with unknown shape work."""
187    event_ph = array_ops.placeholder_with_default(events, shape=None)
188    histograms_ph = array_ops.placeholder_with_default(histograms, shape=None)
189    dist = categorical.Categorical(probs=histograms_ph)
190    cdf_op = dist.cdf(event_ph)
191
192    actual_cdf = self.evaluate(cdf_op)
193    self.assertAllClose(actual_cdf, expected_cdf)
194
195  @test_util.run_deprecated_v1
196  def testCDFWithBatch(self):
197    histograms = [[0.1, 0.2, 0.3, 0.25, 0.15],
198                  [0.0, 0.75, 0.2, 0.05, 0.0]]
199    event = [0, 3]
200    expected_cdf = [0.0, 0.95]
201    dist = categorical.Categorical(probs=histograms)
202    cdf_op = dist.cdf(event)
203
204    with self.cached_session():
205      self.assertAllClose(cdf_op.eval(), expected_cdf)
206
207  @test_util.run_deprecated_v1
208  def testCDFNoBatch(self):
209    histogram = [0.1, 0.2, 0.3, 0.4]
210    event = 2
211    expected_cdf = 0.3
212    dist = categorical.Categorical(probs=histogram)
213    cdf_op = dist.cdf(event)
214
215    with self.cached_session():
216      self.assertAlmostEqual(cdf_op.eval(), expected_cdf)
217
218  @test_util.run_deprecated_v1
219  def testCDFBroadcasting(self):
220    # shape: [batch=2, n_bins=3]
221    histograms = [[0.2, 0.1, 0.7],
222                  [0.3, 0.45, 0.25]]
223
224    # shape: [batch=3, batch=2]
225    devent = [
226        [0, 0],
227        [1, 1],
228        [2, 2]
229    ]
230    dist = categorical.Categorical(probs=histograms)
231
232    # We test that the probabilities are correctly broadcasted over the
233    # additional leading batch dimension of size 3.
234    expected_cdf_result = np.zeros((3, 2))
235    expected_cdf_result[0, 0] = 0
236    expected_cdf_result[0, 1] = 0
237    expected_cdf_result[1, 0] = 0.2
238    expected_cdf_result[1, 1] = 0.3
239    expected_cdf_result[2, 0] = 0.3
240    expected_cdf_result[2, 1] = 0.75
241
242    with self.cached_session():
243      self.assertAllClose(dist.cdf(devent).eval(), expected_cdf_result)
244
245  def testBroadcastWithBatchParamsAndBiggerEvent(self):
246    ## The parameters have a single batch dimension, and the event has two.
247
248    # param shape is [3 x 4], where 4 is the number of bins (non-batch dim).
249    cat_params_py = [
250        [0.2, 0.15, 0.35, 0.3],
251        [0.1, 0.05, 0.68, 0.17],
252        [0.1, 0.05, 0.68, 0.17]
253    ]
254
255    # event shape = [5, 3], both are "batch" dimensions.
256    disc_event_py = [
257        [0, 1, 2],
258        [1, 2, 3],
259        [0, 0, 0],
260        [1, 1, 1],
261        [2, 1, 0]
262    ]
263
264    # shape is [3]
265    normal_params_py = [
266        -10.0,
267        120.0,
268        50.0
269    ]
270
271    # shape is [5, 3]
272    real_event_py = [
273        [-1.0, 0.0, 1.0],
274        [100.0, 101, -50],
275        [90, 90, 90],
276        [-4, -400, 20.0],
277        [0.0, 0.0, 0.0]
278    ]
279
280    cat_params_tf = array_ops.constant(cat_params_py)
281    disc_event_tf = array_ops.constant(disc_event_py)
282    cat = categorical.Categorical(probs=cat_params_tf)
283
284    normal_params_tf = array_ops.constant(normal_params_py)
285    real_event_tf = array_ops.constant(real_event_py)
286    norm = normal.Normal(loc=normal_params_tf, scale=1.0)
287
288    # Check that normal and categorical have the same broadcasting behaviour.
289    to_run = {
290        "cat_prob": cat.prob(disc_event_tf),
291        "cat_log_prob": cat.log_prob(disc_event_tf),
292        "cat_cdf": cat.cdf(disc_event_tf),
293        "cat_log_cdf": cat.log_cdf(disc_event_tf),
294        "norm_prob": norm.prob(real_event_tf),
295        "norm_log_prob": norm.log_prob(real_event_tf),
296        "norm_cdf": norm.cdf(real_event_tf),
297        "norm_log_cdf": norm.log_cdf(real_event_tf),
298    }
299
300    with self.cached_session() as sess:
301      run_result = self.evaluate(to_run)
302
303    self.assertAllEqual(run_result["cat_prob"].shape,
304                        run_result["norm_prob"].shape)
305    self.assertAllEqual(run_result["cat_log_prob"].shape,
306                        run_result["norm_log_prob"].shape)
307    self.assertAllEqual(run_result["cat_cdf"].shape,
308                        run_result["norm_cdf"].shape)
309    self.assertAllEqual(run_result["cat_log_cdf"].shape,
310                        run_result["norm_log_cdf"].shape)
311
312  @test_util.run_deprecated_v1
313  def testLogPMF(self):
314    logits = np.log([[0.2, 0.8], [0.6, 0.4]]) - 50.
315    dist = categorical.Categorical(logits)
316    with self.cached_session():
317      self.assertAllClose(dist.log_prob([0, 1]).eval(), np.log([0.2, 0.4]))
318      self.assertAllClose(dist.log_prob([0.0, 1.0]).eval(), np.log([0.2, 0.4]))
319
320  @test_util.run_deprecated_v1
321  def testEntropyNoBatch(self):
322    logits = np.log([0.2, 0.8]) - 50.
323    dist = categorical.Categorical(logits)
324    with self.cached_session():
325      self.assertAllClose(dist.entropy().eval(),
326                          -(0.2 * np.log(0.2) + 0.8 * np.log(0.8)))
327
328  @test_util.run_deprecated_v1
329  def testEntropyWithBatch(self):
330    logits = np.log([[0.2, 0.8], [0.6, 0.4]]) - 50.
331    dist = categorical.Categorical(logits)
332    with self.cached_session():
333      self.assertAllClose(dist.entropy().eval(), [
334          -(0.2 * np.log(0.2) + 0.8 * np.log(0.8)),
335          -(0.6 * np.log(0.6) + 0.4 * np.log(0.4))
336      ])
337
338  @test_util.run_deprecated_v1
339  def testEntropyGradient(self):
340    with self.cached_session() as sess:
341      logits = constant_op.constant([[1., 2., 3.], [2., 5., 1.]])
342
343      probabilities = nn_ops.softmax(logits)
344      log_probabilities = nn_ops.log_softmax(logits)
345      true_entropy = - math_ops.reduce_sum(
346          probabilities * log_probabilities, axis=-1)
347
348      categorical_distribution = categorical.Categorical(probs=probabilities)
349      categorical_entropy = categorical_distribution.entropy()
350
351      # works
352      true_entropy_g = gradients_impl.gradients(true_entropy, [logits])
353      categorical_entropy_g = gradients_impl.gradients(
354          categorical_entropy, [logits])
355
356      res = sess.run({"true_entropy": true_entropy,
357                      "categorical_entropy": categorical_entropy,
358                      "true_entropy_g": true_entropy_g,
359                      "categorical_entropy_g": categorical_entropy_g})
360      self.assertAllClose(res["true_entropy"],
361                          res["categorical_entropy"])
362      self.assertAllClose(res["true_entropy_g"],
363                          res["categorical_entropy_g"])
364
365  def testSample(self):
366    with self.cached_session():
367      histograms = [[[0.2, 0.8], [0.4, 0.6]]]
368      dist = categorical.Categorical(math_ops.log(histograms) - 50.)
369      n = 10000
370      samples = dist.sample(n, seed=123)
371      samples.set_shape([n, 1, 2])
372      self.assertEqual(samples.dtype, dtypes.int32)
373      sample_values = self.evaluate(samples)
374      self.assertFalse(np.any(sample_values < 0))
375      self.assertFalse(np.any(sample_values > 1))
376      self.assertAllClose(
377          [[0.2, 0.4]], np.mean(
378              sample_values == 0, axis=0), atol=1e-2)
379      self.assertAllClose(
380          [[0.8, 0.6]], np.mean(
381              sample_values == 1, axis=0), atol=1e-2)
382
383  def testSampleWithSampleShape(self):
384    with self.cached_session():
385      histograms = [[[0.2, 0.8], [0.4, 0.6]]]
386      dist = categorical.Categorical(math_ops.log(histograms) - 50.)
387      samples = dist.sample((100, 100), seed=123)
388      prob = dist.prob(samples)
389      prob_val = self.evaluate(prob)
390      self.assertAllClose(
391          [0.2**2 + 0.8**2], [prob_val[:, :, :, 0].mean()], atol=1e-2)
392      self.assertAllClose(
393          [0.4**2 + 0.6**2], [prob_val[:, :, :, 1].mean()], atol=1e-2)
394
395  def testNotReparameterized(self):
396    p = constant_op.constant([0.3, 0.3, 0.4])
397    with backprop.GradientTape() as tape:
398      tape.watch(p)
399      dist = categorical.Categorical(p)
400      samples = dist.sample(100)
401    grad_p = tape.gradient(samples, p)
402    self.assertIsNone(grad_p)
403
404  def testLogPMFBroadcasting(self):
405    with self.cached_session():
406      # 1 x 2 x 2
407      histograms = [[[0.2, 0.8], [0.4, 0.6]]]
408      dist = categorical.Categorical(math_ops.log(histograms) - 50.)
409
410      prob = dist.prob(1)
411      self.assertAllClose([[0.8, 0.6]], self.evaluate(prob))
412
413      prob = dist.prob([1])
414      self.assertAllClose([[0.8, 0.6]], self.evaluate(prob))
415
416      prob = dist.prob([0, 1])
417      self.assertAllClose([[0.2, 0.6]], self.evaluate(prob))
418
419      prob = dist.prob([[0, 1]])
420      self.assertAllClose([[0.2, 0.6]], self.evaluate(prob))
421
422      prob = dist.prob([[[0, 1]]])
423      self.assertAllClose([[[0.2, 0.6]]], self.evaluate(prob))
424
425      prob = dist.prob([[1, 0], [0, 1]])
426      self.assertAllClose([[0.8, 0.4], [0.2, 0.6]], self.evaluate(prob))
427
428      prob = dist.prob([[[1, 1], [1, 0]], [[1, 0], [0, 1]]])
429      self.assertAllClose([[[0.8, 0.6], [0.8, 0.4]], [[0.8, 0.4], [0.2, 0.6]]],
430                          self.evaluate(prob))
431
432  def testLogPMFShape(self):
433    with self.cached_session():
434      # shape [1, 2, 2]
435      histograms = [[[0.2, 0.8], [0.4, 0.6]]]
436      dist = categorical.Categorical(math_ops.log(histograms))
437
438      log_prob = dist.log_prob([0, 1])
439      self.assertEqual(2, log_prob.get_shape().ndims)
440      self.assertAllEqual([1, 2], log_prob.get_shape())
441
442      log_prob = dist.log_prob([[[1, 1], [1, 0]], [[1, 0], [0, 1]]])
443      self.assertEqual(3, log_prob.get_shape().ndims)
444      self.assertAllEqual([2, 2, 2], log_prob.get_shape())
445
446  def testLogPMFShapeNoBatch(self):
447    histograms = [0.2, 0.8]
448    dist = categorical.Categorical(math_ops.log(histograms))
449
450    log_prob = dist.log_prob(0)
451    self.assertEqual(0, log_prob.get_shape().ndims)
452    self.assertAllEqual([], log_prob.get_shape())
453
454    log_prob = dist.log_prob([[[1, 1], [1, 0]], [[1, 0], [0, 1]]])
455    self.assertEqual(3, log_prob.get_shape().ndims)
456    self.assertAllEqual([2, 2, 2], log_prob.get_shape())
457
458  @test_util.run_deprecated_v1
459  def testMode(self):
460    with self.cached_session():
461      histograms = [[[0.2, 0.8], [0.6, 0.4]]]
462      dist = categorical.Categorical(math_ops.log(histograms) - 50.)
463      self.assertAllEqual(dist.mode().eval(), [[1, 0]])
464
465  @test_util.run_deprecated_v1
466  def testCategoricalCategoricalKL(self):
467
468    def np_softmax(logits):
469      exp_logits = np.exp(logits)
470      return exp_logits / exp_logits.sum(axis=-1, keepdims=True)
471
472    with self.cached_session() as sess:
473      for categories in [2, 4]:
474        for batch_size in [1, 10]:
475          a_logits = np.random.randn(batch_size, categories)
476          b_logits = np.random.randn(batch_size, categories)
477
478          a = categorical.Categorical(logits=a_logits)
479          b = categorical.Categorical(logits=b_logits)
480
481          kl = kullback_leibler.kl_divergence(a, b)
482          kl_val = self.evaluate(kl)
483          # Make sure KL(a||a) is 0
484          kl_same = sess.run(kullback_leibler.kl_divergence(a, a))
485
486          prob_a = np_softmax(a_logits)
487          prob_b = np_softmax(b_logits)
488          kl_expected = np.sum(prob_a * (np.log(prob_a) - np.log(prob_b)),
489                               axis=-1)
490
491          self.assertEqual(kl.get_shape(), (batch_size,))
492          self.assertAllClose(kl_val, kl_expected)
493          self.assertAllClose(kl_same, np.zeros_like(kl_expected))
494
495
496if __name__ == "__main__":
497  test.main()
498