• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15import pytest
16import numpy as np
17from mindspore import context, nn, Tensor, Parameter, ParameterTuple
18from mindspore.common import dtype as mstype
19from mindspore.ops import composite as C
20
21
22@pytest.fixture(scope="module", autouse=True)
23def setup_teardown():
24    context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
25    yield
26    context.set_context(mode=context.GRAPH_MODE)
27
28
29class _Grad(nn.Cell):
30    def __init__(self, grad, network, wrt_params=False, real_inputs_count=None):
31        super().__init__()
32        self.network = network
33        self.grad = grad
34        self.sens_param = self.grad.sens_param
35        self.wrt_params = wrt_params
36        self.real_inputs_count = real_inputs_count
37        if self.wrt_params:
38            self.params = ParameterTuple(self.network.trainable_params())
39
40    def construct(self, *inputs):
41        if self.wrt_params:
42            if self.real_inputs_count is None or self.sens_param is False:
43                return self.grad(self.network, self.params)(*inputs)
44            real_inputs = inputs[:self.real_inputs_count]
45            sense_param_inputs = inputs[self.real_inputs_count:]
46            return self.grad(self.network, self.params)(*real_inputs, sense_param_inputs)
47
48        if self.real_inputs_count is None or self.sens_param is False:
49            return self.grad(self.network)(*inputs)
50        real_inputs = inputs[:self.real_inputs_count]
51        sense_param_inputs = inputs[self.real_inputs_count:]
52        return self.grad(self.network)(*real_inputs, sense_param_inputs)
53
54
55class GradOfFirstInput(_Grad):
56    """
57    get grad of first input
58    """
59
60    def __init__(self, network, sens_param=True, real_inputs_count=None):
61        super().__init__(grad=C.GradOperation(sens_param=sens_param),
62                         network=network, real_inputs_count=real_inputs_count)
63
64
65class GradOfAllInputs(_Grad):
66    """
67    get grad of first input
68    """
69
70    def __init__(self, network, sens_param=True, real_inputs_count=None):
71        super().__init__(grad=C.GradOperation(get_all=True, sens_param=sens_param),
72                         network=network, real_inputs_count=real_inputs_count)
73
74
75def test_multi_grad():
76    class ForwardNetMul(nn.Cell):
77        def __init__(self):
78            super().__init__()
79
80        def construct(self, x, y):
81            a = x * x
82            b = y * y
83            return a * b
84
85    class ForwardNetAdd(nn.Cell):
86        def __init__(self):
87            super().__init__()
88
89        def construct(self, x, y):
90            a = x + x + x
91            b = y + y
92            return a * b
93    mulnet = ForwardNetMul()
94    addnet = ForwardNetAdd()
95    x = Tensor(np.ones([32]), dtype=mstype.float32)
96    y = Tensor(np.ones([32])*2, dtype=mstype.float32)
97    sens = Tensor(np.ones([32]), dtype=mstype.float32)
98    mulnet.set_grad()
99    addnet.set_grad()
100    out1 = mulnet(x, y)
101    out2 = addnet(x, y)
102    grad_mul = GradOfAllInputs(mulnet)
103    grad_add = GradOfAllInputs(addnet)
104    grad_mul(x, y, sens)
105    grad_add(x, y, sens)
106
107
108def test_multi_same_grad():
109    class ForwardNetMul(nn.Cell):
110        def __init__(self):
111            super().__init__()
112
113        def construct(self, x, y):
114            a = x * x
115            b = y * y
116            return a * b
117
118    class ForwardNetAdd(nn.Cell):
119        def __init__(self):
120            super().__init__()
121
122        def construct(self, x, y):
123            a = x*3
124            b = y*2
125            return a + b
126    mulnet = ForwardNetMul()
127    addnet = ForwardNetAdd()
128    x = Tensor(np.ones([32]), dtype=mstype.float32)
129    y = Tensor(np.ones([32]), dtype=mstype.float32)
130    sens = Tensor(np.ones([32]), dtype=mstype.float32)
131    mulnet.set_grad()
132    addnet.set_grad()
133    out1 = mulnet(x, y)
134    out2 = addnet(x, y)
135    grad_mul = GradOfAllInputs(mulnet)
136    grad_add = GradOfFirstInput(mulnet)
137    grad_mul(x, y, sens)
138    grad_add(x, y, sens)
139
140
141def test_net_inner_grad():
142    class ForwardNetMul(nn.Cell):
143        def __init__(self):
144            super().__init__()
145
146        def construct(self, x, y):
147            a = x * x
148            b = y * y
149            return a * b
150
151    class ForwardNetAdd(nn.Cell):
152        def __init__(self, net):
153            super().__init__()
154            self.net = net
155
156        def construct(self, x, y):
157            a = x + x
158            b = y + y
159            res = self.net(a, b)
160            return res
161    mulnet = ForwardNetMul()
162    addnet = ForwardNetAdd(mulnet)
163    x = Tensor(np.ones([32]), dtype=mstype.float32)
164    y = Tensor(np.ones([32]), dtype=mstype.float32)
165    sens = Tensor(np.ones([32]), dtype=mstype.float32)
166    mulnet.set_grad()
167    addnet.set_grad()
168    out1 = mulnet(x, y)
169    out2 = addnet(x, y)
170    grad_mul = GradOfAllInputs(addnet)
171    grad_add = GradOfAllInputs(mulnet)
172    grad_mul(x, y, sens)
173    grad_add(x, y, sens)
174
175
176def test_net_inner_first_run_grad():
177    class ForwardNetMul(nn.Cell):
178        def __init__(self):
179            super().__init__()
180            self.z1 = Parameter(Tensor(np.ones([32])*2, dtype=mstype.float32), name='z1')
181
182        def construct(self, x, y):
183            a = x * self.z1
184            b = y * y
185            return a * b
186
187    class ForwardNetAdd(nn.Cell):
188        def __init__(self, net):
189            super().__init__()
190            self.net = net
191            self.z2 = Parameter(Tensor(np.ones([32]), dtype=mstype.float32), name='z2')
192            self.z3 = Parameter(Tensor(np.ones([32]), dtype=mstype.float32), name='z2')
193
194        def construct(self, x, y):
195            a = x + x*self.z3
196            b = y + y*self.z2
197            res = self.net(a, b)
198            return res
199    mulnet = ForwardNetMul()
200    addnet = ForwardNetAdd(mulnet)
201    x = Tensor(np.ones([32]), dtype=mstype.float32)
202    y = Tensor(np.ones([32]), dtype=mstype.float32)
203    sens = Tensor(np.ones([32]), dtype=mstype.float32)
204    mulnet.set_grad()
205    addnet.set_grad()
206    out1 = mulnet(x, y)
207    out2 = addnet(x, y)
208    grad_mul = GradOfAllInputs(addnet)
209    grad_add = GradOfFirstInput(mulnet)
210    grad_mul(x, y, sens)
211    grad_add(x, y, sens)
212