1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15from __future__ import absolute_import 16from __future__ import division 17from __future__ import print_function 18 19import importlib 20 21import numpy as np 22 23from tensorflow.python.eager import backprop 24from tensorflow.python.framework import constant_op 25from tensorflow.python.framework import tensor_shape 26from tensorflow.python.framework import test_util 27from tensorflow.python.ops import array_ops 28from tensorflow.python.ops import math_ops 29from tensorflow.python.ops.distributions import dirichlet as dirichlet_lib 30from tensorflow.python.ops.distributions import kullback_leibler 31from tensorflow.python.platform import test 32from tensorflow.python.platform import tf_logging 33 34 35def try_import(name): # pylint: disable=invalid-name 36 module = None 37 try: 38 module = importlib.import_module(name) 39 except ImportError as e: 40 tf_logging.warning("Could not import %s: %s" % (name, str(e))) 41 return module 42 43 44special = try_import("scipy.special") 45stats = try_import("scipy.stats") 46 47 48@test_util.run_all_in_graph_and_eager_modes 49class DirichletTest(test.TestCase): 50 51 def testSimpleShapes(self): 52 alpha = np.random.rand(3) 53 dist = dirichlet_lib.Dirichlet(alpha) 54 self.assertEqual(3, self.evaluate(dist.event_shape_tensor())) 55 self.assertAllEqual([], self.evaluate(dist.batch_shape_tensor())) 56 self.assertEqual(tensor_shape.TensorShape([3]), dist.event_shape) 57 self.assertEqual(tensor_shape.TensorShape([]), dist.batch_shape) 58 59 def testComplexShapes(self): 60 alpha = np.random.rand(3, 2, 2) 61 dist = dirichlet_lib.Dirichlet(alpha) 62 self.assertEqual(2, self.evaluate(dist.event_shape_tensor())) 63 self.assertAllEqual([3, 2], self.evaluate(dist.batch_shape_tensor())) 64 self.assertEqual(tensor_shape.TensorShape([2]), dist.event_shape) 65 self.assertEqual(tensor_shape.TensorShape([3, 2]), dist.batch_shape) 66 67 def testConcentrationProperty(self): 68 alpha = [[1., 2, 3]] 69 dist = dirichlet_lib.Dirichlet(alpha) 70 self.assertEqual([1, 3], dist.concentration.get_shape()) 71 self.assertAllClose(alpha, self.evaluate(dist.concentration)) 72 73 def testPdfXProper(self): 74 alpha = [[1., 2, 3]] 75 dist = dirichlet_lib.Dirichlet(alpha, validate_args=True) 76 self.evaluate(dist.prob([.1, .3, .6])) 77 self.evaluate(dist.prob([.2, .3, .5])) 78 # Either condition can trigger. 79 with self.assertRaisesOpError("samples must be positive"): 80 self.evaluate(dist.prob([-1., 1.5, 0.5])) 81 with self.assertRaisesOpError("samples must be positive"): 82 self.evaluate(dist.prob([0., .1, .9])) 83 with self.assertRaisesOpError("sample last-dimension must sum to `1`"): 84 self.evaluate(dist.prob([.1, .2, .8])) 85 86 def testLogPdfOnBoundaryIsFiniteWhenAlphaIsOne(self): 87 # Test concentration = 1. for each dimension. 88 concentration = 3 * np.ones((10, 10)).astype(np.float32) 89 concentration[range(10), range(10)] = 1. 90 x = 1 / 9. * np.ones((10, 10)).astype(np.float32) 91 x[range(10), range(10)] = 0. 92 dist = dirichlet_lib.Dirichlet(concentration) 93 log_prob = self.evaluate(dist.log_prob(x)) 94 self.assertAllEqual( 95 np.ones_like(log_prob, dtype=np.bool), np.isfinite(log_prob)) 96 97 # Test when concentration[k] = 1., and x is zero at various dimensions. 98 dist = dirichlet_lib.Dirichlet(10 * [1.]) 99 log_prob = self.evaluate(dist.log_prob(x)) 100 self.assertAllEqual( 101 np.ones_like(log_prob, dtype=np.bool), np.isfinite(log_prob)) 102 103 def testPdfZeroBatches(self): 104 alpha = [1., 2] 105 x = [.5, .5] 106 dist = dirichlet_lib.Dirichlet(alpha) 107 pdf = dist.prob(x) 108 self.assertAllClose(1., self.evaluate(pdf)) 109 self.assertEqual((), pdf.get_shape()) 110 111 def testPdfZeroBatchesNontrivialX(self): 112 alpha = [1., 2] 113 x = [.3, .7] 114 dist = dirichlet_lib.Dirichlet(alpha) 115 pdf = dist.prob(x) 116 self.assertAllClose(7. / 5, self.evaluate(pdf)) 117 self.assertEqual((), pdf.get_shape()) 118 119 def testPdfUniformZeroBatches(self): 120 # Corresponds to a uniform distribution 121 alpha = [1., 1, 1] 122 x = [[.2, .5, .3], [.3, .4, .3]] 123 dist = dirichlet_lib.Dirichlet(alpha) 124 pdf = dist.prob(x) 125 self.assertAllClose([2., 2.], self.evaluate(pdf)) 126 self.assertEqual((2), pdf.get_shape()) 127 128 def testPdfAlphaStretchedInBroadcastWhenSameRank(self): 129 alpha = [[1., 2]] 130 x = [[.5, .5], [.3, .7]] 131 dist = dirichlet_lib.Dirichlet(alpha) 132 pdf = dist.prob(x) 133 self.assertAllClose([1., 7. / 5], self.evaluate(pdf)) 134 self.assertEqual((2), pdf.get_shape()) 135 136 def testPdfAlphaStretchedInBroadcastWhenLowerRank(self): 137 alpha = [1., 2] 138 x = [[.5, .5], [.2, .8]] 139 pdf = dirichlet_lib.Dirichlet(alpha).prob(x) 140 self.assertAllClose([1., 8. / 5], self.evaluate(pdf)) 141 self.assertEqual((2), pdf.get_shape()) 142 143 def testPdfXStretchedInBroadcastWhenSameRank(self): 144 alpha = [[1., 2], [2., 3]] 145 x = [[.5, .5]] 146 pdf = dirichlet_lib.Dirichlet(alpha).prob(x) 147 self.assertAllClose([1., 3. / 2], self.evaluate(pdf)) 148 self.assertEqual((2), pdf.get_shape()) 149 150 def testPdfXStretchedInBroadcastWhenLowerRank(self): 151 alpha = [[1., 2], [2., 3]] 152 x = [.5, .5] 153 pdf = dirichlet_lib.Dirichlet(alpha).prob(x) 154 self.assertAllClose([1., 3. / 2], self.evaluate(pdf)) 155 self.assertEqual((2), pdf.get_shape()) 156 157 def testMean(self): 158 alpha = [1., 2, 3] 159 dirichlet = dirichlet_lib.Dirichlet(concentration=alpha) 160 self.assertEqual(dirichlet.mean().get_shape(), [3]) 161 if not stats: 162 return 163 expected_mean = stats.dirichlet.mean(alpha) 164 self.assertAllClose(self.evaluate(dirichlet.mean()), expected_mean) 165 166 def testCovarianceFromSampling(self): 167 alpha = np.array([[1., 2, 3], 168 [2.5, 4, 0.01]], dtype=np.float32) 169 dist = dirichlet_lib.Dirichlet(alpha) # batch_shape=[2], event_shape=[3] 170 x = dist.sample(int(250e3), seed=1) 171 sample_mean = math_ops.reduce_mean(x, 0) 172 x_centered = x - sample_mean[None, ...] 173 sample_cov = math_ops.reduce_mean(math_ops.matmul( 174 x_centered[..., None], x_centered[..., None, :]), 0) 175 sample_var = array_ops.matrix_diag_part(sample_cov) 176 sample_stddev = math_ops.sqrt(sample_var) 177 178 [ 179 sample_mean_, 180 sample_cov_, 181 sample_var_, 182 sample_stddev_, 183 analytic_mean, 184 analytic_cov, 185 analytic_var, 186 analytic_stddev, 187 ] = self.evaluate([ 188 sample_mean, 189 sample_cov, 190 sample_var, 191 sample_stddev, 192 dist.mean(), 193 dist.covariance(), 194 dist.variance(), 195 dist.stddev(), 196 ]) 197 198 self.assertAllClose(sample_mean_, analytic_mean, atol=0.04, rtol=0.) 199 self.assertAllClose(sample_cov_, analytic_cov, atol=0.06, rtol=0.) 200 self.assertAllClose(sample_var_, analytic_var, atol=0.03, rtol=0.) 201 self.assertAllClose(sample_stddev_, analytic_stddev, atol=0.02, rtol=0.) 202 203 def testVariance(self): 204 alpha = [1., 2, 3] 205 denominator = np.sum(alpha)**2 * (np.sum(alpha) + 1) 206 dirichlet = dirichlet_lib.Dirichlet(concentration=alpha) 207 self.assertEqual(dirichlet.covariance().get_shape(), (3, 3)) 208 if not stats: 209 return 210 expected_covariance = np.diag(stats.dirichlet.var(alpha)) 211 expected_covariance += [[0., -2, -3], [-2, 0, -6], [-3, -6, 0] 212 ] / denominator 213 self.assertAllClose( 214 self.evaluate(dirichlet.covariance()), expected_covariance) 215 216 def testMode(self): 217 alpha = np.array([1.1, 2, 3]) 218 expected_mode = (alpha - 1) / (np.sum(alpha) - 3) 219 dirichlet = dirichlet_lib.Dirichlet(concentration=alpha) 220 self.assertEqual(dirichlet.mode().get_shape(), [3]) 221 self.assertAllClose(self.evaluate(dirichlet.mode()), expected_mode) 222 223 def testModeInvalid(self): 224 alpha = np.array([1., 2, 3]) 225 dirichlet = dirichlet_lib.Dirichlet( 226 concentration=alpha, allow_nan_stats=False) 227 with self.assertRaisesOpError("Condition x < y.*"): 228 self.evaluate(dirichlet.mode()) 229 230 def testModeEnableAllowNanStats(self): 231 alpha = np.array([1., 2, 3]) 232 dirichlet = dirichlet_lib.Dirichlet( 233 concentration=alpha, allow_nan_stats=True) 234 expected_mode = np.zeros_like(alpha) + np.nan 235 236 self.assertEqual(dirichlet.mode().get_shape(), [3]) 237 self.assertAllClose(self.evaluate(dirichlet.mode()), expected_mode) 238 239 def testEntropy(self): 240 alpha = [1., 2, 3] 241 dirichlet = dirichlet_lib.Dirichlet(concentration=alpha) 242 self.assertEqual(dirichlet.entropy().get_shape(), ()) 243 if not stats: 244 return 245 expected_entropy = stats.dirichlet.entropy(alpha) 246 self.assertAllClose(self.evaluate(dirichlet.entropy()), expected_entropy) 247 248 def testSample(self): 249 alpha = [1., 2] 250 dirichlet = dirichlet_lib.Dirichlet(alpha) 251 n = constant_op.constant(100000) 252 samples = dirichlet.sample(n) 253 sample_values = self.evaluate(samples) 254 self.assertEqual(sample_values.shape, (100000, 2)) 255 self.assertTrue(np.all(sample_values > 0.0)) 256 if not stats: 257 return 258 self.assertLess( 259 stats.kstest( 260 # Beta is a univariate distribution. 261 sample_values[:, 0], 262 stats.beta(a=1., b=2.).cdf)[0], 263 0.01) 264 265 def testDirichletFullyReparameterized(self): 266 alpha = constant_op.constant([1.0, 2.0, 3.0]) 267 with backprop.GradientTape() as tape: 268 tape.watch(alpha) 269 dirichlet = dirichlet_lib.Dirichlet(alpha) 270 samples = dirichlet.sample(100) 271 grad_alpha = tape.gradient(samples, alpha) 272 self.assertIsNotNone(grad_alpha) 273 274 def testDirichletDirichletKL(self): 275 conc1 = np.array([[1., 2., 3., 1.5, 2.5, 3.5], 276 [1.5, 2.5, 3.5, 4.5, 5.5, 6.5]]) 277 conc2 = np.array([[0.5, 1., 1.5, 2., 2.5, 3.]]) 278 279 d1 = dirichlet_lib.Dirichlet(conc1) 280 d2 = dirichlet_lib.Dirichlet(conc2) 281 x = d1.sample(int(1e4), seed=0) 282 kl_sample = math_ops.reduce_mean(d1.log_prob(x) - d2.log_prob(x), 0) 283 kl_actual = kullback_leibler.kl_divergence(d1, d2) 284 285 kl_sample_val = self.evaluate(kl_sample) 286 kl_actual_val = self.evaluate(kl_actual) 287 288 self.assertEqual(conc1.shape[:-1], kl_actual.get_shape()) 289 290 if not special: 291 return 292 293 kl_expected = ( 294 special.gammaln(np.sum(conc1, -1)) 295 - special.gammaln(np.sum(conc2, -1)) 296 - np.sum(special.gammaln(conc1) - special.gammaln(conc2), -1) 297 + np.sum((conc1 - conc2) * (special.digamma(conc1) - special.digamma( 298 np.sum(conc1, -1, keepdims=True))), -1)) 299 300 self.assertAllClose(kl_expected, kl_actual_val, atol=0., rtol=1e-6) 301 self.assertAllClose(kl_sample_val, kl_actual_val, atol=0., rtol=1e-1) 302 303 # Make sure KL(d1||d1) is 0 304 kl_same = self.evaluate(kullback_leibler.kl_divergence(d1, d1)) 305 self.assertAllClose(kl_same, np.zeros_like(kl_expected)) 306 307 308if __name__ == "__main__": 309 test.main() 310