• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16import pytest
17import numpy as np
18import mindspore.nn as nn
19import mindspore.common.dtype as mstype
20
21from mindspore import Tensor
22from mindspore import context
23from mindspore import ParameterTuple
24from mindspore.nn import Momentum
25from mindspore.nn import WithLossCell
26from mindspore.ops import composite as C
27from mindspore.ops import operations as P
28from mindspore.common.initializer import TruncatedNormal
29
30context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
31
32
33grad_all = C.GradOperation(get_all=True)
34
35
36def weight_variable():
37    """weight initial"""
38    return TruncatedNormal(0.02)
39
40
41def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
42    """weight initial for conv layer"""
43    weight = weight_variable()
44    return nn.Conv2d(in_channels, out_channels,
45                     kernel_size=kernel_size, stride=stride, padding=padding,
46                     weight_init=weight, has_bias=False, pad_mode="valid")
47
48
49def fc_with_initialize(input_channels, out_channels):
50    """weight initial for fc layer"""
51    weight = weight_variable()
52    bias = weight_variable()
53    return nn.Dense(input_channels, out_channels, weight, bias)
54
55
56class test_custom_hook_function_base():
57    def __init__(self):
58        pass
59
60    def test_custom_hook_function(self, hook_function, cell_hook_function):
61        return hook_function, cell_hook_function
62
63
64def cell_hook_function_print_grad(cell_id, grad_input, grad_output):
65    assert grad_output[0].asnumpy().shape == (32, 6, 14, 14)
66    assert grad_input[0].asnumpy().shape == (32, 16, 10, 10)
67
68
69def custom_hook_function_print_and_save_grad(grad_out):
70    assert grad_out[0].asnumpy().shape == (32, 6, 28, 28)
71
72
73class LeNet5(nn.Cell):
74    def __init__(self, hook_function, cell_hook_function, num_class=10):
75        super(LeNet5, self).__init__()
76        self.num_class = num_class
77        self.batch_size = 32
78        self.conv1 = conv(1, 6, 5)
79        self.conv2 = conv(6, 16, 5)
80        self.conv1.register_backward_hook(cell_hook_function)
81        self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
82        self.fc2 = fc_with_initialize(120, 84)
83        self.fc3 = fc_with_initialize(84, self.num_class)
84        self.relu = nn.ReLU()
85        self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
86        self.reshape = P.Reshape()
87        self.hook = P.HookBackward(hook_function)
88
89    def construct(self, x):
90        x = self.conv1(x)
91        x = self.relu(x)
92        x = self.hook(x)
93        x = self.max_pool2d(x)
94        x = self.conv2(x)
95        x = self.relu(x)
96        x = self.max_pool2d(x)
97        x = self.reshape(x, (self.batch_size, -1))
98        x = self.fc1(x)
99        x = self.relu(x)
100        x = self.fc2(x)
101        x = self.relu(x)
102        x = self.fc3(x)
103        return x
104
105
106class GradWrap(nn.Cell):
107    """ GradWrap definition """
108    def __init__(self, network):
109        super(GradWrap, self).__init__(auto_prefix=False)
110        self.network = network
111        self.weights = ParameterTuple(filter(lambda x: x.requires_grad, network.get_parameters()))
112
113    def construct(self, x, label):
114        weights = self.weights
115        return C.GradOperation(get_by_list=True)(self.network, weights)(x, label)
116
117
118class test_custom_cell_base():
119    def __init__(self):
120        pass
121
122    def test_custom_cell_function(self, cell):
123        return cell
124
125
126class MulAdd(nn.Cell):
127    def construct(self, x, y):
128        return 2 * x + y
129
130    def bprop(self, x, y, out, dout):
131        assert x.asnumpy() == 1.0
132        assert y.asnumpy() == 2.0
133        assert out.asnumpy() == 4.0
134        assert dout.asnumpy() == 1.0
135        return dout, y
136
137class Ms_Cell(nn.Cell):
138    def __init__(self):
139        super(Ms_Cell, self).__init__()
140        self.relu = P.ReLU()
141
142    def construct(self, x):
143        return self.relu(x)
144
145    def bprop(self, x, out, dout):
146        dout = Tensor(np.float32(0.0))
147        assert dout.shape == ()
148        return dout
149
150class Ms_Cell_Change_Shape(nn.Cell):
151    def __init__(self):
152        super(Ms_Cell_Change_Shape, self).__init__()
153        self.relu = P.ReLU()
154
155    def construct(self, x):
156        return self.relu(x)
157
158    def bprop(self, x, out, dout):
159        dout = Tensor(np.ones([5, 5]).astype(np.float32))
160        assert dout.shape == (5, 5)
161        return dout
162
163
164@pytest.mark.level1
165@pytest.mark.platform_arm_ascend_training
166@pytest.mark.platform_x86_ascend_training
167@pytest.mark.env_onecard
168def test_pynative_lenet_train_hook_function_print_and_save_grad():
169    hook = test_custom_hook_function_base()
170    function = hook.test_custom_hook_function(custom_hook_function_print_and_save_grad,
171                                              cell_hook_function_print_grad)
172    net = LeNet5(hook_function=function[0], cell_hook_function=function[1])
173    optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.1, 0.9)
174    criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=False)
175    net_with_criterion = WithLossCell(net, criterion)
176    train_network = GradWrap(net_with_criterion)
177    train_network.set_train()
178
179    input_data = Tensor(np.ones([net.batch_size, 1, 32, 32]).astype(np.float32) * 0.01)
180    label = Tensor(np.ones([net.batch_size, net.num_class]).astype(np.float32))
181    output = net(Tensor(input_data))
182    criterion(output, label)
183    grads = train_network(input_data, label)
184    success = optimizer(grads)
185    assert success
186
187
188@pytest.mark.level1
189@pytest.mark.platform_arm_ascend_training
190@pytest.mark.platform_x86_ascend_training
191@pytest.mark.env_onecard
192def test_pynative_custom_bprop_and_Cell_MulAdd():
193    custom_cell = test_custom_cell_base()
194    mul_add = custom_cell.test_custom_cell_function(MulAdd())
195    mul_add.bprop_debug = True
196    grad_all(mul_add)(Tensor(1, mstype.float32), Tensor(2, mstype.float32))
197    assert grad_all(mul_add)(Tensor(1, mstype.float32), Tensor(2, mstype.float32)) == \
198           (Tensor(1.0, mstype.float32), Tensor(2.0, mstype.float32))
199
200
201@pytest.mark.level1
202@pytest.mark.platform_arm_ascend_training
203@pytest.mark.platform_x86_ascend_training
204@pytest.mark.env_onecard
205def test_pynative_custom_bprop_and_Cell_Ms_Cell_Change_Shape():
206    custom_cell = test_custom_cell_base()
207    ms_Cell = custom_cell.test_custom_cell_function(Ms_Cell_Change_Shape())
208    ms_Cell.bprop_debug = True
209    with pytest.raises(RuntimeError) as ex:
210        grad_all(ms_Cell)(Tensor(1, mstype.float32))
211    assert "Shapes of input and parameter are different, input index" in str(ex.value)
212
213
214@pytest.mark.level1
215@pytest.mark.platform_arm_ascend_training
216@pytest.mark.platform_x86_ascend_training
217@pytest.mark.env_onecard
218def test_pynative_custom_bprop_and_Cell_Ms_Cell():
219    custom_cell = test_custom_cell_base()
220    ms_Cell = custom_cell.test_custom_cell_function(Ms_Cell())
221    ms_Cell.bprop_debug = True
222    assert grad_all(ms_Cell)(Tensor(1, mstype.float32)) == (Tensor(0.0, mstype.float32),)
223