• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2024 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15""" test_bprop """
16import numpy as np
17import pytest
18import mindspore as ms
19from mindspore import grad
20import mindspore.nn as nn
21from mindspore import context
22from mindspore.common import Tensor
23from mindspore.common.api import jit
24from mindspore.common.parameter import Parameter, ParameterTuple
25from mindspore.ops import operations as P
26from mindspore.ops import GradOperation
27from tests.mindspore_test_framework.utils.bprop_util import bprop
28from tests.st.pynative.utils import GradOfFirstInput, GradOfAllInputs, GradOfAllInputsAndParams
29
30
31def setup_module():
32    context.set_context(mode=context.PYNATIVE_MODE)
33
34
35class Net(nn.Cell):
36    """ Net definition """
37
38    def __init__(self):
39        super(Net, self).__init__()
40        self.matmul = P.MatMul()
41        self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
42
43    @jit
44    def construct(self, x, y):
45        x = x * self.z
46        out = self.matmul(x, y)
47        return x, out
48
49
50def test_bprop_no_sens():
51    grads = bprop(Net(), Tensor(np.ones([2, 3]).astype(np.float32)),
52                  Tensor(np.ones([3, 2]).astype(np.float32)), wrt=['inputs'])
53    print(grads)
54
55
56def test_bprop_sens():
57    grads = bprop(Net(), Tensor(np.ones([2, 3]).astype(np.float32)), Tensor(np.ones([3, 2]).astype(np.float32)),
58                  grads_wrt_outputs=(Tensor(np.ones([2, 3]).astype(np.float32)),
59                                     Tensor(np.ones([2, 2]).astype(np.float32))), wrt=['inputs'])
60    print(grads)
61
62
63def test_bprop_first_only():
64    grads = bprop(Net(), Tensor(np.ones([2, 3]).astype(np.float32)), Tensor(np.ones([3, 2]).astype(np.float32)),
65                  grads_wrt_outputs=(Tensor(np.ones([2, 3]).astype(np.float32)),
66                                     Tensor(np.ones([2, 2]).astype(np.float32))))
67    print(grads)
68
69
70def test_bprop_wrt_params():
71    net = Net()
72    grads = bprop(net, Tensor(np.ones([2, 3]).astype(np.float32)), Tensor(np.ones([3, 2]).astype(np.float32)),
73                  grads_wrt_outputs=(Tensor(np.ones([2, 3]).astype(np.float32)),
74                                     Tensor(np.ones([2, 2]).astype(np.float32))),
75                  wrt=['params'],
76                  params=net.trainable_params())
77    print(grads)
78
79
80def test_bprop_wrt_params_no_sens():
81    net = Net()
82    grads = bprop(net, Tensor(np.ones([2, 3]).astype(np.float32)), Tensor(np.ones([3, 2]).astype(np.float32)),
83                  wrt=['params'],
84                  params=net.trainable_params())
85    print(grads)
86
87
88def test_bprop_wrt_inputs_and_params():
89    net = Net()
90    grads = bprop(net, Tensor(np.ones([2, 3]).astype(np.float32)), Tensor(np.ones([3, 2]).astype(np.float32)),
91                  grads_wrt_outputs=(Tensor(np.ones([2, 3]).astype(np.float32)),
92                                     Tensor(np.ones([2, 2]).astype(np.float32))),
93                  wrt=['inputs', 'params'],
94                  params=net.trainable_params())
95    print(grads)
96
97
98@pytest.mark.level1
99@pytest.mark.platform_x86_cpu
100@pytest.mark.env_onecard
101def test_network_with_dict_output():
102    """
103    Feature: Test sens dict
104    Description: Net out is dict
105    Expectation: Success
106    """
107
108    class DicNet(nn.Cell):
109        def __init__(self):
110            super().__init__()
111            self.relu = P.ReLU()
112
113        def construct(self, x):
114            y = self.relu(x)
115            out = {Tensor(True): y}
116            return out
117
118    x = np.array([[0.8, 0.6, 0.2], [1.8, 1.3, 1.1]])
119    ms_net = DicNet()
120    # No sens
121    ms_grad = GradOfFirstInput(ms_net, False)
122    grad_out = ms_grad(Tensor(x))
123    assert np.allclose(np.ones_like(x), grad_out.asnumpy())
124
125    # Have sens
126    out = ms_net(Tensor(x))
127    ms_grad = GradOfFirstInput(ms_net, True)
128    grad_out = ms_grad(Tensor(x), out)
129    assert np.allclose(x, grad_out.asnumpy())
130
131
132@pytest.mark.level0
133@pytest.mark.platform_x86_gpu_training
134@pytest.mark.env_onecard
135def test_jit_network_with_dict_output():
136    """
137    Feature: Test sens dict in jit
138    Description: Net out is dict in jit
139    Expectation: Success
140    """
141
142    class DicNet(nn.Cell):
143        def __init__(self):
144            super().__init__()
145            self.relu = P.ReLU()
146
147        @jit
148        def construct(self, x):
149            y = self.relu(x)
150            out = {'a': y}
151            return out
152
153    x = np.array([[0.8, 0.6, 0.2], [1.8, 1.3, 1.1]])
154    ms_net = DicNet()
155    # No sens
156    ms_grad = GradOfFirstInput(ms_net, False)
157    grad_out = ms_grad(Tensor(x))
158    assert np.allclose(np.ones_like(x), grad_out.asnumpy())
159
160    # Have sens
161    ms_net = DicNet()
162    out = ms_net(Tensor(x))
163    ms_grad = GradOfFirstInput(ms_net, True)
164    grad_out = ms_grad(Tensor(x), out)
165    assert np.allclose(x, grad_out.asnumpy())
166
167
168@pytest.mark.level0
169@pytest.mark.platform_x86_cpu
170@pytest.mark.env_onecard
171def test_pynative_synchronize():
172    """
173    Feature: Test pynative synchronize
174    Description: Test the code for the synchronous branch.
175    Expectation: success
176    """
177    try:
178        context.set_context(pynative_synchronize=True)
179
180        # Cell object to be differentiated
181        class MulNet(nn.Cell):
182            def construct(self, x, y, z):
183                return x * y * z
184
185        x = Tensor([1, 2], ms.float32)
186        y = Tensor([-2, 3], ms.float32)
187        z = Tensor([0, 3], ms.float32)
188        net = MulNet()
189        net.set_inputs(Tensor(shape=[None], dtype=ms.float32), y, z)
190        output = grad(net, grad_position=(1, 2))(x, y, z)
191        assert (output[0].asnumpy() == np.array([0, 6], dtype=np.float32)).all()
192        assert (output[1].asnumpy() == np.array([-2, 6], dtype=np.float32)).all()
193    finally:
194        context.set_context(pynative_synchronize=False)
195
196
197@pytest.mark.level0
198@pytest.mark.platform_x86_cpu
199@pytest.mark.env_onecard
200def test_pynative_multi_grad():
201    """
202    Feature: Test pynative multi grad
203    Description: Test the code for PyNative multi grad.
204    Expectation: success
205    """
206
207    class ForwardNetMul(nn.Cell):
208        def construct(self, x, y):
209            a = x * x
210            b = y * y
211            return a * b
212
213    class ForwardNetAdd(nn.Cell):
214        def construct(self, x, y):
215            a = x + x + x
216            b = y + y
217            return a * b
218
219    mulnet = ForwardNetMul()
220    addnet = ForwardNetAdd()
221    x = Tensor(np.ones([32]), dtype=ms.float32)
222    y = Tensor(np.ones([32]) * 2, dtype=ms.float32)
223    sens = Tensor(np.ones([32]), dtype=ms.float32)
224    mulnet.set_grad()
225    addnet.set_grad()
226    mulnet(x, y)
227    addnet(x, y)
228    grad_mul = GradOfAllInputs(mulnet)
229    grad_add = GradOfAllInputs(addnet)
230    grad_mul(x, y, sens)
231    grad_add(x, y, sens)
232
233
234class GradFactory:
235    def __init__(self, net_me, get_all, get_by_list, sens_param, net_params=None,
236                 defalut_para=False):
237        self.net_me = net_me
238        self.get_all = get_all
239        self.get_by_list = get_by_list
240        self.sens_param = sens_param
241        self.net_params = net_params
242        self.default_para = defalut_para
243
244    def get_grad(self, ms_input):
245        output_grad_me = []
246        out = self.net_me(*ms_input)
247        if isinstance(out, tuple):
248            for it in out:
249                if self.sens_param:
250                    grad_np = np.random.randn(*it.shape).astype(np.float32)
251                else:
252                    grad_np = np.ones(it.shape).astype(np.float32)
253                output_grad_me.append(Tensor(grad_np))
254            output_grad_me = tuple(output_grad_me)
255        else:
256            if self.sens_param:
257                grad_np = np.random.randn(*out.shape).astype(np.float32)
258            else:
259                grad_np = np.ones(out.shape).astype(np.float32)
260            output_grad_me = Tensor(grad_np)
261        return output_grad_me
262
263    def one_backnet_call_twice(self, first_ms_input, second_ms_input, loss=0.001):
264        grad_input = self.get_grad(first_ms_input)
265        if self.default_para:
266            back_net = nn.ForwardValueAndGrad(self.net_me)
267            back_net(*first_ms_input)
268        else:
269            if self.get_by_list:
270                weight = self.net_params
271            else:
272                weight = None
273            back_net = nn.ForwardValueAndGrad(self.net_me,
274                                              weights=weight, get_all=self.get_all,
275                                              get_by_list=self.get_by_list,
276                                              sens_param=self.sens_param)
277            if self.sens_param:
278                back_net(*first_ms_input, grad_input[0])
279            else:
280                back_net(*first_ms_input)
281
282        # second call
283        grad_input = self.get_grad(second_ms_input)
284        if self.default_para:
285            back_net(*second_ms_input)
286        else:
287            if self.sens_param:
288                back_net(*second_ms_input, grad_input[0])
289            else:
290                back_net(*second_ms_input)
291
292    def two_backnet_call_twice(self, first_ms_input, second_ms_input, loss=0.001):
293        grad_input = self.get_grad(first_ms_input)
294        if self.default_para:
295            back_net = nn.ForwardValueAndGrad(self.net_me)
296            back_net(*first_ms_input)
297        else:
298            if self.get_by_list:
299                weight = self.net_params
300            else:
301                weight = None
302            back_net = nn.ForwardValueAndGrad(self.net_me,
303                                              weights=weight, get_all=self.get_all,
304                                              get_by_list=self.get_by_list,
305                                              sens_param=self.sens_param)
306            if self.sens_param:
307                back_net(*first_ms_input, grad_input[0])
308            else:
309                back_net(*first_ms_input)
310
311        # second call
312        grad_input = self.get_grad(second_ms_input)
313        if self.default_para:
314            back_net2 = nn.ForwardValueAndGrad(self.net_me)
315            back_net2(*second_ms_input)
316        else:
317            back_net2 = nn.ForwardValueAndGrad(self.net_me,
318                                               weights=weight, get_all=self.get_all,
319                                               get_by_list=self.get_by_list,
320                                               sens_param=self.sens_param)
321            if self.sens_param:
322                back_net2(*second_ms_input, grad_input[0])
323            else:
324                back_net2(*second_ms_input)
325
326    def first_forward_second_backnet(self, first_ms_input, second_ms_input, loss=0.001):
327        # second call
328        grad_input = self.get_grad(second_ms_input)
329        if self.default_para:
330            back_net2 = nn.ForwardValueAndGrad(self.net_me)
331            back_net2(*second_ms_input)
332        else:
333            if self.get_by_list:
334                weight = self.net_params
335            else:
336                weight = None
337            back_net2 = nn.ForwardValueAndGrad(self.net_me,
338                                               weights=weight, get_all=self.get_all,
339                                               get_by_list=self.get_by_list,
340                                               sens_param=self.sens_param)
341            if self.sens_param:
342                back_net2(*second_ms_input, grad_input[0])
343            else:
344                back_net2(*second_ms_input)
345
346
347@pytest.mark.level0
348@pytest.mark.platform_x86_cpu
349@pytest.mark.env_onecard
350def test_forward_value_and_grad_0():
351    """
352    Feature: Test pynative value and grad
353    Description: Test the code for pynative value and grad.
354    Expectation: success
355    """
356
357    class Net0(nn.Cell):
358        def __init__(self):
359            super().__init__()
360            self.para = Parameter(Tensor([2, 3, 4], ms.float32), name="para")
361
362        def construct(self):
363            x = self.para * self.para
364            return x
365
366    net_me = Net0()
367    fact = GradFactory(net_me=net_me,
368                       get_all=True,
369                       get_by_list=True,
370                       sens_param=False,
371                       net_params=ParameterTuple(net_me.trainable_params()))
372
373    first_input = ()
374    second_input = ()
375    fact.one_backnet_call_twice(first_input, second_input)
376    fact.two_backnet_call_twice(first_input, second_input)
377    fact.first_forward_second_backnet(first_input, second_input)
378
379
380@pytest.mark.level0
381@pytest.mark.platform_x86_cpu
382@pytest.mark.env_onecard
383def test_forward_value_and_grad_1():
384    """
385    Feature: Test pynative value and grad
386    Description: Test the code for pynative value and grad.
387    Expectation: success
388    """
389
390    class Net1(nn.Cell):
391        def __init__(self):
392            super().__init__()
393            self.para = Parameter(Tensor([1], ms.float32), name="para")
394
395        def construct(self, x):
396            y = x + self.para
397            return y
398
399    net_me = Net1()
400    fact = GradFactory(net_me=net_me,
401                       get_all=False,
402                       get_by_list=False,
403                       sens_param=False,
404                       defalut_para=True)
405
406    input_1 = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32))
407    first_input = (input_1,)
408
409    input_1 = Tensor(np.random.randn(1, 2, 3, 4).astype(np.float32))
410    second_input = (input_1,)
411    fact.one_backnet_call_twice(first_input, second_input)
412    fact.two_backnet_call_twice(first_input, second_input)
413    fact.first_forward_second_backnet(first_input, second_input)
414
415
416class CustomNet(nn.Cell):
417    def __init__(self):
418        super().__init__()
419        self.p1 = Parameter(Tensor(np.array([1.0], np.float32)), name='p1')
420        self.p2 = Parameter(Tensor(np.array([1.0], np.float32)), name='p2')
421        self.p3 = Parameter(Tensor(np.array([1.0], np.float32)), name='p2')
422        self.p1.requires_grad = False
423        self.p2.requires_grad = False
424        self.p3.requires_grad = True
425
426    def construct(self, x):
427        out = self.p1 * x
428        out = out * self.p2
429        out = out + self.p3
430        return out
431
432
433@pytest.mark.level0
434@pytest.mark.platform_x86_cpu
435@pytest.mark.env_onecard
436def test_pynative_requires_grad():
437    """
438    Feature: Test pynative requires grad
439    Description: Test the code for requires grad
440    Expectation: success
441    """
442    x = Tensor([1], ms.float32)
443    net = CustomNet()
444    output = GradOfAllInputsAndParams(net, sens_param=False)(x)
445    assert (output[1][0].asnumpy() == np.array([1.0], dtype=np.float32)).all()
446
447
448@pytest.mark.level0
449@pytest.mark.platform_x86_cpu
450@pytest.mark.env_onecard
451def test_pynative_requires_grad_use_grad_operation():
452    """
453    Feature: Test pynative requires grad use grad operation
454    Description: Test the code for requires grad
455    Expectation: success
456    """
457
458    # Cell object to be differentiated
459    x = Tensor([1], ms.float32)
460    net = CustomNet()
461    output = GradOperation(get_all=True, get_by_list=True)(net, [net.p1, net.p2, net.p3])(x)
462    assert (output[1][0].asnumpy() == np.array([0.0], dtype=np.float32)).all()
463    assert (output[1][1].asnumpy() == np.array([0.0], dtype=np.float32)).all()
464    assert (output[1][2].asnumpy() == np.array([1.0], dtype=np.float32)).all()
465
466
467@pytest.mark.level0
468@pytest.mark.platform_x86_cpu
469@pytest.mark.env_onecard
470def test_pynative_requires_grad_without_params():
471    """
472    Feature: Test pynative requires grad without params
473    Description: Test the code for requires grad
474    Expectation: success
475    """
476
477    # Cell object to be differentiated
478    x = Tensor([1], ms.float32)
479    net = CustomNet()
480    output = GradOperation(get_all=True, get_by_list=True)(net)(x)
481    assert (output[1][0].asnumpy() == np.array([0.0], dtype=np.float32)).all()
482    assert (output[1][1].asnumpy() == np.array([0.0], dtype=np.float32)).all()
483    assert (output[1][2].asnumpy() == np.array([1.0], dtype=np.float32)).all()
484
485
486@pytest.mark.level0
487@pytest.mark.platform_x86_cpu
488@pytest.mark.env_onecard
489def test_pynative_requires_grad_case2():
490    """
491    Feature: Test pynative requires grad case2
492    Description: Test the code for requires grad
493    Expectation: success
494    """
495
496    # Cell object to be differentiated
497    x = Tensor([1], ms.float32)
498    net = CustomNet()
499    output = GradOperation(get_all=True, get_by_list=True)(net, [net.p1])(x)
500    assert (output[1][0].asnumpy() == np.array([0.0], dtype=np.float32)).all()
501    assert len(output[1]) == 1
502