• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2023 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15""" test_cell_bprop """
16import numpy as np
17import pytest
18
19import mindspore as ms
20import mindspore.common.dtype as mstype
21import mindspore.nn as nn
22from mindspore import Parameter, ParameterTuple
23from mindspore import context, mutable
24from mindspore.common.initializer import initializer
25from mindspore.common.tensor import Tensor
26from mindspore.ops import composite as C
27from mindspore.ops import operations as P
28from mindspore import ops
29from mindspore._extends import cell_attr_register
30
31context.set_context(mode=context.GRAPH_MODE)
32grad_all = C.GradOperation(get_all=True)
33
34
35class MulAdd(nn.Cell):
36    def construct(self, x, y):
37        return 2 * x + y
38
39    def bprop(self, x, y, out, dout):
40        # In this test case, The user defined bprop is wrong defined purposely to distinguish from ad result
41        return 2 * dout, 2 * y
42
43
44@pytest.mark.level0
45@pytest.mark.platform_x86_ascend_training
46@pytest.mark.env_onecard
47def test_grad_mul_add():
48    mul_add = MulAdd()
49    x = Tensor(1, dtype=ms.int32)
50    y = Tensor(2, dtype=ms.int32)
51    assert grad_all(mul_add)(x, y) == (2, 4)
52
53
54class InlineMulADD(nn.Cell):
55    def __init__(self):
56        super(InlineMulADD, self).__init__()
57        self.mul_add = MulAdd()
58        self.param = 2
59
60    def construct(self, x, y):
61        return self.mul_add(x, y) + x + self.param * y
62
63
64@pytest.mark.level0
65@pytest.mark.platform_x86_ascend_training
66@pytest.mark.env_onecard
67def test_grad_inline_mul_add():
68    inline_mul_add = InlineMulADD()
69    x = Tensor(1, dtype=ms.int32)
70    y = Tensor(2, dtype=ms.int32)
71    assert grad_all(inline_mul_add)(x, y) == (3, 6)
72
73
74class WithParameter(nn.Cell):
75    def __init__(self):
76        super(WithParameter, self).__init__()
77        self.param1 = Parameter(1, 'param1')
78        self.param2 = Parameter(2, 'param2')
79
80    def construct(self, x, y):
81        return self.param1 * self.param2 * x + y
82
83    def bprop(self, x, y, out, dout):
84        # In this test case, The user defined bprop is wrong defined purposely to distinguish from ad result
85        return self.param1 * self.param2 * dout, 2 * y
86
87
88@pytest.mark.level0
89@pytest.mark.platform_x86_ascend_training
90@pytest.mark.env_onecard
91def test_with_param():
92    with_param = WithParameter()
93    with pytest.raises(RuntimeError):
94        grad_all(with_param)(mutable(1), 2)
95
96
97class WithNoBprop(nn.Cell):
98    def construct(self, x, y):
99        return 2 * x + y
100
101
102@pytest.mark.level0
103@pytest.mark.platform_x86_ascend_training
104@pytest.mark.env_onecard
105def test_with_no_bprop():
106    with_no_bprop = WithNoBprop()
107    x = Tensor(1, dtype=ms.int32)
108    y = Tensor(2, dtype=ms.int32)
109    assert grad_all(with_no_bprop)(x, y) == (2, 1)
110
111
112@pytest.mark.level0
113@pytest.mark.platform_x86_ascend_training
114@pytest.mark.env_onecard
115def test_grad_in_bprop_1():
116    class GradInBprop_1(nn.Cell):
117        def __init__(self):
118            super(GradInBprop_1, self).__init__()
119            self.relu = P.ReLU()
120
121        def construct(self, x, y):
122            return self.relu(x)
123
124    class GradInBprop_2(nn.Cell):
125        def __init__(self):
126            super(GradInBprop_2, self).__init__()
127            self.f = GradInBprop_1()
128
129        def construct(self, x, y):
130            return self.f(x, y), grad_all(self.f)(x, y)
131
132        def bprop(self, x, y, out, dout):
133            grads = grad_all(self.f)(x, y)
134            return out[1][0], grads[1]
135
136    class GradInBprop_3(nn.Cell):
137        def __init__(self):
138            super(GradInBprop_3, self).__init__()
139            self.f = GradInBprop_2()
140
141        def construct(self, x, y):
142            return self.f(x, y)
143
144    grad_in_bprop = GradInBprop_3()
145    grads = grad_all(grad_in_bprop)(Tensor(np.ones([2, 2]).astype(np.float32)),
146                                    Tensor(np.ones([2, 2]).astype(np.float32)))
147    assert (grads[0].asnumpy() == np.ones([2, 2]).astype(np.float32)).all()
148    assert (grads[1].asnumpy() == np.zeros([2, 2]).astype(np.float32)).all()
149
150
151@pytest.mark.level0
152@pytest.mark.platform_x86_ascend_training
153@pytest.mark.env_onecard
154def test_grad_in_bprop_2():
155    class GradInBprop_1(nn.Cell):
156        def __init__(self):
157            super(GradInBprop_1, self).__init__()
158            self.relu = P.ReLU()
159
160        def construct(self, x, y):
161            return self.relu(x)
162
163        def bprop(self, x, y, out, dout):
164            return x * y, y + x
165
166    class GradInBprop_2(nn.Cell):
167        def __init__(self):
168            super(GradInBprop_2, self).__init__()
169            self.f = GradInBprop_1()
170
171        def construct(self, x, y):
172            return self.f(x, y), grad_all(self.f)(x, y)
173
174        def bprop(self, x, y, out, dout):
175            grads = grad_all(self.f)(x, y)
176            return out[1][0], grads[1]
177
178    class GradInBprop_3(nn.Cell):
179        def __init__(self):
180            super(GradInBprop_3, self).__init__()
181            self.f = GradInBprop_2()
182
183        def construct(self, x, y):
184            return self.f(x, y)
185
186    grad_in_bprop = GradInBprop_3()
187    grads = grad_all(grad_in_bprop)(Tensor(np.ones([2, 2]).astype(np.float32)),
188                                    Tensor(np.ones([2, 2]).astype(np.float32)))
189    assert (grads[0].asnumpy() == np.ones([2, 2]).astype(np.float32)).all()
190    assert (grads[1].asnumpy() == np.array([[2, 2], [2, 2]]).astype(np.float32)).all()
191
192
193@pytest.mark.level0
194@pytest.mark.platform_x86_ascend_training
195@pytest.mark.env_onecard
196def test_grad_in_bprop_3():
197    class GradInBprop_1(nn.Cell):
198        def __init__(self):
199            super(GradInBprop_1, self).__init__()
200            self.relu = P.ReLU()
201
202        def construct(self, x, y):
203            return self.relu(x)
204
205    class GradInBprop_2(nn.Cell):
206        def __init__(self):
207            super(GradInBprop_2, self).__init__()
208            self.f = GradInBprop_1()
209
210        def construct(self, x, y):
211            return self.f(x, y), grad_all(self.f)(x, y)
212
213        def bprop(self, x, y, out, dout):
214            grads = grad_all(self.f)(x, y)
215            return out[1][0], grads[1]
216
217    class GradInBprop_3(nn.Cell):
218        def __init__(self):
219            super(GradInBprop_3, self).__init__()
220            self.f = GradInBprop_2()
221
222        def construct(self, x, y):
223            return self.f(x, y)
224
225        def bprop(self, x, y, out, dout):
226            return x + y + y + out[0], x + x + y + y + dout[0]
227
228    grad_in_bprop = GradInBprop_3()
229    grads = grad_all(grad_in_bprop)(Tensor(np.ones([2, 2]).astype(np.float32)),
230                                    Tensor(np.ones([2, 2]).astype(np.float32)))
231    assert (grads[0].asnumpy() == np.array([[4, 4], [4, 4]]).astype(np.float32)).all()
232    assert (grads[1].asnumpy() == np.array([[5, 5], [5, 5]]).astype(np.float32)).all()
233
234
235class OneInputBprop(nn.Cell):
236    def __init__(self):
237        super().__init__()
238        self.op = P.ReLU()
239
240    def construct(self, x):
241        return self.op(x)
242
243    def bprop(self, x, out, dout):
244        return (5 * x,)
245
246
247@pytest.mark.level0
248@pytest.mark.platform_x86_ascend_training
249@pytest.mark.env_onecard
250def test_grad_one_input_bprop():
251    net = OneInputBprop()
252    input1 = Tensor(np.ones([2, 2]).astype(np.float32))
253    grad = grad_all(net)(input1)
254    assert (grad[0].asnumpy() == np.array([5, 5]).astype(np.float32)).all()
255
256
257class TwoInput(nn.Cell):
258    def construct(self, x, y):
259        return x * y
260
261
262class InlineBpropTwoInput(nn.Cell):
263    def __init__(self):
264        super().__init__()
265        self.f = TwoInput()
266
267    def construct(self, x, y):
268        return self.f(x, y), grad_all(self.f)(x, y)
269
270    def bprop(self, x, y, out, dout):
271        grads = grad_all(self.f)(x, y)
272        return grads[0] * 2, grads[1] * 2
273
274
275@pytest.mark.level0
276@pytest.mark.platform_x86_ascend_training
277@pytest.mark.env_onecard
278def test_grad_inline_bprop_two_input():
279    net = InlineBpropTwoInput()
280    input1 = Tensor(np.ones([2, 2]).astype(np.float32))
281    input2 = Tensor(np.ones([2, 2]).astype(np.float32))
282    grads = grad_all(net)(input1, input2)
283    assert (grads[0].asnumpy() == np.array([2, 2]).astype(np.float32)).all()
284    assert (grads[1].asnumpy() == np.array([2, 2]).astype(np.float32)).all()
285    assert len(grads) == 2
286
287
288class TwoInputBprop(nn.Cell):
289    def __init__(self):
290        super().__init__()
291        self.op = P.Mul()
292
293    def construct(self, x, y):
294        return self.op(x, y)
295
296    def bprop(self, x, y, out, dout):
297        return 5 * x, 8 * y
298
299
300class TwoInputWithParameter(nn.Cell):
301    def __init__(self):
302        super().__init__()
303        self.op = P.Mul()
304        self.inputdata = Parameter(initializer(1, (2, 2), mstype.float32), name="global_step")
305
306    def construct(self, x, y):
307        x = self.inputdata + x
308        return self.op(x, y)
309
310
311class TwoInputWithOnlyInitParameterBprop(nn.Cell):
312    def __init__(self):
313        super().__init__()
314        self.op = P.Mul()
315        self.inputdata = Parameter(initializer(1, (2, 2), mstype.float32), name="global_step")
316
317    def construct(self, x, y):
318        return self.op(x, y)
319
320    def bprop(self, x, y, out, dout):
321        return 5 * x, 8 * y
322
323
324class InlineMutilTwoInputParameterCell(nn.Cell):
325    def __init__(self):
326        super().__init__()
327        self.f1 = TwoInputBprop()
328        self.f2 = TwoInput()
329        self.f3 = TwoInputWithParameter()
330        self.f4 = TwoInputWithOnlyInitParameterBprop()
331
332    def construct(self, x, y):
333        output = self.f1(x, y) + self.f2(x, y) + self.f3(x, y) + self.f4(x, y)
334        return output
335
336
337@pytest.mark.level0
338@pytest.mark.platform_x86_ascend_training
339@pytest.mark.env_onecard
340def test_grad_inline_bprop_multi_input():
341    net = InlineMutilTwoInputParameterCell()
342    input1 = Tensor(np.ones([2, 2]).astype(np.float32))
343    input2 = Tensor(np.ones([2, 2]).astype(np.float32))
344    net.init_parameters_data()
345    grads = grad_all(net)(input1, input2)
346    assert (grads[0].asnumpy() == np.array([[12, 12], [12, 12]]).astype(np.float32)).all()
347    assert (grads[1].asnumpy() == np.array([[19, 19], [19, 19]]).astype(np.float32)).all()
348    assert len(grads) == 2
349
350
351class MulAddWithParam(nn.Cell):
352    def __init__(self):
353        super(MulAddWithParam, self).__init__()
354        self.mul_add = MulAdd()
355        self.param = Parameter(Tensor(np.array([[3, 2]], np.float32)), 'param')
356
357    def construct(self, x):
358        return self.mul_add(self.param, x)
359
360
361@pytest.mark.level0
362@pytest.mark.platform_x86_ascend_training
363@pytest.mark.env_onecard
364def test_refkey_bprop():
365    grad_by_list = C.GradOperation(get_all=True, get_by_list=True)
366
367    class GradWrap(nn.Cell):
368        def __init__(self, network):
369            super(GradWrap, self).__init__()
370            self.network = network
371            self.weights = ParameterTuple(filter(lambda x: x.requires_grad, network.get_parameters()))
372
373        def construct(self, x):
374            weights = self.weights
375            grads = grad_by_list(self.network, weights)(x)
376            return grads
377
378    network = GradWrap(MulAddWithParam())
379    input_data = Tensor(np.array([2, 2], np.float32))
380    grads = network(input_data)
381    assert (grads[0][0].asnumpy() == np.array([4, 4]).astype(np.float32)).all()
382    assert (grads[1][0].asnumpy() == np.array([2, 2]).astype(np.float32)).all()
383
384
385class MulAddWithWrongOutputNum(nn.Cell):
386    def construct(self, x, y):
387        return 2 * x + y
388
389    def bprop(self, x, y, out, dout):
390        return (2 * dout,)
391
392
393@pytest.mark.level0
394@pytest.mark.platform_x86_ascend_training
395@pytest.mark.env_onecard
396def test_grad_mul_add_with_wrong_output_num():
397    context.set_context(check_bprop=True)
398    mul_add = MulAddWithWrongOutputNum()
399    with pytest.raises(ValueError):
400        grad_all(mul_add)(mutable(1), 2)
401
402
403class MulAddWithWrongOutputType(nn.Cell):
404    def construct(self, x, y):
405        return 2 * x + y
406
407    def bprop(self, x, y, out, dout):
408        return 2 * dout, 2
409
410
411@pytest.mark.level0
412@pytest.mark.platform_x86_ascend_training
413@pytest.mark.env_onecard
414def test_grad_mul_add_with_wrong_output_type():
415    context.set_context(check_bprop=True)
416    mul_add = MulAddWithWrongOutputType()
417    with pytest.raises(TypeError):
418        grad_all(mul_add)(1, Tensor(np.ones([2, 2])))
419
420
421class MulAddWithWrongOutputShape(nn.Cell):
422    def __init__(self):
423        super(MulAddWithWrongOutputShape, self).__init__()
424        self.ones = Tensor(np.ones([2,]))
425
426    def construct(self, x, y):
427        return 2 * x + y
428
429    def bprop(self, x, y, out, dout):
430        return 2, self.ones
431
432
433@pytest.mark.level0
434@pytest.mark.platform_x86_ascend_training
435@pytest.mark.env_onecard
436def test_grad_mul_add_with_wrong_output_shape():
437    context.set_context(check_bprop=True)
438    mul_add = MulAddWithWrongOutputShape()
439    with pytest.raises(ValueError):
440        grad_all(mul_add)(1, Tensor(np.ones([2, 2])))
441
442
443@pytest.mark.level1
444@pytest.mark.platform_x86_cpu
445@pytest.mark.env_onecard
446def test_forward_with_parameter():
447    """
448    Feature: Custom cell bprop
449    Description: Get the gradients of inputs when the forward net using Parameter.
450    Expectation: Get the correct gradients.
451    """
452
453    class Net(nn.Cell):
454        def __init__(self):
455            super(Net, self).__init__()
456            self.matmul = P.MatMul()
457            self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
458
459        def construct(self, x, y):
460            x = x * self.z
461            out = self.matmul(x, y)
462            return out
463
464        def bprop(self, x, y, out, dout):
465            dx = x + x
466            dy = y + y
467            return dx, dy
468
469    class GradNet(nn.Cell):
470        def __init__(self, net):
471            super(GradNet, self).__init__()
472            self.net = net
473
474        def construct(self, x, y):
475            grad_f = grad_all(self.net)
476            return grad_f(x, y)
477
478    x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
479    y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
480    out = GradNet(Net())(x, y)
481    expect_dx = np.array([[1.0, 1.2, 0.8],
482                          [2.4, 2.6, 2.2]]).astype(np.float32)
483    expect_dy = np.array([[0.02, 0.6, 2.2],
484                          [0.2, 0.4, 2.6],
485                          [4.2, 2.4, 6.6]]).astype(np.float32)
486    assert np.allclose(out[0].asnumpy(), expect_dx)
487    assert np.allclose(out[1].asnumpy(), expect_dy)
488
489
490@pytest.mark.level1
491@pytest.mark.platform_x86_cpu
492@pytest.mark.env_onecard
493def test_forward_with_parameter_in_sub_cell():
494    """
495    Feature: Custom cell bprop
496    Description: Get the gradients of inputs when the forward net using Parameter in the sub-cell.
497    Expectation: Get the correct gradients.
498    """
499
500    class Net(nn.Cell):
501        def __init__(self):
502            super(Net, self).__init__()
503            self.net = Net1()
504
505        def construct(self, x, y):
506            return self.net(x, y)
507
508    class Net1(nn.Cell):
509        def __init__(self):
510            super(Net1, self).__init__()
511            self.matmul = P.MatMul()
512            self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
513
514        def construct(self, x, y):
515            x = x * self.z
516            out = self.matmul(x, y)
517            return out
518
519        def bprop(self, x, y, out, dout):
520            dx = x + x
521            dy = y + y
522            return dx, dy
523
524    class GradNet(nn.Cell):
525        def __init__(self, net):
526            super(GradNet, self).__init__()
527            self.net = net
528
529        def construct(self, x, y):
530            grad_f = grad_all(self.net)
531            return grad_f(x, y)
532
533    x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
534    y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
535    out = GradNet(Net())(x, y)
536    expect_dx = np.array([[1.0, 1.2, 0.8],
537                          [2.4, 2.6, 2.2]]).astype(np.float32)
538    expect_dy = np.array([[0.02, 0.6, 2.2],
539                          [0.2, 0.4, 2.6],
540                          [4.2, 2.4, 6.6]]).astype(np.float32)
541    assert np.allclose(out[0].asnumpy(), expect_dx)
542    assert np.allclose(out[1].asnumpy(), expect_dy)
543
544
545@pytest.mark.level1
546@pytest.mark.platform_x86_cpu
547@pytest.mark.env_onecard
548def test_forward_with_parameter_in_sub_cell_get_by_list():
549    """
550    Feature: Custom cell bprop
551    Description: Get the gradients of inputs and Parameters when the forward net using Parameter in the sub-cell.
552    Expectation: Get the correct gradients.
553    """
554
555    class Net(nn.Cell):
556        def __init__(self):
557            super(Net, self).__init__()
558            self.net = Net1()
559
560        def construct(self, x, y):
561            return self.net(x, y)
562
563    class Net1(nn.Cell):
564        def __init__(self):
565            super(Net1, self).__init__()
566            self.matmul = P.MatMul()
567            self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
568
569        def construct(self, x, y):
570            x = x * self.z
571            out = self.matmul(x, y)
572            return out
573
574        def bprop(self, x, y, out, dout):
575            dx = x + x
576            dy = y + y
577            return dx, dy
578
579    class GradNet(nn.Cell):
580        def __init__(self, net):
581            super(GradNet, self).__init__()
582            self.net = net
583            self.params = ParameterTuple(net.trainable_params())
584            self.grad_op = C.GradOperation(get_by_list=True, get_all=True)
585
586        def construct(self, x, y):
587            grad_f = self.grad_op(self.net, self.params)
588            return grad_f(x, y)
589
590    x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
591    y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
592    out = GradNet(Net())(x, y)
593    expect_dx = np.array([[1.0, 1.2, 0.8],
594                          [2.4, 2.6, 2.2]]).astype(np.float32)
595    expect_dy = np.array([[0.02, 0.6, 2.2],
596                          [0.2, 0.4, 2.6],
597                          [4.2, 2.4, 6.6]]).astype(np.float32)
598    expect_dz = np.array([0.0]).astype(np.float32)
599    assert np.allclose(out[0][0].asnumpy(), expect_dx)
600    assert np.allclose(out[0][1].asnumpy(), expect_dy)
601    assert np.allclose(out[1][0].asnumpy(), expect_dz)
602
603
604@pytest.mark.level1
605@pytest.mark.platform_x86_cpu
606@pytest.mark.env_onecard
607def test_pynative_forward_with_parameter():
608    """
609    Feature: Custom cell bprop
610    Description: Get the gradients of inputs when the forward net using Parameter.
611    Expectation: Get the correct gradients.
612    """
613    context.set_context(mode=context.PYNATIVE_MODE)
614
615    class Net(nn.Cell):
616        def __init__(self):
617            super(Net, self).__init__()
618            self.matmul = P.MatMul()
619            self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
620
621        def construct(self, x, y):
622            x = x * self.z
623            out = self.matmul(x, y)
624            return out
625
626        def bprop(self, x, y, out, dout):
627            dx = x + x
628            dy = y + y
629            return dx, dy
630
631    class GradNet(nn.Cell):
632        def __init__(self, net):
633            super(GradNet, self).__init__()
634            self.net = net
635
636        def construct(self, x, y):
637            grad_f = grad_all(self.net)
638            return grad_f(x, y)
639
640    x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
641    y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
642    out = GradNet(Net())(x, y)
643    expect_dx = np.array([[1.0, 1.2, 0.8],
644                          [2.4, 2.6, 2.2]]).astype(np.float32)
645    expect_dy = np.array([[0.02, 0.6, 2.2],
646                          [0.2, 0.4, 2.6],
647                          [4.2, 2.4, 6.6]]).astype(np.float32)
648    assert np.allclose(out[0].asnumpy(), expect_dx)
649    assert np.allclose(out[1].asnumpy(), expect_dy)
650    context.set_context(mode=context.GRAPH_MODE)
651
652
653@pytest.mark.level1
654@pytest.mark.platform_x86_cpu
655@pytest.mark.env_onecard
656def test_pynative_forward_with_parameter_in_sub_cell():
657    """
658    Feature: Custom cell bprop
659    Description: Get the gradients of inputs when the forward net using Parameter in the sub-cell.
660    Expectation: Get the correct gradients.
661    """
662    context.set_context(mode=context.PYNATIVE_MODE)
663
664    class Net(nn.Cell):
665        def __init__(self):
666            super(Net, self).__init__()
667            self.net = Net1()
668
669        def construct(self, x, y):
670            return self.net(x, y)
671
672    class Net1(nn.Cell):
673        def __init__(self):
674            super(Net1, self).__init__()
675            self.matmul = P.MatMul()
676            self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
677
678        def construct(self, x, y):
679            x = x * self.z
680            out = self.matmul(x, y)
681            return out
682
683        def bprop(self, x, y, out, dout):
684            dx = x + x
685            dy = y + y
686            return dx, dy
687
688    class GradNet(nn.Cell):
689        def __init__(self, net):
690            super(GradNet, self).__init__()
691            self.net = net
692
693        def construct(self, x, y):
694            grad_f = grad_all(self.net)
695            return grad_f(x, y)
696
697    x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
698    y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
699    out = GradNet(Net())(x, y)
700    expect_dx = np.array([[1.0, 1.2, 0.8],
701                          [2.4, 2.6, 2.2]]).astype(np.float32)
702    expect_dy = np.array([[0.02, 0.6, 2.2],
703                          [0.2, 0.4, 2.6],
704                          [4.2, 2.4, 6.6]]).astype(np.float32)
705    assert np.allclose(out[0].asnumpy(), expect_dx)
706    assert np.allclose(out[1].asnumpy(), expect_dy)
707    context.set_context(mode=context.GRAPH_MODE)
708
709
710@pytest.mark.level1
711@pytest.mark.platform_x86_cpu
712@pytest.mark.env_onecard
713def test_pynative_forward_with_parameter_in_sub_cell_get_by_list():
714    """
715    Feature: Custom cell bprop
716    Description: Get the gradients of inputs and Parameters when the forward net using Parameter in the sub-cell.
717    Expectation: Get the correct gradients.
718    """
719    context.set_context(mode=context.PYNATIVE_MODE)
720
721    class Net(nn.Cell):
722        def __init__(self):
723            super(Net, self).__init__()
724            self.net = Net1()
725
726        def construct(self, x, y):
727            return self.net(x, y)
728
729    class Net1(nn.Cell):
730        def __init__(self):
731            super(Net1, self).__init__()
732            self.matmul = P.MatMul()
733            self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
734
735        def construct(self, x, y):
736            x = x * self.z
737            out = self.matmul(x, y)
738            return out
739
740        def bprop(self, x, y, out, dout):
741            dx = x + x
742            dy = y + y
743            return dx, dy
744
745    class GradNet(nn.Cell):
746        def __init__(self, net):
747            super(GradNet, self).__init__()
748            self.net = net
749            self.params = ParameterTuple(net.trainable_params())
750            self.grad_op = C.GradOperation(get_by_list=True, get_all=True)
751
752        def construct(self, x, y):
753            grad_f = self.grad_op(self.net, self.params)
754            return grad_f(x, y)
755
756    x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
757    y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
758    out = GradNet(Net())(x, y)
759    expect_dx = np.array([[1.0, 1.2, 0.8],
760                          [2.4, 2.6, 2.2]]).astype(np.float32)
761    expect_dy = np.array([[0.02, 0.6, 2.2],
762                          [0.2, 0.4, 2.6],
763                          [4.2, 2.4, 6.6]]).astype(np.float32)
764    expect_dz = np.array([0.0]).astype(np.float32)
765    assert np.allclose(out[0][0].asnumpy(), expect_dx)
766    assert np.allclose(out[0][1].asnumpy(), expect_dy)
767    assert np.allclose(out[1][0].asnumpy(), expect_dz)
768    context.set_context(mode=context.GRAPH_MODE)
769
770
771@pytest.mark.level1
772@pytest.mark.platform_x86_gpu_training
773@pytest.mark.env_onecard
774def test_dde_self_define_cell_output_not_use():
775    """
776    Feature: Custom cell bprop
777    Description: Fprop output[1] only used by bprop, it should not erased by dde.
778    Expectation: Get the correct gradients.
779    """
780
781    class SelfDefineCell(ms.nn.Cell):
782        def construct(self, x):
783            return x + 1, x + 2
784
785        def bprop(self, x, out, dout):
786            return (out[1],)
787
788    class ForwardNet(ms.nn.Cell):
789        def __init__(self):
790            super(ForwardNet, self).__init__()
791            self.self_defined_cell = SelfDefineCell()
792
793        def construct(self, x):
794            # keep out1 not used in fprop.
795            out0, _ = self.self_defined_cell(x)
796            return out0
797
798    class TestNet(ms.nn.Cell):
799        def __init__(self):
800            super(TestNet, self).__init__()
801            self.forward_net = ForwardNet()
802            self.grad_op = ops.GradOperation(get_all=True)
803
804        def construct(self, x):
805            grad_out = self.grad_op(self.forward_net)(x)
806            return grad_out
807
808    net = TestNet()
809    x_input = ms.Tensor([1])
810    out = net(x_input)
811    assert out[0] == ms.Tensor([3])
812
813
814@pytest.mark.level1
815@pytest.mark.platform_x86_gpu_training
816@pytest.mark.env_onecard
817def test_bprop_defined_in_cell_attr_register():
818    """
819    Feature: Custom cell bprop
820    Description: Get the gradients of input for the cell which has been added @cell_attr_register.
821    Expectation: Get the correct gradients.
822    """
823
824    class Net(nn.Cell):
825        @cell_attr_register
826        def __init__(self):
827            super().__init__()
828            self.z = Parameter(Tensor(2, mstype.float32), name='z')
829
830        def construct(self, x, y):
831            x = x * self.z
832            return x * y
833
834        def bprop(self, x, y, out, dout):
835            return y, x
836
837    net = Net()
838    x = Tensor(3, mstype.float32)
839    y = Tensor(4, mstype.float32)
840    output = ops.grad(net)(x, y)
841    assert output == 4
842