• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""basic"""
16from mindspore import context
17from mindspore.ops import operations as P
18from mindspore.ops import functional as F
19from mindspore.nn.cell import Cell
20from mindspore.ops.primitive import constexpr
21from mindspore.ops.operations import _inner_ops as inner
22from mindspore import _checkparam as validator
23from ._utils.utils import raise_none_error, cast_to_tensor, set_param_type, cast_type_for_device,\
24    raise_not_implemented_util
25from ._utils.utils import CheckTuple, CheckTensor
26from ._utils.custom_ops import broadcast_to, exp_generic, log_generic
27
28
29class Distribution(Cell):
30    """
31    Base class for all mathematical distributions.
32
33    Args:
34        seed (int): The seed is used in sampling. 0 is used if it is None.
35        dtype (mindspore.dtype): The type of the event samples.
36        name (str): The name of the distribution.
37        param (dict): The parameters used to initialize the distribution.
38
39    Note:
40        Derived class must override operations such as `_mean`, `_prob`,
41        and `_log_prob`. Required arguments, such as `value` for `_prob`,
42        must be passed in through `args` or `kwargs`. `dist_spec_args` which specifies
43        a new distribution are optional.
44
45        `dist_spec_args` is unique for each type of distribution. For example, `mean` and `sd`
46        are the `dist_spec_args` for a Normal distribution, while `rate` is the `dist_spec_args`
47        for an Exponential distribution.
48
49        For all functions, passing in `dist_spec_args`, is optional.
50        Function calls with the additional `dist_spec_args` passed in will evaluate the result with
51        a new distribution specified by the `dist_spec_args`. However, it will not change the original distribution.
52
53    Supported Platforms:
54        ``Ascend`` ``GPU``
55    """
56
57    def __init__(self,
58                 seed,
59                 dtype,
60                 name,
61                 param):
62        """
63        Constructor of distribution class.
64        """
65        super(Distribution, self).__init__()
66        if seed is None:
67            seed = 0
68        validator.check_value_type('name', name, [str], type(self).__name__)
69        validator.check_non_negative_int(seed, 'seed', name)
70
71        self._name = name
72        self._seed = seed
73        self._dtype = cast_type_for_device(dtype)
74        self._parameters = {}
75        self.default_parameters = []
76        self.parameter_names = []
77
78        # parsing parameters
79        for k in param.keys():
80            if not(k == 'self' or k.startswith('_')):
81                self._parameters[k] = param[k]
82
83        # if not a transformed distribution, set the following attribute
84        if 'distribution' not in self.parameters.keys():
85            self.parameter_type = set_param_type(
86                self.parameters.get('param_dict', {}), dtype)
87            self._batch_shape = self._calc_batch_shape()
88            self._is_scalar_batch = self._check_is_scalar_batch()
89            self._broadcast_shape = self._batch_shape
90
91        # set the function to call according to the derived class's attributes
92        self._set_prob()
93        self._set_log_prob()
94        self._set_sd()
95        self._set_var()
96        self._set_cdf()
97        self._set_survival()
98        self._set_log_cdf()
99        self._set_log_survival()
100        self._set_cross_entropy()
101
102        self.context_mode = context.get_context('mode')
103        self.device_target = context.get_context('device_target')
104        self.checktuple = CheckTuple()
105
106        @constexpr(check=False)
107        def _check_tensor(x, name):
108            CheckTensor()(x, name)
109            return x
110        # we use constexpr to force CheckTensor to run only once in pynative mode
111        self.checktensor = CheckTensor() if self.context_mode == 0 else _check_tensor
112        self.broadcast = broadcast_to
113
114        # ops needed for the base class
115        self.cast_base = P.Cast()
116        self.dtype_base = P.DType()
117        self.sametypeshape_base = inner.SameTypeShape()
118        self.sq_base = P.Square()
119        self.sqrt_base = P.Sqrt()
120        self.shape_base = P.Shape()
121        if self.device_target != "Ascend":
122            self.log_base = P.Log()
123            self.exp_base = P.Exp()
124        else:
125            self.exp_base = exp_generic
126            self.log_base = log_generic
127
128    @property
129    def name(self):
130        return self._name
131
132    @property
133    def dtype(self):
134        return self._dtype
135
136    @property
137    def seed(self):
138        return self._seed
139
140    @property
141    def parameters(self):
142        return self._parameters
143
144    @property
145    def is_scalar_batch(self):
146        return self._is_scalar_batch
147
148    @property
149    def batch_shape(self):
150        return self._batch_shape
151
152    @property
153    def broadcast_shape(self):
154        return self._broadcast_shape
155
156    def _reset_parameters(self):
157        self.default_parameters = []
158        self.parameter_names = []
159
160    def _add_parameter(self, value, name):
161        """
162        Cast `value` to a tensor and add it to `self.default_parameters`.
163        Add `name` into  and `self.parameter_names`.
164        """
165        # initialize the attributes if they do not exist yet
166        if not hasattr(self, 'default_parameters'):
167            self.default_parameters = []
168            self.parameter_names = []
169        # cast value to a tensor if it is not None
170        value_t = None if value is None else cast_to_tensor(value, self.parameter_type)
171        self.default_parameters.append(value_t)
172        self.parameter_names.append(name)
173        return value_t
174
175    def _check_param_type(self, *args):
176        """
177        Check the availability and validity of default parameters and `dist_spec_args`.
178        `dist_spec_args` passed in must be tensors. If default parameters of the distribution
179        are None, the parameters must be passed in through `args`.
180        """
181        broadcast_shape = None
182        broadcast_shape_tensor = None
183        common_dtype = None
184        out = []
185
186        for arg, name, default in zip(args, self.parameter_names, self.default_parameters):
187            # check if the argument is a Tensor
188            if arg is not None:
189                self.checktensor(arg, name)
190            else:
191                arg = default if default is not None else raise_none_error(name)
192
193            # broadcast if the number of args > 1
194            if broadcast_shape is None:
195                broadcast_shape = self.shape_base(arg)
196                common_dtype = self.dtype_base(arg)
197                broadcast_shape_tensor = F.fill(
198                    common_dtype, broadcast_shape, 1.0)
199            else:
200                broadcast_shape = self.shape_base(arg + broadcast_shape_tensor)
201                broadcast_shape_tensor = F.fill(
202                    common_dtype, broadcast_shape, 1.0)
203                arg = self.broadcast(arg, broadcast_shape_tensor)
204                # check if the arguments have the same dtype
205                self.sametypeshape_base(arg, broadcast_shape_tensor)
206
207            arg = self.cast_base(arg, self.parameter_type)
208            out.append(arg)
209
210        if len(out) == 1:
211            return out[0]
212
213        # broadcast all args to broadcast_shape
214        result = ()
215        for arg in out:
216            arg = self.broadcast(arg, broadcast_shape_tensor)
217            result = result + (arg,)
218        return result
219
220    def _check_value(self, value, name):
221        """
222        Check availability of `value` as a Tensor.
223        """
224        self.checktensor(value, name)
225        return value
226
227    def _check_is_scalar_batch(self):
228        """
229        Check if the parameters used during initialization are scalars.
230        """
231        param_dict = self.parameters.get('param_dict')
232        for value in param_dict.values():
233            if value is None:
234                continue
235            if not isinstance(value, (int, float)):
236                return False
237        return True
238
239    def _calc_batch_shape(self):
240        """
241        Calculate the broadcast shape of the parameters used during initialization.
242        """
243        broadcast_shape_tensor = None
244        param_dict = self.parameters.get('param_dict')
245        for value in param_dict.values():
246            if value is None:
247                return None
248            if broadcast_shape_tensor is None:
249                broadcast_shape_tensor = cast_to_tensor(value)
250            else:
251                value = cast_to_tensor(value)
252                broadcast_shape_tensor = (value + broadcast_shape_tensor)
253        return broadcast_shape_tensor.shape
254
255    def _set_prob(self):
256        """
257        Set probability function based on the availability of `_prob` and `_log_likehood`.
258        """
259        if hasattr(self, '_prob'):
260            self._call_prob = self._prob
261        elif hasattr(self, '_log_prob'):
262            self._call_prob = self._calc_prob_from_log_prob
263        else:
264            self._call_prob = self._raise_not_implemented_error('prob')
265
266    def _set_sd(self):
267        """
268        Set standard deviation based on the availability of `_sd` and `_var`.
269        """
270        if hasattr(self, '_sd'):
271            self._call_sd = self._sd
272        elif hasattr(self, '_var'):
273            self._call_sd = self._calc_sd_from_var
274        else:
275            self._call_sd = self._raise_not_implemented_error('sd')
276
277    def _set_var(self):
278        """
279        Set variance based on the availability of `_sd` and `_var`.
280        """
281        if hasattr(self, '_var'):
282            self._call_var = self._var
283        elif hasattr(self, '_sd'):
284            self._call_var = self._calc_var_from_sd
285        else:
286            self._call_var = self._raise_not_implemented_error('var')
287
288    def _set_log_prob(self):
289        """
290        Set log probability based on the availability of `_prob` and `_log_prob`.
291        """
292        if hasattr(self, '_log_prob'):
293            self._call_log_prob = self._log_prob
294        elif hasattr(self, '_prob'):
295            self._call_log_prob = self._calc_log_prob_from_prob
296        else:
297            self._call_log_prob = self._raise_not_implemented_error('log_prob')
298
299    def _set_cdf(self):
300        """
301        Set cumulative distribution function (cdf) based on the availability of `_cdf` and `_log_cdf` and
302        `survival_functions`.
303        """
304        if hasattr(self, '_cdf'):
305            self._call_cdf = self._cdf
306        elif hasattr(self, '_log_cdf'):
307            self._call_cdf = self._calc_cdf_from_log_cdf
308        elif hasattr(self, '_survival_function'):
309            self._call_cdf = self._calc_cdf_from_survival
310        elif hasattr(self, '_log_survival'):
311            self._call_cdf = self._calc_cdf_from_log_survival
312        else:
313            self._call_cdf = self._raise_not_implemented_error('cdf')
314
315    def _set_survival(self):
316        """
317        Set survival function based on the availability of _survival function and `_log_survival`
318        and `_call_cdf`.
319        """
320        if not (hasattr(self, '_survival_function') or hasattr(self, '_log_survival') or
321                hasattr(self, '_cdf') or hasattr(self, '_log_cdf')):
322            self._call_survival = self._raise_not_implemented_error(
323                'survival_function')
324        elif hasattr(self, '_survival_function'):
325            self._call_survival = self._survival_function
326        elif hasattr(self, '_log_survival'):
327            self._call_survival = self._calc_survival_from_log_survival
328        elif hasattr(self, '_call_cdf'):
329            self._call_survival = self._calc_survival_from_call_cdf
330
331    def _set_log_cdf(self):
332        """
333        Set log cdf based on the availability of `_log_cdf` and `_call_cdf`.
334        """
335        if not (hasattr(self, '_log_cdf') or hasattr(self, '_cdf') or
336                hasattr(self, '_survival_function') or hasattr(self, '_log_survival')):
337            self._call_log_cdf = self._raise_not_implemented_error('log_cdf')
338        elif hasattr(self, '_log_cdf'):
339            self._call_log_cdf = self._log_cdf
340        elif hasattr(self, '_call_cdf'):
341            self._call_log_cdf = self._calc_log_cdf_from_call_cdf
342
343    def _set_log_survival(self):
344        """
345        Set log survival based on the availability of `_log_survival` and `_call_survival`.
346        """
347        if not (hasattr(self, '_log_survival') or hasattr(self, '_survival_function') or
348                hasattr(self, '_log_cdf') or hasattr(self, '_cdf')):
349            self._call_log_survival = self._raise_not_implemented_error(
350                'log_cdf')
351        elif hasattr(self, '_log_survival'):
352            self._call_log_survival = self._log_survival
353        elif hasattr(self, '_call_survival'):
354            self._call_log_survival = self._calc_log_survival_from_call_survival
355
356    def _set_cross_entropy(self):
357        """
358        Set log survival based on the availability of `_cross_entropy`.
359        """
360        if hasattr(self, '_cross_entropy'):
361            self._call_cross_entropy = self._cross_entropy
362        else:
363            self._call_cross_entropy = self._raise_not_implemented_error(
364                'cross_entropy')
365
366    def _get_dist_args(self, *args, **kwargs):
367        return raise_not_implemented_util('get_dist_args', self.name, *args, **kwargs)
368
369    def get_dist_args(self, *args, **kwargs):
370        """
371        Check the availability and validity of default parameters and `dist_spec_args`.
372
373        Args:
374            *args (list): the list of positional arguments forwarded to subclasses.
375            **kwargs (dict): the dictionary of keyword arguments forwarded to subclasses.
376
377        Note:
378            `dist_spec_args` must be passed in through list or dictionary. The order of `dist_spec_args`
379            should follow the initialization order of default parameters through `_add_parameter`.
380            If some `dist_spec_args` is None, the corresponding default parameter is returned.
381
382        Return:
383            list[Tensor], the list of parameters.
384        """
385        return self._get_dist_args(*args, **kwargs)
386
387    def _get_dist_type(self):
388        return raise_not_implemented_util('get_dist_type', self.name)
389
390    def get_dist_type(self):
391        """
392        Return the type of the distribution.
393
394        Return:
395            string, the name of distribution.
396        """
397        return self._get_dist_type()
398
399    def _raise_not_implemented_error(self, func_name):
400        name = self.name
401
402        def raise_error(*args, **kwargs):
403            return raise_not_implemented_util(func_name, name, *args, **kwargs)
404        return raise_error
405
406    def log_prob(self, value, *args, **kwargs):
407        """
408        Evaluate the log probability(pdf or pmf) at the given value.
409
410        Args:
411            value (Tensor): value to be evaluated.
412            *args (list): the list of positional arguments forwarded to subclasses.
413            **kwargs (dict): the dictionary of keyword arguments forwarded to subclasses.
414
415        Note:
416            A distribution can be optionally passed to the function by passing its `dist_spec_args` through
417            `args` or `kwargs`.
418
419        Return:
420            Tensor, the value of log probability.
421        """
422        return self._call_log_prob(value, *args, **kwargs)
423
424    def _calc_prob_from_log_prob(self, value, *args, **kwargs):
425        r"""
426        Evaluate prob from log probability.
427
428        .. math::
429            probability(x) = \exp(log_likehood(x))
430        """
431        return self.exp_base(self._log_prob(value, *args, **kwargs))
432
433    def prob(self, value, *args, **kwargs):
434        """
435        Evaluate the probability (pdf or pmf) at given value. For a discrete distribution,
436        it is a probability mass function, while for a continuous distribution, it is probability density function.
437
438        Args:
439            value (Tensor): value to be evaluated.
440            *args (list): the list of positional arguments forwarded to subclasses.
441            **kwargs (dict): the dictionary of keyword arguments forwarded to subclasses.
442
443        Note:
444            A distribution can be optionally passed to the function by passing its `dist_spec_args` through
445            `args` or `kwargs`.
446
447        Return:
448            Tensor, the value of probability.
449        """
450        return self._call_prob(value, *args, **kwargs)
451
452    def _calc_log_prob_from_prob(self, value, *args, **kwargs):
453        r"""
454        Evaluate log probability from probability.
455
456        .. math::
457            log_prob(x) = \log(prob(x))
458        """
459        return self.log_base(self._prob(value, *args, **kwargs))
460
461    def cdf(self, value, *args, **kwargs):
462        """
463        Evaluate the cumulative distribution function(cdf) at given value.
464
465        Args:
466            value (Tensor): value to be evaluated.
467            *args (list): the list of positional arguments forwarded to subclasses.
468            **kwargs (dict): the dictionary of keyword arguments forwarded to subclasses.
469
470        Note:
471            A distribution can be optionally passed to the function by passing its `dist_spec_args` through
472            `args` or `kwargs`.
473
474        Return:
475            Tensor, the cdf of the distribution.
476        """
477        return self._call_cdf(value, *args, **kwargs)
478
479    def _calc_cdf_from_log_cdf(self, value, *args, **kwargs):
480        r"""
481        Evaluate cdf from log_cdf.
482
483        .. math::
484            cdf(x) = \exp(log_cdf(x))
485        """
486        return self.exp_base(self._log_cdf(value, *args, **kwargs))
487
488    def _calc_cdf_from_survival(self, value, *args, **kwargs):
489        r"""
490        Evaluate cdf from survival function.
491
492        .. math::
493            cdf(x) =  1 - (survival_function(x))
494        """
495        return 1.0 - self._survival_function(value, *args, **kwargs)
496
497    def _calc_cdf_from_log_survival(self, value, *args, **kwargs):
498        r"""
499        Evaluate cdf from log survival function.
500
501        .. math::
502            cdf(x) =  1 - (\exp(log_survival(x)))
503        """
504        return 1.0 - self.exp_base(self._log_survival(value, *args, **kwargs))
505
506    def log_cdf(self, value, *args, **kwargs):
507        """
508        Evaluate the log the cumulative distribution function(cdf) at given value.
509
510        Args:
511            value (Tensor): value to be evaluated.
512            *args (list): the list of positional arguments forwarded to subclasses.
513            **kwargs (dict): the dictionary of keyword arguments forwarded to subclasses.
514
515        Note:
516            A distribution can be optionally passed to the function by passing its `dist_spec_args` through
517            `args` or `kwargs`.
518
519        Return:
520            Tensor, the log cdf of the distribution.
521        """
522        return self._call_log_cdf(value, *args, **kwargs)
523
524    def _calc_log_cdf_from_call_cdf(self, value, *args, **kwargs):
525        r"""
526        Evaluate log cdf from cdf.
527
528        .. math::
529            log_cdf(x) = \log(cdf(x))
530        """
531        return self.log_base(self._call_cdf(value, *args, **kwargs))
532
533    def survival_function(self, value, *args, **kwargs):
534        """
535        Evaluate the survival function at given value.
536
537        Args:
538            value (Tensor): value to be evaluated.
539            *args (list): the list of positional arguments forwarded to subclasses.
540            **kwargs (dict): the dictionary of keyword arguments forwarded to subclasses.
541
542        Note:
543            A distribution can be optionally passed to the function by passing its `dist_spec_args` through
544            `args` or `kwargs`.
545
546        Return:
547            Tensor, the survival function of the distribution.
548        """
549        return self._call_survival(value, *args, **kwargs)
550
551    def _calc_survival_from_call_cdf(self, value, *args, **kwargs):
552        r"""
553        Evaluate survival function from cdf.
554
555        .. math::
556            survival_function(x) =  1 - (cdf(x))
557        """
558        return 1.0 - self._call_cdf(value, *args, **kwargs)
559
560    def _calc_survival_from_log_survival(self, value, *args, **kwargs):
561        r"""
562        Evaluate survival function from log survival function.
563
564        .. math::
565            survival(x) = \exp(survival_function(x))
566        """
567        return self.exp_base(self._log_survival(value, *args, **kwargs))
568
569    def log_survival(self, value, *args, **kwargs):
570        """
571        Evaluate the log survival function at given value.
572
573        Args:
574            value (Tensor): value to be evaluated.
575            *args (list): the list of positional arguments forwarded to subclasses.
576            **kwargs (dict): the dictionary of keyword arguments forwarded to subclasses.
577
578        Note:
579            A distribution can be optionally passed to the function by passing its `dist_spec_args` through
580            `args` or `kwargs`.
581
582        Return:
583            Tensor, the log survival function of the distribution.
584        """
585        return self._call_log_survival(value, *args, **kwargs)
586
587    def _calc_log_survival_from_call_survival(self, value, *args, **kwargs):
588        r"""
589        Evaluate log survival function from survival function.
590
591        .. math::
592            log_survival(x) = \log(survival_function(x))
593        """
594        return self.log_base(self._call_survival(value, *args, **kwargs))
595
596    def _kl_loss(self, *args, **kwargs):
597        return raise_not_implemented_util('kl_loss', self.name, *args, **kwargs)
598
599    def kl_loss(self, dist, *args, **kwargs):
600        """
601        Evaluate the KL divergence, i.e. KL(a||b).
602
603        Args:
604            dist (str): type of the distribution.
605            *args (list): the list of positional arguments forwarded to subclasses.
606            **kwargs (dict): the dictionary of keyword arguments forwarded to subclasses.
607
608        Note:
609            `dist_spec_args` of distribution b must be passed to the function through `args` or `kwargs`.
610            Passing in `dist_spec_args` of distribution a is optional.
611
612        Return:
613            Tensor, the kl loss function of the distribution.
614        """
615        return self._kl_loss(dist, *args, **kwargs)
616
617    def _mean(self, *args, **kwargs):
618        return raise_not_implemented_util('mean', self.name, *args, **kwargs)
619
620    def mean(self, *args, **kwargs):
621        """
622        Evaluate the mean.
623
624        Args:
625            *args (list): the list of positional arguments forwarded to subclasses.
626            **kwargs (dict): the dictionary of keyword arguments forwarded to subclasses.
627
628        Note:
629            A distribution can be optionally passed to the function by passing its *dist_spec_args* through
630            `args` or `kwargs`.
631
632        Return:
633            Tensor, the mean of the distribution.
634        """
635        return self._mean(*args, **kwargs)
636
637    def _mode(self, *args, **kwargs):
638        return raise_not_implemented_util('mode', self.name, *args, **kwargs)
639
640    def mode(self, *args, **kwargs):
641        """
642        Evaluate the mode.
643
644        Args:
645            *args (list): the list of positional arguments forwarded to subclasses.
646            **kwargs (dict): the dictionary of keyword arguments forwarded to subclasses.
647
648        Note:
649            A distribution can be optionally passed to the function by passing its *dist_spec_args* through
650            `args` or `kwargs`.
651
652        Return:
653            Tensor, the mode of the distribution.
654        """
655        return self._mode(*args, **kwargs)
656
657    def sd(self, *args, **kwargs):
658        """
659        Evaluate the standard deviation.
660
661        Args:
662            *args (list): the list of positional arguments forwarded to subclasses.
663            **kwargs (dict): the dictionary of keyword arguments forwarded to subclasses.
664
665        Note:
666            A distribution can be optionally passed to the function by passing its *dist_spec_args* through
667            `args` or `kwargs`.
668
669        Return:
670            Tensor, the standard deviation of the distribution.
671        """
672        return self._call_sd(*args, **kwargs)
673
674    def var(self, *args, **kwargs):
675        """
676        Evaluate the variance.
677
678        Args:
679            *args (list): the list of positional arguments forwarded to subclasses.
680            **kwargs (dict): the dictionary of keyword arguments forwarded to subclasses.
681
682        Note:
683            A distribution can be optionally passed to the function by passing its *dist_spec_args* through
684            `args` or `kwargs`.
685
686        Return:
687            Tensor, the variance of the distribution.
688        """
689        return self._call_var(*args, **kwargs)
690
691    def _calc_sd_from_var(self, *args, **kwargs):
692        r"""
693        Evaluate log probability from probability.
694
695        .. math::
696            STD(x) = \sqrt(VAR(x))
697        """
698        return self.sqrt_base(self._var(*args, **kwargs))
699
700    def _calc_var_from_sd(self, *args, **kwargs):
701        r"""
702        Evaluate log probability from probability.
703
704        .. math::
705            VAR(x) = STD(x) ^ 2
706        """
707        return self.sq_base(self._sd(*args, **kwargs))
708
709    def _entropy(self, *args, **kwargs):
710        return raise_not_implemented_util('entropy', self.name, *args, **kwargs)
711
712    def entropy(self, *args, **kwargs):
713        """
714        Evaluate the entropy.
715
716        Args:
717            *args (list): the list of positional arguments forwarded to subclasses.
718            **kwargs (dict): the dictionary of keyword arguments forwarded to subclasses.
719
720        Note:
721            A distribution can be optionally passed to the function by passing its *dist_spec_args* through
722            `args` or `kwargs`.
723
724        Return:
725            Tensor, the entropy of the distribution.
726        """
727        return self._entropy(*args, **kwargs)
728
729    def cross_entropy(self, dist, *args, **kwargs):
730        """
731        Evaluate the cross_entropy between distribution a and b.
732
733        Args:
734            dist (str): type of the distribution.
735            *args (list): the list of positional arguments forwarded to subclasses.
736            **kwargs (dict): the dictionary of keyword arguments forwarded to subclasses.
737
738        Note:
739            `dist_spec_args` of distribution b must be passed to the function through `args` or `kwargs`.
740            Passing in `dist_spec_args` of distribution a is optional.
741
742        Return:
743            Tensor, the cross_entropy of two distributions.
744        """
745        return self._call_cross_entropy(dist, *args, **kwargs)
746
747    def _calc_cross_entropy(self, dist, *args, **kwargs):
748        r"""
749        Evaluate cross_entropy from entropy and kl divergence.
750
751        .. math::
752            H(X, Y) = H(X) + KL(X||Y)
753        """
754        return self._entropy(*args, **kwargs) + self._kl_loss(dist, *args, **kwargs)
755
756    def _sample(self, *args, **kwargs):
757        return raise_not_implemented_util('sample', self.name, *args, **kwargs)
758
759    def sample(self, *args, **kwargs):
760        """
761        Sampling function.
762
763        Args:
764            *args (list): the list of positional arguments forwarded to subclasses.
765            **kwargs (dict): the dictionary of keyword arguments forwarded to subclasses.
766
767        Note:
768            A distribution can be optionally passed to the function by passing its *dist_spec_args* through
769            `args` or `kwargs`.
770
771        Return:
772            Tensor, the sample generated from the distribution.
773        """
774        return self._sample(*args, **kwargs)
775
776    def construct(self, name, *args, **kwargs):
777        """
778        Override `construct` in Cell.
779
780        Note:
781            Names of supported functions include:
782            'prob', 'log_prob', 'cdf', 'log_cdf', 'survival_function', 'log_survival',
783            'var', 'sd', 'mode', 'mean', 'entropy', 'kl_loss', 'cross_entropy', 'sample',
784            'get_dist_args', and 'get_dist_type'.
785
786        Args:
787            name (str): The name of the function.
788            *args (list): A list of positional arguments that the function needs.
789            **kwargs (dict): A dictionary of keyword arguments that the function needs.
790
791        Return:
792            Tensor, the value of corresponding computation method.
793        """
794
795        if name == 'log_prob':
796            return self._call_log_prob(*args, **kwargs)
797        if name == 'prob':
798            return self._call_prob(*args, **kwargs)
799        if name == 'cdf':
800            return self._call_cdf(*args, **kwargs)
801        if name == 'log_cdf':
802            return self._call_log_cdf(*args, **kwargs)
803        if name == 'survival_function':
804            return self._call_survival(*args, **kwargs)
805        if name == 'log_survival':
806            return self._call_log_survival(*args, **kwargs)
807        if name == 'kl_loss':
808            return self._kl_loss(*args, **kwargs)
809        if name == 'mean':
810            return self._mean(*args, **kwargs)
811        if name == 'mode':
812            return self._mode(*args, **kwargs)
813        if name == 'sd':
814            return self._call_sd(*args, **kwargs)
815        if name == 'var':
816            return self._call_var(*args, **kwargs)
817        if name == 'entropy':
818            return self._entropy(*args, **kwargs)
819        if name == 'cross_entropy':
820            return self._call_cross_entropy(*args, **kwargs)
821        if name == 'sample':
822            return self._sample(*args, **kwargs)
823        if name == 'get_dist_args':
824            return self._get_dist_args(*args, **kwargs)
825        if name == 'get_dist_type':
826            return self._get_dist_type()
827        return raise_not_implemented_util(name, self.name, *args, **kwargs)
828