1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15import numpy as np 16import pytest 17 18import mindspore.nn as nn 19import mindspore.ops.operations as P 20from mindspore import context, Tensor, ParameterTuple 21from mindspore.common.initializer import TruncatedNormal 22from mindspore.nn import WithLossCell, Momentum 23from mindspore.ops import composite as C 24 25context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") 26cell_hook_done = False 27var_hook_done = False 28cell_bprop_done = False 29 30 31grad_all = C.GradOperation(get_all=True) 32 33 34def conv(in_channels, out_channels, kernel_size, stride=1, padding=0): 35 """weight initial for conv layer""" 36 weight = weight_variable() 37 return nn.Conv2d(in_channels, out_channels, 38 kernel_size=kernel_size, stride=stride, padding=padding, 39 weight_init=weight, has_bias=False, pad_mode="valid") 40 41 42def fc_with_initialize(input_channels, out_channels): 43 """weight initial for fc layer""" 44 weight = weight_variable() 45 bias = weight_variable() 46 return nn.Dense(input_channels, out_channels, weight, bias) 47 48 49def weight_variable(): 50 """weight initial""" 51 return TruncatedNormal(0.02) 52 53 54def cell_hook_function(cell_id, grad_input, grad_output): 55 print(cell_id) 56 global cell_hook_done 57 cell_hook_done = True 58 assert (grad_output[0].asnumpy().shape == (32, 6, 14, 14)) 59 assert (grad_input[0].asnumpy().shape == (32, 16, 10, 10)) 60 61 62def var_hook_function(grad_out): 63 print("grad:", grad_out) 64 global var_hook_done 65 var_hook_done = True 66 assert (grad_out[0].asnumpy().shape == (32, 120)) 67 68 69class Block(nn.Cell): 70 def __init__(self): 71 super(Block, self).__init__() 72 self.relu = nn.ReLU() 73 74 def construct(self, x): 75 x = self.relu(x) 76 return x 77 78 def bprop(self, x, out, dout): 79 global cell_bprop_done 80 cell_bprop_done = True 81 grad = out.asnumpy() * dout.asnumpy() 82 grad = Tensor(grad) 83 return (grad,) 84 85class LeNet5(nn.Cell): 86 """ 87 Lenet network 88 Args: 89 num_class (int): Num classes. Default: 10. 90 Returns: 91 Tensor, output tensor 92 93 Examples: 94 >>> LeNet(num_class=10) 95 """ 96 def __init__(self, num_class=10): 97 super(LeNet5, self).__init__() 98 self.num_class = num_class 99 self.batch_size = 32 100 self.conv1 = conv(1, 6, 5) 101 self.conv2 = conv(6, 16, 5) 102 self.conv2.register_backward_hook(cell_hook_function) 103 self.block = Block() 104 self.fc1 = fc_with_initialize(16 * 5 * 5, 120) 105 self.fc2 = fc_with_initialize(120, 84) 106 self.fc3 = fc_with_initialize(84, self.num_class) 107 self.relu = nn.ReLU() 108 self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) 109 self.reshape = P.Reshape() 110 self.hook = P.HookBackward(var_hook_function) 111 112 def construct(self, x): 113 x = self.conv1(x) 114 x = self.relu(x) 115 x = self.max_pool2d(x) 116 x = self.conv2(x) 117 x = self.block(x) 118 x = self.max_pool2d(x) 119 x = self.reshape(x, (self.batch_size, -1)) 120 x = self.fc1(x) 121 x = self.hook(x) 122 x = self.relu(x) 123 x = self.fc2(x) 124 x = self.relu(x) 125 x = self.fc3(x) 126 return x 127 128 129class GradWrap(nn.Cell): 130 """ GradWrap definition """ 131 def __init__(self, network): 132 super(GradWrap, self).__init__(auto_prefix=False) 133 self.network = network 134 self.weights = ParameterTuple(filter(lambda x: x.requires_grad, network.get_parameters())) 135 136 def construct(self, x, label): 137 weights = self.weights 138 return C.GradOperation(get_by_list=True)(self.network, weights)(x, label) 139 140 141def test_hook(): 142 net = LeNet5() 143 optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.1, 0.9) 144 criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=False) 145 net_with_criterion = WithLossCell(net, criterion) 146 train_network = GradWrap(net_with_criterion) 147 train_network.set_train() 148 149 input_data = Tensor(np.ones([net.batch_size, 1, 32, 32]).astype(np.float32) * 0.01) 150 label = Tensor(np.ones([net.batch_size, net.num_class]).astype(np.float32)) 151 output = net(Tensor(input_data)) 152 loss_output = criterion(output, label) 153 grads = train_network(input_data, label) 154 success = optimizer(grads) 155 assert cell_hook_done 156 assert var_hook_done 157 assert cell_bprop_done 158 print(loss_output.asnumpy()) 159 160 161bprop_debug = False 162 163class MulAdd(nn.Cell): 164 def __init__(self): 165 super(MulAdd, self).__init__() 166 167 def construct(self, x, y): 168 return 2 * x * x + y * y 169 170 def bprop(self, x, y, out, dout): 171 global bprop_debug 172 bprop_debug = True 173 return dout, 2 * y 174 175 176def test_custom_bprop(): 177 mul_add = MulAdd() 178 mul_add.bprop_debug = True 179 x = Tensor(np.array([1, 2, 3]).astype(np.int32)) 180 y = Tensor(np.array([2, 3, 4]).astype(np.int32)) 181 grad_all(mul_add)(x, y) 182 assert bprop_debug 183 184 185class Net(nn.Cell): 186 def __init__(self): 187 super(Net, self).__init__() 188 189 def construct(self, x, y): 190 return 2 * x * x + y * y 191 192def test_grad_all(): 193 net = Net() 194 x = Tensor(np.array([1, 2, 3]).astype(np.int32)) 195 y = Tensor(np.array([2, 3, 4]).astype(np.int32)) 196 res = grad_all(net)(x, y) 197 print(res) 198 199def test_check_input(): 200 net = Net() 201 x = np.array([1, 2, 3]) 202 y = np.array([2, 3, 4]) 203 with pytest.raises(TypeError): 204 net(x, y) 205