• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Base class for XAI metrics."""
16
17import copy
18import math
19from typing import Callable
20
21import numpy as np
22
23import mindspore as ms
24from mindspore import log as logger
25from mindspore.train._utils import check_value_type
26from ..._operators import Tensor
27from ..._utils import format_tensor_to_ndarray
28from ...explanation._attribution.attribution import Attribution
29
30_Explainer = Attribution
31
32
33def verify_argument(inputs, arg_name):
34    """Verify the validity of the parsed arguments."""
35    check_value_type(arg_name, inputs, Tensor)
36    if len(inputs.shape) != 4:
37        raise ValueError('Argument {} must be a 4D Tensor.'.format(arg_name))
38    if len(inputs) > 1:
39        raise ValueError('Support single data evaluation only, but got {}.'.format(len(inputs)))
40
41
42def verify_targets(targets, num_labels):
43    """Verify the validity of the parsed targets."""
44    check_value_type('targets', targets, (int, Tensor))
45
46    if isinstance(targets, Tensor):
47        if len(targets.shape) > 1 or (len(targets.shape) == 1 and len(targets) != 1):
48            raise ValueError('Argument targets must be a 1D or 0D Tensor. If it is a 1D Tensor, '
49                             'it should have the length = 1 as we only support single evaluation now.')
50        targets = int(targets.asnumpy()[0]) if len(targets.shape) == 1 else int(targets.asnumpy())
51    if targets > num_labels - 1 or targets < 0:
52        raise ValueError('Parsed targets exceed the label range.')
53
54
55class AttributionMetric:
56    """Super class of XAI metric class used in classification scenarios."""
57
58    def __init__(self):
59        self._explainer = None
60
61    evaluate: Callable
62    """
63    This method evaluates the explainer on the given attribution and returns the evaluation results.
64    Derived class should implement this method according to specific algorithms of the metric.
65    """
66
67    def _record_explainer(self, explainer: _Explainer):
68        """Record the explainer in current evaluation."""
69        if self._explainer is None:
70            self._explainer = explainer
71        elif self._explainer is not explainer:
72            logger.info('Provided explainer is not the same as previously evaluated one. Please reset the evaluated '
73                        'results. Previous explainer: %s, current explainer: %s', self._explainer, explainer)
74            self._explainer = explainer
75
76
77class LabelAgnosticMetric(AttributionMetric):
78    """Super class add functions for label-agnostic metric."""
79
80    def __init__(self):
81        super().__init__()
82        self._global_results = []
83
84    @property
85    def performance(self) -> float:
86        """
87        Return the average evaluation result.
88
89        Return:
90            float, averaged result. If no result is aggregate in the global_results, 0.0 will be returned.
91        """
92        result_sum, count = 0, 0
93        for res in self._global_results:
94            if math.isfinite(res):
95                result_sum += res
96                count += 1
97        return 0. if count == 0 else result_sum / count
98
99    def aggregate(self, result):
100        """Aggregate single evaluation result to global results."""
101        if isinstance(result, float):
102            self._global_results.append(result)
103        elif isinstance(result, (ms.Tensor, np.ndarray)):
104            result = format_tensor_to_ndarray(result)
105            self._global_results.extend([float(res) for res in result.reshape(-1)])
106        else:
107            raise TypeError('result should have type of float, ms.Tensor or np.ndarray, but receive %s' % type(result))
108
109    def get_results(self):
110        """Return the global results."""
111        return self._global_results.copy()
112
113    def reset(self):
114        """Reset global results."""
115        self._global_results.clear()
116
117    def _check_evaluate_param(self, explainer, inputs):
118        """Check the evaluate parameters."""
119        check_value_type('explainer', explainer, Attribution)
120        self._record_explainer(explainer)
121        verify_argument(inputs, 'inputs')
122
123
124class LabelSensitiveMetric(AttributionMetric):
125    """Super class add functions for label-sensitive metrics."""
126
127    def __init__(self, num_labels: int):
128        super().__init__()
129        LabelSensitiveMetric._verify_params(num_labels)
130        self._num_labels = num_labels
131        self._global_results = {i: [] for i in range(num_labels)}
132
133    @property
134    def num_labels(self):
135        """Number of labels used in evaluation."""
136        return self._num_labels
137
138    @staticmethod
139    def _verify_params(num_labels):
140        """Checks whether num_labels is valid."""
141        check_value_type("num_labels", num_labels, int)
142        if num_labels < 1:
143            raise ValueError("Argument num_labels must be parsed with a integer > 0.")
144
145    def aggregate(self, result, targets):
146        """Aggregates single result to global_results."""
147        if isinstance(result, float):
148            if isinstance(targets, int):
149                self._global_results[targets].append(result)
150            else:
151                target_np = format_tensor_to_ndarray(targets)
152                if len(target_np) > 1:
153                    raise ValueError("One result can not be aggreated to multiple targets.")
154        elif isinstance(result, (ms.Tensor, np.ndarray)):
155            result_np = format_tensor_to_ndarray(result).reshape(-1)
156            if isinstance(targets, int):
157                for res in result_np:
158                    self._global_results[targets].append(float(res))
159            else:
160                target_np = format_tensor_to_ndarray(targets).reshape(-1)
161                if len(target_np) != len(result_np):
162                    raise ValueError("Length of result does not match with length of targets.")
163                for tar, res in zip(target_np, result_np):
164                    self._global_results[int(tar)].append(float(res))
165        else:
166            raise TypeError('Result should have type of float, ms.Tensor or np.ndarray, but receive %s' % type(result))
167
168    def reset(self):
169        """Resets global_result."""
170        self._global_results = {i: [] for i in range(self._num_labels)}
171
172    @property
173    def class_performances(self):
174        """
175        Get the class performances by global result.
176
177        Returns:
178            (:class:`list`): a list of performances where each value is the average score of specific class.
179        """
180        results_on_labels = []
181        for label_id in range(self._num_labels):
182            sum_of_label, count_of_label = 0, 0
183            for res in self._global_results[label_id]:
184                if math.isfinite(res):
185                    sum_of_label += res
186                    count_of_label += 1
187            results_on_labels.append(0. if count_of_label == 0 else sum_of_label / count_of_label)
188        return results_on_labels
189
190    @property
191    def performance(self):
192        """
193        Get the performance by global result.
194
195        Returns:
196            (:class:`float`): mean performance.
197        """
198        result_sum, count = 0, 0
199        for label_id in range(self._num_labels):
200            for res in self._global_results[label_id]:
201                if math.isfinite(res):
202                    result_sum += res
203                    count += 1
204        return 0. if count == 0 else result_sum / count
205
206    def get_results(self):
207        """Global result of the metric can be return"""
208        return copy.deepcopy(self._global_results)
209
210    def _check_evaluate_param(self, explainer, inputs, targets, saliency):
211        """Check the evaluate parameters."""
212        check_value_type('explainer', explainer, Attribution)
213        self._record_explainer(explainer)
214        verify_argument(inputs, 'inputs')
215        output = explainer.network(inputs)
216        check_value_type("output of explainer model", output, Tensor)
217        output_dim = explainer.network(inputs).shape[1]
218        if output_dim != self._num_labels:
219            raise ValueError("The output dimension of of black-box model in explainer does not match the dimension "
220                             "of num_labels set in the __init__, please check explainer and num_labels again.")
221        verify_targets(targets, self._num_labels)
222        check_value_type('saliency', saliency, (Tensor, type(None)))
223