1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""basic""" 16from mindspore import context 17from mindspore.ops import operations as P 18from mindspore.ops import functional as F 19from mindspore.nn.cell import Cell 20from mindspore.ops.primitive import constexpr 21from mindspore.ops.operations import _inner_ops as inner 22from mindspore import _checkparam as validator 23from ._utils.utils import raise_none_error, cast_to_tensor, set_param_type, cast_type_for_device,\ 24 raise_not_implemented_util 25from ._utils.utils import CheckTuple, CheckTensor 26from ._utils.custom_ops import broadcast_to, exp_generic, log_generic 27 28 29class Distribution(Cell): 30 """ 31 Base class for all mathematical distributions. 32 33 Args: 34 seed (int): The seed is used in sampling. 0 is used if it is None. 35 dtype (mindspore.dtype): The type of the event samples. 36 name (str): The name of the distribution. 37 param (dict): The parameters used to initialize the distribution. 38 39 Note: 40 Derived class must override operations such as `_mean`, `_prob`, 41 and `_log_prob`. Required arguments, such as `value` for `_prob`, 42 must be passed in through `args` or `kwargs`. `dist_spec_args` which specifies 43 a new distribution are optional. 44 45 `dist_spec_args` is unique for each type of distribution. For example, `mean` and `sd` 46 are the `dist_spec_args` for a Normal distribution, while `rate` is the `dist_spec_args` 47 for an Exponential distribution. 48 49 For all functions, passing in `dist_spec_args`, is optional. 50 Function calls with the additional `dist_spec_args` passed in will evaluate the result with 51 a new distribution specified by the `dist_spec_args`. However, it will not change the original distribution. 52 53 Supported Platforms: 54 ``Ascend`` ``GPU`` 55 """ 56 57 def __init__(self, 58 seed, 59 dtype, 60 name, 61 param): 62 """ 63 Constructor of distribution class. 64 """ 65 super(Distribution, self).__init__() 66 if seed is None: 67 seed = 0 68 validator.check_value_type('name', name, [str], type(self).__name__) 69 validator.check_non_negative_int(seed, 'seed', name) 70 71 self._name = name 72 self._seed = seed 73 self._dtype = cast_type_for_device(dtype) 74 self._parameters = {} 75 self.default_parameters = [] 76 self.parameter_names = [] 77 78 # parsing parameters 79 for k in param.keys(): 80 if not(k == 'self' or k.startswith('_')): 81 self._parameters[k] = param[k] 82 83 # if not a transformed distribution, set the following attribute 84 if 'distribution' not in self.parameters.keys(): 85 self.parameter_type = set_param_type( 86 self.parameters.get('param_dict', {}), dtype) 87 self._batch_shape = self._calc_batch_shape() 88 self._is_scalar_batch = self._check_is_scalar_batch() 89 self._broadcast_shape = self._batch_shape 90 91 # set the function to call according to the derived class's attributes 92 self._set_prob() 93 self._set_log_prob() 94 self._set_sd() 95 self._set_var() 96 self._set_cdf() 97 self._set_survival() 98 self._set_log_cdf() 99 self._set_log_survival() 100 self._set_cross_entropy() 101 102 self.context_mode = context.get_context('mode') 103 self.device_target = context.get_context('device_target') 104 self.checktuple = CheckTuple() 105 106 @constexpr(check=False) 107 def _check_tensor(x, name): 108 CheckTensor()(x, name) 109 return x 110 # we use constexpr to force CheckTensor to run only once in pynative mode 111 self.checktensor = CheckTensor() if self.context_mode == 0 else _check_tensor 112 self.broadcast = broadcast_to 113 114 # ops needed for the base class 115 self.cast_base = P.Cast() 116 self.dtype_base = P.DType() 117 self.sametypeshape_base = inner.SameTypeShape() 118 self.sq_base = P.Square() 119 self.sqrt_base = P.Sqrt() 120 self.shape_base = P.Shape() 121 if self.device_target != "Ascend": 122 self.log_base = P.Log() 123 self.exp_base = P.Exp() 124 else: 125 self.exp_base = exp_generic 126 self.log_base = log_generic 127 128 @property 129 def name(self): 130 return self._name 131 132 @property 133 def dtype(self): 134 return self._dtype 135 136 @property 137 def seed(self): 138 return self._seed 139 140 @property 141 def parameters(self): 142 return self._parameters 143 144 @property 145 def is_scalar_batch(self): 146 return self._is_scalar_batch 147 148 @property 149 def batch_shape(self): 150 return self._batch_shape 151 152 @property 153 def broadcast_shape(self): 154 return self._broadcast_shape 155 156 def _reset_parameters(self): 157 self.default_parameters = [] 158 self.parameter_names = [] 159 160 def _add_parameter(self, value, name): 161 """ 162 Cast `value` to a tensor and add it to `self.default_parameters`. 163 Add `name` into and `self.parameter_names`. 164 """ 165 # initialize the attributes if they do not exist yet 166 if not hasattr(self, 'default_parameters'): 167 self.default_parameters = [] 168 self.parameter_names = [] 169 # cast value to a tensor if it is not None 170 value_t = None if value is None else cast_to_tensor(value, self.parameter_type) 171 self.default_parameters.append(value_t) 172 self.parameter_names.append(name) 173 return value_t 174 175 def _check_param_type(self, *args): 176 """ 177 Check the availability and validity of default parameters and `dist_spec_args`. 178 `dist_spec_args` passed in must be tensors. If default parameters of the distribution 179 are None, the parameters must be passed in through `args`. 180 """ 181 broadcast_shape = None 182 broadcast_shape_tensor = None 183 common_dtype = None 184 out = [] 185 186 for arg, name, default in zip(args, self.parameter_names, self.default_parameters): 187 # check if the argument is a Tensor 188 if arg is not None: 189 self.checktensor(arg, name) 190 else: 191 arg = default if default is not None else raise_none_error(name) 192 193 # broadcast if the number of args > 1 194 if broadcast_shape is None: 195 broadcast_shape = self.shape_base(arg) 196 common_dtype = self.dtype_base(arg) 197 broadcast_shape_tensor = F.fill( 198 common_dtype, broadcast_shape, 1.0) 199 else: 200 broadcast_shape = self.shape_base(arg + broadcast_shape_tensor) 201 broadcast_shape_tensor = F.fill( 202 common_dtype, broadcast_shape, 1.0) 203 arg = self.broadcast(arg, broadcast_shape_tensor) 204 # check if the arguments have the same dtype 205 self.sametypeshape_base(arg, broadcast_shape_tensor) 206 207 arg = self.cast_base(arg, self.parameter_type) 208 out.append(arg) 209 210 if len(out) == 1: 211 return out[0] 212 213 # broadcast all args to broadcast_shape 214 result = () 215 for arg in out: 216 arg = self.broadcast(arg, broadcast_shape_tensor) 217 result = result + (arg,) 218 return result 219 220 def _check_value(self, value, name): 221 """ 222 Check availability of `value` as a Tensor. 223 """ 224 self.checktensor(value, name) 225 return value 226 227 def _check_is_scalar_batch(self): 228 """ 229 Check if the parameters used during initialization are scalars. 230 """ 231 param_dict = self.parameters.get('param_dict') 232 for value in param_dict.values(): 233 if value is None: 234 continue 235 if not isinstance(value, (int, float)): 236 return False 237 return True 238 239 def _calc_batch_shape(self): 240 """ 241 Calculate the broadcast shape of the parameters used during initialization. 242 """ 243 broadcast_shape_tensor = None 244 param_dict = self.parameters.get('param_dict') 245 for value in param_dict.values(): 246 if value is None: 247 return None 248 if broadcast_shape_tensor is None: 249 broadcast_shape_tensor = cast_to_tensor(value) 250 else: 251 value = cast_to_tensor(value) 252 broadcast_shape_tensor = (value + broadcast_shape_tensor) 253 return broadcast_shape_tensor.shape 254 255 def _set_prob(self): 256 """ 257 Set probability function based on the availability of `_prob` and `_log_likehood`. 258 """ 259 if hasattr(self, '_prob'): 260 self._call_prob = self._prob 261 elif hasattr(self, '_log_prob'): 262 self._call_prob = self._calc_prob_from_log_prob 263 else: 264 self._call_prob = self._raise_not_implemented_error('prob') 265 266 def _set_sd(self): 267 """ 268 Set standard deviation based on the availability of `_sd` and `_var`. 269 """ 270 if hasattr(self, '_sd'): 271 self._call_sd = self._sd 272 elif hasattr(self, '_var'): 273 self._call_sd = self._calc_sd_from_var 274 else: 275 self._call_sd = self._raise_not_implemented_error('sd') 276 277 def _set_var(self): 278 """ 279 Set variance based on the availability of `_sd` and `_var`. 280 """ 281 if hasattr(self, '_var'): 282 self._call_var = self._var 283 elif hasattr(self, '_sd'): 284 self._call_var = self._calc_var_from_sd 285 else: 286 self._call_var = self._raise_not_implemented_error('var') 287 288 def _set_log_prob(self): 289 """ 290 Set log probability based on the availability of `_prob` and `_log_prob`. 291 """ 292 if hasattr(self, '_log_prob'): 293 self._call_log_prob = self._log_prob 294 elif hasattr(self, '_prob'): 295 self._call_log_prob = self._calc_log_prob_from_prob 296 else: 297 self._call_log_prob = self._raise_not_implemented_error('log_prob') 298 299 def _set_cdf(self): 300 """ 301 Set cumulative distribution function (cdf) based on the availability of `_cdf` and `_log_cdf` and 302 `survival_functions`. 303 """ 304 if hasattr(self, '_cdf'): 305 self._call_cdf = self._cdf 306 elif hasattr(self, '_log_cdf'): 307 self._call_cdf = self._calc_cdf_from_log_cdf 308 elif hasattr(self, '_survival_function'): 309 self._call_cdf = self._calc_cdf_from_survival 310 elif hasattr(self, '_log_survival'): 311 self._call_cdf = self._calc_cdf_from_log_survival 312 else: 313 self._call_cdf = self._raise_not_implemented_error('cdf') 314 315 def _set_survival(self): 316 """ 317 Set survival function based on the availability of _survival function and `_log_survival` 318 and `_call_cdf`. 319 """ 320 if not (hasattr(self, '_survival_function') or hasattr(self, '_log_survival') or 321 hasattr(self, '_cdf') or hasattr(self, '_log_cdf')): 322 self._call_survival = self._raise_not_implemented_error( 323 'survival_function') 324 elif hasattr(self, '_survival_function'): 325 self._call_survival = self._survival_function 326 elif hasattr(self, '_log_survival'): 327 self._call_survival = self._calc_survival_from_log_survival 328 elif hasattr(self, '_call_cdf'): 329 self._call_survival = self._calc_survival_from_call_cdf 330 331 def _set_log_cdf(self): 332 """ 333 Set log cdf based on the availability of `_log_cdf` and `_call_cdf`. 334 """ 335 if not (hasattr(self, '_log_cdf') or hasattr(self, '_cdf') or 336 hasattr(self, '_survival_function') or hasattr(self, '_log_survival')): 337 self._call_log_cdf = self._raise_not_implemented_error('log_cdf') 338 elif hasattr(self, '_log_cdf'): 339 self._call_log_cdf = self._log_cdf 340 elif hasattr(self, '_call_cdf'): 341 self._call_log_cdf = self._calc_log_cdf_from_call_cdf 342 343 def _set_log_survival(self): 344 """ 345 Set log survival based on the availability of `_log_survival` and `_call_survival`. 346 """ 347 if not (hasattr(self, '_log_survival') or hasattr(self, '_survival_function') or 348 hasattr(self, '_log_cdf') or hasattr(self, '_cdf')): 349 self._call_log_survival = self._raise_not_implemented_error( 350 'log_cdf') 351 elif hasattr(self, '_log_survival'): 352 self._call_log_survival = self._log_survival 353 elif hasattr(self, '_call_survival'): 354 self._call_log_survival = self._calc_log_survival_from_call_survival 355 356 def _set_cross_entropy(self): 357 """ 358 Set log survival based on the availability of `_cross_entropy`. 359 """ 360 if hasattr(self, '_cross_entropy'): 361 self._call_cross_entropy = self._cross_entropy 362 else: 363 self._call_cross_entropy = self._raise_not_implemented_error( 364 'cross_entropy') 365 366 def _get_dist_args(self, *args, **kwargs): 367 return raise_not_implemented_util('get_dist_args', self.name, *args, **kwargs) 368 369 def get_dist_args(self, *args, **kwargs): 370 """ 371 Check the availability and validity of default parameters and `dist_spec_args`. 372 373 Args: 374 *args (list): the list of positional arguments forwarded to subclasses. 375 **kwargs (dict): the dictionary of keyword arguments forwarded to subclasses. 376 377 Note: 378 `dist_spec_args` must be passed in through list or dictionary. The order of `dist_spec_args` 379 should follow the initialization order of default parameters through `_add_parameter`. 380 If some `dist_spec_args` is None, the corresponding default parameter is returned. 381 382 Return: 383 list[Tensor], the list of parameters. 384 """ 385 return self._get_dist_args(*args, **kwargs) 386 387 def _get_dist_type(self): 388 return raise_not_implemented_util('get_dist_type', self.name) 389 390 def get_dist_type(self): 391 """ 392 Return the type of the distribution. 393 394 Return: 395 string, the name of distribution. 396 """ 397 return self._get_dist_type() 398 399 def _raise_not_implemented_error(self, func_name): 400 name = self.name 401 402 def raise_error(*args, **kwargs): 403 return raise_not_implemented_util(func_name, name, *args, **kwargs) 404 return raise_error 405 406 def log_prob(self, value, *args, **kwargs): 407 """ 408 Evaluate the log probability(pdf or pmf) at the given value. 409 410 Args: 411 value (Tensor): value to be evaluated. 412 *args (list): the list of positional arguments forwarded to subclasses. 413 **kwargs (dict): the dictionary of keyword arguments forwarded to subclasses. 414 415 Note: 416 A distribution can be optionally passed to the function by passing its `dist_spec_args` through 417 `args` or `kwargs`. 418 419 Return: 420 Tensor, the value of log probability. 421 """ 422 return self._call_log_prob(value, *args, **kwargs) 423 424 def _calc_prob_from_log_prob(self, value, *args, **kwargs): 425 r""" 426 Evaluate prob from log probability. 427 428 .. math:: 429 probability(x) = \exp(log_likehood(x)) 430 """ 431 return self.exp_base(self._log_prob(value, *args, **kwargs)) 432 433 def prob(self, value, *args, **kwargs): 434 """ 435 Evaluate the probability (pdf or pmf) at given value. For a discrete distribution, 436 it is a probability mass function, while for a continuous distribution, it is probability density function. 437 438 Args: 439 value (Tensor): value to be evaluated. 440 *args (list): the list of positional arguments forwarded to subclasses. 441 **kwargs (dict): the dictionary of keyword arguments forwarded to subclasses. 442 443 Note: 444 A distribution can be optionally passed to the function by passing its `dist_spec_args` through 445 `args` or `kwargs`. 446 447 Return: 448 Tensor, the value of probability. 449 """ 450 return self._call_prob(value, *args, **kwargs) 451 452 def _calc_log_prob_from_prob(self, value, *args, **kwargs): 453 r""" 454 Evaluate log probability from probability. 455 456 .. math:: 457 log_prob(x) = \log(prob(x)) 458 """ 459 return self.log_base(self._prob(value, *args, **kwargs)) 460 461 def cdf(self, value, *args, **kwargs): 462 """ 463 Evaluate the cumulative distribution function(cdf) at given value. 464 465 Args: 466 value (Tensor): value to be evaluated. 467 *args (list): the list of positional arguments forwarded to subclasses. 468 **kwargs (dict): the dictionary of keyword arguments forwarded to subclasses. 469 470 Note: 471 A distribution can be optionally passed to the function by passing its `dist_spec_args` through 472 `args` or `kwargs`. 473 474 Return: 475 Tensor, the cdf of the distribution. 476 """ 477 return self._call_cdf(value, *args, **kwargs) 478 479 def _calc_cdf_from_log_cdf(self, value, *args, **kwargs): 480 r""" 481 Evaluate cdf from log_cdf. 482 483 .. math:: 484 cdf(x) = \exp(log_cdf(x)) 485 """ 486 return self.exp_base(self._log_cdf(value, *args, **kwargs)) 487 488 def _calc_cdf_from_survival(self, value, *args, **kwargs): 489 r""" 490 Evaluate cdf from survival function. 491 492 .. math:: 493 cdf(x) = 1 - (survival_function(x)) 494 """ 495 return 1.0 - self._survival_function(value, *args, **kwargs) 496 497 def _calc_cdf_from_log_survival(self, value, *args, **kwargs): 498 r""" 499 Evaluate cdf from log survival function. 500 501 .. math:: 502 cdf(x) = 1 - (\exp(log_survival(x))) 503 """ 504 return 1.0 - self.exp_base(self._log_survival(value, *args, **kwargs)) 505 506 def log_cdf(self, value, *args, **kwargs): 507 """ 508 Evaluate the log the cumulative distribution function(cdf) at given value. 509 510 Args: 511 value (Tensor): value to be evaluated. 512 *args (list): the list of positional arguments forwarded to subclasses. 513 **kwargs (dict): the dictionary of keyword arguments forwarded to subclasses. 514 515 Note: 516 A distribution can be optionally passed to the function by passing its `dist_spec_args` through 517 `args` or `kwargs`. 518 519 Return: 520 Tensor, the log cdf of the distribution. 521 """ 522 return self._call_log_cdf(value, *args, **kwargs) 523 524 def _calc_log_cdf_from_call_cdf(self, value, *args, **kwargs): 525 r""" 526 Evaluate log cdf from cdf. 527 528 .. math:: 529 log_cdf(x) = \log(cdf(x)) 530 """ 531 return self.log_base(self._call_cdf(value, *args, **kwargs)) 532 533 def survival_function(self, value, *args, **kwargs): 534 """ 535 Evaluate the survival function at given value. 536 537 Args: 538 value (Tensor): value to be evaluated. 539 *args (list): the list of positional arguments forwarded to subclasses. 540 **kwargs (dict): the dictionary of keyword arguments forwarded to subclasses. 541 542 Note: 543 A distribution can be optionally passed to the function by passing its `dist_spec_args` through 544 `args` or `kwargs`. 545 546 Return: 547 Tensor, the survival function of the distribution. 548 """ 549 return self._call_survival(value, *args, **kwargs) 550 551 def _calc_survival_from_call_cdf(self, value, *args, **kwargs): 552 r""" 553 Evaluate survival function from cdf. 554 555 .. math:: 556 survival_function(x) = 1 - (cdf(x)) 557 """ 558 return 1.0 - self._call_cdf(value, *args, **kwargs) 559 560 def _calc_survival_from_log_survival(self, value, *args, **kwargs): 561 r""" 562 Evaluate survival function from log survival function. 563 564 .. math:: 565 survival(x) = \exp(survival_function(x)) 566 """ 567 return self.exp_base(self._log_survival(value, *args, **kwargs)) 568 569 def log_survival(self, value, *args, **kwargs): 570 """ 571 Evaluate the log survival function at given value. 572 573 Args: 574 value (Tensor): value to be evaluated. 575 *args (list): the list of positional arguments forwarded to subclasses. 576 **kwargs (dict): the dictionary of keyword arguments forwarded to subclasses. 577 578 Note: 579 A distribution can be optionally passed to the function by passing its `dist_spec_args` through 580 `args` or `kwargs`. 581 582 Return: 583 Tensor, the log survival function of the distribution. 584 """ 585 return self._call_log_survival(value, *args, **kwargs) 586 587 def _calc_log_survival_from_call_survival(self, value, *args, **kwargs): 588 r""" 589 Evaluate log survival function from survival function. 590 591 .. math:: 592 log_survival(x) = \log(survival_function(x)) 593 """ 594 return self.log_base(self._call_survival(value, *args, **kwargs)) 595 596 def _kl_loss(self, *args, **kwargs): 597 return raise_not_implemented_util('kl_loss', self.name, *args, **kwargs) 598 599 def kl_loss(self, dist, *args, **kwargs): 600 """ 601 Evaluate the KL divergence, i.e. KL(a||b). 602 603 Args: 604 dist (str): type of the distribution. 605 *args (list): the list of positional arguments forwarded to subclasses. 606 **kwargs (dict): the dictionary of keyword arguments forwarded to subclasses. 607 608 Note: 609 `dist_spec_args` of distribution b must be passed to the function through `args` or `kwargs`. 610 Passing in `dist_spec_args` of distribution a is optional. 611 612 Return: 613 Tensor, the kl loss function of the distribution. 614 """ 615 return self._kl_loss(dist, *args, **kwargs) 616 617 def _mean(self, *args, **kwargs): 618 return raise_not_implemented_util('mean', self.name, *args, **kwargs) 619 620 def mean(self, *args, **kwargs): 621 """ 622 Evaluate the mean. 623 624 Args: 625 *args (list): the list of positional arguments forwarded to subclasses. 626 **kwargs (dict): the dictionary of keyword arguments forwarded to subclasses. 627 628 Note: 629 A distribution can be optionally passed to the function by passing its *dist_spec_args* through 630 `args` or `kwargs`. 631 632 Return: 633 Tensor, the mean of the distribution. 634 """ 635 return self._mean(*args, **kwargs) 636 637 def _mode(self, *args, **kwargs): 638 return raise_not_implemented_util('mode', self.name, *args, **kwargs) 639 640 def mode(self, *args, **kwargs): 641 """ 642 Evaluate the mode. 643 644 Args: 645 *args (list): the list of positional arguments forwarded to subclasses. 646 **kwargs (dict): the dictionary of keyword arguments forwarded to subclasses. 647 648 Note: 649 A distribution can be optionally passed to the function by passing its *dist_spec_args* through 650 `args` or `kwargs`. 651 652 Return: 653 Tensor, the mode of the distribution. 654 """ 655 return self._mode(*args, **kwargs) 656 657 def sd(self, *args, **kwargs): 658 """ 659 Evaluate the standard deviation. 660 661 Args: 662 *args (list): the list of positional arguments forwarded to subclasses. 663 **kwargs (dict): the dictionary of keyword arguments forwarded to subclasses. 664 665 Note: 666 A distribution can be optionally passed to the function by passing its *dist_spec_args* through 667 `args` or `kwargs`. 668 669 Return: 670 Tensor, the standard deviation of the distribution. 671 """ 672 return self._call_sd(*args, **kwargs) 673 674 def var(self, *args, **kwargs): 675 """ 676 Evaluate the variance. 677 678 Args: 679 *args (list): the list of positional arguments forwarded to subclasses. 680 **kwargs (dict): the dictionary of keyword arguments forwarded to subclasses. 681 682 Note: 683 A distribution can be optionally passed to the function by passing its *dist_spec_args* through 684 `args` or `kwargs`. 685 686 Return: 687 Tensor, the variance of the distribution. 688 """ 689 return self._call_var(*args, **kwargs) 690 691 def _calc_sd_from_var(self, *args, **kwargs): 692 r""" 693 Evaluate log probability from probability. 694 695 .. math:: 696 STD(x) = \sqrt(VAR(x)) 697 """ 698 return self.sqrt_base(self._var(*args, **kwargs)) 699 700 def _calc_var_from_sd(self, *args, **kwargs): 701 r""" 702 Evaluate log probability from probability. 703 704 .. math:: 705 VAR(x) = STD(x) ^ 2 706 """ 707 return self.sq_base(self._sd(*args, **kwargs)) 708 709 def _entropy(self, *args, **kwargs): 710 return raise_not_implemented_util('entropy', self.name, *args, **kwargs) 711 712 def entropy(self, *args, **kwargs): 713 """ 714 Evaluate the entropy. 715 716 Args: 717 *args (list): the list of positional arguments forwarded to subclasses. 718 **kwargs (dict): the dictionary of keyword arguments forwarded to subclasses. 719 720 Note: 721 A distribution can be optionally passed to the function by passing its *dist_spec_args* through 722 `args` or `kwargs`. 723 724 Return: 725 Tensor, the entropy of the distribution. 726 """ 727 return self._entropy(*args, **kwargs) 728 729 def cross_entropy(self, dist, *args, **kwargs): 730 """ 731 Evaluate the cross_entropy between distribution a and b. 732 733 Args: 734 dist (str): type of the distribution. 735 *args (list): the list of positional arguments forwarded to subclasses. 736 **kwargs (dict): the dictionary of keyword arguments forwarded to subclasses. 737 738 Note: 739 `dist_spec_args` of distribution b must be passed to the function through `args` or `kwargs`. 740 Passing in `dist_spec_args` of distribution a is optional. 741 742 Return: 743 Tensor, the cross_entropy of two distributions. 744 """ 745 return self._call_cross_entropy(dist, *args, **kwargs) 746 747 def _calc_cross_entropy(self, dist, *args, **kwargs): 748 r""" 749 Evaluate cross_entropy from entropy and kl divergence. 750 751 .. math:: 752 H(X, Y) = H(X) + KL(X||Y) 753 """ 754 return self._entropy(*args, **kwargs) + self._kl_loss(dist, *args, **kwargs) 755 756 def _sample(self, *args, **kwargs): 757 return raise_not_implemented_util('sample', self.name, *args, **kwargs) 758 759 def sample(self, *args, **kwargs): 760 """ 761 Sampling function. 762 763 Args: 764 *args (list): the list of positional arguments forwarded to subclasses. 765 **kwargs (dict): the dictionary of keyword arguments forwarded to subclasses. 766 767 Note: 768 A distribution can be optionally passed to the function by passing its *dist_spec_args* through 769 `args` or `kwargs`. 770 771 Return: 772 Tensor, the sample generated from the distribution. 773 """ 774 return self._sample(*args, **kwargs) 775 776 def construct(self, name, *args, **kwargs): 777 """ 778 Override `construct` in Cell. 779 780 Note: 781 Names of supported functions include: 782 'prob', 'log_prob', 'cdf', 'log_cdf', 'survival_function', 'log_survival', 783 'var', 'sd', 'mode', 'mean', 'entropy', 'kl_loss', 'cross_entropy', 'sample', 784 'get_dist_args', and 'get_dist_type'. 785 786 Args: 787 name (str): The name of the function. 788 *args (list): A list of positional arguments that the function needs. 789 **kwargs (dict): A dictionary of keyword arguments that the function needs. 790 791 Return: 792 Tensor, the value of corresponding computation method. 793 """ 794 795 if name == 'log_prob': 796 return self._call_log_prob(*args, **kwargs) 797 if name == 'prob': 798 return self._call_prob(*args, **kwargs) 799 if name == 'cdf': 800 return self._call_cdf(*args, **kwargs) 801 if name == 'log_cdf': 802 return self._call_log_cdf(*args, **kwargs) 803 if name == 'survival_function': 804 return self._call_survival(*args, **kwargs) 805 if name == 'log_survival': 806 return self._call_log_survival(*args, **kwargs) 807 if name == 'kl_loss': 808 return self._kl_loss(*args, **kwargs) 809 if name == 'mean': 810 return self._mean(*args, **kwargs) 811 if name == 'mode': 812 return self._mode(*args, **kwargs) 813 if name == 'sd': 814 return self._call_sd(*args, **kwargs) 815 if name == 'var': 816 return self._call_var(*args, **kwargs) 817 if name == 'entropy': 818 return self._entropy(*args, **kwargs) 819 if name == 'cross_entropy': 820 return self._call_cross_entropy(*args, **kwargs) 821 if name == 'sample': 822 return self._sample(*args, **kwargs) 823 if name == 'get_dist_args': 824 return self._get_dist_args(*args, **kwargs) 825 if name == 'get_dist_type': 826 return self._get_dist_type() 827 return raise_not_implemented_util(name, self.name, *args, **kwargs) 828