# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """The Deterministic distribution class.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import abc import six from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import distribution from tensorflow.python.util import deprecation __all__ = [ "Deterministic", "VectorDeterministic", ] @six.add_metaclass(abc.ABCMeta) class _BaseDeterministic(distribution.Distribution): """Base class for Deterministic distributions.""" @deprecation.deprecated( "2018-10-01", "The TensorFlow Distributions library has moved to " "TensorFlow Probability " "(https://github.com/tensorflow/probability). You " "should update all references to use `tfp.distributions` " "instead of `tf.contrib.distributions`.", warn_once=True) def __init__(self, loc, atol=None, rtol=None, is_vector=False, validate_args=False, allow_nan_stats=True, name="_BaseDeterministic"): """Initialize a batch of `_BaseDeterministic` distributions. The `atol` and `rtol` parameters allow for some slack in `pmf`, `cdf` computations, e.g. due to floating-point error. ``` pmf(x; loc) = 1, if Abs(x - loc) <= atol + rtol * Abs(loc), = 0, otherwise. ``` Args: loc: Numeric `Tensor`. The point (or batch of points) on which this distribution is supported. atol: Non-negative `Tensor` of same `dtype` as `loc` and broadcastable shape. The absolute tolerance for comparing closeness to `loc`. Default is `0`. rtol: Non-negative `Tensor` of same `dtype` as `loc` and broadcastable shape. The relative tolerance for comparing closeness to `loc`. Default is `0`. is_vector: Python `bool`. If `True`, this is for `VectorDeterministic`, else `Deterministic`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: ValueError: If `loc` is a scalar. """ parameters = dict(locals()) with ops.name_scope(name, values=[loc, atol, rtol]) as name: loc = ops.convert_to_tensor(loc, name="loc") if is_vector and validate_args: msg = "Argument loc must be at least rank 1." if loc.get_shape().ndims is not None: if loc.get_shape().ndims < 1: raise ValueError(msg) else: loc = control_flow_ops.with_dependencies( [check_ops.assert_rank_at_least(loc, 1, message=msg)], loc) self._loc = loc super(_BaseDeterministic, self).__init__( dtype=self._loc.dtype, reparameterization_type=distribution.NOT_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=[self._loc], name=name) self._atol = self._get_tol(atol) self._rtol = self._get_tol(rtol) # Avoid using the large broadcast with self.loc if possible. if rtol is None: self._slack = self.atol else: self._slack = self.atol + self.rtol * math_ops.abs(self.loc) def _get_tol(self, tol): if tol is None: return ops.convert_to_tensor(0, dtype=self.loc.dtype) tol = ops.convert_to_tensor(tol, dtype=self.loc.dtype) if self.validate_args: tol = control_flow_ops.with_dependencies([ check_ops.assert_non_negative( tol, message="Argument 'tol' must be non-negative") ], tol) return tol @property def loc(self): """Point (or batch of points) at which this distribution is supported.""" return self._loc @property def atol(self): """Absolute tolerance for comparing points to `self.loc`.""" return self._atol @property def rtol(self): """Relative tolerance for comparing points to `self.loc`.""" return self._rtol def _entropy(self): return array_ops.zeros(self.batch_shape_tensor(), dtype=self.dtype) def _mean(self): return array_ops.identity(self.loc) def _variance(self): return array_ops.zeros_like(self.loc) def _mode(self): return self.mean() def _sample_n(self, n, seed=None): # pylint: disable=unused-arg n_static = tensor_util.constant_value(ops.convert_to_tensor(n)) if n_static is not None and self.loc.get_shape().ndims is not None: ones = [1] * self.loc.get_shape().ndims multiples = [n_static] + ones else: ones = array_ops.ones_like(array_ops.shape(self.loc)) multiples = array_ops.concat(([n], ones), axis=0) return array_ops.tile(self.loc[array_ops.newaxis, ...], multiples=multiples) class Deterministic(_BaseDeterministic): """Scalar `Deterministic` distribution on the real line. The scalar `Deterministic` distribution is parameterized by a [batch] point `loc` on the real line. The distribution is supported at this point only, and corresponds to a random variable that is constant, equal to `loc`. See [Degenerate rv](https://en.wikipedia.org/wiki/Degenerate_distribution). #### Mathematical Details The probability mass function (pmf) and cumulative distribution function (cdf) are ```none pmf(x; loc) = 1, if x == loc, else 0 cdf(x; loc) = 1, if x >= loc, else 0 ``` #### Examples ```python import tensorflow_probability as tfp tfd = tfp.distributions # Initialize a single Deterministic supported at zero. constant = tfd.Deterministic(0.) constant.prob(0.) ==> 1. constant.prob(2.) ==> 0. # Initialize a [2, 2] batch of scalar constants. loc = [[0., 1.], [2., 3.]] x = [[0., 1.1], [1.99, 3.]] constant = tfd.Deterministic(loc) constant.prob(x) ==> [[1., 0.], [0., 1.]] ``` """ @deprecation.deprecated( "2018-10-01", "The TensorFlow Distributions library has moved to " "TensorFlow Probability " "(https://github.com/tensorflow/probability). You " "should update all references to use `tfp.distributions` " "instead of `tf.contrib.distributions`.", warn_once=True) def __init__(self, loc, atol=None, rtol=None, validate_args=False, allow_nan_stats=True, name="Deterministic"): """Initialize a scalar `Deterministic` distribution. The `atol` and `rtol` parameters allow for some slack in `pmf`, `cdf` computations, e.g. due to floating-point error. ``` pmf(x; loc) = 1, if Abs(x - loc) <= atol + rtol * Abs(loc), = 0, otherwise. ``` Args: loc: Numeric `Tensor` of shape `[B1, ..., Bb]`, with `b >= 0`. The point (or batch of points) on which this distribution is supported. atol: Non-negative `Tensor` of same `dtype` as `loc` and broadcastable shape. The absolute tolerance for comparing closeness to `loc`. Default is `0`. rtol: Non-negative `Tensor` of same `dtype` as `loc` and broadcastable shape. The relative tolerance for comparing closeness to `loc`. Default is `0`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ super(Deterministic, self).__init__( loc, atol=atol, rtol=rtol, validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=name) def _batch_shape_tensor(self): return array_ops.shape(self.loc) def _batch_shape(self): return self.loc.get_shape() def _event_shape_tensor(self): return constant_op.constant([], dtype=dtypes.int32) def _event_shape(self): return tensor_shape.scalar() def _prob(self, x): return math_ops.cast( math_ops.abs(x - self.loc) <= self._slack, dtype=self.dtype) def _cdf(self, x): return math_ops.cast(x >= self.loc - self._slack, dtype=self.dtype) class VectorDeterministic(_BaseDeterministic): """Vector `Deterministic` distribution on `R^k`. The `VectorDeterministic` distribution is parameterized by a [batch] point `loc in R^k`. The distribution is supported at this point only, and corresponds to a random variable that is constant, equal to `loc`. See [Degenerate rv](https://en.wikipedia.org/wiki/Degenerate_distribution). #### Mathematical Details The probability mass function (pmf) is ```none pmf(x; loc) = 1, if All[Abs(x - loc) <= atol + rtol * Abs(loc)], = 0, otherwise. ``` #### Examples ```python import tensorflow_probability as tfp tfd = tfp.distributions # Initialize a single VectorDeterministic supported at [0., 2.] in R^2. constant = tfd.Deterministic([0., 2.]) constant.prob([0., 2.]) ==> 1. constant.prob([0., 3.]) ==> 0. # Initialize a [3] batch of constants on R^2. loc = [[0., 1.], [2., 3.], [4., 5.]] constant = tfd.VectorDeterministic(loc) constant.prob([[0., 1.], [1.9, 3.], [3.99, 5.]]) ==> [1., 0., 0.] ``` """ @deprecation.deprecated( "2018-10-01", "The TensorFlow Distributions library has moved to " "TensorFlow Probability " "(https://github.com/tensorflow/probability). You " "should update all references to use `tfp.distributions` " "instead of `tf.contrib.distributions`.", warn_once=True) def __init__(self, loc, atol=None, rtol=None, validate_args=False, allow_nan_stats=True, name="VectorDeterministic"): """Initialize a `VectorDeterministic` distribution on `R^k`, for `k >= 0`. Note that there is only one point in `R^0`, the "point" `[]`. So if `k = 0` then `self.prob([]) == 1`. The `atol` and `rtol` parameters allow for some slack in `pmf` computations, e.g. due to floating-point error. ``` pmf(x; loc) = 1, if All[Abs(x - loc) <= atol + rtol * Abs(loc)], = 0, otherwise ``` Args: loc: Numeric `Tensor` of shape `[B1, ..., Bb, k]`, with `b >= 0`, `k >= 0` The point (or batch of points) on which this distribution is supported. atol: Non-negative `Tensor` of same `dtype` as `loc` and broadcastable shape. The absolute tolerance for comparing closeness to `loc`. Default is `0`. rtol: Non-negative `Tensor` of same `dtype` as `loc` and broadcastable shape. The relative tolerance for comparing closeness to `loc`. Default is `0`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ super(VectorDeterministic, self).__init__( loc, atol=atol, rtol=rtol, is_vector=True, validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=name) def _batch_shape_tensor(self): return array_ops.shape(self.loc)[:-1] def _batch_shape(self): return self.loc.get_shape()[:-1] def _event_shape_tensor(self): return array_ops.shape(self.loc)[-1] def _event_shape(self): return self.loc.get_shape()[-1:] def _prob(self, x): if self.validate_args: is_vector_check = check_ops.assert_rank_at_least(x, 1) right_vec_space_check = check_ops.assert_equal( self.event_shape_tensor(), array_ops.gather(array_ops.shape(x), array_ops.rank(x) - 1), message= "Argument 'x' not defined in the same space R^k as this distribution") with ops.control_dependencies([is_vector_check]): with ops.control_dependencies([right_vec_space_check]): x = array_ops.identity(x) return math_ops.cast( math_ops.reduce_all(math_ops.abs(x - self.loc) <= self._slack, axis=-1), dtype=self.dtype)