1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Multivariate Normal distribution classes.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.contrib.distributions.python.ops import distribution_util 22from tensorflow.contrib.distributions.python.ops.bijectors import AffineLinearOperator 23from tensorflow.python.framework import ops 24from tensorflow.python.ops import array_ops 25from tensorflow.python.ops import math_ops 26from tensorflow.python.ops.distributions import kullback_leibler 27from tensorflow.python.ops.distributions import normal 28from tensorflow.python.ops.distributions import transformed_distribution 29from tensorflow.python.ops.linalg import linalg 30from tensorflow.python.util import deprecation 31 32 33__all__ = [ 34 "MultivariateNormalLinearOperator", 35] 36 37 38_mvn_sample_note = """ 39`value` is a batch vector with compatible shape if `value` is a `Tensor` whose 40shape can be broadcast up to either: 41 42```python 43self.batch_shape + self.event_shape 44``` 45 46or 47 48```python 49[M1, ..., Mm] + self.batch_shape + self.event_shape 50``` 51 52""" 53 54 55# TODO(b/35290280): Import in `../../__init__.py` after adding unit-tests. 56class MultivariateNormalLinearOperator( 57 transformed_distribution.TransformedDistribution): 58 """The multivariate normal distribution on `R^k`. 59 60 The Multivariate Normal distribution is defined over `R^k` and parameterized 61 by a (batch of) length-`k` `loc` vector (aka "mu") and a (batch of) `k x k` 62 `scale` matrix; `covariance = scale @ scale.T`, where `@` denotes 63 matrix-multiplication. 64 65 #### Mathematical Details 66 67 The probability density function (pdf) is, 68 69 ```none 70 pdf(x; loc, scale) = exp(-0.5 ||y||**2) / Z, 71 y = inv(scale) @ (x - loc), 72 Z = (2 pi)**(0.5 k) |det(scale)|, 73 ``` 74 75 where: 76 77 * `loc` is a vector in `R^k`, 78 * `scale` is a linear operator in `R^{k x k}`, `cov = scale @ scale.T`, 79 * `Z` denotes the normalization constant, and, 80 * `||y||**2` denotes the squared Euclidean norm of `y`. 81 82 The MultivariateNormal distribution is a member of the [location-scale 83 family](https://en.wikipedia.org/wiki/Location-scale_family), i.e., it can be 84 constructed as, 85 86 ```none 87 X ~ MultivariateNormal(loc=0, scale=1) # Identity scale, zero shift. 88 Y = scale @ X + loc 89 ``` 90 91 #### Examples 92 93 ```python 94 import tensorflow_probability as tfp 95 tfd = tfp.distributions 96 97 # Initialize a single 3-variate Gaussian. 98 mu = [1., 2, 3] 99 cov = [[ 0.36, 0.12, 0.06], 100 [ 0.12, 0.29, -0.13], 101 [ 0.06, -0.13, 0.26]] 102 scale = tf.cholesky(cov) 103 # ==> [[ 0.6, 0. , 0. ], 104 # [ 0.2, 0.5, 0. ], 105 # [ 0.1, -0.3, 0.4]]) 106 107 mvn = tfd.MultivariateNormalLinearOperator( 108 loc=mu, 109 scale=tf.linalg.LinearOperatorLowerTriangular(scale)) 110 111 # Covariance agrees with cholesky(cov) parameterization. 112 mvn.covariance().eval() 113 # ==> [[ 0.36, 0.12, 0.06], 114 # [ 0.12, 0.29, -0.13], 115 # [ 0.06, -0.13, 0.26]] 116 117 # Compute the pdf of an`R^3` observation; return a scalar. 118 mvn.prob([-1., 0, 1]).eval() # shape: [] 119 120 # Initialize a 2-batch of 3-variate Gaussians. 121 mu = [[1., 2, 3], 122 [11, 22, 33]] # shape: [2, 3] 123 scale_diag = [[1., 2, 3], 124 [0.5, 1, 1.5]] # shape: [2, 3] 125 126 mvn = tfd.MultivariateNormalLinearOperator( 127 loc=mu, 128 scale=tf.linalg.LinearOperatorDiag(scale_diag)) 129 130 # Compute the pdf of two `R^3` observations; return a length-2 vector. 131 x = [[-0.9, 0, 0.1], 132 [-10, 0, 9]] # shape: [2, 3] 133 mvn.prob(x).eval() # shape: [2] 134 ``` 135 136 """ 137 138 @deprecation.deprecated( 139 "2018-10-01", 140 "The TensorFlow Distributions library has moved to " 141 "TensorFlow Probability " 142 "(https://github.com/tensorflow/probability). You " 143 "should update all references to use `tfp.distributions` " 144 "instead of `tf.contrib.distributions`.", 145 warn_once=True) 146 def __init__(self, 147 loc=None, 148 scale=None, 149 validate_args=False, 150 allow_nan_stats=True, 151 name="MultivariateNormalLinearOperator"): 152 """Construct Multivariate Normal distribution on `R^k`. 153 154 The `batch_shape` is the broadcast shape between `loc` and `scale` 155 arguments. 156 157 The `event_shape` is given by last dimension of the matrix implied by 158 `scale`. The last dimension of `loc` (if provided) must broadcast with this. 159 160 Recall that `covariance = scale @ scale.T`. 161 162 Additional leading dimensions (if any) will index batches. 163 164 Args: 165 loc: Floating-point `Tensor`. If this is set to `None`, `loc` is 166 implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where 167 `b >= 0` and `k` is the event size. 168 scale: Instance of `LinearOperator` with same `dtype` as `loc` and shape 169 `[B1, ..., Bb, k, k]`. 170 validate_args: Python `bool`, default `False`. Whether to validate input 171 with asserts. If `validate_args` is `False`, and the inputs are 172 invalid, correct behavior is not guaranteed. 173 allow_nan_stats: Python `bool`, default `True`. If `False`, raise an 174 exception if a statistic (e.g. mean/mode/etc...) is undefined for any 175 batch member If `True`, batch members with valid parameters leading to 176 undefined statistics will return NaN for this statistic. 177 name: The name to give Ops created by the initializer. 178 179 Raises: 180 ValueError: if `scale` is unspecified. 181 TypeError: if not `scale.dtype.is_floating` 182 """ 183 parameters = dict(locals()) 184 if scale is None: 185 raise ValueError("Missing required `scale` parameter.") 186 if not scale.dtype.is_floating: 187 raise TypeError("`scale` parameter must have floating-point dtype.") 188 189 with ops.name_scope(name, values=[loc] + scale.graph_parents) as name: 190 # Since expand_dims doesn't preserve constant-ness, we obtain the 191 # non-dynamic value if possible. 192 loc = ops.convert_to_tensor(loc, name="loc") if loc is not None else loc 193 batch_shape, event_shape = distribution_util.shapes_from_loc_and_scale( 194 loc, scale) 195 196 super(MultivariateNormalLinearOperator, self).__init__( 197 distribution=normal.Normal( 198 loc=array_ops.zeros([], dtype=scale.dtype), 199 scale=array_ops.ones([], dtype=scale.dtype)), 200 bijector=AffineLinearOperator( 201 shift=loc, scale=scale, validate_args=validate_args), 202 batch_shape=batch_shape, 203 event_shape=event_shape, 204 validate_args=validate_args, 205 name=name) 206 self._parameters = parameters 207 208 @property 209 def loc(self): 210 """The `loc` `Tensor` in `Y = scale @ X + loc`.""" 211 return self.bijector.shift 212 213 @property 214 def scale(self): 215 """The `scale` `LinearOperator` in `Y = scale @ X + loc`.""" 216 return self.bijector.scale 217 218 @distribution_util.AppendDocstring(_mvn_sample_note) 219 def _log_prob(self, x): 220 return super(MultivariateNormalLinearOperator, self)._log_prob(x) 221 222 @distribution_util.AppendDocstring(_mvn_sample_note) 223 def _prob(self, x): 224 return super(MultivariateNormalLinearOperator, self)._prob(x) 225 226 def _mean(self): 227 shape = self.batch_shape.concatenate(self.event_shape) 228 has_static_shape = shape.is_fully_defined() 229 if not has_static_shape: 230 shape = array_ops.concat([ 231 self.batch_shape_tensor(), 232 self.event_shape_tensor(), 233 ], 0) 234 235 if self.loc is None: 236 return array_ops.zeros(shape, self.dtype) 237 238 if has_static_shape and shape == self.loc.get_shape(): 239 return array_ops.identity(self.loc) 240 241 # Add dummy tensor of zeros to broadcast. This is only necessary if shape 242 # != self.loc.shape, but we could not determine if this is the case. 243 return array_ops.identity(self.loc) + array_ops.zeros(shape, self.dtype) 244 245 def _covariance(self): 246 if distribution_util.is_diagonal_scale(self.scale): 247 return array_ops.matrix_diag(math_ops.square(self.scale.diag_part())) 248 else: 249 return self.scale.matmul(self.scale.to_dense(), adjoint_arg=True) 250 251 def _variance(self): 252 if distribution_util.is_diagonal_scale(self.scale): 253 return math_ops.square(self.scale.diag_part()) 254 elif (isinstance(self.scale, linalg.LinearOperatorLowRankUpdate) and 255 self.scale.is_self_adjoint): 256 return array_ops.matrix_diag_part( 257 self.scale.matmul(self.scale.to_dense())) 258 else: 259 return array_ops.matrix_diag_part( 260 self.scale.matmul(self.scale.to_dense(), adjoint_arg=True)) 261 262 def _stddev(self): 263 if distribution_util.is_diagonal_scale(self.scale): 264 return math_ops.abs(self.scale.diag_part()) 265 elif (isinstance(self.scale, linalg.LinearOperatorLowRankUpdate) and 266 self.scale.is_self_adjoint): 267 return math_ops.sqrt(array_ops.matrix_diag_part( 268 self.scale.matmul(self.scale.to_dense()))) 269 else: 270 return math_ops.sqrt(array_ops.matrix_diag_part( 271 self.scale.matmul(self.scale.to_dense(), adjoint_arg=True))) 272 273 def _mode(self): 274 return self._mean() 275 276 277@kullback_leibler.RegisterKL(MultivariateNormalLinearOperator, 278 MultivariateNormalLinearOperator) 279@deprecation.deprecated( 280 "2018-10-01", 281 "The TensorFlow Distributions library has moved to " 282 "TensorFlow Probability " 283 "(https://github.com/tensorflow/probability). You " 284 "should update all references to use `tfp.distributions` " 285 "instead of `tf.contrib.distributions`.", 286 warn_once=True) 287def _kl_brute_force(a, b, name=None): 288 """Batched KL divergence `KL(a || b)` for multivariate Normals. 289 290 With `X`, `Y` both multivariate Normals in `R^k` with means `mu_a`, `mu_b` and 291 covariance `C_a`, `C_b` respectively, 292 293 ``` 294 KL(a || b) = 0.5 * ( L - k + T + Q ), 295 L := Log[Det(C_b)] - Log[Det(C_a)] 296 T := trace(C_b^{-1} C_a), 297 Q := (mu_b - mu_a)^T C_b^{-1} (mu_b - mu_a), 298 ``` 299 300 This `Op` computes the trace by solving `C_b^{-1} C_a`. Although efficient 301 methods for solving systems with `C_b` may be available, a dense version of 302 (the square root of) `C_a` is used, so performance is `O(B s k**2)` where `B` 303 is the batch size, and `s` is the cost of solving `C_b x = y` for vectors `x` 304 and `y`. 305 306 Args: 307 a: Instance of `MultivariateNormalLinearOperator`. 308 b: Instance of `MultivariateNormalLinearOperator`. 309 name: (optional) name to use for created ops. Default "kl_mvn". 310 311 Returns: 312 Batchwise `KL(a || b)`. 313 """ 314 315 def squared_frobenius_norm(x): 316 """Helper to make KL calculation slightly more readable.""" 317 # http://mathworld.wolfram.com/FrobeniusNorm.html 318 # The gradient of KL[p,q] is not defined when p==q. The culprit is 319 # linalg_ops.norm, i.e., we cannot use the commented out code. 320 # return math_ops.square(linalg_ops.norm(x, ord="fro", axis=[-2, -1])) 321 return math_ops.reduce_sum(math_ops.square(x), axis=[-2, -1]) 322 323 # TODO(b/35041439): See also b/35040945. Remove this function once LinOp 324 # supports something like: 325 # A.inverse().solve(B).norm(order='fro', axis=[-1, -2]) 326 def is_diagonal(x): 327 """Helper to identify if `LinearOperator` has only a diagonal component.""" 328 return (isinstance(x, linalg.LinearOperatorIdentity) or 329 isinstance(x, linalg.LinearOperatorScaledIdentity) or 330 isinstance(x, linalg.LinearOperatorDiag)) 331 332 with ops.name_scope(name, "kl_mvn", values=[a.loc, b.loc] + 333 a.scale.graph_parents + b.scale.graph_parents): 334 # Calculation is based on: 335 # http://stats.stackexchange.com/questions/60680/kl-divergence-between-two-multivariate-gaussians 336 # and, 337 # https://en.wikipedia.org/wiki/Matrix_norm#Frobenius_norm 338 # i.e., 339 # If Ca = AA', Cb = BB', then 340 # tr[inv(Cb) Ca] = tr[inv(B)' inv(B) A A'] 341 # = tr[inv(B) A A' inv(B)'] 342 # = tr[(inv(B) A) (inv(B) A)'] 343 # = sum_{ij} (inv(B) A)_{ij}**2 344 # = ||inv(B) A||_F**2 345 # where ||.||_F is the Frobenius norm and the second equality follows from 346 # the cyclic permutation property. 347 if is_diagonal(a.scale) and is_diagonal(b.scale): 348 # Using `stddev` because it handles expansion of Identity cases. 349 b_inv_a = (a.stddev() / b.stddev())[..., array_ops.newaxis] 350 else: 351 b_inv_a = b.scale.solve(a.scale.to_dense()) 352 kl_div = (b.scale.log_abs_determinant() 353 - a.scale.log_abs_determinant() 354 + 0.5 * ( 355 - math_ops.cast(a.scale.domain_dimension_tensor(), a.dtype) 356 + squared_frobenius_norm(b_inv_a) 357 + squared_frobenius_norm(b.scale.solve( 358 (b.mean() - a.mean())[..., array_ops.newaxis])))) 359 kl_div.set_shape(array_ops.broadcast_static_shape( 360 a.batch_shape, b.batch_shape)) 361 return kl_div 362