• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15""" test bprop disorder """
16import functools
17import numpy as np
18
19import mindspore.nn as nn
20import mindspore.context as context
21from mindspore import Tensor, Parameter
22from mindspore.common.parameter import ParameterTuple
23from mindspore.ops import composite as C
24from mindspore.ops import operations as P
25from ..ut_filter import non_graph_engine
26from ....mindspore_test_framework.mindspore_test import mindspore_test
27from ....mindspore_test_framework.pipeline.forward.compile_forward \
28    import pipeline_for_compile_forward_ge_graph_for_case_by_case_config
29
30context.set_context(mode=context.PYNATIVE_MODE)
31grad_by_list_with_sens = C.GradOperation(get_by_list=True, sens_param=True)
32
33
34class DisOrderTest1(nn.Cell):
35    """ DisOrderTest1 definition """
36
37    def __init__(self):
38        super(DisOrderTest1, self).__init__()
39        weight = Tensor(np.ones([1], np.float32))
40        self.s1 = Parameter(weight, name="s1")
41        self.s2 = Parameter(weight, name="s2")
42        self.s3 = Parameter(weight, name="s3")
43        self.s4 = Parameter(weight, name="s4")
44        self.mul = P.Mul()
45        self.add = P.Add()
46
47    def construct(self, x):
48        return x * (self.s1 * self.s2 + self.s2 * self.s3 + self.s3 * self.s4 + self.s4 * self.s1)
49
50
51class DisOrderTest2(nn.Cell):
52    """ DisOrderTest2 definition """
53
54    def __init__(self):
55        super(DisOrderTest2, self).__init__()
56        weight = Tensor(np.ones([1], np.float32))
57        self.s1 = Parameter(weight, name="s1")
58        self.s2 = Parameter(weight, name="s2")
59        self.s3 = Parameter(weight, name="s3")
60        self.s4 = Parameter(weight, name="s4")
61        self.mul = P.Mul()
62        self.add = P.Add()
63
64    def construct(self, x):
65        return self.mul(x, (self.add(self.add(self.add(self.mul(self.s1, self.s2), self.mul(self.s2, self.s3)),
66                                              self.mul(self.s3, self.s4)), self.mul(self.s4, self.s1))))
67
68
69class GradNetWrap(nn.Cell):
70    """ GradNetWrap definition """
71
72    def __init__(self, net):
73        super(GradNetWrap, self).__init__()
74        self.net = net
75        self.weights = ParameterTuple(net.get_parameters())
76
77    def construct(self, x, sens):
78        return grad_by_list_with_sens(self.net, self.weights)(x, sens)
79
80
81test_case_ops = [
82    ('DisOrderTest1', {
83        'block': GradNetWrap(DisOrderTest1()),
84        'desc_inputs': [Tensor(np.ones([1], np.float32)), Tensor(np.ones([1], np.float32))]}),
85    ('DisOrderTest2', {
86        'block': GradNetWrap(DisOrderTest2()),
87        'desc_inputs': [Tensor(np.ones([1], np.float32)), Tensor(np.ones([1], np.float32))]}),
88]
89
90test_case_lists = [test_case_ops]
91test_exec_case = functools.reduce(lambda x, y: x + y, test_case_lists)
92# use -k to select certain testcast
93# pytest tests/python/ops/test_ops.py::test_backward -k LayerNorm
94
95
96
97@non_graph_engine
98@mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config)
99def test_exec():
100    context.set_context(mode=context.GRAPH_MODE)
101    return test_exec_case
102