• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for Monte Carlo Ops."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.contrib import layers as layers_lib
24from tensorflow.contrib.bayesflow.python.ops import monte_carlo_impl as monte_carlo_lib
25from tensorflow.contrib.bayesflow.python.ops.monte_carlo_impl import _get_samples
26from tensorflow.contrib.distributions.python.ops import mvn_diag as mvn_diag_lib
27from tensorflow.python.framework import constant_op
28from tensorflow.python.framework import dtypes
29from tensorflow.python.ops import gradients_impl
30from tensorflow.python.ops import math_ops
31from tensorflow.python.ops.distributions import distribution as distribution_lib
32from tensorflow.python.ops.distributions import kullback_leibler
33from tensorflow.python.ops.distributions import normal as normal_lib
34from tensorflow.python.platform import test
35
36
37layers = layers_lib
38mc = monte_carlo_lib
39
40
41class ExpectationImportanceSampleTest(test.TestCase):
42
43  def test_normal_integral_mean_and_var_correctly_estimated(self):
44    n = int(1e6)
45    with self.cached_session():
46      mu_p = constant_op.constant([-1.0, 1.0], dtype=dtypes.float64)
47      mu_q = constant_op.constant([0.0, 0.0], dtype=dtypes.float64)
48      sigma_p = constant_op.constant([0.5, 0.5], dtype=dtypes.float64)
49      sigma_q = constant_op.constant([1.0, 1.0], dtype=dtypes.float64)
50      p = normal_lib.Normal(loc=mu_p, scale=sigma_p)
51      q = normal_lib.Normal(loc=mu_q, scale=sigma_q)
52
53      # Compute E_p[X].
54      e_x = mc.expectation_importance_sampler(
55          f=lambda x: x, log_p=p.log_prob, sampling_dist_q=q, n=n, seed=42)
56
57      # Compute E_p[X^2].
58      e_x2 = mc.expectation_importance_sampler(
59          f=math_ops.square, log_p=p.log_prob, sampling_dist_q=q, n=n, seed=42)
60
61      stddev = math_ops.sqrt(e_x2 - math_ops.square(e_x))
62
63      # Relative tolerance (rtol) chosen 2 times as large as minimim needed to
64      # pass.
65      # Convergence of mean is +- 0.003 if n = 100M
66      # Convergence of stddev is +- 0.00001 if n = 100M
67      self.assertEqual(p.batch_shape, e_x.get_shape())
68      self.assertAllClose(p.mean().eval(), e_x.eval(), rtol=0.01)
69      self.assertAllClose(p.stddev().eval(), stddev.eval(), rtol=0.02)
70
71  def test_multivariate_normal_prob_positive_product_of_components(self):
72    # Test that importance sampling can correctly estimate the probability that
73    # the product of components in a MultivariateNormal are > 0.
74    n = 1000
75    with self.cached_session():
76      p = mvn_diag_lib.MultivariateNormalDiag(
77          loc=[0.], scale_diag=[1.0, 1.0])
78      q = mvn_diag_lib.MultivariateNormalDiag(
79          loc=[0.5], scale_diag=[3., 3.])
80
81      # Compute E_p[X_1 * X_2 > 0], with X_i the ith component of X ~ p(x).
82      # Should equal 1/2 because p is a spherical Gaussian centered at (0, 0).
83      def indicator(x):
84        x1_times_x2 = math_ops.reduce_prod(x, axis=[-1])
85        return 0.5 * (math_ops.sign(x1_times_x2) + 1.0)
86
87      prob = mc.expectation_importance_sampler(
88          f=indicator, log_p=p.log_prob, sampling_dist_q=q, n=n, seed=42)
89
90      # Relative tolerance (rtol) chosen 2 times as large as minimim needed to
91      # pass.
92      # Convergence is +- 0.004 if n = 100k.
93      self.assertEqual(p.batch_shape, prob.get_shape())
94      self.assertAllClose(0.5, prob.eval(), rtol=0.05)
95
96
97class ExpectationImportanceSampleLogspaceTest(test.TestCase):
98
99  def test_normal_distribution_second_moment_estimated_correctly(self):
100    # Test the importance sampled estimate against an analytical result.
101    n = int(1e6)
102    with self.cached_session():
103      mu_p = constant_op.constant([0.0, 0.0], dtype=dtypes.float64)
104      mu_q = constant_op.constant([-1.0, 1.0], dtype=dtypes.float64)
105      sigma_p = constant_op.constant([1.0, 2 / 3.], dtype=dtypes.float64)
106      sigma_q = constant_op.constant([1.0, 1.0], dtype=dtypes.float64)
107      p = normal_lib.Normal(loc=mu_p, scale=sigma_p)
108      q = normal_lib.Normal(loc=mu_q, scale=sigma_q)
109
110      # Compute E_p[X^2].
111      # Should equal [1, (2/3)^2]
112      log_e_x2 = mc.expectation_importance_sampler_logspace(
113          log_f=lambda x: math_ops.log(math_ops.square(x)),
114          log_p=p.log_prob,
115          sampling_dist_q=q,
116          n=n,
117          seed=42)
118      e_x2 = math_ops.exp(log_e_x2)
119
120      # Relative tolerance (rtol) chosen 2 times as large as minimim needed to
121      # pass.
122      self.assertEqual(p.batch_shape, e_x2.get_shape())
123      self.assertAllClose([1., (2 / 3.)**2], e_x2.eval(), rtol=0.02)
124
125
126class GetSamplesTest(test.TestCase):
127  """Test the private method 'get_samples'."""
128
129  def test_raises_if_both_z_and_n_are_none(self):
130    with self.cached_session():
131      dist = normal_lib.Normal(loc=0., scale=1.)
132      z = None
133      n = None
134      seed = None
135      with self.assertRaisesRegexp(ValueError, 'exactly one'):
136        _get_samples(dist, z, n, seed)
137
138  def test_raises_if_both_z_and_n_are_not_none(self):
139    with self.cached_session():
140      dist = normal_lib.Normal(loc=0., scale=1.)
141      z = dist.sample(seed=42)
142      n = 1
143      seed = None
144      with self.assertRaisesRegexp(ValueError, 'exactly one'):
145        _get_samples(dist, z, n, seed)
146
147  def test_returns_n_samples_if_n_provided(self):
148    with self.cached_session():
149      dist = normal_lib.Normal(loc=0., scale=1.)
150      z = None
151      n = 10
152      seed = None
153      z = _get_samples(dist, z, n, seed)
154      self.assertEqual((10,), z.get_shape())
155
156  def test_returns_z_if_z_provided(self):
157    with self.cached_session():
158      dist = normal_lib.Normal(loc=0., scale=1.)
159      z = dist.sample(10, seed=42)
160      n = None
161      seed = None
162      z = _get_samples(dist, z, n, seed)
163      self.assertEqual((10,), z.get_shape())
164
165
166class ExpectationTest(test.TestCase):
167
168  def test_works_correctly(self):
169    with self.cached_session() as sess:
170      x = constant_op.constant([-1e6, -100, -10, -1, 1, 10, 100, 1e6])
171      p = normal_lib.Normal(loc=x, scale=1.)
172
173      # We use the prefex "efx" to mean "E_p[f(X)]".
174      f = lambda u: u
175      efx_true = x
176      samples = p.sample(int(1e5), seed=1)
177      efx_reparam = mc.expectation(f, samples, p.log_prob)
178      efx_score = mc.expectation(f, samples, p.log_prob,
179                                 use_reparametrization=False)
180
181      [
182          efx_true_,
183          efx_reparam_,
184          efx_score_,
185          efx_true_grad_,
186          efx_reparam_grad_,
187          efx_score_grad_,
188      ] = sess.run([
189          efx_true,
190          efx_reparam,
191          efx_score,
192          gradients_impl.gradients(efx_true, x)[0],
193          gradients_impl.gradients(efx_reparam, x)[0],
194          gradients_impl.gradients(efx_score, x)[0],
195      ])
196
197      self.assertAllEqual(np.ones_like(efx_true_grad_), efx_true_grad_)
198
199      self.assertAllClose(efx_true_, efx_reparam_, rtol=0.005, atol=0.)
200      self.assertAllClose(efx_true_, efx_score_, rtol=0.005, atol=0.)
201
202      self.assertAllEqual(np.ones_like(efx_true_grad_, dtype=np.bool),
203                          np.isfinite(efx_reparam_grad_))
204      self.assertAllEqual(np.ones_like(efx_true_grad_, dtype=np.bool),
205                          np.isfinite(efx_score_grad_))
206
207      self.assertAllClose(efx_true_grad_, efx_reparam_grad_,
208                          rtol=0.03, atol=0.)
209      # Variance is too high to be meaningful, so we'll only check those which
210      # converge.
211      self.assertAllClose(efx_true_grad_[2:-2],
212                          efx_score_grad_[2:-2],
213                          rtol=0.05, atol=0.)
214
215  def test_docstring_example_normal(self):
216    with self.cached_session() as sess:
217      num_draws = int(1e5)
218      mu_p = constant_op.constant(0.)
219      mu_q = constant_op.constant(1.)
220      p = normal_lib.Normal(loc=mu_p, scale=1.)
221      q = normal_lib.Normal(loc=mu_q, scale=2.)
222      exact_kl_normal_normal = kullback_leibler.kl_divergence(p, q)
223      approx_kl_normal_normal = monte_carlo_lib.expectation(
224          f=lambda x: p.log_prob(x) - q.log_prob(x),
225          samples=p.sample(num_draws, seed=42),
226          log_prob=p.log_prob,
227          use_reparametrization=(p.reparameterization_type
228                                 == distribution_lib.FULLY_REPARAMETERIZED))
229      [exact_kl_normal_normal_, approx_kl_normal_normal_] = sess.run([
230          exact_kl_normal_normal, approx_kl_normal_normal])
231      self.assertEqual(
232          True,
233          p.reparameterization_type == distribution_lib.FULLY_REPARAMETERIZED)
234      self.assertAllClose(exact_kl_normal_normal_, approx_kl_normal_normal_,
235                          rtol=0.01, atol=0.)
236
237      # Compare gradients. (Not present in `docstring`.)
238      gradp = lambda fp: gradients_impl.gradients(fp, mu_p)[0]
239      gradq = lambda fq: gradients_impl.gradients(fq, mu_q)[0]
240      [
241          gradp_exact_kl_normal_normal_,
242          gradq_exact_kl_normal_normal_,
243          gradp_approx_kl_normal_normal_,
244          gradq_approx_kl_normal_normal_,
245      ] = sess.run([
246          gradp(exact_kl_normal_normal),
247          gradq(exact_kl_normal_normal),
248          gradp(approx_kl_normal_normal),
249          gradq(approx_kl_normal_normal),
250      ])
251      self.assertAllClose(gradp_exact_kl_normal_normal_,
252                          gradp_approx_kl_normal_normal_,
253                          rtol=0.01, atol=0.)
254      self.assertAllClose(gradq_exact_kl_normal_normal_,
255                          gradq_approx_kl_normal_normal_,
256                          rtol=0.01, atol=0.)
257
258
259if __name__ == '__main__':
260  test.main()
261