1# Copyright 2020-2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Metric base class.""" 16from abc import ABCMeta, abstractmethod 17import functools 18import numpy as np 19from mindspore.common.tensor import Tensor 20 21_eval_types = {'classification', 'multilabel'} 22 23 24def rearrange_inputs(func): 25 """ 26 This decorator is used to rearrange the inputs according to its `_indexes` attributes 27 which is specified by the `set_indexes` method. 28 29 Examples: 30 >>> class RearrangeInputsExample: 31 ... def __init__(self): 32 ... self._indexes = None 33 ... 34 ... @property 35 ... def indexes(self): 36 ... return getattr(self, '_indexes', None) 37 ... 38 ... def set_indexes(self, indexes): 39 ... self._indexes = indexes 40 ... return self 41 ... 42 ... @rearrange_inputs 43 ... def update(self, *inputs): 44 ... return inputs 45 >>> 46 >>> rearrange_inputs_example = RearrangeInputsExample().set_indexes([1, 0]) 47 >>> outs = rearrange_inputs_example.update(5, 9) 48 >>> print(outs) 49 (9, 5) 50 51 Args: 52 func (Callable): A candidate function to be wrapped whose input will be rearranged. 53 54 Returns: 55 Callable, used to exchange metadata between functions. 56 """ 57 @functools.wraps(func) 58 def wrapper(self, *inputs): 59 indexes = self.indexes 60 inputs = inputs if not indexes else [inputs[i] for i in indexes] 61 return func(self, *inputs) 62 return wrapper 63 64 65class Metric(metaclass=ABCMeta): 66 """ 67 Base class of metric. 68 69 Note: 70 For examples of subclasses, please refer to the definition of class `MAE`, `Recall` etc. 71 """ 72 def __init__(self): 73 self._indexes = None 74 75 def _convert_data(self, data): 76 """ 77 Convert data type to numpy array. 78 79 Args: 80 data (Object): Input data. 81 82 Returns: 83 Ndarray, data with `np.ndarray` type. 84 """ 85 if isinstance(data, Tensor): 86 data = data.asnumpy() 87 elif isinstance(data, list): 88 data = np.array(data) 89 elif isinstance(data, np.ndarray): 90 pass 91 else: 92 raise TypeError('The input data type must be a tensor, list or numpy.ndarray') 93 return data 94 95 def _check_onehot_data(self, data): 96 """ 97 Whether input data is one-hot encoding. 98 99 Args: 100 data (numpy.array): Input data. 101 102 Returns: 103 bool, return true, if input data is one-hot encoding. 104 """ 105 if data.ndim > 1 and np.equal(data ** 2, data).all(): 106 shp = (data.shape[0],) + data.shape[2:] 107 if np.equal(np.ones(shp), data.sum(axis=1)).all(): 108 return True 109 return False 110 111 def _binary_clf_curve(self, preds, target, sample_weights=None, pos_label=1): 112 """Calculate True Positives and False Positives per binary classification threshold.""" 113 if sample_weights is not None and not isinstance(sample_weights, np.ndarray): 114 sample_weights = np.array(sample_weights) 115 116 if preds.ndim > target.ndim: 117 preds = preds[:, 0] 118 desc_score_indices = np.argsort(-preds) 119 120 preds = preds[desc_score_indices] 121 target = target[desc_score_indices] 122 123 if sample_weights is not None: 124 weight = sample_weights[desc_score_indices] 125 else: 126 weight = 1. 127 128 distinct_value_indices = np.where(preds[1:] - preds[:-1])[0] 129 threshold_idxs = np.pad(distinct_value_indices, (0, 1), constant_values=target.shape[0] - 1) 130 target = np.array(target == pos_label).astype(np.int64) 131 tps = np.cumsum(target * weight, axis=0)[threshold_idxs] 132 133 if sample_weights is not None: 134 fps = np.cumsum((1 - target) * weight, axis=0)[threshold_idxs] 135 else: 136 fps = 1 + threshold_idxs - tps 137 138 return fps, tps, preds[threshold_idxs] 139 140 @property 141 def indexes(self): 142 """The `_indexes` is a private attribute, and you can retrieve it by `self.indexes`. 143 """ 144 return getattr(self, '_indexes', None) 145 146 def set_indexes(self, indexes): 147 """ 148 The `_indexes` is a private attribute and you can modify it by this function. 149 This allows you to determine the order of logits and labels to be calculated in the 150 inputs, specially when you call the method `update` within this metrics. 151 152 Note: 153 It has been applied in subclass of Metric, eg. `Accuracy`, `BleuScore`, `ConfusionMatrix`, 154 `CosineSimilarity`, `MAE`, and `MSE`. 155 156 Args: 157 indexes (List(int)): The order of logits and labels to be rearranged. 158 159 Outputs: 160 :class:`Metric`, its original Class instance. 161 162 Examples: 163 >>> import numpy as np 164 >>> from mindspore import nn, Tensor 165 >>> 166 >>> x = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]])) 167 >>> y = Tensor(np.array([1, 0, 1])) 168 >>> y2 = Tensor(np.array([0, 0, 1])) 169 >>> metric = nn.Accuracy('classification').set_indexes([0, 2]) 170 >>> metric.clear() 171 >>> metric.update(x, y, y2) 172 >>> accuracy = metric.eval() 173 >>> print(accuracy) 174 0.3333333333333333 175 """ 176 if not isinstance(indexes, list) or not all(isinstance(i, int) for i in indexes): 177 raise ValueError("The indexes should be a list and all its elements should be int") 178 self._indexes = indexes 179 return self 180 181 def __call__(self, *inputs): 182 """ 183 Evaluate input data once. 184 185 Args: 186 inputs (tuple): The first item is a predict array, the second item is a target array. 187 188 Returns: 189 Float, compute result. 190 """ 191 self.clear() 192 self.update(*inputs) 193 return self.eval() 194 195 @abstractmethod 196 def clear(self): 197 """ 198 An interface describes the behavior of clearing the internal evaluation result. 199 200 Note: 201 All subclasses must override this interface. 202 """ 203 raise NotImplementedError('Must define clear function to use this base class') 204 205 @abstractmethod 206 def eval(self): 207 """ 208 An interface describes the behavior of computing the evaluation result. 209 210 Note: 211 All subclasses must override this interface. 212 """ 213 raise NotImplementedError('Must define eval function to use this base class') 214 215 @abstractmethod 216 def update(self, *inputs): 217 """ 218 An interface describes the behavior of updating the internal evaluation result. 219 220 Note: 221 All subclasses must override this interface. 222 223 Args: 224 inputs: A variable-length input argument list. 225 """ 226 raise NotImplementedError('Must define update function to use this base class') 227 228 229class EvaluationBase(Metric): 230 """ 231 Base class of evaluation. 232 233 Note: 234 Please refer to the definition of class `Accuracy`. 235 236 Args: 237 eval_type (str): Type of evaluation must be in {'classification', 'multilabel'}. 238 239 Raises: 240 TypeError: If the input type is not classification or multilabel. 241 """ 242 def __init__(self, eval_type): 243 super(EvaluationBase, self).__init__() 244 if eval_type not in _eval_types: 245 raise TypeError('Type must be in {}, but got {}'.format(_eval_types, eval_type)) 246 self._type = eval_type 247 248 def _check_shape(self, y_pred, y): 249 """ 250 Checks the shapes of y_pred and y. 251 252 Args: 253 y_pred (Tensor): Predict array. 254 y (Tensor): Target array. 255 """ 256 if self._type == 'classification': 257 if y_pred.ndim != y.ndim + 1: 258 raise ValueError('Classification case, dims of y_pred equal dims of y add 1, ' 259 'but got y_pred: {} dims and y: {} dims'.format(y_pred.ndim, y.ndim)) 260 if y.shape != (y_pred.shape[0],) + y_pred.shape[2:]: 261 raise ValueError('Classification case, y_pred shape and y shape can not match. ' 262 'got y_pred shape is {} and y shape is {}'.format(y_pred.shape, y.shape)) 263 else: 264 if y_pred.ndim != y.ndim: 265 raise ValueError('{} case, dims of y_pred must be equal to dims of y, but got y_pred: {} ' 266 'dims and y: {} dims.'.format(self._type, y_pred.ndim, y.ndim)) 267 if y_pred.shape != y.shape: 268 raise ValueError('{} case, y_pred shape must be equal to y shape, but got y_pred: {} and y: {}'. 269 format(self._type, y_pred.shape, y.shape)) 270 271 def _check_value(self, y_pred, y): 272 """ 273 Checks the values of y_pred and y. 274 275 Args: 276 y_pred (Tensor): Predict array. 277 y (Tensor): Target array. 278 """ 279 if self._type != 'classification' and not (np.equal(y_pred ** 2, y_pred).all() and np.equal(y ** 2, y).all()): 280 raise ValueError('For multilabel case, input value must be 1 or 0.') 281 282 def clear(self): 283 """ 284 A interface describes the behavior of clearing the internal evaluation result. 285 286 Note: 287 All subclasses must override this interface. 288 """ 289 raise NotImplementedError 290 291 def update(self, *inputs): 292 """ 293 A interface describes the behavior of updating the internal evaluation result. 294 295 Note: 296 All subclasses must override this interface. 297 298 Args: 299 inputs: The first item is a predicted array and the second item is a target array. 300 """ 301 raise NotImplementedError 302 303 def eval(self): 304 """ 305 A interface describes the behavior of computing the evaluation result. 306 307 Note: 308 All subclasses must override this interface. 309 """ 310 raise NotImplementedError 311