• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Monte Carlo integration and helpers.
16
17@@expectation
18@@expectation_importance_sampler
19@@expectation_importance_sampler_logspace
20"""
21
22from __future__ import absolute_import
23from __future__ import division
24from __future__ import print_function
25
26from tensorflow.python.framework import ops
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import math_ops
29from tensorflow.python.ops import nn
30from tensorflow.python.util import deprecation
31
32__all__ = [
33    'expectation',
34    'expectation_importance_sampler',
35    'expectation_importance_sampler_logspace',
36]
37
38
39def expectation_importance_sampler(f,
40                                   log_p,
41                                   sampling_dist_q,
42                                   z=None,
43                                   n=None,
44                                   seed=None,
45                                   name='expectation_importance_sampler'):
46  r"""Monte Carlo estimate of \\(E_p[f(Z)] = E_q[f(Z) p(Z) / q(Z)]\\).
47
48  With \\(p(z) := exp^{log_p(z)}\\), this `Op` returns
49
50  \\(n^{-1} sum_{i=1}^n [ f(z_i) p(z_i) / q(z_i) ],  z_i ~ q,\\)
51  \\(\approx E_q[ f(Z) p(Z) / q(Z) ]\\)
52  \\(=       E_p[f(Z)]\\)
53
54  This integral is done in log-space with max-subtraction to better handle the
55  often extreme values that `f(z) p(z) / q(z)` can take on.
56
57  If `f >= 0`, it is up to 2x more efficient to exponentiate the result of
58  `expectation_importance_sampler_logspace` applied to `Log[f]`.
59
60  User supplies either `Tensor` of samples `z`, or number of samples to draw `n`
61
62  Args:
63    f: Callable mapping samples from `sampling_dist_q` to `Tensors` with shape
64      broadcastable to `q.batch_shape`.
65      For example, `f` works "just like" `q.log_prob`.
66    log_p:  Callable mapping samples from `sampling_dist_q` to `Tensors` with
67      shape broadcastable to `q.batch_shape`.
68      For example, `log_p` works "just like" `sampling_dist_q.log_prob`.
69    sampling_dist_q:  The sampling distribution.
70      `tfp.distributions.Distribution`.
71      `float64` `dtype` recommended.
72      `log_p` and `q` should be supported on the same set.
73    z:  `Tensor` of samples from `q`, produced by `q.sample` for some `n`.
74    n:  Integer `Tensor`.  Number of samples to generate if `z` is not provided.
75    seed:  Python integer to seed the random number generator.
76    name:  A name to give this `Op`.
77
78  Returns:
79    The importance sampling estimate.  `Tensor` with `shape` equal
80      to batch shape of `q`, and `dtype` = `q.dtype`.
81  """
82  q = sampling_dist_q
83  with ops.name_scope(name, values=[z, n]):
84    z = _get_samples(q, z, n, seed)
85
86    log_p_z = log_p(z)
87    q_log_prob_z = q.log_prob(z)
88
89    def _importance_sampler_positive_f(log_f_z):
90      # Same as expectation_importance_sampler_logspace, but using Tensors
91      # rather than samples and functions.  Allows us to sample once.
92      log_values = log_f_z + log_p_z - q_log_prob_z
93      return _logspace_mean(log_values)
94
95    # With \\(f_{plus}(z) = max(0, f(z)), f_{minus}(z) = max(0, -f(z))\\),
96    # \\(E_p[f(Z)] = E_p[f_{plus}(Z)] - E_p[f_{minus}(Z)]\\)
97    # \\(          = E_p[f_{plus}(Z) + 1] - E_p[f_{minus}(Z) + 1]\\)
98    # Without incurring bias, 1 is added to each to prevent zeros in logspace.
99    # The logarithm is approximately linear around 1 + epsilon, so this is good
100    # for small values of 'z' as well.
101    f_z = f(z)
102    log_f_plus_z = math_ops.log(nn.relu(f_z) + 1.)
103    log_f_minus_z = math_ops.log(nn.relu(-1. * f_z) + 1.)
104
105    log_f_plus_integral = _importance_sampler_positive_f(log_f_plus_z)
106    log_f_minus_integral = _importance_sampler_positive_f(log_f_minus_z)
107
108  return math_ops.exp(log_f_plus_integral) - math_ops.exp(log_f_minus_integral)
109
110
111def expectation_importance_sampler_logspace(
112    log_f,
113    log_p,
114    sampling_dist_q,
115    z=None,
116    n=None,
117    seed=None,
118    name='expectation_importance_sampler_logspace'):
119  r"""Importance sampling with a positive function, in log-space.
120
121  With \\(p(z) := exp^{log_p(z)}\\), and \\(f(z) = exp{log_f(z)}\\),
122  this `Op` returns
123
124  \\(Log[ n^{-1} sum_{i=1}^n [ f(z_i) p(z_i) / q(z_i) ] ],  z_i ~ q,\\)
125  \\(\approx Log[ E_q[ f(Z) p(Z) / q(Z) ] ]\\)
126  \\(=       Log[E_p[f(Z)]]\\)
127
128  This integral is done in log-space with max-subtraction to better handle the
129  often extreme values that `f(z) p(z) / q(z)` can take on.
130
131  In contrast to `expectation_importance_sampler`, this `Op` returns values in
132  log-space.
133
134
135  User supplies either `Tensor` of samples `z`, or number of samples to draw `n`
136
137  Args:
138    log_f: Callable mapping samples from `sampling_dist_q` to `Tensors` with
139      shape broadcastable to `q.batch_shape`.
140      For example, `log_f` works "just like" `sampling_dist_q.log_prob`.
141    log_p:  Callable mapping samples from `sampling_dist_q` to `Tensors` with
142      shape broadcastable to `q.batch_shape`.
143      For example, `log_p` works "just like" `q.log_prob`.
144    sampling_dist_q:  The sampling distribution.
145      `tfp.distributions.Distribution`.
146      `float64` `dtype` recommended.
147      `log_p` and `q` should be supported on the same set.
148    z:  `Tensor` of samples from `q`, produced by `q.sample` for some `n`.
149    n:  Integer `Tensor`.  Number of samples to generate if `z` is not provided.
150    seed:  Python integer to seed the random number generator.
151    name:  A name to give this `Op`.
152
153  Returns:
154    Logarithm of the importance sampling estimate.  `Tensor` with `shape` equal
155      to batch shape of `q`, and `dtype` = `q.dtype`.
156  """
157  q = sampling_dist_q
158  with ops.name_scope(name, values=[z, n]):
159    z = _get_samples(q, z, n, seed)
160    log_values = log_f(z) + log_p(z) - q.log_prob(z)
161    return _logspace_mean(log_values)
162
163
164def _logspace_mean(log_values):
165  """Evaluate `Log[E[values]]` in a stable manner.
166
167  Args:
168    log_values:  `Tensor` holding `Log[values]`.
169
170  Returns:
171    `Tensor` of same `dtype` as `log_values`, reduced across dim 0.
172      `Log[Mean[values]]`.
173  """
174  # center = Max[Log[values]],  with stop-gradient
175  # The center hopefully keep the exponentiated term small.  It is canceled
176  # from the final result, so putting stop gradient on it will not change the
177  # final result.  We put stop gradient on to eliminate unnecessary computation.
178  center = array_ops.stop_gradient(_sample_max(log_values))
179
180  # centered_values = exp{Log[values] - E[Log[values]]}
181  centered_values = math_ops.exp(log_values - center)
182
183  # log_mean_of_values = Log[ E[centered_values] ] + center
184  #                    = Log[ E[exp{log_values - E[log_values]}] ] + center
185  #                    = Log[E[values]] - E[log_values] + center
186  #                    = Log[E[values]]
187  log_mean_of_values = math_ops.log(_sample_mean(centered_values)) + center
188
189  return log_mean_of_values
190
191
192@deprecation.deprecated(
193    '2018-10-01',
194    'The tf.contrib.bayesflow library has moved to '
195    'TensorFlow Probability (https://github.com/tensorflow/probability). '
196    'Use `tfp.monte_carlo.expectation` instead.',
197    warn_once=True)
198def expectation(f, samples, log_prob=None, use_reparametrization=True,
199                axis=0, keep_dims=False, name=None):
200  r"""Computes the Monte-Carlo approximation of \\(E_p[f(X)]\\).
201
202  This function computes the Monte-Carlo approximation of an expectation, i.e.,
203
204  \\(E_p[f(X)] \approx= m^{-1} sum_i^m f(x_j),  x_j\  ~iid\ p(X)\\)
205
206  where:
207
208  - `x_j = samples[j, ...]`,
209  - `log(p(samples)) = log_prob(samples)` and
210  - `m = prod(shape(samples)[axis])`.
211
212  Tricks: Reparameterization and Score-Gradient
213
214  When p is "reparameterized", i.e., a diffeomorphic transformation of a
215  parameterless distribution (e.g.,
216  `Normal(Y; m, s) <=> Y = sX + m, X ~ Normal(0,1)`), we can swap gradient and
217  expectation, i.e.,
218  grad[ Avg{ \\(s_i : i=1...n\\) } ] = Avg{ grad[\\(s_i\\)] : i=1...n } where
219  S_n = Avg{\\(s_i\\)}` and `\\(s_i = f(x_i), x_i ~ p\\).
220
221  However, if p is not reparameterized, TensorFlow's gradient will be incorrect
222  since the chain-rule stops at samples of non-reparameterized distributions.
223  (The non-differentiated result, `approx_expectation`, is the same regardless
224  of `use_reparametrization`.) In this circumstance using the Score-Gradient
225  trick results in an unbiased gradient, i.e.,
226
227  ```none
228  grad[ E_p[f(X)] ]
229  = grad[ int dx p(x) f(x) ]
230  = int dx grad[ p(x) f(x) ]
231  = int dx [ p'(x) f(x) + p(x) f'(x) ]
232  = int dx p(x) [p'(x) / p(x) f(x) + f'(x) ]
233  = int dx p(x) grad[ f(x) p(x) / stop_grad[p(x)] ]
234  = E_p[ grad[ f(x) p(x) / stop_grad[p(x)] ] ]
235  ```
236
237  Unless p is not reparametrized, it is usually preferable to
238  `use_reparametrization = True`.
239
240  Warning: users are responsible for verifying `p` is a "reparameterized"
241  distribution.
242
243  Example Use:
244
245  ```python
246  import tensorflow_probability as tfp
247  tfd = tfp.distributions
248
249  # Monte-Carlo approximation of a reparameterized distribution, e.g., Normal.
250
251  num_draws = int(1e5)
252  p = tfd.Normal(loc=0., scale=1.)
253  q = tfd.Normal(loc=1., scale=2.)
254  exact_kl_normal_normal = tfd.kl_divergence(p, q)
255  # ==> 0.44314718
256  approx_kl_normal_normal = tfp.monte_carlo.expectation(
257      f=lambda x: p.log_prob(x) - q.log_prob(x),
258      samples=p.sample(num_draws, seed=42),
259      log_prob=p.log_prob,
260      use_reparametrization=(p.reparameterization_type
261                             == distribution.FULLY_REPARAMETERIZED))
262  # ==> 0.44632751
263  # Relative Error: <1%
264
265  # Monte-Carlo approximation of non-reparameterized distribution, e.g., Gamma.
266
267  num_draws = int(1e5)
268  p = ds.Gamma(concentration=1., rate=1.)
269  q = ds.Gamma(concentration=2., rate=3.)
270  exact_kl_gamma_gamma = tfd.kl_divergence(p, q)
271  # ==> 0.37999129
272  approx_kl_gamma_gamma = tfp.monte_carlo.expectation(
273      f=lambda x: p.log_prob(x) - q.log_prob(x),
274      samples=p.sample(num_draws, seed=42),
275      log_prob=p.log_prob,
276      use_reparametrization=(p.reparameterization_type
277                             == distribution.FULLY_REPARAMETERIZED))
278  # ==> 0.37696719
279  # Relative Error: <1%
280
281  # For comparing the gradients, see `monte_carlo_test.py`.
282  ```
283
284  Note: The above example is for illustration only. To compute approximate
285  KL-divergence, the following is preferred:
286
287  ```python
288  approx_kl_p_q = tfp.vi.monte_carlo_csiszar_f_divergence(
289      f=bf.kl_reverse,
290      p_log_prob=q.log_prob,
291      q=p,
292      num_draws=num_draws)
293  ```
294
295  Args:
296    f: Python callable which can return `f(samples)`.
297    samples: `Tensor` of samples used to form the Monte-Carlo approximation of
298      \\(E_p[f(X)]\\).  A batch of samples should be indexed by `axis`
299      dimensions.
300    log_prob: Python callable which can return `log_prob(samples)`. Must
301      correspond to the natural-logarithm of the pdf/pmf of each sample. Only
302      required/used if `use_reparametrization=False`.
303      Default value: `None`.
304    use_reparametrization: Python `bool` indicating that the approximation
305      should use the fact that the gradient of samples is unbiased. Whether
306      `True` or `False`, this arg only affects the gradient of the resulting
307      `approx_expectation`.
308      Default value: `True`.
309    axis: The dimensions to average. If `None`, averages all
310      dimensions.
311      Default value: `0` (the left-most dimension).
312    keep_dims: If True, retains averaged dimensions using size `1`.
313      Default value: `False`.
314    name: A `name_scope` for operations created by this function.
315      Default value: `None` (which implies "expectation").
316
317  Returns:
318    approx_expectation: `Tensor` corresponding to the Monte-Carlo approximation
319      of \\(E_p[f(X)]\\).
320
321  Raises:
322    ValueError: if `f` is not a Python `callable`.
323    ValueError: if `use_reparametrization=False` and `log_prob` is not a Python
324      `callable`.
325  """
326
327  with ops.name_scope(name, 'expectation', [samples]):
328    if not callable(f):
329      raise ValueError('`f` must be a callable function.')
330    if use_reparametrization:
331      return math_ops.reduce_mean(f(samples), axis=axis, keepdims=keep_dims)
332    else:
333      if not callable(log_prob):
334        raise ValueError('`log_prob` must be a callable function.')
335      stop = array_ops.stop_gradient  # For readability.
336      x = stop(samples)
337      logpx = log_prob(x)
338      fx = f(x)  # Call `f` once in case it has side-effects.
339      # We now rewrite f(x) so that:
340      #   `grad[f(x)] := grad[f(x)] + f(x) * grad[logqx]`.
341      # To achieve this, we use a trick that
342      #   `h(x) - stop(h(x)) == zeros_like(h(x))`
343      # but its gradient is grad[h(x)].
344      # Note that IEEE754 specifies that `x - x == 0.` and `x + 0. == x`, hence
345      # this trick loses no precision. For more discussion regarding the
346      # relevant portions of the IEEE754 standard, see the StackOverflow
347      # question,
348      # "Is there a floating point value of x, for which x-x == 0 is false?"
349      # http://stackoverflow.com/q/2686644
350      fx += stop(fx) * (logpx - stop(logpx))  # Add zeros_like(logpx).
351      return math_ops.reduce_mean(fx, axis=axis, keepdims=keep_dims)
352
353
354def _sample_mean(values):
355  """Mean over sample indices.  In this module this is always [0]."""
356  return math_ops.reduce_mean(values, axis=[0])
357
358
359def _sample_max(values):
360  """Max over sample indices.  In this module this is always [0]."""
361  return math_ops.reduce_max(values, axis=[0])
362
363
364def _get_samples(dist, z, n, seed):
365  """Check args and return samples."""
366  with ops.name_scope('get_samples', values=[z, n]):
367    if (n is None) == (z is None):
368      raise ValueError(
369          'Must specify exactly one of arguments "n" and "z".  Found: '
370          'n = %s, z = %s' % (n, z))
371    if n is not None:
372      return dist.sample(n, seed=seed)
373    else:
374      return ops.convert_to_tensor(z, name='z')
375