1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""basic""" 16from mindspore import context 17from mindspore.ops import operations as P 18from mindspore.nn.cell import Cell 19from mindspore._checkparam import Validator as validator 20from ._utils.utils import raise_none_error, cast_to_tensor, set_param_type, cast_type_for_device,\ 21 raise_not_implemented_util 22from ._utils.utils import CheckTuple, CheckTensor 23from ._utils.custom_ops import broadcast_to, exp_generic, log_generic 24 25 26class Distribution(Cell): 27 """ 28 Base class for all mathematical distributions. 29 30 Args: 31 seed (int): The seed is used in sampling. 0 is used if it is None. 32 dtype (mindspore.dtype): The type of the event samples. 33 name (str): The name of the distribution. 34 param (dict): The parameters used to initialize the distribution. 35 36 Supported Platforms: 37 ``Ascend`` ``GPU`` 38 39 Note: 40 Derived class must override operations such as `_mean`, `_prob`, 41 and `_log_prob`. Required arguments, such as `value` for `_prob`, 42 must be passed in through `args` or `kwargs`. `dist_spec_args` which specifies 43 a new distribution are optional. 44 45 `dist_spec_args` is unique for each type of distribution. For example, `mean` and `sd` 46 are the `dist_spec_args` for a Normal distribution, while `rate` is the `dist_spec_args` 47 for an Exponential distribution. 48 49 For all functions, passing in `dist_spec_args`, is optional. 50 Function calls with the additional `dist_spec_args` passed in will evaluate the result with 51 a new distribution specified by the `dist_spec_args`. However, it will not change the original distribution. 52 """ 53 54 def __init__(self, 55 seed, 56 dtype, 57 name, 58 param): 59 """ 60 Constructor of distribution class. 61 """ 62 super(Distribution, self).__init__() 63 if seed is None: 64 seed = 0 65 validator.check_value_type('name', name, [str], type(self).__name__) 66 validator.check_non_negative_int(seed, 'seed', name) 67 68 self._name = name 69 self._seed = seed 70 self._dtype = cast_type_for_device(dtype) 71 self._parameters = {} 72 73 # parsing parameters 74 for k in param.keys(): 75 if not(k == 'self' or k.startswith('_')): 76 self._parameters[k] = param[k] 77 78 # if not a transformed distribution, set the following attribute 79 if 'distribution' not in self.parameters.keys(): 80 self.parameter_type = set_param_type( 81 self.parameters['param_dict'], dtype) 82 self._batch_shape = self._calc_batch_shape() 83 self._is_scalar_batch = self._check_is_scalar_batch() 84 self._broadcast_shape = self._batch_shape 85 86 # set the function to call according to the derived class's attributes 87 self._set_prob() 88 self._set_log_prob() 89 self._set_sd() 90 self._set_var() 91 self._set_cdf() 92 self._set_survival() 93 self._set_log_cdf() 94 self._set_log_survival() 95 self._set_cross_entropy() 96 97 self.context_mode = context.get_context('mode') 98 self.device_target = context.get_context('device_target') 99 self.checktuple = CheckTuple() 100 self.checktensor = CheckTensor() 101 self.broadcast = broadcast_to 102 103 # ops needed for the base class 104 self.cast_base = P.Cast() 105 self.dtype_base = P.DType() 106 self.exp_base = exp_generic 107 self.fill_base = P.Fill() 108 self.log_base = log_generic 109 self.sametypeshape_base = P.SameTypeShape() 110 self.sq_base = P.Square() 111 self.sqrt_base = P.Sqrt() 112 self.shape_base = P.Shape() 113 114 @property 115 def name(self): 116 return self._name 117 118 @property 119 def dtype(self): 120 return self._dtype 121 122 @property 123 def seed(self): 124 return self._seed 125 126 @property 127 def parameters(self): 128 return self._parameters 129 130 @property 131 def is_scalar_batch(self): 132 return self._is_scalar_batch 133 134 @property 135 def batch_shape(self): 136 return self._batch_shape 137 138 @property 139 def broadcast_shape(self): 140 return self._broadcast_shape 141 142 def _reset_parameters(self): 143 self.default_parameters = [] 144 self.parameter_names = [] 145 146 def _add_parameter(self, value, name): 147 """ 148 Cast `value` to a tensor and add it to `self.default_parameters`. 149 Add `name` into and `self.parameter_names`. 150 """ 151 # initialize the attributes if they do not exist yet 152 if not hasattr(self, 'default_parameters'): 153 self.default_parameters = [] 154 self.parameter_names = [] 155 # cast value to a tensor if it is not None 156 value_t = None if value is None else cast_to_tensor( 157 value, self.parameter_type) 158 self.default_parameters += [value_t,] 159 self.parameter_names += [name,] 160 return value_t 161 162 def _check_param_type(self, *args): 163 """ 164 Check the availability and validity of default parameters and `dist_spec_args`. 165 `dist_spec_args` passed in must be tensors. If default parameters of the distribution 166 are None, the parameters must be passed in through `args`. 167 """ 168 broadcast_shape = None 169 broadcast_shape_tensor = None 170 common_dtype = None 171 out = [] 172 173 for arg, name, default in zip(args, self.parameter_names, self.default_parameters): 174 # check if the argument is a Tensor 175 if arg is not None: 176 self.checktensor(arg, name) 177 else: 178 arg = default if default is not None else raise_none_error( 179 name) 180 181 # broadcast if the number of args > 1 182 if broadcast_shape is None: 183 broadcast_shape = self.shape_base(arg) 184 common_dtype = self.dtype_base(arg) 185 broadcast_shape_tensor = self.fill_base( 186 common_dtype, broadcast_shape, 1.0) 187 else: 188 broadcast_shape = self.shape_base(arg + broadcast_shape_tensor) 189 broadcast_shape_tensor = self.fill_base( 190 common_dtype, broadcast_shape, 1.0) 191 arg = self.broadcast(arg, broadcast_shape_tensor) 192 # check if the arguments have the same dtype 193 self.sametypeshape_base(arg, broadcast_shape_tensor) 194 195 arg = self.cast_base(arg, self.parameter_type) 196 out.append(arg) 197 198 if len(out) == 1: 199 return out[0] 200 201 # broadcast all args to broadcast_shape 202 result = () 203 for arg in out: 204 arg = self.broadcast(arg, broadcast_shape_tensor) 205 result = result + (arg,) 206 return result 207 208 def _check_value(self, value, name): 209 """ 210 Check availability of `value` as a Tensor. 211 """ 212 if self.context_mode == 0: 213 self.checktensor(value, name) 214 return value 215 return self.checktensor(value, name) 216 217 def _check_is_scalar_batch(self): 218 """ 219 Check if the parameters used during initialization are scalars. 220 """ 221 param_dict = self.parameters['param_dict'] 222 for value in param_dict.values(): 223 if value is None: 224 continue 225 if not isinstance(value, (int, float)): 226 return False 227 return True 228 229 def _calc_batch_shape(self): 230 """ 231 Calculate the broadcast shape of the parameters used during initialization. 232 """ 233 broadcast_shape_tensor = None 234 param_dict = self.parameters['param_dict'] 235 for value in param_dict.values(): 236 if value is None: 237 return None 238 if broadcast_shape_tensor is None: 239 broadcast_shape_tensor = cast_to_tensor(value) 240 else: 241 value = cast_to_tensor(value) 242 broadcast_shape_tensor = (value + broadcast_shape_tensor) 243 return broadcast_shape_tensor.shape 244 245 def _set_prob(self): 246 """ 247 Set probability function based on the availability of `_prob` and `_log_likehood`. 248 """ 249 if hasattr(self, '_prob'): 250 self._call_prob = self._prob 251 elif hasattr(self, '_log_prob'): 252 self._call_prob = self._calc_prob_from_log_prob 253 else: 254 self._call_prob = self._raise_not_implemented_error('prob') 255 256 def _set_sd(self): 257 """ 258 Set standard deviation based on the availability of `_sd` and `_var`. 259 """ 260 if hasattr(self, '_sd'): 261 self._call_sd = self._sd 262 elif hasattr(self, '_var'): 263 self._call_sd = self._calc_sd_from_var 264 else: 265 self._call_sd = self._raise_not_implemented_error('sd') 266 267 def _set_var(self): 268 """ 269 Set variance based on the availability of `_sd` and `_var`. 270 """ 271 if hasattr(self, '_var'): 272 self._call_var = self._var 273 elif hasattr(self, '_sd'): 274 self._call_var = self._calc_var_from_sd 275 else: 276 self._call_var = self._raise_not_implemented_error('var') 277 278 def _set_log_prob(self): 279 """ 280 Set log probability based on the availability of `_prob` and `_log_prob`. 281 """ 282 if hasattr(self, '_log_prob'): 283 self._call_log_prob = self._log_prob 284 elif hasattr(self, '_prob'): 285 self._call_log_prob = self._calc_log_prob_from_prob 286 else: 287 self._call_log_prob = self._raise_not_implemented_error('log_prob') 288 289 def _set_cdf(self): 290 """ 291 Set cumulative distribution function (cdf) based on the availability of `_cdf` and `_log_cdf` and 292 `survival_functions`. 293 """ 294 if hasattr(self, '_cdf'): 295 self._call_cdf = self._cdf 296 elif hasattr(self, '_log_cdf'): 297 self._call_cdf = self._calc_cdf_from_log_cdf 298 elif hasattr(self, '_survival_function'): 299 self._call_cdf = self._calc_cdf_from_survival 300 elif hasattr(self, '_log_survival'): 301 self._call_cdf = self._calc_cdf_from_log_survival 302 else: 303 self._call_cdf = self._raise_not_implemented_error('cdf') 304 305 def _set_survival(self): 306 """ 307 Set survival function based on the availability of _survival function and `_log_survival` 308 and `_call_cdf`. 309 """ 310 if not (hasattr(self, '_survival_function') or hasattr(self, '_log_survival') or 311 hasattr(self, '_cdf') or hasattr(self, '_log_cdf')): 312 self._call_survival = self._raise_not_implemented_error( 313 'survival_function') 314 elif hasattr(self, '_survival_function'): 315 self._call_survival = self._survival_function 316 elif hasattr(self, '_log_survival'): 317 self._call_survival = self._calc_survival_from_log_survival 318 elif hasattr(self, '_call_cdf'): 319 self._call_survival = self._calc_survival_from_call_cdf 320 321 def _set_log_cdf(self): 322 """ 323 Set log cdf based on the availability of `_log_cdf` and `_call_cdf`. 324 """ 325 if not (hasattr(self, '_log_cdf') or hasattr(self, '_cdf') or 326 hasattr(self, '_survival_function') or hasattr(self, '_log_survival')): 327 self._call_log_cdf = self._raise_not_implemented_error('log_cdf') 328 elif hasattr(self, '_log_cdf'): 329 self._call_log_cdf = self._log_cdf 330 elif hasattr(self, '_call_cdf'): 331 self._call_log_cdf = self._calc_log_cdf_from_call_cdf 332 333 def _set_log_survival(self): 334 """ 335 Set log survival based on the availability of `_log_survival` and `_call_survival`. 336 """ 337 if not (hasattr(self, '_log_survival') or hasattr(self, '_survival_function') or 338 hasattr(self, '_log_cdf') or hasattr(self, '_cdf')): 339 self._call_log_survival = self._raise_not_implemented_error( 340 'log_cdf') 341 elif hasattr(self, '_log_survival'): 342 self._call_log_survival = self._log_survival 343 elif hasattr(self, '_call_survival'): 344 self._call_log_survival = self._calc_log_survival_from_call_survival 345 346 def _set_cross_entropy(self): 347 """ 348 Set log survival based on the availability of `_cross_entropy`. 349 """ 350 if hasattr(self, '_cross_entropy'): 351 self._call_cross_entropy = self._cross_entropy 352 else: 353 self._call_cross_entropy = self._raise_not_implemented_error( 354 'cross_entropy') 355 356 def _get_dist_args(self, *args, **kwargs): 357 return raise_not_implemented_util('get_dist_args', self.name, *args, **kwargs) 358 359 def get_dist_args(self, *args, **kwargs): 360 """ 361 Check the availability and validity of default parameters and `dist_spec_args`. 362 363 Args: 364 *args (list): the list of positional arguments forwarded to subclasses. 365 **kwargs (dict): the dictionary of keyword arguments forwarded to subclasses. 366 367 Note: 368 `dist_spec_args` must be passed in through list or dictionary. The order of `dist_spec_args` 369 should follow the initialization order of default parameters through `_add_parameter`. 370 If some `dist_spec_args` is None, the corresponding default parameter is returned. 371 """ 372 return self._get_dist_args(*args, **kwargs) 373 374 def _get_dist_type(self): 375 return raise_not_implemented_util('get_dist_type', self.name) 376 377 def get_dist_type(self): 378 """ 379 Return the type of the distribution. 380 """ 381 return self._get_dist_type() 382 383 def _raise_not_implemented_error(self, func_name): 384 name = self.name 385 386 def raise_error(*args, **kwargs): 387 return raise_not_implemented_util(func_name, name, *args, **kwargs) 388 return raise_error 389 390 def log_prob(self, value, *args, **kwargs): 391 """ 392 Evaluate the log probability(pdf or pmf) at the given value. 393 394 Args: 395 value (Tensor): value to be evaluated. 396 *args (list): the list of positional arguments forwarded to subclasses. 397 **kwargs (dict): the dictionary of keyword arguments forwarded to subclasses. 398 399 Note: 400 A distribution can be optionally passed to the function by passing its dist_spec_args through 401 `args` or `kwargs`. 402 """ 403 return self._call_log_prob(value, *args, **kwargs) 404 405 def _calc_prob_from_log_prob(self, value, *args, **kwargs): 406 r""" 407 Evaluate prob from log probability. 408 409 .. math:: 410 probability(x) = \exp(log_likehood(x)) 411 """ 412 return self.exp_base(self._log_prob(value, *args, **kwargs)) 413 414 def prob(self, value, *args, **kwargs): 415 """ 416 Evaluate the probability (pdf or pmf) at given value. 417 418 Args: 419 value (Tensor): value to be evaluated. 420 *args (list): the list of positional arguments forwarded to subclasses. 421 **kwargs (dict): the dictionary of keyword arguments forwarded to subclasses. 422 423 Note: 424 A distribution can be optionally passed to the function by passing its dist_spec_args through 425 `args` or `kwargs`. 426 """ 427 return self._call_prob(value, *args, **kwargs) 428 429 def _calc_log_prob_from_prob(self, value, *args, **kwargs): 430 r""" 431 Evaluate log probability from probability. 432 433 .. math:: 434 log_prob(x) = \log(prob(x)) 435 """ 436 return self.log_base(self._prob(value, *args, **kwargs)) 437 438 def cdf(self, value, *args, **kwargs): 439 """ 440 Evaluate the cdf at given value. 441 442 Args: 443 value (Tensor): value to be evaluated. 444 *args (list): the list of positional arguments forwarded to subclasses. 445 **kwargs (dict): the dictionary of keyword arguments forwarded to subclasses. 446 447 Note: 448 A distribution can be optionally passed to the function by passing its dist_spec_args through 449 `args` or `kwargs`. 450 """ 451 return self._call_cdf(value, *args, **kwargs) 452 453 def _calc_cdf_from_log_cdf(self, value, *args, **kwargs): 454 r""" 455 Evaluate cdf from log_cdf. 456 457 .. math:: 458 cdf(x) = \exp(log_cdf(x)) 459 """ 460 return self.exp_base(self._log_cdf(value, *args, **kwargs)) 461 462 def _calc_cdf_from_survival(self, value, *args, **kwargs): 463 r""" 464 Evaluate cdf from survival function. 465 466 .. math:: 467 cdf(x) = 1 - (survival_function(x)) 468 """ 469 return 1.0 - self._survival_function(value, *args, **kwargs) 470 471 def _calc_cdf_from_log_survival(self, value, *args, **kwargs): 472 r""" 473 Evaluate cdf from log survival function. 474 475 .. math:: 476 cdf(x) = 1 - (\exp(log_survival(x))) 477 """ 478 return 1.0 - self.exp_base(self._log_survival(value, *args, **kwargs)) 479 480 def log_cdf(self, value, *args, **kwargs): 481 """ 482 Evaluate the log cdf at given value. 483 484 Args: 485 value (Tensor): value to be evaluated. 486 *args (list): the list of positional arguments forwarded to subclasses. 487 **kwargs (dict: the dictionary of keyword arguments forwarded to subclasses. 488 489 Note: 490 A distribution can be optionally passed to the function by passing its dist_spec_args through 491 `args` or `kwargs`. 492 """ 493 return self._call_log_cdf(value, *args, **kwargs) 494 495 def _calc_log_cdf_from_call_cdf(self, value, *args, **kwargs): 496 r""" 497 Evaluate log cdf from cdf. 498 499 .. math:: 500 log_cdf(x) = \log(cdf(x)) 501 """ 502 return self.log_base(self._call_cdf(value, *args, **kwargs)) 503 504 def survival_function(self, value, *args, **kwargs): 505 """ 506 Evaluate the survival function at given value. 507 508 Args: 509 value (Tensor): value to be evaluated. 510 *args (list): the list of positional arguments forwarded to subclasses. 511 **kwargs (dict): the dictionary of keyword arguments forwarded to subclasses. 512 513 Note: 514 A distribution can be optionally passed to the function by passing its dist_spec_args through 515 `args` or `kwargs`. 516 """ 517 return self._call_survival(value, *args, **kwargs) 518 519 def _calc_survival_from_call_cdf(self, value, *args, **kwargs): 520 r""" 521 Evaluate survival function from cdf. 522 523 .. math:: 524 survival_function(x) = 1 - (cdf(x)) 525 """ 526 return 1.0 - self._call_cdf(value, *args, **kwargs) 527 528 def _calc_survival_from_log_survival(self, value, *args, **kwargs): 529 r""" 530 Evaluate survival function from log survival function. 531 532 .. math:: 533 survival(x) = \exp(survival_function(x)) 534 """ 535 return self.exp_base(self._log_survival(value, *args, **kwargs)) 536 537 def log_survival(self, value, *args, **kwargs): 538 """ 539 Evaluate the log survival function at given value. 540 541 Args: 542 value (Tensor): value to be evaluated. 543 *args (list): the list of positional arguments forwarded to subclasses. 544 **kwargs (dict): the dictionary of keyword arguments forwarded to subclasses. 545 546 Note: 547 A distribution can be optionally passed to the function by passing its dist_spec_args through 548 `args` or `kwargs`. 549 """ 550 return self._call_log_survival(value, *args, **kwargs) 551 552 def _calc_log_survival_from_call_survival(self, value, *args, **kwargs): 553 r""" 554 Evaluate log survival function from survival function. 555 556 .. math:: 557 log_survival(x) = \log(survival_function(x)) 558 """ 559 return self.log_base(self._call_survival(value, *args, **kwargs)) 560 561 def _kl_loss(self, *args, **kwargs): 562 return raise_not_implemented_util('kl_loss', self.name, *args, **kwargs) 563 564 def kl_loss(self, dist, *args, **kwargs): 565 """ 566 Evaluate the KL divergence, i.e. KL(a||b). 567 568 Args: 569 dist (str): type of the distribution. 570 *args (list): the list of positional arguments forwarded to subclasses. 571 **kwargs (dict): the dictionary of keyword arguments forwarded to subclasses. 572 573 Note: 574 dist_spec_args of distribution b must be passed to the function through `args` or `kwargs`. 575 Passing in dist_spec_args of distribution a is optional. 576 """ 577 return self._kl_loss(dist, *args, **kwargs) 578 579 def _mean(self, *args, **kwargs): 580 return raise_not_implemented_util('mean', self.name, *args, **kwargs) 581 582 def mean(self, *args, **kwargs): 583 """ 584 Evaluate the mean. 585 586 Args: 587 *args (list): the list of positional arguments forwarded to subclasses. 588 **kwargs (dict): the dictionary of keyword arguments forwarded to subclasses. 589 590 Note: 591 A distribution can be optionally passed to the function by passing its *dist_spec_args* through 592 *args* or *kwargs*. 593 """ 594 return self._mean(*args, **kwargs) 595 596 def _mode(self, *args, **kwargs): 597 return raise_not_implemented_util('mode', self.name, *args, **kwargs) 598 599 def mode(self, *args, **kwargs): 600 """ 601 Evaluate the mode. 602 603 Args: 604 *args (list): the list of positional arguments forwarded to subclasses. 605 **kwargs (dict): the dictionary of keyword arguments forwarded to subclasses. 606 607 Note: 608 A distribution can be optionally passed to the function by passing its *dist_spec_args* through 609 *args* or *kwargs*. 610 """ 611 return self._mode(*args, **kwargs) 612 613 def sd(self, *args, **kwargs): 614 """ 615 Evaluate the standard deviation. 616 617 Args: 618 *args (list): the list of positional arguments forwarded to subclasses. 619 **kwargs (dict: the dictionary of keyword arguments forwarded to subclasses. 620 621 Note: 622 A distribution can be optionally passed to the function by passing its *dist_spec_args* through 623 *args* or *kwargs*. 624 """ 625 return self._call_sd(*args, **kwargs) 626 627 def var(self, *args, **kwargs): 628 """ 629 Evaluate the variance. 630 631 Args: 632 *args (list): the list of positional arguments forwarded to subclasses. 633 **kwargs (dict): the dictionary of keyword arguments forwarded to subclasses. 634 635 Note: 636 A distribution can be optionally passed to the function by passing its *dist_spec_args* through 637 *args* or *kwargs*. 638 """ 639 return self._call_var(*args, **kwargs) 640 641 def _calc_sd_from_var(self, *args, **kwargs): 642 r""" 643 Evaluate log probability from probability. 644 645 .. math:: 646 STD(x) = \sqrt(VAR(x)) 647 """ 648 return self.sqrt_base(self._var(*args, **kwargs)) 649 650 def _calc_var_from_sd(self, *args, **kwargs): 651 r""" 652 Evaluate log probability from probability. 653 654 .. math:: 655 VAR(x) = STD(x) ^ 2 656 """ 657 return self.sq_base(self._sd(*args, **kwargs)) 658 659 def _entropy(self, *args, **kwargs): 660 return raise_not_implemented_util('entropy', self.name, *args, **kwargs) 661 662 def entropy(self, *args, **kwargs): 663 """ 664 Evaluate the entropy. 665 666 Args: 667 *args (list): the list of positional arguments forwarded to subclasses. 668 **kwargs (dict): the dictionary of keyword arguments forwarded to subclasses. 669 670 Note: 671 A distribution can be optionally passed to the function by passing its *dist_spec_args* through 672 *args* or *kwargs*. 673 """ 674 return self._entropy(*args, **kwargs) 675 676 def cross_entropy(self, dist, *args, **kwargs): 677 """ 678 Evaluate the cross_entropy between distribution a and b. 679 680 Args: 681 dist (str): type of the distribution. 682 *args (list): the list of positional arguments forwarded to subclasses. 683 **kwargs (dict): the dictionary of keyword arguments forwarded to subclasses. 684 685 Note: 686 dist_spec_args of distribution b must be passed to the function through `args` or `kwargs`. 687 Passing in dist_spec_args of distribution a is optional. 688 """ 689 return self._call_cross_entropy(dist, *args, **kwargs) 690 691 def _calc_cross_entropy(self, dist, *args, **kwargs): 692 r""" 693 Evaluate cross_entropy from entropy and kl divergence. 694 695 .. math:: 696 H(X, Y) = H(X) + KL(X||Y) 697 """ 698 return self._entropy(*args, **kwargs) + self._kl_loss(dist, *args, **kwargs) 699 700 def _sample(self, *args, **kwargs): 701 return raise_not_implemented_util('sample', self.name, *args, **kwargs) 702 703 def sample(self, *args, **kwargs): 704 """ 705 Sampling function. 706 707 Args: 708 shape (tuple): shape of the sample. 709 *args (list): the list of positional arguments forwarded to subclasses. 710 **kwargs (dict): the dictionary of keyword arguments forwarded to subclasses. 711 712 Note: 713 A distribution can be optionally passed to the function by passing its *dist_spec_args* through 714 *args* or *kwargs*. 715 """ 716 return self._sample(*args, **kwargs) 717 718 def construct(self, name, *args, **kwargs): 719 """ 720 Override `construct` in Cell. 721 722 Note: 723 Names of supported functions include: 724 'prob', 'log_prob', 'cdf', 'log_cdf', 'survival_function', 'log_survival', 725 'var', 'sd', 'mode', 'mean', 'entropy', 'kl_loss', 'cross_entropy', 'sample', 726 'get_dist_args', and 'get_dist_type'. 727 728 Args: 729 name (str): The name of the function. 730 *args (list): A list of positional arguments that the function needs. 731 **kwargs (dict): A dictionary of keyword arguments that the function needs. 732 """ 733 734 if name == 'log_prob': 735 return self._call_log_prob(*args, **kwargs) 736 if name == 'prob': 737 return self._call_prob(*args, **kwargs) 738 if name == 'cdf': 739 return self._call_cdf(*args, **kwargs) 740 if name == 'log_cdf': 741 return self._call_log_cdf(*args, **kwargs) 742 if name == 'survival_function': 743 return self._call_survival(*args, **kwargs) 744 if name == 'log_survival': 745 return self._call_log_survival(*args, **kwargs) 746 if name == 'kl_loss': 747 return self._kl_loss(*args, **kwargs) 748 if name == 'mean': 749 return self._mean(*args, **kwargs) 750 if name == 'mode': 751 return self._mode(*args, **kwargs) 752 if name == 'sd': 753 return self._call_sd(*args, **kwargs) 754 if name == 'var': 755 return self._call_var(*args, **kwargs) 756 if name == 'entropy': 757 return self._entropy(*args, **kwargs) 758 if name == 'cross_entropy': 759 return self._call_cross_entropy(*args, **kwargs) 760 if name == 'sample': 761 return self._sample(*args, **kwargs) 762 if name == 'get_dist_args': 763 return self._get_dist_args(*args, **kwargs) 764 if name == 'get_dist_type': 765 return self._get_dist_type() 766 return raise_not_implemented_util(name, self.name, *args, **kwargs) 767