1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for Categorical distribution.""" 16 17from absl.testing import parameterized 18import numpy as np 19 20from tensorflow.python.eager import backprop 21from tensorflow.python.framework import constant_op 22from tensorflow.python.framework import dtypes 23from tensorflow.python.framework import tensor_util 24from tensorflow.python.framework import test_util 25from tensorflow.python.ops import array_ops 26from tensorflow.python.ops import gradients_impl 27from tensorflow.python.ops import math_ops 28from tensorflow.python.ops import nn_ops 29from tensorflow.python.ops import random_ops 30from tensorflow.python.ops.distributions import categorical 31from tensorflow.python.ops.distributions import kullback_leibler 32from tensorflow.python.ops.distributions import normal 33from tensorflow.python.platform import test 34 35 36def make_categorical(batch_shape, num_classes, dtype=dtypes.int32): 37 logits = random_ops.random_uniform( 38 list(batch_shape) + [num_classes], -10, 10, dtype=dtypes.float32) - 50. 39 return categorical.Categorical(logits, dtype=dtype) 40 41 42class CategoricalTest(test.TestCase, parameterized.TestCase): 43 44 @test_util.run_deprecated_v1 45 def testP(self): 46 p = [0.2, 0.8] 47 dist = categorical.Categorical(probs=p) 48 with self.cached_session(): 49 self.assertAllClose(p, dist.probs) 50 self.assertAllEqual([2], dist.logits.get_shape()) 51 52 @test_util.run_deprecated_v1 53 def testLogits(self): 54 p = np.array([0.2, 0.8], dtype=np.float32) 55 logits = np.log(p) - 50. 56 dist = categorical.Categorical(logits=logits) 57 with self.cached_session(): 58 self.assertAllEqual([2], dist.probs.get_shape()) 59 self.assertAllEqual([2], dist.logits.get_shape()) 60 self.assertAllClose(dist.probs, p) 61 self.assertAllClose(dist.logits, logits) 62 63 @test_util.run_deprecated_v1 64 def testShapes(self): 65 with self.cached_session(): 66 for batch_shape in ([], [1], [2, 3, 4]): 67 dist = make_categorical(batch_shape, 10) 68 self.assertAllEqual(batch_shape, dist.batch_shape) 69 self.assertAllEqual(batch_shape, dist.batch_shape_tensor()) 70 self.assertAllEqual([], dist.event_shape) 71 self.assertAllEqual([], dist.event_shape_tensor()) 72 self.assertEqual(10, dist.event_size.eval()) 73 # event_size is available as a constant because the shape is 74 # known at graph build time. 75 self.assertEqual(10, tensor_util.constant_value(dist.event_size)) 76 77 for batch_shape in ([], [1], [2, 3, 4]): 78 dist = make_categorical( 79 batch_shape, constant_op.constant( 80 10, dtype=dtypes.int32)) 81 self.assertAllEqual(len(batch_shape), dist.batch_shape.ndims) 82 self.assertAllEqual(batch_shape, dist.batch_shape_tensor()) 83 self.assertAllEqual([], dist.event_shape) 84 self.assertAllEqual([], dist.event_shape_tensor()) 85 self.assertEqual(10, dist.event_size.eval()) 86 87 def testDtype(self): 88 dist = make_categorical([], 5, dtype=dtypes.int32) 89 self.assertEqual(dist.dtype, dtypes.int32) 90 self.assertEqual(dist.dtype, dist.sample(5).dtype) 91 self.assertEqual(dist.dtype, dist.mode().dtype) 92 dist = make_categorical([], 5, dtype=dtypes.int64) 93 self.assertEqual(dist.dtype, dtypes.int64) 94 self.assertEqual(dist.dtype, dist.sample(5).dtype) 95 self.assertEqual(dist.dtype, dist.mode().dtype) 96 self.assertEqual(dist.probs.dtype, dtypes.float32) 97 self.assertEqual(dist.logits.dtype, dtypes.float32) 98 self.assertEqual(dist.logits.dtype, dist.entropy().dtype) 99 self.assertEqual( 100 dist.logits.dtype, dist.prob(np.array( 101 0, dtype=np.int64)).dtype) 102 self.assertEqual( 103 dist.logits.dtype, dist.log_prob(np.array( 104 0, dtype=np.int64)).dtype) 105 for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]: 106 dist = make_categorical([], 5, dtype=dtype) 107 self.assertEqual(dist.dtype, dtype) 108 self.assertEqual(dist.dtype, dist.sample(5).dtype) 109 110 @test_util.run_deprecated_v1 111 def testUnknownShape(self): 112 with self.cached_session(): 113 logits = array_ops.placeholder(dtype=dtypes.float32) 114 dist = categorical.Categorical(logits) 115 sample = dist.sample() 116 # Will sample class 1. 117 sample_value = sample.eval(feed_dict={logits: [-1000.0, 1000.0]}) 118 self.assertEqual(1, sample_value) 119 120 # Batch entry 0 will sample class 1, batch entry 1 will sample class 0. 121 sample_value_batch = sample.eval( 122 feed_dict={logits: [[-1000.0, 1000.0], [1000.0, -1000.0]]}) 123 self.assertAllEqual([1, 0], sample_value_batch) 124 125 @test_util.run_deprecated_v1 126 def testPMFWithBatch(self): 127 histograms = [[0.2, 0.8], [0.6, 0.4]] 128 dist = categorical.Categorical(math_ops.log(histograms) - 50.) 129 with self.cached_session(): 130 self.assertAllClose(dist.prob([0, 1]), [0.2, 0.4]) 131 132 @test_util.run_deprecated_v1 133 def testPMFNoBatch(self): 134 histograms = [0.2, 0.8] 135 dist = categorical.Categorical(math_ops.log(histograms) - 50.) 136 with self.cached_session(): 137 self.assertAllClose(dist.prob(0), 0.2) 138 139 @test_util.run_deprecated_v1 140 def testCDFWithDynamicEventShapeKnownNdims(self): 141 """Test that dynamically-sized events with unknown shape work.""" 142 batch_size = 2 143 histograms = array_ops.placeholder(dtype=dtypes.float32, 144 shape=(batch_size, None)) 145 event = array_ops.placeholder(dtype=dtypes.float32, shape=(batch_size,)) 146 dist = categorical.Categorical(probs=histograms) 147 cdf_op = dist.cdf(event) 148 149 # Feed values into the placeholder with different shapes three classes. 150 event_feed_one = [0, 1] 151 histograms_feed_one = [[0.5, 0.3, 0.2], [1.0, 0.0, 0.0]] 152 expected_cdf_one = [0.0, 1.0] 153 feed_dict_one = { 154 histograms: histograms_feed_one, 155 event: event_feed_one 156 } 157 158 # six classes. 159 event_feed_two = [2, 5] 160 histograms_feed_two = [[0.9, 0.0, 0.0, 0.0, 0.0, 0.1], 161 [0.15, 0.2, 0.05, 0.35, 0.13, 0.12]] 162 expected_cdf_two = [0.9, 0.88] 163 feed_dict_two = { 164 histograms: histograms_feed_two, 165 event: event_feed_two 166 } 167 168 with self.cached_session() as sess: 169 actual_cdf_one = sess.run(cdf_op, feed_dict=feed_dict_one) 170 actual_cdf_two = sess.run(cdf_op, feed_dict=feed_dict_two) 171 172 self.assertAllClose(actual_cdf_one, expected_cdf_one) 173 self.assertAllClose(actual_cdf_two, expected_cdf_two) 174 175 @parameterized.named_parameters( 176 ("test1", [0, 1], [[0.5, 0.3, 0.2], [1.0, 0.0, 0.0]], [0.0, 1.0]), 177 ("test2", [2, 5], [[0.9, 0.0, 0.0, 0.0, 0.0, 0.1], 178 [0.15, 0.2, 0.05, 0.35, 0.13, 0.12]], [0.9, 0.88])) 179 def testCDFWithDynamicEventShapeUnknownNdims( 180 self, events, histograms, expected_cdf): 181 """Test that dynamically-sized events with unknown shape work.""" 182 event_ph = array_ops.placeholder_with_default(events, shape=None) 183 histograms_ph = array_ops.placeholder_with_default(histograms, shape=None) 184 dist = categorical.Categorical(probs=histograms_ph) 185 cdf_op = dist.cdf(event_ph) 186 187 actual_cdf = self.evaluate(cdf_op) 188 self.assertAllClose(actual_cdf, expected_cdf) 189 190 @test_util.run_deprecated_v1 191 def testCDFWithBatch(self): 192 histograms = [[0.1, 0.2, 0.3, 0.25, 0.15], 193 [0.0, 0.75, 0.2, 0.05, 0.0]] 194 event = [0, 3] 195 expected_cdf = [0.0, 0.95] 196 dist = categorical.Categorical(probs=histograms) 197 cdf_op = dist.cdf(event) 198 199 with self.cached_session(): 200 self.assertAllClose(cdf_op, expected_cdf) 201 202 @test_util.run_deprecated_v1 203 def testCDFNoBatch(self): 204 histogram = [0.1, 0.2, 0.3, 0.4] 205 event = 2 206 expected_cdf = 0.3 207 dist = categorical.Categorical(probs=histogram) 208 cdf_op = dist.cdf(event) 209 210 with self.cached_session(): 211 self.assertAlmostEqual(cdf_op.eval(), expected_cdf) 212 213 @test_util.run_deprecated_v1 214 def testCDFBroadcasting(self): 215 # shape: [batch=2, n_bins=3] 216 histograms = [[0.2, 0.1, 0.7], 217 [0.3, 0.45, 0.25]] 218 219 # shape: [batch=3, batch=2] 220 devent = [ 221 [0, 0], 222 [1, 1], 223 [2, 2] 224 ] 225 dist = categorical.Categorical(probs=histograms) 226 227 # We test that the probabilities are correctly broadcasted over the 228 # additional leading batch dimension of size 3. 229 expected_cdf_result = np.zeros((3, 2)) 230 expected_cdf_result[0, 0] = 0 231 expected_cdf_result[0, 1] = 0 232 expected_cdf_result[1, 0] = 0.2 233 expected_cdf_result[1, 1] = 0.3 234 expected_cdf_result[2, 0] = 0.3 235 expected_cdf_result[2, 1] = 0.75 236 237 with self.cached_session(): 238 self.assertAllClose(dist.cdf(devent), expected_cdf_result) 239 240 def testBroadcastWithBatchParamsAndBiggerEvent(self): 241 ## The parameters have a single batch dimension, and the event has two. 242 243 # param shape is [3 x 4], where 4 is the number of bins (non-batch dim). 244 cat_params_py = [ 245 [0.2, 0.15, 0.35, 0.3], 246 [0.1, 0.05, 0.68, 0.17], 247 [0.1, 0.05, 0.68, 0.17] 248 ] 249 250 # event shape = [5, 3], both are "batch" dimensions. 251 disc_event_py = [ 252 [0, 1, 2], 253 [1, 2, 3], 254 [0, 0, 0], 255 [1, 1, 1], 256 [2, 1, 0] 257 ] 258 259 # shape is [3] 260 normal_params_py = [ 261 -10.0, 262 120.0, 263 50.0 264 ] 265 266 # shape is [5, 3] 267 real_event_py = [ 268 [-1.0, 0.0, 1.0], 269 [100.0, 101, -50], 270 [90, 90, 90], 271 [-4, -400, 20.0], 272 [0.0, 0.0, 0.0] 273 ] 274 275 cat_params_tf = array_ops.constant(cat_params_py) 276 disc_event_tf = array_ops.constant(disc_event_py) 277 cat = categorical.Categorical(probs=cat_params_tf) 278 279 normal_params_tf = array_ops.constant(normal_params_py) 280 real_event_tf = array_ops.constant(real_event_py) 281 norm = normal.Normal(loc=normal_params_tf, scale=1.0) 282 283 # Check that normal and categorical have the same broadcasting behaviour. 284 to_run = { 285 "cat_prob": cat.prob(disc_event_tf), 286 "cat_log_prob": cat.log_prob(disc_event_tf), 287 "cat_cdf": cat.cdf(disc_event_tf), 288 "cat_log_cdf": cat.log_cdf(disc_event_tf), 289 "norm_prob": norm.prob(real_event_tf), 290 "norm_log_prob": norm.log_prob(real_event_tf), 291 "norm_cdf": norm.cdf(real_event_tf), 292 "norm_log_cdf": norm.log_cdf(real_event_tf), 293 } 294 295 with self.cached_session() as sess: 296 run_result = self.evaluate(to_run) 297 298 self.assertAllEqual(run_result["cat_prob"].shape, 299 run_result["norm_prob"].shape) 300 self.assertAllEqual(run_result["cat_log_prob"].shape, 301 run_result["norm_log_prob"].shape) 302 self.assertAllEqual(run_result["cat_cdf"].shape, 303 run_result["norm_cdf"].shape) 304 self.assertAllEqual(run_result["cat_log_cdf"].shape, 305 run_result["norm_log_cdf"].shape) 306 307 @test_util.run_deprecated_v1 308 def testLogPMF(self): 309 logits = np.log([[0.2, 0.8], [0.6, 0.4]]) - 50. 310 dist = categorical.Categorical(logits) 311 with self.cached_session(): 312 self.assertAllClose(dist.log_prob([0, 1]), np.log([0.2, 0.4])) 313 self.assertAllClose(dist.log_prob([0.0, 1.0]), np.log([0.2, 0.4])) 314 315 @test_util.run_deprecated_v1 316 def testEntropyNoBatch(self): 317 logits = np.log([0.2, 0.8]) - 50. 318 dist = categorical.Categorical(logits) 319 with self.cached_session(): 320 self.assertAllClose(dist.entropy(), 321 -(0.2 * np.log(0.2) + 0.8 * np.log(0.8))) 322 323 @test_util.run_deprecated_v1 324 def testEntropyWithBatch(self): 325 logits = np.log([[0.2, 0.8], [0.6, 0.4]]) - 50. 326 dist = categorical.Categorical(logits) 327 with self.cached_session(): 328 self.assertAllClose(dist.entropy(), [ 329 -(0.2 * np.log(0.2) + 0.8 * np.log(0.8)), 330 -(0.6 * np.log(0.6) + 0.4 * np.log(0.4)) 331 ]) 332 333 @test_util.run_deprecated_v1 334 def testEntropyGradient(self): 335 with self.cached_session() as sess: 336 logits = constant_op.constant([[1., 2., 3.], [2., 5., 1.]]) 337 338 probabilities = nn_ops.softmax(logits) 339 log_probabilities = nn_ops.log_softmax(logits) 340 true_entropy = - math_ops.reduce_sum( 341 probabilities * log_probabilities, axis=-1) 342 343 categorical_distribution = categorical.Categorical(probs=probabilities) 344 categorical_entropy = categorical_distribution.entropy() 345 346 # works 347 true_entropy_g = gradients_impl.gradients(true_entropy, [logits]) 348 categorical_entropy_g = gradients_impl.gradients( 349 categorical_entropy, [logits]) 350 351 res = sess.run({"true_entropy": true_entropy, 352 "categorical_entropy": categorical_entropy, 353 "true_entropy_g": true_entropy_g, 354 "categorical_entropy_g": categorical_entropy_g}) 355 self.assertAllClose(res["true_entropy"], 356 res["categorical_entropy"]) 357 self.assertAllClose(res["true_entropy_g"], 358 res["categorical_entropy_g"]) 359 360 def testSample(self): 361 with self.cached_session(): 362 histograms = [[[0.2, 0.8], [0.4, 0.6]]] 363 dist = categorical.Categorical(math_ops.log(histograms) - 50.) 364 n = 10000 365 samples = dist.sample(n, seed=123) 366 samples.set_shape([n, 1, 2]) 367 self.assertEqual(samples.dtype, dtypes.int32) 368 sample_values = self.evaluate(samples) 369 self.assertFalse(np.any(sample_values < 0)) 370 self.assertFalse(np.any(sample_values > 1)) 371 self.assertAllClose( 372 [[0.2, 0.4]], np.mean( 373 sample_values == 0, axis=0), atol=1e-2) 374 self.assertAllClose( 375 [[0.8, 0.6]], np.mean( 376 sample_values == 1, axis=0), atol=1e-2) 377 378 def testSampleWithSampleShape(self): 379 with self.cached_session(): 380 histograms = [[[0.2, 0.8], [0.4, 0.6]]] 381 dist = categorical.Categorical(math_ops.log(histograms) - 50.) 382 samples = dist.sample((100, 100), seed=123) 383 prob = dist.prob(samples) 384 prob_val = self.evaluate(prob) 385 self.assertAllClose( 386 [0.2**2 + 0.8**2], [prob_val[:, :, :, 0].mean()], atol=1e-2) 387 self.assertAllClose( 388 [0.4**2 + 0.6**2], [prob_val[:, :, :, 1].mean()], atol=1e-2) 389 390 def testNotReparameterized(self): 391 p = constant_op.constant([0.3, 0.3, 0.4]) 392 with backprop.GradientTape() as tape: 393 tape.watch(p) 394 dist = categorical.Categorical(p) 395 samples = dist.sample(100) 396 grad_p = tape.gradient(samples, p) 397 self.assertIsNone(grad_p) 398 399 def testLogPMFBroadcasting(self): 400 with self.cached_session(): 401 # 1 x 2 x 2 402 histograms = [[[0.2, 0.8], [0.4, 0.6]]] 403 dist = categorical.Categorical(math_ops.log(histograms) - 50.) 404 405 prob = dist.prob(1) 406 self.assertAllClose([[0.8, 0.6]], self.evaluate(prob)) 407 408 prob = dist.prob([1]) 409 self.assertAllClose([[0.8, 0.6]], self.evaluate(prob)) 410 411 prob = dist.prob([0, 1]) 412 self.assertAllClose([[0.2, 0.6]], self.evaluate(prob)) 413 414 prob = dist.prob([[0, 1]]) 415 self.assertAllClose([[0.2, 0.6]], self.evaluate(prob)) 416 417 prob = dist.prob([[[0, 1]]]) 418 self.assertAllClose([[[0.2, 0.6]]], self.evaluate(prob)) 419 420 prob = dist.prob([[1, 0], [0, 1]]) 421 self.assertAllClose([[0.8, 0.4], [0.2, 0.6]], self.evaluate(prob)) 422 423 prob = dist.prob([[[1, 1], [1, 0]], [[1, 0], [0, 1]]]) 424 self.assertAllClose([[[0.8, 0.6], [0.8, 0.4]], [[0.8, 0.4], [0.2, 0.6]]], 425 self.evaluate(prob)) 426 427 def testLogPMFShape(self): 428 with self.cached_session(): 429 # shape [1, 2, 2] 430 histograms = [[[0.2, 0.8], [0.4, 0.6]]] 431 dist = categorical.Categorical(math_ops.log(histograms)) 432 433 log_prob = dist.log_prob([0, 1]) 434 self.assertEqual(2, log_prob.get_shape().ndims) 435 self.assertAllEqual([1, 2], log_prob.get_shape()) 436 437 log_prob = dist.log_prob([[[1, 1], [1, 0]], [[1, 0], [0, 1]]]) 438 self.assertEqual(3, log_prob.get_shape().ndims) 439 self.assertAllEqual([2, 2, 2], log_prob.get_shape()) 440 441 def testLogPMFShapeNoBatch(self): 442 histograms = [0.2, 0.8] 443 dist = categorical.Categorical(math_ops.log(histograms)) 444 445 log_prob = dist.log_prob(0) 446 self.assertEqual(0, log_prob.get_shape().ndims) 447 self.assertAllEqual([], log_prob.get_shape()) 448 449 log_prob = dist.log_prob([[[1, 1], [1, 0]], [[1, 0], [0, 1]]]) 450 self.assertEqual(3, log_prob.get_shape().ndims) 451 self.assertAllEqual([2, 2, 2], log_prob.get_shape()) 452 453 @test_util.run_deprecated_v1 454 def testMode(self): 455 with self.cached_session(): 456 histograms = [[[0.2, 0.8], [0.6, 0.4]]] 457 dist = categorical.Categorical(math_ops.log(histograms) - 50.) 458 self.assertAllEqual(dist.mode(), [[1, 0]]) 459 460 @test_util.run_deprecated_v1 461 def testCategoricalCategoricalKL(self): 462 463 def np_softmax(logits): 464 exp_logits = np.exp(logits) 465 return exp_logits / exp_logits.sum(axis=-1, keepdims=True) 466 467 with self.cached_session() as sess: 468 for categories in [2, 4]: 469 for batch_size in [1, 10]: 470 a_logits = np.random.randn(batch_size, categories) 471 b_logits = np.random.randn(batch_size, categories) 472 473 a = categorical.Categorical(logits=a_logits) 474 b = categorical.Categorical(logits=b_logits) 475 476 kl = kullback_leibler.kl_divergence(a, b) 477 kl_val = self.evaluate(kl) 478 # Make sure KL(a||a) is 0 479 kl_same = sess.run(kullback_leibler.kl_divergence(a, a)) 480 481 prob_a = np_softmax(a_logits) 482 prob_b = np_softmax(b_logits) 483 kl_expected = np.sum(prob_a * (np.log(prob_a) - np.log(prob_b)), 484 axis=-1) 485 486 self.assertEqual(kl.get_shape(), (batch_size,)) 487 self.assertAllClose(kl_val, kl_expected) 488 self.assertAllClose(kl_same, np.zeros_like(kl_expected)) 489 490 491if __name__ == "__main__": 492 test.main() 493