• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""auc"""
16import numpy as np
17
18
19def auc(x, y, reorder=False):
20    """
21    Computes the AUC(Area Under the Curve) using the trapezoidal rule. This is a general function, given points on a
22    curve. For computing the area under the ROC-curve.
23
24    Args:
25        x (Union[np.array, list]): From the ROC curve(fpr), np.array with false positive rates. If multiclass,
26                                   this is a list of such np.array, one for each class. The shape :math:`(N)`.
27        y (Union[np.array, list]): From the ROC curve(tpr), np.array with true positive rates. If multiclass,
28                                   this is a list of such np.array, one for each class. The shape :math:`(N)`.
29        reorder (boolean): If True, assume that the curve is ascending in the case of ties, as for an ROC curve.
30                           If the curve is non-ascending, the result will be wrong. Default: False.
31
32    Returns:
33        area (float): Compute result.
34
35    Supported Platforms:
36        ``Ascend`` ``GPU`` ``CPU``
37
38    Examples:
39        >>> import numpy as np
40        >>> from mindspore import nn
41        >>>
42        >>> y_pred = np.array([[3, 0, 1], [1, 3, 0], [1, 0, 2]])
43        >>> y = np.array([[0, 2, 1], [1, 2, 1], [0, 0, 1]])
44        >>> metric = nn.ROC(pos_label=2)
45        >>> metric.clear()
46        >>> metric.update(y_pred, y)
47        >>> fpr, tpr, thre = metric.eval()
48        >>> output = auc(fpr, tpr)
49        >>> print(output)
50        0.5357142857142857
51    """
52    if not isinstance(x, np.ndarray) or not isinstance(y, np.ndarray):
53        raise TypeError('The inputs must be np.ndarray, but got {}, {}'.format(type(x), type(y)))
54    _check_consistent_length(x, y)
55    x = _column_or_1d(x)
56    y = _column_or_1d(y)
57
58    if x.shape[0] < 2:
59        raise ValueError('At least 2 points are needed to compute the AUC, but x.shape = {}.'.format(x.shape))
60
61    direction = 1
62    if reorder:
63        order = np.lexsort((y, x))
64        x, y = x[order], y[order]
65    else:
66        dx = np.diff(x)
67        if np.any(dx < 0):
68            if np.all(dx <= 0):
69                direction = -1
70            else:
71                raise ValueError("Reordering is not turned on, and the x array is not increasing:{}".format(x))
72
73    area = direction * np.trapz(y, x)
74    if isinstance(area, np.memmap):
75        area = area.dtype.type(area)
76    return area
77
78
79def _column_or_1d(y):
80    """
81     Ravel column or 1D numpy array, otherwise raise a ValueError.
82    """
83    shape = np.shape(y)
84    if len(shape) == 1 or (len(shape) == 2 and shape[1] == 1):
85        return np.ravel(y)
86
87    raise ValueError("Bad input shape {0}.".format(shape))
88
89
90def _num_samples(x):
91    """Return the number of samples in array-like x."""
92    if hasattr(x, 'fit') and callable(x.fit):
93        raise TypeError('Expected sequence or array-like, got estimator {}.'.format(x))
94    if not hasattr(x, '__len__') and not hasattr(x, 'shape'):
95        if hasattr(x, '__array__'):
96            x = np.asarray(x)
97        else:
98            raise TypeError("Expected sequence or array-like, got {}." .format(type(x)))
99    if hasattr(x, 'shape'):
100        if x.ndim == 0:
101            raise TypeError("Singleton array {} cannot be considered as a valid collection.".format(x))
102        res = x.shape[0]
103    else:
104        res = x.size
105
106    return res
107
108
109def _check_consistent_length(*arrays):
110    r"""
111    Check that all arrays have consistent first dimensions. Check whether all objects in arrays have the same shape
112    or length.
113
114    Args:
115        - **(*arrays)** - (Union[tuple, list]): list or tuple of input objects. Objects that will be checked for
116            consistent length.
117    """
118
119    lengths = [_num_samples(array) for array in arrays if array is not None]
120    uniques = np.unique(lengths)
121    if len(uniques) > 1:
122        raise ValueError("Found input variables with inconsistent numbers of samples: {}."
123                         .format([int(length) for length in lengths]))
124