• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Occlusion explainer."""
16
17from typing import Tuple
18
19import numpy as np
20
21import mindspore as ms
22import mindspore.nn as nn
23from .ablation import Ablation
24from .perturbation import PerturbationAttribution
25from .replacement import Constant
26from ...._utils import abs_max, deprecated_error
27
28
29def _generate_patches(array, window_size: Tuple, strides: Tuple):
30    """Generate patches from image w.r.t given window_size and strides."""
31    window_strides = array.strides
32    slices = tuple(slice(None, None, stride) for stride in strides)
33    indexing_strides = array[slices].strides
34    win_indices_shape = (np.array(array.shape) - np.array(window_size)) // np.array(strides) + 1
35
36    patches_shape = tuple(win_indices_shape) + window_size
37    strides_in_memory = indexing_strides + window_strides
38    patches = np.lib.stride_tricks.as_strided(array, shape=patches_shape, strides=strides_in_memory, writeable=False)
39    patches = patches.reshape((-1,) + window_size)
40    return patches
41
42
43@deprecated_error
44class Occlusion(PerturbationAttribution):
45    """
46    Occlusion uses a sliding window to replace the pixels with a reference value (e.g. constant value), and computes
47    the output difference w.r.t the original output. The output difference caused by perturbed pixels are assigned as
48    feature importance to those pixels. For pixels involved in multiple sliding windows, the feature importance is the
49    averaged differences from multiple sliding windows.
50
51    For more details, please refer to the original paper via: `<https://arxiv.org/abs/1311.2901>`_.
52
53    Args:
54        network (Cell): The black-box model to be explained.
55        activation_fn (Cell): The activation layer that transforms logits to prediction probabilities. For
56            single label classification tasks, `nn.Softmax` is usually applied. As for multi-label classification
57            tasks,`nn.Sigmoid` is usually be applied. Users can also pass their own customized `activation_fn` as long
58            as when combining this function with network, the final output is the probability of the input.
59        perturbation_per_eval (int, optional): Number of perturbations for each inference during inferring the
60            perturbed samples. Within the memory capacity, usually the larger this number is, the faster the
61            explanation is obtained. Default: 32.
62
63    Inputs:
64        - **inputs** (Tensor) - The input data to be explained, a 4D tensor of shape :math:`(N, C, H, W)`.
65        - **targets** (Tensor, int) - The label of interest. It should be a 1D or 0D tensor, or an integer.
66          If it is a 1D tensor, its length should be the same as `inputs`.
67
68    Outputs:
69        Tensor, a 4D tensor of shape :math:`(N, 1, H, W)`, saliency maps.
70
71    Raises:
72        TypeError: Be raised for any argument or input type problem.
73        ValueError: Be raised for any input value problem.
74
75    Supported Platforms:
76        ``Ascend`` ``GPU``
77
78    Example:
79        >>> import numpy as np
80        >>> import mindspore as ms
81        >>> from mindspore.explainer.explanation import Occlusion
82        >>> from mindspore import context
83        >>>
84        >>> context.set_context(mode=context.PYNATIVE_MODE)
85        >>> # The detail of LeNet5 is shown in model_zoo.official.cv.lenet.src.lenet.py
86        >>> net = LeNet5(10, num_channel=3)
87        >>> # initialize Occlusion explainer with the pretrained model and activation function
88        >>> activation_fn = ms.nn.Softmax() # softmax layer is applied to transform logits to probabilities
89        >>> occlusion = Occlusion(net, activation_fn=activation_fn)
90        >>> input_x = ms.Tensor(np.random.rand(1, 3, 32, 32), ms.float32)
91        >>> label = ms.Tensor([1], ms.int32)
92        >>> saliency = occlusion(input_x, label)
93        >>> print(saliency.shape)
94        (1, 1, 32, 32)
95    """
96
97    def __init__(self, network, activation_fn, perturbation_per_eval=32):
98        super().__init__(network, activation_fn, perturbation_per_eval)
99
100        self._ablation = Ablation(perturb_mode='Deletion')
101        self._aggregation_fn = abs_max
102        self._get_replacement = Constant(base_value=0.0)
103        self._num_sample_per_dim = 32  # specify the number of perturbations each dimension.
104
105    def __call__(self, inputs, targets):
106        """Call function for 'Occlusion'."""
107        self._verify_data(inputs, targets)
108
109        inputs = inputs.asnumpy()
110        targets = targets.asnumpy() if isinstance(targets, ms.Tensor) else np.array([targets], np.int)
111
112        batch_size = inputs.shape[0]
113        window_size, strides = self._get_window_size_and_strides(inputs)
114
115        full_network = nn.SequentialCell([self._network, self._activation_fn])
116
117        original_outputs = full_network(ms.Tensor(inputs, ms.float32)).asnumpy()[np.arange(batch_size), targets]
118
119        masks = Occlusion._generate_masks(inputs, window_size, strides)
120
121        return self._perturbate(batch_size, full_network, (original_outputs, masks, inputs, targets))
122
123    def _perturbate(self, batch_size, full_network, data):
124        """Perform perturbations."""
125        original_outputs, masks, inputs, targets = data
126        total_attribution = np.zeros_like(inputs)
127        weights = np.ones_like(inputs)
128        num_perturbations = masks.shape[1]
129        reference = self._get_replacement(inputs)
130
131        count = 0
132        while count < num_perturbations:
133            ith_masks = masks[:, count:min(count+self._perturbation_per_eval, num_perturbations)]
134            actual_num_eval = ith_masks.shape[1]
135            num_samples = batch_size * actual_num_eval
136            occluded_inputs = self._ablation(inputs, reference, ith_masks)
137            occluded_inputs = occluded_inputs.reshape((-1, *inputs.shape[1:]))
138            targets_repeat = np.repeat(targets, repeats=actual_num_eval, axis=0)
139            occluded_outputs = full_network(
140                ms.Tensor(occluded_inputs, ms.float32)).asnumpy()[np.arange(num_samples), targets_repeat]
141            original_outputs_repeat = np.repeat(original_outputs, repeats=actual_num_eval, axis=0)
142            outputs_diff = original_outputs_repeat - occluded_outputs
143            total_attribution += (
144                outputs_diff.reshape(ith_masks.shape[:2] + (1,) * (len(masks.shape) - 2)) * ith_masks).sum(axis=1)
145            weights += ith_masks.sum(axis=1)
146            count += actual_num_eval
147        attribution = self._aggregation_fn(ms.Tensor(total_attribution / weights, ms.float32))
148        return attribution
149
150    def _get_window_size_and_strides(self, inputs):
151        """
152        Return window_size and strides.
153
154        # If spatial size of input data is smaller than self._num_sample_per_dim, window_size and strides will set to
155        # `(C, 3, 3)` and `(C, 1, 1)` separately. Otherwise, the window_size and strides will generated adaptively to
156        match self._num_sample_per_dim.
157        """
158        window_size = tuple(
159            [inputs.shape[1]]
160            + [x // self._num_sample_per_dim if x > self._num_sample_per_dim else 3 for x in inputs.shape[2:]])
161        strides = tuple(
162            [inputs.shape[1]]
163            + [x // self._num_sample_per_dim if x > self._num_sample_per_dim else 1 for x in inputs.shape[2:]])
164        return window_size, strides
165
166    @staticmethod
167    def _generate_masks(inputs, window_size, strides):
168        """Generate masks to perturb contiguous regions."""
169        total_dim = np.prod(inputs.shape[1:]).item()
170        template = np.arange(total_dim).reshape(inputs.shape[1:])
171        indices = _generate_patches(template, window_size, strides)
172        num_perturbations = indices.shape[0]
173        indices = indices.reshape(num_perturbations, -1)
174
175        mask = np.zeros((num_perturbations, total_dim), dtype=np.bool)
176        for i in range(num_perturbations):
177            mask[i, indices[i]] = True
178        mask = mask.reshape((num_perturbations,) + inputs.shape[1:])
179
180        masks = np.tile(mask, reps=(inputs.shape[0],) + (1,) * len(mask.shape))
181        return masks
182