• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15import numpy as np
16import pytest
17
18import mindspore.nn as nn
19import mindspore.ops.operations as P
20from mindspore import context, Tensor, ParameterTuple
21from mindspore.common.initializer import TruncatedNormal
22from mindspore.nn import WithLossCell, Momentum
23from mindspore.ops import composite as C
24
25context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
26cell_hook_done = False
27var_hook_done = False
28cell_bprop_done = False
29
30
31grad_all = C.GradOperation(get_all=True)
32
33
34def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
35    """weight initial for conv layer"""
36    weight = weight_variable()
37    return nn.Conv2d(in_channels, out_channels,
38                     kernel_size=kernel_size, stride=stride, padding=padding,
39                     weight_init=weight, has_bias=False, pad_mode="valid")
40
41
42def fc_with_initialize(input_channels, out_channels):
43    """weight initial for fc layer"""
44    weight = weight_variable()
45    bias = weight_variable()
46    return nn.Dense(input_channels, out_channels, weight, bias)
47
48
49def weight_variable():
50    """weight initial"""
51    return TruncatedNormal(0.02)
52
53
54def cell_hook_function(cell_id, grad_input, grad_output):
55    print(cell_id)
56    global cell_hook_done
57    cell_hook_done = True
58    assert (grad_output[0].asnumpy().shape == (32, 6, 14, 14))
59    assert (grad_input[0].asnumpy().shape == (32, 16, 10, 10))
60
61
62def var_hook_function(grad_out):
63    print("grad:", grad_out)
64    global var_hook_done
65    var_hook_done = True
66    assert (grad_out[0].asnumpy().shape == (32, 120))
67
68
69class Block(nn.Cell):
70    def __init__(self):
71        super(Block, self).__init__()
72        self.relu = nn.ReLU()
73
74    def construct(self, x):
75        x = self.relu(x)
76        return x
77
78    def bprop(self, x, out, dout):
79        global cell_bprop_done
80        cell_bprop_done = True
81        grad = out.asnumpy() * dout.asnumpy()
82        grad = Tensor(grad)
83        return (grad,)
84
85class LeNet5(nn.Cell):
86    """
87    Lenet network
88    Args:
89        num_class (int): Num classes. Default: 10.
90    Returns:
91        Tensor, output tensor
92
93    Examples:
94        >>> LeNet(num_class=10)
95    """
96    def __init__(self, num_class=10):
97        super(LeNet5, self).__init__()
98        self.num_class = num_class
99        self.batch_size = 32
100        self.conv1 = conv(1, 6, 5)
101        self.conv2 = conv(6, 16, 5)
102        self.conv2.register_backward_hook(cell_hook_function)
103        self.block = Block()
104        self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
105        self.fc2 = fc_with_initialize(120, 84)
106        self.fc3 = fc_with_initialize(84, self.num_class)
107        self.relu = nn.ReLU()
108        self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
109        self.reshape = P.Reshape()
110        self.hook = P.HookBackward(var_hook_function)
111
112    def construct(self, x):
113        x = self.conv1(x)
114        x = self.relu(x)
115        x = self.max_pool2d(x)
116        x = self.conv2(x)
117        x = self.block(x)
118        x = self.max_pool2d(x)
119        x = self.reshape(x, (self.batch_size, -1))
120        x = self.fc1(x)
121        x = self.hook(x)
122        x = self.relu(x)
123        x = self.fc2(x)
124        x = self.relu(x)
125        x = self.fc3(x)
126        return x
127
128
129class GradWrap(nn.Cell):
130    """ GradWrap definition """
131    def __init__(self, network):
132        super(GradWrap, self).__init__(auto_prefix=False)
133        self.network = network
134        self.weights = ParameterTuple(filter(lambda x: x.requires_grad, network.get_parameters()))
135
136    def construct(self, x, label):
137        weights = self.weights
138        return C.GradOperation(get_by_list=True)(self.network, weights)(x, label)
139
140
141def test_hook():
142    net = LeNet5()
143    optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.1, 0.9)
144    criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=False)
145    net_with_criterion = WithLossCell(net, criterion)
146    train_network = GradWrap(net_with_criterion)
147    train_network.set_train()
148
149    input_data = Tensor(np.ones([net.batch_size, 1, 32, 32]).astype(np.float32) * 0.01)
150    label = Tensor(np.ones([net.batch_size, net.num_class]).astype(np.float32))
151    output = net(Tensor(input_data))
152    loss_output = criterion(output, label)
153    grads = train_network(input_data, label)
154    success = optimizer(grads)
155    assert cell_hook_done
156    assert var_hook_done
157    assert cell_bprop_done
158    print(loss_output.asnumpy())
159
160
161bprop_debug = False
162
163class MulAdd(nn.Cell):
164    def __init__(self):
165        super(MulAdd, self).__init__()
166
167    def construct(self, x, y):
168        return 2 * x * x + y * y
169
170    def bprop(self, x, y, out, dout):
171        global bprop_debug
172        bprop_debug = True
173        return dout, 2 * y
174
175
176def test_custom_bprop():
177    mul_add = MulAdd()
178    mul_add.bprop_debug = True
179    x = Tensor(np.array([1, 2, 3]).astype(np.int32))
180    y = Tensor(np.array([2, 3, 4]).astype(np.int32))
181    grad_all(mul_add)(x, y)
182    assert bprop_debug
183
184
185class Net(nn.Cell):
186    def __init__(self):
187        super(Net, self).__init__()
188
189    def construct(self, x, y):
190        return 2 * x * x + y * y
191
192def test_grad_all():
193    net = Net()
194    x = Tensor(np.array([1, 2, 3]).astype(np.int32))
195    y = Tensor(np.array([2, 3, 4]).astype(np.int32))
196    res = grad_all(net)(x, y)
197    print(res)
198
199def test_check_input():
200    net = Net()
201    x = np.array([1, 2, 3])
202    y = np.array([2, 3, 4])
203    with pytest.raises(TypeError):
204        net(x, y)
205