1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""ROC""" 16import numpy as np 17from mindspore._checkparam import Validator as validator 18from .metric import Metric, rearrange_inputs 19 20 21class ROC(Metric): 22 """ 23 Calculates the ROC curve. It is suitable for solving binary classification and multi classification problems. 24 In the case of multiclass, the values will be calculated based on a one-vs-the-rest approach. 25 26 Args: 27 class_num (int): Integer with the number of classes. For the problem of binary classification, it is not 28 necessary to provide this argument. Default: None. 29 pos_label (int): Determine the integer of positive class. Default: None. For binary problems, it is translated 30 to 1. For multiclass problems, this argument should not be set, as it is iteratively changed in the 31 range [0,num_classes-1]. Default: None. 32 33 Supported Platforms: 34 ``Ascend`` ``GPU`` ``CPU`` 35 36 Examples: 37 >>> import numpy as np 38 >>> from mindspore import nn, Tensor 39 >>> 40 >>> # 1) binary classification example 41 >>> x = Tensor(np.array([3, 1, 4, 2])) 42 >>> y = Tensor(np.array([0, 1, 2, 3])) 43 >>> metric = nn.ROC(pos_label=2) 44 >>> metric.clear() 45 >>> metric.update(x, y) 46 >>> fpr, tpr, thresholds = metric.eval() 47 >>> print(fpr) 48 [0. 0. 0.33333333 0.6666667 1.] 49 >>> print(tpr) 50 [0. 1. 1. 1. 1.] 51 >>> print(thresholds) 52 [5 4 3 2 1] 53 >>> 54 >>> # 2) multiclass classification example 55 >>> x = Tensor(np.array([[0.28, 0.55, 0.15, 0.05], [0.10, 0.20, 0.05, 0.05], [0.20, 0.05, 0.15, 0.05], 56 ... [0.05, 0.05, 0.05, 0.75]])) 57 >>> y = Tensor(np.array([0, 1, 2, 3])) 58 >>> metric = nn.ROC(class_num=4) 59 >>> metric.clear() 60 >>> metric.update(x, y) 61 >>> fpr, tpr, thresholds = metric.eval() 62 >>> print(fpr) 63 [array([0., 0., 0.33333333, 0.66666667, 1.]), array([0., 0.33333333, 0.33333333, 1.]), 64 array([0., 0.33333333, 1.]), array([0., 0., 1.])] 65 >>> print(tpr) 66 [array([0., 1., 1., 1., 1.]), array([0., 0., 1., 1.]), array([0., 1., 1.]), array([0., 1., 1.])] 67 >>> print(thresholds) 68 [array([1.28, 0.28, 0.2, 0.1, 0.05]), array([1.55, 0.55, 0.2, 0.05]), array([1.15, 0.15, 0.05]), 69 array([1.75, 0.75, 0.05])] 70 """ 71 def __init__(self, class_num=None, pos_label=None): 72 super().__init__() 73 self.class_num = class_num if class_num is None else validator.check_value_type("class_num", class_num, [int]) 74 self.pos_label = pos_label if pos_label is None else validator.check_value_type("pos_label", pos_label, [int]) 75 self.clear() 76 77 def clear(self): 78 """Clear the internal evaluation result.""" 79 self.y_pred = 0 80 self.y = 0 81 self.sample_weights = None 82 self._is_update = False 83 84 def _precision_recall_curve_update(self, y_pred, y, class_num, pos_label): 85 """update curve""" 86 if not (len(y_pred.shape) == len(y.shape) or len(y_pred.shape) == len(y.shape) + 1): 87 raise ValueError("y_pred and y must have the same number of dimensions, or one additional dimension for" 88 " y_pred.") 89 90 # single class evaluation 91 if len(y_pred.shape) == len(y.shape): 92 if class_num is not None and class_num != 1: 93 raise ValueError('The y_pred and y should have the same shape, ' 94 'but the number of classes is different from 1.') 95 class_num = 1 96 if pos_label is None: 97 pos_label = 1 98 y_pred = y_pred.flatten() 99 y = y.flatten() 100 101 # multi class evaluation 102 elif len(y_pred.shape) == len(y.shape) + 1: 103 if pos_label is not None: 104 raise ValueError('Argument `pos_label` should be `None` when running multiclass precision recall ' 105 'curve, but got {}.'.format(pos_label)) 106 if class_num != y_pred.shape[1]: 107 raise ValueError('Argument `class_num` was set to {}, but detected {} number of classes from ' 108 'predictions.'.format(class_num, y_pred.shape[1])) 109 y_pred = y_pred.transpose(0, 1).reshape(class_num, -1).transpose(0, 1) 110 y = y.flatten() 111 112 return y_pred, y, class_num, pos_label 113 114 @rearrange_inputs 115 def update(self, *inputs): 116 """ 117 Update state with predictions and targets. 118 119 Args: 120 inputs: Input `y_pred` and `y`. `y_pred` and `y` are Tensor, list or numpy.ndarray. 121 In most cases (not strictly), y_pred is a list of floating numbers in range :math:`[0, 1]` 122 and the shape is :math:`(N, C)`, where :math:`N` is the number of cases and :math:`C` 123 is the number of categories. y contains values of integers. 124 """ 125 if len(inputs) != 2: 126 raise ValueError('ROC need 2 inputs (y_pred, y), but got {}'.format(len(inputs))) 127 y_pred = self._convert_data(inputs[0]) 128 y = self._convert_data(inputs[1]) 129 130 y_pred, y, class_num, pos_label = self._precision_recall_curve_update(y_pred, y, self.class_num, self.pos_label) 131 132 self.y_pred = y_pred 133 self.y = y 134 self.class_num = class_num 135 self.pos_label = pos_label 136 self._is_update = True 137 138 def _roc_eval(self, y_pred, y, class_num, pos_label, sample_weights=None): 139 """Computes the ROC curve.""" 140 if class_num == 1: 141 fps, tps, thresholds = self._binary_clf_curve(y_pred, y, sample_weights=sample_weights, 142 pos_label=pos_label) 143 tps = np.squeeze(np.hstack([np.zeros(1, dtype=tps.dtype), tps])) 144 fps = np.squeeze(np.hstack([np.zeros(1, dtype=fps.dtype), fps])) 145 thresholds = np.hstack([thresholds[0][None] + 1, thresholds]) 146 147 if fps[-1] <= 0: 148 raise ValueError("No negative samples in y, false positive value should be meaningless.") 149 fpr = fps / fps[-1] 150 151 if tps[-1] <= 0: 152 raise ValueError("No positive samples in y, true positive value should be meaningless.") 153 tpr = tps / tps[-1] 154 155 return fpr, tpr, thresholds 156 157 fpr, tpr, thresholds = [], [], [] 158 for c in range(class_num): 159 preds_c = y_pred[:, c] 160 res = self.roc(preds_c, y, class_num=1, pos_label=c, sample_weights=sample_weights) 161 fpr.append(res[0]) 162 tpr.append(res[1]) 163 thresholds.append(res[2]) 164 165 return fpr, tpr, thresholds 166 167 def roc(self, y_pred, y, class_num=None, pos_label=None, sample_weights=None): 168 """roc""" 169 y_pred, y, class_num, pos_label = self._precision_recall_curve_update(y_pred, y, class_num, pos_label) 170 171 return self._roc_eval(y_pred, y, class_num, pos_label, sample_weights) 172 173 def eval(self): 174 """ 175 Computes the ROC curve. 176 177 Returns: 178 A tuple, composed of `fpr`, `tpr`, and `thresholds`. 179 180 - **fpr** (np.array) - np.array with false positive rates. If multiclass, this is a list of such np.array, 181 one for each class. 182 - **tps** (np.array) - np.array with true positive rates. If multiclass, this is a list of such np.array, 183 one for each class. 184 - **thresholds** (np.array) - thresholds used for computing false- and true positive rates. 185 186 Raises: 187 RuntimeError: If the update method is not called first, an error will be reported. 188 189 """ 190 if self._is_update is False: 191 raise RuntimeError('Call the update method before calling eval.') 192 193 y_pred = np.squeeze(np.vstack(self.y_pred)) 194 y = np.squeeze(np.vstack(self.y)) 195 196 return self._roc_eval(y_pred, y, self.class_num, self.pos_label) 197