1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""A helper class for inferring Distribution shape.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import contextlib 21 22from tensorflow.python.framework import dtypes 23from tensorflow.python.framework import ops 24from tensorflow.python.framework import tensor_util 25from tensorflow.python.ops import array_ops 26from tensorflow.python.ops import check_ops 27from tensorflow.python.ops import control_flow_ops 28from tensorflow.python.ops import math_ops 29from tensorflow.python.ops.distributions import util as distribution_util 30from tensorflow.python.util import deprecation 31 32 33class _DistributionShape(object): 34 """Manage and manipulate `Distribution` shape. 35 36 #### Terminology 37 38 Recall that a `Tensor` has: 39 - `shape`: size of `Tensor` dimensions, 40 - `ndims`: size of `shape`; number of `Tensor` dimensions, 41 - `dims`: indexes into `shape`; useful for transpose, reduce. 42 43 `Tensor`s sampled from a `Distribution` can be partitioned by `sample_dims`, 44 `batch_dims`, and `event_dims`. To understand the semantics of these 45 dimensions, consider when two of the three are fixed and the remaining 46 is varied: 47 - `sample_dims`: indexes independent draws from identical 48 parameterizations of the `Distribution`. 49 - `batch_dims`: indexes independent draws from non-identical 50 parameterizations of the `Distribution`. 51 - `event_dims`: indexes event coordinates from one sample. 52 53 The `sample`, `batch`, and `event` dimensions constitute the entirety of a 54 `Distribution` `Tensor`'s shape. 55 56 The dimensions are always in `sample`, `batch`, `event` order. 57 58 #### Purpose 59 60 This class partitions `Tensor` notions of `shape`, `ndims`, and `dims` into 61 `Distribution` notions of `sample,` `batch,` and `event` dimensions. That 62 is, it computes any of: 63 64 ``` 65 sample_shape batch_shape event_shape 66 sample_dims batch_dims event_dims 67 sample_ndims batch_ndims event_ndims 68 ``` 69 70 for a given `Tensor`, e.g., the result of 71 `Distribution.sample(sample_shape=...)`. 72 73 For a given `Tensor`, this class computes the above table using minimal 74 information: `batch_ndims` and `event_ndims`. 75 76 #### Examples 77 78 We show examples of distribution shape semantics. 79 80 - Sample dimensions: 81 Computing summary statistics, i.e., the average is a reduction over sample 82 dimensions. 83 84 ```python 85 sample_dims = [0] 86 tf.reduce_mean(Normal(loc=1.3, scale=1.).sample_n(1000), 87 axis=sample_dims) # ~= 1.3 88 ``` 89 90 - Batch dimensions: 91 Monte Carlo estimation of a marginal probability: 92 Average over batch dimensions where batch dimensions are associated with 93 random draws from a prior. 94 E.g., suppose we want to find the Monte Carlo estimate of the marginal 95 distribution of a `Normal` with a random `Laplace` location: 96 97 ``` 98 P(X=x) = integral P(X=x|y) P(Y=y) dy 99 ~= 1/n sum_{i=1}^n P(X=x|y_i), y_i ~iid Laplace(0,1) 100 = tf.reduce_mean(Normal(loc=Laplace(0., 1.).sample_n(n=1000), 101 scale=tf.ones(1000)).prob(x), 102 axis=batch_dims) 103 ``` 104 105 The `Laplace` distribution generates a `Tensor` of shape `[1000]`. When 106 fed to a `Normal`, this is interpreted as 1000 different locations, i.e., 107 1000 non-identical Normals. Therefore a single call to `prob(x)` yields 108 1000 probabilities, one for every location. The average over this batch 109 yields the marginal. 110 111 - Event dimensions: 112 Computing the determinant of the Jacobian of a function of a random 113 variable involves a reduction over event dimensions. 114 E.g., Jacobian of the transform `Y = g(X) = exp(X)`: 115 116 ```python 117 tf.div(1., tf.reduce_prod(x, event_dims)) 118 ``` 119 120 We show examples using this class. 121 122 Write `S, B, E` for `sample_shape`, `batch_shape`, and `event_shape`. 123 124 ```python 125 # 150 iid samples from one multivariate Normal with two degrees of freedom. 126 mu = [0., 0] 127 sigma = [[1., 0], 128 [0, 1]] 129 mvn = MultivariateNormal(mu, sigma) 130 rand_mvn = mvn.sample(sample_shape=[3, 50]) 131 shaper = DistributionShape(batch_ndims=0, event_ndims=1) 132 S, B, E = shaper.get_shape(rand_mvn) 133 # S = [3, 50] 134 # B = [] 135 # E = [2] 136 137 # 12 iid samples from one Wishart with 2x2 events. 138 sigma = [[1., 0], 139 [2, 1]] 140 wishart = Wishart(df=5, scale=sigma) 141 rand_wishart = wishart.sample(sample_shape=[3, 4]) 142 shaper = DistributionShape(batch_ndims=0, event_ndims=2) 143 S, B, E = shaper.get_shape(rand_wishart) 144 # S = [3, 4] 145 # B = [] 146 # E = [2, 2] 147 148 # 100 iid samples from two, non-identical trivariate Normal distributions. 149 mu = ... # shape(2, 3) 150 sigma = ... # shape(2, 3, 3) 151 X = MultivariateNormal(mu, sigma).sample(shape=[4, 25]) 152 # S = [4, 25] 153 # B = [2] 154 # E = [3] 155 ``` 156 157 #### Argument Validation 158 159 When `validate_args=False`, checks that cannot be done during 160 graph construction are performed at graph execution. This may result in a 161 performance degradation because data must be switched from GPU to CPU. 162 163 For example, when `validate_args=False` and `event_ndims` is a 164 non-constant `Tensor`, it is checked to be a non-negative integer at graph 165 execution. (Same for `batch_ndims`). Constant `Tensor`s and non-`Tensor` 166 arguments are always checked for correctness since this can be done for 167 "free," i.e., during graph construction. 168 """ 169 170 @deprecation.deprecated( 171 "2018-10-01", 172 "The TensorFlow Distributions library has moved to " 173 "TensorFlow Probability " 174 "(https://github.com/tensorflow/probability). You " 175 "should update all references to use `tfp.distributions` " 176 "instead of `tf.contrib.distributions`.", 177 warn_once=True) 178 def __init__(self, 179 batch_ndims=None, 180 event_ndims=None, 181 validate_args=False, 182 name="DistributionShape"): 183 """Construct `DistributionShape` with fixed `batch_ndims`, `event_ndims`. 184 185 `batch_ndims` and `event_ndims` are fixed throughout the lifetime of a 186 `Distribution`. They may only be known at graph execution. 187 188 If both `batch_ndims` and `event_ndims` are python scalars (rather than 189 either being a `Tensor`), functions in this class automatically perform 190 sanity checks during graph construction. 191 192 Args: 193 batch_ndims: `Tensor`. Number of `dims` (`rank`) of the batch portion of 194 indexes of a `Tensor`. A "batch" is a non-identical distribution, i.e, 195 Normal with different parameters. 196 event_ndims: `Tensor`. Number of `dims` (`rank`) of the event portion of 197 indexes of a `Tensor`. An "event" is what is sampled from a 198 distribution, i.e., a trivariate Normal has an event shape of [3] and a 199 4 dimensional Wishart has an event shape of [4, 4]. 200 validate_args: Python `bool`, default `False`. When `True`, 201 non-`tf.constant` `Tensor` arguments are checked for correctness. 202 (`tf.constant` arguments are always checked.) 203 name: Python `str`. The name prepended to Ops created by this class. 204 205 Raises: 206 ValueError: if either `batch_ndims` or `event_ndims` are: `None`, 207 negative, not `int32`. 208 """ 209 if batch_ndims is None: raise ValueError("batch_ndims cannot be None") 210 if event_ndims is None: raise ValueError("event_ndims cannot be None") 211 self._batch_ndims = batch_ndims 212 self._event_ndims = event_ndims 213 self._validate_args = validate_args 214 with ops.name_scope(name): 215 self._name = name 216 with ops.name_scope("init"): 217 self._batch_ndims = self._assert_non_negative_int32_scalar( 218 ops.convert_to_tensor( 219 batch_ndims, name="batch_ndims")) 220 self._batch_ndims_static, self._batch_ndims_is_0 = ( 221 self._introspect_ndims(self._batch_ndims)) 222 self._event_ndims = self._assert_non_negative_int32_scalar( 223 ops.convert_to_tensor( 224 event_ndims, name="event_ndims")) 225 self._event_ndims_static, self._event_ndims_is_0 = ( 226 self._introspect_ndims(self._event_ndims)) 227 228 @property 229 def name(self): 230 """Name given to ops created by this class.""" 231 return self._name 232 233 @property 234 def batch_ndims(self): 235 """Returns number of dimensions corresponding to non-identical draws.""" 236 return self._batch_ndims 237 238 @property 239 def event_ndims(self): 240 """Returns number of dimensions needed to index a sample's coordinates.""" 241 return self._event_ndims 242 243 @property 244 def validate_args(self): 245 """Returns True if graph-runtime `Tensor` checks are enabled.""" 246 return self._validate_args 247 248 def get_ndims(self, x, name="get_ndims"): 249 """Get `Tensor` number of dimensions (rank). 250 251 Args: 252 x: `Tensor`. 253 name: Python `str`. The name to give this op. 254 255 Returns: 256 ndims: Scalar number of dimensions associated with a `Tensor`. 257 """ 258 with self._name_scope(name, values=[x]): 259 x = ops.convert_to_tensor(x, name="x") 260 ndims = x.get_shape().ndims 261 if ndims is None: 262 return array_ops.rank(x, name="ndims") 263 return ops.convert_to_tensor(ndims, dtype=dtypes.int32, name="ndims") 264 265 def get_sample_ndims(self, x, name="get_sample_ndims"): 266 """Returns number of dimensions corresponding to iid draws ("sample"). 267 268 Args: 269 x: `Tensor`. 270 name: Python `str`. The name to give this op. 271 272 Returns: 273 sample_ndims: `Tensor` (0D, `int32`). 274 275 Raises: 276 ValueError: if `sample_ndims` is calculated to be negative. 277 """ 278 with self._name_scope(name, values=[x]): 279 ndims = self.get_ndims(x, name=name) 280 if self._is_all_constant_helper(ndims, self.batch_ndims, 281 self.event_ndims): 282 ndims = tensor_util.constant_value(ndims) 283 sample_ndims = (ndims - self._batch_ndims_static - 284 self._event_ndims_static) 285 if sample_ndims < 0: 286 raise ValueError( 287 "expected batch_ndims(%d) + event_ndims(%d) <= ndims(%d)" % 288 (self._batch_ndims_static, self._event_ndims_static, ndims)) 289 return ops.convert_to_tensor(sample_ndims, name="sample_ndims") 290 else: 291 with ops.name_scope(name="sample_ndims"): 292 sample_ndims = ndims - self.batch_ndims - self.event_ndims 293 if self.validate_args: 294 sample_ndims = control_flow_ops.with_dependencies( 295 [check_ops.assert_non_negative(sample_ndims)], sample_ndims) 296 return sample_ndims 297 298 def get_dims(self, x, name="get_dims"): 299 """Returns dimensions indexing `sample_shape`, `batch_shape`, `event_shape`. 300 301 Example: 302 303 ```python 304 x = ... # Tensor with shape [4, 3, 2, 1] 305 sample_dims, batch_dims, event_dims = _DistributionShape( 306 batch_ndims=2, event_ndims=1).get_dims(x) 307 # sample_dims == [0] 308 # batch_dims == [1, 2] 309 # event_dims == [3] 310 # Note that these are not the shape parts, but rather indexes into shape. 311 ``` 312 313 Args: 314 x: `Tensor`. 315 name: Python `str`. The name to give this op. 316 317 Returns: 318 sample_dims: `Tensor` (1D, `int32`). 319 batch_dims: `Tensor` (1D, `int32`). 320 event_dims: `Tensor` (1D, `int32`). 321 """ 322 with self._name_scope(name, values=[x]): 323 def make_dims(start_sum, size, name): 324 """Closure to make dims range.""" 325 start_sum = start_sum if start_sum else [ 326 array_ops.zeros([], dtype=dtypes.int32, name="zero")] 327 if self._is_all_constant_helper(size, *start_sum): 328 start = sum(tensor_util.constant_value(s) for s in start_sum) 329 stop = start + tensor_util.constant_value(size) 330 return ops.convert_to_tensor( 331 list(range(start, stop)), dtype=dtypes.int32, name=name) 332 else: 333 start = sum(start_sum) 334 return math_ops.range(start, start + size) 335 sample_ndims = self.get_sample_ndims(x, name=name) 336 return (make_dims([], sample_ndims, name="sample_dims"), 337 make_dims([sample_ndims], self.batch_ndims, name="batch_dims"), 338 make_dims([sample_ndims, self.batch_ndims], 339 self.event_ndims, name="event_dims")) 340 341 def get_shape(self, x, name="get_shape"): 342 """Returns `Tensor`'s shape partitioned into `sample`, `batch`, `event`. 343 344 Args: 345 x: `Tensor`. 346 name: Python `str`. The name to give this op. 347 348 Returns: 349 sample_shape: `Tensor` (1D, `int32`). 350 batch_shape: `Tensor` (1D, `int32`). 351 event_shape: `Tensor` (1D, `int32`). 352 """ 353 with self._name_scope(name, values=[x]): 354 x = ops.convert_to_tensor(x, name="x") 355 def slice_shape(start_sum, size, name): 356 """Closure to slice out shape.""" 357 start_sum = start_sum if start_sum else [ 358 array_ops.zeros([], dtype=dtypes.int32, name="zero")] 359 if (x.get_shape().ndims is not None and 360 self._is_all_constant_helper(size, *start_sum)): 361 start = sum(tensor_util.constant_value(s) for s in start_sum) 362 stop = start + tensor_util.constant_value(size) 363 slice_ = x.get_shape()[start:stop].as_list() 364 if all(s is not None for s in slice_): 365 return ops.convert_to_tensor(slice_, dtype=dtypes.int32, name=name) 366 return array_ops.slice(array_ops.shape(x), [sum(start_sum)], [size]) 367 sample_ndims = self.get_sample_ndims(x, name=name) 368 return (slice_shape([], sample_ndims, 369 name="sample_shape"), 370 slice_shape([sample_ndims], self.batch_ndims, 371 name="batch_shape"), 372 slice_shape([sample_ndims, self.batch_ndims], self.event_ndims, 373 name="event_shape")) 374 375 # TODO(jvdillon): Make remove expand_batch_dim and make expand_batch_dim=False 376 # the default behavior. 377 def make_batch_of_event_sample_matrices( 378 self, x, expand_batch_dim=True, 379 name="make_batch_of_event_sample_matrices"): 380 """Reshapes/transposes `Distribution` `Tensor` from S+B+E to B_+E_+S_. 381 382 Where: 383 - `B_ = B if B or not expand_batch_dim else [1]`, 384 - `E_ = E if E else [1]`, 385 - `S_ = [tf.reduce_prod(S)]`. 386 387 Args: 388 x: `Tensor`. 389 expand_batch_dim: Python `bool`. If `True` the batch dims will be expanded 390 such that `batch_ndims >= 1`. 391 name: Python `str`. The name to give this op. 392 393 Returns: 394 x: `Tensor`. Input transposed/reshaped to `B_+E_+S_`. 395 sample_shape: `Tensor` (1D, `int32`). 396 """ 397 with self._name_scope(name, values=[x]): 398 x = ops.convert_to_tensor(x, name="x") 399 # x.shape: S+B+E 400 sample_shape, batch_shape, event_shape = self.get_shape(x) 401 event_shape = distribution_util.pick_vector( 402 self._event_ndims_is_0, [1], event_shape) 403 if expand_batch_dim: 404 batch_shape = distribution_util.pick_vector( 405 self._batch_ndims_is_0, [1], batch_shape) 406 new_shape = array_ops.concat([[-1], batch_shape, event_shape], 0) 407 x = array_ops.reshape(x, shape=new_shape) 408 # x.shape: [prod(S)]+B_+E_ 409 x = distribution_util.rotate_transpose(x, shift=-1) 410 # x.shape: B_+E_+[prod(S)] 411 return x, sample_shape 412 413 # TODO(jvdillon): Make remove expand_batch_dim and make expand_batch_dim=False 414 # the default behavior. 415 def undo_make_batch_of_event_sample_matrices( 416 self, x, sample_shape, expand_batch_dim=True, 417 name="undo_make_batch_of_event_sample_matrices"): 418 """Reshapes/transposes `Distribution` `Tensor` from B_+E_+S_ to S+B+E. 419 420 Where: 421 - `B_ = B if B or not expand_batch_dim else [1]`, 422 - `E_ = E if E else [1]`, 423 - `S_ = [tf.reduce_prod(S)]`. 424 425 This function "reverses" `make_batch_of_event_sample_matrices`. 426 427 Args: 428 x: `Tensor` of shape `B_+E_+S_`. 429 sample_shape: `Tensor` (1D, `int32`). 430 expand_batch_dim: Python `bool`. If `True` the batch dims will be expanded 431 such that `batch_ndims>=1`. 432 name: Python `str`. The name to give this op. 433 434 Returns: 435 x: `Tensor`. Input transposed/reshaped to `S+B+E`. 436 """ 437 with self._name_scope(name, values=[x, sample_shape]): 438 x = ops.convert_to_tensor(x, name="x") 439 # x.shape: _B+_E+[prod(S)] 440 sample_shape = ops.convert_to_tensor(sample_shape, name="sample_shape") 441 x = distribution_util.rotate_transpose(x, shift=1) 442 # x.shape: [prod(S)]+_B+_E 443 if self._is_all_constant_helper(self.batch_ndims, self.event_ndims): 444 if self._batch_ndims_is_0 or self._event_ndims_is_0: 445 squeeze_dims = [] 446 if self._event_ndims_is_0: 447 squeeze_dims += [-1] 448 if self._batch_ndims_is_0 and expand_batch_dim: 449 squeeze_dims += [1] 450 if squeeze_dims: 451 x = array_ops.squeeze(x, axis=squeeze_dims) 452 # x.shape: [prod(S)]+B+E 453 _, batch_shape, event_shape = self.get_shape(x) 454 else: 455 s = (x.get_shape().as_list() if x.get_shape().is_fully_defined() 456 else array_ops.shape(x)) 457 batch_shape = s[1:1+self.batch_ndims] 458 # Since sample_dims=1 and is left-most, we add 1 to the number of 459 # batch_ndims to get the event start dim. 460 event_start = array_ops.where( 461 math_ops.logical_and(expand_batch_dim, self._batch_ndims_is_0), 462 2, 1 + self.batch_ndims) 463 event_shape = s[event_start:event_start+self.event_ndims] 464 new_shape = array_ops.concat([sample_shape, batch_shape, event_shape], 0) 465 x = array_ops.reshape(x, shape=new_shape) 466 # x.shape: S+B+E 467 return x 468 469 @contextlib.contextmanager 470 def _name_scope(self, name=None, values=None): 471 """Helper function to standardize op scope.""" 472 with ops.name_scope(self.name): 473 with ops.name_scope(name, values=( 474 (values or []) + [self.batch_ndims, self.event_ndims])) as scope: 475 yield scope 476 477 def _is_all_constant_helper(self, *args): 478 """Helper which returns True if all inputs are constant_value.""" 479 return all(tensor_util.constant_value(x) is not None for x in args) 480 481 def _assert_non_negative_int32_scalar(self, x): 482 """Helper which ensures that input is a non-negative, int32, scalar.""" 483 x = ops.convert_to_tensor(x, name="x") 484 if x.dtype.base_dtype != dtypes.int32.base_dtype: 485 raise TypeError("%s.dtype=%s is not %s" % (x.name, x.dtype, dtypes.int32)) 486 x_value_static = tensor_util.constant_value(x) 487 if x.get_shape().ndims is not None and x_value_static is not None: 488 if x.get_shape().ndims != 0: 489 raise ValueError("%s.ndims=%d is not 0 (scalar)" % 490 (x.name, x.get_shape().ndims)) 491 if x_value_static < 0: 492 raise ValueError("%s.value=%d cannot be negative" % 493 (x.name, x_value_static)) 494 return x 495 if self.validate_args: 496 x = control_flow_ops.with_dependencies([ 497 check_ops.assert_rank(x, 0), 498 check_ops.assert_non_negative(x)], x) 499 return x 500 501 def _introspect_ndims(self, ndims): 502 """Helper to establish some properties of input ndims args.""" 503 if self._is_all_constant_helper(ndims): 504 return (tensor_util.constant_value(ndims), 505 tensor_util.constant_value(ndims) == 0) 506 return None, math_ops.equal(ndims, 0) 507