1# Copyright 2020-2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Utils for MindExplain""" 16 17__all__ = [ 18 'ForwardProbe', 19 'abs_max', 20 'calc_auc', 21 'calc_correlation', 22 'deprecated_error', 23 'format_tensor_to_ndarray', 24 'generate_one_hot', 25 'rank_pixels', 26 'resize', 27 'retrieve_layer_by_name', 28 'retrieve_layer', 29 'unify_inputs', 30 'unify_targets' 31] 32 33from typing import Tuple, Union 34 35import numpy as np 36from PIL import Image 37 38import mindspore as ms 39import mindspore.nn as nn 40import mindspore.ops.operations as op 41 42_Array = np.ndarray 43_Module = nn.Cell 44_Tensor = ms.Tensor 45 46 47class DeprecatedError(RuntimeError): 48 def __init__(self): 49 super().__init__("'mindspore.explainer' is deprecated from version 1.5 and " 50 "will be removed in a future version, use MindSpore XAI " 51 "https://gitee.com/mindspore/xai instead.") 52 53 54def deprecated_error(func_or_cls): 55 del func_or_cls 56 raise DeprecatedError() 57 58 59def abs_max(gradients): 60 """ 61 Transform gradients to saliency through abs then take max along channels. 62 63 Args: 64 gradients (_Tensor): Gradients which will be transformed to saliency map. 65 66 Returns: 67 _Tensor, saliency map integrated from gradients. 68 """ 69 gradients = op.Abs()(gradients) 70 saliency = op.ReduceMax(keep_dims=True)(gradients, axis=1) 71 return saliency 72 73 74def generate_one_hot(indices, depth): 75 r""" 76 Simple wrap of OneHot operation, the on_value an off_value are fixed to 1.0 77 and 0.0. 78 """ 79 on_value = ms.Tensor(1.0, ms.float32) 80 off_value = ms.Tensor(0.0, ms.float32) 81 weights = op.OneHot()(indices, depth, on_value, off_value) 82 return weights 83 84 85def unify_inputs(inputs) -> tuple: 86 """Unify inputs of explainer.""" 87 if isinstance(inputs, tuple): 88 return inputs 89 if isinstance(inputs, ms.Tensor): 90 inputs = (inputs,) 91 elif isinstance(inputs, np.ndarray): 92 inputs = (ms.Tensor(inputs),) 93 else: 94 raise TypeError( 95 'inputs must be one of [tuple, ms.Tensor or np.ndarray], ' 96 'but get {}'.format(type(inputs))) 97 return inputs 98 99 100def unify_targets(targets) -> ms.Tensor: 101 """Unify targets labels of explainer.""" 102 if isinstance(targets, ms.Tensor): 103 return targets 104 if isinstance(targets, list): 105 targets = ms.Tensor(targets, dtype=ms.int32) 106 if isinstance(targets, int): 107 targets = ms.Tensor([targets], dtype=ms.int32) 108 else: 109 raise TypeError( 110 'targets must be one of [int, list or ms.Tensor], ' 111 'but get {}'.format(type(targets))) 112 return targets 113 114 115def retrieve_layer_by_name(model: _Module, layer_name: str): 116 """ 117 Retrieve the layer in the model by the given layer_name. 118 119 Args: 120 model (Cell): Model which contains the target layer. 121 layer_name (str): Name of target layer. 122 123 Returns: 124 Cell, the target layer. 125 126 Raises: 127 ValueError: If module with given layer_name is not found in the model. 128 """ 129 if not isinstance(layer_name, str): 130 raise TypeError('layer_name should be type of str, but receive {}.' 131 .format(type(layer_name))) 132 133 if not layer_name: 134 return model 135 136 target_layer = None 137 for name, cell in model.cells_and_names(): 138 if name == layer_name: 139 target_layer = cell 140 return target_layer 141 142 if target_layer is None: 143 raise ValueError( 144 'Cannot match {}, please provide target layer' 145 'in the given model.'.format(layer_name)) 146 return None 147 148 149def retrieve_layer(model: _Module, target_layer: Union[str, _Module] = ''): 150 """ 151 Retrieve the layer in the model. 152 153 'target' can be either a layer name or a Cell object. Given the layer name, 154 the method will search thourgh the model and return the matched layer. If a 155 Cell object is provided, it will check whether the given layer exists 156 in the model. If target layer is not found in the model, ValueError will 157 be raised. 158 159 Args: 160 model (Cell): Model which contains the target layer. 161 target_layer (str, Cell): Name of target layer or the target layer instance. 162 163 Returns: 164 Cell, the target layer. 165 166 Raises: 167 ValueError: If module with given layer_name is not found in the model. 168 """ 169 if isinstance(target_layer, str): 170 target_layer = retrieve_layer_by_name(model, target_layer) 171 return target_layer 172 173 if isinstance(target_layer, _Module): 174 for _, cell in model.cells_and_names(): 175 if target_layer is cell: 176 return target_layer 177 raise ValueError( 178 'Model not contain cell {}, fail to probe.'.format(target_layer) 179 ) 180 raise TypeError('layer_name must have type of str or ms.nn.Cell,' 181 'but receive {}'.format(type(target_layer))) 182 183 184class ForwardProbe: 185 """ 186 Probe to capture output of specific layer in a given model. 187 188 Args: 189 target_layer (str, Cell): Name of target layer or the target layer instance. 190 """ 191 192 def __init__(self, target_layer: _Module): 193 self._target_layer = target_layer 194 self._original_construct = self._target_layer.construct 195 self._intermediate_tensor = None 196 197 @property 198 def value(self): 199 """Obtain the intermediate tensor.""" 200 return self._intermediate_tensor 201 202 def __enter__(self): 203 self._target_layer.construct = self._new_construct 204 return self 205 206 def __exit__(self, *_): 207 self._target_layer.construct = self._original_construct 208 self._intermediate_tensor = None 209 return False 210 211 def _new_construct(self, *inputs): 212 outputs = self._original_construct(*inputs) 213 self._intermediate_tensor = outputs 214 return outputs 215 216 217def format_tensor_to_ndarray(x: Union[ms.Tensor, np.ndarray]) -> np.ndarray: 218 """Unify Tensor and numpy.array to numpy.array.""" 219 if isinstance(x, ms.Tensor): 220 x = x.asnumpy() 221 222 if not isinstance(x, np.ndarray): 223 raise TypeError('input should be one of [ms.Tensor or np.ndarray],' 224 ' but receive {}'.format(type(x))) 225 return x 226 227 228def calc_correlation(x: Union[ms.Tensor, np.ndarray], 229 y: Union[ms.Tensor, np.ndarray]) -> float: 230 """Calculate Pearson correlation coefficient between two vectors.""" 231 x = format_tensor_to_ndarray(x) 232 y = format_tensor_to_ndarray(y) 233 234 if len(x.shape) > 1 or len(y.shape) > 1: 235 raise ValueError('"calc_correlation" only support 1-dim vectors currently, but get shape {} and {}.' 236 .format(len(x.shape), len(y.shape))) 237 238 if np.all(x == 0) or np.all(y == 0): 239 return np.float(0) 240 faithfulness = np.corrcoef(x, y)[0, 1] 241 return faithfulness 242 243 244def calc_auc(x: _Array) -> _Array: 245 """Calculate the Area under Curve.""" 246 # take mean for multiple patches if the model is fully convolutional model 247 if len(x.shape) == 4: 248 x = np.mean(np.mean(x, axis=2), axis=3) 249 auc = (x.sum() - x[0] - x[-1]) / len(x) 250 return auc 251 252 253def rank_pixels(inputs: _Array, descending: bool = True) -> _Array: 254 """ 255 Generate rank order for every pixel in an 2D array. 256 257 The rank order start from 0 to (num_pixel-1). If descending is True, the 258 rank order will generate in a descending order, otherwise in ascending 259 order. 260 """ 261 if len(inputs.shape) < 2 or len(inputs.shape) > 3: 262 raise ValueError('Only support 2D or 3D inputs currently.') 263 264 batch_size = inputs.shape[0] 265 flatten_saliency = inputs.reshape(batch_size, -1) 266 factor = -1 if descending else 1 267 sorted_arg = np.argsort(factor * flatten_saliency, axis=1) 268 flatten_rank = np.zeros_like(sorted_arg) 269 arange = np.arange(flatten_saliency.shape[1]) 270 for i in range(batch_size): 271 flatten_rank[i][sorted_arg[i]] = arange 272 rank_map = flatten_rank.reshape(inputs.shape) 273 return rank_map 274 275 276def resize(inputs: _Tensor, size: Tuple[int, int], mode: str) -> _Tensor: 277 """ 278 Resize the intermediate layer _attribution to the same size as inputs. 279 280 Args: 281 inputs (Tensor): The input tensor to be resized. 282 size (tuple[int]): The targeted size resize to. 283 mode (str): The resize mode. Options: 'nearest_neighbor', 'bilinear'. 284 285 Returns: 286 Tensor, the resized tensor. 287 288 Raises: 289 ValueError: the resize mode is not in ['nearest_neighbor', 'bilinear']. 290 """ 291 h, w = size 292 if mode == 'nearest_neighbor': 293 resize_nn = op.ResizeNearestNeighbor((h, w)) 294 outputs = resize_nn(inputs) 295 296 elif mode == 'bilinear': 297 inputs_np = inputs.asnumpy() 298 inputs_np = np.transpose(inputs_np, [0, 2, 3, 1]) 299 array_lst = [] 300 for inp in inputs_np: 301 array = (np.repeat(inp, 3, axis=2) * 255).astype(np.uint8) 302 image = Image.fromarray(array) 303 image = image.resize(size, resample=Image.BILINEAR) 304 array = np.asarray(image).astype(np.float32) / 255 305 array_lst.append(array[:, :, 0:1]) 306 307 resized_np = np.transpose(array_lst, [0, 3, 1, 2]) 308 outputs = ms.Tensor(resized_np, inputs.dtype) 309 else: 310 raise ValueError('Unsupported resize mode {}.'.format(mode)) 311 312 return outputs 313