• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Utils for MindExplain"""
16
17__all__ = [
18    'ForwardProbe',
19    'abs_max',
20    'calc_auc',
21    'calc_correlation',
22    'deprecated_error',
23    'format_tensor_to_ndarray',
24    'generate_one_hot',
25    'rank_pixels',
26    'resize',
27    'retrieve_layer_by_name',
28    'retrieve_layer',
29    'unify_inputs',
30    'unify_targets'
31]
32
33from typing import Tuple, Union
34
35import numpy as np
36from PIL import Image
37
38import mindspore as ms
39import mindspore.nn as nn
40import mindspore.ops.operations as op
41
42_Array = np.ndarray
43_Module = nn.Cell
44_Tensor = ms.Tensor
45
46
47class DeprecatedError(RuntimeError):
48    def __init__(self):
49        super().__init__("'mindspore.explainer' is deprecated from version 1.5 and "
50                         "will be removed in a future version, use MindSpore XAI "
51                         "https://gitee.com/mindspore/xai instead.")
52
53
54def deprecated_error(func_or_cls):
55    del func_or_cls
56    raise DeprecatedError()
57
58
59def abs_max(gradients):
60    """
61    Transform gradients to saliency through abs then take max along channels.
62
63    Args:
64        gradients (_Tensor): Gradients which will be transformed to saliency map.
65
66    Returns:
67        _Tensor, saliency map integrated from gradients.
68    """
69    gradients = op.Abs()(gradients)
70    saliency = op.ReduceMax(keep_dims=True)(gradients, axis=1)
71    return saliency
72
73
74def generate_one_hot(indices, depth):
75    r"""
76    Simple wrap of OneHot operation, the on_value an off_value are fixed to 1.0
77    and 0.0.
78    """
79    on_value = ms.Tensor(1.0, ms.float32)
80    off_value = ms.Tensor(0.0, ms.float32)
81    weights = op.OneHot()(indices, depth, on_value, off_value)
82    return weights
83
84
85def unify_inputs(inputs) -> tuple:
86    """Unify inputs of explainer."""
87    if isinstance(inputs, tuple):
88        return inputs
89    if isinstance(inputs, ms.Tensor):
90        inputs = (inputs,)
91    elif isinstance(inputs, np.ndarray):
92        inputs = (ms.Tensor(inputs),)
93    else:
94        raise TypeError(
95            'inputs must be one of [tuple, ms.Tensor or np.ndarray], '
96            'but get {}'.format(type(inputs)))
97    return inputs
98
99
100def unify_targets(targets) -> ms.Tensor:
101    """Unify targets labels of explainer."""
102    if isinstance(targets, ms.Tensor):
103        return targets
104    if isinstance(targets, list):
105        targets = ms.Tensor(targets, dtype=ms.int32)
106    if isinstance(targets, int):
107        targets = ms.Tensor([targets], dtype=ms.int32)
108    else:
109        raise TypeError(
110            'targets must be one of [int, list or ms.Tensor], '
111            'but get {}'.format(type(targets)))
112    return targets
113
114
115def retrieve_layer_by_name(model: _Module, layer_name: str):
116    """
117    Retrieve the layer in the model by the given layer_name.
118
119    Args:
120        model (Cell): Model which contains the target layer.
121        layer_name (str): Name of target layer.
122
123    Returns:
124        Cell, the target layer.
125
126    Raises:
127        ValueError: If module with given layer_name is not found in the model.
128    """
129    if not isinstance(layer_name, str):
130        raise TypeError('layer_name should be type of str, but receive {}.'
131                        .format(type(layer_name)))
132
133    if not layer_name:
134        return model
135
136    target_layer = None
137    for name, cell in model.cells_and_names():
138        if name == layer_name:
139            target_layer = cell
140            return target_layer
141
142    if target_layer is None:
143        raise ValueError(
144            'Cannot match {}, please provide target layer'
145            'in the given model.'.format(layer_name))
146    return None
147
148
149def retrieve_layer(model: _Module, target_layer: Union[str, _Module] = ''):
150    """
151    Retrieve the layer in the model.
152
153    'target' can be either a layer name or a Cell object. Given the layer name,
154    the method will search thourgh the model and return the matched layer. If a
155    Cell object is provided, it will check whether the given layer exists
156    in the model. If target layer is not found in the model, ValueError will
157    be raised.
158
159    Args:
160        model (Cell): Model which contains the target layer.
161        target_layer (str, Cell): Name of target layer or the target layer instance.
162
163    Returns:
164        Cell, the target layer.
165
166    Raises:
167        ValueError: If module with given layer_name is not found in the model.
168    """
169    if isinstance(target_layer, str):
170        target_layer = retrieve_layer_by_name(model, target_layer)
171        return target_layer
172
173    if isinstance(target_layer, _Module):
174        for _, cell in model.cells_and_names():
175            if target_layer is cell:
176                return target_layer
177        raise ValueError(
178            'Model not contain cell {}, fail to probe.'.format(target_layer)
179        )
180    raise TypeError('layer_name must have type of str or ms.nn.Cell,'
181                    'but receive {}'.format(type(target_layer)))
182
183
184class ForwardProbe:
185    """
186    Probe to capture output of specific layer in a given model.
187
188    Args:
189        target_layer (str, Cell): Name of target layer or the target layer instance.
190    """
191
192    def __init__(self, target_layer: _Module):
193        self._target_layer = target_layer
194        self._original_construct = self._target_layer.construct
195        self._intermediate_tensor = None
196
197    @property
198    def value(self):
199        """Obtain the intermediate tensor."""
200        return self._intermediate_tensor
201
202    def __enter__(self):
203        self._target_layer.construct = self._new_construct
204        return self
205
206    def __exit__(self, *_):
207        self._target_layer.construct = self._original_construct
208        self._intermediate_tensor = None
209        return False
210
211    def _new_construct(self, *inputs):
212        outputs = self._original_construct(*inputs)
213        self._intermediate_tensor = outputs
214        return outputs
215
216
217def format_tensor_to_ndarray(x: Union[ms.Tensor, np.ndarray]) -> np.ndarray:
218    """Unify Tensor and numpy.array to numpy.array."""
219    if isinstance(x, ms.Tensor):
220        x = x.asnumpy()
221
222    if not isinstance(x, np.ndarray):
223        raise TypeError('input should be one of [ms.Tensor or np.ndarray],'
224                        ' but receive {}'.format(type(x)))
225    return x
226
227
228def calc_correlation(x: Union[ms.Tensor, np.ndarray],
229                     y: Union[ms.Tensor, np.ndarray]) -> float:
230    """Calculate Pearson correlation coefficient between two vectors."""
231    x = format_tensor_to_ndarray(x)
232    y = format_tensor_to_ndarray(y)
233
234    if len(x.shape) > 1 or len(y.shape) > 1:
235        raise ValueError('"calc_correlation" only support 1-dim vectors currently, but get shape {} and {}.'
236                         .format(len(x.shape), len(y.shape)))
237
238    if np.all(x == 0) or np.all(y == 0):
239        return np.float(0)
240    faithfulness = np.corrcoef(x, y)[0, 1]
241    return faithfulness
242
243
244def calc_auc(x: _Array) -> _Array:
245    """Calculate the Area under Curve."""
246    # take mean for multiple patches if the model is fully convolutional model
247    if len(x.shape) == 4:
248        x = np.mean(np.mean(x, axis=2), axis=3)
249    auc = (x.sum() - x[0] - x[-1]) / len(x)
250    return auc
251
252
253def rank_pixels(inputs: _Array, descending: bool = True) -> _Array:
254    """
255    Generate rank order for every pixel in an 2D array.
256
257    The rank order start from 0 to (num_pixel-1). If descending is True, the
258    rank order will generate in a descending order, otherwise in ascending
259    order.
260    """
261    if len(inputs.shape) < 2 or len(inputs.shape) > 3:
262        raise ValueError('Only support 2D or 3D inputs currently.')
263
264    batch_size = inputs.shape[0]
265    flatten_saliency = inputs.reshape(batch_size, -1)
266    factor = -1 if descending else 1
267    sorted_arg = np.argsort(factor * flatten_saliency, axis=1)
268    flatten_rank = np.zeros_like(sorted_arg)
269    arange = np.arange(flatten_saliency.shape[1])
270    for i in range(batch_size):
271        flatten_rank[i][sorted_arg[i]] = arange
272    rank_map = flatten_rank.reshape(inputs.shape)
273    return rank_map
274
275
276def resize(inputs: _Tensor, size: Tuple[int, int], mode: str) -> _Tensor:
277    """
278    Resize the intermediate layer _attribution to the same size as inputs.
279
280    Args:
281        inputs (Tensor): The input tensor to be resized.
282        size (tuple[int]): The targeted size resize to.
283        mode (str): The resize mode. Options: 'nearest_neighbor', 'bilinear'.
284
285    Returns:
286        Tensor, the resized tensor.
287
288    Raises:
289        ValueError: the resize mode is not in ['nearest_neighbor', 'bilinear'].
290    """
291    h, w = size
292    if mode == 'nearest_neighbor':
293        resize_nn = op.ResizeNearestNeighbor((h, w))
294        outputs = resize_nn(inputs)
295
296    elif mode == 'bilinear':
297        inputs_np = inputs.asnumpy()
298        inputs_np = np.transpose(inputs_np, [0, 2, 3, 1])
299        array_lst = []
300        for inp in inputs_np:
301            array = (np.repeat(inp, 3, axis=2) * 255).astype(np.uint8)
302            image = Image.fromarray(array)
303            image = image.resize(size, resample=Image.BILINEAR)
304            array = np.asarray(image).astype(np.float32) / 255
305            array_lst.append(array[:, :, 0:1])
306
307        resized_np = np.transpose(array_lst, [0, 3, 1, 2])
308        outputs = ms.Tensor(resized_np, inputs.dtype)
309    else:
310        raise ValueError('Unsupported resize mode {}.'.format(mode))
311
312    return outputs
313