• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Multivariate Normal distribution classes."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.contrib.distributions.python.ops import distribution_util
22from tensorflow.contrib.distributions.python.ops.bijectors import AffineLinearOperator
23from tensorflow.python.framework import ops
24from tensorflow.python.ops import array_ops
25from tensorflow.python.ops import math_ops
26from tensorflow.python.ops.distributions import kullback_leibler
27from tensorflow.python.ops.distributions import normal
28from tensorflow.python.ops.distributions import transformed_distribution
29from tensorflow.python.ops.linalg import linalg
30from tensorflow.python.util import deprecation
31
32
33__all__ = [
34    "MultivariateNormalLinearOperator",
35]
36
37
38_mvn_sample_note = """
39`value` is a batch vector with compatible shape if `value` is a `Tensor` whose
40shape can be broadcast up to either:
41
42```python
43self.batch_shape + self.event_shape
44```
45
46or
47
48```python
49[M1, ..., Mm] + self.batch_shape + self.event_shape
50```
51
52"""
53
54
55# TODO(b/35290280): Import in `../../__init__.py` after adding unit-tests.
56class MultivariateNormalLinearOperator(
57    transformed_distribution.TransformedDistribution):
58  """The multivariate normal distribution on `R^k`.
59
60  The Multivariate Normal distribution is defined over `R^k` and parameterized
61  by a (batch of) length-`k` `loc` vector (aka "mu") and a (batch of) `k x k`
62  `scale` matrix; `covariance = scale @ scale.T`, where `@` denotes
63  matrix-multiplication.
64
65  #### Mathematical Details
66
67  The probability density function (pdf) is,
68
69  ```none
70  pdf(x; loc, scale) = exp(-0.5 ||y||**2) / Z,
71  y = inv(scale) @ (x - loc),
72  Z = (2 pi)**(0.5 k) |det(scale)|,
73  ```
74
75  where:
76
77  * `loc` is a vector in `R^k`,
78  * `scale` is a linear operator in `R^{k x k}`, `cov = scale @ scale.T`,
79  * `Z` denotes the normalization constant, and,
80  * `||y||**2` denotes the squared Euclidean norm of `y`.
81
82  The MultivariateNormal distribution is a member of the [location-scale
83  family](https://en.wikipedia.org/wiki/Location-scale_family), i.e., it can be
84  constructed as,
85
86  ```none
87  X ~ MultivariateNormal(loc=0, scale=1)   # Identity scale, zero shift.
88  Y = scale @ X + loc
89  ```
90
91  #### Examples
92
93  ```python
94  import tensorflow_probability as tfp
95  tfd = tfp.distributions
96
97  # Initialize a single 3-variate Gaussian.
98  mu = [1., 2, 3]
99  cov = [[ 0.36,  0.12,  0.06],
100         [ 0.12,  0.29, -0.13],
101         [ 0.06, -0.13,  0.26]]
102  scale = tf.cholesky(cov)
103  # ==> [[ 0.6,  0. ,  0. ],
104  #      [ 0.2,  0.5,  0. ],
105  #      [ 0.1, -0.3,  0.4]])
106
107  mvn = tfd.MultivariateNormalLinearOperator(
108      loc=mu,
109      scale=tf.linalg.LinearOperatorLowerTriangular(scale))
110
111  # Covariance agrees with cholesky(cov) parameterization.
112  mvn.covariance().eval()
113  # ==> [[ 0.36,  0.12,  0.06],
114  #      [ 0.12,  0.29, -0.13],
115  #      [ 0.06, -0.13,  0.26]]
116
117  # Compute the pdf of an`R^3` observation; return a scalar.
118  mvn.prob([-1., 0, 1]).eval()  # shape: []
119
120  # Initialize a 2-batch of 3-variate Gaussians.
121  mu = [[1., 2, 3],
122        [11, 22, 33]]              # shape: [2, 3]
123  scale_diag = [[1., 2, 3],
124                [0.5, 1, 1.5]]     # shape: [2, 3]
125
126  mvn = tfd.MultivariateNormalLinearOperator(
127      loc=mu,
128      scale=tf.linalg.LinearOperatorDiag(scale_diag))
129
130  # Compute the pdf of two `R^3` observations; return a length-2 vector.
131  x = [[-0.9, 0, 0.1],
132       [-10, 0, 9]]     # shape: [2, 3]
133  mvn.prob(x).eval()    # shape: [2]
134  ```
135
136  """
137
138  @deprecation.deprecated(
139      "2018-10-01",
140      "The TensorFlow Distributions library has moved to "
141      "TensorFlow Probability "
142      "(https://github.com/tensorflow/probability). You "
143      "should update all references to use `tfp.distributions` "
144      "instead of `tf.contrib.distributions`.",
145      warn_once=True)
146  def __init__(self,
147               loc=None,
148               scale=None,
149               validate_args=False,
150               allow_nan_stats=True,
151               name="MultivariateNormalLinearOperator"):
152    """Construct Multivariate Normal distribution on `R^k`.
153
154    The `batch_shape` is the broadcast shape between `loc` and `scale`
155    arguments.
156
157    The `event_shape` is given by last dimension of the matrix implied by
158    `scale`. The last dimension of `loc` (if provided) must broadcast with this.
159
160    Recall that `covariance = scale @ scale.T`.
161
162    Additional leading dimensions (if any) will index batches.
163
164    Args:
165      loc: Floating-point `Tensor`. If this is set to `None`, `loc` is
166        implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where
167        `b >= 0` and `k` is the event size.
168      scale: Instance of `LinearOperator` with same `dtype` as `loc` and shape
169        `[B1, ..., Bb, k, k]`.
170      validate_args: Python `bool`, default `False`. Whether to validate input
171        with asserts. If `validate_args` is `False`, and the inputs are
172        invalid, correct behavior is not guaranteed.
173      allow_nan_stats: Python `bool`, default `True`. If `False`, raise an
174        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
175        batch member If `True`, batch members with valid parameters leading to
176        undefined statistics will return NaN for this statistic.
177      name: The name to give Ops created by the initializer.
178
179    Raises:
180      ValueError: if `scale` is unspecified.
181      TypeError: if not `scale.dtype.is_floating`
182    """
183    parameters = dict(locals())
184    if scale is None:
185      raise ValueError("Missing required `scale` parameter.")
186    if not scale.dtype.is_floating:
187      raise TypeError("`scale` parameter must have floating-point dtype.")
188
189    with ops.name_scope(name, values=[loc] + scale.graph_parents) as name:
190      # Since expand_dims doesn't preserve constant-ness, we obtain the
191      # non-dynamic value if possible.
192      loc = ops.convert_to_tensor(loc, name="loc") if loc is not None else loc
193      batch_shape, event_shape = distribution_util.shapes_from_loc_and_scale(
194          loc, scale)
195
196    super(MultivariateNormalLinearOperator, self).__init__(
197        distribution=normal.Normal(
198            loc=array_ops.zeros([], dtype=scale.dtype),
199            scale=array_ops.ones([], dtype=scale.dtype)),
200        bijector=AffineLinearOperator(
201            shift=loc, scale=scale, validate_args=validate_args),
202        batch_shape=batch_shape,
203        event_shape=event_shape,
204        validate_args=validate_args,
205        name=name)
206    self._parameters = parameters
207
208  @property
209  def loc(self):
210    """The `loc` `Tensor` in `Y = scale @ X + loc`."""
211    return self.bijector.shift
212
213  @property
214  def scale(self):
215    """The `scale` `LinearOperator` in `Y = scale @ X + loc`."""
216    return self.bijector.scale
217
218  @distribution_util.AppendDocstring(_mvn_sample_note)
219  def _log_prob(self, x):
220    return super(MultivariateNormalLinearOperator, self)._log_prob(x)
221
222  @distribution_util.AppendDocstring(_mvn_sample_note)
223  def _prob(self, x):
224    return super(MultivariateNormalLinearOperator, self)._prob(x)
225
226  def _mean(self):
227    shape = self.batch_shape.concatenate(self.event_shape)
228    has_static_shape = shape.is_fully_defined()
229    if not has_static_shape:
230      shape = array_ops.concat([
231          self.batch_shape_tensor(),
232          self.event_shape_tensor(),
233      ], 0)
234
235    if self.loc is None:
236      return array_ops.zeros(shape, self.dtype)
237
238    if has_static_shape and shape == self.loc.get_shape():
239      return array_ops.identity(self.loc)
240
241    # Add dummy tensor of zeros to broadcast.  This is only necessary if shape
242    # != self.loc.shape, but we could not determine if this is the case.
243    return array_ops.identity(self.loc) + array_ops.zeros(shape, self.dtype)
244
245  def _covariance(self):
246    if distribution_util.is_diagonal_scale(self.scale):
247      return array_ops.matrix_diag(math_ops.square(self.scale.diag_part()))
248    else:
249      return self.scale.matmul(self.scale.to_dense(), adjoint_arg=True)
250
251  def _variance(self):
252    if distribution_util.is_diagonal_scale(self.scale):
253      return math_ops.square(self.scale.diag_part())
254    elif (isinstance(self.scale, linalg.LinearOperatorLowRankUpdate) and
255          self.scale.is_self_adjoint):
256      return array_ops.matrix_diag_part(
257          self.scale.matmul(self.scale.to_dense()))
258    else:
259      return array_ops.matrix_diag_part(
260          self.scale.matmul(self.scale.to_dense(), adjoint_arg=True))
261
262  def _stddev(self):
263    if distribution_util.is_diagonal_scale(self.scale):
264      return math_ops.abs(self.scale.diag_part())
265    elif (isinstance(self.scale, linalg.LinearOperatorLowRankUpdate) and
266          self.scale.is_self_adjoint):
267      return math_ops.sqrt(array_ops.matrix_diag_part(
268          self.scale.matmul(self.scale.to_dense())))
269    else:
270      return math_ops.sqrt(array_ops.matrix_diag_part(
271          self.scale.matmul(self.scale.to_dense(), adjoint_arg=True)))
272
273  def _mode(self):
274    return self._mean()
275
276
277@kullback_leibler.RegisterKL(MultivariateNormalLinearOperator,
278                             MultivariateNormalLinearOperator)
279@deprecation.deprecated(
280    "2018-10-01",
281    "The TensorFlow Distributions library has moved to "
282    "TensorFlow Probability "
283    "(https://github.com/tensorflow/probability). You "
284    "should update all references to use `tfp.distributions` "
285    "instead of `tf.contrib.distributions`.",
286    warn_once=True)
287def _kl_brute_force(a, b, name=None):
288  """Batched KL divergence `KL(a || b)` for multivariate Normals.
289
290  With `X`, `Y` both multivariate Normals in `R^k` with means `mu_a`, `mu_b` and
291  covariance `C_a`, `C_b` respectively,
292
293  ```
294  KL(a || b) = 0.5 * ( L - k + T + Q ),
295  L := Log[Det(C_b)] - Log[Det(C_a)]
296  T := trace(C_b^{-1} C_a),
297  Q := (mu_b - mu_a)^T C_b^{-1} (mu_b - mu_a),
298  ```
299
300  This `Op` computes the trace by solving `C_b^{-1} C_a`. Although efficient
301  methods for solving systems with `C_b` may be available, a dense version of
302  (the square root of) `C_a` is used, so performance is `O(B s k**2)` where `B`
303  is the batch size, and `s` is the cost of solving `C_b x = y` for vectors `x`
304  and `y`.
305
306  Args:
307    a: Instance of `MultivariateNormalLinearOperator`.
308    b: Instance of `MultivariateNormalLinearOperator`.
309    name: (optional) name to use for created ops. Default "kl_mvn".
310
311  Returns:
312    Batchwise `KL(a || b)`.
313  """
314
315  def squared_frobenius_norm(x):
316    """Helper to make KL calculation slightly more readable."""
317    # http://mathworld.wolfram.com/FrobeniusNorm.html
318    # The gradient of KL[p,q] is not defined when p==q. The culprit is
319    # linalg_ops.norm, i.e., we cannot use the commented out code.
320    # return math_ops.square(linalg_ops.norm(x, ord="fro", axis=[-2, -1]))
321    return math_ops.reduce_sum(math_ops.square(x), axis=[-2, -1])
322
323  # TODO(b/35041439): See also b/35040945. Remove this function once LinOp
324  # supports something like:
325  #   A.inverse().solve(B).norm(order='fro', axis=[-1, -2])
326  def is_diagonal(x):
327    """Helper to identify if `LinearOperator` has only a diagonal component."""
328    return (isinstance(x, linalg.LinearOperatorIdentity) or
329            isinstance(x, linalg.LinearOperatorScaledIdentity) or
330            isinstance(x, linalg.LinearOperatorDiag))
331
332  with ops.name_scope(name, "kl_mvn", values=[a.loc, b.loc] +
333                      a.scale.graph_parents + b.scale.graph_parents):
334    # Calculation is based on:
335    # http://stats.stackexchange.com/questions/60680/kl-divergence-between-two-multivariate-gaussians
336    # and,
337    # https://en.wikipedia.org/wiki/Matrix_norm#Frobenius_norm
338    # i.e.,
339    #   If Ca = AA', Cb = BB', then
340    #   tr[inv(Cb) Ca] = tr[inv(B)' inv(B) A A']
341    #                  = tr[inv(B) A A' inv(B)']
342    #                  = tr[(inv(B) A) (inv(B) A)']
343    #                  = sum_{ij} (inv(B) A)_{ij}**2
344    #                  = ||inv(B) A||_F**2
345    # where ||.||_F is the Frobenius norm and the second equality follows from
346    # the cyclic permutation property.
347    if is_diagonal(a.scale) and is_diagonal(b.scale):
348      # Using `stddev` because it handles expansion of Identity cases.
349      b_inv_a = (a.stddev() / b.stddev())[..., array_ops.newaxis]
350    else:
351      b_inv_a = b.scale.solve(a.scale.to_dense())
352    kl_div = (b.scale.log_abs_determinant()
353              - a.scale.log_abs_determinant()
354              + 0.5 * (
355                  - math_ops.cast(a.scale.domain_dimension_tensor(), a.dtype)
356                  + squared_frobenius_norm(b_inv_a)
357                  + squared_frobenius_norm(b.scale.solve(
358                      (b.mean() - a.mean())[..., array_ops.newaxis]))))
359    kl_div.set_shape(array_ops.broadcast_static_shape(
360        a.batch_shape, b.batch_shape))
361    return kl_div
362