1# Copyright 2020-2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Recall.""" 16import sys 17 18import numpy as np 19 20from mindspore._checkparam import Validator as validator 21from .metric import EvaluationBase, rearrange_inputs 22 23 24class Recall(EvaluationBase): 25 r""" 26 Calculates recall for classification and multilabel data. 27 28 The recall class creates two local variables, :math:`\text{true_positive}` and :math:`\text{false_negative}`, 29 that are used to compute the recall. This value is ultimately returned as the recall, an idempotent operation 30 that simply divides :math:`\text{true_positive}` by the sum of :math:`\text{true_positive}` and 31 :math:`\text{false_negative}`. 32 33 .. math:: 34 \text{recall} = \frac{\text{true_positive}}{\text{true_positive} + \text{false_negative}} 35 36 Note: 37 In the multi-label cases, the elements of :math:`y` and :math:`y_{pred}` must be 0 or 1. 38 39 Args: 40 eval_type (str): The metric to calculate the recall over a dataset, for classification or 41 multilabel. Default: 'classification'. 42 43 Examples: 44 >>> import numpy as np 45 >>> from mindspore import nn, Tensor 46 >>> 47 >>> x = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]])) 48 >>> y = Tensor(np.array([1, 0, 1])) 49 >>> metric = nn.Recall('classification') 50 >>> metric.clear() 51 >>> metric.update(x, y) 52 >>> recall = metric.eval() 53 >>> print(recall) 54 [1. 0.5] 55 """ 56 def __init__(self, eval_type='classification'): 57 super(Recall, self).__init__(eval_type) 58 self.eps = sys.float_info.min 59 self.clear() 60 61 def clear(self): 62 """Clears the internal evaluation result.""" 63 self._class_num = 0 64 if self._type == "multilabel": 65 self._true_positives = np.empty(0) 66 self._actual_positives = np.empty(0) 67 self._true_positives_average = 0 68 self._actual_positives_average = 0 69 else: 70 self._true_positives = 0 71 self._actual_positives = 0 72 73 @rearrange_inputs 74 def update(self, *inputs): 75 """ 76 Updates the internal evaluation result with `y_pred` and `y`. 77 78 Args: 79 inputs: Input `y_pred` and `y`. `y_pred` and `y` are a `Tensor`, a list or an array. 80 For 'classification' evaluation type, `y_pred` is in most cases (not strictly) a list 81 of floating numbers in range :math:`[0, 1]` 82 and the shape is :math:`(N, C)`, where :math:`N` is the number of cases and :math:`C` 83 is the number of categories. Shape of `y` can be :math:`(N, C)` with values 0 and 1 if one-hot 84 encoding is used or the shape is :math:`(N,)` with integer values if index of category is used. 85 For 'multilabel' evaluation type, `y_pred` and `y` can only be one-hot encoding with 86 values 0 or 1. Indices with 1 indicate positive category. The shape of `y_pred` and `y` 87 are both :math:`(N, C)`. 88 89 90 Raises: 91 ValueError: If the number of input is not 2. 92 """ 93 if len(inputs) != 2: 94 raise ValueError('The recall needs 2 inputs (y_pred, y), but got {}'.format(len(inputs))) 95 y_pred = self._convert_data(inputs[0]) 96 y = self._convert_data(inputs[1]) 97 if self._type == 'classification' and y_pred.ndim == y.ndim and self._check_onehot_data(y): 98 y = y.argmax(axis=1) 99 self._check_shape(y_pred, y) 100 self._check_value(y_pred, y) 101 102 if self._class_num == 0: 103 self._class_num = y_pred.shape[1] 104 elif y_pred.shape[1] != self._class_num: 105 raise ValueError('The class number does not match, the last input data contains {} classes, ' 106 'but the current data contains {} classes'.format(self._class_num, y_pred.shape[1])) 107 108 class_num = self._class_num 109 if self._type == "classification": 110 if y.max() + 1 > class_num: 111 raise ValueError('y_pred contains {} classes less than y contains {} classes.'. 112 format(class_num, y.max() + 1)) 113 y = np.eye(class_num)[y.reshape(-1)] 114 indices = y_pred.argmax(axis=1).reshape(-1) 115 y_pred = np.eye(class_num)[indices] 116 elif self._type == "multilabel": 117 y_pred = y_pred.swapaxes(1, 0).reshape(class_num, -1) 118 y = y.swapaxes(1, 0).reshape(class_num, -1) 119 120 actual_positives = y.sum(axis=0) 121 true_positives = (y * y_pred).sum(axis=0) 122 123 if self._type == "multilabel": 124 self._true_positives_average += np.sum(true_positives / (actual_positives + self.eps)) 125 self._actual_positives_average += len(actual_positives) 126 self._true_positives = np.concatenate((self._true_positives, true_positives), axis=0) 127 self._actual_positives = np.concatenate((self._actual_positives, actual_positives), axis=0) 128 else: 129 self._true_positives += true_positives 130 self._actual_positives += actual_positives 131 132 def eval(self, average=False): 133 """ 134 Computes the recall. 135 136 Args: 137 average (bool): Specify whether calculate the average recall. Default value is False. 138 139 Returns: 140 Float, the computed result. 141 """ 142 if self._class_num == 0: 143 raise RuntimeError('The input number of samples can not be 0.') 144 145 validator.check_value_type("average", average, [bool], self.__class__.__name__) 146 result = self._true_positives / (self._actual_positives + self.eps) 147 148 if average: 149 if self._type == "multilabel": 150 result = self._true_positives_average / (self._actual_positives_average + self.eps) 151 return result.mean() 152 return result 153