• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Robustness."""
16
17import numpy as np
18
19import mindspore as ms
20import mindspore.nn as nn
21from mindspore.train._utils import check_value_type
22from mindspore import log
23from .metric import LabelSensitiveMetric
24from ...explanation._attribution._perturbation.replacement import RandomPerturb
25from ..._utils import deprecated_error
26
27
28@deprecated_error
29class Robustness(LabelSensitiveMetric):
30    """
31    Robustness perturbs the inputs by adding random noise and choose the maximum sensitivity as evaluation score from
32    the perturbations.
33
34    Args:
35        num_labels (int): Number of classes in the dataset.
36        activation_fn (Cell): The activation layer that transforms logits to prediction probabilities. For
37            single label classification tasks, `nn.Softmax` is usually applied. As for multi-label classification
38            tasks, `nn.Sigmoid` is usually be applied. Users can also pass their own customized `activation_fn` as long
39            as when combining this function with network, the final output is the probability of the input.
40
41    Raises:
42        TypeError: Be raised for any argument type problem.
43
44    Supported Platforms:
45        ``Ascend`` ``GPU``
46    """
47
48    def __init__(self, num_labels, activation_fn):
49        super().__init__(num_labels)
50        check_value_type("activation_fn", activation_fn, nn.Cell)
51        self._perturb = RandomPerturb()
52        self._num_perturbations = 10  # number of perturbations used in evaluation
53        self._threshold = 0.1  # threshold to generate perturbation
54        self._activation_fn = activation_fn
55
56    def evaluate(self, explainer, inputs, targets, saliency=None):
57        """
58        Evaluate robustness on single sample.
59
60        Note:
61            Currently only single sample (:math:`N=1`) at each call is supported.
62
63        Args:
64            explainer (Explanation): The explainer to be evaluated, see `mindspore.explainer.explanation`.
65            inputs (Tensor): A data sample, a 4D tensor of shape :math:`(N, C, H, W)`.
66            targets (Tensor, int): The label of interest. It should be a 1D or 0D tensor, or an integer.
67                If `targets` is a 1D tensor, its length should be the same as `inputs`.
68            saliency (Tensor, optional): The saliency map to be evaluated, a 4D tensor of shape :math:`(N, 1, H, W)`.
69                If it is None, the parsed `explainer` will generate the saliency map with `inputs` and `targets` and
70                continue the evaluation. Default: None.
71
72        Returns:
73            numpy.ndarray, 1D array of shape :math:`(N,)`, result of localization evaluated on `explainer`.
74
75        Raises:
76            ValueError: If batch_size is larger than 1.
77
78        Examples:
79            >>> import numpy as np
80            >>> import mindspore as ms
81            >>> from mindspore import nn
82            >>> from mindspore.explainer.explanation import Gradient
83            >>> from mindspore.explainer.benchmark import Robustness
84            >>> from mindspore import context
85            >>>
86            >>> context.set_context(mode=context.PYNATIVE_MODE)
87            >>> # Initialize a Robustness benchmarker passing num_labels of the dataset.
88            >>> num_labels = 10
89            >>> activation_fn = nn.Softmax()
90            >>> robustness = Robustness(num_labels, activation_fn)
91            >>>
92            >>> # The detail of LeNet5 is shown in model_zoo.official.cv.lenet.src.lenet.py
93            >>> net = LeNet5(10, num_channel=3)
94            >>> # prepare your explainer to be evaluated, e.g., Gradient.
95            >>> gradient = Gradient(net)
96            >>> input_x = ms.Tensor(np.random.rand(1, 3, 32, 32), ms.float32)
97            >>> target_label = ms.Tensor([0], ms.int32)
98            >>> # robustness is a Robustness instance
99            >>> res = robustness.evaluate(gradient, input_x, target_label)
100            >>> print(res.shape)
101            (1,)
102        """
103
104        self._check_evaluate_param(explainer, inputs, targets, saliency)
105        if inputs.shape[0] > 1:
106            raise ValueError('Robustness only support a sample each time, but receive {}'.format(inputs.shape[0]))
107
108        if isinstance(targets, int):
109            targets = ms.Tensor([targets], ms.int32)
110        if saliency is None:
111            saliency = explainer(inputs, targets)
112        saliency = saliency.asnumpy()
113
114        norm = np.sqrt(np.sum(np.square(saliency), axis=tuple(range(1, len(saliency.shape)))))
115        if (norm == 0).any():
116            log.warning('Get saliency norm equals 0, robustness return NaN for zero-norm saliency currently.')
117            norm[norm == 0] = np.nan
118
119        full_network = nn.SequentialCell([explainer.network, self._activation_fn])
120        original_outputs = full_network(inputs).asnumpy()
121        sensitivities = []
122        inputs = inputs.asnumpy()
123        for _ in range(self._num_perturbations):
124            perturbations = []
125            for j, sample in enumerate(inputs):
126                perturbation_on_single_sample = self._perturb_with_threshold(full_network,
127                                                                             np.expand_dims(sample, axis=0),
128                                                                             original_outputs[j])
129                perturbations.append(perturbation_on_single_sample)
130            perturbations = np.vstack(perturbations)
131            perturbations = explainer(ms.Tensor(perturbations, ms.float32), targets).asnumpy()
132            sensitivity = np.sqrt(np.sum((perturbations - saliency) ** 2,
133                                         axis=tuple(range(1, len(saliency.shape)))))
134            sensitivities.append(sensitivity)
135        sensitivities = np.stack(sensitivities, axis=-1)
136        sensitivity = np.max(sensitivities, axis=1) / norm
137        return 1 / np.exp(sensitivity)
138
139    def _perturb_with_threshold(self, network: nn.Cell, sample: np.ndarray, original_output: np.ndarray) -> np.ndarray:
140        """
141        Generate the perturbation until the L2-distance between original_output and perturbation_output is lower than
142        the given self._threshold or until the attempt reaches the max_attempt_time.
143        """
144        # the maximum time attempt to get a perturbation with perturb_error low than self._threshold
145        max_attempt_time = 3
146        perturbation = None
147        for _ in range(max_attempt_time):
148            perturbation = self._perturb(sample)
149            perturbation_output = self._activation_fn(network(ms.Tensor(sample, ms.float32))).asnumpy()
150            perturb_error = np.linalg.norm(original_output - perturbation_output)
151            if perturb_error <= self._threshold:
152                return perturbation
153        return perturbation
154