• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A helper class for inferring Distribution shape."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import contextlib
21
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import ops
24from tensorflow.python.framework import tensor_util
25from tensorflow.python.ops import array_ops
26from tensorflow.python.ops import check_ops
27from tensorflow.python.ops import control_flow_ops
28from tensorflow.python.ops import math_ops
29from tensorflow.python.ops.distributions import util as distribution_util
30from tensorflow.python.util import deprecation
31
32
33class _DistributionShape(object):
34  """Manage and manipulate `Distribution` shape.
35
36  #### Terminology
37
38  Recall that a `Tensor` has:
39    - `shape`: size of `Tensor` dimensions,
40    - `ndims`: size of `shape`; number of `Tensor` dimensions,
41    - `dims`: indexes into `shape`; useful for transpose, reduce.
42
43  `Tensor`s sampled from a `Distribution` can be partitioned by `sample_dims`,
44  `batch_dims`, and `event_dims`. To understand the semantics of these
45  dimensions, consider when two of the three are fixed and the remaining
46  is varied:
47    - `sample_dims`: indexes independent draws from identical
48                     parameterizations of the `Distribution`.
49    - `batch_dims`:  indexes independent draws from non-identical
50                     parameterizations of the `Distribution`.
51    - `event_dims`:  indexes event coordinates from one sample.
52
53  The `sample`, `batch`, and `event` dimensions constitute the entirety of a
54  `Distribution` `Tensor`'s shape.
55
56  The dimensions are always in `sample`, `batch`, `event` order.
57
58  #### Purpose
59
60  This class partitions `Tensor` notions of `shape`, `ndims`, and `dims` into
61  `Distribution` notions of `sample,` `batch,` and `event` dimensions. That
62  is, it computes any of:
63
64  ```
65  sample_shape     batch_shape     event_shape
66  sample_dims      batch_dims      event_dims
67  sample_ndims     batch_ndims     event_ndims
68  ```
69
70  for a given `Tensor`, e.g., the result of
71  `Distribution.sample(sample_shape=...)`.
72
73  For a given `Tensor`, this class computes the above table using minimal
74  information: `batch_ndims` and `event_ndims`.
75
76  #### Examples
77
78  We show examples of distribution shape semantics.
79
80    - Sample dimensions:
81      Computing summary statistics, i.e., the average is a reduction over sample
82      dimensions.
83
84      ```python
85      sample_dims = [0]
86      tf.reduce_mean(Normal(loc=1.3, scale=1.).sample_n(1000),
87                     axis=sample_dims)  # ~= 1.3
88      ```
89
90    - Batch dimensions:
91      Monte Carlo estimation of a marginal probability:
92      Average over batch dimensions where batch dimensions are associated with
93      random draws from a prior.
94      E.g., suppose we want to find the Monte Carlo estimate of the marginal
95      distribution of a `Normal` with a random `Laplace` location:
96
97      ```
98      P(X=x) = integral P(X=x|y) P(Y=y) dy
99            ~= 1/n sum_{i=1}^n P(X=x|y_i),   y_i ~iid Laplace(0,1)
100             = tf.reduce_mean(Normal(loc=Laplace(0., 1.).sample_n(n=1000),
101                                     scale=tf.ones(1000)).prob(x),
102                              axis=batch_dims)
103      ```
104
105      The `Laplace` distribution generates a `Tensor` of shape `[1000]`. When
106      fed to a `Normal`, this is interpreted as 1000 different locations, i.e.,
107      1000 non-identical Normals. Therefore a single call to `prob(x)` yields
108      1000 probabilities, one for every location. The average over this batch
109      yields the marginal.
110
111    - Event dimensions:
112      Computing the determinant of the Jacobian of a function of a random
113      variable involves a reduction over event dimensions.
114      E.g., Jacobian of the transform `Y = g(X) = exp(X)`:
115
116      ```python
117      tf.div(1., tf.reduce_prod(x, event_dims))
118      ```
119
120  We show examples using this class.
121
122  Write `S, B, E` for `sample_shape`, `batch_shape`, and `event_shape`.
123
124  ```python
125  # 150 iid samples from one multivariate Normal with two degrees of freedom.
126  mu = [0., 0]
127  sigma = [[1., 0],
128           [0,  1]]
129  mvn = MultivariateNormal(mu, sigma)
130  rand_mvn = mvn.sample(sample_shape=[3, 50])
131  shaper = DistributionShape(batch_ndims=0, event_ndims=1)
132  S, B, E = shaper.get_shape(rand_mvn)
133  # S = [3, 50]
134  # B = []
135  # E = [2]
136
137  # 12 iid samples from one Wishart with 2x2 events.
138  sigma = [[1., 0],
139           [2,  1]]
140  wishart = Wishart(df=5, scale=sigma)
141  rand_wishart = wishart.sample(sample_shape=[3, 4])
142  shaper = DistributionShape(batch_ndims=0, event_ndims=2)
143  S, B, E = shaper.get_shape(rand_wishart)
144  # S = [3, 4]
145  # B = []
146  # E = [2, 2]
147
148  # 100 iid samples from two, non-identical trivariate Normal distributions.
149  mu    = ...  # shape(2, 3)
150  sigma = ...  # shape(2, 3, 3)
151  X = MultivariateNormal(mu, sigma).sample(shape=[4, 25])
152  # S = [4, 25]
153  # B = [2]
154  # E = [3]
155  ```
156
157  #### Argument Validation
158
159  When `validate_args=False`, checks that cannot be done during
160  graph construction are performed at graph execution. This may result in a
161  performance degradation because data must be switched from GPU to CPU.
162
163  For example, when `validate_args=False` and `event_ndims` is a
164  non-constant `Tensor`, it is checked to be a non-negative integer at graph
165  execution. (Same for `batch_ndims`). Constant `Tensor`s and non-`Tensor`
166  arguments are always checked for correctness since this can be done for
167  "free," i.e., during graph construction.
168  """
169
170  @deprecation.deprecated(
171      "2018-10-01",
172      "The TensorFlow Distributions library has moved to "
173      "TensorFlow Probability "
174      "(https://github.com/tensorflow/probability). You "
175      "should update all references to use `tfp.distributions` "
176      "instead of `tf.contrib.distributions`.",
177      warn_once=True)
178  def __init__(self,
179               batch_ndims=None,
180               event_ndims=None,
181               validate_args=False,
182               name="DistributionShape"):
183    """Construct `DistributionShape` with fixed `batch_ndims`, `event_ndims`.
184
185    `batch_ndims` and `event_ndims` are fixed throughout the lifetime of a
186    `Distribution`. They may only be known at graph execution.
187
188    If both `batch_ndims` and `event_ndims` are python scalars (rather than
189    either being a `Tensor`), functions in this class automatically perform
190    sanity checks during graph construction.
191
192    Args:
193      batch_ndims: `Tensor`. Number of `dims` (`rank`) of the batch portion of
194        indexes of a `Tensor`. A "batch" is a non-identical distribution, i.e,
195        Normal with different parameters.
196      event_ndims: `Tensor`. Number of `dims` (`rank`) of the event portion of
197        indexes of a `Tensor`. An "event" is what is sampled from a
198        distribution, i.e., a trivariate Normal has an event shape of [3] and a
199        4 dimensional Wishart has an event shape of [4, 4].
200      validate_args: Python `bool`, default `False`. When `True`,
201        non-`tf.constant` `Tensor` arguments are checked for correctness.
202        (`tf.constant` arguments are always checked.)
203      name: Python `str`. The name prepended to Ops created by this class.
204
205    Raises:
206      ValueError: if either `batch_ndims` or `event_ndims` are: `None`,
207        negative, not `int32`.
208    """
209    if batch_ndims is None: raise ValueError("batch_ndims cannot be None")
210    if event_ndims is None: raise ValueError("event_ndims cannot be None")
211    self._batch_ndims = batch_ndims
212    self._event_ndims = event_ndims
213    self._validate_args = validate_args
214    with ops.name_scope(name):
215      self._name = name
216      with ops.name_scope("init"):
217        self._batch_ndims = self._assert_non_negative_int32_scalar(
218            ops.convert_to_tensor(
219                batch_ndims, name="batch_ndims"))
220        self._batch_ndims_static, self._batch_ndims_is_0 = (
221            self._introspect_ndims(self._batch_ndims))
222        self._event_ndims = self._assert_non_negative_int32_scalar(
223            ops.convert_to_tensor(
224                event_ndims, name="event_ndims"))
225        self._event_ndims_static, self._event_ndims_is_0 = (
226            self._introspect_ndims(self._event_ndims))
227
228  @property
229  def name(self):
230    """Name given to ops created by this class."""
231    return self._name
232
233  @property
234  def batch_ndims(self):
235    """Returns number of dimensions corresponding to non-identical draws."""
236    return self._batch_ndims
237
238  @property
239  def event_ndims(self):
240    """Returns number of dimensions needed to index a sample's coordinates."""
241    return self._event_ndims
242
243  @property
244  def validate_args(self):
245    """Returns True if graph-runtime `Tensor` checks are enabled."""
246    return self._validate_args
247
248  def get_ndims(self, x, name="get_ndims"):
249    """Get `Tensor` number of dimensions (rank).
250
251    Args:
252      x: `Tensor`.
253      name: Python `str`. The name to give this op.
254
255    Returns:
256      ndims: Scalar number of dimensions associated with a `Tensor`.
257    """
258    with self._name_scope(name, values=[x]):
259      x = ops.convert_to_tensor(x, name="x")
260      ndims = x.get_shape().ndims
261      if ndims is None:
262        return array_ops.rank(x, name="ndims")
263      return ops.convert_to_tensor(ndims, dtype=dtypes.int32, name="ndims")
264
265  def get_sample_ndims(self, x, name="get_sample_ndims"):
266    """Returns number of dimensions corresponding to iid draws ("sample").
267
268    Args:
269      x: `Tensor`.
270      name: Python `str`. The name to give this op.
271
272    Returns:
273      sample_ndims: `Tensor` (0D, `int32`).
274
275    Raises:
276      ValueError: if `sample_ndims` is calculated to be negative.
277    """
278    with self._name_scope(name, values=[x]):
279      ndims = self.get_ndims(x, name=name)
280      if self._is_all_constant_helper(ndims, self.batch_ndims,
281                                      self.event_ndims):
282        ndims = tensor_util.constant_value(ndims)
283        sample_ndims = (ndims - self._batch_ndims_static -
284                        self._event_ndims_static)
285        if sample_ndims < 0:
286          raise ValueError(
287              "expected batch_ndims(%d) + event_ndims(%d) <= ndims(%d)" %
288              (self._batch_ndims_static, self._event_ndims_static, ndims))
289        return ops.convert_to_tensor(sample_ndims, name="sample_ndims")
290      else:
291        with ops.name_scope(name="sample_ndims"):
292          sample_ndims = ndims - self.batch_ndims - self.event_ndims
293          if self.validate_args:
294            sample_ndims = control_flow_ops.with_dependencies(
295                [check_ops.assert_non_negative(sample_ndims)], sample_ndims)
296        return sample_ndims
297
298  def get_dims(self, x, name="get_dims"):
299    """Returns dimensions indexing `sample_shape`, `batch_shape`, `event_shape`.
300
301    Example:
302
303    ```python
304    x = ...  # Tensor with shape [4, 3, 2, 1]
305    sample_dims, batch_dims, event_dims = _DistributionShape(
306      batch_ndims=2, event_ndims=1).get_dims(x)
307    # sample_dims == [0]
308    # batch_dims == [1, 2]
309    # event_dims == [3]
310    # Note that these are not the shape parts, but rather indexes into shape.
311    ```
312
313    Args:
314      x: `Tensor`.
315      name: Python `str`. The name to give this op.
316
317    Returns:
318      sample_dims: `Tensor` (1D, `int32`).
319      batch_dims: `Tensor` (1D, `int32`).
320      event_dims: `Tensor` (1D, `int32`).
321    """
322    with self._name_scope(name, values=[x]):
323      def make_dims(start_sum, size, name):
324        """Closure to make dims range."""
325        start_sum = start_sum if start_sum else [
326            array_ops.zeros([], dtype=dtypes.int32, name="zero")]
327        if self._is_all_constant_helper(size, *start_sum):
328          start = sum(tensor_util.constant_value(s) for s in start_sum)
329          stop = start + tensor_util.constant_value(size)
330          return ops.convert_to_tensor(
331              list(range(start, stop)), dtype=dtypes.int32, name=name)
332        else:
333          start = sum(start_sum)
334          return math_ops.range(start, start + size)
335      sample_ndims = self.get_sample_ndims(x, name=name)
336      return (make_dims([], sample_ndims, name="sample_dims"),
337              make_dims([sample_ndims], self.batch_ndims, name="batch_dims"),
338              make_dims([sample_ndims, self.batch_ndims],
339                        self.event_ndims, name="event_dims"))
340
341  def get_shape(self, x, name="get_shape"):
342    """Returns `Tensor`'s shape partitioned into `sample`, `batch`, `event`.
343
344    Args:
345      x: `Tensor`.
346      name: Python `str`. The name to give this op.
347
348    Returns:
349      sample_shape: `Tensor` (1D, `int32`).
350      batch_shape: `Tensor` (1D, `int32`).
351      event_shape: `Tensor` (1D, `int32`).
352    """
353    with self._name_scope(name, values=[x]):
354      x = ops.convert_to_tensor(x, name="x")
355      def slice_shape(start_sum, size, name):
356        """Closure to slice out shape."""
357        start_sum = start_sum if start_sum else [
358            array_ops.zeros([], dtype=dtypes.int32, name="zero")]
359        if (x.get_shape().ndims is not None and
360            self._is_all_constant_helper(size, *start_sum)):
361          start = sum(tensor_util.constant_value(s) for s in start_sum)
362          stop = start + tensor_util.constant_value(size)
363          slice_ = x.get_shape()[start:stop].as_list()
364          if all(s is not None for s in slice_):
365            return ops.convert_to_tensor(slice_, dtype=dtypes.int32, name=name)
366        return array_ops.slice(array_ops.shape(x), [sum(start_sum)], [size])
367      sample_ndims = self.get_sample_ndims(x, name=name)
368      return (slice_shape([], sample_ndims,
369                          name="sample_shape"),
370              slice_shape([sample_ndims], self.batch_ndims,
371                          name="batch_shape"),
372              slice_shape([sample_ndims, self.batch_ndims], self.event_ndims,
373                          name="event_shape"))
374
375  # TODO(jvdillon): Make remove expand_batch_dim and make expand_batch_dim=False
376  # the default behavior.
377  def make_batch_of_event_sample_matrices(
378      self, x, expand_batch_dim=True,
379      name="make_batch_of_event_sample_matrices"):
380    """Reshapes/transposes `Distribution` `Tensor` from S+B+E to B_+E_+S_.
381
382    Where:
383      - `B_ = B if B or not expand_batch_dim else [1]`,
384      - `E_ = E if E else [1]`,
385      - `S_ = [tf.reduce_prod(S)]`.
386
387    Args:
388      x: `Tensor`.
389      expand_batch_dim: Python `bool`. If `True` the batch dims will be expanded
390        such that `batch_ndims >= 1`.
391      name: Python `str`. The name to give this op.
392
393    Returns:
394      x: `Tensor`. Input transposed/reshaped to `B_+E_+S_`.
395      sample_shape: `Tensor` (1D, `int32`).
396    """
397    with self._name_scope(name, values=[x]):
398      x = ops.convert_to_tensor(x, name="x")
399      # x.shape: S+B+E
400      sample_shape, batch_shape, event_shape = self.get_shape(x)
401      event_shape = distribution_util.pick_vector(
402          self._event_ndims_is_0, [1], event_shape)
403      if expand_batch_dim:
404        batch_shape = distribution_util.pick_vector(
405            self._batch_ndims_is_0, [1], batch_shape)
406      new_shape = array_ops.concat([[-1], batch_shape, event_shape], 0)
407      x = array_ops.reshape(x, shape=new_shape)
408      # x.shape: [prod(S)]+B_+E_
409      x = distribution_util.rotate_transpose(x, shift=-1)
410      # x.shape: B_+E_+[prod(S)]
411      return x, sample_shape
412
413  # TODO(jvdillon): Make remove expand_batch_dim and make expand_batch_dim=False
414  # the default behavior.
415  def undo_make_batch_of_event_sample_matrices(
416      self, x, sample_shape, expand_batch_dim=True,
417      name="undo_make_batch_of_event_sample_matrices"):
418    """Reshapes/transposes `Distribution` `Tensor` from B_+E_+S_ to S+B+E.
419
420    Where:
421      - `B_ = B if B or not expand_batch_dim else [1]`,
422      - `E_ = E if E else [1]`,
423      - `S_ = [tf.reduce_prod(S)]`.
424
425    This function "reverses" `make_batch_of_event_sample_matrices`.
426
427    Args:
428      x: `Tensor` of shape `B_+E_+S_`.
429      sample_shape: `Tensor` (1D, `int32`).
430      expand_batch_dim: Python `bool`. If `True` the batch dims will be expanded
431        such that `batch_ndims>=1`.
432      name: Python `str`. The name to give this op.
433
434    Returns:
435      x: `Tensor`. Input transposed/reshaped to `S+B+E`.
436    """
437    with self._name_scope(name, values=[x, sample_shape]):
438      x = ops.convert_to_tensor(x, name="x")
439      # x.shape: _B+_E+[prod(S)]
440      sample_shape = ops.convert_to_tensor(sample_shape, name="sample_shape")
441      x = distribution_util.rotate_transpose(x, shift=1)
442      # x.shape: [prod(S)]+_B+_E
443      if self._is_all_constant_helper(self.batch_ndims, self.event_ndims):
444        if self._batch_ndims_is_0 or self._event_ndims_is_0:
445          squeeze_dims = []
446          if self._event_ndims_is_0:
447            squeeze_dims += [-1]
448          if self._batch_ndims_is_0 and expand_batch_dim:
449            squeeze_dims += [1]
450          if squeeze_dims:
451            x = array_ops.squeeze(x, axis=squeeze_dims)
452            # x.shape: [prod(S)]+B+E
453        _, batch_shape, event_shape = self.get_shape(x)
454      else:
455        s = (x.get_shape().as_list() if x.get_shape().is_fully_defined()
456             else array_ops.shape(x))
457        batch_shape = s[1:1+self.batch_ndims]
458        # Since sample_dims=1 and is left-most, we add 1 to the number of
459        # batch_ndims to get the event start dim.
460        event_start = array_ops.where(
461            math_ops.logical_and(expand_batch_dim, self._batch_ndims_is_0),
462            2, 1 + self.batch_ndims)
463        event_shape = s[event_start:event_start+self.event_ndims]
464      new_shape = array_ops.concat([sample_shape, batch_shape, event_shape], 0)
465      x = array_ops.reshape(x, shape=new_shape)
466      # x.shape: S+B+E
467      return x
468
469  @contextlib.contextmanager
470  def _name_scope(self, name=None, values=None):
471    """Helper function to standardize op scope."""
472    with ops.name_scope(self.name):
473      with ops.name_scope(name, values=(
474          (values or []) + [self.batch_ndims, self.event_ndims])) as scope:
475        yield scope
476
477  def _is_all_constant_helper(self, *args):
478    """Helper which returns True if all inputs are constant_value."""
479    return all(tensor_util.constant_value(x) is not None for x in args)
480
481  def _assert_non_negative_int32_scalar(self, x):
482    """Helper which ensures that input is a non-negative, int32, scalar."""
483    x = ops.convert_to_tensor(x, name="x")
484    if x.dtype.base_dtype != dtypes.int32.base_dtype:
485      raise TypeError("%s.dtype=%s is not %s" % (x.name, x.dtype, dtypes.int32))
486    x_value_static = tensor_util.constant_value(x)
487    if x.get_shape().ndims is not None and x_value_static is not None:
488      if x.get_shape().ndims != 0:
489        raise ValueError("%s.ndims=%d is not 0 (scalar)" %
490                         (x.name, x.get_shape().ndims))
491      if x_value_static < 0:
492        raise ValueError("%s.value=%d cannot be negative" %
493                         (x.name, x_value_static))
494      return x
495    if self.validate_args:
496      x = control_flow_ops.with_dependencies([
497          check_ops.assert_rank(x, 0),
498          check_ops.assert_non_negative(x)], x)
499    return x
500
501  def _introspect_ndims(self, ndims):
502    """Helper to establish some properties of input ndims args."""
503    if self._is_all_constant_helper(ndims):
504      return (tensor_util.constant_value(ndims),
505              tensor_util.constant_value(ndims) == 0)
506    return None, math_ops.equal(ndims, 0)
507