1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Perplexity""" 16import math 17import numpy as np 18from mindspore._checkparam import Validator as validator 19from .metric import Metric, rearrange_inputs 20 21 22class Perplexity(Metric): 23 r""" 24 Computes perplexity. Perplexity is a measurement about how well a probability distribution or a model predicts a 25 sample. A low perplexity indicates the model can predict the sample well. The function is shown as follows: 26 27 .. math:: 28 PP(W)=P(w_{1}w_{2}...w_{N})^{-\frac{1}{N}}=\sqrt[N]{\frac{1}{P(w_{1}w_{2}...w_{N})}} 29 30 Args: 31 ignore_label (int): Index of an invalid label to be ignored when counting. If set to `None`, it will include all 32 entries. Default: -1. 33 34 Supported Platforms: 35 ``Ascend`` ``GPU`` ``CPU`` 36 37 Note: 38 The method `update` must be called with the form `update(preds, labels)`. 39 40 Examples: 41 >>> import numpy as np 42 >>> from mindspore import nn, Tensor 43 >>> 44 >>> x = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]])) 45 >>> y = Tensor(np.array([1, 0, 1])) 46 >>> metric = nn.Perplexity(ignore_label=None) 47 >>> metric.clear() 48 >>> metric.update(x, y) 49 >>> perplexity = metric.eval() 50 >>> print(perplexity) 51 2.231443166940565 52 """ 53 54 def __init__(self, ignore_label=None): 55 super(Perplexity, self).__init__() 56 57 if ignore_label is None: 58 self.ignore_label = ignore_label 59 else: 60 self.ignore_label = validator.check_value_type("ignore_label", ignore_label, [int]) 61 self.clear() 62 63 def clear(self): 64 """Clears the internal evaluation result.""" 65 self._sum_metric = 0.0 66 self._num_inst = 0 67 68 @rearrange_inputs 69 def update(self, *inputs): 70 """ 71 Updates the internal evaluation result: math:preds and :math:labels. 72 73 Args: 74 inputs: Input `preds` and `labels`. `preds` and `labels` are Tensor, list or numpy.ndarray. 75 `preds` is the predicted values, `labels` is the label of the data. 76 The shape of `preds` and `labels` are both :math:`(N, C)`. 77 78 Raises: 79 ValueError: If the number of the inputs is not 2. 80 RuntimeError: If preds and labels have different lengths. 81 RuntimeError: If label shape is not equal to pred shape. 82 """ 83 if len(inputs) != 2: 84 raise ValueError('The perplexity needs 2 inputs (preds, labels), but got {}.'.format(len(inputs))) 85 86 preds = [self._convert_data(inputs[0])] 87 labels = [self._convert_data(inputs[1])] 88 89 if len(preds) != len(labels): 90 raise RuntimeError('The preds and labels should have the same length, but the length of preds is{}, ' 91 'the length of labels is {}.'.format(len(preds), len(labels))) 92 93 loss = 0. 94 num = 0 95 for label, pred in zip(labels, preds): 96 if label.size != pred.size / pred.shape[-1]: 97 raise RuntimeError("shape mismatch: label shape should be equal to pred shape, but got label shape " 98 "is {}, pred shape is {}.".format(label.shape, pred.shape)) 99 label = label.reshape((label.size,)) 100 label_expand = label.astype(int) 101 label_expand = np.expand_dims(label_expand, axis=1) 102 first_indices = np.arange(label_expand.shape[0])[:, None] 103 pred = np.squeeze(pred[first_indices, label_expand]) 104 if self.ignore_label is not None: 105 ignore = (label == self.ignore_label).astype(pred.dtype) 106 num -= np.sum(ignore) 107 pred = pred * (1 - ignore) + ignore 108 loss -= np.sum(np.log(np.maximum(1e-10, pred))) 109 num += pred.size 110 self._sum_metric += loss 111 self._num_inst += num 112 113 def eval(self): 114 r""" 115 Returns the current evaluation result. 116 117 Returns: 118 float, the computed result. 119 120 Raises: 121 RuntimeError: If the sample size is 0. 122 """ 123 if self._num_inst == 0: 124 raise RuntimeError('The perplexity can not be calculated, because the number of samples is 0.') 125 126 return math.exp(self._sum_metric / self._num_inst) 127