1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" 16This test is used to monitor inversion attack method of MindArmour. 17""" 18import numpy as np 19import pytest 20 21import mindspore.context as context 22from mindspore.nn import Cell, MSELoss 23from mindspore.ops import operations as P 24from mindspore.ops.composite import GradOperation 25from mindspore import Tensor 26 27 28class GradWrapWithLoss(Cell): 29 def __init__(self, network): 30 super(GradWrapWithLoss, self).__init__() 31 self._grad_all = GradOperation(get_all=True, sens_param=False) 32 self._network = network 33 34 def construct(self, inputs, labels): 35 gout = self._grad_all(self._network)(inputs, labels) 36 return gout[0] 37 38 39class AddNet(Cell): 40 def __init__(self): 41 super(AddNet, self).__init__() 42 self._add = P.Add() 43 44 def construct(self, inputs): 45 out = self._add(inputs, inputs) 46 return out 47 48 49class InversionLoss(Cell): 50 def __init__(self, network, weights): 51 super(InversionLoss, self).__init__() 52 self._network = network 53 self._mse_loss = MSELoss() 54 self._weights = weights 55 self._get_shape = P.Shape() 56 self._zeros = P.ZerosLike() 57 self._device_target = context.get_context("device_target") 58 59 def construct(self, input_data, target_features): 60 output = self._network(input_data) 61 loss_1 = self._mse_loss(output, target_features) / self._mse_loss(target_features, self._zeros(target_features)) 62 63 data_shape = self._get_shape(input_data) 64 if self._device_target == 'CPU': 65 split_op_1 = P.Split(2, data_shape[2]) 66 split_op_2 = P.Split(3, data_shape[3]) 67 data_split_1 = split_op_1(input_data) 68 data_split_2 = split_op_2(input_data) 69 loss_2 = 0 70 for i in range(1, data_shape[2]): 71 loss_2 += self._mse_loss(data_split_1[i], data_split_1[i - 1]) 72 for j in range(1, data_shape[3]): 73 loss_2 += self._mse_loss(data_split_2[j], data_split_2[j - 1]) 74 else: 75 data_copy_1 = self._zeros(input_data) 76 data_copy_2 = self._zeros(input_data) 77 data_copy_1[:, :, :(data_shape[2] - 1), :] = input_data[:, :, 1:, :] 78 data_copy_2[:, :, :, :(data_shape[2] - 1)] = input_data[:, :, :, 1:] 79 loss_2 = self._mse_loss(input_data, data_copy_1) + self._mse_loss(input_data, data_copy_2) 80 loss_3 = self._mse_loss(input_data, self._zeros(input_data)) 81 loss = loss_1*self._weights[0] + loss_2*self._weights[1] + loss_3*self._weights[2] 82 return loss 83 84 85class ImageInversionAttack: 86 def __init__(self, network, input_shape, loss_weights=(1, 0.2, 5)): 87 self._network = network 88 self._loss = InversionLoss(self._network, loss_weights) 89 self._input_shape = input_shape 90 91 def generate(self, target_features): 92 target_features = target_features 93 img_num = target_features.shape[0] 94 test_input = np.random.random((img_num,) + self._input_shape).astype(np.float32) 95 loss_net = self._loss 96 loss_grad = GradWrapWithLoss(loss_net) 97 x_grad = loss_grad(Tensor(test_input), Tensor(target_features)).asnumpy() 98 return x_grad 99 100 101@pytest.mark.level0 102@pytest.mark.platform_arm_ascend_training 103@pytest.mark.platform_x86_ascend_training 104@pytest.mark.platform_x86_gpu_training 105@pytest.mark.platform_x86_cpu 106@pytest.mark.env_onecard 107def test_loss_grad_graph(): 108 context.set_context(mode=context.GRAPH_MODE) 109 net = AddNet() 110 target_features = np.random.random((1, 32, 32)).astype(np.float32) 111 inversion_attack = ImageInversionAttack(net, input_shape=(1, 32, 32)) 112 grads = inversion_attack.generate(target_features) 113 assert np.any(grads != 0), 'grad result can not be all zeros' 114