• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for Categorical distribution."""
16
17from absl.testing import parameterized
18import numpy as np
19
20from tensorflow.python.eager import backprop
21from tensorflow.python.framework import constant_op
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import tensor_util
24from tensorflow.python.framework import test_util
25from tensorflow.python.ops import array_ops
26from tensorflow.python.ops import gradients_impl
27from tensorflow.python.ops import math_ops
28from tensorflow.python.ops import nn_ops
29from tensorflow.python.ops import random_ops
30from tensorflow.python.ops.distributions import categorical
31from tensorflow.python.ops.distributions import kullback_leibler
32from tensorflow.python.ops.distributions import normal
33from tensorflow.python.platform import test
34
35
36def make_categorical(batch_shape, num_classes, dtype=dtypes.int32):
37  logits = random_ops.random_uniform(
38      list(batch_shape) + [num_classes], -10, 10, dtype=dtypes.float32) - 50.
39  return categorical.Categorical(logits, dtype=dtype)
40
41
42class CategoricalTest(test.TestCase, parameterized.TestCase):
43
44  @test_util.run_deprecated_v1
45  def testP(self):
46    p = [0.2, 0.8]
47    dist = categorical.Categorical(probs=p)
48    with self.cached_session():
49      self.assertAllClose(p, dist.probs)
50      self.assertAllEqual([2], dist.logits.get_shape())
51
52  @test_util.run_deprecated_v1
53  def testLogits(self):
54    p = np.array([0.2, 0.8], dtype=np.float32)
55    logits = np.log(p) - 50.
56    dist = categorical.Categorical(logits=logits)
57    with self.cached_session():
58      self.assertAllEqual([2], dist.probs.get_shape())
59      self.assertAllEqual([2], dist.logits.get_shape())
60      self.assertAllClose(dist.probs, p)
61      self.assertAllClose(dist.logits, logits)
62
63  @test_util.run_deprecated_v1
64  def testShapes(self):
65    with self.cached_session():
66      for batch_shape in ([], [1], [2, 3, 4]):
67        dist = make_categorical(batch_shape, 10)
68        self.assertAllEqual(batch_shape, dist.batch_shape)
69        self.assertAllEqual(batch_shape, dist.batch_shape_tensor())
70        self.assertAllEqual([], dist.event_shape)
71        self.assertAllEqual([], dist.event_shape_tensor())
72        self.assertEqual(10, dist.event_size.eval())
73        # event_size is available as a constant because the shape is
74        # known at graph build time.
75        self.assertEqual(10, tensor_util.constant_value(dist.event_size))
76
77      for batch_shape in ([], [1], [2, 3, 4]):
78        dist = make_categorical(
79            batch_shape, constant_op.constant(
80                10, dtype=dtypes.int32))
81        self.assertAllEqual(len(batch_shape), dist.batch_shape.ndims)
82        self.assertAllEqual(batch_shape, dist.batch_shape_tensor())
83        self.assertAllEqual([], dist.event_shape)
84        self.assertAllEqual([], dist.event_shape_tensor())
85        self.assertEqual(10, dist.event_size.eval())
86
87  def testDtype(self):
88    dist = make_categorical([], 5, dtype=dtypes.int32)
89    self.assertEqual(dist.dtype, dtypes.int32)
90    self.assertEqual(dist.dtype, dist.sample(5).dtype)
91    self.assertEqual(dist.dtype, dist.mode().dtype)
92    dist = make_categorical([], 5, dtype=dtypes.int64)
93    self.assertEqual(dist.dtype, dtypes.int64)
94    self.assertEqual(dist.dtype, dist.sample(5).dtype)
95    self.assertEqual(dist.dtype, dist.mode().dtype)
96    self.assertEqual(dist.probs.dtype, dtypes.float32)
97    self.assertEqual(dist.logits.dtype, dtypes.float32)
98    self.assertEqual(dist.logits.dtype, dist.entropy().dtype)
99    self.assertEqual(
100        dist.logits.dtype, dist.prob(np.array(
101            0, dtype=np.int64)).dtype)
102    self.assertEqual(
103        dist.logits.dtype, dist.log_prob(np.array(
104            0, dtype=np.int64)).dtype)
105    for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
106      dist = make_categorical([], 5, dtype=dtype)
107      self.assertEqual(dist.dtype, dtype)
108      self.assertEqual(dist.dtype, dist.sample(5).dtype)
109
110  @test_util.run_deprecated_v1
111  def testUnknownShape(self):
112    with self.cached_session():
113      logits = array_ops.placeholder(dtype=dtypes.float32)
114      dist = categorical.Categorical(logits)
115      sample = dist.sample()
116      # Will sample class 1.
117      sample_value = sample.eval(feed_dict={logits: [-1000.0, 1000.0]})
118      self.assertEqual(1, sample_value)
119
120      # Batch entry 0 will sample class 1, batch entry 1 will sample class 0.
121      sample_value_batch = sample.eval(
122          feed_dict={logits: [[-1000.0, 1000.0], [1000.0, -1000.0]]})
123      self.assertAllEqual([1, 0], sample_value_batch)
124
125  @test_util.run_deprecated_v1
126  def testPMFWithBatch(self):
127    histograms = [[0.2, 0.8], [0.6, 0.4]]
128    dist = categorical.Categorical(math_ops.log(histograms) - 50.)
129    with self.cached_session():
130      self.assertAllClose(dist.prob([0, 1]), [0.2, 0.4])
131
132  @test_util.run_deprecated_v1
133  def testPMFNoBatch(self):
134    histograms = [0.2, 0.8]
135    dist = categorical.Categorical(math_ops.log(histograms) - 50.)
136    with self.cached_session():
137      self.assertAllClose(dist.prob(0), 0.2)
138
139  @test_util.run_deprecated_v1
140  def testCDFWithDynamicEventShapeKnownNdims(self):
141    """Test that dynamically-sized events with unknown shape work."""
142    batch_size = 2
143    histograms = array_ops.placeholder(dtype=dtypes.float32,
144                                       shape=(batch_size, None))
145    event = array_ops.placeholder(dtype=dtypes.float32, shape=(batch_size,))
146    dist = categorical.Categorical(probs=histograms)
147    cdf_op = dist.cdf(event)
148
149    # Feed values into the placeholder with different shapes three classes.
150    event_feed_one = [0, 1]
151    histograms_feed_one = [[0.5, 0.3, 0.2], [1.0, 0.0, 0.0]]
152    expected_cdf_one = [0.0, 1.0]
153    feed_dict_one = {
154        histograms: histograms_feed_one,
155        event: event_feed_one
156    }
157
158    # six classes.
159    event_feed_two = [2, 5]
160    histograms_feed_two = [[0.9, 0.0, 0.0, 0.0, 0.0, 0.1],
161                           [0.15, 0.2, 0.05, 0.35, 0.13, 0.12]]
162    expected_cdf_two = [0.9, 0.88]
163    feed_dict_two = {
164        histograms: histograms_feed_two,
165        event: event_feed_two
166    }
167
168    with self.cached_session() as sess:
169      actual_cdf_one = sess.run(cdf_op, feed_dict=feed_dict_one)
170      actual_cdf_two = sess.run(cdf_op, feed_dict=feed_dict_two)
171
172    self.assertAllClose(actual_cdf_one, expected_cdf_one)
173    self.assertAllClose(actual_cdf_two, expected_cdf_two)
174
175  @parameterized.named_parameters(
176      ("test1", [0, 1], [[0.5, 0.3, 0.2], [1.0, 0.0, 0.0]], [0.0, 1.0]),
177      ("test2", [2, 5], [[0.9, 0.0, 0.0, 0.0, 0.0, 0.1],
178                         [0.15, 0.2, 0.05, 0.35, 0.13, 0.12]], [0.9, 0.88]))
179  def testCDFWithDynamicEventShapeUnknownNdims(
180      self, events, histograms, expected_cdf):
181    """Test that dynamically-sized events with unknown shape work."""
182    event_ph = array_ops.placeholder_with_default(events, shape=None)
183    histograms_ph = array_ops.placeholder_with_default(histograms, shape=None)
184    dist = categorical.Categorical(probs=histograms_ph)
185    cdf_op = dist.cdf(event_ph)
186
187    actual_cdf = self.evaluate(cdf_op)
188    self.assertAllClose(actual_cdf, expected_cdf)
189
190  @test_util.run_deprecated_v1
191  def testCDFWithBatch(self):
192    histograms = [[0.1, 0.2, 0.3, 0.25, 0.15],
193                  [0.0, 0.75, 0.2, 0.05, 0.0]]
194    event = [0, 3]
195    expected_cdf = [0.0, 0.95]
196    dist = categorical.Categorical(probs=histograms)
197    cdf_op = dist.cdf(event)
198
199    with self.cached_session():
200      self.assertAllClose(cdf_op, expected_cdf)
201
202  @test_util.run_deprecated_v1
203  def testCDFNoBatch(self):
204    histogram = [0.1, 0.2, 0.3, 0.4]
205    event = 2
206    expected_cdf = 0.3
207    dist = categorical.Categorical(probs=histogram)
208    cdf_op = dist.cdf(event)
209
210    with self.cached_session():
211      self.assertAlmostEqual(cdf_op.eval(), expected_cdf)
212
213  @test_util.run_deprecated_v1
214  def testCDFBroadcasting(self):
215    # shape: [batch=2, n_bins=3]
216    histograms = [[0.2, 0.1, 0.7],
217                  [0.3, 0.45, 0.25]]
218
219    # shape: [batch=3, batch=2]
220    devent = [
221        [0, 0],
222        [1, 1],
223        [2, 2]
224    ]
225    dist = categorical.Categorical(probs=histograms)
226
227    # We test that the probabilities are correctly broadcasted over the
228    # additional leading batch dimension of size 3.
229    expected_cdf_result = np.zeros((3, 2))
230    expected_cdf_result[0, 0] = 0
231    expected_cdf_result[0, 1] = 0
232    expected_cdf_result[1, 0] = 0.2
233    expected_cdf_result[1, 1] = 0.3
234    expected_cdf_result[2, 0] = 0.3
235    expected_cdf_result[2, 1] = 0.75
236
237    with self.cached_session():
238      self.assertAllClose(dist.cdf(devent), expected_cdf_result)
239
240  def testBroadcastWithBatchParamsAndBiggerEvent(self):
241    ## The parameters have a single batch dimension, and the event has two.
242
243    # param shape is [3 x 4], where 4 is the number of bins (non-batch dim).
244    cat_params_py = [
245        [0.2, 0.15, 0.35, 0.3],
246        [0.1, 0.05, 0.68, 0.17],
247        [0.1, 0.05, 0.68, 0.17]
248    ]
249
250    # event shape = [5, 3], both are "batch" dimensions.
251    disc_event_py = [
252        [0, 1, 2],
253        [1, 2, 3],
254        [0, 0, 0],
255        [1, 1, 1],
256        [2, 1, 0]
257    ]
258
259    # shape is [3]
260    normal_params_py = [
261        -10.0,
262        120.0,
263        50.0
264    ]
265
266    # shape is [5, 3]
267    real_event_py = [
268        [-1.0, 0.0, 1.0],
269        [100.0, 101, -50],
270        [90, 90, 90],
271        [-4, -400, 20.0],
272        [0.0, 0.0, 0.0]
273    ]
274
275    cat_params_tf = array_ops.constant(cat_params_py)
276    disc_event_tf = array_ops.constant(disc_event_py)
277    cat = categorical.Categorical(probs=cat_params_tf)
278
279    normal_params_tf = array_ops.constant(normal_params_py)
280    real_event_tf = array_ops.constant(real_event_py)
281    norm = normal.Normal(loc=normal_params_tf, scale=1.0)
282
283    # Check that normal and categorical have the same broadcasting behaviour.
284    to_run = {
285        "cat_prob": cat.prob(disc_event_tf),
286        "cat_log_prob": cat.log_prob(disc_event_tf),
287        "cat_cdf": cat.cdf(disc_event_tf),
288        "cat_log_cdf": cat.log_cdf(disc_event_tf),
289        "norm_prob": norm.prob(real_event_tf),
290        "norm_log_prob": norm.log_prob(real_event_tf),
291        "norm_cdf": norm.cdf(real_event_tf),
292        "norm_log_cdf": norm.log_cdf(real_event_tf),
293    }
294
295    with self.cached_session() as sess:
296      run_result = self.evaluate(to_run)
297
298    self.assertAllEqual(run_result["cat_prob"].shape,
299                        run_result["norm_prob"].shape)
300    self.assertAllEqual(run_result["cat_log_prob"].shape,
301                        run_result["norm_log_prob"].shape)
302    self.assertAllEqual(run_result["cat_cdf"].shape,
303                        run_result["norm_cdf"].shape)
304    self.assertAllEqual(run_result["cat_log_cdf"].shape,
305                        run_result["norm_log_cdf"].shape)
306
307  @test_util.run_deprecated_v1
308  def testLogPMF(self):
309    logits = np.log([[0.2, 0.8], [0.6, 0.4]]) - 50.
310    dist = categorical.Categorical(logits)
311    with self.cached_session():
312      self.assertAllClose(dist.log_prob([0, 1]), np.log([0.2, 0.4]))
313      self.assertAllClose(dist.log_prob([0.0, 1.0]), np.log([0.2, 0.4]))
314
315  @test_util.run_deprecated_v1
316  def testEntropyNoBatch(self):
317    logits = np.log([0.2, 0.8]) - 50.
318    dist = categorical.Categorical(logits)
319    with self.cached_session():
320      self.assertAllClose(dist.entropy(),
321                          -(0.2 * np.log(0.2) + 0.8 * np.log(0.8)))
322
323  @test_util.run_deprecated_v1
324  def testEntropyWithBatch(self):
325    logits = np.log([[0.2, 0.8], [0.6, 0.4]]) - 50.
326    dist = categorical.Categorical(logits)
327    with self.cached_session():
328      self.assertAllClose(dist.entropy(), [
329          -(0.2 * np.log(0.2) + 0.8 * np.log(0.8)),
330          -(0.6 * np.log(0.6) + 0.4 * np.log(0.4))
331      ])
332
333  @test_util.run_deprecated_v1
334  def testEntropyGradient(self):
335    with self.cached_session() as sess:
336      logits = constant_op.constant([[1., 2., 3.], [2., 5., 1.]])
337
338      probabilities = nn_ops.softmax(logits)
339      log_probabilities = nn_ops.log_softmax(logits)
340      true_entropy = - math_ops.reduce_sum(
341          probabilities * log_probabilities, axis=-1)
342
343      categorical_distribution = categorical.Categorical(probs=probabilities)
344      categorical_entropy = categorical_distribution.entropy()
345
346      # works
347      true_entropy_g = gradients_impl.gradients(true_entropy, [logits])
348      categorical_entropy_g = gradients_impl.gradients(
349          categorical_entropy, [logits])
350
351      res = sess.run({"true_entropy": true_entropy,
352                      "categorical_entropy": categorical_entropy,
353                      "true_entropy_g": true_entropy_g,
354                      "categorical_entropy_g": categorical_entropy_g})
355      self.assertAllClose(res["true_entropy"],
356                          res["categorical_entropy"])
357      self.assertAllClose(res["true_entropy_g"],
358                          res["categorical_entropy_g"])
359
360  def testSample(self):
361    with self.cached_session():
362      histograms = [[[0.2, 0.8], [0.4, 0.6]]]
363      dist = categorical.Categorical(math_ops.log(histograms) - 50.)
364      n = 10000
365      samples = dist.sample(n, seed=123)
366      samples.set_shape([n, 1, 2])
367      self.assertEqual(samples.dtype, dtypes.int32)
368      sample_values = self.evaluate(samples)
369      self.assertFalse(np.any(sample_values < 0))
370      self.assertFalse(np.any(sample_values > 1))
371      self.assertAllClose(
372          [[0.2, 0.4]], np.mean(
373              sample_values == 0, axis=0), atol=1e-2)
374      self.assertAllClose(
375          [[0.8, 0.6]], np.mean(
376              sample_values == 1, axis=0), atol=1e-2)
377
378  def testSampleWithSampleShape(self):
379    with self.cached_session():
380      histograms = [[[0.2, 0.8], [0.4, 0.6]]]
381      dist = categorical.Categorical(math_ops.log(histograms) - 50.)
382      samples = dist.sample((100, 100), seed=123)
383      prob = dist.prob(samples)
384      prob_val = self.evaluate(prob)
385      self.assertAllClose(
386          [0.2**2 + 0.8**2], [prob_val[:, :, :, 0].mean()], atol=1e-2)
387      self.assertAllClose(
388          [0.4**2 + 0.6**2], [prob_val[:, :, :, 1].mean()], atol=1e-2)
389
390  def testNotReparameterized(self):
391    p = constant_op.constant([0.3, 0.3, 0.4])
392    with backprop.GradientTape() as tape:
393      tape.watch(p)
394      dist = categorical.Categorical(p)
395      samples = dist.sample(100)
396    grad_p = tape.gradient(samples, p)
397    self.assertIsNone(grad_p)
398
399  def testLogPMFBroadcasting(self):
400    with self.cached_session():
401      # 1 x 2 x 2
402      histograms = [[[0.2, 0.8], [0.4, 0.6]]]
403      dist = categorical.Categorical(math_ops.log(histograms) - 50.)
404
405      prob = dist.prob(1)
406      self.assertAllClose([[0.8, 0.6]], self.evaluate(prob))
407
408      prob = dist.prob([1])
409      self.assertAllClose([[0.8, 0.6]], self.evaluate(prob))
410
411      prob = dist.prob([0, 1])
412      self.assertAllClose([[0.2, 0.6]], self.evaluate(prob))
413
414      prob = dist.prob([[0, 1]])
415      self.assertAllClose([[0.2, 0.6]], self.evaluate(prob))
416
417      prob = dist.prob([[[0, 1]]])
418      self.assertAllClose([[[0.2, 0.6]]], self.evaluate(prob))
419
420      prob = dist.prob([[1, 0], [0, 1]])
421      self.assertAllClose([[0.8, 0.4], [0.2, 0.6]], self.evaluate(prob))
422
423      prob = dist.prob([[[1, 1], [1, 0]], [[1, 0], [0, 1]]])
424      self.assertAllClose([[[0.8, 0.6], [0.8, 0.4]], [[0.8, 0.4], [0.2, 0.6]]],
425                          self.evaluate(prob))
426
427  def testLogPMFShape(self):
428    with self.cached_session():
429      # shape [1, 2, 2]
430      histograms = [[[0.2, 0.8], [0.4, 0.6]]]
431      dist = categorical.Categorical(math_ops.log(histograms))
432
433      log_prob = dist.log_prob([0, 1])
434      self.assertEqual(2, log_prob.get_shape().ndims)
435      self.assertAllEqual([1, 2], log_prob.get_shape())
436
437      log_prob = dist.log_prob([[[1, 1], [1, 0]], [[1, 0], [0, 1]]])
438      self.assertEqual(3, log_prob.get_shape().ndims)
439      self.assertAllEqual([2, 2, 2], log_prob.get_shape())
440
441  def testLogPMFShapeNoBatch(self):
442    histograms = [0.2, 0.8]
443    dist = categorical.Categorical(math_ops.log(histograms))
444
445    log_prob = dist.log_prob(0)
446    self.assertEqual(0, log_prob.get_shape().ndims)
447    self.assertAllEqual([], log_prob.get_shape())
448
449    log_prob = dist.log_prob([[[1, 1], [1, 0]], [[1, 0], [0, 1]]])
450    self.assertEqual(3, log_prob.get_shape().ndims)
451    self.assertAllEqual([2, 2, 2], log_prob.get_shape())
452
453  @test_util.run_deprecated_v1
454  def testMode(self):
455    with self.cached_session():
456      histograms = [[[0.2, 0.8], [0.6, 0.4]]]
457      dist = categorical.Categorical(math_ops.log(histograms) - 50.)
458      self.assertAllEqual(dist.mode(), [[1, 0]])
459
460  @test_util.run_deprecated_v1
461  def testCategoricalCategoricalKL(self):
462
463    def np_softmax(logits):
464      exp_logits = np.exp(logits)
465      return exp_logits / exp_logits.sum(axis=-1, keepdims=True)
466
467    with self.cached_session() as sess:
468      for categories in [2, 4]:
469        for batch_size in [1, 10]:
470          a_logits = np.random.randn(batch_size, categories)
471          b_logits = np.random.randn(batch_size, categories)
472
473          a = categorical.Categorical(logits=a_logits)
474          b = categorical.Categorical(logits=b_logits)
475
476          kl = kullback_leibler.kl_divergence(a, b)
477          kl_val = self.evaluate(kl)
478          # Make sure KL(a||a) is 0
479          kl_same = sess.run(kullback_leibler.kl_divergence(a, a))
480
481          prob_a = np_softmax(a_logits)
482          prob_b = np_softmax(b_logits)
483          kl_expected = np.sum(prob_a * (np.log(prob_a) - np.log(prob_b)),
484                               axis=-1)
485
486          self.assertEqual(kl.get_shape(), (batch_size,))
487          self.assertAllClose(kl_val, kl_expected)
488          self.assertAllClose(kl_same, np.zeros_like(kl_expected))
489
490
491if __name__ == "__main__":
492  test.main()
493