1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Base class for XAI metrics.""" 16 17import copy 18import math 19from typing import Callable 20 21import numpy as np 22 23import mindspore as ms 24from mindspore import log as logger 25from mindspore.train._utils import check_value_type 26from ..._operators import Tensor 27from ..._utils import format_tensor_to_ndarray 28from ...explanation._attribution.attribution import Attribution 29 30_Explainer = Attribution 31 32 33def verify_argument(inputs, arg_name): 34 """Verify the validity of the parsed arguments.""" 35 check_value_type(arg_name, inputs, Tensor) 36 if len(inputs.shape) != 4: 37 raise ValueError('Argument {} must be a 4D Tensor.'.format(arg_name)) 38 if len(inputs) > 1: 39 raise ValueError('Support single data evaluation only, but got {}.'.format(len(inputs))) 40 41 42def verify_targets(targets, num_labels): 43 """Verify the validity of the parsed targets.""" 44 check_value_type('targets', targets, (int, Tensor)) 45 46 if isinstance(targets, Tensor): 47 if len(targets.shape) > 1 or (len(targets.shape) == 1 and len(targets) != 1): 48 raise ValueError('Argument targets must be a 1D or 0D Tensor. If it is a 1D Tensor, ' 49 'it should have the length = 1 as we only support single evaluation now.') 50 targets = int(targets.asnumpy()[0]) if len(targets.shape) == 1 else int(targets.asnumpy()) 51 if targets > num_labels - 1 or targets < 0: 52 raise ValueError('Parsed targets exceed the label range.') 53 54 55class AttributionMetric: 56 """Super class of XAI metric class used in classification scenarios.""" 57 58 def __init__(self): 59 self._explainer = None 60 61 evaluate: Callable 62 """ 63 This method evaluates the explainer on the given attribution and returns the evaluation results. 64 Derived class should implement this method according to specific algorithms of the metric. 65 """ 66 67 def _record_explainer(self, explainer: _Explainer): 68 """Record the explainer in current evaluation.""" 69 if self._explainer is None: 70 self._explainer = explainer 71 elif self._explainer is not explainer: 72 logger.info('Provided explainer is not the same as previously evaluated one. Please reset the evaluated ' 73 'results. Previous explainer: %s, current explainer: %s', self._explainer, explainer) 74 self._explainer = explainer 75 76 77class LabelAgnosticMetric(AttributionMetric): 78 """Super class add functions for label-agnostic metric.""" 79 80 def __init__(self): 81 super().__init__() 82 self._global_results = [] 83 84 @property 85 def performance(self) -> float: 86 """ 87 Return the average evaluation result. 88 89 Return: 90 float, averaged result. If no result is aggregate in the global_results, 0.0 will be returned. 91 """ 92 result_sum, count = 0, 0 93 for res in self._global_results: 94 if math.isfinite(res): 95 result_sum += res 96 count += 1 97 return 0. if count == 0 else result_sum / count 98 99 def aggregate(self, result): 100 """Aggregate single evaluation result to global results.""" 101 if isinstance(result, float): 102 self._global_results.append(result) 103 elif isinstance(result, (ms.Tensor, np.ndarray)): 104 result = format_tensor_to_ndarray(result) 105 self._global_results.extend([float(res) for res in result.reshape(-1)]) 106 else: 107 raise TypeError('result should have type of float, ms.Tensor or np.ndarray, but receive %s' % type(result)) 108 109 def get_results(self): 110 """Return the global results.""" 111 return self._global_results.copy() 112 113 def reset(self): 114 """Reset global results.""" 115 self._global_results.clear() 116 117 def _check_evaluate_param(self, explainer, inputs): 118 """Check the evaluate parameters.""" 119 check_value_type('explainer', explainer, Attribution) 120 self._record_explainer(explainer) 121 verify_argument(inputs, 'inputs') 122 123 124class LabelSensitiveMetric(AttributionMetric): 125 """Super class add functions for label-sensitive metrics.""" 126 127 def __init__(self, num_labels: int): 128 super().__init__() 129 LabelSensitiveMetric._verify_params(num_labels) 130 self._num_labels = num_labels 131 self._global_results = {i: [] for i in range(num_labels)} 132 133 @property 134 def num_labels(self): 135 """Number of labels used in evaluation.""" 136 return self._num_labels 137 138 @staticmethod 139 def _verify_params(num_labels): 140 """Checks whether num_labels is valid.""" 141 check_value_type("num_labels", num_labels, int) 142 if num_labels < 1: 143 raise ValueError("Argument num_labels must be parsed with a integer > 0.") 144 145 def aggregate(self, result, targets): 146 """Aggregates single result to global_results.""" 147 if isinstance(result, float): 148 if isinstance(targets, int): 149 self._global_results[targets].append(result) 150 else: 151 target_np = format_tensor_to_ndarray(targets) 152 if len(target_np) > 1: 153 raise ValueError("One result can not be aggreated to multiple targets.") 154 elif isinstance(result, (ms.Tensor, np.ndarray)): 155 result_np = format_tensor_to_ndarray(result).reshape(-1) 156 if isinstance(targets, int): 157 for res in result_np: 158 self._global_results[targets].append(float(res)) 159 else: 160 target_np = format_tensor_to_ndarray(targets).reshape(-1) 161 if len(target_np) != len(result_np): 162 raise ValueError("Length of result does not match with length of targets.") 163 for tar, res in zip(target_np, result_np): 164 self._global_results[int(tar)].append(float(res)) 165 else: 166 raise TypeError('Result should have type of float, ms.Tensor or np.ndarray, but receive %s' % type(result)) 167 168 def reset(self): 169 """Resets global_result.""" 170 self._global_results = {i: [] for i in range(self._num_labels)} 171 172 @property 173 def class_performances(self): 174 """ 175 Get the class performances by global result. 176 177 Returns: 178 (:class:`list`): a list of performances where each value is the average score of specific class. 179 """ 180 results_on_labels = [] 181 for label_id in range(self._num_labels): 182 sum_of_label, count_of_label = 0, 0 183 for res in self._global_results[label_id]: 184 if math.isfinite(res): 185 sum_of_label += res 186 count_of_label += 1 187 results_on_labels.append(0. if count_of_label == 0 else sum_of_label / count_of_label) 188 return results_on_labels 189 190 @property 191 def performance(self): 192 """ 193 Get the performance by global result. 194 195 Returns: 196 (:class:`float`): mean performance. 197 """ 198 result_sum, count = 0, 0 199 for label_id in range(self._num_labels): 200 for res in self._global_results[label_id]: 201 if math.isfinite(res): 202 result_sum += res 203 count += 1 204 return 0. if count == 0 else result_sum / count 205 206 def get_results(self): 207 """Global result of the metric can be return""" 208 return copy.deepcopy(self._global_results) 209 210 def _check_evaluate_param(self, explainer, inputs, targets, saliency): 211 """Check the evaluate parameters.""" 212 check_value_type('explainer', explainer, Attribution) 213 self._record_explainer(explainer) 214 verify_argument(inputs, 'inputs') 215 output = explainer.network(inputs) 216 check_value_type("output of explainer model", output, Tensor) 217 output_dim = explainer.network(inputs).shape[1] 218 if output_dim != self._num_labels: 219 raise ValueError("The output dimension of of black-box model in explainer does not match the dimension " 220 "of num_labels set in the __init__, please check explainer and num_labels again.") 221 verify_targets(targets, self._num_labels) 222 check_value_type('saliency', saliency, (Tensor, type(None))) 223