• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""basic"""
16from mindspore import context
17from mindspore.ops import operations as P
18from mindspore.nn.cell import Cell
19from mindspore._checkparam import Validator as validator
20from ._utils.utils import raise_none_error, cast_to_tensor, set_param_type, cast_type_for_device,\
21    raise_not_implemented_util
22from ._utils.utils import CheckTuple, CheckTensor
23from ._utils.custom_ops import broadcast_to, exp_generic, log_generic
24
25
26class Distribution(Cell):
27    """
28    Base class for all mathematical distributions.
29
30    Args:
31        seed (int): The seed is used in sampling. 0 is used if it is None.
32        dtype (mindspore.dtype): The type of the event samples.
33        name (str): The name of the distribution.
34        param (dict): The parameters used to initialize the distribution.
35
36    Supported Platforms:
37        ``Ascend`` ``GPU``
38
39    Note:
40        Derived class must override operations such as `_mean`, `_prob`,
41        and `_log_prob`. Required arguments, such as `value` for `_prob`,
42        must be passed in through `args` or `kwargs`. `dist_spec_args` which specifies
43        a new distribution are optional.
44
45        `dist_spec_args` is unique for each type of distribution. For example, `mean` and `sd`
46        are the `dist_spec_args` for a Normal distribution, while `rate` is the `dist_spec_args`
47        for an Exponential distribution.
48
49        For all functions, passing in `dist_spec_args`, is optional.
50        Function calls with the additional `dist_spec_args` passed in will evaluate the result with
51        a new distribution specified by the `dist_spec_args`. However, it will not change the original distribution.
52    """
53
54    def __init__(self,
55                 seed,
56                 dtype,
57                 name,
58                 param):
59        """
60        Constructor of distribution class.
61        """
62        super(Distribution, self).__init__()
63        if seed is None:
64            seed = 0
65        validator.check_value_type('name', name, [str], type(self).__name__)
66        validator.check_non_negative_int(seed, 'seed', name)
67
68        self._name = name
69        self._seed = seed
70        self._dtype = cast_type_for_device(dtype)
71        self._parameters = {}
72
73        # parsing parameters
74        for k in param.keys():
75            if not(k == 'self' or k.startswith('_')):
76                self._parameters[k] = param[k]
77
78        # if not a transformed distribution, set the following attribute
79        if 'distribution' not in self.parameters.keys():
80            self.parameter_type = set_param_type(
81                self.parameters['param_dict'], dtype)
82            self._batch_shape = self._calc_batch_shape()
83            self._is_scalar_batch = self._check_is_scalar_batch()
84            self._broadcast_shape = self._batch_shape
85
86        # set the function to call according to the derived class's attributes
87        self._set_prob()
88        self._set_log_prob()
89        self._set_sd()
90        self._set_var()
91        self._set_cdf()
92        self._set_survival()
93        self._set_log_cdf()
94        self._set_log_survival()
95        self._set_cross_entropy()
96
97        self.context_mode = context.get_context('mode')
98        self.device_target = context.get_context('device_target')
99        self.checktuple = CheckTuple()
100        self.checktensor = CheckTensor()
101        self.broadcast = broadcast_to
102
103        # ops needed for the base class
104        self.cast_base = P.Cast()
105        self.dtype_base = P.DType()
106        self.exp_base = exp_generic
107        self.fill_base = P.Fill()
108        self.log_base = log_generic
109        self.sametypeshape_base = P.SameTypeShape()
110        self.sq_base = P.Square()
111        self.sqrt_base = P.Sqrt()
112        self.shape_base = P.Shape()
113
114    @property
115    def name(self):
116        return self._name
117
118    @property
119    def dtype(self):
120        return self._dtype
121
122    @property
123    def seed(self):
124        return self._seed
125
126    @property
127    def parameters(self):
128        return self._parameters
129
130    @property
131    def is_scalar_batch(self):
132        return self._is_scalar_batch
133
134    @property
135    def batch_shape(self):
136        return self._batch_shape
137
138    @property
139    def broadcast_shape(self):
140        return self._broadcast_shape
141
142    def _reset_parameters(self):
143        self.default_parameters = []
144        self.parameter_names = []
145
146    def _add_parameter(self, value, name):
147        """
148        Cast `value` to a tensor and add it to `self.default_parameters`.
149        Add `name` into  and `self.parameter_names`.
150        """
151        # initialize the attributes if they do not exist yet
152        if not hasattr(self, 'default_parameters'):
153            self.default_parameters = []
154            self.parameter_names = []
155        # cast value to a tensor if it is not None
156        value_t = None if value is None else cast_to_tensor(
157            value, self.parameter_type)
158        self.default_parameters += [value_t,]
159        self.parameter_names += [name,]
160        return value_t
161
162    def _check_param_type(self, *args):
163        """
164        Check the availability and validity of default parameters and `dist_spec_args`.
165        `dist_spec_args` passed in must be tensors. If default parameters of the distribution
166        are None, the parameters must be passed in through `args`.
167        """
168        broadcast_shape = None
169        broadcast_shape_tensor = None
170        common_dtype = None
171        out = []
172
173        for arg, name, default in zip(args, self.parameter_names, self.default_parameters):
174            # check if the argument is a Tensor
175            if arg is not None:
176                self.checktensor(arg, name)
177            else:
178                arg = default if default is not None else raise_none_error(
179                    name)
180
181            # broadcast if the number of args > 1
182            if broadcast_shape is None:
183                broadcast_shape = self.shape_base(arg)
184                common_dtype = self.dtype_base(arg)
185                broadcast_shape_tensor = self.fill_base(
186                    common_dtype, broadcast_shape, 1.0)
187            else:
188                broadcast_shape = self.shape_base(arg + broadcast_shape_tensor)
189                broadcast_shape_tensor = self.fill_base(
190                    common_dtype, broadcast_shape, 1.0)
191                arg = self.broadcast(arg, broadcast_shape_tensor)
192                # check if the arguments have the same dtype
193                self.sametypeshape_base(arg, broadcast_shape_tensor)
194
195            arg = self.cast_base(arg, self.parameter_type)
196            out.append(arg)
197
198        if len(out) == 1:
199            return out[0]
200
201        # broadcast all args to broadcast_shape
202        result = ()
203        for arg in out:
204            arg = self.broadcast(arg, broadcast_shape_tensor)
205            result = result + (arg,)
206        return result
207
208    def _check_value(self, value, name):
209        """
210        Check availability of `value` as a Tensor.
211        """
212        if self.context_mode == 0:
213            self.checktensor(value, name)
214            return value
215        return self.checktensor(value, name)
216
217    def _check_is_scalar_batch(self):
218        """
219        Check if the parameters used during initialization are scalars.
220        """
221        param_dict = self.parameters['param_dict']
222        for value in param_dict.values():
223            if value is None:
224                continue
225            if not isinstance(value, (int, float)):
226                return False
227        return True
228
229    def _calc_batch_shape(self):
230        """
231        Calculate the broadcast shape of the parameters used during initialization.
232        """
233        broadcast_shape_tensor = None
234        param_dict = self.parameters['param_dict']
235        for value in param_dict.values():
236            if value is None:
237                return None
238            if broadcast_shape_tensor is None:
239                broadcast_shape_tensor = cast_to_tensor(value)
240            else:
241                value = cast_to_tensor(value)
242                broadcast_shape_tensor = (value + broadcast_shape_tensor)
243        return broadcast_shape_tensor.shape
244
245    def _set_prob(self):
246        """
247        Set probability function based on the availability of `_prob` and `_log_likehood`.
248        """
249        if hasattr(self, '_prob'):
250            self._call_prob = self._prob
251        elif hasattr(self, '_log_prob'):
252            self._call_prob = self._calc_prob_from_log_prob
253        else:
254            self._call_prob = self._raise_not_implemented_error('prob')
255
256    def _set_sd(self):
257        """
258        Set standard deviation based on the availability of `_sd` and `_var`.
259        """
260        if hasattr(self, '_sd'):
261            self._call_sd = self._sd
262        elif hasattr(self, '_var'):
263            self._call_sd = self._calc_sd_from_var
264        else:
265            self._call_sd = self._raise_not_implemented_error('sd')
266
267    def _set_var(self):
268        """
269        Set variance based on the availability of `_sd` and `_var`.
270        """
271        if hasattr(self, '_var'):
272            self._call_var = self._var
273        elif hasattr(self, '_sd'):
274            self._call_var = self._calc_var_from_sd
275        else:
276            self._call_var = self._raise_not_implemented_error('var')
277
278    def _set_log_prob(self):
279        """
280        Set log probability based on the availability of `_prob` and `_log_prob`.
281        """
282        if hasattr(self, '_log_prob'):
283            self._call_log_prob = self._log_prob
284        elif hasattr(self, '_prob'):
285            self._call_log_prob = self._calc_log_prob_from_prob
286        else:
287            self._call_log_prob = self._raise_not_implemented_error('log_prob')
288
289    def _set_cdf(self):
290        """
291        Set cumulative distribution function (cdf) based on the availability of `_cdf` and `_log_cdf` and
292        `survival_functions`.
293        """
294        if hasattr(self, '_cdf'):
295            self._call_cdf = self._cdf
296        elif hasattr(self, '_log_cdf'):
297            self._call_cdf = self._calc_cdf_from_log_cdf
298        elif hasattr(self, '_survival_function'):
299            self._call_cdf = self._calc_cdf_from_survival
300        elif hasattr(self, '_log_survival'):
301            self._call_cdf = self._calc_cdf_from_log_survival
302        else:
303            self._call_cdf = self._raise_not_implemented_error('cdf')
304
305    def _set_survival(self):
306        """
307        Set survival function based on the availability of _survival function and `_log_survival`
308        and `_call_cdf`.
309        """
310        if not (hasattr(self, '_survival_function') or hasattr(self, '_log_survival') or
311                hasattr(self, '_cdf') or hasattr(self, '_log_cdf')):
312            self._call_survival = self._raise_not_implemented_error(
313                'survival_function')
314        elif hasattr(self, '_survival_function'):
315            self._call_survival = self._survival_function
316        elif hasattr(self, '_log_survival'):
317            self._call_survival = self._calc_survival_from_log_survival
318        elif hasattr(self, '_call_cdf'):
319            self._call_survival = self._calc_survival_from_call_cdf
320
321    def _set_log_cdf(self):
322        """
323        Set log cdf based on the availability of `_log_cdf` and `_call_cdf`.
324        """
325        if not (hasattr(self, '_log_cdf') or hasattr(self, '_cdf') or
326                hasattr(self, '_survival_function') or hasattr(self, '_log_survival')):
327            self._call_log_cdf = self._raise_not_implemented_error('log_cdf')
328        elif hasattr(self, '_log_cdf'):
329            self._call_log_cdf = self._log_cdf
330        elif hasattr(self, '_call_cdf'):
331            self._call_log_cdf = self._calc_log_cdf_from_call_cdf
332
333    def _set_log_survival(self):
334        """
335        Set log survival based on the availability of `_log_survival` and `_call_survival`.
336        """
337        if not (hasattr(self, '_log_survival') or hasattr(self, '_survival_function') or
338                hasattr(self, '_log_cdf') or hasattr(self, '_cdf')):
339            self._call_log_survival = self._raise_not_implemented_error(
340                'log_cdf')
341        elif hasattr(self, '_log_survival'):
342            self._call_log_survival = self._log_survival
343        elif hasattr(self, '_call_survival'):
344            self._call_log_survival = self._calc_log_survival_from_call_survival
345
346    def _set_cross_entropy(self):
347        """
348        Set log survival based on the availability of `_cross_entropy`.
349        """
350        if hasattr(self, '_cross_entropy'):
351            self._call_cross_entropy = self._cross_entropy
352        else:
353            self._call_cross_entropy = self._raise_not_implemented_error(
354                'cross_entropy')
355
356    def _get_dist_args(self, *args, **kwargs):
357        return raise_not_implemented_util('get_dist_args', self.name, *args, **kwargs)
358
359    def get_dist_args(self, *args, **kwargs):
360        """
361        Check the availability and validity of default parameters and `dist_spec_args`.
362
363        Args:
364            *args (list): the list of positional arguments forwarded to subclasses.
365            **kwargs (dict): the dictionary of keyword arguments forwarded to subclasses.
366
367        Note:
368            `dist_spec_args` must be passed in through list or dictionary. The order of `dist_spec_args`
369            should follow the initialization order of default parameters through `_add_parameter`.
370            If some `dist_spec_args` is None, the corresponding default parameter is returned.
371        """
372        return self._get_dist_args(*args, **kwargs)
373
374    def _get_dist_type(self):
375        return raise_not_implemented_util('get_dist_type', self.name)
376
377    def get_dist_type(self):
378        """
379        Return the type of the distribution.
380        """
381        return self._get_dist_type()
382
383    def _raise_not_implemented_error(self, func_name):
384        name = self.name
385
386        def raise_error(*args, **kwargs):
387            return raise_not_implemented_util(func_name, name, *args, **kwargs)
388        return raise_error
389
390    def log_prob(self, value, *args, **kwargs):
391        """
392        Evaluate the log probability(pdf or pmf) at the given value.
393
394        Args:
395            value (Tensor): value to be evaluated.
396            *args (list): the list of positional arguments forwarded to subclasses.
397            **kwargs (dict): the dictionary of keyword arguments forwarded to subclasses.
398
399        Note:
400            A distribution can be optionally passed to the function by passing its dist_spec_args through
401            `args` or `kwargs`.
402        """
403        return self._call_log_prob(value, *args, **kwargs)
404
405    def _calc_prob_from_log_prob(self, value, *args, **kwargs):
406        r"""
407        Evaluate prob from log probability.
408
409        .. math::
410            probability(x) = \exp(log_likehood(x))
411        """
412        return self.exp_base(self._log_prob(value, *args, **kwargs))
413
414    def prob(self, value, *args, **kwargs):
415        """
416        Evaluate the probability (pdf or pmf) at given value.
417
418        Args:
419            value (Tensor): value to be evaluated.
420            *args (list): the list of positional arguments forwarded to subclasses.
421            **kwargs (dict): the dictionary of keyword arguments forwarded to subclasses.
422
423        Note:
424            A distribution can be optionally passed to the function by passing its dist_spec_args through
425            `args` or `kwargs`.
426        """
427        return self._call_prob(value, *args, **kwargs)
428
429    def _calc_log_prob_from_prob(self, value, *args, **kwargs):
430        r"""
431        Evaluate log probability from probability.
432
433        .. math::
434            log_prob(x) = \log(prob(x))
435        """
436        return self.log_base(self._prob(value, *args, **kwargs))
437
438    def cdf(self, value, *args, **kwargs):
439        """
440        Evaluate the cdf at given value.
441
442        Args:
443            value (Tensor): value to be evaluated.
444            *args (list): the list of positional arguments forwarded to subclasses.
445            **kwargs (dict): the dictionary of keyword arguments forwarded to subclasses.
446
447        Note:
448            A distribution can be optionally passed to the function by passing its dist_spec_args through
449            `args` or `kwargs`.
450        """
451        return self._call_cdf(value, *args, **kwargs)
452
453    def _calc_cdf_from_log_cdf(self, value, *args, **kwargs):
454        r"""
455        Evaluate cdf from log_cdf.
456
457        .. math::
458            cdf(x) = \exp(log_cdf(x))
459        """
460        return self.exp_base(self._log_cdf(value, *args, **kwargs))
461
462    def _calc_cdf_from_survival(self, value, *args, **kwargs):
463        r"""
464        Evaluate cdf from survival function.
465
466        .. math::
467            cdf(x) =  1 - (survival_function(x))
468        """
469        return 1.0 - self._survival_function(value, *args, **kwargs)
470
471    def _calc_cdf_from_log_survival(self, value, *args, **kwargs):
472        r"""
473        Evaluate cdf from log survival function.
474
475        .. math::
476            cdf(x) =  1 - (\exp(log_survival(x)))
477        """
478        return 1.0 - self.exp_base(self._log_survival(value, *args, **kwargs))
479
480    def log_cdf(self, value, *args, **kwargs):
481        """
482        Evaluate the log cdf at given value.
483
484        Args:
485            value (Tensor): value to be evaluated.
486            *args (list): the list of positional arguments forwarded to subclasses.
487            **kwargs (dict: the dictionary of keyword arguments forwarded to subclasses.
488
489        Note:
490            A distribution can be optionally passed to the function by passing its dist_spec_args through
491            `args` or `kwargs`.
492        """
493        return self._call_log_cdf(value, *args, **kwargs)
494
495    def _calc_log_cdf_from_call_cdf(self, value, *args, **kwargs):
496        r"""
497        Evaluate log cdf from cdf.
498
499        .. math::
500            log_cdf(x) = \log(cdf(x))
501        """
502        return self.log_base(self._call_cdf(value, *args, **kwargs))
503
504    def survival_function(self, value, *args, **kwargs):
505        """
506        Evaluate the survival function at given value.
507
508        Args:
509            value (Tensor): value to be evaluated.
510            *args (list): the list of positional arguments forwarded to subclasses.
511            **kwargs (dict): the dictionary of keyword arguments forwarded to subclasses.
512
513        Note:
514            A distribution can be optionally passed to the function by passing its dist_spec_args through
515            `args` or `kwargs`.
516        """
517        return self._call_survival(value, *args, **kwargs)
518
519    def _calc_survival_from_call_cdf(self, value, *args, **kwargs):
520        r"""
521        Evaluate survival function from cdf.
522
523        .. math::
524            survival_function(x) =  1 - (cdf(x))
525        """
526        return 1.0 - self._call_cdf(value, *args, **kwargs)
527
528    def _calc_survival_from_log_survival(self, value, *args, **kwargs):
529        r"""
530        Evaluate survival function from log survival function.
531
532        .. math::
533            survival(x) = \exp(survival_function(x))
534        """
535        return self.exp_base(self._log_survival(value, *args, **kwargs))
536
537    def log_survival(self, value, *args, **kwargs):
538        """
539        Evaluate the log survival function at given value.
540
541        Args:
542            value (Tensor): value to be evaluated.
543            *args (list): the list of positional arguments forwarded to subclasses.
544            **kwargs (dict): the dictionary of keyword arguments forwarded to subclasses.
545
546        Note:
547            A distribution can be optionally passed to the function by passing its dist_spec_args through
548            `args` or `kwargs`.
549        """
550        return self._call_log_survival(value, *args, **kwargs)
551
552    def _calc_log_survival_from_call_survival(self, value, *args, **kwargs):
553        r"""
554        Evaluate log survival function from survival function.
555
556        .. math::
557            log_survival(x) = \log(survival_function(x))
558        """
559        return self.log_base(self._call_survival(value, *args, **kwargs))
560
561    def _kl_loss(self, *args, **kwargs):
562        return raise_not_implemented_util('kl_loss', self.name, *args, **kwargs)
563
564    def kl_loss(self, dist, *args, **kwargs):
565        """
566        Evaluate the KL divergence, i.e. KL(a||b).
567
568        Args:
569            dist (str): type of the distribution.
570            *args (list): the list of positional arguments forwarded to subclasses.
571            **kwargs (dict): the dictionary of keyword arguments forwarded to subclasses.
572
573        Note:
574            dist_spec_args of distribution b must be passed to the function through `args` or `kwargs`.
575            Passing in dist_spec_args of distribution a is optional.
576        """
577        return self._kl_loss(dist, *args, **kwargs)
578
579    def _mean(self, *args, **kwargs):
580        return raise_not_implemented_util('mean', self.name, *args, **kwargs)
581
582    def mean(self, *args, **kwargs):
583        """
584        Evaluate the mean.
585
586        Args:
587            *args (list): the list of positional arguments forwarded to subclasses.
588            **kwargs (dict): the dictionary of keyword arguments forwarded to subclasses.
589
590        Note:
591            A distribution can be optionally passed to the function by passing its *dist_spec_args* through
592            *args* or *kwargs*.
593        """
594        return self._mean(*args, **kwargs)
595
596    def _mode(self, *args, **kwargs):
597        return raise_not_implemented_util('mode', self.name, *args, **kwargs)
598
599    def mode(self, *args, **kwargs):
600        """
601        Evaluate the mode.
602
603        Args:
604            *args (list): the list of positional arguments forwarded to subclasses.
605            **kwargs (dict): the dictionary of keyword arguments forwarded to subclasses.
606
607        Note:
608            A distribution can be optionally passed to the function by passing its *dist_spec_args* through
609            *args* or *kwargs*.
610        """
611        return self._mode(*args, **kwargs)
612
613    def sd(self, *args, **kwargs):
614        """
615        Evaluate the standard deviation.
616
617        Args:
618            *args (list): the list of positional arguments forwarded to subclasses.
619            **kwargs (dict: the dictionary of keyword arguments forwarded to subclasses.
620
621        Note:
622            A distribution can be optionally passed to the function by passing its *dist_spec_args* through
623            *args* or *kwargs*.
624        """
625        return self._call_sd(*args, **kwargs)
626
627    def var(self, *args, **kwargs):
628        """
629        Evaluate the variance.
630
631        Args:
632            *args (list): the list of positional arguments forwarded to subclasses.
633            **kwargs (dict): the dictionary of keyword arguments forwarded to subclasses.
634
635        Note:
636            A distribution can be optionally passed to the function by passing its *dist_spec_args* through
637            *args* or *kwargs*.
638        """
639        return self._call_var(*args, **kwargs)
640
641    def _calc_sd_from_var(self, *args, **kwargs):
642        r"""
643        Evaluate log probability from probability.
644
645        .. math::
646            STD(x) = \sqrt(VAR(x))
647        """
648        return self.sqrt_base(self._var(*args, **kwargs))
649
650    def _calc_var_from_sd(self, *args, **kwargs):
651        r"""
652        Evaluate log probability from probability.
653
654        .. math::
655            VAR(x) = STD(x) ^ 2
656        """
657        return self.sq_base(self._sd(*args, **kwargs))
658
659    def _entropy(self, *args, **kwargs):
660        return raise_not_implemented_util('entropy', self.name, *args, **kwargs)
661
662    def entropy(self, *args, **kwargs):
663        """
664        Evaluate the entropy.
665
666        Args:
667            *args (list): the list of positional arguments forwarded to subclasses.
668            **kwargs (dict): the dictionary of keyword arguments forwarded to subclasses.
669
670        Note:
671            A distribution can be optionally passed to the function by passing its *dist_spec_args* through
672            *args* or *kwargs*.
673        """
674        return self._entropy(*args, **kwargs)
675
676    def cross_entropy(self, dist, *args, **kwargs):
677        """
678        Evaluate the cross_entropy between distribution a and b.
679
680        Args:
681            dist (str): type of the distribution.
682            *args (list): the list of positional arguments forwarded to subclasses.
683            **kwargs (dict): the dictionary of keyword arguments forwarded to subclasses.
684
685        Note:
686            dist_spec_args of distribution b must be passed to the function through `args` or `kwargs`.
687            Passing in dist_spec_args of distribution a is optional.
688        """
689        return self._call_cross_entropy(dist, *args, **kwargs)
690
691    def _calc_cross_entropy(self, dist, *args, **kwargs):
692        r"""
693        Evaluate cross_entropy from entropy and kl divergence.
694
695        .. math::
696            H(X, Y) = H(X) + KL(X||Y)
697        """
698        return self._entropy(*args, **kwargs) + self._kl_loss(dist, *args, **kwargs)
699
700    def _sample(self, *args, **kwargs):
701        return raise_not_implemented_util('sample', self.name, *args, **kwargs)
702
703    def sample(self, *args, **kwargs):
704        """
705        Sampling function.
706
707        Args:
708            shape (tuple): shape of the sample.
709            *args (list): the list of positional arguments forwarded to subclasses.
710            **kwargs (dict): the dictionary of keyword arguments forwarded to subclasses.
711
712        Note:
713            A distribution can be optionally passed to the function by passing its *dist_spec_args* through
714            *args* or *kwargs*.
715        """
716        return self._sample(*args, **kwargs)
717
718    def construct(self, name, *args, **kwargs):
719        """
720        Override `construct` in Cell.
721
722        Note:
723            Names of supported functions include:
724            'prob', 'log_prob', 'cdf', 'log_cdf', 'survival_function', 'log_survival',
725            'var', 'sd', 'mode', 'mean', 'entropy', 'kl_loss', 'cross_entropy', 'sample',
726            'get_dist_args', and 'get_dist_type'.
727
728        Args:
729            name (str): The name of the function.
730            *args (list): A list of positional arguments that the function needs.
731            **kwargs (dict): A dictionary of keyword arguments that the function needs.
732        """
733
734        if name == 'log_prob':
735            return self._call_log_prob(*args, **kwargs)
736        if name == 'prob':
737            return self._call_prob(*args, **kwargs)
738        if name == 'cdf':
739            return self._call_cdf(*args, **kwargs)
740        if name == 'log_cdf':
741            return self._call_log_cdf(*args, **kwargs)
742        if name == 'survival_function':
743            return self._call_survival(*args, **kwargs)
744        if name == 'log_survival':
745            return self._call_log_survival(*args, **kwargs)
746        if name == 'kl_loss':
747            return self._kl_loss(*args, **kwargs)
748        if name == 'mean':
749            return self._mean(*args, **kwargs)
750        if name == 'mode':
751            return self._mode(*args, **kwargs)
752        if name == 'sd':
753            return self._call_sd(*args, **kwargs)
754        if name == 'var':
755            return self._call_var(*args, **kwargs)
756        if name == 'entropy':
757            return self._entropy(*args, **kwargs)
758        if name == 'cross_entropy':
759            return self._call_cross_entropy(*args, **kwargs)
760        if name == 'sample':
761            return self._sample(*args, **kwargs)
762        if name == 'get_dist_args':
763            return self._get_dist_args(*args, **kwargs)
764        if name == 'get_dist_type':
765            return self._get_dist_type()
766        return raise_not_implemented_util(name, self.name, *args, **kwargs)
767