• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15import numpy as np
16
17import mindspore.ops.composite as C
18from mindspore import Tensor, Parameter
19from mindspore import context
20from mindspore.common import dtype as mstype
21from mindspore.common.parameter import ParameterTuple
22from mindspore.nn import Cell
23from mindspore.ops import operations as P
24
25context.set_context(mode=context.GRAPH_MODE)
26
27
28grad_by_list = C.GradOperation(get_by_list=True)
29grad_all_with_sens = C.GradOperation(get_all=True, sens_param=True)
30grad_by_list_with_sens = C.GradOperation(get_by_list=True, sens_param=True)
31grad_all = C.GradOperation(get_all=True)
32grad_with_sens = C.GradOperation(sens_param=True)
33
34
35def test_net_vargs_expand():
36    class AddNet(Cell):
37        def __init__(self):
38            super(AddNet, self).__init__()
39            self.w = Parameter(
40                Tensor(np.ones((3, 4, 5), np.float32)), "w2", requires_grad=True)
41
42        def construct(self, x, y):
43            return x + y
44
45    x = Tensor(np.random.normal(0, 1, [3, 4, 5]).astype(np.float32))
46    y = Tensor(np.random.normal(0, 1, [3, 4, 5]).astype(np.float32))
47    sens = Tensor(np.random.normal(0, 1, [3, 4, 5]).astype(np.float32))
48    net = AddNet()
49    _ = grad_all_with_sens(net, net.trainable_params())(x, y, sens)
50
51
52class VarNet(Cell):
53    def __init__(self, net):
54        super(VarNet, self).__init__()
55        self.b = Parameter(
56            Tensor(np.ones([3, 4, 5]), dtype=mstype.float32), "b", requires_grad=True)
57        self.w = Parameter(
58            Tensor(np.ones([3, 4, 5]), dtype=mstype.float32), "w", requires_grad=True)
59        self.net = net
60
61    def construct(self, *args):
62        return self.net(*args) * self.w + self.b
63
64
65class SecondNet(Cell):
66    def __init__(self):
67        super(SecondNet, self).__init__()
68        self.b2 = Parameter(
69            Tensor(np.ones([3, 4, 5]), dtype=mstype.float32), "b2", requires_grad=True)
70
71    def construct(self, *args):
72        res = args[0] + args[1]
73        return res + self.b2
74
75
76class Bprop(Cell):
77    def __init__(self, func, wrt_params, params, grad_op, sens=None):
78        super(Bprop, self).__init__(auto_prefix=False)
79        self.func = func
80        self.wrt_params = wrt_params
81        self.params = None
82        if self.wrt_params and params:
83            self.params = ParameterTuple(params)
84        self.grad = grad_op
85        self.with_sens = False
86        self.sens = sens
87        if not sens is None:
88            self.sens = sens if isinstance(sens, Tensor) else Tensor(sens, dtype=mstype.float32)
89            self.with_sens = True
90
91    def construct(self, *inputs):
92        # pylint: disable=no-else-return
93        if self.wrt_params:
94            if self.with_sens:
95                return self.grad(self.func, self.params)(*inputs, self.sens)
96            else:
97                return self.grad(self.func, self.params)(*inputs)
98        elif self.with_sens:
99            return self.grad(self.func)(*inputs, self.sens)
100        else:
101            return self.grad(self.func)(*inputs)
102
103
104def test_all_var_args_grad_with_sens():
105    """"test grad_by_list_with_sens with all var args input"""
106
107    class GradNet(Cell):
108        def __init__(self, net):
109            super(GradNet, self).__init__()
110            self.weights = ParameterTuple(net.trainable_params())
111            self.net = net
112
113        def construct(self, *inputs):
114            return grad_by_list_with_sens(self.net, self.weights)(*inputs)
115
116    x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
117    y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
118    sens = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
119    net = VarNet(SecondNet())
120    grad_net = GradNet(net)
121    _ = grad_net(x, y, sens)
122
123
124def test_grad_list_var_args():
125    class GradNet(Cell):
126        def __init__(self, net):
127            super(GradNet, self).__init__()
128            self.weights = ParameterTuple(net.trainable_params())
129            self.net = net
130
131        def construct(self, *inputs):
132            return grad_by_list(self.net, self.weights)(*inputs)
133
134    x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
135    y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
136    net = VarNet(SecondNet())
137    grad_net = GradNet(net)
138    _ = grad_net(x, y)
139
140
141def test_grad_all_var_args():
142    class GradNet(Cell):
143        def __init__(self, net):
144            super(GradNet, self).__init__()
145            self.weights = ParameterTuple(net.trainable_params())
146            self.net = net
147
148        def construct(self, *inputs):
149            return grad_all(self.net)(*inputs)
150
151    x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
152    y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
153    net = VarNet(SecondNet())
154    grad_net = GradNet(net)
155    _ = grad_net(x, y)
156
157
158def test_grad_all_var_args_with_sens():
159    class GradNet(Cell):
160        def __init__(self, net):
161            super(GradNet, self).__init__()
162            self.weights = ParameterTuple(net.trainable_params())
163            self.net = net
164
165        def construct(self, *inputs):
166            return grad_all_with_sens(self.net)(*inputs)
167
168    x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
169    y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
170    sens = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
171    net = VarNet(SecondNet())
172    grad_net = GradNet(net)
173    _ = grad_net(x, y, sens)
174
175
176def test_grad_var_args_with_sens():
177    class GradNet(Cell):
178        def __init__(self, net):
179            super(GradNet, self).__init__()
180            self.weights = ParameterTuple(net.trainable_params())
181            self.net = net
182
183        def construct(self, *inputs):
184            return grad_with_sens(self.net)(*inputs)
185
186    x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
187    y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
188    sens = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
189    net = VarNet(SecondNet())
190    grad_net = GradNet(net)
191    _ = grad_net(x, y, sens)
192
193
194def test_grad_with_param_sens():
195    """"test grad_with_sens parameter"""
196
197    class GradNet(Cell):
198        def __init__(self, net):
199            super(GradNet, self).__init__()
200            self.weights = ParameterTuple(net.trainable_params())
201            self.net = net
202            self.sens = Parameter(Tensor(np.ones([3, 4, 5]), dtype=mstype.float32), name='sens', requires_grad=False)
203            self.grad = C.GradOperation(get_by_list=True, sens_param=True)
204
205        def construct(self, x, y):
206            return self.grad(self.net, self.weights)(x, y, self.sens)
207
208    x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
209    y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
210    net = SecondNet()
211    grad_net = GradNet(net)
212    _ = grad_net(x, y)
213
214
215def test_var_args_grad():
216    class VarNet(Cell):
217        def __init__(self, net):
218            super(VarNet, self).__init__()
219            self.b = Parameter(
220                Tensor(np.ones([3, 4, 5]), dtype=mstype.float32), "b", requires_grad=True)
221            self.net = net
222
223        def construct(self, *args):
224            return self.net(*args) + self.b
225
226    class SecondNet(Cell):
227        def __init__(self):
228            super(SecondNet, self).__init__()
229            self.b2 = Parameter(
230                Tensor(np.ones([3, 4, 5]), dtype=mstype.float32), "b2", requires_grad=True)
231
232        def construct(self, *args):
233            res = args[0] + args[1]
234            return res + self.b2
235
236    class GradNet(Cell):
237        def __init__(self, net):
238            super(GradNet, self).__init__()
239            self.net = net
240            self.weights = ParameterTuple(net.trainable_params())
241
242        def construct(self, x, y, sens):
243            return grad_by_list_with_sens(self.net, self.weights)(x, y, sens)
244
245    x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
246    y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
247    sens = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
248    net = VarNet(SecondNet())
249    grad_net = GradNet(net)
250    _ = grad_net(x, y, sens)
251
252
253def test_var_args_positional():
254    """"test grad_all with var args in inner graph"""
255
256    class VarNet(Cell):
257        def __init__(self, net):
258            super(VarNet, self).__init__()
259            self.net = net
260
261        def construct(self, x, y):
262            return self.net(x, y) * x
263
264    class SecondNet(Cell):
265        def __init__(self):
266            super(SecondNet, self).__init__()
267
268        def construct(self, *args):
269            return args[0] + args[1]
270
271    class GradNet(Cell):
272        def __init__(self, net):
273            super(GradNet, self).__init__()
274            self.net = net
275            self.weights = ParameterTuple(net.trainable_params())
276
277        def construct(self, x, y):
278            return grad_all(self.net)(x, y)
279
280    x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
281    y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
282    net = VarNet(SecondNet())
283    grad_net = GradNet(net)
284    _ = grad_net(x, y)
285
286
287def test_grad_within_if_else():
288    class GradNet(Cell):
289        def __init__(self, net):
290            super(GradNet, self).__init__()
291            self.weights = ParameterTuple(net.trainable_params())
292            self.net = net
293            grad_op = C.GradOperation(get_all=False, get_by_list=True, sens_param=True)
294            sens = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
295            self.grad = Bprop(self.net, True, self.weights, grad_op, sens)
296
297        def construct(self, *inputs):
298            return self.grad(*inputs)
299
300    x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
301    y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
302    net = VarNet(SecondNet())
303    grad_net = GradNet(net)
304    out = grad_net(x, y)
305    print("test_grad_var_args_with_sens out=", out)
306
307
308def test_grad_for_concat():
309    class GradNet(Cell):
310        def __init__(self, net):
311            super(GradNet, self).__init__()
312            self.weights = ParameterTuple(net.trainable_params())
313            self.net = net
314            grad_op = C.GradOperation(get_all=True, get_by_list=False, sens_param=True)
315            self.grad = Bprop(self.net, False, self.weights, grad_op)
316
317        def construct(self, *inputs):
318            return self.grad(*inputs)
319
320    class Concat(Cell):
321        def __init__(self, axis):
322            super().__init__()
323            self.concat = P.Concat(axis=axis)
324
325        def construct(self, *input1):
326            return self.concat(input1)
327
328    class ConcatFactory:
329        def __init__(self, input_shape, axis, dtype=np.float32):
330            super(ConcatFactory, self).__init__()
331            self.inputs_np = []
332            for s in input_shape:
333                self.inputs_np.append(np.random.randn(*s).astype(dtype))
334            self.axis = axis
335            self.out_numpy = np.concatenate(self.inputs_np, axis=self.axis)
336            self.out_grad_np = self.out_numpy
337
338        def grad_mindspore_impl(self):
339            inputs = []
340            for i in self.inputs_np:
341                inputs.append(Tensor(i))
342            net = Concat(axis=self.axis)
343            grad_net = GradNet(net)
344            grad_net.set_train()
345            _ = grad_net(*inputs, Tensor(self.out_grad_np))
346
347        def grad_cmp(self):
348            self.grad_mindspore_impl()
349
350    fact = ConcatFactory(input_shape=(
351        (2, 184320, 1), (2, 46080, 1), (2, 11520, 1), (2, 2880, 1), (2, 720, 1)), axis=1)
352    fact.grad_cmp()
353