1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for Categorical distribution.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from absl.testing import parameterized 22import numpy as np 23 24from tensorflow.python.eager import backprop 25from tensorflow.python.framework import constant_op 26from tensorflow.python.framework import dtypes 27from tensorflow.python.framework import tensor_util 28from tensorflow.python.framework import test_util 29from tensorflow.python.ops import array_ops 30from tensorflow.python.ops import gradients_impl 31from tensorflow.python.ops import math_ops 32from tensorflow.python.ops import nn_ops 33from tensorflow.python.ops import random_ops 34from tensorflow.python.ops.distributions import categorical 35from tensorflow.python.ops.distributions import kullback_leibler 36from tensorflow.python.ops.distributions import normal 37from tensorflow.python.platform import test 38 39 40def make_categorical(batch_shape, num_classes, dtype=dtypes.int32): 41 logits = random_ops.random_uniform( 42 list(batch_shape) + [num_classes], -10, 10, dtype=dtypes.float32) - 50. 43 return categorical.Categorical(logits, dtype=dtype) 44 45 46class CategoricalTest(test.TestCase, parameterized.TestCase): 47 48 @test_util.run_deprecated_v1 49 def testP(self): 50 p = [0.2, 0.8] 51 dist = categorical.Categorical(probs=p) 52 with self.cached_session(): 53 self.assertAllClose(p, dist.probs.eval()) 54 self.assertAllEqual([2], dist.logits.get_shape()) 55 56 @test_util.run_deprecated_v1 57 def testLogits(self): 58 p = np.array([0.2, 0.8], dtype=np.float32) 59 logits = np.log(p) - 50. 60 dist = categorical.Categorical(logits=logits) 61 with self.cached_session(): 62 self.assertAllEqual([2], dist.probs.get_shape()) 63 self.assertAllEqual([2], dist.logits.get_shape()) 64 self.assertAllClose(dist.probs.eval(), p) 65 self.assertAllClose(dist.logits.eval(), logits) 66 67 @test_util.run_deprecated_v1 68 def testShapes(self): 69 with self.cached_session(): 70 for batch_shape in ([], [1], [2, 3, 4]): 71 dist = make_categorical(batch_shape, 10) 72 self.assertAllEqual(batch_shape, dist.batch_shape) 73 self.assertAllEqual(batch_shape, dist.batch_shape_tensor().eval()) 74 self.assertAllEqual([], dist.event_shape) 75 self.assertAllEqual([], dist.event_shape_tensor().eval()) 76 self.assertEqual(10, dist.event_size.eval()) 77 # event_size is available as a constant because the shape is 78 # known at graph build time. 79 self.assertEqual(10, tensor_util.constant_value(dist.event_size)) 80 81 for batch_shape in ([], [1], [2, 3, 4]): 82 dist = make_categorical( 83 batch_shape, constant_op.constant( 84 10, dtype=dtypes.int32)) 85 self.assertAllEqual(len(batch_shape), dist.batch_shape.ndims) 86 self.assertAllEqual(batch_shape, dist.batch_shape_tensor().eval()) 87 self.assertAllEqual([], dist.event_shape) 88 self.assertAllEqual([], dist.event_shape_tensor().eval()) 89 self.assertEqual(10, dist.event_size.eval()) 90 91 def testDtype(self): 92 dist = make_categorical([], 5, dtype=dtypes.int32) 93 self.assertEqual(dist.dtype, dtypes.int32) 94 self.assertEqual(dist.dtype, dist.sample(5).dtype) 95 self.assertEqual(dist.dtype, dist.mode().dtype) 96 dist = make_categorical([], 5, dtype=dtypes.int64) 97 self.assertEqual(dist.dtype, dtypes.int64) 98 self.assertEqual(dist.dtype, dist.sample(5).dtype) 99 self.assertEqual(dist.dtype, dist.mode().dtype) 100 self.assertEqual(dist.probs.dtype, dtypes.float32) 101 self.assertEqual(dist.logits.dtype, dtypes.float32) 102 self.assertEqual(dist.logits.dtype, dist.entropy().dtype) 103 self.assertEqual( 104 dist.logits.dtype, dist.prob(np.array( 105 0, dtype=np.int64)).dtype) 106 self.assertEqual( 107 dist.logits.dtype, dist.log_prob(np.array( 108 0, dtype=np.int64)).dtype) 109 for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]: 110 dist = make_categorical([], 5, dtype=dtype) 111 self.assertEqual(dist.dtype, dtype) 112 self.assertEqual(dist.dtype, dist.sample(5).dtype) 113 114 @test_util.run_deprecated_v1 115 def testUnknownShape(self): 116 with self.cached_session(): 117 logits = array_ops.placeholder(dtype=dtypes.float32) 118 dist = categorical.Categorical(logits) 119 sample = dist.sample() 120 # Will sample class 1. 121 sample_value = sample.eval(feed_dict={logits: [-1000.0, 1000.0]}) 122 self.assertEqual(1, sample_value) 123 124 # Batch entry 0 will sample class 1, batch entry 1 will sample class 0. 125 sample_value_batch = sample.eval( 126 feed_dict={logits: [[-1000.0, 1000.0], [1000.0, -1000.0]]}) 127 self.assertAllEqual([1, 0], sample_value_batch) 128 129 @test_util.run_deprecated_v1 130 def testPMFWithBatch(self): 131 histograms = [[0.2, 0.8], [0.6, 0.4]] 132 dist = categorical.Categorical(math_ops.log(histograms) - 50.) 133 with self.cached_session(): 134 self.assertAllClose(dist.prob([0, 1]).eval(), [0.2, 0.4]) 135 136 @test_util.run_deprecated_v1 137 def testPMFNoBatch(self): 138 histograms = [0.2, 0.8] 139 dist = categorical.Categorical(math_ops.log(histograms) - 50.) 140 with self.cached_session(): 141 self.assertAllClose(dist.prob(0).eval(), 0.2) 142 143 @test_util.run_deprecated_v1 144 def testCDFWithDynamicEventShapeKnownNdims(self): 145 """Test that dynamically-sized events with unknown shape work.""" 146 batch_size = 2 147 histograms = array_ops.placeholder(dtype=dtypes.float32, 148 shape=(batch_size, None)) 149 event = array_ops.placeholder(dtype=dtypes.float32, shape=(batch_size,)) 150 dist = categorical.Categorical(probs=histograms) 151 cdf_op = dist.cdf(event) 152 153 # Feed values into the placeholder with different shapes 154 # three classes. 155 event_feed_one = [0, 1] 156 histograms_feed_one = [[0.5, 0.3, 0.2], [1.0, 0.0, 0.0]] 157 expected_cdf_one = [0.0, 1.0] 158 feed_dict_one = { 159 histograms: histograms_feed_one, 160 event: event_feed_one 161 } 162 163 # six classes. 164 event_feed_two = [2, 5] 165 histograms_feed_two = [[0.9, 0.0, 0.0, 0.0, 0.0, 0.1], 166 [0.15, 0.2, 0.05, 0.35, 0.13, 0.12]] 167 expected_cdf_two = [0.9, 0.88] 168 feed_dict_two = { 169 histograms: histograms_feed_two, 170 event: event_feed_two 171 } 172 173 with self.cached_session() as sess: 174 actual_cdf_one = sess.run(cdf_op, feed_dict=feed_dict_one) 175 actual_cdf_two = sess.run(cdf_op, feed_dict=feed_dict_two) 176 177 self.assertAllClose(actual_cdf_one, expected_cdf_one) 178 self.assertAllClose(actual_cdf_two, expected_cdf_two) 179 180 @parameterized.named_parameters( 181 ("test1", [0, 1], [[0.5, 0.3, 0.2], [1.0, 0.0, 0.0]], [0.0, 1.0]), 182 ("test2", [2, 5], [[0.9, 0.0, 0.0, 0.0, 0.0, 0.1], 183 [0.15, 0.2, 0.05, 0.35, 0.13, 0.12]], [0.9, 0.88])) 184 def testCDFWithDynamicEventShapeUnknownNdims( 185 self, events, histograms, expected_cdf): 186 """Test that dynamically-sized events with unknown shape work.""" 187 event_ph = array_ops.placeholder_with_default(events, shape=None) 188 histograms_ph = array_ops.placeholder_with_default(histograms, shape=None) 189 dist = categorical.Categorical(probs=histograms_ph) 190 cdf_op = dist.cdf(event_ph) 191 192 actual_cdf = self.evaluate(cdf_op) 193 self.assertAllClose(actual_cdf, expected_cdf) 194 195 @test_util.run_deprecated_v1 196 def testCDFWithBatch(self): 197 histograms = [[0.1, 0.2, 0.3, 0.25, 0.15], 198 [0.0, 0.75, 0.2, 0.05, 0.0]] 199 event = [0, 3] 200 expected_cdf = [0.0, 0.95] 201 dist = categorical.Categorical(probs=histograms) 202 cdf_op = dist.cdf(event) 203 204 with self.cached_session(): 205 self.assertAllClose(cdf_op.eval(), expected_cdf) 206 207 @test_util.run_deprecated_v1 208 def testCDFNoBatch(self): 209 histogram = [0.1, 0.2, 0.3, 0.4] 210 event = 2 211 expected_cdf = 0.3 212 dist = categorical.Categorical(probs=histogram) 213 cdf_op = dist.cdf(event) 214 215 with self.cached_session(): 216 self.assertAlmostEqual(cdf_op.eval(), expected_cdf) 217 218 @test_util.run_deprecated_v1 219 def testCDFBroadcasting(self): 220 # shape: [batch=2, n_bins=3] 221 histograms = [[0.2, 0.1, 0.7], 222 [0.3, 0.45, 0.25]] 223 224 # shape: [batch=3, batch=2] 225 devent = [ 226 [0, 0], 227 [1, 1], 228 [2, 2] 229 ] 230 dist = categorical.Categorical(probs=histograms) 231 232 # We test that the probabilities are correctly broadcasted over the 233 # additional leading batch dimension of size 3. 234 expected_cdf_result = np.zeros((3, 2)) 235 expected_cdf_result[0, 0] = 0 236 expected_cdf_result[0, 1] = 0 237 expected_cdf_result[1, 0] = 0.2 238 expected_cdf_result[1, 1] = 0.3 239 expected_cdf_result[2, 0] = 0.3 240 expected_cdf_result[2, 1] = 0.75 241 242 with self.cached_session(): 243 self.assertAllClose(dist.cdf(devent).eval(), expected_cdf_result) 244 245 def testBroadcastWithBatchParamsAndBiggerEvent(self): 246 ## The parameters have a single batch dimension, and the event has two. 247 248 # param shape is [3 x 4], where 4 is the number of bins (non-batch dim). 249 cat_params_py = [ 250 [0.2, 0.15, 0.35, 0.3], 251 [0.1, 0.05, 0.68, 0.17], 252 [0.1, 0.05, 0.68, 0.17] 253 ] 254 255 # event shape = [5, 3], both are "batch" dimensions. 256 disc_event_py = [ 257 [0, 1, 2], 258 [1, 2, 3], 259 [0, 0, 0], 260 [1, 1, 1], 261 [2, 1, 0] 262 ] 263 264 # shape is [3] 265 normal_params_py = [ 266 -10.0, 267 120.0, 268 50.0 269 ] 270 271 # shape is [5, 3] 272 real_event_py = [ 273 [-1.0, 0.0, 1.0], 274 [100.0, 101, -50], 275 [90, 90, 90], 276 [-4, -400, 20.0], 277 [0.0, 0.0, 0.0] 278 ] 279 280 cat_params_tf = array_ops.constant(cat_params_py) 281 disc_event_tf = array_ops.constant(disc_event_py) 282 cat = categorical.Categorical(probs=cat_params_tf) 283 284 normal_params_tf = array_ops.constant(normal_params_py) 285 real_event_tf = array_ops.constant(real_event_py) 286 norm = normal.Normal(loc=normal_params_tf, scale=1.0) 287 288 # Check that normal and categorical have the same broadcasting behaviour. 289 to_run = { 290 "cat_prob": cat.prob(disc_event_tf), 291 "cat_log_prob": cat.log_prob(disc_event_tf), 292 "cat_cdf": cat.cdf(disc_event_tf), 293 "cat_log_cdf": cat.log_cdf(disc_event_tf), 294 "norm_prob": norm.prob(real_event_tf), 295 "norm_log_prob": norm.log_prob(real_event_tf), 296 "norm_cdf": norm.cdf(real_event_tf), 297 "norm_log_cdf": norm.log_cdf(real_event_tf), 298 } 299 300 with self.cached_session() as sess: 301 run_result = self.evaluate(to_run) 302 303 self.assertAllEqual(run_result["cat_prob"].shape, 304 run_result["norm_prob"].shape) 305 self.assertAllEqual(run_result["cat_log_prob"].shape, 306 run_result["norm_log_prob"].shape) 307 self.assertAllEqual(run_result["cat_cdf"].shape, 308 run_result["norm_cdf"].shape) 309 self.assertAllEqual(run_result["cat_log_cdf"].shape, 310 run_result["norm_log_cdf"].shape) 311 312 @test_util.run_deprecated_v1 313 def testLogPMF(self): 314 logits = np.log([[0.2, 0.8], [0.6, 0.4]]) - 50. 315 dist = categorical.Categorical(logits) 316 with self.cached_session(): 317 self.assertAllClose(dist.log_prob([0, 1]).eval(), np.log([0.2, 0.4])) 318 self.assertAllClose(dist.log_prob([0.0, 1.0]).eval(), np.log([0.2, 0.4])) 319 320 @test_util.run_deprecated_v1 321 def testEntropyNoBatch(self): 322 logits = np.log([0.2, 0.8]) - 50. 323 dist = categorical.Categorical(logits) 324 with self.cached_session(): 325 self.assertAllClose(dist.entropy().eval(), 326 -(0.2 * np.log(0.2) + 0.8 * np.log(0.8))) 327 328 @test_util.run_deprecated_v1 329 def testEntropyWithBatch(self): 330 logits = np.log([[0.2, 0.8], [0.6, 0.4]]) - 50. 331 dist = categorical.Categorical(logits) 332 with self.cached_session(): 333 self.assertAllClose(dist.entropy().eval(), [ 334 -(0.2 * np.log(0.2) + 0.8 * np.log(0.8)), 335 -(0.6 * np.log(0.6) + 0.4 * np.log(0.4)) 336 ]) 337 338 @test_util.run_deprecated_v1 339 def testEntropyGradient(self): 340 with self.cached_session() as sess: 341 logits = constant_op.constant([[1., 2., 3.], [2., 5., 1.]]) 342 343 probabilities = nn_ops.softmax(logits) 344 log_probabilities = nn_ops.log_softmax(logits) 345 true_entropy = - math_ops.reduce_sum( 346 probabilities * log_probabilities, axis=-1) 347 348 categorical_distribution = categorical.Categorical(probs=probabilities) 349 categorical_entropy = categorical_distribution.entropy() 350 351 # works 352 true_entropy_g = gradients_impl.gradients(true_entropy, [logits]) 353 categorical_entropy_g = gradients_impl.gradients( 354 categorical_entropy, [logits]) 355 356 res = sess.run({"true_entropy": true_entropy, 357 "categorical_entropy": categorical_entropy, 358 "true_entropy_g": true_entropy_g, 359 "categorical_entropy_g": categorical_entropy_g}) 360 self.assertAllClose(res["true_entropy"], 361 res["categorical_entropy"]) 362 self.assertAllClose(res["true_entropy_g"], 363 res["categorical_entropy_g"]) 364 365 def testSample(self): 366 with self.cached_session(): 367 histograms = [[[0.2, 0.8], [0.4, 0.6]]] 368 dist = categorical.Categorical(math_ops.log(histograms) - 50.) 369 n = 10000 370 samples = dist.sample(n, seed=123) 371 samples.set_shape([n, 1, 2]) 372 self.assertEqual(samples.dtype, dtypes.int32) 373 sample_values = self.evaluate(samples) 374 self.assertFalse(np.any(sample_values < 0)) 375 self.assertFalse(np.any(sample_values > 1)) 376 self.assertAllClose( 377 [[0.2, 0.4]], np.mean( 378 sample_values == 0, axis=0), atol=1e-2) 379 self.assertAllClose( 380 [[0.8, 0.6]], np.mean( 381 sample_values == 1, axis=0), atol=1e-2) 382 383 def testSampleWithSampleShape(self): 384 with self.cached_session(): 385 histograms = [[[0.2, 0.8], [0.4, 0.6]]] 386 dist = categorical.Categorical(math_ops.log(histograms) - 50.) 387 samples = dist.sample((100, 100), seed=123) 388 prob = dist.prob(samples) 389 prob_val = self.evaluate(prob) 390 self.assertAllClose( 391 [0.2**2 + 0.8**2], [prob_val[:, :, :, 0].mean()], atol=1e-2) 392 self.assertAllClose( 393 [0.4**2 + 0.6**2], [prob_val[:, :, :, 1].mean()], atol=1e-2) 394 395 def testNotReparameterized(self): 396 p = constant_op.constant([0.3, 0.3, 0.4]) 397 with backprop.GradientTape() as tape: 398 tape.watch(p) 399 dist = categorical.Categorical(p) 400 samples = dist.sample(100) 401 grad_p = tape.gradient(samples, p) 402 self.assertIsNone(grad_p) 403 404 def testLogPMFBroadcasting(self): 405 with self.cached_session(): 406 # 1 x 2 x 2 407 histograms = [[[0.2, 0.8], [0.4, 0.6]]] 408 dist = categorical.Categorical(math_ops.log(histograms) - 50.) 409 410 prob = dist.prob(1) 411 self.assertAllClose([[0.8, 0.6]], self.evaluate(prob)) 412 413 prob = dist.prob([1]) 414 self.assertAllClose([[0.8, 0.6]], self.evaluate(prob)) 415 416 prob = dist.prob([0, 1]) 417 self.assertAllClose([[0.2, 0.6]], self.evaluate(prob)) 418 419 prob = dist.prob([[0, 1]]) 420 self.assertAllClose([[0.2, 0.6]], self.evaluate(prob)) 421 422 prob = dist.prob([[[0, 1]]]) 423 self.assertAllClose([[[0.2, 0.6]]], self.evaluate(prob)) 424 425 prob = dist.prob([[1, 0], [0, 1]]) 426 self.assertAllClose([[0.8, 0.4], [0.2, 0.6]], self.evaluate(prob)) 427 428 prob = dist.prob([[[1, 1], [1, 0]], [[1, 0], [0, 1]]]) 429 self.assertAllClose([[[0.8, 0.6], [0.8, 0.4]], [[0.8, 0.4], [0.2, 0.6]]], 430 self.evaluate(prob)) 431 432 def testLogPMFShape(self): 433 with self.cached_session(): 434 # shape [1, 2, 2] 435 histograms = [[[0.2, 0.8], [0.4, 0.6]]] 436 dist = categorical.Categorical(math_ops.log(histograms)) 437 438 log_prob = dist.log_prob([0, 1]) 439 self.assertEqual(2, log_prob.get_shape().ndims) 440 self.assertAllEqual([1, 2], log_prob.get_shape()) 441 442 log_prob = dist.log_prob([[[1, 1], [1, 0]], [[1, 0], [0, 1]]]) 443 self.assertEqual(3, log_prob.get_shape().ndims) 444 self.assertAllEqual([2, 2, 2], log_prob.get_shape()) 445 446 def testLogPMFShapeNoBatch(self): 447 histograms = [0.2, 0.8] 448 dist = categorical.Categorical(math_ops.log(histograms)) 449 450 log_prob = dist.log_prob(0) 451 self.assertEqual(0, log_prob.get_shape().ndims) 452 self.assertAllEqual([], log_prob.get_shape()) 453 454 log_prob = dist.log_prob([[[1, 1], [1, 0]], [[1, 0], [0, 1]]]) 455 self.assertEqual(3, log_prob.get_shape().ndims) 456 self.assertAllEqual([2, 2, 2], log_prob.get_shape()) 457 458 @test_util.run_deprecated_v1 459 def testMode(self): 460 with self.cached_session(): 461 histograms = [[[0.2, 0.8], [0.6, 0.4]]] 462 dist = categorical.Categorical(math_ops.log(histograms) - 50.) 463 self.assertAllEqual(dist.mode().eval(), [[1, 0]]) 464 465 @test_util.run_deprecated_v1 466 def testCategoricalCategoricalKL(self): 467 468 def np_softmax(logits): 469 exp_logits = np.exp(logits) 470 return exp_logits / exp_logits.sum(axis=-1, keepdims=True) 471 472 with self.cached_session() as sess: 473 for categories in [2, 4]: 474 for batch_size in [1, 10]: 475 a_logits = np.random.randn(batch_size, categories) 476 b_logits = np.random.randn(batch_size, categories) 477 478 a = categorical.Categorical(logits=a_logits) 479 b = categorical.Categorical(logits=b_logits) 480 481 kl = kullback_leibler.kl_divergence(a, b) 482 kl_val = self.evaluate(kl) 483 # Make sure KL(a||a) is 0 484 kl_same = sess.run(kullback_leibler.kl_divergence(a, a)) 485 486 prob_a = np_softmax(a_logits) 487 prob_b = np_softmax(b_logits) 488 kl_expected = np.sum(prob_a * (np.log(prob_a) - np.log(prob_b)), 489 axis=-1) 490 491 self.assertEqual(kl.get_shape(), (batch_size,)) 492 self.assertAllClose(kl_val, kl_expected) 493 self.assertAllClose(kl_same, np.zeros_like(kl_expected)) 494 495 496if __name__ == "__main__": 497 test.main() 498