• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for Hamiltonian Monte Carlo."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22
23import numpy as np
24from scipy import stats
25
26from tensorflow.contrib.bayesflow.python.ops import hmc
27from tensorflow.contrib.bayesflow.python.ops.hmc_impl import _compute_energy_change
28from tensorflow.contrib.bayesflow.python.ops.hmc_impl import _leapfrog_integrator
29
30from tensorflow.contrib.distributions.python.ops import independent as independent_lib
31from tensorflow.python.framework import ops
32from tensorflow.python.framework import random_seed
33from tensorflow.python.ops import array_ops
34from tensorflow.python.ops import gen_linalg_ops
35from tensorflow.python.ops import gradients_impl as gradients_ops
36from tensorflow.python.ops import math_ops
37from tensorflow.python.ops import random_ops
38from tensorflow.python.ops.distributions import gamma as gamma_lib
39from tensorflow.python.ops.distributions import normal as normal_lib
40from tensorflow.python.platform import test
41from tensorflow.python.platform import tf_logging as logging_ops
42
43
44def _reduce_variance(x, axis=None, keepdims=False):
45  sample_mean = math_ops.reduce_mean(x, axis, keepdims=True)
46  return math_ops.reduce_mean(
47      math_ops.squared_difference(x, sample_mean), axis, keepdims)
48
49
50class HMCTest(test.TestCase):
51
52  def setUp(self):
53    self._shape_param = 5.
54    self._rate_param = 10.
55
56    random_seed.set_random_seed(10003)
57    np.random.seed(10003)
58
59  def assertAllFinite(self, x):
60    self.assertAllEqual(np.ones_like(x).astype(bool), np.isfinite(x))
61
62  def _log_gamma_log_prob(self, x, event_dims=()):
63    """Computes log-pdf of a log-gamma random variable.
64
65    Args:
66      x: Value of the random variable.
67      event_dims: Dimensions not to treat as independent.
68
69    Returns:
70      log_prob: The log-pdf up to a normalizing constant.
71    """
72    return math_ops.reduce_sum(self._shape_param * x -
73                               self._rate_param * math_ops.exp(x),
74                               event_dims)
75
76  def _integrator_conserves_energy(self, x, independent_chain_ndims, sess,
77                                   feed_dict=None):
78    step_size = array_ops.placeholder(np.float32, [], name="step_size")
79    hmc_lf_steps = array_ops.placeholder(np.int32, [], name="hmc_lf_steps")
80
81    if feed_dict is None:
82      feed_dict = {}
83    feed_dict[hmc_lf_steps] = 1000
84
85    event_dims = math_ops.range(independent_chain_ndims,
86                                array_ops.rank(x))
87
88    m = random_ops.random_normal(array_ops.shape(x))
89    log_prob_0 = self._log_gamma_log_prob(x, event_dims)
90    grad_0 = gradients_ops.gradients(log_prob_0, x)
91    old_energy = -log_prob_0 + 0.5 * math_ops.reduce_sum(m**2., event_dims)
92
93    new_m, _, log_prob_1, _ = _leapfrog_integrator(
94        current_momentums=[m],
95        target_log_prob_fn=lambda x: self._log_gamma_log_prob(x, event_dims),
96        current_state_parts=[x],
97        step_sizes=[step_size],
98        num_leapfrog_steps=hmc_lf_steps,
99        current_target_log_prob=log_prob_0,
100        current_grads_target_log_prob=grad_0)
101    new_m = new_m[0]
102
103    new_energy = -log_prob_1 + 0.5 * math_ops.reduce_sum(new_m * new_m,
104                                                         event_dims)
105
106    x_shape = sess.run(x, feed_dict).shape
107    event_size = np.prod(x_shape[independent_chain_ndims:])
108    feed_dict[step_size] = 0.1 / event_size
109    old_energy_, new_energy_ = sess.run([old_energy, new_energy],
110                                        feed_dict)
111    logging_ops.vlog(1, "average energy relative change: {}".format(
112        (1. - new_energy_ / old_energy_).mean()))
113    self.assertAllClose(old_energy_, new_energy_, atol=0., rtol=0.02)
114
115  def _integrator_conserves_energy_wrapper(self, independent_chain_ndims):
116    """Tests the long-term energy conservation of the leapfrog integrator.
117
118    The leapfrog integrator is symplectic, so for sufficiently small step
119    sizes it should be possible to run it more or less indefinitely without
120    the energy of the system blowing up or collapsing.
121
122    Args:
123      independent_chain_ndims: Python `int` scalar representing the number of
124        dims associated with independent chains.
125    """
126    with self.test_session(graph=ops.Graph()) as sess:
127      x_ph = array_ops.placeholder(np.float32, name="x_ph")
128      feed_dict = {x_ph: np.random.rand(50, 10, 2)}
129      self._integrator_conserves_energy(x_ph, independent_chain_ndims,
130                                        sess, feed_dict)
131
132  def testIntegratorEnergyConservationNullShape(self):
133    self._integrator_conserves_energy_wrapper(0)
134
135  def testIntegratorEnergyConservation1(self):
136    self._integrator_conserves_energy_wrapper(1)
137
138  def testIntegratorEnergyConservation2(self):
139    self._integrator_conserves_energy_wrapper(2)
140
141  def testIntegratorEnergyConservation3(self):
142    self._integrator_conserves_energy_wrapper(3)
143
144  def testSampleChainSeedReproducibleWorksCorrectly(self):
145    with self.test_session(graph=ops.Graph()) as sess:
146      num_results = 10
147      independent_chain_ndims = 1
148
149      def log_gamma_log_prob(x):
150        event_dims = math_ops.range(independent_chain_ndims,
151                                    array_ops.rank(x))
152        return self._log_gamma_log_prob(x, event_dims)
153
154      kwargs = dict(
155          target_log_prob_fn=log_gamma_log_prob,
156          current_state=np.random.rand(4, 3, 2),
157          step_size=0.1,
158          num_leapfrog_steps=2,
159          num_burnin_steps=150,
160          seed=52,
161      )
162
163      samples0, kernel_results0 = hmc.sample_chain(
164          **dict(list(kwargs.items()) + list(dict(
165              num_results=2 * num_results,
166              num_steps_between_results=0).items())))
167
168      samples1, kernel_results1 = hmc.sample_chain(
169          **dict(list(kwargs.items()) + list(dict(
170              num_results=num_results,
171              num_steps_between_results=1).items())))
172
173      [
174          samples0_,
175          samples1_,
176          target_log_prob0_,
177          target_log_prob1_,
178      ] = sess.run([
179          samples0,
180          samples1,
181          kernel_results0.current_target_log_prob,
182          kernel_results1.current_target_log_prob,
183      ])
184      self.assertAllClose(samples0_[::2], samples1_,
185                          atol=1e-5, rtol=1e-5)
186      self.assertAllClose(target_log_prob0_[::2], target_log_prob1_,
187                          atol=1e-5, rtol=1e-5)
188
189  def _chain_gets_correct_expectations(self, x, independent_chain_ndims,
190                                       sess, feed_dict=None):
191    counter = collections.Counter()
192    def log_gamma_log_prob(x):
193      counter["target_calls"] += 1
194      event_dims = math_ops.range(independent_chain_ndims,
195                                  array_ops.rank(x))
196      return self._log_gamma_log_prob(x, event_dims)
197
198    num_results = array_ops.placeholder(
199        np.int32, [], name="num_results")
200    step_size = array_ops.placeholder(
201        np.float32, [], name="step_size")
202    num_leapfrog_steps = array_ops.placeholder(
203        np.int32, [], name="num_leapfrog_steps")
204
205    if feed_dict is None:
206      feed_dict = {}
207    feed_dict.update({num_results: 150,
208                      step_size: 0.05,
209                      num_leapfrog_steps: 2})
210
211    samples, kernel_results = hmc.sample_chain(
212        num_results=num_results,
213        target_log_prob_fn=log_gamma_log_prob,
214        current_state=x,
215        step_size=step_size,
216        num_leapfrog_steps=num_leapfrog_steps,
217        num_burnin_steps=150,
218        seed=42)
219
220    self.assertAllEqual(dict(target_calls=2), counter)
221
222    expected_x = (math_ops.digamma(self._shape_param)
223                  - np.log(self._rate_param))
224
225    expected_exp_x = self._shape_param / self._rate_param
226
227    acceptance_probs_, samples_, expected_x_ = sess.run(
228        [kernel_results.acceptance_probs, samples, expected_x],
229        feed_dict)
230
231    actual_x = samples_.mean()
232    actual_exp_x = np.exp(samples_).mean()
233
234    logging_ops.vlog(1, "True      E[x, exp(x)]: {}\t{}".format(
235        expected_x_, expected_exp_x))
236    logging_ops.vlog(1, "Estimated E[x, exp(x)]: {}\t{}".format(
237        actual_x, actual_exp_x))
238    self.assertNear(actual_x, expected_x_, 2e-2)
239    self.assertNear(actual_exp_x, expected_exp_x, 2e-2)
240    self.assertAllEqual(np.ones_like(acceptance_probs_, np.bool),
241                        acceptance_probs_ > 0.5)
242    self.assertAllEqual(np.ones_like(acceptance_probs_, np.bool),
243                        acceptance_probs_ <= 1.)
244
245  def _chain_gets_correct_expectations_wrapper(self, independent_chain_ndims):
246    with self.test_session(graph=ops.Graph()) as sess:
247      x_ph = array_ops.placeholder(np.float32, name="x_ph")
248      feed_dict = {x_ph: np.random.rand(50, 10, 2)}
249      self._chain_gets_correct_expectations(x_ph, independent_chain_ndims,
250                                            sess, feed_dict)
251
252  def testHMCChainExpectationsNullShape(self):
253    self._chain_gets_correct_expectations_wrapper(0)
254
255  def testHMCChainExpectations1(self):
256    self._chain_gets_correct_expectations_wrapper(1)
257
258  def testHMCChainExpectations2(self):
259    self._chain_gets_correct_expectations_wrapper(2)
260
261  def testKernelResultsUsingTruncatedDistribution(self):
262    def log_prob(x):
263      return array_ops.where(
264          x >= 0.,
265          -x - x**2,  # Non-constant gradient.
266          array_ops.fill(x.shape, math_ops.cast(-np.inf, x.dtype)))
267    # This log_prob has the property that it is likely to attract
268    # the HMC flow toward, and below, zero...but for x <=0,
269    # log_prob(x) = -inf, which should result in rejection, as well
270    # as a non-finite log_prob.  Thus, this distribution gives us an opportunity
271    # to test out the kernel results ability to correctly capture rejections due
272    # to finite AND non-finite reasons.
273    # Why use a non-constant gradient?  This ensures the leapfrog integrator
274    # will not be exact.
275
276    num_results = 1000
277    # Large step size, will give rejections due to integration error in addition
278    # to rejection due to going into a region of log_prob = -inf.
279    step_size = 0.1
280    num_leapfrog_steps = 5
281    num_chains = 2
282
283    with self.test_session(graph=ops.Graph()) as sess:
284
285      # Start multiple independent chains.
286      initial_state = ops.convert_to_tensor([0.1] * num_chains)
287
288      states, kernel_results = hmc.sample_chain(
289          num_results=num_results,
290          target_log_prob_fn=log_prob,
291          current_state=initial_state,
292          step_size=step_size,
293          num_leapfrog_steps=num_leapfrog_steps,
294          seed=42)
295
296      states_, kernel_results_ = sess.run([states, kernel_results])
297      pstates_ = kernel_results_.proposed_state
298
299      neg_inf_mask = np.isneginf(kernel_results_.proposed_target_log_prob)
300
301      # First:  Test that the mathematical properties of the above log prob
302      # function in conjunction with HMC show up as expected in kernel_results_.
303
304      # We better have log_prob = -inf some of the time.
305      self.assertLess(0, neg_inf_mask.sum())
306      # We better have some rejections due to something other than -inf.
307      self.assertLess(neg_inf_mask.sum(), (~kernel_results_.is_accepted).sum())
308      # We better have been accepted a decent amount, even near the end of the
309      # chain, or else this HMC run just got stuck at some point.
310      self.assertLess(
311          0.1, kernel_results_.is_accepted[int(0.9 * num_results):].mean())
312      # We better not have any NaNs in proposed state or log_prob.
313      # We may have some NaN in grads, which involve multiplication/addition due
314      # to gradient rules.  This is the known "NaN grad issue with tf.where."
315      self.assertAllEqual(np.zeros_like(states_),
316                          np.isnan(kernel_results_.proposed_target_log_prob))
317      self.assertAllEqual(np.zeros_like(states_),
318                          np.isnan(states_))
319      # We better not have any +inf in states, grads, or log_prob.
320      self.assertAllEqual(np.zeros_like(states_),
321                          np.isposinf(kernel_results_.proposed_target_log_prob))
322      self.assertAllEqual(
323          np.zeros_like(states_),
324          np.isposinf(kernel_results_.proposed_grads_target_log_prob[0]))
325      self.assertAllEqual(np.zeros_like(states_),
326                          np.isposinf(states_))
327
328      # Second:  Test that kernel_results is congruent with itself and
329      # acceptance/rejection of states.
330
331      # Proposed state is negative iff proposed target log prob is -inf.
332      np.testing.assert_array_less(pstates_[neg_inf_mask], 0.)
333      np.testing.assert_array_less(0., pstates_[~neg_inf_mask])
334
335      # Acceptance probs are zero whenever proposed state is negative.
336      self.assertAllEqual(
337          np.zeros_like(pstates_[neg_inf_mask]),
338          kernel_results_.acceptance_probs[neg_inf_mask])
339
340      # The move is accepted ==> state = proposed state.
341      self.assertAllEqual(
342          states_[kernel_results_.is_accepted],
343          pstates_[kernel_results_.is_accepted],
344      )
345      # The move was rejected <==> state[t] == state[t - 1].
346      for t in range(1, num_results):
347        for i in range(num_chains):
348          if kernel_results_.is_accepted[t, i]:
349            self.assertNotEqual(states_[t, i], states_[t - 1, i])
350          else:
351            self.assertEqual(states_[t, i], states_[t - 1, i])
352
353  def _kernel_leaves_target_invariant(self, initial_draws,
354                                      independent_chain_ndims,
355                                      sess, feed_dict=None):
356    def log_gamma_log_prob(x):
357      event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x))
358      return self._log_gamma_log_prob(x, event_dims)
359
360    def fake_log_prob(x):
361      """Cooled version of the target distribution."""
362      return 1.1 * log_gamma_log_prob(x)
363
364    step_size = array_ops.placeholder(np.float32, [], name="step_size")
365
366    if feed_dict is None:
367      feed_dict = {}
368
369    feed_dict[step_size] = 0.4
370
371    sample, kernel_results = hmc.kernel(
372        target_log_prob_fn=log_gamma_log_prob,
373        current_state=initial_draws,
374        step_size=step_size,
375        num_leapfrog_steps=5,
376        seed=43)
377
378    bad_sample, bad_kernel_results = hmc.kernel(
379        target_log_prob_fn=fake_log_prob,
380        current_state=initial_draws,
381        step_size=step_size,
382        num_leapfrog_steps=5,
383        seed=44)
384
385    [
386        acceptance_probs_,
387        bad_acceptance_probs_,
388        initial_draws_,
389        updated_draws_,
390        fake_draws_,
391    ] = sess.run([
392        kernel_results.acceptance_probs,
393        bad_kernel_results.acceptance_probs,
394        initial_draws,
395        sample,
396        bad_sample,
397    ], feed_dict)
398
399    # Confirm step size is small enough that we usually accept.
400    self.assertGreater(acceptance_probs_.mean(), 0.5)
401    self.assertGreater(bad_acceptance_probs_.mean(), 0.5)
402
403    # Confirm step size is large enough that we sometimes reject.
404    self.assertLess(acceptance_probs_.mean(), 0.99)
405    self.assertLess(bad_acceptance_probs_.mean(), 0.99)
406
407    _, ks_p_value_true = stats.ks_2samp(initial_draws_.flatten(),
408                                        updated_draws_.flatten())
409    _, ks_p_value_fake = stats.ks_2samp(initial_draws_.flatten(),
410                                        fake_draws_.flatten())
411
412    logging_ops.vlog(1, "acceptance rate for true target: {}".format(
413        acceptance_probs_.mean()))
414    logging_ops.vlog(1, "acceptance rate for fake target: {}".format(
415        bad_acceptance_probs_.mean()))
416    logging_ops.vlog(1, "K-S p-value for true target: {}".format(
417        ks_p_value_true))
418    logging_ops.vlog(1, "K-S p-value for fake target: {}".format(
419        ks_p_value_fake))
420    # Make sure that the MCMC update hasn't changed the empirical CDF much.
421    self.assertGreater(ks_p_value_true, 1e-3)
422    # Confirm that targeting the wrong distribution does
423    # significantly change the empirical CDF.
424    self.assertLess(ks_p_value_fake, 1e-6)
425
426  def _kernel_leaves_target_invariant_wrapper(self, independent_chain_ndims):
427    """Tests that the kernel leaves the target distribution invariant.
428
429    Draws some independent samples from the target distribution,
430    applies an iteration of the MCMC kernel, then runs a
431    Kolmogorov-Smirnov test to determine if the distribution of the
432    MCMC-updated samples has changed.
433
434    We also confirm that running the kernel with a different log-pdf
435    does change the target distribution. (And that we can detect that.)
436
437    Args:
438      independent_chain_ndims: Python `int` scalar representing the number of
439        dims associated with independent chains.
440    """
441    with self.test_session(graph=ops.Graph()) as sess:
442      initial_draws = np.log(np.random.gamma(self._shape_param,
443                                             size=[50000, 2, 2]))
444      initial_draws -= np.log(self._rate_param)
445      x_ph = array_ops.placeholder(np.float32, name="x_ph")
446
447      feed_dict = {x_ph: initial_draws}
448
449      self._kernel_leaves_target_invariant(x_ph, independent_chain_ndims,
450                                           sess, feed_dict)
451
452  def testKernelLeavesTargetInvariant1(self):
453    self._kernel_leaves_target_invariant_wrapper(1)
454
455  def testKernelLeavesTargetInvariant2(self):
456    self._kernel_leaves_target_invariant_wrapper(2)
457
458  def testKernelLeavesTargetInvariant3(self):
459    self._kernel_leaves_target_invariant_wrapper(3)
460
461  def _ais_gets_correct_log_normalizer(self, init, independent_chain_ndims,
462                                       sess, feed_dict=None):
463    counter = collections.Counter()
464
465    def proposal_log_prob(x):
466      counter["proposal_calls"] += 1
467      event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x))
468      return -0.5 * math_ops.reduce_sum(x**2. + np.log(2 * np.pi),
469                                        axis=event_dims)
470
471    def target_log_prob(x):
472      counter["target_calls"] += 1
473      event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x))
474      return self._log_gamma_log_prob(x, event_dims)
475
476    if feed_dict is None:
477      feed_dict = {}
478
479    num_steps = 200
480
481    _, ais_weights, _ = hmc.sample_annealed_importance_chain(
482        proposal_log_prob_fn=proposal_log_prob,
483        num_steps=num_steps,
484        target_log_prob_fn=target_log_prob,
485        step_size=0.5,
486        current_state=init,
487        num_leapfrog_steps=2,
488        seed=45)
489
490    # We have three calls because the calculation of `ais_weights` entails
491    # another call to the `convex_combined_log_prob_fn`. We could refactor
492    # things to avoid this, if needed (eg, b/72994218).
493    self.assertAllEqual(dict(target_calls=3, proposal_calls=3), counter)
494
495    event_shape = array_ops.shape(init)[independent_chain_ndims:]
496    event_size = math_ops.reduce_prod(event_shape)
497
498    log_true_normalizer = (
499        -self._shape_param * math_ops.log(self._rate_param)
500        + math_ops.lgamma(self._shape_param))
501    log_true_normalizer *= math_ops.cast(event_size, log_true_normalizer.dtype)
502
503    log_estimated_normalizer = (math_ops.reduce_logsumexp(ais_weights)
504                                - np.log(num_steps))
505
506    ratio_estimate_true = math_ops.exp(ais_weights - log_true_normalizer)
507    ais_weights_size = array_ops.size(ais_weights)
508    standard_error = math_ops.sqrt(
509        _reduce_variance(ratio_estimate_true)
510        / math_ops.cast(ais_weights_size, ratio_estimate_true.dtype))
511
512    [
513        ratio_estimate_true_,
514        log_true_normalizer_,
515        log_estimated_normalizer_,
516        standard_error_,
517        ais_weights_size_,
518        event_size_,
519    ] = sess.run([
520        ratio_estimate_true,
521        log_true_normalizer,
522        log_estimated_normalizer,
523        standard_error,
524        ais_weights_size,
525        event_size,
526    ], feed_dict)
527
528    logging_ops.vlog(1, "        log_true_normalizer: {}\n"
529                        "   log_estimated_normalizer: {}\n"
530                        "           ais_weights_size: {}\n"
531                        "                 event_size: {}\n".format(
532                            log_true_normalizer_,
533                            log_estimated_normalizer_,
534                            ais_weights_size_,
535                            event_size_))
536    self.assertNear(ratio_estimate_true_.mean(), 1., 4. * standard_error_)
537
538  def _ais_gets_correct_log_normalizer_wrapper(self, independent_chain_ndims):
539    """Tests that AIS yields reasonable estimates of normalizers."""
540    with self.test_session(graph=ops.Graph()) as sess:
541      x_ph = array_ops.placeholder(np.float32, name="x_ph")
542      initial_draws = np.random.normal(size=[30, 2, 1])
543      self._ais_gets_correct_log_normalizer(
544          x_ph,
545          independent_chain_ndims,
546          sess,
547          feed_dict={x_ph: initial_draws})
548
549  def testAIS1(self):
550    self._ais_gets_correct_log_normalizer_wrapper(1)
551
552  def testAIS2(self):
553    self._ais_gets_correct_log_normalizer_wrapper(2)
554
555  def testAIS3(self):
556    self._ais_gets_correct_log_normalizer_wrapper(3)
557
558  def testSampleAIChainSeedReproducibleWorksCorrectly(self):
559    with self.test_session(graph=ops.Graph()) as sess:
560      independent_chain_ndims = 1
561      x = np.random.rand(4, 3, 2)
562
563      def proposal_log_prob(x):
564        event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x))
565        return -0.5 * math_ops.reduce_sum(x**2. + np.log(2 * np.pi),
566                                          axis=event_dims)
567
568      def target_log_prob(x):
569        event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x))
570        return self._log_gamma_log_prob(x, event_dims)
571
572      ais_kwargs = dict(
573          proposal_log_prob_fn=proposal_log_prob,
574          num_steps=200,
575          target_log_prob_fn=target_log_prob,
576          step_size=0.5,
577          current_state=x,
578          num_leapfrog_steps=2,
579          seed=53)
580
581      _, ais_weights0, _ = hmc.sample_annealed_importance_chain(
582          **ais_kwargs)
583
584      _, ais_weights1, _ = hmc.sample_annealed_importance_chain(
585          **ais_kwargs)
586
587      [ais_weights0_, ais_weights1_] = sess.run([
588          ais_weights0, ais_weights1])
589
590      self.assertAllClose(ais_weights0_, ais_weights1_,
591                          atol=1e-5, rtol=1e-5)
592
593  def testNanRejection(self):
594    """Tests that an update that yields NaN potentials gets rejected.
595
596    We run HMC with a target distribution that returns NaN
597    log-likelihoods if any element of x < 0, and unit-scale
598    exponential log-likelihoods otherwise. The exponential potential
599    pushes x towards 0, ensuring that any reasonably large update will
600    push us over the edge into NaN territory.
601    """
602    def _unbounded_exponential_log_prob(x):
603      """An exponential distribution with log-likelihood NaN for x < 0."""
604      per_element_potentials = array_ops.where(
605          x < 0.,
606          array_ops.fill(array_ops.shape(x), x.dtype.as_numpy_dtype(np.nan)),
607          -x)
608      return math_ops.reduce_sum(per_element_potentials)
609
610    with self.test_session(graph=ops.Graph()) as sess:
611      initial_x = math_ops.linspace(0.01, 5, 10)
612      updated_x, kernel_results = hmc.kernel(
613          target_log_prob_fn=_unbounded_exponential_log_prob,
614          current_state=initial_x,
615          step_size=2.,
616          num_leapfrog_steps=5,
617          seed=46)
618      initial_x_, updated_x_, acceptance_probs_ = sess.run(
619          [initial_x, updated_x, kernel_results.acceptance_probs])
620
621      logging_ops.vlog(1, "initial_x = {}".format(initial_x_))
622      logging_ops.vlog(1, "updated_x = {}".format(updated_x_))
623      logging_ops.vlog(1, "acceptance_probs = {}".format(acceptance_probs_))
624
625      self.assertAllEqual(initial_x_, updated_x_)
626      self.assertEqual(acceptance_probs_, 0.)
627
628  def testNanFromGradsDontPropagate(self):
629    """Test that update with NaN gradients does not cause NaN in results."""
630    def _nan_log_prob_with_nan_gradient(x):
631      return np.nan * math_ops.reduce_sum(x)
632
633    with self.test_session(graph=ops.Graph()) as sess:
634      initial_x = math_ops.linspace(0.01, 5, 10)
635      updated_x, kernel_results = hmc.kernel(
636          target_log_prob_fn=_nan_log_prob_with_nan_gradient,
637          current_state=initial_x,
638          step_size=2.,
639          num_leapfrog_steps=5,
640          seed=47)
641      initial_x_, updated_x_, acceptance_probs_ = sess.run(
642          [initial_x, updated_x, kernel_results.acceptance_probs])
643
644      logging_ops.vlog(1, "initial_x = {}".format(initial_x_))
645      logging_ops.vlog(1, "updated_x = {}".format(updated_x_))
646      logging_ops.vlog(1, "acceptance_probs = {}".format(acceptance_probs_))
647
648      self.assertAllEqual(initial_x_, updated_x_)
649      self.assertEqual(acceptance_probs_, 0.)
650
651      self.assertAllFinite(
652          gradients_ops.gradients(updated_x, initial_x)[0].eval())
653      self.assertAllEqual([True], [g is None for g in gradients_ops.gradients(
654          kernel_results.proposed_grads_target_log_prob, initial_x)])
655      self.assertAllEqual([False], [g is None for g in gradients_ops.gradients(
656          kernel_results.proposed_grads_target_log_prob,
657          kernel_results.proposed_state)])
658
659      # Gradients of the acceptance probs and new log prob are not finite.
660      # self.assertAllFinite(
661      #     gradients_ops.gradients(acceptance_probs, initial_x)[0].eval())
662      # self.assertAllFinite(
663      #     gradients_ops.gradients(new_log_prob, initial_x)[0].eval())
664
665  def _testChainWorksDtype(self, dtype):
666    with self.test_session(graph=ops.Graph()) as sess:
667      states, kernel_results = hmc.sample_chain(
668          num_results=10,
669          target_log_prob_fn=lambda x: -math_ops.reduce_sum(x**2., axis=-1),
670          current_state=np.zeros(5).astype(dtype),
671          step_size=0.01,
672          num_leapfrog_steps=10,
673          seed=48)
674      states_, acceptance_probs_ = sess.run(
675          [states, kernel_results.acceptance_probs])
676      self.assertEqual(dtype, states_.dtype)
677      self.assertEqual(dtype, acceptance_probs_.dtype)
678
679  def testChainWorksIn64Bit(self):
680    self._testChainWorksDtype(np.float64)
681
682  def testChainWorksIn16Bit(self):
683    self._testChainWorksDtype(np.float16)
684
685  def testChainWorksCorrelatedMultivariate(self):
686    dtype = np.float32
687    true_mean = dtype([0, 0])
688    true_cov = dtype([[1, 0.5],
689                      [0.5, 1]])
690    num_results = 2000
691    counter = collections.Counter()
692    with self.test_session(graph=ops.Graph()) as sess:
693      def target_log_prob(x, y):
694        counter["target_calls"] += 1
695        # Corresponds to unnormalized MVN.
696        # z = matmul(inv(chol(true_cov)), [x, y] - true_mean)
697        z = array_ops.stack([x, y], axis=-1) - true_mean
698        z = array_ops.squeeze(
699            gen_linalg_ops.matrix_triangular_solve(
700                np.linalg.cholesky(true_cov),
701                z[..., array_ops.newaxis]),
702            axis=-1)
703        return -0.5 * math_ops.reduce_sum(z**2., axis=-1)
704      states, _ = hmc.sample_chain(
705          num_results=num_results,
706          target_log_prob_fn=target_log_prob,
707          current_state=[dtype(-2), dtype(2)],
708          step_size=[0.5, 0.5],
709          num_leapfrog_steps=2,
710          num_burnin_steps=200,
711          num_steps_between_results=1,
712          seed=54)
713      self.assertAllEqual(dict(target_calls=2), counter)
714      states = array_ops.stack(states, axis=-1)
715      self.assertEqual(num_results, states.shape[0].value)
716      sample_mean = math_ops.reduce_mean(states, axis=0)
717      x = states - sample_mean
718      sample_cov = math_ops.matmul(x, x, transpose_a=True) / dtype(num_results)
719      [sample_mean_, sample_cov_] = sess.run([
720          sample_mean, sample_cov])
721      self.assertAllClose(true_mean, sample_mean_,
722                          atol=0.05, rtol=0.)
723      self.assertAllClose(true_cov, sample_cov_,
724                          atol=0., rtol=0.1)
725
726
727class _EnergyComputationTest(object):
728
729  def testHandlesNanFromPotential(self):
730    with self.test_session(graph=ops.Graph()) as sess:
731      x = [1, np.inf, -np.inf, np.nan]
732      target_log_prob, proposed_target_log_prob = [
733          self.dtype(x.flatten()) for x in np.meshgrid(x, x)]
734      num_chains = len(target_log_prob)
735      dummy_momentums = [-1, 1]
736      momentums = [self.dtype([dummy_momentums] * num_chains)]
737      proposed_momentums = [self.dtype([dummy_momentums] * num_chains)]
738
739      target_log_prob = ops.convert_to_tensor(target_log_prob)
740      momentums = [ops.convert_to_tensor(momentums[0])]
741      proposed_target_log_prob = ops.convert_to_tensor(proposed_target_log_prob)
742      proposed_momentums = [ops.convert_to_tensor(proposed_momentums[0])]
743
744      energy = _compute_energy_change(
745          target_log_prob,
746          momentums,
747          proposed_target_log_prob,
748          proposed_momentums,
749          independent_chain_ndims=1)
750      grads = gradients_ops.gradients(energy, momentums)
751
752      [actual_energy, grads_] = sess.run([energy, grads])
753
754      # Ensure energy is `inf` (note: that's positive inf) in weird cases and
755      # finite otherwise.
756      expected_energy = self.dtype([0] + [np.inf]*(num_chains - 1))
757      self.assertAllEqual(expected_energy, actual_energy)
758
759      # Ensure gradient is finite.
760      self.assertAllEqual(np.ones_like(grads_).astype(np.bool),
761                          np.isfinite(grads_))
762
763  def testHandlesNanFromKinetic(self):
764    with self.test_session(graph=ops.Graph()) as sess:
765      x = [1, np.inf, -np.inf, np.nan]
766      momentums, proposed_momentums = [
767          [np.reshape(self.dtype(x), [-1, 1])]
768          for x in np.meshgrid(x, x)]
769      num_chains = len(momentums[0])
770      target_log_prob = np.ones(num_chains, self.dtype)
771      proposed_target_log_prob = np.ones(num_chains, self.dtype)
772
773      target_log_prob = ops.convert_to_tensor(target_log_prob)
774      momentums = [ops.convert_to_tensor(momentums[0])]
775      proposed_target_log_prob = ops.convert_to_tensor(proposed_target_log_prob)
776      proposed_momentums = [ops.convert_to_tensor(proposed_momentums[0])]
777
778      energy = _compute_energy_change(
779          target_log_prob,
780          momentums,
781          proposed_target_log_prob,
782          proposed_momentums,
783          independent_chain_ndims=1)
784      grads = gradients_ops.gradients(energy, momentums)
785
786      [actual_energy, grads_] = sess.run([energy, grads])
787
788      # Ensure energy is `inf` (note: that's positive inf) in weird cases and
789      # finite otherwise.
790      expected_energy = self.dtype([0] + [np.inf]*(num_chains - 1))
791      self.assertAllEqual(expected_energy, actual_energy)
792
793      # Ensure gradient is finite.
794      g = grads_[0].reshape([len(x), len(x)])[:, 0]
795      self.assertAllEqual(np.ones_like(g).astype(np.bool), np.isfinite(g))
796
797      # The remaining gradients are nan because the momentum was itself nan or
798      # inf.
799      g = grads_[0].reshape([len(x), len(x)])[:, 1:]
800      self.assertAllEqual(np.ones_like(g).astype(np.bool), np.isnan(g))
801
802
803class EnergyComputationTest16(test.TestCase, _EnergyComputationTest):
804  dtype = np.float16
805
806
807class EnergyComputationTest32(test.TestCase, _EnergyComputationTest):
808  dtype = np.float32
809
810
811class EnergyComputationTest64(test.TestCase, _EnergyComputationTest):
812  dtype = np.float64
813
814
815class _HMCHandlesLists(object):
816
817  def testStateParts(self):
818    with self.test_session(graph=ops.Graph()) as sess:
819      dist_x = normal_lib.Normal(loc=self.dtype(0), scale=self.dtype(1))
820      dist_y = independent_lib.Independent(
821          gamma_lib.Gamma(concentration=self.dtype([1, 2]),
822                          rate=self.dtype([0.5, 0.75])),
823          reinterpreted_batch_ndims=1)
824      def target_log_prob(x, y):
825        return dist_x.log_prob(x) + dist_y.log_prob(y)
826      x0 = [dist_x.sample(seed=1), dist_y.sample(seed=2)]
827      samples, _ = hmc.sample_chain(
828          num_results=int(2e3),
829          target_log_prob_fn=target_log_prob,
830          current_state=x0,
831          step_size=0.85,
832          num_leapfrog_steps=3,
833          num_burnin_steps=int(250),
834          seed=49)
835      actual_means = [math_ops.reduce_mean(s, axis=0) for s in samples]
836      actual_vars = [_reduce_variance(s, axis=0) for s in samples]
837      expected_means = [dist_x.mean(), dist_y.mean()]
838      expected_vars = [dist_x.variance(), dist_y.variance()]
839      [
840          actual_means_,
841          actual_vars_,
842          expected_means_,
843          expected_vars_,
844      ] = sess.run([
845          actual_means,
846          actual_vars,
847          expected_means,
848          expected_vars,
849      ])
850      self.assertAllClose(expected_means_, actual_means_, atol=0.05, rtol=0.16)
851      self.assertAllClose(expected_vars_, actual_vars_, atol=0., rtol=0.25)
852
853
854class HMCHandlesLists32(_HMCHandlesLists, test.TestCase):
855  dtype = np.float32
856
857
858class HMCHandlesLists64(_HMCHandlesLists, test.TestCase):
859  dtype = np.float64
860
861
862if __name__ == "__main__":
863  test.main()
864