1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Base classes for probability distributions.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import abc 22import contextlib 23import types 24 25import numpy as np 26import six 27 28from tensorflow.python.framework import dtypes 29from tensorflow.python.framework import ops 30from tensorflow.python.framework import tensor_shape 31from tensorflow.python.framework import tensor_util 32from tensorflow.python.ops import array_ops 33from tensorflow.python.ops import math_ops 34from tensorflow.python.ops.distributions import kullback_leibler 35from tensorflow.python.ops.distributions import util 36from tensorflow.python.util import tf_inspect 37from tensorflow.python.util.tf_export import tf_export 38 39 40__all__ = [ 41 "ReparameterizationType", 42 "FULLY_REPARAMETERIZED", 43 "NOT_REPARAMETERIZED", 44 "Distribution", 45] 46 47_DISTRIBUTION_PUBLIC_METHOD_WRAPPERS = [ 48 "batch_shape", 49 "batch_shape_tensor", 50 "cdf", 51 "covariance", 52 "cross_entropy", 53 "entropy", 54 "event_shape", 55 "event_shape_tensor", 56 "kl_divergence", 57 "log_cdf", 58 "log_prob", 59 "log_survival_function", 60 "mean", 61 "mode", 62 "prob", 63 "sample", 64 "stddev", 65 "survival_function", 66 "variance", 67] 68 69 70@six.add_metaclass(abc.ABCMeta) 71class _BaseDistribution(object): 72 """Abstract base class needed for resolving subclass hierarchy.""" 73 pass 74 75 76def _copy_fn(fn): 77 """Create a deep copy of fn. 78 79 Args: 80 fn: a callable 81 82 Returns: 83 A `FunctionType`: a deep copy of fn. 84 85 Raises: 86 TypeError: if `fn` is not a callable. 87 """ 88 if not callable(fn): 89 raise TypeError("fn is not callable: %s" % fn) 90 # The blessed way to copy a function. copy.deepcopy fails to create a 91 # non-reference copy. Since: 92 # types.FunctionType == type(lambda: None), 93 # and the docstring for the function type states: 94 # 95 # function(code, globals[, name[, argdefs[, closure]]]) 96 # 97 # Create a function object from a code object and a dictionary. 98 # ... 99 # 100 # Here we can use this to create a new function with the old function's 101 # code, globals, closure, etc. 102 return types.FunctionType( 103 code=fn.__code__, globals=fn.__globals__, 104 name=fn.__name__, argdefs=fn.__defaults__, 105 closure=fn.__closure__) 106 107 108def _update_docstring(old_str, append_str): 109 """Update old_str by inserting append_str just before the "Args:" section.""" 110 old_str = old_str or "" 111 old_str_lines = old_str.split("\n") 112 113 # Step 0: Prepend spaces to all lines of append_str. This is 114 # necessary for correct markdown generation. 115 append_str = "\n".join(" %s" % line for line in append_str.split("\n")) 116 117 # Step 1: Find mention of "Args": 118 has_args_ix = [ 119 ix for ix, line in enumerate(old_str_lines) 120 if line.strip().lower() == "args:"] 121 if has_args_ix: 122 final_args_ix = has_args_ix[-1] 123 return ("\n".join(old_str_lines[:final_args_ix]) 124 + "\n\n" + append_str + "\n\n" 125 + "\n".join(old_str_lines[final_args_ix:])) 126 else: 127 return old_str + "\n\n" + append_str 128 129 130class _DistributionMeta(abc.ABCMeta): 131 132 def __new__(mcs, classname, baseclasses, attrs): 133 """Control the creation of subclasses of the Distribution class. 134 135 The main purpose of this method is to properly propagate docstrings 136 from private Distribution methods, like `_log_prob`, into their 137 public wrappers as inherited by the Distribution base class 138 (e.g. `log_prob`). 139 140 Args: 141 classname: The name of the subclass being created. 142 baseclasses: A tuple of parent classes. 143 attrs: A dict mapping new attributes to their values. 144 145 Returns: 146 The class object. 147 148 Raises: 149 TypeError: If `Distribution` is not a subclass of `BaseDistribution`, or 150 the new class is derived via multiple inheritance and the first 151 parent class is not a subclass of `BaseDistribution`. 152 AttributeError: If `Distribution` does not implement e.g. `log_prob`. 153 ValueError: If a `Distribution` public method lacks a docstring. 154 """ 155 if not baseclasses: # Nothing to be done for Distribution 156 raise TypeError("Expected non-empty baseclass. Does Distribution " 157 "not subclass _BaseDistribution?") 158 which_base = [ 159 base for base in baseclasses 160 if base == _BaseDistribution or issubclass(base, Distribution)] 161 base = which_base[0] 162 if base == _BaseDistribution: # Nothing to be done for Distribution 163 return abc.ABCMeta.__new__(mcs, classname, baseclasses, attrs) 164 if not issubclass(base, Distribution): 165 raise TypeError("First parent class declared for %s must be " 166 "Distribution, but saw '%s'" % (classname, base.__name__)) 167 for attr in _DISTRIBUTION_PUBLIC_METHOD_WRAPPERS: 168 special_attr = "_%s" % attr 169 class_attr_value = attrs.get(attr, None) 170 if attr in attrs: 171 # The method is being overridden, do not update its docstring 172 continue 173 base_attr_value = getattr(base, attr, None) 174 if not base_attr_value: 175 raise AttributeError( 176 "Internal error: expected base class '%s' to implement method '%s'" 177 % (base.__name__, attr)) 178 class_special_attr_value = attrs.get(special_attr, None) 179 if class_special_attr_value is None: 180 # No _special method available, no need to update the docstring. 181 continue 182 class_special_attr_docstring = tf_inspect.getdoc(class_special_attr_value) 183 if not class_special_attr_docstring: 184 # No docstring to append. 185 continue 186 class_attr_value = _copy_fn(base_attr_value) 187 class_attr_docstring = tf_inspect.getdoc(base_attr_value) 188 if class_attr_docstring is None: 189 raise ValueError( 190 "Expected base class fn to contain a docstring: %s.%s" 191 % (base.__name__, attr)) 192 class_attr_value.__doc__ = _update_docstring( 193 class_attr_value.__doc__, 194 ("Additional documentation from `%s`:\n\n%s" 195 % (classname, class_special_attr_docstring))) 196 attrs[attr] = class_attr_value 197 198 return abc.ABCMeta.__new__(mcs, classname, baseclasses, attrs) 199 200 201@tf_export("distributions.ReparameterizationType") 202class ReparameterizationType(object): 203 """Instances of this class represent how sampling is reparameterized. 204 205 Two static instances exist in the distributions library, signifying 206 one of two possible properties for samples from a distribution: 207 208 `FULLY_REPARAMETERIZED`: Samples from the distribution are fully 209 reparameterized, and straight-through gradients are supported. 210 211 `NOT_REPARAMETERIZED`: Samples from the distribution are not fully 212 reparameterized, and straight-through gradients are either partially 213 unsupported or are not supported at all. In this case, for purposes of 214 e.g. RL or variational inference, it is generally safest to wrap the 215 sample results in a `stop_gradients` call and instead use policy 216 gradients / surrogate loss instead. 217 """ 218 219 def __init__(self, rep_type): 220 self._rep_type = rep_type 221 222 def __repr__(self): 223 return "<Reparameteriation Type: %s>" % self._rep_type 224 225 def __eq__(self, other): 226 """Determine if this `ReparameterizationType` is equal to another. 227 228 Since RepaparameterizationType instances are constant static global 229 instances, equality checks if two instances' id() values are equal. 230 231 Args: 232 other: Object to compare against. 233 234 Returns: 235 `self is other`. 236 """ 237 return self is other 238 239 240# Fully reparameterized distribution: samples from a fully 241# reparameterized distribution support straight-through gradients with 242# respect to all parameters. 243FULLY_REPARAMETERIZED = ReparameterizationType("FULLY_REPARAMETERIZED") 244tf_export("distributions.FULLY_REPARAMETERIZED").export_constant( 245 __name__, "FULLY_REPARAMETERIZED") 246 247 248# Not reparameterized distribution: samples from a non- 249# reparameterized distribution do not support straight-through gradients for 250# at least some of the parameters. 251NOT_REPARAMETERIZED = ReparameterizationType("NOT_REPARAMETERIZED") 252tf_export("distributions.NOT_REPARAMETERIZED").export_constant( 253 __name__, "NOT_REPARAMETERIZED") 254 255 256@six.add_metaclass(_DistributionMeta) 257@tf_export("distributions.Distribution") 258class Distribution(_BaseDistribution): 259 """A generic probability distribution base class. 260 261 `Distribution` is a base class for constructing and organizing properties 262 (e.g., mean, variance) of random variables (e.g, Bernoulli, Gaussian). 263 264 #### Subclassing 265 266 Subclasses are expected to implement a leading-underscore version of the 267 same-named function. The argument signature should be identical except for 268 the omission of `name="..."`. For example, to enable `log_prob(value, 269 name="log_prob")` a subclass should implement `_log_prob(value)`. 270 271 Subclasses can append to public-level docstrings by providing 272 docstrings for their method specializations. For example: 273 274 ```python 275 @util.AppendDocstring("Some other details.") 276 def _log_prob(self, value): 277 ... 278 ``` 279 280 would add the string "Some other details." to the `log_prob` function 281 docstring. This is implemented as a simple decorator to avoid python 282 linter complaining about missing Args/Returns/Raises sections in the 283 partial docstrings. 284 285 #### Broadcasting, batching, and shapes 286 287 All distributions support batches of independent distributions of that type. 288 The batch shape is determined by broadcasting together the parameters. 289 290 The shape of arguments to `__init__`, `cdf`, `log_cdf`, `prob`, and 291 `log_prob` reflect this broadcasting, as does the return value of `sample` and 292 `sample_n`. 293 294 `sample_n_shape = [n] + batch_shape + event_shape`, where `sample_n_shape` is 295 the shape of the `Tensor` returned from `sample_n`, `n` is the number of 296 samples, `batch_shape` defines how many independent distributions there are, 297 and `event_shape` defines the shape of samples from each of those independent 298 distributions. Samples are independent along the `batch_shape` dimensions, but 299 not necessarily so along the `event_shape` dimensions (depending on the 300 particulars of the underlying distribution). 301 302 Using the `Uniform` distribution as an example: 303 304 ```python 305 minval = 3.0 306 maxval = [[4.0, 6.0], 307 [10.0, 12.0]] 308 309 # Broadcasting: 310 # This instance represents 4 Uniform distributions. Each has a lower bound at 311 # 3.0 as the `minval` parameter was broadcasted to match `maxval`'s shape. 312 u = Uniform(minval, maxval) 313 314 # `event_shape` is `TensorShape([])`. 315 event_shape = u.event_shape 316 # `event_shape_t` is a `Tensor` which will evaluate to []. 317 event_shape_t = u.event_shape_tensor() 318 319 # Sampling returns a sample per distribution. `samples` has shape 320 # [5, 2, 2], which is [n] + batch_shape + event_shape, where n=5, 321 # batch_shape=[2, 2], and event_shape=[]. 322 samples = u.sample_n(5) 323 324 # The broadcasting holds across methods. Here we use `cdf` as an example. The 325 # same holds for `log_cdf` and the likelihood functions. 326 327 # `cum_prob` has shape [2, 2] as the `value` argument was broadcasted to the 328 # shape of the `Uniform` instance. 329 cum_prob_broadcast = u.cdf(4.0) 330 331 # `cum_prob`'s shape is [2, 2], one per distribution. No broadcasting 332 # occurred. 333 cum_prob_per_dist = u.cdf([[4.0, 5.0], 334 [6.0, 7.0]]) 335 336 # INVALID as the `value` argument is not broadcastable to the distribution's 337 # shape. 338 cum_prob_invalid = u.cdf([4.0, 5.0, 6.0]) 339 ``` 340 341 #### Parameter values leading to undefined statistics or distributions. 342 343 Some distributions do not have well-defined statistics for all initialization 344 parameter values. For example, the beta distribution is parameterized by 345 positive real numbers `concentration1` and `concentration0`, and does not have 346 well-defined mode if `concentration1 < 1` or `concentration0 < 1`. 347 348 The user is given the option of raising an exception or returning `NaN`. 349 350 ```python 351 a = tf.exp(tf.matmul(logits, weights_a)) 352 b = tf.exp(tf.matmul(logits, weights_b)) 353 354 # Will raise exception if ANY batch member has a < 1 or b < 1. 355 dist = distributions.beta(a, b, allow_nan_stats=False) 356 mode = dist.mode().eval() 357 358 # Will return NaN for batch members with either a < 1 or b < 1. 359 dist = distributions.beta(a, b, allow_nan_stats=True) # Default behavior 360 mode = dist.mode().eval() 361 ``` 362 363 In all cases, an exception is raised if *invalid* parameters are passed, e.g. 364 365 ```python 366 # Will raise an exception if any Op is run. 367 negative_a = -1.0 * a # beta distribution by definition has a > 0. 368 dist = distributions.beta(negative_a, b, allow_nan_stats=True) 369 dist.mean().eval() 370 ``` 371 372 """ 373 374 def __init__(self, 375 dtype, 376 reparameterization_type, 377 validate_args, 378 allow_nan_stats, 379 parameters=None, 380 graph_parents=None, 381 name=None): 382 """Constructs the `Distribution`. 383 384 **This is a private method for subclass use.** 385 386 Args: 387 dtype: The type of the event samples. `None` implies no type-enforcement. 388 reparameterization_type: Instance of `ReparameterizationType`. 389 If `distributions.FULLY_REPARAMETERIZED`, this 390 `Distribution` can be reparameterized in terms of some standard 391 distribution with a function whose Jacobian is constant for the support 392 of the standard distribution. If `distributions.NOT_REPARAMETERIZED`, 393 then no such reparameterization is available. 394 validate_args: Python `bool`, default `False`. When `True` distribution 395 parameters are checked for validity despite possibly degrading runtime 396 performance. When `False` invalid inputs may silently render incorrect 397 outputs. 398 allow_nan_stats: Python `bool`, default `True`. When `True`, statistics 399 (e.g., mean, mode, variance) use the value "`NaN`" to indicate the 400 result is undefined. When `False`, an exception is raised if one or 401 more of the statistic's batch members are undefined. 402 parameters: Python `dict` of parameters used to instantiate this 403 `Distribution`. 404 graph_parents: Python `list` of graph prerequisites of this 405 `Distribution`. 406 name: Python `str` name prefixed to Ops created by this class. Default: 407 subclass name. 408 409 Raises: 410 ValueError: if any member of graph_parents is `None` or not a `Tensor`. 411 """ 412 graph_parents = [] if graph_parents is None else graph_parents 413 for i, t in enumerate(graph_parents): 414 if t is None or not tensor_util.is_tensor(t): 415 raise ValueError("Graph parent item %d is not a Tensor; %s." % (i, t)) 416 self._dtype = dtype 417 self._reparameterization_type = reparameterization_type 418 self._allow_nan_stats = allow_nan_stats 419 self._validate_args = validate_args 420 self._parameters = parameters or {} 421 self._graph_parents = graph_parents 422 self._name = name or type(self).__name__ 423 424 @classmethod 425 def param_shapes(cls, sample_shape, name="DistributionParamShapes"): 426 """Shapes of parameters given the desired shape of a call to `sample()`. 427 428 This is a class method that describes what key/value arguments are required 429 to instantiate the given `Distribution` so that a particular shape is 430 returned for that instance's call to `sample()`. 431 432 Subclasses should override class method `_param_shapes`. 433 434 Args: 435 sample_shape: `Tensor` or python list/tuple. Desired shape of a call to 436 `sample()`. 437 name: name to prepend ops with. 438 439 Returns: 440 `dict` of parameter name to `Tensor` shapes. 441 """ 442 with ops.name_scope(name, values=[sample_shape]): 443 return cls._param_shapes(sample_shape) 444 445 @classmethod 446 def param_static_shapes(cls, sample_shape): 447 """param_shapes with static (i.e. `TensorShape`) shapes. 448 449 This is a class method that describes what key/value arguments are required 450 to instantiate the given `Distribution` so that a particular shape is 451 returned for that instance's call to `sample()`. Assumes that the sample's 452 shape is known statically. 453 454 Subclasses should override class method `_param_shapes` to return 455 constant-valued tensors when constant values are fed. 456 457 Args: 458 sample_shape: `TensorShape` or python list/tuple. Desired shape of a call 459 to `sample()`. 460 461 Returns: 462 `dict` of parameter name to `TensorShape`. 463 464 Raises: 465 ValueError: if `sample_shape` is a `TensorShape` and is not fully defined. 466 """ 467 if isinstance(sample_shape, tensor_shape.TensorShape): 468 if not sample_shape.is_fully_defined(): 469 raise ValueError("TensorShape sample_shape must be fully defined") 470 sample_shape = sample_shape.as_list() 471 472 params = cls.param_shapes(sample_shape) 473 474 static_params = {} 475 for name, shape in params.items(): 476 static_shape = tensor_util.constant_value(shape) 477 if static_shape is None: 478 raise ValueError( 479 "sample_shape must be a fully-defined TensorShape or list/tuple") 480 static_params[name] = tensor_shape.TensorShape(static_shape) 481 482 return static_params 483 484 @staticmethod 485 def _param_shapes(sample_shape): 486 raise NotImplementedError("_param_shapes not implemented") 487 488 @property 489 def name(self): 490 """Name prepended to all ops created by this `Distribution`.""" 491 return self._name 492 493 @property 494 def dtype(self): 495 """The `DType` of `Tensor`s handled by this `Distribution`.""" 496 return self._dtype 497 498 @property 499 def parameters(self): 500 """Dictionary of parameters used to instantiate this `Distribution`.""" 501 # Remove "self", "__class__", or other special variables. These can appear 502 # if the subclass used `parameters = locals()`. 503 return dict((k, v) for k, v in self._parameters.items() 504 if not k.startswith("__") and k != "self") 505 506 @property 507 def reparameterization_type(self): 508 """Describes how samples from the distribution are reparameterized. 509 510 Currently this is one of the static instances 511 `distributions.FULLY_REPARAMETERIZED` 512 or `distributions.NOT_REPARAMETERIZED`. 513 514 Returns: 515 An instance of `ReparameterizationType`. 516 """ 517 return self._reparameterization_type 518 519 @property 520 def allow_nan_stats(self): 521 """Python `bool` describing behavior when a stat is undefined. 522 523 Stats return +/- infinity when it makes sense. E.g., the variance of a 524 Cauchy distribution is infinity. However, sometimes the statistic is 525 undefined, e.g., if a distribution's pdf does not achieve a maximum within 526 the support of the distribution, the mode is undefined. If the mean is 527 undefined, then by definition the variance is undefined. E.g. the mean for 528 Student's T for df = 1 is undefined (no clear way to say it is either + or - 529 infinity), so the variance = E[(X - mean)**2] is also undefined. 530 531 Returns: 532 allow_nan_stats: Python `bool`. 533 """ 534 return self._allow_nan_stats 535 536 @property 537 def validate_args(self): 538 """Python `bool` indicating possibly expensive checks are enabled.""" 539 return self._validate_args 540 541 def copy(self, **override_parameters_kwargs): 542 """Creates a deep copy of the distribution. 543 544 Note: the copy distribution may continue to depend on the original 545 initialization arguments. 546 547 Args: 548 **override_parameters_kwargs: String/value dictionary of initialization 549 arguments to override with new values. 550 551 Returns: 552 distribution: A new instance of `type(self)` initialized from the union 553 of self.parameters and override_parameters_kwargs, i.e., 554 `dict(self.parameters, **override_parameters_kwargs)`. 555 """ 556 parameters = dict(self.parameters, **override_parameters_kwargs) 557 return type(self)(**parameters) 558 559 def _batch_shape_tensor(self): 560 raise NotImplementedError("batch_shape_tensor is not implemented") 561 562 def batch_shape_tensor(self, name="batch_shape_tensor"): 563 """Shape of a single sample from a single event index as a 1-D `Tensor`. 564 565 The batch dimensions are indexes into independent, non-identical 566 parameterizations of this distribution. 567 568 Args: 569 name: name to give to the op 570 571 Returns: 572 batch_shape: `Tensor`. 573 """ 574 with self._name_scope(name): 575 if self.batch_shape.is_fully_defined(): 576 return ops.convert_to_tensor(self.batch_shape.as_list(), 577 dtype=dtypes.int32, 578 name="batch_shape") 579 return self._batch_shape_tensor() 580 581 def _batch_shape(self): 582 return tensor_shape.TensorShape(None) 583 584 @property 585 def batch_shape(self): 586 """Shape of a single sample from a single event index as a `TensorShape`. 587 588 May be partially defined or unknown. 589 590 The batch dimensions are indexes into independent, non-identical 591 parameterizations of this distribution. 592 593 Returns: 594 batch_shape: `TensorShape`, possibly unknown. 595 """ 596 return self._batch_shape() 597 598 def _event_shape_tensor(self): 599 raise NotImplementedError("event_shape_tensor is not implemented") 600 601 def event_shape_tensor(self, name="event_shape_tensor"): 602 """Shape of a single sample from a single batch as a 1-D int32 `Tensor`. 603 604 Args: 605 name: name to give to the op 606 607 Returns: 608 event_shape: `Tensor`. 609 """ 610 with self._name_scope(name): 611 if self.event_shape.is_fully_defined(): 612 return ops.convert_to_tensor(self.event_shape.as_list(), 613 dtype=dtypes.int32, 614 name="event_shape") 615 return self._event_shape_tensor() 616 617 def _event_shape(self): 618 return tensor_shape.TensorShape(None) 619 620 @property 621 def event_shape(self): 622 """Shape of a single sample from a single batch as a `TensorShape`. 623 624 May be partially defined or unknown. 625 626 Returns: 627 event_shape: `TensorShape`, possibly unknown. 628 """ 629 return self._event_shape() 630 631 def is_scalar_event(self, name="is_scalar_event"): 632 """Indicates that `event_shape == []`. 633 634 Args: 635 name: Python `str` prepended to names of ops created by this function. 636 637 Returns: 638 is_scalar_event: `bool` scalar `Tensor`. 639 """ 640 with self._name_scope(name): 641 return ops.convert_to_tensor( 642 self._is_scalar_helper(self.event_shape, self.event_shape_tensor), 643 name="is_scalar_event") 644 645 def is_scalar_batch(self, name="is_scalar_batch"): 646 """Indicates that `batch_shape == []`. 647 648 Args: 649 name: Python `str` prepended to names of ops created by this function. 650 651 Returns: 652 is_scalar_batch: `bool` scalar `Tensor`. 653 """ 654 with self._name_scope(name): 655 return ops.convert_to_tensor( 656 self._is_scalar_helper(self.batch_shape, self.batch_shape_tensor), 657 name="is_scalar_batch") 658 659 def _sample_n(self, n, seed=None): 660 raise NotImplementedError("sample_n is not implemented") 661 662 def _call_sample_n(self, sample_shape, seed, name, **kwargs): 663 with self._name_scope(name, values=[sample_shape]): 664 sample_shape = ops.convert_to_tensor( 665 sample_shape, dtype=dtypes.int32, name="sample_shape") 666 sample_shape, n = self._expand_sample_shape_to_vector( 667 sample_shape, "sample_shape") 668 samples = self._sample_n(n, seed, **kwargs) 669 batch_event_shape = array_ops.shape(samples)[1:] 670 final_shape = array_ops.concat([sample_shape, batch_event_shape], 0) 671 samples = array_ops.reshape(samples, final_shape) 672 samples = self._set_sample_static_shape(samples, sample_shape) 673 return samples 674 675 def sample(self, sample_shape=(), seed=None, name="sample"): 676 """Generate samples of the specified shape. 677 678 Note that a call to `sample()` without arguments will generate a single 679 sample. 680 681 Args: 682 sample_shape: 0D or 1D `int32` `Tensor`. Shape of the generated samples. 683 seed: Python integer seed for RNG 684 name: name to give to the op. 685 686 Returns: 687 samples: a `Tensor` with prepended dimensions `sample_shape`. 688 """ 689 return self._call_sample_n(sample_shape, seed, name) 690 691 def _log_prob(self, value): 692 raise NotImplementedError("log_prob is not implemented") 693 694 def _call_log_prob(self, value, name, **kwargs): 695 with self._name_scope(name, values=[value]): 696 value = ops.convert_to_tensor(value, name="value") 697 try: 698 return self._log_prob(value, **kwargs) 699 except NotImplementedError as original_exception: 700 try: 701 return math_ops.log(self._prob(value, **kwargs)) 702 except NotImplementedError: 703 raise original_exception 704 705 def log_prob(self, value, name="log_prob"): 706 """Log probability density/mass function. 707 708 Args: 709 value: `float` or `double` `Tensor`. 710 name: Python `str` prepended to names of ops created by this function. 711 712 Returns: 713 log_prob: a `Tensor` of shape `sample_shape(x) + self.batch_shape` with 714 values of type `self.dtype`. 715 """ 716 return self._call_log_prob(value, name) 717 718 def _prob(self, value): 719 raise NotImplementedError("prob is not implemented") 720 721 def _call_prob(self, value, name, **kwargs): 722 with self._name_scope(name, values=[value]): 723 value = ops.convert_to_tensor(value, name="value") 724 try: 725 return self._prob(value, **kwargs) 726 except NotImplementedError as original_exception: 727 try: 728 return math_ops.exp(self._log_prob(value, **kwargs)) 729 except NotImplementedError: 730 raise original_exception 731 732 def prob(self, value, name="prob"): 733 """Probability density/mass function. 734 735 Args: 736 value: `float` or `double` `Tensor`. 737 name: Python `str` prepended to names of ops created by this function. 738 739 Returns: 740 prob: a `Tensor` of shape `sample_shape(x) + self.batch_shape` with 741 values of type `self.dtype`. 742 """ 743 return self._call_prob(value, name) 744 745 def _log_cdf(self, value): 746 raise NotImplementedError("log_cdf is not implemented") 747 748 def _call_log_cdf(self, value, name, **kwargs): 749 with self._name_scope(name, values=[value]): 750 value = ops.convert_to_tensor(value, name="value") 751 try: 752 return self._log_cdf(value, **kwargs) 753 except NotImplementedError as original_exception: 754 try: 755 return math_ops.log(self._cdf(value, **kwargs)) 756 except NotImplementedError: 757 raise original_exception 758 759 def log_cdf(self, value, name="log_cdf"): 760 """Log cumulative distribution function. 761 762 Given random variable `X`, the cumulative distribution function `cdf` is: 763 764 ```none 765 log_cdf(x) := Log[ P[X <= x] ] 766 ``` 767 768 Often, a numerical approximation can be used for `log_cdf(x)` that yields 769 a more accurate answer than simply taking the logarithm of the `cdf` when 770 `x << -1`. 771 772 Args: 773 value: `float` or `double` `Tensor`. 774 name: Python `str` prepended to names of ops created by this function. 775 776 Returns: 777 logcdf: a `Tensor` of shape `sample_shape(x) + self.batch_shape` with 778 values of type `self.dtype`. 779 """ 780 return self._call_log_cdf(value, name) 781 782 def _cdf(self, value): 783 raise NotImplementedError("cdf is not implemented") 784 785 def _call_cdf(self, value, name, **kwargs): 786 with self._name_scope(name, values=[value]): 787 value = ops.convert_to_tensor(value, name="value") 788 try: 789 return self._cdf(value, **kwargs) 790 except NotImplementedError as original_exception: 791 try: 792 return math_ops.exp(self._log_cdf(value, **kwargs)) 793 except NotImplementedError: 794 raise original_exception 795 796 def cdf(self, value, name="cdf"): 797 """Cumulative distribution function. 798 799 Given random variable `X`, the cumulative distribution function `cdf` is: 800 801 ```none 802 cdf(x) := P[X <= x] 803 ``` 804 805 Args: 806 value: `float` or `double` `Tensor`. 807 name: Python `str` prepended to names of ops created by this function. 808 809 Returns: 810 cdf: a `Tensor` of shape `sample_shape(x) + self.batch_shape` with 811 values of type `self.dtype`. 812 """ 813 return self._call_cdf(value, name) 814 815 def _log_survival_function(self, value): 816 raise NotImplementedError("log_survival_function is not implemented") 817 818 def _call_log_survival_function(self, value, name, **kwargs): 819 with self._name_scope(name, values=[value]): 820 value = ops.convert_to_tensor(value, name="value") 821 try: 822 return self._log_survival_function(value, **kwargs) 823 except NotImplementedError as original_exception: 824 try: 825 return math_ops.log1p(-self.cdf(value, **kwargs)) 826 except NotImplementedError: 827 raise original_exception 828 829 def log_survival_function(self, value, name="log_survival_function"): 830 """Log survival function. 831 832 Given random variable `X`, the survival function is defined: 833 834 ```none 835 log_survival_function(x) = Log[ P[X > x] ] 836 = Log[ 1 - P[X <= x] ] 837 = Log[ 1 - cdf(x) ] 838 ``` 839 840 Typically, different numerical approximations can be used for the log 841 survival function, which are more accurate than `1 - cdf(x)` when `x >> 1`. 842 843 Args: 844 value: `float` or `double` `Tensor`. 845 name: Python `str` prepended to names of ops created by this function. 846 847 Returns: 848 `Tensor` of shape `sample_shape(x) + self.batch_shape` with values of type 849 `self.dtype`. 850 """ 851 return self._call_log_survival_function(value, name) 852 853 def _survival_function(self, value): 854 raise NotImplementedError("survival_function is not implemented") 855 856 def _call_survival_function(self, value, name, **kwargs): 857 with self._name_scope(name, values=[value]): 858 value = ops.convert_to_tensor(value, name="value") 859 try: 860 return self._survival_function(value, **kwargs) 861 except NotImplementedError as original_exception: 862 try: 863 return 1. - self.cdf(value, **kwargs) 864 except NotImplementedError: 865 raise original_exception 866 867 def survival_function(self, value, name="survival_function"): 868 """Survival function. 869 870 Given random variable `X`, the survival function is defined: 871 872 ```none 873 survival_function(x) = P[X > x] 874 = 1 - P[X <= x] 875 = 1 - cdf(x). 876 ``` 877 878 Args: 879 value: `float` or `double` `Tensor`. 880 name: Python `str` prepended to names of ops created by this function. 881 882 Returns: 883 `Tensor` of shape `sample_shape(x) + self.batch_shape` with values of type 884 `self.dtype`. 885 """ 886 return self._call_survival_function(value, name) 887 888 def _entropy(self): 889 raise NotImplementedError("entropy is not implemented") 890 891 def entropy(self, name="entropy"): 892 """Shannon entropy in nats.""" 893 with self._name_scope(name): 894 return self._entropy() 895 896 def _mean(self): 897 raise NotImplementedError("mean is not implemented") 898 899 def mean(self, name="mean"): 900 """Mean.""" 901 with self._name_scope(name): 902 return self._mean() 903 904 def _quantile(self, value): 905 raise NotImplementedError("quantile is not implemented") 906 907 def _call_quantile(self, value, name, **kwargs): 908 with self._name_scope(name, values=[value]): 909 value = ops.convert_to_tensor(value, name="value") 910 try: 911 return self._quantile(value, **kwargs) 912 except NotImplementedError as original_exception: 913 raise original_exception 914 915 def quantile(self, value, name="quantile"): 916 """Quantile function. Aka "inverse cdf" or "percent point function". 917 918 Given random variable `X` and `p in [0, 1]`, the `quantile` is: 919 920 ```none 921 quantile(p) := x such that P[X <= x] == p 922 ``` 923 924 Args: 925 value: `float` or `double` `Tensor`. 926 name: Python `str` prepended to names of ops created by this function. 927 928 Returns: 929 quantile: a `Tensor` of shape `sample_shape(x) + self.batch_shape` with 930 values of type `self.dtype`. 931 """ 932 return self._call_quantile(value, name) 933 934 def _variance(self): 935 raise NotImplementedError("variance is not implemented") 936 937 def variance(self, name="variance"): 938 """Variance. 939 940 Variance is defined as, 941 942 ```none 943 Var = E[(X - E[X])**2] 944 ``` 945 946 where `X` is the random variable associated with this distribution, `E` 947 denotes expectation, and `Var.shape = batch_shape + event_shape`. 948 949 Args: 950 name: Python `str` prepended to names of ops created by this function. 951 952 Returns: 953 variance: Floating-point `Tensor` with shape identical to 954 `batch_shape + event_shape`, i.e., the same shape as `self.mean()`. 955 """ 956 with self._name_scope(name): 957 try: 958 return self._variance() 959 except NotImplementedError as original_exception: 960 try: 961 return math_ops.square(self._stddev()) 962 except NotImplementedError: 963 raise original_exception 964 965 def _stddev(self): 966 raise NotImplementedError("stddev is not implemented") 967 968 def stddev(self, name="stddev"): 969 """Standard deviation. 970 971 Standard deviation is defined as, 972 973 ```none 974 stddev = E[(X - E[X])**2]**0.5 975 ``` 976 977 where `X` is the random variable associated with this distribution, `E` 978 denotes expectation, and `stddev.shape = batch_shape + event_shape`. 979 980 Args: 981 name: Python `str` prepended to names of ops created by this function. 982 983 Returns: 984 stddev: Floating-point `Tensor` with shape identical to 985 `batch_shape + event_shape`, i.e., the same shape as `self.mean()`. 986 """ 987 988 with self._name_scope(name): 989 try: 990 return self._stddev() 991 except NotImplementedError as original_exception: 992 try: 993 return math_ops.sqrt(self._variance()) 994 except NotImplementedError: 995 raise original_exception 996 997 def _covariance(self): 998 raise NotImplementedError("covariance is not implemented") 999 1000 def covariance(self, name="covariance"): 1001 """Covariance. 1002 1003 Covariance is (possibly) defined only for non-scalar-event distributions. 1004 1005 For example, for a length-`k`, vector-valued distribution, it is calculated 1006 as, 1007 1008 ```none 1009 Cov[i, j] = Covariance(X_i, X_j) = E[(X_i - E[X_i]) (X_j - E[X_j])] 1010 ``` 1011 1012 where `Cov` is a (batch of) `k x k` matrix, `0 <= (i, j) < k`, and `E` 1013 denotes expectation. 1014 1015 Alternatively, for non-vector, multivariate distributions (e.g., 1016 matrix-valued, Wishart), `Covariance` shall return a (batch of) matrices 1017 under some vectorization of the events, i.e., 1018 1019 ```none 1020 Cov[i, j] = Covariance(Vec(X)_i, Vec(X)_j) = [as above] 1021 ``` 1022 1023 where `Cov` is a (batch of) `k' x k'` matrices, 1024 `0 <= (i, j) < k' = reduce_prod(event_shape)`, and `Vec` is some function 1025 mapping indices of this distribution's event dimensions to indices of a 1026 length-`k'` vector. 1027 1028 Args: 1029 name: Python `str` prepended to names of ops created by this function. 1030 1031 Returns: 1032 covariance: Floating-point `Tensor` with shape `[B1, ..., Bn, k', k']` 1033 where the first `n` dimensions are batch coordinates and 1034 `k' = reduce_prod(self.event_shape)`. 1035 """ 1036 with self._name_scope(name): 1037 return self._covariance() 1038 1039 def _mode(self): 1040 raise NotImplementedError("mode is not implemented") 1041 1042 def mode(self, name="mode"): 1043 """Mode.""" 1044 with self._name_scope(name): 1045 return self._mode() 1046 1047 def _cross_entropy(self, other): 1048 return kullback_leibler.cross_entropy( 1049 self, other, allow_nan_stats=self.allow_nan_stats) 1050 1051 def cross_entropy(self, other, name="cross_entropy"): 1052 """Computes the (Shannon) cross entropy. 1053 1054 Denote this distribution (`self`) by `P` and the `other` distribution by 1055 `Q`. Assuming `P, Q` are absolutely continuous with respect to 1056 one another and permit densities `p(x) dr(x)` and `q(x) dr(x)`, (Shanon) 1057 cross entropy is defined as: 1058 1059 ```none 1060 H[P, Q] = E_p[-log q(X)] = -int_F p(x) log q(x) dr(x) 1061 ``` 1062 1063 where `F` denotes the support of the random variable `X ~ P`. 1064 1065 Args: 1066 other: `tf.distributions.Distribution` instance. 1067 name: Python `str` prepended to names of ops created by this function. 1068 1069 Returns: 1070 cross_entropy: `self.dtype` `Tensor` with shape `[B1, ..., Bn]` 1071 representing `n` different calculations of (Shanon) cross entropy. 1072 """ 1073 with self._name_scope(name): 1074 return self._cross_entropy(other) 1075 1076 def _kl_divergence(self, other): 1077 return kullback_leibler.kl_divergence( 1078 self, other, allow_nan_stats=self.allow_nan_stats) 1079 1080 def kl_divergence(self, other, name="kl_divergence"): 1081 """Computes the Kullback--Leibler divergence. 1082 1083 Denote this distribution (`self`) by `p` and the `other` distribution by 1084 `q`. Assuming `p, q` are absolutely continuous with respect to reference 1085 measure `r`, the KL divergence is defined as: 1086 1087 ```none 1088 KL[p, q] = E_p[log(p(X)/q(X))] 1089 = -int_F p(x) log q(x) dr(x) + int_F p(x) log p(x) dr(x) 1090 = H[p, q] - H[p] 1091 ``` 1092 1093 where `F` denotes the support of the random variable `X ~ p`, `H[., .]` 1094 denotes (Shanon) cross entropy, and `H[.]` denotes (Shanon) entropy. 1095 1096 Args: 1097 other: `tf.distributions.Distribution` instance. 1098 name: Python `str` prepended to names of ops created by this function. 1099 1100 Returns: 1101 kl_divergence: `self.dtype` `Tensor` with shape `[B1, ..., Bn]` 1102 representing `n` different calculations of the Kullback-Leibler 1103 divergence. 1104 """ 1105 with self._name_scope(name): 1106 return self._kl_divergence(other) 1107 1108 @contextlib.contextmanager 1109 def _name_scope(self, name=None, values=None): 1110 """Helper function to standardize op scope.""" 1111 with ops.name_scope(self.name): 1112 with ops.name_scope(name, values=( 1113 ([] if values is None else values) + self._graph_parents)) as scope: 1114 yield scope 1115 1116 def _expand_sample_shape_to_vector(self, x, name): 1117 """Helper to `sample` which ensures input is 1D.""" 1118 x_static_val = tensor_util.constant_value(x) 1119 if x_static_val is None: 1120 prod = math_ops.reduce_prod(x) 1121 else: 1122 prod = np.prod(x_static_val, dtype=x.dtype.as_numpy_dtype()) 1123 1124 ndims = x.get_shape().ndims # != sample_ndims 1125 if ndims is None: 1126 # Maybe expand_dims. 1127 ndims = array_ops.rank(x) 1128 expanded_shape = util.pick_vector( 1129 math_ops.equal(ndims, 0), 1130 np.array([1], dtype=np.int32), array_ops.shape(x)) 1131 x = array_ops.reshape(x, expanded_shape) 1132 elif ndims == 0: 1133 # Definitely expand_dims. 1134 if x_static_val is not None: 1135 x = ops.convert_to_tensor( 1136 np.array([x_static_val], dtype=x.dtype.as_numpy_dtype()), 1137 name=name) 1138 else: 1139 x = array_ops.reshape(x, [1]) 1140 elif ndims != 1: 1141 raise ValueError("Input is neither scalar nor vector.") 1142 1143 return x, prod 1144 1145 def _set_sample_static_shape(self, x, sample_shape): 1146 """Helper to `sample`; sets static shape info.""" 1147 # Set shape hints. 1148 sample_shape = tensor_shape.TensorShape( 1149 tensor_util.constant_value(sample_shape)) 1150 1151 ndims = x.get_shape().ndims 1152 sample_ndims = sample_shape.ndims 1153 batch_ndims = self.batch_shape.ndims 1154 event_ndims = self.event_shape.ndims 1155 1156 # Infer rank(x). 1157 if (ndims is None and 1158 sample_ndims is not None and 1159 batch_ndims is not None and 1160 event_ndims is not None): 1161 ndims = sample_ndims + batch_ndims + event_ndims 1162 x.set_shape([None] * ndims) 1163 1164 # Infer sample shape. 1165 if ndims is not None and sample_ndims is not None: 1166 shape = sample_shape.concatenate([None]*(ndims - sample_ndims)) 1167 x.set_shape(x.get_shape().merge_with(shape)) 1168 1169 # Infer event shape. 1170 if ndims is not None and event_ndims is not None: 1171 shape = tensor_shape.TensorShape( 1172 [None]*(ndims - event_ndims)).concatenate(self.event_shape) 1173 x.set_shape(x.get_shape().merge_with(shape)) 1174 1175 # Infer batch shape. 1176 if batch_ndims is not None: 1177 if ndims is not None: 1178 if sample_ndims is None and event_ndims is not None: 1179 sample_ndims = ndims - batch_ndims - event_ndims 1180 elif event_ndims is None and sample_ndims is not None: 1181 event_ndims = ndims - batch_ndims - sample_ndims 1182 if sample_ndims is not None and event_ndims is not None: 1183 shape = tensor_shape.TensorShape([None]*sample_ndims).concatenate( 1184 self.batch_shape).concatenate([None]*event_ndims) 1185 x.set_shape(x.get_shape().merge_with(shape)) 1186 1187 return x 1188 1189 def _is_scalar_helper(self, static_shape, dynamic_shape_fn): 1190 """Implementation for `is_scalar_batch` and `is_scalar_event`.""" 1191 if static_shape.ndims is not None: 1192 return static_shape.ndims == 0 1193 shape = dynamic_shape_fn() 1194 if (shape.get_shape().ndims is not None and 1195 shape.get_shape()[0].value is not None): 1196 # If the static_shape is correctly written then we should never execute 1197 # this branch. We keep it just in case there's some unimagined corner 1198 # case. 1199 return shape.get_shape().as_list() == [0] 1200 return math_ops.equal(array_ops.shape(shape)[0], 0) 1201