1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15import importlib 16 17import numpy as np 18 19from tensorflow.python.eager import backprop 20from tensorflow.python.framework import constant_op 21from tensorflow.python.framework import tensor_shape 22from tensorflow.python.framework import test_util 23from tensorflow.python.ops import array_ops 24from tensorflow.python.ops import math_ops 25from tensorflow.python.ops.distributions import dirichlet as dirichlet_lib 26from tensorflow.python.ops.distributions import kullback_leibler 27from tensorflow.python.platform import test 28from tensorflow.python.platform import tf_logging 29 30 31def try_import(name): # pylint: disable=invalid-name 32 module = None 33 try: 34 module = importlib.import_module(name) 35 except ImportError as e: 36 tf_logging.warning("Could not import %s: %s" % (name, str(e))) 37 return module 38 39 40special = try_import("scipy.special") 41stats = try_import("scipy.stats") 42 43 44@test_util.run_all_in_graph_and_eager_modes 45class DirichletTest(test.TestCase): 46 47 def testSimpleShapes(self): 48 alpha = np.random.rand(3) 49 dist = dirichlet_lib.Dirichlet(alpha) 50 self.assertEqual(3, self.evaluate(dist.event_shape_tensor())) 51 self.assertAllEqual([], self.evaluate(dist.batch_shape_tensor())) 52 self.assertEqual(tensor_shape.TensorShape([3]), dist.event_shape) 53 self.assertEqual(tensor_shape.TensorShape([]), dist.batch_shape) 54 55 def testComplexShapes(self): 56 alpha = np.random.rand(3, 2, 2) 57 dist = dirichlet_lib.Dirichlet(alpha) 58 self.assertEqual(2, self.evaluate(dist.event_shape_tensor())) 59 self.assertAllEqual([3, 2], self.evaluate(dist.batch_shape_tensor())) 60 self.assertEqual(tensor_shape.TensorShape([2]), dist.event_shape) 61 self.assertEqual(tensor_shape.TensorShape([3, 2]), dist.batch_shape) 62 63 def testConcentrationProperty(self): 64 alpha = [[1., 2, 3]] 65 dist = dirichlet_lib.Dirichlet(alpha) 66 self.assertEqual([1, 3], dist.concentration.get_shape()) 67 self.assertAllClose(alpha, self.evaluate(dist.concentration)) 68 69 def testPdfXProper(self): 70 alpha = [[1., 2, 3]] 71 dist = dirichlet_lib.Dirichlet(alpha, validate_args=True) 72 self.evaluate(dist.prob([.1, .3, .6])) 73 self.evaluate(dist.prob([.2, .3, .5])) 74 # Either condition can trigger. 75 with self.assertRaisesOpError("samples must be positive"): 76 self.evaluate(dist.prob([-1., 1.5, 0.5])) 77 with self.assertRaisesOpError("samples must be positive"): 78 self.evaluate(dist.prob([0., .1, .9])) 79 with self.assertRaisesOpError("sample last-dimension must sum to `1`"): 80 self.evaluate(dist.prob([.1, .2, .8])) 81 82 def testLogPdfOnBoundaryIsFiniteWhenAlphaIsOne(self): 83 # Test concentration = 1. for each dimension. 84 concentration = 3 * np.ones((10, 10)).astype(np.float32) 85 concentration[range(10), range(10)] = 1. 86 x = 1 / 9. * np.ones((10, 10)).astype(np.float32) 87 x[range(10), range(10)] = 0. 88 dist = dirichlet_lib.Dirichlet(concentration) 89 log_prob = self.evaluate(dist.log_prob(x)) 90 self.assertAllEqual( 91 np.ones_like(log_prob, dtype=np.bool_), np.isfinite(log_prob)) 92 93 # Test when concentration[k] = 1., and x is zero at various dimensions. 94 dist = dirichlet_lib.Dirichlet(10 * [1.]) 95 log_prob = self.evaluate(dist.log_prob(x)) 96 self.assertAllEqual( 97 np.ones_like(log_prob, dtype=np.bool_), np.isfinite(log_prob)) 98 99 def testPdfZeroBatches(self): 100 alpha = [1., 2] 101 x = [.5, .5] 102 dist = dirichlet_lib.Dirichlet(alpha) 103 pdf = dist.prob(x) 104 self.assertAllClose(1., self.evaluate(pdf)) 105 self.assertEqual((), pdf.get_shape()) 106 107 def testPdfZeroBatchesNontrivialX(self): 108 alpha = [1., 2] 109 x = [.3, .7] 110 dist = dirichlet_lib.Dirichlet(alpha) 111 pdf = dist.prob(x) 112 self.assertAllClose(7. / 5, self.evaluate(pdf)) 113 self.assertEqual((), pdf.get_shape()) 114 115 def testPdfUniformZeroBatches(self): 116 # Corresponds to a uniform distribution 117 alpha = [1., 1, 1] 118 x = [[.2, .5, .3], [.3, .4, .3]] 119 dist = dirichlet_lib.Dirichlet(alpha) 120 pdf = dist.prob(x) 121 self.assertAllClose([2., 2.], self.evaluate(pdf)) 122 self.assertEqual((2), pdf.get_shape()) 123 124 def testPdfAlphaStretchedInBroadcastWhenSameRank(self): 125 alpha = [[1., 2]] 126 x = [[.5, .5], [.3, .7]] 127 dist = dirichlet_lib.Dirichlet(alpha) 128 pdf = dist.prob(x) 129 self.assertAllClose([1., 7. / 5], self.evaluate(pdf)) 130 self.assertEqual((2), pdf.get_shape()) 131 132 def testPdfAlphaStretchedInBroadcastWhenLowerRank(self): 133 alpha = [1., 2] 134 x = [[.5, .5], [.2, .8]] 135 pdf = dirichlet_lib.Dirichlet(alpha).prob(x) 136 self.assertAllClose([1., 8. / 5], self.evaluate(pdf)) 137 self.assertEqual((2), pdf.get_shape()) 138 139 def testPdfXStretchedInBroadcastWhenSameRank(self): 140 alpha = [[1., 2], [2., 3]] 141 x = [[.5, .5]] 142 pdf = dirichlet_lib.Dirichlet(alpha).prob(x) 143 self.assertAllClose([1., 3. / 2], self.evaluate(pdf)) 144 self.assertEqual((2), pdf.get_shape()) 145 146 def testPdfXStretchedInBroadcastWhenLowerRank(self): 147 alpha = [[1., 2], [2., 3]] 148 x = [.5, .5] 149 pdf = dirichlet_lib.Dirichlet(alpha).prob(x) 150 self.assertAllClose([1., 3. / 2], self.evaluate(pdf)) 151 self.assertEqual((2), pdf.get_shape()) 152 153 def testMean(self): 154 alpha = [1., 2, 3] 155 dirichlet = dirichlet_lib.Dirichlet(concentration=alpha) 156 self.assertEqual(dirichlet.mean().get_shape(), [3]) 157 if not stats: 158 return 159 expected_mean = stats.dirichlet.mean(alpha) 160 self.assertAllClose(self.evaluate(dirichlet.mean()), expected_mean) 161 162 def testCovarianceFromSampling(self): 163 alpha = np.array([[1., 2, 3], 164 [2.5, 4, 0.01]], dtype=np.float32) 165 dist = dirichlet_lib.Dirichlet(alpha) # batch_shape=[2], event_shape=[3] 166 x = dist.sample(int(250e3), seed=1) 167 sample_mean = math_ops.reduce_mean(x, 0) 168 x_centered = x - sample_mean[None, ...] 169 sample_cov = math_ops.reduce_mean(math_ops.matmul( 170 x_centered[..., None], x_centered[..., None, :]), 0) 171 sample_var = array_ops.matrix_diag_part(sample_cov) 172 sample_stddev = math_ops.sqrt(sample_var) 173 174 [ 175 sample_mean_, 176 sample_cov_, 177 sample_var_, 178 sample_stddev_, 179 analytic_mean, 180 analytic_cov, 181 analytic_var, 182 analytic_stddev, 183 ] = self.evaluate([ 184 sample_mean, 185 sample_cov, 186 sample_var, 187 sample_stddev, 188 dist.mean(), 189 dist.covariance(), 190 dist.variance(), 191 dist.stddev(), 192 ]) 193 194 self.assertAllClose(sample_mean_, analytic_mean, atol=0.04, rtol=0.) 195 self.assertAllClose(sample_cov_, analytic_cov, atol=0.06, rtol=0.) 196 self.assertAllClose(sample_var_, analytic_var, atol=0.04, rtol=0.) 197 self.assertAllClose(sample_stddev_, analytic_stddev, atol=0.02, rtol=0.) 198 199 @test_util.run_without_tensor_float_32( 200 "Calls Dirichlet.covariance, which calls matmul") 201 def testVariance(self): 202 alpha = [1., 2, 3] 203 denominator = np.sum(alpha)**2 * (np.sum(alpha) + 1) 204 dirichlet = dirichlet_lib.Dirichlet(concentration=alpha) 205 self.assertEqual(dirichlet.covariance().get_shape(), (3, 3)) 206 if not stats: 207 return 208 expected_covariance = np.diag(stats.dirichlet.var(alpha)) 209 expected_covariance += [[0., -2, -3], [-2, 0, -6], [-3, -6, 0] 210 ] / denominator 211 self.assertAllClose( 212 self.evaluate(dirichlet.covariance()), expected_covariance) 213 214 def testMode(self): 215 alpha = np.array([1.1, 2, 3]) 216 expected_mode = (alpha - 1) / (np.sum(alpha) - 3) 217 dirichlet = dirichlet_lib.Dirichlet(concentration=alpha) 218 self.assertEqual(dirichlet.mode().get_shape(), [3]) 219 self.assertAllClose(self.evaluate(dirichlet.mode()), expected_mode) 220 221 def testModeInvalid(self): 222 alpha = np.array([1., 2, 3]) 223 dirichlet = dirichlet_lib.Dirichlet( 224 concentration=alpha, allow_nan_stats=False) 225 with self.assertRaisesOpError("Condition x < y.*"): 226 self.evaluate(dirichlet.mode()) 227 228 def testModeEnableAllowNanStats(self): 229 alpha = np.array([1., 2, 3]) 230 dirichlet = dirichlet_lib.Dirichlet( 231 concentration=alpha, allow_nan_stats=True) 232 expected_mode = np.zeros_like(alpha) + np.nan 233 234 self.assertEqual(dirichlet.mode().get_shape(), [3]) 235 self.assertAllClose(self.evaluate(dirichlet.mode()), expected_mode) 236 237 def testEntropy(self): 238 alpha = [1., 2, 3] 239 dirichlet = dirichlet_lib.Dirichlet(concentration=alpha) 240 self.assertEqual(dirichlet.entropy().get_shape(), ()) 241 if not stats: 242 return 243 expected_entropy = stats.dirichlet.entropy(alpha) 244 self.assertAllClose(self.evaluate(dirichlet.entropy()), expected_entropy) 245 246 def testSample(self): 247 alpha = [1., 2] 248 dirichlet = dirichlet_lib.Dirichlet(alpha) 249 n = constant_op.constant(100000) 250 samples = dirichlet.sample(n) 251 sample_values = self.evaluate(samples) 252 self.assertEqual(sample_values.shape, (100000, 2)) 253 self.assertTrue(np.all(sample_values > 0.0)) 254 if not stats: 255 return 256 self.assertLess( 257 stats.kstest( 258 # Beta is a univariate distribution. 259 sample_values[:, 0], 260 stats.beta(a=1., b=2.).cdf)[0], 261 0.01) 262 263 def testDirichletFullyReparameterized(self): 264 alpha = constant_op.constant([1.0, 2.0, 3.0]) 265 with backprop.GradientTape() as tape: 266 tape.watch(alpha) 267 dirichlet = dirichlet_lib.Dirichlet(alpha) 268 samples = dirichlet.sample(100) 269 grad_alpha = tape.gradient(samples, alpha) 270 self.assertIsNotNone(grad_alpha) 271 272 def testDirichletDirichletKL(self): 273 conc1 = np.array([[1., 2., 3., 1.5, 2.5, 3.5], 274 [1.5, 2.5, 3.5, 4.5, 5.5, 6.5]]) 275 conc2 = np.array([[0.5, 1., 1.5, 2., 2.5, 3.]]) 276 277 d1 = dirichlet_lib.Dirichlet(conc1) 278 d2 = dirichlet_lib.Dirichlet(conc2) 279 x = d1.sample(int(1e4), seed=0) 280 kl_sample = math_ops.reduce_mean(d1.log_prob(x) - d2.log_prob(x), 0) 281 kl_actual = kullback_leibler.kl_divergence(d1, d2) 282 283 kl_sample_val = self.evaluate(kl_sample) 284 kl_actual_val = self.evaluate(kl_actual) 285 286 self.assertEqual(conc1.shape[:-1], kl_actual.get_shape()) 287 288 if not special: 289 return 290 291 kl_expected = ( 292 special.gammaln(np.sum(conc1, -1)) 293 - special.gammaln(np.sum(conc2, -1)) 294 - np.sum(special.gammaln(conc1) - special.gammaln(conc2), -1) 295 + np.sum((conc1 - conc2) * (special.digamma(conc1) - special.digamma( 296 np.sum(conc1, -1, keepdims=True))), -1)) 297 298 self.assertAllClose(kl_expected, kl_actual_val, atol=0., rtol=1e-6) 299 self.assertAllClose(kl_sample_val, kl_actual_val, atol=0., rtol=1e-1) 300 301 # Make sure KL(d1||d1) is 0 302 kl_same = self.evaluate(kullback_leibler.kl_divergence(d1, d1)) 303 self.assertAllClose(kl_same, np.zeros_like(kl_expected)) 304 305 306if __name__ == "__main__": 307 test.main() 308