• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""ROC"""
16import numpy as np
17from mindspore._checkparam import Validator as validator
18from .metric import Metric, rearrange_inputs
19
20
21class ROC(Metric):
22    """
23    Calculates the ROC curve. It is suitable for solving binary classification and multi classification problems.
24    In the case of multiclass, the values will be calculated based on a one-vs-the-rest approach.
25
26    Args:
27        class_num (int): Integer with the number of classes. For the problem of binary classification, it is not
28            necessary to provide this argument. Default: None.
29        pos_label (int): Determine the integer of positive class. Default: None. For binary problems, it is translated
30            to 1. For multiclass problems, this argument should not be set, as it is iteratively changed in the
31            range [0,num_classes-1]. Default: None.
32
33    Supported Platforms:
34        ``Ascend`` ``GPU`` ``CPU``
35
36    Examples:
37        >>> import numpy as np
38        >>> from mindspore import nn, Tensor
39        >>>
40        >>> # 1) binary classification example
41        >>> x = Tensor(np.array([3, 1, 4, 2]))
42        >>> y = Tensor(np.array([0, 1, 2, 3]))
43        >>> metric = nn.ROC(pos_label=2)
44        >>> metric.clear()
45        >>> metric.update(x, y)
46        >>> fpr, tpr, thresholds = metric.eval()
47        >>> print(fpr)
48        [0. 0. 0.33333333 0.6666667 1.]
49        >>> print(tpr)
50        [0. 1. 1. 1. 1.]
51        >>> print(thresholds)
52        [5 4 3 2 1]
53        >>>
54        >>> # 2) multiclass classification example
55        >>> x = Tensor(np.array([[0.28, 0.55, 0.15, 0.05], [0.10, 0.20, 0.05, 0.05], [0.20, 0.05, 0.15, 0.05],
56        ...                     [0.05, 0.05, 0.05, 0.75]]))
57        >>> y = Tensor(np.array([0, 1, 2, 3]))
58        >>> metric = nn.ROC(class_num=4)
59        >>> metric.clear()
60        >>> metric.update(x, y)
61        >>> fpr, tpr, thresholds = metric.eval()
62        >>> print(fpr)
63        [array([0., 0., 0.33333333, 0.66666667, 1.]), array([0., 0.33333333, 0.33333333, 1.]),
64        array([0., 0.33333333, 1.]), array([0., 0., 1.])]
65        >>> print(tpr)
66        [array([0., 1., 1., 1., 1.]), array([0., 0., 1., 1.]), array([0., 1., 1.]), array([0., 1., 1.])]
67        >>> print(thresholds)
68        [array([1.28, 0.28, 0.2, 0.1, 0.05]), array([1.55, 0.55, 0.2, 0.05]), array([1.15, 0.15, 0.05]),
69        array([1.75, 0.75, 0.05])]
70    """
71    def __init__(self, class_num=None, pos_label=None):
72        super().__init__()
73        self.class_num = class_num if class_num is None else validator.check_value_type("class_num", class_num, [int])
74        self.pos_label = pos_label if pos_label is None else validator.check_value_type("pos_label", pos_label, [int])
75        self.clear()
76
77    def clear(self):
78        """Clear the internal evaluation result."""
79        self.y_pred = 0
80        self.y = 0
81        self.sample_weights = None
82        self._is_update = False
83
84    def _precision_recall_curve_update(self, y_pred, y, class_num, pos_label):
85        """update curve"""
86        if not (len(y_pred.shape) == len(y.shape) or len(y_pred.shape) == len(y.shape) + 1):
87            raise ValueError("y_pred and y must have the same number of dimensions, or one additional dimension for"
88                             " y_pred.")
89
90        # single class evaluation
91        if len(y_pred.shape) == len(y.shape):
92            if class_num is not None and class_num != 1:
93                raise ValueError('The y_pred and y should have the same shape, '
94                                 'but the number of classes is different from 1.')
95            class_num = 1
96            if pos_label is None:
97                pos_label = 1
98            y_pred = y_pred.flatten()
99            y = y.flatten()
100
101        # multi class evaluation
102        elif len(y_pred.shape) == len(y.shape) + 1:
103            if pos_label is not None:
104                raise ValueError('Argument `pos_label` should be `None` when running multiclass precision recall '
105                                 'curve, but got {}.'.format(pos_label))
106            if class_num != y_pred.shape[1]:
107                raise ValueError('Argument `class_num` was set to {}, but detected {} number of classes from '
108                                 'predictions.'.format(class_num, y_pred.shape[1]))
109            y_pred = y_pred.transpose(0, 1).reshape(class_num, -1).transpose(0, 1)
110            y = y.flatten()
111
112        return y_pred, y, class_num, pos_label
113
114    @rearrange_inputs
115    def update(self, *inputs):
116        """
117        Update state with predictions and targets.
118
119        Args:
120            inputs: Input `y_pred` and `y`. `y_pred` and `y` are Tensor, list or numpy.ndarray.
121                In most cases (not strictly), y_pred is a list of floating numbers in range :math:`[0, 1]`
122                and the shape is :math:`(N, C)`, where :math:`N` is the number of cases and :math:`C`
123                is the number of categories. y contains values of integers.
124        """
125        if len(inputs) != 2:
126            raise ValueError('ROC need 2 inputs (y_pred, y), but got {}'.format(len(inputs)))
127        y_pred = self._convert_data(inputs[0])
128        y = self._convert_data(inputs[1])
129
130        y_pred, y, class_num, pos_label = self._precision_recall_curve_update(y_pred, y, self.class_num, self.pos_label)
131
132        self.y_pred = y_pred
133        self.y = y
134        self.class_num = class_num
135        self.pos_label = pos_label
136        self._is_update = True
137
138    def _roc_eval(self, y_pred, y, class_num, pos_label, sample_weights=None):
139        """Computes the ROC curve."""
140        if class_num == 1:
141            fps, tps, thresholds = self._binary_clf_curve(y_pred, y, sample_weights=sample_weights,
142                                                          pos_label=pos_label)
143            tps = np.squeeze(np.hstack([np.zeros(1, dtype=tps.dtype), tps]))
144            fps = np.squeeze(np.hstack([np.zeros(1, dtype=fps.dtype), fps]))
145            thresholds = np.hstack([thresholds[0][None] + 1, thresholds])
146
147            if fps[-1] <= 0:
148                raise ValueError("No negative samples in y, false positive value should be meaningless.")
149            fpr = fps / fps[-1]
150
151            if tps[-1] <= 0:
152                raise ValueError("No positive samples in y, true positive value should be meaningless.")
153            tpr = tps / tps[-1]
154
155            return fpr, tpr, thresholds
156
157        fpr, tpr, thresholds = [], [], []
158        for c in range(class_num):
159            preds_c = y_pred[:, c]
160            res = self.roc(preds_c, y, class_num=1, pos_label=c, sample_weights=sample_weights)
161            fpr.append(res[0])
162            tpr.append(res[1])
163            thresholds.append(res[2])
164
165        return fpr, tpr, thresholds
166
167    def roc(self, y_pred, y, class_num=None, pos_label=None, sample_weights=None):
168        """roc"""
169        y_pred, y, class_num, pos_label = self._precision_recall_curve_update(y_pred, y, class_num, pos_label)
170
171        return self._roc_eval(y_pred, y, class_num, pos_label, sample_weights)
172
173    def eval(self):
174        """
175        Computes the ROC curve.
176
177        Returns:
178            A tuple, composed of `fpr`, `tpr`, and `thresholds`.
179
180            - **fpr** (np.array) - np.array with false positive rates. If multiclass, this is a list of such np.array,
181              one for each class.
182            - **tps** (np.array) - np.array with true positive rates. If multiclass, this is a list of such np.array,
183              one for each class.
184            - **thresholds** (np.array) - thresholds used for computing false- and true positive rates.
185
186        Raises:
187            RuntimeError: If the update method is not called first, an error will be reported.
188
189        """
190        if self._is_update is False:
191            raise RuntimeError('Call the update method before calling eval.')
192
193        y_pred = np.squeeze(np.vstack(self.y_pred))
194        y = np.squeeze(np.vstack(self.y))
195
196        return self._roc_eval(y_pred, y, self.class_num, self.pos_label)
197