1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for Monte Carlo Ops.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.contrib import layers as layers_lib 24from tensorflow.contrib.bayesflow.python.ops import monte_carlo_impl as monte_carlo_lib 25from tensorflow.contrib.bayesflow.python.ops.monte_carlo_impl import _get_samples 26from tensorflow.contrib.distributions.python.ops import mvn_diag as mvn_diag_lib 27from tensorflow.python.framework import constant_op 28from tensorflow.python.framework import dtypes 29from tensorflow.python.ops import gradients_impl 30from tensorflow.python.ops import math_ops 31from tensorflow.python.ops.distributions import distribution as distribution_lib 32from tensorflow.python.ops.distributions import kullback_leibler 33from tensorflow.python.ops.distributions import normal as normal_lib 34from tensorflow.python.platform import test 35 36 37layers = layers_lib 38mc = monte_carlo_lib 39 40 41class ExpectationImportanceSampleTest(test.TestCase): 42 43 def test_normal_integral_mean_and_var_correctly_estimated(self): 44 n = int(1e6) 45 with self.cached_session(): 46 mu_p = constant_op.constant([-1.0, 1.0], dtype=dtypes.float64) 47 mu_q = constant_op.constant([0.0, 0.0], dtype=dtypes.float64) 48 sigma_p = constant_op.constant([0.5, 0.5], dtype=dtypes.float64) 49 sigma_q = constant_op.constant([1.0, 1.0], dtype=dtypes.float64) 50 p = normal_lib.Normal(loc=mu_p, scale=sigma_p) 51 q = normal_lib.Normal(loc=mu_q, scale=sigma_q) 52 53 # Compute E_p[X]. 54 e_x = mc.expectation_importance_sampler( 55 f=lambda x: x, log_p=p.log_prob, sampling_dist_q=q, n=n, seed=42) 56 57 # Compute E_p[X^2]. 58 e_x2 = mc.expectation_importance_sampler( 59 f=math_ops.square, log_p=p.log_prob, sampling_dist_q=q, n=n, seed=42) 60 61 stddev = math_ops.sqrt(e_x2 - math_ops.square(e_x)) 62 63 # Relative tolerance (rtol) chosen 2 times as large as minimim needed to 64 # pass. 65 # Convergence of mean is +- 0.003 if n = 100M 66 # Convergence of stddev is +- 0.00001 if n = 100M 67 self.assertEqual(p.batch_shape, e_x.get_shape()) 68 self.assertAllClose(p.mean().eval(), e_x.eval(), rtol=0.01) 69 self.assertAllClose(p.stddev().eval(), stddev.eval(), rtol=0.02) 70 71 def test_multivariate_normal_prob_positive_product_of_components(self): 72 # Test that importance sampling can correctly estimate the probability that 73 # the product of components in a MultivariateNormal are > 0. 74 n = 1000 75 with self.cached_session(): 76 p = mvn_diag_lib.MultivariateNormalDiag( 77 loc=[0.], scale_diag=[1.0, 1.0]) 78 q = mvn_diag_lib.MultivariateNormalDiag( 79 loc=[0.5], scale_diag=[3., 3.]) 80 81 # Compute E_p[X_1 * X_2 > 0], with X_i the ith component of X ~ p(x). 82 # Should equal 1/2 because p is a spherical Gaussian centered at (0, 0). 83 def indicator(x): 84 x1_times_x2 = math_ops.reduce_prod(x, axis=[-1]) 85 return 0.5 * (math_ops.sign(x1_times_x2) + 1.0) 86 87 prob = mc.expectation_importance_sampler( 88 f=indicator, log_p=p.log_prob, sampling_dist_q=q, n=n, seed=42) 89 90 # Relative tolerance (rtol) chosen 2 times as large as minimim needed to 91 # pass. 92 # Convergence is +- 0.004 if n = 100k. 93 self.assertEqual(p.batch_shape, prob.get_shape()) 94 self.assertAllClose(0.5, prob.eval(), rtol=0.05) 95 96 97class ExpectationImportanceSampleLogspaceTest(test.TestCase): 98 99 def test_normal_distribution_second_moment_estimated_correctly(self): 100 # Test the importance sampled estimate against an analytical result. 101 n = int(1e6) 102 with self.cached_session(): 103 mu_p = constant_op.constant([0.0, 0.0], dtype=dtypes.float64) 104 mu_q = constant_op.constant([-1.0, 1.0], dtype=dtypes.float64) 105 sigma_p = constant_op.constant([1.0, 2 / 3.], dtype=dtypes.float64) 106 sigma_q = constant_op.constant([1.0, 1.0], dtype=dtypes.float64) 107 p = normal_lib.Normal(loc=mu_p, scale=sigma_p) 108 q = normal_lib.Normal(loc=mu_q, scale=sigma_q) 109 110 # Compute E_p[X^2]. 111 # Should equal [1, (2/3)^2] 112 log_e_x2 = mc.expectation_importance_sampler_logspace( 113 log_f=lambda x: math_ops.log(math_ops.square(x)), 114 log_p=p.log_prob, 115 sampling_dist_q=q, 116 n=n, 117 seed=42) 118 e_x2 = math_ops.exp(log_e_x2) 119 120 # Relative tolerance (rtol) chosen 2 times as large as minimim needed to 121 # pass. 122 self.assertEqual(p.batch_shape, e_x2.get_shape()) 123 self.assertAllClose([1., (2 / 3.)**2], e_x2.eval(), rtol=0.02) 124 125 126class GetSamplesTest(test.TestCase): 127 """Test the private method 'get_samples'.""" 128 129 def test_raises_if_both_z_and_n_are_none(self): 130 with self.cached_session(): 131 dist = normal_lib.Normal(loc=0., scale=1.) 132 z = None 133 n = None 134 seed = None 135 with self.assertRaisesRegexp(ValueError, 'exactly one'): 136 _get_samples(dist, z, n, seed) 137 138 def test_raises_if_both_z_and_n_are_not_none(self): 139 with self.cached_session(): 140 dist = normal_lib.Normal(loc=0., scale=1.) 141 z = dist.sample(seed=42) 142 n = 1 143 seed = None 144 with self.assertRaisesRegexp(ValueError, 'exactly one'): 145 _get_samples(dist, z, n, seed) 146 147 def test_returns_n_samples_if_n_provided(self): 148 with self.cached_session(): 149 dist = normal_lib.Normal(loc=0., scale=1.) 150 z = None 151 n = 10 152 seed = None 153 z = _get_samples(dist, z, n, seed) 154 self.assertEqual((10,), z.get_shape()) 155 156 def test_returns_z_if_z_provided(self): 157 with self.cached_session(): 158 dist = normal_lib.Normal(loc=0., scale=1.) 159 z = dist.sample(10, seed=42) 160 n = None 161 seed = None 162 z = _get_samples(dist, z, n, seed) 163 self.assertEqual((10,), z.get_shape()) 164 165 166class ExpectationTest(test.TestCase): 167 168 def test_works_correctly(self): 169 with self.cached_session() as sess: 170 x = constant_op.constant([-1e6, -100, -10, -1, 1, 10, 100, 1e6]) 171 p = normal_lib.Normal(loc=x, scale=1.) 172 173 # We use the prefex "efx" to mean "E_p[f(X)]". 174 f = lambda u: u 175 efx_true = x 176 samples = p.sample(int(1e5), seed=1) 177 efx_reparam = mc.expectation(f, samples, p.log_prob) 178 efx_score = mc.expectation(f, samples, p.log_prob, 179 use_reparametrization=False) 180 181 [ 182 efx_true_, 183 efx_reparam_, 184 efx_score_, 185 efx_true_grad_, 186 efx_reparam_grad_, 187 efx_score_grad_, 188 ] = sess.run([ 189 efx_true, 190 efx_reparam, 191 efx_score, 192 gradients_impl.gradients(efx_true, x)[0], 193 gradients_impl.gradients(efx_reparam, x)[0], 194 gradients_impl.gradients(efx_score, x)[0], 195 ]) 196 197 self.assertAllEqual(np.ones_like(efx_true_grad_), efx_true_grad_) 198 199 self.assertAllClose(efx_true_, efx_reparam_, rtol=0.005, atol=0.) 200 self.assertAllClose(efx_true_, efx_score_, rtol=0.005, atol=0.) 201 202 self.assertAllEqual(np.ones_like(efx_true_grad_, dtype=np.bool), 203 np.isfinite(efx_reparam_grad_)) 204 self.assertAllEqual(np.ones_like(efx_true_grad_, dtype=np.bool), 205 np.isfinite(efx_score_grad_)) 206 207 self.assertAllClose(efx_true_grad_, efx_reparam_grad_, 208 rtol=0.03, atol=0.) 209 # Variance is too high to be meaningful, so we'll only check those which 210 # converge. 211 self.assertAllClose(efx_true_grad_[2:-2], 212 efx_score_grad_[2:-2], 213 rtol=0.05, atol=0.) 214 215 def test_docstring_example_normal(self): 216 with self.cached_session() as sess: 217 num_draws = int(1e5) 218 mu_p = constant_op.constant(0.) 219 mu_q = constant_op.constant(1.) 220 p = normal_lib.Normal(loc=mu_p, scale=1.) 221 q = normal_lib.Normal(loc=mu_q, scale=2.) 222 exact_kl_normal_normal = kullback_leibler.kl_divergence(p, q) 223 approx_kl_normal_normal = monte_carlo_lib.expectation( 224 f=lambda x: p.log_prob(x) - q.log_prob(x), 225 samples=p.sample(num_draws, seed=42), 226 log_prob=p.log_prob, 227 use_reparametrization=(p.reparameterization_type 228 == distribution_lib.FULLY_REPARAMETERIZED)) 229 [exact_kl_normal_normal_, approx_kl_normal_normal_] = sess.run([ 230 exact_kl_normal_normal, approx_kl_normal_normal]) 231 self.assertEqual( 232 True, 233 p.reparameterization_type == distribution_lib.FULLY_REPARAMETERIZED) 234 self.assertAllClose(exact_kl_normal_normal_, approx_kl_normal_normal_, 235 rtol=0.01, atol=0.) 236 237 # Compare gradients. (Not present in `docstring`.) 238 gradp = lambda fp: gradients_impl.gradients(fp, mu_p)[0] 239 gradq = lambda fq: gradients_impl.gradients(fq, mu_q)[0] 240 [ 241 gradp_exact_kl_normal_normal_, 242 gradq_exact_kl_normal_normal_, 243 gradp_approx_kl_normal_normal_, 244 gradq_approx_kl_normal_normal_, 245 ] = sess.run([ 246 gradp(exact_kl_normal_normal), 247 gradq(exact_kl_normal_normal), 248 gradp(approx_kl_normal_normal), 249 gradq(approx_kl_normal_normal), 250 ]) 251 self.assertAllClose(gradp_exact_kl_normal_normal_, 252 gradp_approx_kl_normal_normal_, 253 rtol=0.01, atol=0.) 254 self.assertAllClose(gradq_exact_kl_normal_normal_, 255 gradq_approx_kl_normal_normal_, 256 rtol=0.01, atol=0.) 257 258 259if __name__ == '__main__': 260 test.main() 261