• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Perplexity"""
16import math
17import numpy as np
18from mindspore._checkparam import Validator as validator
19from .metric import Metric, rearrange_inputs
20
21
22class Perplexity(Metric):
23    r"""
24    Computes perplexity. Perplexity is a measurement about how well a probability distribution or a model predicts a
25    sample. A low perplexity indicates the model can predict the sample well. The function is shown as follows:
26
27    .. math::
28        PP(W)=P(w_{1}w_{2}...w_{N})^{-\frac{1}{N}}=\sqrt[N]{\frac{1}{P(w_{1}w_{2}...w_{N})}}
29
30    Args:
31        ignore_label (int): Index of an invalid label to be ignored when counting. If set to `None`, it will include all
32                            entries. Default: -1.
33
34    Supported Platforms:
35        ``Ascend`` ``GPU`` ``CPU``
36
37    Note:
38        The method `update` must be called with the form `update(preds, labels)`.
39
40    Examples:
41        >>> import numpy as np
42        >>> from mindspore import nn, Tensor
43        >>>
44        >>> x = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]))
45        >>> y = Tensor(np.array([1, 0, 1]))
46        >>> metric = nn.Perplexity(ignore_label=None)
47        >>> metric.clear()
48        >>> metric.update(x, y)
49        >>> perplexity = metric.eval()
50        >>> print(perplexity)
51        2.231443166940565
52    """
53
54    def __init__(self, ignore_label=None):
55        super(Perplexity, self).__init__()
56
57        if ignore_label is None:
58            self.ignore_label = ignore_label
59        else:
60            self.ignore_label = validator.check_value_type("ignore_label", ignore_label, [int])
61        self.clear()
62
63    def clear(self):
64        """Clears the internal evaluation result."""
65        self._sum_metric = 0.0
66        self._num_inst = 0
67
68    @rearrange_inputs
69    def update(self, *inputs):
70        """
71        Updates the internal evaluation result: math:preds and :math:labels.
72
73        Args:
74            inputs: Input `preds` and `labels`. `preds` and `labels` are Tensor, list or numpy.ndarray.
75                    `preds` is the predicted values, `labels` is the label of the data.
76                    The shape of `preds` and `labels` are both :math:`(N, C)`.
77
78        Raises:
79            ValueError: If the number of the inputs is not 2.
80            RuntimeError: If preds and labels have different lengths.
81            RuntimeError: If label shape is not equal to pred shape.
82        """
83        if len(inputs) != 2:
84            raise ValueError('The perplexity needs 2 inputs (preds, labels), but got {}.'.format(len(inputs)))
85
86        preds = [self._convert_data(inputs[0])]
87        labels = [self._convert_data(inputs[1])]
88
89        if len(preds) != len(labels):
90            raise RuntimeError('The preds and labels should have the same length, but the length of preds is{}, '
91                               'the length of labels is {}.'.format(len(preds), len(labels)))
92
93        loss = 0.
94        num = 0
95        for label, pred in zip(labels, preds):
96            if label.size != pred.size / pred.shape[-1]:
97                raise RuntimeError("shape mismatch: label shape should be equal to pred shape, but got label shape "
98                                   "is {}, pred shape is {}.".format(label.shape, pred.shape))
99            label = label.reshape((label.size,))
100            label_expand = label.astype(int)
101            label_expand = np.expand_dims(label_expand, axis=1)
102            first_indices = np.arange(label_expand.shape[0])[:, None]
103            pred = np.squeeze(pred[first_indices, label_expand])
104            if self.ignore_label is not None:
105                ignore = (label == self.ignore_label).astype(pred.dtype)
106                num -= np.sum(ignore)
107                pred = pred * (1 - ignore) + ignore
108            loss -= np.sum(np.log(np.maximum(1e-10, pred)))
109            num += pred.size
110        self._sum_metric += loss
111        self._num_inst += num
112
113    def eval(self):
114        r"""
115        Returns the current evaluation result.
116
117        Returns:
118            float, the computed result.
119
120        Raises:
121            RuntimeError: If the sample size is 0.
122        """
123        if self._num_inst == 0:
124            raise RuntimeError('The perplexity can not be calculated, because the number of samples is 0.')
125
126        return math.exp(self._sum_metric / self._num_inst)
127