1# Copyright 2020-2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Occlusion explainer.""" 16 17from typing import Tuple 18 19import numpy as np 20 21import mindspore as ms 22import mindspore.nn as nn 23from .ablation import Ablation 24from .perturbation import PerturbationAttribution 25from .replacement import Constant 26from ...._utils import abs_max, deprecated_error 27 28 29def _generate_patches(array, window_size: Tuple, strides: Tuple): 30 """Generate patches from image w.r.t given window_size and strides.""" 31 window_strides = array.strides 32 slices = tuple(slice(None, None, stride) for stride in strides) 33 indexing_strides = array[slices].strides 34 win_indices_shape = (np.array(array.shape) - np.array(window_size)) // np.array(strides) + 1 35 36 patches_shape = tuple(win_indices_shape) + window_size 37 strides_in_memory = indexing_strides + window_strides 38 patches = np.lib.stride_tricks.as_strided(array, shape=patches_shape, strides=strides_in_memory, writeable=False) 39 patches = patches.reshape((-1,) + window_size) 40 return patches 41 42 43@deprecated_error 44class Occlusion(PerturbationAttribution): 45 """ 46 Occlusion uses a sliding window to replace the pixels with a reference value (e.g. constant value), and computes 47 the output difference w.r.t the original output. The output difference caused by perturbed pixels are assigned as 48 feature importance to those pixels. For pixels involved in multiple sliding windows, the feature importance is the 49 averaged differences from multiple sliding windows. 50 51 For more details, please refer to the original paper via: `<https://arxiv.org/abs/1311.2901>`_. 52 53 Args: 54 network (Cell): The black-box model to be explained. 55 activation_fn (Cell): The activation layer that transforms logits to prediction probabilities. For 56 single label classification tasks, `nn.Softmax` is usually applied. As for multi-label classification 57 tasks,`nn.Sigmoid` is usually be applied. Users can also pass their own customized `activation_fn` as long 58 as when combining this function with network, the final output is the probability of the input. 59 perturbation_per_eval (int, optional): Number of perturbations for each inference during inferring the 60 perturbed samples. Within the memory capacity, usually the larger this number is, the faster the 61 explanation is obtained. Default: 32. 62 63 Inputs: 64 - **inputs** (Tensor) - The input data to be explained, a 4D tensor of shape :math:`(N, C, H, W)`. 65 - **targets** (Tensor, int) - The label of interest. It should be a 1D or 0D tensor, or an integer. 66 If it is a 1D tensor, its length should be the same as `inputs`. 67 68 Outputs: 69 Tensor, a 4D tensor of shape :math:`(N, 1, H, W)`, saliency maps. 70 71 Raises: 72 TypeError: Be raised for any argument or input type problem. 73 ValueError: Be raised for any input value problem. 74 75 Supported Platforms: 76 ``Ascend`` ``GPU`` 77 78 Example: 79 >>> import numpy as np 80 >>> import mindspore as ms 81 >>> from mindspore.explainer.explanation import Occlusion 82 >>> from mindspore import context 83 >>> 84 >>> context.set_context(mode=context.PYNATIVE_MODE) 85 >>> # The detail of LeNet5 is shown in model_zoo.official.cv.lenet.src.lenet.py 86 >>> net = LeNet5(10, num_channel=3) 87 >>> # initialize Occlusion explainer with the pretrained model and activation function 88 >>> activation_fn = ms.nn.Softmax() # softmax layer is applied to transform logits to probabilities 89 >>> occlusion = Occlusion(net, activation_fn=activation_fn) 90 >>> input_x = ms.Tensor(np.random.rand(1, 3, 32, 32), ms.float32) 91 >>> label = ms.Tensor([1], ms.int32) 92 >>> saliency = occlusion(input_x, label) 93 >>> print(saliency.shape) 94 (1, 1, 32, 32) 95 """ 96 97 def __init__(self, network, activation_fn, perturbation_per_eval=32): 98 super().__init__(network, activation_fn, perturbation_per_eval) 99 100 self._ablation = Ablation(perturb_mode='Deletion') 101 self._aggregation_fn = abs_max 102 self._get_replacement = Constant(base_value=0.0) 103 self._num_sample_per_dim = 32 # specify the number of perturbations each dimension. 104 105 def __call__(self, inputs, targets): 106 """Call function for 'Occlusion'.""" 107 self._verify_data(inputs, targets) 108 109 inputs = inputs.asnumpy() 110 targets = targets.asnumpy() if isinstance(targets, ms.Tensor) else np.array([targets], np.int) 111 112 batch_size = inputs.shape[0] 113 window_size, strides = self._get_window_size_and_strides(inputs) 114 115 full_network = nn.SequentialCell([self._network, self._activation_fn]) 116 117 original_outputs = full_network(ms.Tensor(inputs, ms.float32)).asnumpy()[np.arange(batch_size), targets] 118 119 masks = Occlusion._generate_masks(inputs, window_size, strides) 120 121 return self._perturbate(batch_size, full_network, (original_outputs, masks, inputs, targets)) 122 123 def _perturbate(self, batch_size, full_network, data): 124 """Perform perturbations.""" 125 original_outputs, masks, inputs, targets = data 126 total_attribution = np.zeros_like(inputs) 127 weights = np.ones_like(inputs) 128 num_perturbations = masks.shape[1] 129 reference = self._get_replacement(inputs) 130 131 count = 0 132 while count < num_perturbations: 133 ith_masks = masks[:, count:min(count+self._perturbation_per_eval, num_perturbations)] 134 actual_num_eval = ith_masks.shape[1] 135 num_samples = batch_size * actual_num_eval 136 occluded_inputs = self._ablation(inputs, reference, ith_masks) 137 occluded_inputs = occluded_inputs.reshape((-1, *inputs.shape[1:])) 138 targets_repeat = np.repeat(targets, repeats=actual_num_eval, axis=0) 139 occluded_outputs = full_network( 140 ms.Tensor(occluded_inputs, ms.float32)).asnumpy()[np.arange(num_samples), targets_repeat] 141 original_outputs_repeat = np.repeat(original_outputs, repeats=actual_num_eval, axis=0) 142 outputs_diff = original_outputs_repeat - occluded_outputs 143 total_attribution += ( 144 outputs_diff.reshape(ith_masks.shape[:2] + (1,) * (len(masks.shape) - 2)) * ith_masks).sum(axis=1) 145 weights += ith_masks.sum(axis=1) 146 count += actual_num_eval 147 attribution = self._aggregation_fn(ms.Tensor(total_attribution / weights, ms.float32)) 148 return attribution 149 150 def _get_window_size_and_strides(self, inputs): 151 """ 152 Return window_size and strides. 153 154 # If spatial size of input data is smaller than self._num_sample_per_dim, window_size and strides will set to 155 # `(C, 3, 3)` and `(C, 1, 1)` separately. Otherwise, the window_size and strides will generated adaptively to 156 match self._num_sample_per_dim. 157 """ 158 window_size = tuple( 159 [inputs.shape[1]] 160 + [x // self._num_sample_per_dim if x > self._num_sample_per_dim else 3 for x in inputs.shape[2:]]) 161 strides = tuple( 162 [inputs.shape[1]] 163 + [x // self._num_sample_per_dim if x > self._num_sample_per_dim else 1 for x in inputs.shape[2:]]) 164 return window_size, strides 165 166 @staticmethod 167 def _generate_masks(inputs, window_size, strides): 168 """Generate masks to perturb contiguous regions.""" 169 total_dim = np.prod(inputs.shape[1:]).item() 170 template = np.arange(total_dim).reshape(inputs.shape[1:]) 171 indices = _generate_patches(template, window_size, strides) 172 num_perturbations = indices.shape[0] 173 indices = indices.reshape(num_perturbations, -1) 174 175 mask = np.zeros((num_perturbations, total_dim), dtype=np.bool) 176 for i in range(num_perturbations): 177 mask[i, indices[i]] = True 178 mask = mask.reshape((num_perturbations,) + inputs.shape[1:]) 179 180 masks = np.tile(mask, reps=(inputs.shape[0],) + (1,) * len(mask.shape)) 181 return masks 182