• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Categorical Distribution"""
16import numpy as np
17from mindspore.ops import operations as P
18from mindspore.ops import composite as C
19from mindspore.ops.functional import stop_gradient
20from mindspore._checkparam import Validator
21import mindspore.nn as nn
22from mindspore.common import dtype as mstype
23from .distribution import Distribution
24from ._utils.utils import check_prob, check_sum_equal_one, check_rank,\
25                          check_distribution_name, raise_not_implemented_util
26from ._utils.custom_ops import exp_generic, log_generic, broadcast_to
27
28
29class Categorical(Distribution):
30    """
31    Create a categorical distribution parameterized by event probabilities.
32
33    Args:
34        probs (Tensor, list, numpy.ndarray): Event probabilities.
35        seed (int): The global seed is used in sampling. Global seed is used if it is None. Default: None.
36        dtype (mindspore.dtype): The type of the event samples. Default: mstype.int32.
37        name (str): The name of the distribution. Default: Categorical.
38
39    Supported Platforms:
40        ``Ascend`` ``GPU``
41
42    Note:
43        `probs` must have rank at least 1, values are proper probabilities and sum to 1.
44
45    Examples:
46        >>> import mindspore
47        >>> import mindspore.nn as nn
48        >>> import mindspore.nn.probability.distribution as msd
49        >>> from mindspore import Tensor
50        >>> # To initialize a Categorical distribution of probs [0.5, 0.5]
51        >>> ca1 = msd.Categorical(probs=[0.2, 0.8], dtype=mindspore.int32)
52        >>> # A Categorical distribution can be initialized without arguments.
53        >>> # In this case, `probs` must be passed in through arguments during function calls.
54        >>> ca2 = msd.Categorical(dtype=mindspore.int32)
55        >>> # Here are some tensors used below for testing
56        >>> value = Tensor([1, 0], dtype=mindspore.int32)
57        >>> probs_a = Tensor([0.5, 0.5], dtype=mindspore.float32)
58        >>> probs_b = Tensor([0.35, 0.65], dtype=mindspore.float32)
59        >>> # Private interfaces of probability functions corresponding to public interfaces, including
60        >>> # `prob`, `log_prob`, `cdf`, `log_cdf`, `survival_function`, and `log_survival`, are the same as follows.
61        >>> # Args:
62        >>> #     value (Tensor): the value to be evaluated.
63        >>> #     probs (Tensor): event probabilities. Default: self.probs.
64        >>> # Examples of `prob`.
65        >>> # Similar calls can be made to other probability functions
66        >>> # by replacing `prob` by the name of the function.
67        >>> ans = ca1.prob(value)
68        >>> print(ans.shape)
69        (2,)
70        >>> # Evaluate `prob` with respect to distribution b.
71        >>> ans = ca1.prob(value, probs_b)
72        >>> print(ans.shape)
73        (2,)
74        >>> # `probs` must be passed in during function calls.
75        >>> ans = ca2.prob(value, probs_a)
76        >>> print(ans.shape)
77        (2,)
78        >>> # Functions `mean`, `sd`, `var`, and `entropy` have the same arguments.
79        >>> # Args:
80        >>> #     probs (Tensor): event probabilities. Default: self.probs.
81        >>> # Examples of `mean`. `sd`, `var`, and `entropy` are similar.
82        >>> ans = ca1.mean() # return 0.8
83        >>> print(ans.shape)
84        (1,)
85        >>> ans = ca1.mean(probs_b)
86        >>> print(ans.shape)
87        (1,)
88        >>> # `probs` must be passed in during function calls.
89        >>> ans = ca2.mean(probs_a)
90        >>> print(ans.shape)
91        (1,)
92        >>> # Interfaces of `kl_loss` and `cross_entropy` are the same as follows:
93        >>> # Args:
94        >>> #     dist (str): the name of the distribution. Only 'Categorical' is supported.
95        >>> #     probs_b (Tensor): event probabilities of distribution b.
96        >>> #     probs (Tensor): event probabilities of distribution a. Default: self.probs.
97        >>> # Examples of kl_loss. `cross_entropy` is similar.
98        >>> ans = ca1.kl_loss('Categorical', probs_b)
99        >>> print(ans.shape)
100        ()
101        >>> ans = ca1.kl_loss('Categorical', probs_b, probs_a)
102        >>> print(ans.shape)
103        ()
104        >>> # An additional `probs` must be passed in.
105        >>> ans = ca2.kl_loss('Categorical', probs_b, probs_a)
106        >>> print(ans.shape)
107        ()
108    """
109
110    def __init__(self,
111                 probs=None,
112                 seed=None,
113                 dtype=mstype.int32,
114                 name="Categorical"):
115        param = dict(locals())
116        param['param_dict'] = {'probs': probs}
117        valid_dtype = mstype.uint_type + mstype.int_type + mstype.float_type
118        Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__)
119        super(Categorical, self).__init__(seed, dtype, name, param)
120
121        self._probs = self._add_parameter(probs, 'probs')
122        if self.probs is not None:
123            check_rank(self.probs)
124            check_prob(self.probs)
125            check_sum_equal_one(probs)
126
127            # update is_scalar_batch and broadcast_shape
128            # drop one dimension
129            if self.probs.shape[:-1] == ():
130                self._is_scalar_batch = True
131            self._broadcast_shape = self._broadcast_shape[:-1]
132
133        self.argmax = P.ArgMaxWithValue(axis=-1)
134        self.broadcast = broadcast_to
135        self.cast = P.Cast()
136        self.clip_by_value = C.clip_by_value
137        self.concat = P.Concat(-1)
138        self.cumsum = P.CumSum()
139        self.dtypeop = P.DType()
140        self.exp = exp_generic
141        self.expand_dim = P.ExpandDims()
142        self.fill = P.Fill()
143        self.gather = P.GatherNd()
144        self.greater = P.Greater()
145        self.issubclass = P.IsSubClass()
146        self.less = P.Less()
147        self.log = log_generic
148        self.log_softmax = P.LogSoftmax()
149        self.logicor = P.LogicalOr()
150        self.logicand = P.LogicalAnd()
151        self.multinomial = P.Multinomial(seed=self.seed)
152        self.reshape = P.Reshape()
153        self.reduce_sum = P.ReduceSum(keep_dims=True)
154        self.select = P.Select()
155        self.shape = P.Shape()
156        self.softmax = P.Softmax()
157        self.squeeze = P.Squeeze()
158        self.squeeze_first_axis = P.Squeeze(0)
159        self.squeeze_last_axis = P.Squeeze(-1)
160        self.square = P.Square()
161        self.transpose = P.Transpose()
162
163        self.index_type = mstype.int32
164        self.nan = np.nan
165
166    @property
167    def probs(self):
168        """
169        Return the probability after casting to dtype.
170        """
171        return self._probs
172
173    def extend_repr(self):
174        """Display instance object as string."""
175        if self.is_scalar_batch:
176            s = 'probs = {}'.format(self.probs)
177        else:
178            s = 'batch_shape = {}'.format(self._broadcast_shape)
179        return s
180
181    def _get_dist_type(self):
182        return "Categorical"
183
184    def _get_dist_args(self, probs=None):
185        if probs is not None:
186            self.checktensor(probs, 'probs')
187        else:
188            probs = self.probs
189        return (probs,)
190
191    def _mean(self, probs=None):
192        r"""
193        .. math::
194            E[X] = \sum_{i=0}^{num_classes-1} i*p_i
195        """
196        probs = self._check_param_type(probs)
197        num_classes = self.shape(probs)[-1]
198        index = nn.Range(0., num_classes, 1.)()
199        return self.reduce_sum(index * probs, -1)
200
201    def _mode(self, probs=None):
202        probs = self._check_param_type(probs)
203        index, _ = self.argmax(probs)
204        mode = self.cast(index, self.dtype)
205        return mode
206
207    def _var(self, probs=None):
208        r"""
209        .. math::
210            VAR(X) = E[X^{2}] - (E[X])^{2}
211        """
212        probs = self._check_param_type(probs)
213        num_classes = self.shape(probs)[-1]
214        index = nn.Range(0., num_classes, 1.)()
215        return self.reduce_sum(self.square(index) * probs, -1) -\
216               self.square(self.reduce_sum(index * probs, -1))
217
218    def _entropy(self, probs=None):
219        r"""
220        Evaluate entropy.
221
222        .. math::
223            H(X) = -\sum(logits * probs)
224        """
225        probs = self._check_param_type(probs)
226        logits = self.log(probs)
227        return self.squeeze(-self.reduce_sum(logits * probs, -1))
228
229    def _kl_loss(self, dist, probs_b, probs=None):
230        """
231        Evaluate KL divergence between Categorical distributions.
232
233        Args:
234            dist (str): The type of the distributions. Should be "Categorical" in this case.
235            probs_b (Tensor): Event probabilities of distribution b.
236            probs (Tensor): Event probabilities of distribution a. Default: self.probs.
237        """
238        check_distribution_name(dist, 'Categorical')
239        probs_b = self._check_value(probs_b, 'probs_b')
240        probs_b = self.cast(probs_b, self.parameter_type)
241        probs_a = self._check_param_type(probs)
242        logits_a = self.log(probs_a)
243        logits_b = self.log(probs_b)
244        return self.squeeze(self.reduce_sum(
245            self.softmax(logits_a) * (self.log_softmax(logits_a) - (self.log_softmax(logits_b))), -1))
246
247    def _cross_entropy(self, dist, probs_b, probs=None):
248        """
249        Evaluate cross entropy between Categorical distributions.
250
251        Args:
252            dist (str): The type of the distributions. Should be "Categorical" in this case.
253            probs_b (Tensor): Event probabilities of distribution b.
254            probs (Tensor): Event probabilities of distribution a. Default: self.probs.
255        """
256        check_distribution_name(dist, 'Categorical')
257        return self._entropy(probs) + self._kl_loss(dist, probs_b, probs)
258
259    def _log_prob(self, value, probs=None):
260        r"""
261        Evaluate log probability.
262
263        Args:
264            value (Tensor): The value to be evaluated.
265            probs (Tensor): Event probabilities. Default: self.probs.
266        """
267        value = self._check_value(value, 'value')
268
269        probs = self._check_param_type(probs)
270        logits = self.log(probs)
271
272        # find the right integer to compute index
273        # here we simulate casting to int but still keeping float dtype
274        value = self.cast(value, self.dtypeop(probs))
275
276        zeros = self.fill(self.dtypeop(value), self.shape(value), 0.0)
277        between_zero_neone = self.logicand(self.less(value, 0,),
278                                           self.greater(value, -1.))
279        value = self.select(between_zero_neone,
280                            zeros,
281                            P.Floor()(value))
282
283        # handle the case when value is of shape () and probs is a scalar batch
284        drop_dim = False
285        if self.shape(value) == () and self.shape(probs)[:-1] == ():
286            drop_dim = True
287            # manually add one more dimension: () -> (1,)
288            # drop this dimension before return
289            value = self.expand_dim(value, -1)
290
291        value = self.expand_dim(value, -1)
292
293        broadcast_shape_tensor = logits * value
294        broadcast_shape = self.shape(broadcast_shape_tensor)
295        num_classes = broadcast_shape[-1]
296        label_shape = broadcast_shape[:-1]
297
298        # broadcasting logits and value
299        # logit_pmf shape (num of labels, C)
300        logits = self.broadcast(logits, broadcast_shape_tensor)
301        value = self.broadcast(value, broadcast_shape_tensor)[..., :1]
302
303        # flatten value to shape (number of labels, 1)
304        # clip value to be in range from 0 to num_classes -1 and cast into int32
305        value = self.reshape(value, (-1, 1))
306        out_of_bound = self.squeeze_last_axis(self.logicor(\
307                        self.less(value, 0.0), self.less(num_classes-1, value)))
308        # deal with the case the there is only one class.
309        value_clipped = self.clip_by_value(value, 0.0, num_classes - 1)
310        value_clipped = self.cast(value_clipped, self.index_type)
311        # create index from 0 ... NumOfLabels
312        index = self.reshape(nn.Range(0, self.shape(value)[0], 1)(), (-1, 1))
313        index = self.concat((index, value_clipped))
314
315        # index into logit_pmf, fill in out_of_bound places with -inf
316        # reshape into label shape N
317        logits_pmf = self.gather(self.reshape(logits, (-1, num_classes)), index)
318        nan = self.fill(self.dtypeop(logits_pmf), self.shape(logits_pmf), self.nan)
319        logits_pmf = self.select(out_of_bound, nan, logits_pmf)
320        ans = self.reshape(logits_pmf, label_shape)
321        if drop_dim:
322            return self.squeeze(ans)
323        return ans
324
325    def _cdf(self, value, probs=None):
326        r"""
327        Cumulative distribution function (cdf) of Categorical distributions.
328
329        Args:
330            value (Tensor): The value to be evaluated.
331            probs (Tensor): Event probabilities. Default: self.probs.
332        """
333        value = self._check_value(value, 'value')
334        probs = self._check_param_type(probs)
335
336        value = self.cast(value, self.dtypeop(probs))
337
338        zeros = self.fill(self.dtypeop(value), self.shape(value), 0.0)
339        between_zero_neone = self.logicand(self.less(value, 0,), self.greater(value, -1.))
340        value = self.select(between_zero_neone, zeros, P.Floor()(value))
341
342        drop_dim = False
343        if self.shape(value) == () and self.shape(probs)[:-1] == ():
344            drop_dim = True
345            value = self.expand_dim(value, -1)
346
347        value = self.expand_dim(value, -1)
348
349        broadcast_shape_tensor = probs * value
350        broadcast_shape = self.shape(broadcast_shape_tensor)
351        num_classes = broadcast_shape[-1]
352        label_shape = broadcast_shape[:-1]
353
354        probs = self.broadcast(probs, broadcast_shape_tensor)
355        value = self.broadcast(value, broadcast_shape_tensor)[..., :1]
356
357        # flatten value to shape (number of labels, 1)
358        value = self.reshape(value, (-1, 1))
359
360        # drop one dimension to match cdf
361        # clip value to be in range from 0 to num_classes -1 and cast into int32
362        less_than_zero = self.squeeze_last_axis(self.less(value, 0.0))
363        value_clipped = self.clip_by_value(value, 0.0, num_classes - 1)
364        value_clipped = self.cast(value_clipped, self.index_type)
365
366        index = self.reshape(nn.Range(0, self.shape(value)[0], 1)(), (-1, 1))
367        index = self.concat((index, value_clipped))
368
369        # reshape probs and fill less_than_zero places with 0
370        probs = self.reshape(probs, (-1, num_classes))
371        cdf = self.gather(self.cumsum(probs, 1), index)
372        zeros = self.fill(self.dtypeop(cdf), self.shape(cdf), 0.0)
373        cdf = self.select(less_than_zero, zeros, cdf)
374        cdf = self.reshape(cdf, label_shape)
375
376        if drop_dim:
377            return self.squeeze(cdf)
378        return cdf
379
380    def _sample(self, shape=(), probs=None):
381        """
382        Sampling.
383
384        Args:
385            shape (tuple): The shape of the sample. Default: ().
386            probs (Tensor): Event probabilities. Default: self.probs.
387
388        Returns:
389            Tensor, shape is shape(probs)[:-1] + sample_shape
390        """
391        if self.device_target == 'Ascend':
392            raise_not_implemented_util('On d backend, sample', self.name)
393        shape = self.checktuple(shape, 'shape')
394        probs = self._check_param_type(probs)
395        num_classes = self.shape(probs)[-1]
396        batch_shape = self.shape(probs)[:-1]
397
398        sample_shape = shape + batch_shape
399        drop_dim = False
400        if sample_shape == ():
401            drop_dim = True
402            sample_shape = (1,)
403
404        probs_2d = self.reshape(probs, (-1, num_classes))
405        sample_tensor = self.fill(self.dtype, shape, 1.0)
406        sample_tensor = self.reshape(sample_tensor, (-1, 1))
407        num_sample = self.shape(sample_tensor)[0]
408        samples = self.multinomial(probs_2d, num_sample)
409        samples = self.squeeze(self.transpose(samples, (1, 0)))
410        samples = self.cast(self.reshape(samples, sample_shape), self.dtype)
411        if drop_dim:
412            return self.squeeze_first_axis(samples)
413        samples = stop_gradient(samples)
414        return samples
415