1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16import pytest 17import numpy as np 18import mindspore.nn as nn 19import mindspore.common.dtype as mstype 20 21from mindspore import Tensor 22from mindspore import context 23from mindspore import ParameterTuple 24from mindspore.nn import Momentum 25from mindspore.nn import WithLossCell 26from mindspore.ops import composite as C 27from mindspore.ops import operations as P 28from mindspore.common.initializer import TruncatedNormal 29 30context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") 31 32 33grad_all = C.GradOperation(get_all=True) 34 35 36def weight_variable(): 37 """weight initial""" 38 return TruncatedNormal(0.02) 39 40 41def conv(in_channels, out_channels, kernel_size, stride=1, padding=0): 42 """weight initial for conv layer""" 43 weight = weight_variable() 44 return nn.Conv2d(in_channels, out_channels, 45 kernel_size=kernel_size, stride=stride, padding=padding, 46 weight_init=weight, has_bias=False, pad_mode="valid") 47 48 49def fc_with_initialize(input_channels, out_channels): 50 """weight initial for fc layer""" 51 weight = weight_variable() 52 bias = weight_variable() 53 return nn.Dense(input_channels, out_channels, weight, bias) 54 55 56class test_custom_hook_function_base(): 57 def __init__(self): 58 pass 59 60 def test_custom_hook_function(self, hook_function, cell_hook_function): 61 return hook_function, cell_hook_function 62 63 64def cell_hook_function_print_grad(cell_id, grad_input, grad_output): 65 assert grad_output[0].asnumpy().shape == (32, 6, 14, 14) 66 assert grad_input[0].asnumpy().shape == (32, 16, 10, 10) 67 68 69def custom_hook_function_print_and_save_grad(grad_out): 70 assert grad_out[0].asnumpy().shape == (32, 6, 28, 28) 71 72 73class LeNet5(nn.Cell): 74 def __init__(self, hook_function, cell_hook_function, num_class=10): 75 super(LeNet5, self).__init__() 76 self.num_class = num_class 77 self.batch_size = 32 78 self.conv1 = conv(1, 6, 5) 79 self.conv2 = conv(6, 16, 5) 80 self.conv1.register_backward_hook(cell_hook_function) 81 self.fc1 = fc_with_initialize(16 * 5 * 5, 120) 82 self.fc2 = fc_with_initialize(120, 84) 83 self.fc3 = fc_with_initialize(84, self.num_class) 84 self.relu = nn.ReLU() 85 self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) 86 self.reshape = P.Reshape() 87 self.hook = P.HookBackward(hook_function) 88 89 def construct(self, x): 90 x = self.conv1(x) 91 x = self.relu(x) 92 x = self.hook(x) 93 x = self.max_pool2d(x) 94 x = self.conv2(x) 95 x = self.relu(x) 96 x = self.max_pool2d(x) 97 x = self.reshape(x, (self.batch_size, -1)) 98 x = self.fc1(x) 99 x = self.relu(x) 100 x = self.fc2(x) 101 x = self.relu(x) 102 x = self.fc3(x) 103 return x 104 105 106class GradWrap(nn.Cell): 107 """ GradWrap definition """ 108 def __init__(self, network): 109 super(GradWrap, self).__init__(auto_prefix=False) 110 self.network = network 111 self.weights = ParameterTuple(filter(lambda x: x.requires_grad, network.get_parameters())) 112 113 def construct(self, x, label): 114 weights = self.weights 115 return C.GradOperation(get_by_list=True)(self.network, weights)(x, label) 116 117 118class test_custom_cell_base(): 119 def __init__(self): 120 pass 121 122 def test_custom_cell_function(self, cell): 123 return cell 124 125 126class MulAdd(nn.Cell): 127 def construct(self, x, y): 128 return 2 * x + y 129 130 def bprop(self, x, y, out, dout): 131 assert x.asnumpy() == 1.0 132 assert y.asnumpy() == 2.0 133 assert out.asnumpy() == 4.0 134 assert dout.asnumpy() == 1.0 135 return dout, y 136 137class Ms_Cell(nn.Cell): 138 def __init__(self): 139 super(Ms_Cell, self).__init__() 140 self.relu = P.ReLU() 141 142 def construct(self, x): 143 return self.relu(x) 144 145 def bprop(self, x, out, dout): 146 dout = Tensor(np.float32(0.0)) 147 assert dout.shape == () 148 return dout 149 150class Ms_Cell_Change_Shape(nn.Cell): 151 def __init__(self): 152 super(Ms_Cell_Change_Shape, self).__init__() 153 self.relu = P.ReLU() 154 155 def construct(self, x): 156 return self.relu(x) 157 158 def bprop(self, x, out, dout): 159 dout = Tensor(np.ones([5, 5]).astype(np.float32)) 160 assert dout.shape == (5, 5) 161 return dout 162 163 164@pytest.mark.level1 165@pytest.mark.platform_arm_ascend_training 166@pytest.mark.platform_x86_ascend_training 167@pytest.mark.env_onecard 168def test_pynative_lenet_train_hook_function_print_and_save_grad(): 169 hook = test_custom_hook_function_base() 170 function = hook.test_custom_hook_function(custom_hook_function_print_and_save_grad, 171 cell_hook_function_print_grad) 172 net = LeNet5(hook_function=function[0], cell_hook_function=function[1]) 173 optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.1, 0.9) 174 criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=False) 175 net_with_criterion = WithLossCell(net, criterion) 176 train_network = GradWrap(net_with_criterion) 177 train_network.set_train() 178 179 input_data = Tensor(np.ones([net.batch_size, 1, 32, 32]).astype(np.float32) * 0.01) 180 label = Tensor(np.ones([net.batch_size, net.num_class]).astype(np.float32)) 181 output = net(Tensor(input_data)) 182 criterion(output, label) 183 grads = train_network(input_data, label) 184 success = optimizer(grads) 185 assert success 186 187 188@pytest.mark.level1 189@pytest.mark.platform_arm_ascend_training 190@pytest.mark.platform_x86_ascend_training 191@pytest.mark.env_onecard 192def test_pynative_custom_bprop_and_Cell_MulAdd(): 193 custom_cell = test_custom_cell_base() 194 mul_add = custom_cell.test_custom_cell_function(MulAdd()) 195 mul_add.bprop_debug = True 196 grad_all(mul_add)(Tensor(1, mstype.float32), Tensor(2, mstype.float32)) 197 assert grad_all(mul_add)(Tensor(1, mstype.float32), Tensor(2, mstype.float32)) == \ 198 (Tensor(1.0, mstype.float32), Tensor(2.0, mstype.float32)) 199 200 201@pytest.mark.level1 202@pytest.mark.platform_arm_ascend_training 203@pytest.mark.platform_x86_ascend_training 204@pytest.mark.env_onecard 205def test_pynative_custom_bprop_and_Cell_Ms_Cell_Change_Shape(): 206 custom_cell = test_custom_cell_base() 207 ms_Cell = custom_cell.test_custom_cell_function(Ms_Cell_Change_Shape()) 208 ms_Cell.bprop_debug = True 209 with pytest.raises(RuntimeError) as ex: 210 grad_all(ms_Cell)(Tensor(1, mstype.float32)) 211 assert "Shapes of input and parameter are different, input index" in str(ex.value) 212 213 214@pytest.mark.level1 215@pytest.mark.platform_arm_ascend_training 216@pytest.mark.platform_x86_ascend_training 217@pytest.mark.env_onecard 218def test_pynative_custom_bprop_and_Cell_Ms_Cell(): 219 custom_cell = test_custom_cell_base() 220 ms_Cell = custom_cell.test_custom_cell_function(Ms_Cell()) 221 ms_Cell.bprop_debug = True 222 assert grad_all(ms_Cell)(Tensor(1, mstype.float32)) == (Tensor(0.0, mstype.float32),) 223