1# Copyright 2020-2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Robustness.""" 16 17import numpy as np 18 19import mindspore as ms 20import mindspore.nn as nn 21from mindspore.train._utils import check_value_type 22from mindspore import log 23from .metric import LabelSensitiveMetric 24from ...explanation._attribution._perturbation.replacement import RandomPerturb 25from ..._utils import deprecated_error 26 27 28@deprecated_error 29class Robustness(LabelSensitiveMetric): 30 """ 31 Robustness perturbs the inputs by adding random noise and choose the maximum sensitivity as evaluation score from 32 the perturbations. 33 34 Args: 35 num_labels (int): Number of classes in the dataset. 36 activation_fn (Cell): The activation layer that transforms logits to prediction probabilities. For 37 single label classification tasks, `nn.Softmax` is usually applied. As for multi-label classification 38 tasks, `nn.Sigmoid` is usually be applied. Users can also pass their own customized `activation_fn` as long 39 as when combining this function with network, the final output is the probability of the input. 40 41 Raises: 42 TypeError: Be raised for any argument type problem. 43 44 Supported Platforms: 45 ``Ascend`` ``GPU`` 46 """ 47 48 def __init__(self, num_labels, activation_fn): 49 super().__init__(num_labels) 50 check_value_type("activation_fn", activation_fn, nn.Cell) 51 self._perturb = RandomPerturb() 52 self._num_perturbations = 10 # number of perturbations used in evaluation 53 self._threshold = 0.1 # threshold to generate perturbation 54 self._activation_fn = activation_fn 55 56 def evaluate(self, explainer, inputs, targets, saliency=None): 57 """ 58 Evaluate robustness on single sample. 59 60 Note: 61 Currently only single sample (:math:`N=1`) at each call is supported. 62 63 Args: 64 explainer (Explanation): The explainer to be evaluated, see `mindspore.explainer.explanation`. 65 inputs (Tensor): A data sample, a 4D tensor of shape :math:`(N, C, H, W)`. 66 targets (Tensor, int): The label of interest. It should be a 1D or 0D tensor, or an integer. 67 If `targets` is a 1D tensor, its length should be the same as `inputs`. 68 saliency (Tensor, optional): The saliency map to be evaluated, a 4D tensor of shape :math:`(N, 1, H, W)`. 69 If it is None, the parsed `explainer` will generate the saliency map with `inputs` and `targets` and 70 continue the evaluation. Default: None. 71 72 Returns: 73 numpy.ndarray, 1D array of shape :math:`(N,)`, result of localization evaluated on `explainer`. 74 75 Raises: 76 ValueError: If batch_size is larger than 1. 77 78 Examples: 79 >>> import numpy as np 80 >>> import mindspore as ms 81 >>> from mindspore import nn 82 >>> from mindspore.explainer.explanation import Gradient 83 >>> from mindspore.explainer.benchmark import Robustness 84 >>> from mindspore import context 85 >>> 86 >>> context.set_context(mode=context.PYNATIVE_MODE) 87 >>> # Initialize a Robustness benchmarker passing num_labels of the dataset. 88 >>> num_labels = 10 89 >>> activation_fn = nn.Softmax() 90 >>> robustness = Robustness(num_labels, activation_fn) 91 >>> 92 >>> # The detail of LeNet5 is shown in model_zoo.official.cv.lenet.src.lenet.py 93 >>> net = LeNet5(10, num_channel=3) 94 >>> # prepare your explainer to be evaluated, e.g., Gradient. 95 >>> gradient = Gradient(net) 96 >>> input_x = ms.Tensor(np.random.rand(1, 3, 32, 32), ms.float32) 97 >>> target_label = ms.Tensor([0], ms.int32) 98 >>> # robustness is a Robustness instance 99 >>> res = robustness.evaluate(gradient, input_x, target_label) 100 >>> print(res.shape) 101 (1,) 102 """ 103 104 self._check_evaluate_param(explainer, inputs, targets, saliency) 105 if inputs.shape[0] > 1: 106 raise ValueError('Robustness only support a sample each time, but receive {}'.format(inputs.shape[0])) 107 108 if isinstance(targets, int): 109 targets = ms.Tensor([targets], ms.int32) 110 if saliency is None: 111 saliency = explainer(inputs, targets) 112 saliency = saliency.asnumpy() 113 114 norm = np.sqrt(np.sum(np.square(saliency), axis=tuple(range(1, len(saliency.shape))))) 115 if (norm == 0).any(): 116 log.warning('Get saliency norm equals 0, robustness return NaN for zero-norm saliency currently.') 117 norm[norm == 0] = np.nan 118 119 full_network = nn.SequentialCell([explainer.network, self._activation_fn]) 120 original_outputs = full_network(inputs).asnumpy() 121 sensitivities = [] 122 inputs = inputs.asnumpy() 123 for _ in range(self._num_perturbations): 124 perturbations = [] 125 for j, sample in enumerate(inputs): 126 perturbation_on_single_sample = self._perturb_with_threshold(full_network, 127 np.expand_dims(sample, axis=0), 128 original_outputs[j]) 129 perturbations.append(perturbation_on_single_sample) 130 perturbations = np.vstack(perturbations) 131 perturbations = explainer(ms.Tensor(perturbations, ms.float32), targets).asnumpy() 132 sensitivity = np.sqrt(np.sum((perturbations - saliency) ** 2, 133 axis=tuple(range(1, len(saliency.shape))))) 134 sensitivities.append(sensitivity) 135 sensitivities = np.stack(sensitivities, axis=-1) 136 sensitivity = np.max(sensitivities, axis=1) / norm 137 return 1 / np.exp(sensitivity) 138 139 def _perturb_with_threshold(self, network: nn.Cell, sample: np.ndarray, original_output: np.ndarray) -> np.ndarray: 140 """ 141 Generate the perturbation until the L2-distance between original_output and perturbation_output is lower than 142 the given self._threshold or until the attempt reaches the max_attempt_time. 143 """ 144 # the maximum time attempt to get a perturbation with perturb_error low than self._threshold 145 max_attempt_time = 3 146 perturbation = None 147 for _ in range(max_attempt_time): 148 perturbation = self._perturb(sample) 149 perturbation_output = self._activation_fn(network(ms.Tensor(sample, ms.float32))).asnumpy() 150 perturb_error = np.linalg.norm(original_output - perturbation_output) 151 if perturb_error <= self._threshold: 152 return perturbation 153 return perturbation 154