1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Base classes for probability distributions.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import abc 22import contextlib 23import types 24 25import numpy as np 26import six 27 28from tensorflow.python.eager import context 29from tensorflow.python.framework import dtypes 30from tensorflow.python.framework import ops 31from tensorflow.python.framework import tensor_shape 32from tensorflow.python.framework import tensor_util 33from tensorflow.python.ops import array_ops 34from tensorflow.python.ops import math_ops 35from tensorflow.python.ops.distributions import kullback_leibler 36from tensorflow.python.ops.distributions import util 37from tensorflow.python.util import deprecation 38from tensorflow.python.util import tf_inspect 39from tensorflow.python.util.tf_export import tf_export 40 41 42__all__ = [ 43 "ReparameterizationType", 44 "FULLY_REPARAMETERIZED", 45 "NOT_REPARAMETERIZED", 46 "Distribution", 47] 48 49_DISTRIBUTION_PUBLIC_METHOD_WRAPPERS = [ 50 "batch_shape", 51 "batch_shape_tensor", 52 "cdf", 53 "covariance", 54 "cross_entropy", 55 "entropy", 56 "event_shape", 57 "event_shape_tensor", 58 "kl_divergence", 59 "log_cdf", 60 "log_prob", 61 "log_survival_function", 62 "mean", 63 "mode", 64 "prob", 65 "sample", 66 "stddev", 67 "survival_function", 68 "variance", 69] 70 71 72@six.add_metaclass(abc.ABCMeta) 73class _BaseDistribution(object): 74 """Abstract base class needed for resolving subclass hierarchy.""" 75 pass 76 77 78def _copy_fn(fn): 79 """Create a deep copy of fn. 80 81 Args: 82 fn: a callable 83 84 Returns: 85 A `FunctionType`: a deep copy of fn. 86 87 Raises: 88 TypeError: if `fn` is not a callable. 89 """ 90 if not callable(fn): 91 raise TypeError("fn is not callable: %s" % fn) 92 # The blessed way to copy a function. copy.deepcopy fails to create a 93 # non-reference copy. Since: 94 # types.FunctionType == type(lambda: None), 95 # and the docstring for the function type states: 96 # 97 # function(code, globals[, name[, argdefs[, closure]]]) 98 # 99 # Create a function object from a code object and a dictionary. 100 # ... 101 # 102 # Here we can use this to create a new function with the old function's 103 # code, globals, closure, etc. 104 return types.FunctionType( 105 code=fn.__code__, globals=fn.__globals__, 106 name=fn.__name__, argdefs=fn.__defaults__, 107 closure=fn.__closure__) 108 109 110def _update_docstring(old_str, append_str): 111 """Update old_str by inserting append_str just before the "Args:" section.""" 112 old_str = old_str or "" 113 old_str_lines = old_str.split("\n") 114 115 # Step 0: Prepend spaces to all lines of append_str. This is 116 # necessary for correct markdown generation. 117 append_str = "\n".join(" %s" % line for line in append_str.split("\n")) 118 119 # Step 1: Find mention of "Args": 120 has_args_ix = [ 121 ix for ix, line in enumerate(old_str_lines) 122 if line.strip().lower() == "args:"] 123 if has_args_ix: 124 final_args_ix = has_args_ix[-1] 125 return ("\n".join(old_str_lines[:final_args_ix]) 126 + "\n\n" + append_str + "\n\n" 127 + "\n".join(old_str_lines[final_args_ix:])) 128 else: 129 return old_str + "\n\n" + append_str 130 131 132def _convert_to_tensor(value, name=None, preferred_dtype=None): 133 """Converts to tensor avoiding an eager bug that loses float precision.""" 134 # TODO(b/116672045): Remove this function. 135 if (context.executing_eagerly() and preferred_dtype is not None and 136 (preferred_dtype.is_integer or preferred_dtype.is_bool)): 137 v = ops.convert_to_tensor(value, name=name) 138 if v.dtype.is_floating: 139 return v 140 return ops.convert_to_tensor( 141 value, name=name, preferred_dtype=preferred_dtype) 142 143 144class _DistributionMeta(abc.ABCMeta): 145 146 def __new__(mcs, classname, baseclasses, attrs): 147 """Control the creation of subclasses of the Distribution class. 148 149 The main purpose of this method is to properly propagate docstrings 150 from private Distribution methods, like `_log_prob`, into their 151 public wrappers as inherited by the Distribution base class 152 (e.g. `log_prob`). 153 154 Args: 155 classname: The name of the subclass being created. 156 baseclasses: A tuple of parent classes. 157 attrs: A dict mapping new attributes to their values. 158 159 Returns: 160 The class object. 161 162 Raises: 163 TypeError: If `Distribution` is not a subclass of `BaseDistribution`, or 164 the new class is derived via multiple inheritance and the first 165 parent class is not a subclass of `BaseDistribution`. 166 AttributeError: If `Distribution` does not implement e.g. `log_prob`. 167 ValueError: If a `Distribution` public method lacks a docstring. 168 """ 169 if not baseclasses: # Nothing to be done for Distribution 170 raise TypeError("Expected non-empty baseclass. Does Distribution " 171 "not subclass _BaseDistribution?") 172 which_base = [ 173 base for base in baseclasses 174 if base == _BaseDistribution or issubclass(base, Distribution)] 175 base = which_base[0] 176 if base == _BaseDistribution: # Nothing to be done for Distribution 177 return abc.ABCMeta.__new__(mcs, classname, baseclasses, attrs) 178 if not issubclass(base, Distribution): 179 raise TypeError("First parent class declared for %s must be " 180 "Distribution, but saw '%s'" % (classname, base.__name__)) 181 for attr in _DISTRIBUTION_PUBLIC_METHOD_WRAPPERS: 182 special_attr = "_%s" % attr 183 class_attr_value = attrs.get(attr, None) 184 if attr in attrs: 185 # The method is being overridden, do not update its docstring 186 continue 187 base_attr_value = getattr(base, attr, None) 188 if not base_attr_value: 189 raise AttributeError( 190 "Internal error: expected base class '%s' to implement method '%s'" 191 % (base.__name__, attr)) 192 class_special_attr_value = attrs.get(special_attr, None) 193 if class_special_attr_value is None: 194 # No _special method available, no need to update the docstring. 195 continue 196 class_special_attr_docstring = tf_inspect.getdoc(class_special_attr_value) 197 if not class_special_attr_docstring: 198 # No docstring to append. 199 continue 200 class_attr_value = _copy_fn(base_attr_value) 201 class_attr_docstring = tf_inspect.getdoc(base_attr_value) 202 if class_attr_docstring is None: 203 raise ValueError( 204 "Expected base class fn to contain a docstring: %s.%s" 205 % (base.__name__, attr)) 206 class_attr_value.__doc__ = _update_docstring( 207 class_attr_value.__doc__, 208 ("Additional documentation from `%s`:\n\n%s" 209 % (classname, class_special_attr_docstring))) 210 attrs[attr] = class_attr_value 211 212 return abc.ABCMeta.__new__(mcs, classname, baseclasses, attrs) 213 214 215@tf_export(v1=["distributions.ReparameterizationType"]) 216class ReparameterizationType(object): 217 """Instances of this class represent how sampling is reparameterized. 218 219 Two static instances exist in the distributions library, signifying 220 one of two possible properties for samples from a distribution: 221 222 `FULLY_REPARAMETERIZED`: Samples from the distribution are fully 223 reparameterized, and straight-through gradients are supported. 224 225 `NOT_REPARAMETERIZED`: Samples from the distribution are not fully 226 reparameterized, and straight-through gradients are either partially 227 unsupported or are not supported at all. In this case, for purposes of 228 e.g. RL or variational inference, it is generally safest to wrap the 229 sample results in a `stop_gradients` call and use policy 230 gradients / surrogate loss instead. 231 """ 232 233 @deprecation.deprecated( 234 "2019-01-01", 235 "The TensorFlow Distributions library has moved to " 236 "TensorFlow Probability " 237 "(https://github.com/tensorflow/probability). You " 238 "should update all references to use `tfp.distributions` " 239 "instead of `tf.distributions`.", 240 warn_once=True) 241 def __init__(self, rep_type): 242 self._rep_type = rep_type 243 244 def __repr__(self): 245 return "<Reparameteriation Type: %s>" % self._rep_type 246 247 def __eq__(self, other): 248 """Determine if this `ReparameterizationType` is equal to another. 249 250 Since RepaparameterizationType instances are constant static global 251 instances, equality checks if two instances' id() values are equal. 252 253 Args: 254 other: Object to compare against. 255 256 Returns: 257 `self is other`. 258 """ 259 return self is other 260 261 262# Fully reparameterized distribution: samples from a fully 263# reparameterized distribution support straight-through gradients with 264# respect to all parameters. 265FULLY_REPARAMETERIZED = ReparameterizationType("FULLY_REPARAMETERIZED") 266tf_export(v1=["distributions.FULLY_REPARAMETERIZED"]).export_constant( 267 __name__, "FULLY_REPARAMETERIZED") 268 269 270# Not reparameterized distribution: samples from a non- 271# reparameterized distribution do not support straight-through gradients for 272# at least some of the parameters. 273NOT_REPARAMETERIZED = ReparameterizationType("NOT_REPARAMETERIZED") 274tf_export(v1=["distributions.NOT_REPARAMETERIZED"]).export_constant( 275 __name__, "NOT_REPARAMETERIZED") 276 277 278@six.add_metaclass(_DistributionMeta) 279@tf_export(v1=["distributions.Distribution"]) 280class Distribution(_BaseDistribution): 281 """A generic probability distribution base class. 282 283 `Distribution` is a base class for constructing and organizing properties 284 (e.g., mean, variance) of random variables (e.g, Bernoulli, Gaussian). 285 286 #### Subclassing 287 288 Subclasses are expected to implement a leading-underscore version of the 289 same-named function. The argument signature should be identical except for 290 the omission of `name="..."`. For example, to enable `log_prob(value, 291 name="log_prob")` a subclass should implement `_log_prob(value)`. 292 293 Subclasses can append to public-level docstrings by providing 294 docstrings for their method specializations. For example: 295 296 ```python 297 @util.AppendDocstring("Some other details.") 298 def _log_prob(self, value): 299 ... 300 ``` 301 302 would add the string "Some other details." to the `log_prob` function 303 docstring. This is implemented as a simple decorator to avoid python 304 linter complaining about missing Args/Returns/Raises sections in the 305 partial docstrings. 306 307 #### Broadcasting, batching, and shapes 308 309 All distributions support batches of independent distributions of that type. 310 The batch shape is determined by broadcasting together the parameters. 311 312 The shape of arguments to `__init__`, `cdf`, `log_cdf`, `prob`, and 313 `log_prob` reflect this broadcasting, as does the return value of `sample` and 314 `sample_n`. 315 316 `sample_n_shape = [n] + batch_shape + event_shape`, where `sample_n_shape` is 317 the shape of the `Tensor` returned from `sample_n`, `n` is the number of 318 samples, `batch_shape` defines how many independent distributions there are, 319 and `event_shape` defines the shape of samples from each of those independent 320 distributions. Samples are independent along the `batch_shape` dimensions, but 321 not necessarily so along the `event_shape` dimensions (depending on the 322 particulars of the underlying distribution). 323 324 Using the `Uniform` distribution as an example: 325 326 ```python 327 minval = 3.0 328 maxval = [[4.0, 6.0], 329 [10.0, 12.0]] 330 331 # Broadcasting: 332 # This instance represents 4 Uniform distributions. Each has a lower bound at 333 # 3.0 as the `minval` parameter was broadcasted to match `maxval`'s shape. 334 u = Uniform(minval, maxval) 335 336 # `event_shape` is `TensorShape([])`. 337 event_shape = u.event_shape 338 # `event_shape_t` is a `Tensor` which will evaluate to []. 339 event_shape_t = u.event_shape_tensor() 340 341 # Sampling returns a sample per distribution. `samples` has shape 342 # [5, 2, 2], which is [n] + batch_shape + event_shape, where n=5, 343 # batch_shape=[2, 2], and event_shape=[]. 344 samples = u.sample_n(5) 345 346 # The broadcasting holds across methods. Here we use `cdf` as an example. The 347 # same holds for `log_cdf` and the likelihood functions. 348 349 # `cum_prob` has shape [2, 2] as the `value` argument was broadcasted to the 350 # shape of the `Uniform` instance. 351 cum_prob_broadcast = u.cdf(4.0) 352 353 # `cum_prob`'s shape is [2, 2], one per distribution. No broadcasting 354 # occurred. 355 cum_prob_per_dist = u.cdf([[4.0, 5.0], 356 [6.0, 7.0]]) 357 358 # INVALID as the `value` argument is not broadcastable to the distribution's 359 # shape. 360 cum_prob_invalid = u.cdf([4.0, 5.0, 6.0]) 361 ``` 362 363 #### Shapes 364 365 There are three important concepts associated with TensorFlow Distributions 366 shapes: 367 - Event shape describes the shape of a single draw from the distribution; 368 it may be dependent across dimensions. For scalar distributions, the event 369 shape is `[]`. For a 5-dimensional MultivariateNormal, the event shape is 370 `[5]`. 371 - Batch shape describes independent, not identically distributed draws, aka a 372 "collection" or "bunch" of distributions. 373 - Sample shape describes independent, identically distributed draws of batches 374 from the distribution family. 375 376 The event shape and the batch shape are properties of a Distribution object, 377 whereas the sample shape is associated with a specific call to `sample` or 378 `log_prob`. 379 380 For detailed usage examples of TensorFlow Distributions shapes, see 381 [this tutorial]( 382 https://github.com/tensorflow/probability/blob/master/tensorflow_probability/examples/jupyter_notebooks/Understanding_TensorFlow_Distributions_Shapes.ipynb) 383 384 #### Parameter values leading to undefined statistics or distributions. 385 386 Some distributions do not have well-defined statistics for all initialization 387 parameter values. For example, the beta distribution is parameterized by 388 positive real numbers `concentration1` and `concentration0`, and does not have 389 well-defined mode if `concentration1 < 1` or `concentration0 < 1`. 390 391 The user is given the option of raising an exception or returning `NaN`. 392 393 ```python 394 a = tf.exp(tf.matmul(logits, weights_a)) 395 b = tf.exp(tf.matmul(logits, weights_b)) 396 397 # Will raise exception if ANY batch member has a < 1 or b < 1. 398 dist = distributions.beta(a, b, allow_nan_stats=False) 399 mode = dist.mode().eval() 400 401 # Will return NaN for batch members with either a < 1 or b < 1. 402 dist = distributions.beta(a, b, allow_nan_stats=True) # Default behavior 403 mode = dist.mode().eval() 404 ``` 405 406 In all cases, an exception is raised if *invalid* parameters are passed, e.g. 407 408 ```python 409 # Will raise an exception if any Op is run. 410 negative_a = -1.0 * a # beta distribution by definition has a > 0. 411 dist = distributions.beta(negative_a, b, allow_nan_stats=True) 412 dist.mean().eval() 413 ``` 414 415 """ 416 417 @deprecation.deprecated( 418 "2019-01-01", 419 "The TensorFlow Distributions library has moved to " 420 "TensorFlow Probability " 421 "(https://github.com/tensorflow/probability). You " 422 "should update all references to use `tfp.distributions` " 423 "instead of `tf.distributions`.", 424 warn_once=True) 425 def __init__(self, 426 dtype, 427 reparameterization_type, 428 validate_args, 429 allow_nan_stats, 430 parameters=None, 431 graph_parents=None, 432 name=None): 433 """Constructs the `Distribution`. 434 435 **This is a private method for subclass use.** 436 437 Args: 438 dtype: The type of the event samples. `None` implies no type-enforcement. 439 reparameterization_type: Instance of `ReparameterizationType`. 440 If `distributions.FULLY_REPARAMETERIZED`, this 441 `Distribution` can be reparameterized in terms of some standard 442 distribution with a function whose Jacobian is constant for the support 443 of the standard distribution. If `distributions.NOT_REPARAMETERIZED`, 444 then no such reparameterization is available. 445 validate_args: Python `bool`, default `False`. When `True` distribution 446 parameters are checked for validity despite possibly degrading runtime 447 performance. When `False` invalid inputs may silently render incorrect 448 outputs. 449 allow_nan_stats: Python `bool`, default `True`. When `True`, statistics 450 (e.g., mean, mode, variance) use the value "`NaN`" to indicate the 451 result is undefined. When `False`, an exception is raised if one or 452 more of the statistic's batch members are undefined. 453 parameters: Python `dict` of parameters used to instantiate this 454 `Distribution`. 455 graph_parents: Python `list` of graph prerequisites of this 456 `Distribution`. 457 name: Python `str` name prefixed to Ops created by this class. Default: 458 subclass name. 459 460 Raises: 461 ValueError: if any member of graph_parents is `None` or not a `Tensor`. 462 """ 463 graph_parents = [] if graph_parents is None else graph_parents 464 for i, t in enumerate(graph_parents): 465 if t is None or not tensor_util.is_tensor(t): 466 raise ValueError("Graph parent item %d is not a Tensor; %s." % (i, t)) 467 if not name or name[-1] != "/": # `name` is not a name scope 468 non_unique_name = name or type(self).__name__ 469 with ops.name_scope(non_unique_name) as name: 470 pass 471 self._dtype = dtype 472 self._reparameterization_type = reparameterization_type 473 self._allow_nan_stats = allow_nan_stats 474 self._validate_args = validate_args 475 self._parameters = parameters or {} 476 self._graph_parents = graph_parents 477 self._name = name 478 479 @property 480 def _parameters(self): 481 return self._parameter_dict 482 483 @_parameters.setter 484 def _parameters(self, value): 485 """Intercept assignments to self._parameters to avoid reference cycles. 486 487 Parameters are often created using locals(), so we need to clean out any 488 references to `self` before assigning it to an attribute. 489 490 Args: 491 value: A dictionary of parameters to assign to the `_parameters` property. 492 """ 493 if "self" in value: 494 del value["self"] 495 self._parameter_dict = value 496 497 @classmethod 498 def param_shapes(cls, sample_shape, name="DistributionParamShapes"): 499 """Shapes of parameters given the desired shape of a call to `sample()`. 500 501 This is a class method that describes what key/value arguments are required 502 to instantiate the given `Distribution` so that a particular shape is 503 returned for that instance's call to `sample()`. 504 505 Subclasses should override class method `_param_shapes`. 506 507 Args: 508 sample_shape: `Tensor` or python list/tuple. Desired shape of a call to 509 `sample()`. 510 name: name to prepend ops with. 511 512 Returns: 513 `dict` of parameter name to `Tensor` shapes. 514 """ 515 with ops.name_scope(name, values=[sample_shape]): 516 return cls._param_shapes(sample_shape) 517 518 @classmethod 519 def param_static_shapes(cls, sample_shape): 520 """param_shapes with static (i.e. `TensorShape`) shapes. 521 522 This is a class method that describes what key/value arguments are required 523 to instantiate the given `Distribution` so that a particular shape is 524 returned for that instance's call to `sample()`. Assumes that the sample's 525 shape is known statically. 526 527 Subclasses should override class method `_param_shapes` to return 528 constant-valued tensors when constant values are fed. 529 530 Args: 531 sample_shape: `TensorShape` or python list/tuple. Desired shape of a call 532 to `sample()`. 533 534 Returns: 535 `dict` of parameter name to `TensorShape`. 536 537 Raises: 538 ValueError: if `sample_shape` is a `TensorShape` and is not fully defined. 539 """ 540 if isinstance(sample_shape, tensor_shape.TensorShape): 541 if not sample_shape.is_fully_defined(): 542 raise ValueError("TensorShape sample_shape must be fully defined") 543 sample_shape = sample_shape.as_list() 544 545 params = cls.param_shapes(sample_shape) 546 547 static_params = {} 548 for name, shape in params.items(): 549 static_shape = tensor_util.constant_value(shape) 550 if static_shape is None: 551 raise ValueError( 552 "sample_shape must be a fully-defined TensorShape or list/tuple") 553 static_params[name] = tensor_shape.TensorShape(static_shape) 554 555 return static_params 556 557 @staticmethod 558 def _param_shapes(sample_shape): 559 raise NotImplementedError("_param_shapes not implemented") 560 561 @property 562 def name(self): 563 """Name prepended to all ops created by this `Distribution`.""" 564 return self._name 565 566 @property 567 def dtype(self): 568 """The `DType` of `Tensor`s handled by this `Distribution`.""" 569 return self._dtype 570 571 @property 572 def parameters(self): 573 """Dictionary of parameters used to instantiate this `Distribution`.""" 574 # Remove "self", "__class__", or other special variables. These can appear 575 # if the subclass used: 576 # `parameters = dict(locals())`. 577 return {k: v for k, v in self._parameters.items() 578 if not k.startswith("__") and k != "self"} 579 580 @property 581 def reparameterization_type(self): 582 """Describes how samples from the distribution are reparameterized. 583 584 Currently this is one of the static instances 585 `distributions.FULLY_REPARAMETERIZED` 586 or `distributions.NOT_REPARAMETERIZED`. 587 588 Returns: 589 An instance of `ReparameterizationType`. 590 """ 591 return self._reparameterization_type 592 593 @property 594 def allow_nan_stats(self): 595 """Python `bool` describing behavior when a stat is undefined. 596 597 Stats return +/- infinity when it makes sense. E.g., the variance of a 598 Cauchy distribution is infinity. However, sometimes the statistic is 599 undefined, e.g., if a distribution's pdf does not achieve a maximum within 600 the support of the distribution, the mode is undefined. If the mean is 601 undefined, then by definition the variance is undefined. E.g. the mean for 602 Student's T for df = 1 is undefined (no clear way to say it is either + or - 603 infinity), so the variance = E[(X - mean)**2] is also undefined. 604 605 Returns: 606 allow_nan_stats: Python `bool`. 607 """ 608 return self._allow_nan_stats 609 610 @property 611 def validate_args(self): 612 """Python `bool` indicating possibly expensive checks are enabled.""" 613 return self._validate_args 614 615 def copy(self, **override_parameters_kwargs): 616 """Creates a deep copy of the distribution. 617 618 Note: the copy distribution may continue to depend on the original 619 initialization arguments. 620 621 Args: 622 **override_parameters_kwargs: String/value dictionary of initialization 623 arguments to override with new values. 624 625 Returns: 626 distribution: A new instance of `type(self)` initialized from the union 627 of self.parameters and override_parameters_kwargs, i.e., 628 `dict(self.parameters, **override_parameters_kwargs)`. 629 """ 630 parameters = dict(self.parameters, **override_parameters_kwargs) 631 return type(self)(**parameters) 632 633 def _batch_shape_tensor(self): 634 raise NotImplementedError( 635 "batch_shape_tensor is not implemented: {}".format(type(self).__name__)) 636 637 def batch_shape_tensor(self, name="batch_shape_tensor"): 638 """Shape of a single sample from a single event index as a 1-D `Tensor`. 639 640 The batch dimensions are indexes into independent, non-identical 641 parameterizations of this distribution. 642 643 Args: 644 name: name to give to the op 645 646 Returns: 647 batch_shape: `Tensor`. 648 """ 649 with self._name_scope(name): 650 if self.batch_shape.is_fully_defined(): 651 return ops.convert_to_tensor(self.batch_shape.as_list(), 652 dtype=dtypes.int32, 653 name="batch_shape") 654 return self._batch_shape_tensor() 655 656 def _batch_shape(self): 657 return tensor_shape.TensorShape(None) 658 659 @property 660 def batch_shape(self): 661 """Shape of a single sample from a single event index as a `TensorShape`. 662 663 May be partially defined or unknown. 664 665 The batch dimensions are indexes into independent, non-identical 666 parameterizations of this distribution. 667 668 Returns: 669 batch_shape: `TensorShape`, possibly unknown. 670 """ 671 return tensor_shape.as_shape(self._batch_shape()) 672 673 def _event_shape_tensor(self): 674 raise NotImplementedError( 675 "event_shape_tensor is not implemented: {}".format(type(self).__name__)) 676 677 def event_shape_tensor(self, name="event_shape_tensor"): 678 """Shape of a single sample from a single batch as a 1-D int32 `Tensor`. 679 680 Args: 681 name: name to give to the op 682 683 Returns: 684 event_shape: `Tensor`. 685 """ 686 with self._name_scope(name): 687 if self.event_shape.is_fully_defined(): 688 return ops.convert_to_tensor(self.event_shape.as_list(), 689 dtype=dtypes.int32, 690 name="event_shape") 691 return self._event_shape_tensor() 692 693 def _event_shape(self): 694 return tensor_shape.TensorShape(None) 695 696 @property 697 def event_shape(self): 698 """Shape of a single sample from a single batch as a `TensorShape`. 699 700 May be partially defined or unknown. 701 702 Returns: 703 event_shape: `TensorShape`, possibly unknown. 704 """ 705 return tensor_shape.as_shape(self._event_shape()) 706 707 def is_scalar_event(self, name="is_scalar_event"): 708 """Indicates that `event_shape == []`. 709 710 Args: 711 name: Python `str` prepended to names of ops created by this function. 712 713 Returns: 714 is_scalar_event: `bool` scalar `Tensor`. 715 """ 716 with self._name_scope(name): 717 return ops.convert_to_tensor( 718 self._is_scalar_helper(self.event_shape, self.event_shape_tensor), 719 name="is_scalar_event") 720 721 def is_scalar_batch(self, name="is_scalar_batch"): 722 """Indicates that `batch_shape == []`. 723 724 Args: 725 name: Python `str` prepended to names of ops created by this function. 726 727 Returns: 728 is_scalar_batch: `bool` scalar `Tensor`. 729 """ 730 with self._name_scope(name): 731 return ops.convert_to_tensor( 732 self._is_scalar_helper(self.batch_shape, self.batch_shape_tensor), 733 name="is_scalar_batch") 734 735 def _sample_n(self, n, seed=None): 736 raise NotImplementedError("sample_n is not implemented: {}".format( 737 type(self).__name__)) 738 739 def _call_sample_n(self, sample_shape, seed, name, **kwargs): 740 with self._name_scope(name, values=[sample_shape]): 741 sample_shape = ops.convert_to_tensor( 742 sample_shape, dtype=dtypes.int32, name="sample_shape") 743 sample_shape, n = self._expand_sample_shape_to_vector( 744 sample_shape, "sample_shape") 745 samples = self._sample_n(n, seed, **kwargs) 746 batch_event_shape = array_ops.shape(samples)[1:] 747 final_shape = array_ops.concat([sample_shape, batch_event_shape], 0) 748 samples = array_ops.reshape(samples, final_shape) 749 samples = self._set_sample_static_shape(samples, sample_shape) 750 return samples 751 752 def sample(self, sample_shape=(), seed=None, name="sample"): 753 """Generate samples of the specified shape. 754 755 Note that a call to `sample()` without arguments will generate a single 756 sample. 757 758 Args: 759 sample_shape: 0D or 1D `int32` `Tensor`. Shape of the generated samples. 760 seed: Python integer seed for RNG 761 name: name to give to the op. 762 763 Returns: 764 samples: a `Tensor` with prepended dimensions `sample_shape`. 765 """ 766 return self._call_sample_n(sample_shape, seed, name) 767 768 def _log_prob(self, value): 769 raise NotImplementedError("log_prob is not implemented: {}".format( 770 type(self).__name__)) 771 772 def _call_log_prob(self, value, name, **kwargs): 773 with self._name_scope(name, values=[value]): 774 value = _convert_to_tensor( 775 value, name="value", preferred_dtype=self.dtype) 776 try: 777 return self._log_prob(value, **kwargs) 778 except NotImplementedError as original_exception: 779 try: 780 return math_ops.log(self._prob(value, **kwargs)) 781 except NotImplementedError: 782 raise original_exception 783 784 def log_prob(self, value, name="log_prob"): 785 """Log probability density/mass function. 786 787 Args: 788 value: `float` or `double` `Tensor`. 789 name: Python `str` prepended to names of ops created by this function. 790 791 Returns: 792 log_prob: a `Tensor` of shape `sample_shape(x) + self.batch_shape` with 793 values of type `self.dtype`. 794 """ 795 return self._call_log_prob(value, name) 796 797 def _prob(self, value): 798 raise NotImplementedError("prob is not implemented: {}".format( 799 type(self).__name__)) 800 801 def _call_prob(self, value, name, **kwargs): 802 with self._name_scope(name, values=[value]): 803 value = _convert_to_tensor( 804 value, name="value", preferred_dtype=self.dtype) 805 try: 806 return self._prob(value, **kwargs) 807 except NotImplementedError as original_exception: 808 try: 809 return math_ops.exp(self._log_prob(value, **kwargs)) 810 except NotImplementedError: 811 raise original_exception 812 813 def prob(self, value, name="prob"): 814 """Probability density/mass function. 815 816 Args: 817 value: `float` or `double` `Tensor`. 818 name: Python `str` prepended to names of ops created by this function. 819 820 Returns: 821 prob: a `Tensor` of shape `sample_shape(x) + self.batch_shape` with 822 values of type `self.dtype`. 823 """ 824 return self._call_prob(value, name) 825 826 def _log_cdf(self, value): 827 raise NotImplementedError("log_cdf is not implemented: {}".format( 828 type(self).__name__)) 829 830 def _call_log_cdf(self, value, name, **kwargs): 831 with self._name_scope(name, values=[value]): 832 value = _convert_to_tensor( 833 value, name="value", preferred_dtype=self.dtype) 834 try: 835 return self._log_cdf(value, **kwargs) 836 except NotImplementedError as original_exception: 837 try: 838 return math_ops.log(self._cdf(value, **kwargs)) 839 except NotImplementedError: 840 raise original_exception 841 842 def log_cdf(self, value, name="log_cdf"): 843 """Log cumulative distribution function. 844 845 Given random variable `X`, the cumulative distribution function `cdf` is: 846 847 ```none 848 log_cdf(x) := Log[ P[X <= x] ] 849 ``` 850 851 Often, a numerical approximation can be used for `log_cdf(x)` that yields 852 a more accurate answer than simply taking the logarithm of the `cdf` when 853 `x << -1`. 854 855 Args: 856 value: `float` or `double` `Tensor`. 857 name: Python `str` prepended to names of ops created by this function. 858 859 Returns: 860 logcdf: a `Tensor` of shape `sample_shape(x) + self.batch_shape` with 861 values of type `self.dtype`. 862 """ 863 return self._call_log_cdf(value, name) 864 865 def _cdf(self, value): 866 raise NotImplementedError("cdf is not implemented: {}".format( 867 type(self).__name__)) 868 869 def _call_cdf(self, value, name, **kwargs): 870 with self._name_scope(name, values=[value]): 871 value = _convert_to_tensor( 872 value, name="value", preferred_dtype=self.dtype) 873 try: 874 return self._cdf(value, **kwargs) 875 except NotImplementedError as original_exception: 876 try: 877 return math_ops.exp(self._log_cdf(value, **kwargs)) 878 except NotImplementedError: 879 raise original_exception 880 881 def cdf(self, value, name="cdf"): 882 """Cumulative distribution function. 883 884 Given random variable `X`, the cumulative distribution function `cdf` is: 885 886 ```none 887 cdf(x) := P[X <= x] 888 ``` 889 890 Args: 891 value: `float` or `double` `Tensor`. 892 name: Python `str` prepended to names of ops created by this function. 893 894 Returns: 895 cdf: a `Tensor` of shape `sample_shape(x) + self.batch_shape` with 896 values of type `self.dtype`. 897 """ 898 return self._call_cdf(value, name) 899 900 def _log_survival_function(self, value): 901 raise NotImplementedError( 902 "log_survival_function is not implemented: {}".format( 903 type(self).__name__)) 904 905 def _call_log_survival_function(self, value, name, **kwargs): 906 with self._name_scope(name, values=[value]): 907 value = _convert_to_tensor( 908 value, name="value", preferred_dtype=self.dtype) 909 try: 910 return self._log_survival_function(value, **kwargs) 911 except NotImplementedError as original_exception: 912 try: 913 return math_ops.log1p(-self.cdf(value, **kwargs)) 914 except NotImplementedError: 915 raise original_exception 916 917 def log_survival_function(self, value, name="log_survival_function"): 918 """Log survival function. 919 920 Given random variable `X`, the survival function is defined: 921 922 ```none 923 log_survival_function(x) = Log[ P[X > x] ] 924 = Log[ 1 - P[X <= x] ] 925 = Log[ 1 - cdf(x) ] 926 ``` 927 928 Typically, different numerical approximations can be used for the log 929 survival function, which are more accurate than `1 - cdf(x)` when `x >> 1`. 930 931 Args: 932 value: `float` or `double` `Tensor`. 933 name: Python `str` prepended to names of ops created by this function. 934 935 Returns: 936 `Tensor` of shape `sample_shape(x) + self.batch_shape` with values of type 937 `self.dtype`. 938 """ 939 return self._call_log_survival_function(value, name) 940 941 def _survival_function(self, value): 942 raise NotImplementedError("survival_function is not implemented: {}".format( 943 type(self).__name__)) 944 945 def _call_survival_function(self, value, name, **kwargs): 946 with self._name_scope(name, values=[value]): 947 value = _convert_to_tensor( 948 value, name="value", preferred_dtype=self.dtype) 949 try: 950 return self._survival_function(value, **kwargs) 951 except NotImplementedError as original_exception: 952 try: 953 return 1. - self.cdf(value, **kwargs) 954 except NotImplementedError: 955 raise original_exception 956 957 def survival_function(self, value, name="survival_function"): 958 """Survival function. 959 960 Given random variable `X`, the survival function is defined: 961 962 ```none 963 survival_function(x) = P[X > x] 964 = 1 - P[X <= x] 965 = 1 - cdf(x). 966 ``` 967 968 Args: 969 value: `float` or `double` `Tensor`. 970 name: Python `str` prepended to names of ops created by this function. 971 972 Returns: 973 `Tensor` of shape `sample_shape(x) + self.batch_shape` with values of type 974 `self.dtype`. 975 """ 976 return self._call_survival_function(value, name) 977 978 def _entropy(self): 979 raise NotImplementedError("entropy is not implemented: {}".format( 980 type(self).__name__)) 981 982 def entropy(self, name="entropy"): 983 """Shannon entropy in nats.""" 984 with self._name_scope(name): 985 return self._entropy() 986 987 def _mean(self): 988 raise NotImplementedError("mean is not implemented: {}".format( 989 type(self).__name__)) 990 991 def mean(self, name="mean"): 992 """Mean.""" 993 with self._name_scope(name): 994 return self._mean() 995 996 def _quantile(self, value): 997 raise NotImplementedError("quantile is not implemented: {}".format( 998 type(self).__name__)) 999 1000 def _call_quantile(self, value, name, **kwargs): 1001 with self._name_scope(name, values=[value]): 1002 value = _convert_to_tensor( 1003 value, name="value", preferred_dtype=self.dtype) 1004 return self._quantile(value, **kwargs) 1005 1006 def quantile(self, value, name="quantile"): 1007 """Quantile function. Aka "inverse cdf" or "percent point function". 1008 1009 Given random variable `X` and `p in [0, 1]`, the `quantile` is: 1010 1011 ```none 1012 quantile(p) := x such that P[X <= x] == p 1013 ``` 1014 1015 Args: 1016 value: `float` or `double` `Tensor`. 1017 name: Python `str` prepended to names of ops created by this function. 1018 1019 Returns: 1020 quantile: a `Tensor` of shape `sample_shape(x) + self.batch_shape` with 1021 values of type `self.dtype`. 1022 """ 1023 return self._call_quantile(value, name) 1024 1025 def _variance(self): 1026 raise NotImplementedError("variance is not implemented: {}".format( 1027 type(self).__name__)) 1028 1029 def variance(self, name="variance"): 1030 """Variance. 1031 1032 Variance is defined as, 1033 1034 ```none 1035 Var = E[(X - E[X])**2] 1036 ``` 1037 1038 where `X` is the random variable associated with this distribution, `E` 1039 denotes expectation, and `Var.shape = batch_shape + event_shape`. 1040 1041 Args: 1042 name: Python `str` prepended to names of ops created by this function. 1043 1044 Returns: 1045 variance: Floating-point `Tensor` with shape identical to 1046 `batch_shape + event_shape`, i.e., the same shape as `self.mean()`. 1047 """ 1048 with self._name_scope(name): 1049 try: 1050 return self._variance() 1051 except NotImplementedError as original_exception: 1052 try: 1053 return math_ops.square(self._stddev()) 1054 except NotImplementedError: 1055 raise original_exception 1056 1057 def _stddev(self): 1058 raise NotImplementedError("stddev is not implemented: {}".format( 1059 type(self).__name__)) 1060 1061 def stddev(self, name="stddev"): 1062 """Standard deviation. 1063 1064 Standard deviation is defined as, 1065 1066 ```none 1067 stddev = E[(X - E[X])**2]**0.5 1068 ``` 1069 1070 where `X` is the random variable associated with this distribution, `E` 1071 denotes expectation, and `stddev.shape = batch_shape + event_shape`. 1072 1073 Args: 1074 name: Python `str` prepended to names of ops created by this function. 1075 1076 Returns: 1077 stddev: Floating-point `Tensor` with shape identical to 1078 `batch_shape + event_shape`, i.e., the same shape as `self.mean()`. 1079 """ 1080 1081 with self._name_scope(name): 1082 try: 1083 return self._stddev() 1084 except NotImplementedError as original_exception: 1085 try: 1086 return math_ops.sqrt(self._variance()) 1087 except NotImplementedError: 1088 raise original_exception 1089 1090 def _covariance(self): 1091 raise NotImplementedError("covariance is not implemented: {}".format( 1092 type(self).__name__)) 1093 1094 def covariance(self, name="covariance"): 1095 """Covariance. 1096 1097 Covariance is (possibly) defined only for non-scalar-event distributions. 1098 1099 For example, for a length-`k`, vector-valued distribution, it is calculated 1100 as, 1101 1102 ```none 1103 Cov[i, j] = Covariance(X_i, X_j) = E[(X_i - E[X_i]) (X_j - E[X_j])] 1104 ``` 1105 1106 where `Cov` is a (batch of) `k x k` matrix, `0 <= (i, j) < k`, and `E` 1107 denotes expectation. 1108 1109 Alternatively, for non-vector, multivariate distributions (e.g., 1110 matrix-valued, Wishart), `Covariance` shall return a (batch of) matrices 1111 under some vectorization of the events, i.e., 1112 1113 ```none 1114 Cov[i, j] = Covariance(Vec(X)_i, Vec(X)_j) = [as above] 1115 ``` 1116 1117 where `Cov` is a (batch of) `k' x k'` matrices, 1118 `0 <= (i, j) < k' = reduce_prod(event_shape)`, and `Vec` is some function 1119 mapping indices of this distribution's event dimensions to indices of a 1120 length-`k'` vector. 1121 1122 Args: 1123 name: Python `str` prepended to names of ops created by this function. 1124 1125 Returns: 1126 covariance: Floating-point `Tensor` with shape `[B1, ..., Bn, k', k']` 1127 where the first `n` dimensions are batch coordinates and 1128 `k' = reduce_prod(self.event_shape)`. 1129 """ 1130 with self._name_scope(name): 1131 return self._covariance() 1132 1133 def _mode(self): 1134 raise NotImplementedError("mode is not implemented: {}".format( 1135 type(self).__name__)) 1136 1137 def mode(self, name="mode"): 1138 """Mode.""" 1139 with self._name_scope(name): 1140 return self._mode() 1141 1142 def _cross_entropy(self, other): 1143 return kullback_leibler.cross_entropy( 1144 self, other, allow_nan_stats=self.allow_nan_stats) 1145 1146 def cross_entropy(self, other, name="cross_entropy"): 1147 """Computes the (Shannon) cross entropy. 1148 1149 Denote this distribution (`self`) by `P` and the `other` distribution by 1150 `Q`. Assuming `P, Q` are absolutely continuous with respect to 1151 one another and permit densities `p(x) dr(x)` and `q(x) dr(x)`, (Shanon) 1152 cross entropy is defined as: 1153 1154 ```none 1155 H[P, Q] = E_p[-log q(X)] = -int_F p(x) log q(x) dr(x) 1156 ``` 1157 1158 where `F` denotes the support of the random variable `X ~ P`. 1159 1160 Args: 1161 other: `tfp.distributions.Distribution` instance. 1162 name: Python `str` prepended to names of ops created by this function. 1163 1164 Returns: 1165 cross_entropy: `self.dtype` `Tensor` with shape `[B1, ..., Bn]` 1166 representing `n` different calculations of (Shanon) cross entropy. 1167 """ 1168 with self._name_scope(name): 1169 return self._cross_entropy(other) 1170 1171 def _kl_divergence(self, other): 1172 return kullback_leibler.kl_divergence( 1173 self, other, allow_nan_stats=self.allow_nan_stats) 1174 1175 def kl_divergence(self, other, name="kl_divergence"): 1176 """Computes the Kullback--Leibler divergence. 1177 1178 Denote this distribution (`self`) by `p` and the `other` distribution by 1179 `q`. Assuming `p, q` are absolutely continuous with respect to reference 1180 measure `r`, the KL divergence is defined as: 1181 1182 ```none 1183 KL[p, q] = E_p[log(p(X)/q(X))] 1184 = -int_F p(x) log q(x) dr(x) + int_F p(x) log p(x) dr(x) 1185 = H[p, q] - H[p] 1186 ``` 1187 1188 where `F` denotes the support of the random variable `X ~ p`, `H[., .]` 1189 denotes (Shanon) cross entropy, and `H[.]` denotes (Shanon) entropy. 1190 1191 Args: 1192 other: `tfp.distributions.Distribution` instance. 1193 name: Python `str` prepended to names of ops created by this function. 1194 1195 Returns: 1196 kl_divergence: `self.dtype` `Tensor` with shape `[B1, ..., Bn]` 1197 representing `n` different calculations of the Kullback-Leibler 1198 divergence. 1199 """ 1200 with self._name_scope(name): 1201 return self._kl_divergence(other) 1202 1203 def __str__(self): 1204 return ("tfp.distributions.{type_name}(" 1205 "\"{self_name}\"" 1206 "{maybe_batch_shape}" 1207 "{maybe_event_shape}" 1208 ", dtype={dtype})".format( 1209 type_name=type(self).__name__, 1210 self_name=self.name, 1211 maybe_batch_shape=(", batch_shape={}".format(self.batch_shape) 1212 if self.batch_shape.ndims is not None 1213 else ""), 1214 maybe_event_shape=(", event_shape={}".format(self.event_shape) 1215 if self.event_shape.ndims is not None 1216 else ""), 1217 dtype=self.dtype.name)) 1218 1219 def __repr__(self): 1220 return ("<tfp.distributions.{type_name} " 1221 "'{self_name}'" 1222 " batch_shape={batch_shape}" 1223 " event_shape={event_shape}" 1224 " dtype={dtype}>".format( 1225 type_name=type(self).__name__, 1226 self_name=self.name, 1227 batch_shape=self.batch_shape, 1228 event_shape=self.event_shape, 1229 dtype=self.dtype.name)) 1230 1231 @contextlib.contextmanager 1232 def _name_scope(self, name=None, values=None): 1233 """Helper function to standardize op scope.""" 1234 with ops.name_scope(self.name): 1235 with ops.name_scope(name, values=( 1236 ([] if values is None else values) + self._graph_parents)) as scope: 1237 yield scope 1238 1239 def _expand_sample_shape_to_vector(self, x, name): 1240 """Helper to `sample` which ensures input is 1D.""" 1241 x_static_val = tensor_util.constant_value(x) 1242 if x_static_val is None: 1243 prod = math_ops.reduce_prod(x) 1244 else: 1245 prod = np.prod(x_static_val, dtype=x.dtype.as_numpy_dtype()) 1246 1247 ndims = x.get_shape().ndims # != sample_ndims 1248 if ndims is None: 1249 # Maybe expand_dims. 1250 ndims = array_ops.rank(x) 1251 expanded_shape = util.pick_vector( 1252 math_ops.equal(ndims, 0), 1253 np.array([1], dtype=np.int32), array_ops.shape(x)) 1254 x = array_ops.reshape(x, expanded_shape) 1255 elif ndims == 0: 1256 # Definitely expand_dims. 1257 if x_static_val is not None: 1258 x = ops.convert_to_tensor( 1259 np.array([x_static_val], dtype=x.dtype.as_numpy_dtype()), 1260 name=name) 1261 else: 1262 x = array_ops.reshape(x, [1]) 1263 elif ndims != 1: 1264 raise ValueError("Input is neither scalar nor vector.") 1265 1266 return x, prod 1267 1268 def _set_sample_static_shape(self, x, sample_shape): 1269 """Helper to `sample`; sets static shape info.""" 1270 # Set shape hints. 1271 sample_shape = tensor_shape.TensorShape( 1272 tensor_util.constant_value(sample_shape)) 1273 1274 ndims = x.get_shape().ndims 1275 sample_ndims = sample_shape.ndims 1276 batch_ndims = self.batch_shape.ndims 1277 event_ndims = self.event_shape.ndims 1278 1279 # Infer rank(x). 1280 if (ndims is None and 1281 sample_ndims is not None and 1282 batch_ndims is not None and 1283 event_ndims is not None): 1284 ndims = sample_ndims + batch_ndims + event_ndims 1285 x.set_shape([None] * ndims) 1286 1287 # Infer sample shape. 1288 if ndims is not None and sample_ndims is not None: 1289 shape = sample_shape.concatenate([None]*(ndims - sample_ndims)) 1290 x.set_shape(x.get_shape().merge_with(shape)) 1291 1292 # Infer event shape. 1293 if ndims is not None and event_ndims is not None: 1294 shape = tensor_shape.TensorShape( 1295 [None]*(ndims - event_ndims)).concatenate(self.event_shape) 1296 x.set_shape(x.get_shape().merge_with(shape)) 1297 1298 # Infer batch shape. 1299 if batch_ndims is not None: 1300 if ndims is not None: 1301 if sample_ndims is None and event_ndims is not None: 1302 sample_ndims = ndims - batch_ndims - event_ndims 1303 elif event_ndims is None and sample_ndims is not None: 1304 event_ndims = ndims - batch_ndims - sample_ndims 1305 if sample_ndims is not None and event_ndims is not None: 1306 shape = tensor_shape.TensorShape([None]*sample_ndims).concatenate( 1307 self.batch_shape).concatenate([None]*event_ndims) 1308 x.set_shape(x.get_shape().merge_with(shape)) 1309 1310 return x 1311 1312 def _is_scalar_helper(self, static_shape, dynamic_shape_fn): 1313 """Implementation for `is_scalar_batch` and `is_scalar_event`.""" 1314 if static_shape.ndims is not None: 1315 return static_shape.ndims == 0 1316 shape = dynamic_shape_fn() 1317 if (shape.get_shape().ndims is not None and 1318 shape.get_shape().dims[0].value is not None): 1319 # If the static_shape is correctly written then we should never execute 1320 # this branch. We keep it just in case there's some unimagined corner 1321 # case. 1322 return shape.get_shape().as_list() == [0] 1323 return math_ops.equal(array_ops.shape(shape)[0], 0) 1324