• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15""" test_cell_bprop """
16import numpy as np
17import pytest
18
19import mindspore as ms
20import mindspore.common.dtype as mstype
21import mindspore.nn as nn
22from mindspore import Parameter, ParameterTuple
23from mindspore import context
24from mindspore.common.initializer import initializer
25from mindspore.common.tensor import Tensor
26from mindspore.ops import composite as C
27from mindspore.ops import operations as P
28
29context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
30
31
32grad_all = C.GradOperation(get_all=True)
33
34
35class MulAdd(nn.Cell):
36    def construct(self, x, y):
37        return 2 * x + y
38
39    def bprop(self, x, y, out, dout):
40        # In this test case, The user defined bprop is wrong defined purposely to distinguish from ad result
41        return 2 * dout, 2 * y
42
43@pytest.mark.level0
44@pytest.mark.platform_x86_ascend_training
45@pytest.mark.env_onecard
46def test_grad_mul_add():
47    mul_add = MulAdd()
48    x = Tensor(1, dtype=ms.int32)
49    y = Tensor(2, dtype=ms.int32)
50    assert grad_all(mul_add)(x, y) == (2, 4)
51
52
53class InlineMulADD(nn.Cell):
54    def __init__(self):
55        super(InlineMulADD, self).__init__()
56        self.mul_add = MulAdd()
57        self.param = 2
58
59    def construct(self, x, y):
60        return self.mul_add(x, y) + x + self.param * y
61
62@pytest.mark.level0
63@pytest.mark.platform_x86_ascend_training
64@pytest.mark.env_onecard
65def test_grad_inline_mul_add():
66    inline_mul_add = InlineMulADD()
67    x = Tensor(1, dtype=ms.int32)
68    y = Tensor(2, dtype=ms.int32)
69    assert grad_all(inline_mul_add)(x, y) == (3, 6)
70
71
72class WithParameter(nn.Cell):
73    def __init__(self):
74        super(WithParameter, self).__init__()
75        self.param1 = Parameter(1, 'param1')
76        self.param2 = Parameter(2, 'param2')
77
78    def construct(self, x, y):
79        return self.param1 * self.param2 * x + y
80
81    def bprop(self, x, y, out, dout):
82        # In this test case, The user defined bprop is wrong defined purposely to distinguish from ad result
83        return self.param1 * self.param2 * dout, 2 * y
84
85@pytest.mark.level0
86@pytest.mark.platform_x86_ascend_training
87@pytest.mark.env_onecard
88def test_with_param():
89    with_param = WithParameter()
90    with pytest.raises(RuntimeError):
91        grad_all(with_param)(1, 2)
92
93
94class WithNoBprop(nn.Cell):
95    def construct(self, x, y):
96        return 2 * x + y
97
98@pytest.mark.level0
99@pytest.mark.platform_x86_ascend_training
100@pytest.mark.env_onecard
101def test_with_no_bprop():
102    with_no_bprop = WithNoBprop()
103    x = Tensor(1, dtype=ms.int32)
104    y = Tensor(2, dtype=ms.int32)
105    assert grad_all(with_no_bprop)(x, y) == (2, 1)
106
107@pytest.mark.level0
108@pytest.mark.platform_x86_ascend_training
109@pytest.mark.env_onecard
110def test_grad_in_bprop_1():
111    class GradInBprop_1(nn.Cell):
112        def __init__(self):
113            super(GradInBprop_1, self).__init__()
114            self.relu = P.ReLU()
115
116        def construct(self, x, y):
117            return self.relu(x)
118
119    class GradInBprop_2(nn.Cell):
120        def __init__(self):
121            super(GradInBprop_2, self).__init__()
122            self.f = GradInBprop_1()
123
124        def construct(self, x, y):
125            return self.f(x, y), grad_all(self.f)(x, y)
126
127        def bprop(self, x, y, out, dout):
128            grads = grad_all(self.f)(x, y)
129            return out[1][0], grads[1]
130
131    class GradInBprop_3(nn.Cell):
132        def __init__(self):
133            super(GradInBprop_3, self).__init__()
134            self.f = GradInBprop_2()
135
136        def construct(self, x, y):
137            return self.f(x, y)
138
139    grad_in_bprop = GradInBprop_3()
140    grads = grad_all(grad_in_bprop)(Tensor(np.ones([2, 2]).astype(np.float32)),
141                                    Tensor(np.ones([2, 2]).astype(np.float32)))
142    assert (grads[0].asnumpy() == np.ones([2, 2]).astype(np.float32)).all()
143    assert (grads[1].asnumpy() == np.zeros([2, 2]).astype(np.float32)).all()
144
145@pytest.mark.level0
146@pytest.mark.platform_x86_ascend_training
147@pytest.mark.env_onecard
148def test_grad_in_bprop_2():
149    class GradInBprop_1(nn.Cell):
150        def __init__(self):
151            super(GradInBprop_1, self).__init__()
152            self.relu = P.ReLU()
153
154        def construct(self, x, y):
155            return self.relu(x)
156
157        def bprop(self, x, y, out, dout):
158            return x * y, y + x
159
160    class GradInBprop_2(nn.Cell):
161        def __init__(self):
162            super(GradInBprop_2, self).__init__()
163            self.f = GradInBprop_1()
164
165        def construct(self, x, y):
166            return self.f(x, y), grad_all(self.f)(x, y)
167
168        def bprop(self, x, y, out, dout):
169            grads = grad_all(self.f)(x, y)
170            return out[1][0], grads[1]
171
172    class GradInBprop_3(nn.Cell):
173        def __init__(self):
174            super(GradInBprop_3, self).__init__()
175            self.f = GradInBprop_2()
176
177        def construct(self, x, y):
178            return self.f(x, y)
179
180    grad_in_bprop = GradInBprop_3()
181    grads = grad_all(grad_in_bprop)(Tensor(np.ones([2, 2]).astype(np.float32)),
182                                    Tensor(np.ones([2, 2]).astype(np.float32)))
183    assert (grads[0].asnumpy() == np.ones([2, 2]).astype(np.float32)).all()
184    assert (grads[1].asnumpy() == np.array([[2, 2], [2, 2]]).astype(np.float32)).all()
185
186@pytest.mark.level0
187@pytest.mark.platform_x86_ascend_training
188@pytest.mark.env_onecard
189def test_grad_in_bprop_3():
190    class GradInBprop_1(nn.Cell):
191        def __init__(self):
192            super(GradInBprop_1, self).__init__()
193            self.relu = P.ReLU()
194
195        def construct(self, x, y):
196            return self.relu(x)
197
198    class GradInBprop_2(nn.Cell):
199        def __init__(self):
200            super(GradInBprop_2, self).__init__()
201            self.f = GradInBprop_1()
202
203        def construct(self, x, y):
204            return self.f(x, y), grad_all(self.f)(x, y)
205
206        def bprop(self, x, y, out, dout):
207            grads = grad_all(self.f)(x, y)
208            return out[1][0], grads[1]
209
210    class GradInBprop_3(nn.Cell):
211        def __init__(self):
212            super(GradInBprop_3, self).__init__()
213            self.f = GradInBprop_2()
214
215        def construct(self, x, y):
216            return self.f(x, y)
217
218        def bprop(self, x, y, out, dout):
219            return x + y + y + out[0], x + x + y + y + dout[0]
220
221    grad_in_bprop = GradInBprop_3()
222    grads = grad_all(grad_in_bprop)(Tensor(np.ones([2, 2]).astype(np.float32)),
223                                    Tensor(np.ones([2, 2]).astype(np.float32)))
224    assert (grads[0].asnumpy() == np.array([[4, 4], [4, 4]]).astype(np.float32)).all()
225    assert (grads[1].asnumpy() == np.array([[5, 5], [5, 5]]).astype(np.float32)).all()
226
227
228class OneInputBprop(nn.Cell):
229    def __init__(self):
230        super().__init__()
231        self.op = P.ReLU()
232
233    def construct(self, x):
234        return self.op(x)
235
236    def bprop(self, x, out, dout):
237        return (5 * x,)
238
239@pytest.mark.level0
240@pytest.mark.platform_x86_ascend_training
241@pytest.mark.env_onecard
242def test_grad_one_input_bprop():
243    net = OneInputBprop()
244    input1 = Tensor(np.ones([2, 2]).astype(np.float32))
245    grad = grad_all(net)(input1)
246    assert (grad[0].asnumpy() == np.array([5, 5]).astype(np.float32)).all()
247
248
249class TwoInput(nn.Cell):
250    def construct(self, x, y):
251        return x * y
252
253
254class InlineBpropTwoInput(nn.Cell):
255    def __init__(self):
256        super().__init__()
257        self.f = TwoInput()
258
259    def construct(self, x, y):
260        return self.f(x, y), grad_all(self.f)(x, y)
261
262    def bprop(self, x, y, out, dout):
263        grads = grad_all(self.f)(x, y)
264        return grads[0] * 2, grads[1] * 2
265
266@pytest.mark.level0
267@pytest.mark.platform_x86_ascend_training
268@pytest.mark.env_onecard
269def test_grad_inline_bprop_two_input():
270    net = InlineBpropTwoInput()
271    input1 = Tensor(np.ones([2, 2]).astype(np.float32))
272    input2 = Tensor(np.ones([2, 2]).astype(np.float32))
273    grads = grad_all(net)(input1, input2)
274    assert (grads[0].asnumpy() == np.array([2, 2]).astype(np.float32)).all()
275    assert (grads[1].asnumpy() == np.array([2, 2]).astype(np.float32)).all()
276    assert len(grads) == 2
277
278
279class TwoInputBprop(nn.Cell):
280    def __init__(self):
281        super().__init__()
282        self.op = P.Mul()
283
284    def construct(self, x, y):
285        return self.op(x, y)
286
287    def bprop(self, x, y, out, dout):
288        return 5 * x, 8 * y
289
290
291class TwoInputWithParameter(nn.Cell):
292    def __init__(self):
293        super().__init__()
294        self.op = P.Mul()
295        self.inputdata = Parameter(initializer(1, (2, 2), mstype.float32), name="global_step")
296
297    def construct(self, x, y):
298        x = self.inputdata + x
299        return self.op(x, y)
300
301
302class TwoInputWithOnlyInitParameterBprop(nn.Cell):
303    def __init__(self):
304        super().__init__()
305        self.op = P.Mul()
306        self.inputdata = Parameter(initializer(1, (2, 2), mstype.float32), name="global_step")
307
308    def construct(self, x, y):
309        return self.op(x, y)
310
311    def bprop(self, x, y, out, dout):
312        return 5 * x, 8 * y
313
314
315class InlineMutilTwoInputParameterCell(nn.Cell):
316    def __init__(self):
317        super().__init__()
318        self.f1 = TwoInputBprop()
319        self.f2 = TwoInput()
320        self.f3 = TwoInputWithParameter()
321        self.f4 = TwoInputWithOnlyInitParameterBprop()
322
323    def construct(self, x, y):
324        output = self.f1(x, y) + self.f2(x, y) + self.f3(x, y) + self.f4(x, y)
325        return output
326
327@pytest.mark.level0
328@pytest.mark.platform_x86_ascend_training
329@pytest.mark.env_onecard
330def test_grad_inline_bprop_multi_input():
331    net = InlineMutilTwoInputParameterCell()
332    input1 = Tensor(np.ones([2, 2]).astype(np.float32))
333    input2 = Tensor(np.ones([2, 2]).astype(np.float32))
334    net.init_parameters_data()
335    grads = grad_all(net)(input1, input2)
336    assert (grads[0].asnumpy() == np.array([[12, 12], [12, 12]]).astype(np.float32)).all()
337    assert (grads[1].asnumpy() == np.array([[19, 19], [19, 19]]).astype(np.float32)).all()
338    assert len(grads) == 2
339
340
341class MulAddWithParam(nn.Cell):
342    def __init__(self):
343        super(MulAddWithParam, self).__init__()
344        self.mul_add = MulAdd()
345        self.param = Parameter(Tensor(np.array([[3, 2]], np.float32)), 'param')
346
347    def construct(self, x):
348        return self.mul_add(self.param, x)
349
350@pytest.mark.level0
351@pytest.mark.platform_x86_ascend_training
352@pytest.mark.env_onecard
353def test_refkey_bprop():
354    grad_by_list = C.GradOperation(get_all=True, get_by_list=True)
355    class GradWrap(nn.Cell):
356        def __init__(self, network):
357            super(GradWrap, self).__init__()
358            self.network = network
359            self.weights = ParameterTuple(filter(lambda x: x.requires_grad, network.get_parameters()))
360        def construct(self, x):
361            weights = self.weights
362            grads = grad_by_list(self.network, weights)(x)
363            return grads
364    network = GradWrap(MulAddWithParam())
365    input_data = Tensor(np.array([2, 2], np.float32))
366    grads = network(input_data)
367    assert (grads[0][0].asnumpy() == np.array([4, 4]).astype(np.float32)).all()
368    assert (grads[1][0].asnumpy() == np.array([2, 2]).astype(np.float32)).all()
369
370
371class MulAddWithWrongOutputNum(nn.Cell):
372    def construct(self, x, y):
373        return 2 * x + y
374
375    def bprop(self, x, y, out, dout):
376        return (2 * dout,)
377
378@pytest.mark.level0
379@pytest.mark.platform_x86_ascend_training
380@pytest.mark.env_onecard
381def test_grad_mul_add_with_wrong_output_num():
382    context.set_context(check_bprop=True)
383    mul_add = MulAddWithWrongOutputNum()
384    with pytest.raises(TypeError):
385        grad_all(mul_add)(1, 2)
386
387
388class MulAddWithWrongOutputType(nn.Cell):
389    def construct(self, x, y):
390        return 2 * x + y
391
392    def bprop(self, x, y, out, dout):
393        return 2 * dout, 2
394
395@pytest.mark.level0
396@pytest.mark.platform_x86_ascend_training
397@pytest.mark.env_onecard
398def test_grad_mul_add_with_wrong_output_type():
399    context.set_context(check_bprop=True)
400    mul_add = MulAddWithWrongOutputType()
401    with pytest.raises(TypeError):
402        grad_all(mul_add)(1, Tensor(np.ones([2, 2])))
403
404
405class MulAddWithWrongOutputShape(nn.Cell):
406    def __init__(self):
407        super(MulAddWithWrongOutputShape, self).__init__()
408        self.ones = Tensor(np.ones([2,]))
409
410    def construct(self, x, y):
411        return 2 * x + y
412
413    def bprop(self, x, y, out, dout):
414        return 2, self.ones
415
416@pytest.mark.level0
417@pytest.mark.platform_x86_ascend_training
418@pytest.mark.env_onecard
419def test_grad_mul_add_with_wrong_output_shape():
420    context.set_context(check_bprop=True)
421    mul_add = MulAddWithWrongOutputShape()
422    with pytest.raises(TypeError):
423        grad_all(mul_add)(1, Tensor(np.ones([2, 2])))
424