1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""auc""" 16import numpy as np 17 18 19def auc(x, y, reorder=False): 20 """ 21 Computes the AUC(Area Under the Curve) using the trapezoidal rule. This is a general function, given points on a 22 curve. For computing the area under the ROC-curve. 23 24 Args: 25 x (Union[np.array, list]): From the ROC curve(fpr), np.array with false positive rates. If multiclass, 26 this is a list of such np.array, one for each class. The shape :math:`(N)`. 27 y (Union[np.array, list]): From the ROC curve(tpr), np.array with true positive rates. If multiclass, 28 this is a list of such np.array, one for each class. The shape :math:`(N)`. 29 reorder (boolean): If True, assume that the curve is ascending in the case of ties, as for an ROC curve. 30 If the curve is non-ascending, the result will be wrong. Default: False. 31 32 Returns: 33 area (float): Compute result. 34 35 Supported Platforms: 36 ``Ascend`` ``GPU`` ``CPU`` 37 38 Examples: 39 >>> import numpy as np 40 >>> from mindspore import nn 41 >>> 42 >>> y_pred = np.array([[3, 0, 1], [1, 3, 0], [1, 0, 2]]) 43 >>> y = np.array([[0, 2, 1], [1, 2, 1], [0, 0, 1]]) 44 >>> metric = nn.ROC(pos_label=2) 45 >>> metric.clear() 46 >>> metric.update(y_pred, y) 47 >>> fpr, tpr, thre = metric.eval() 48 >>> output = auc(fpr, tpr) 49 >>> print(output) 50 0.5357142857142857 51 """ 52 if not isinstance(x, np.ndarray) or not isinstance(y, np.ndarray): 53 raise TypeError('The inputs must be np.ndarray, but got {}, {}'.format(type(x), type(y))) 54 _check_consistent_length(x, y) 55 x = _column_or_1d(x) 56 y = _column_or_1d(y) 57 58 if x.shape[0] < 2: 59 raise ValueError('At least 2 points are needed to compute the AUC, but x.shape = {}.'.format(x.shape)) 60 61 direction = 1 62 if reorder: 63 order = np.lexsort((y, x)) 64 x, y = x[order], y[order] 65 else: 66 dx = np.diff(x) 67 if np.any(dx < 0): 68 if np.all(dx <= 0): 69 direction = -1 70 else: 71 raise ValueError("Reordering is not turned on, and the x array is not increasing:{}".format(x)) 72 73 area = direction * np.trapz(y, x) 74 if isinstance(area, np.memmap): 75 area = area.dtype.type(area) 76 return area 77 78 79def _column_or_1d(y): 80 """ 81 Ravel column or 1D numpy array, otherwise raise a ValueError. 82 """ 83 shape = np.shape(y) 84 if len(shape) == 1 or (len(shape) == 2 and shape[1] == 1): 85 return np.ravel(y) 86 87 raise ValueError("Bad input shape {0}.".format(shape)) 88 89 90def _num_samples(x): 91 """Return the number of samples in array-like x.""" 92 if hasattr(x, 'fit') and callable(x.fit): 93 raise TypeError('Expected sequence or array-like, got estimator {}.'.format(x)) 94 if not hasattr(x, '__len__') and not hasattr(x, 'shape'): 95 if hasattr(x, '__array__'): 96 x = np.asarray(x) 97 else: 98 raise TypeError("Expected sequence or array-like, got {}." .format(type(x))) 99 if hasattr(x, 'shape'): 100 if x.ndim == 0: 101 raise TypeError("Singleton array {} cannot be considered as a valid collection.".format(x)) 102 res = x.shape[0] 103 else: 104 res = x.size 105 106 return res 107 108 109def _check_consistent_length(*arrays): 110 r""" 111 Check that all arrays have consistent first dimensions. Check whether all objects in arrays have the same shape 112 or length. 113 114 Args: 115 - **(*arrays)** - (Union[tuple, list]): list or tuple of input objects. Objects that will be checked for 116 consistent length. 117 """ 118 119 lengths = [_num_samples(array) for array in arrays if array is not None] 120 uniques = np.unique(lengths) 121 if len(uniques) > 1: 122 raise ValueError("Found input variables with inconsistent numbers of samples: {}." 123 .format([int(length) for length in lengths])) 124