• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Metric base class."""
16from abc import ABCMeta, abstractmethod
17import functools
18import numpy as np
19from mindspore.common.tensor import Tensor
20
21_eval_types = {'classification', 'multilabel'}
22
23
24def rearrange_inputs(func):
25    """
26    This decorator is used to rearrange the inputs according to its `_indexes` attributes
27    which is specified by the `set_indexes` method.
28
29    Examples:
30        >>> class RearrangeInputsExample:
31        ...     def __init__(self):
32        ...         self._indexes = None
33        ...
34        ...     @property
35        ...     def indexes(self):
36        ...         return getattr(self, '_indexes', None)
37        ...
38        ...     def set_indexes(self, indexes):
39        ...         self._indexes = indexes
40        ...         return self
41        ...
42        ...     @rearrange_inputs
43        ...     def update(self, *inputs):
44        ...         return inputs
45        >>>
46        >>> rearrange_inputs_example = RearrangeInputsExample().set_indexes([1, 0])
47        >>> outs = rearrange_inputs_example.update(5, 9)
48        >>> print(outs)
49        (9, 5)
50
51    Args:
52        func (Callable): A candidate function to be wrapped whose input will be rearranged.
53
54    Returns:
55        Callable, used to exchange metadata between functions.
56    """
57    @functools.wraps(func)
58    def wrapper(self, *inputs):
59        indexes = self.indexes
60        inputs = inputs if not indexes else [inputs[i] for i in indexes]
61        return func(self, *inputs)
62    return wrapper
63
64
65class Metric(metaclass=ABCMeta):
66    """
67    Base class of metric.
68
69    Note:
70        For examples of subclasses, please refer to the definition of class `MAE`, `Recall` etc.
71    """
72    def __init__(self):
73        self._indexes = None
74
75    def _convert_data(self, data):
76        """
77        Convert data type to numpy array.
78
79        Args:
80            data (Object): Input data.
81
82        Returns:
83            Ndarray, data with `np.ndarray` type.
84        """
85        if isinstance(data, Tensor):
86            data = data.asnumpy()
87        elif isinstance(data, list):
88            data = np.array(data)
89        elif isinstance(data, np.ndarray):
90            pass
91        else:
92            raise TypeError('The input data type must be a tensor, list or numpy.ndarray')
93        return data
94
95    def _check_onehot_data(self, data):
96        """
97        Whether input data is one-hot encoding.
98
99        Args:
100            data (numpy.array): Input data.
101
102        Returns:
103            bool, return true, if input data is one-hot encoding.
104        """
105        if data.ndim > 1 and np.equal(data ** 2, data).all():
106            shp = (data.shape[0],) + data.shape[2:]
107            if np.equal(np.ones(shp), data.sum(axis=1)).all():
108                return True
109        return False
110
111    def _binary_clf_curve(self, preds, target, sample_weights=None, pos_label=1):
112        """Calculate True Positives and False Positives per binary classification threshold."""
113        if sample_weights is not None and not isinstance(sample_weights, np.ndarray):
114            sample_weights = np.array(sample_weights)
115
116        if preds.ndim > target.ndim:
117            preds = preds[:, 0]
118        desc_score_indices = np.argsort(-preds)
119
120        preds = preds[desc_score_indices]
121        target = target[desc_score_indices]
122
123        if sample_weights is not None:
124            weight = sample_weights[desc_score_indices]
125        else:
126            weight = 1.
127
128        distinct_value_indices = np.where(preds[1:] - preds[:-1])[0]
129        threshold_idxs = np.pad(distinct_value_indices, (0, 1), constant_values=target.shape[0] - 1)
130        target = np.array(target == pos_label).astype(np.int64)
131        tps = np.cumsum(target * weight, axis=0)[threshold_idxs]
132
133        if sample_weights is not None:
134            fps = np.cumsum((1 - target) * weight, axis=0)[threshold_idxs]
135        else:
136            fps = 1 + threshold_idxs - tps
137
138        return fps, tps, preds[threshold_idxs]
139
140    @property
141    def indexes(self):
142        """The `_indexes` is a private attribute, and you can retrieve it by `self.indexes`.
143        """
144        return getattr(self, '_indexes', None)
145
146    def set_indexes(self, indexes):
147        """
148        The `_indexes` is a private attribute and you can modify it by this function.
149        This allows you to determine the order of logits and labels to be calculated in the
150        inputs, specially when you call the method `update` within this metrics.
151
152        Note:
153            It has been applied in subclass of Metric, eg. `Accuracy`, `BleuScore`, `ConfusionMatrix`,
154            `CosineSimilarity`, `MAE`, and `MSE`.
155
156        Args:
157            indexes (List(int)): The order of logits and labels to be rearranged.
158
159        Outputs:
160            :class:`Metric`, its original Class instance.
161
162        Examples:
163            >>> import numpy as np
164            >>> from mindspore import nn, Tensor
165            >>>
166            >>> x = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]))
167            >>> y = Tensor(np.array([1, 0, 1]))
168            >>> y2 = Tensor(np.array([0, 0, 1]))
169            >>> metric = nn.Accuracy('classification').set_indexes([0, 2])
170            >>> metric.clear()
171            >>> metric.update(x, y, y2)
172            >>> accuracy = metric.eval()
173            >>> print(accuracy)
174            0.3333333333333333
175        """
176        if not isinstance(indexes, list) or not all(isinstance(i, int) for i in indexes):
177            raise ValueError("The indexes should be a list and all its elements should be int")
178        self._indexes = indexes
179        return self
180
181    def __call__(self, *inputs):
182        """
183        Evaluate input data once.
184
185        Args:
186            inputs (tuple): The first item is a predict array, the second item is a target array.
187
188        Returns:
189            Float, compute result.
190        """
191        self.clear()
192        self.update(*inputs)
193        return self.eval()
194
195    @abstractmethod
196    def clear(self):
197        """
198        An interface describes the behavior of clearing the internal evaluation result.
199
200        Note:
201            All subclasses must override this interface.
202        """
203        raise NotImplementedError('Must define clear function to use this base class')
204
205    @abstractmethod
206    def eval(self):
207        """
208        An interface describes the behavior of computing the evaluation result.
209
210        Note:
211            All subclasses must override this interface.
212        """
213        raise NotImplementedError('Must define eval function to use this base class')
214
215    @abstractmethod
216    def update(self, *inputs):
217        """
218        An interface describes the behavior of updating the internal evaluation result.
219
220        Note:
221            All subclasses must override this interface.
222
223        Args:
224            inputs: A variable-length input argument list.
225        """
226        raise NotImplementedError('Must define update function to use this base class')
227
228
229class EvaluationBase(Metric):
230    """
231    Base class of evaluation.
232
233    Note:
234        Please refer to the definition of class `Accuracy`.
235
236    Args:
237        eval_type (str): Type of evaluation must be in {'classification', 'multilabel'}.
238
239    Raises:
240        TypeError: If the input type is not classification or multilabel.
241    """
242    def __init__(self, eval_type):
243        super(EvaluationBase, self).__init__()
244        if eval_type not in _eval_types:
245            raise TypeError('Type must be in {}, but got {}'.format(_eval_types, eval_type))
246        self._type = eval_type
247
248    def _check_shape(self, y_pred, y):
249        """
250        Checks the shapes of y_pred and y.
251
252        Args:
253            y_pred (Tensor): Predict array.
254            y (Tensor): Target array.
255        """
256        if self._type == 'classification':
257            if y_pred.ndim != y.ndim + 1:
258                raise ValueError('Classification case, dims of y_pred equal dims of y add 1, '
259                                 'but got y_pred: {} dims and y: {} dims'.format(y_pred.ndim, y.ndim))
260            if y.shape != (y_pred.shape[0],) + y_pred.shape[2:]:
261                raise ValueError('Classification case, y_pred shape and y shape can not match. '
262                                 'got y_pred shape is {} and y shape is {}'.format(y_pred.shape, y.shape))
263        else:
264            if y_pred.ndim != y.ndim:
265                raise ValueError('{} case, dims of y_pred must be equal to dims of y, but got y_pred: {} '
266                                 'dims and y: {} dims.'.format(self._type, y_pred.ndim, y.ndim))
267            if y_pred.shape != y.shape:
268                raise ValueError('{} case, y_pred shape must be equal to y shape, but got y_pred: {} and y: {}'.
269                                 format(self._type, y_pred.shape, y.shape))
270
271    def _check_value(self, y_pred, y):
272        """
273        Checks the values of y_pred and y.
274
275        Args:
276            y_pred (Tensor): Predict array.
277            y (Tensor): Target array.
278        """
279        if self._type != 'classification' and not (np.equal(y_pred ** 2, y_pred).all() and np.equal(y ** 2, y).all()):
280            raise ValueError('For multilabel case, input value must be 1 or 0.')
281
282    def clear(self):
283        """
284        A interface describes the behavior of clearing the internal evaluation result.
285
286        Note:
287            All subclasses must override this interface.
288        """
289        raise NotImplementedError
290
291    def update(self, *inputs):
292        """
293        A interface describes the behavior of updating the internal evaluation result.
294
295        Note:
296            All subclasses must override this interface.
297
298        Args:
299            inputs: The first item is a predicted array and the second item is a target array.
300        """
301        raise NotImplementedError
302
303    def eval(self):
304        """
305        A interface describes the behavior of computing the evaluation result.
306
307        Note:
308            All subclasses must override this interface.
309        """
310        raise NotImplementedError
311