• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15import numpy as np
16
17import mindspore.nn as nn
18import mindspore.ops as ops
19from mindspore import context
20from mindspore import Tensor
21from mindspore.ops import operations as P
22from mindspore.ops import composite as C
23from mindspore.common.parameter import Parameter, ParameterTuple
24
25grad_all = C.GradOperation(get_all=True)
26grad_by_list = C.GradOperation(get_by_list=True)
27
28class CropAndResizeNet(nn.Cell):
29    def __init__(self, crop_size):
30        super(CropAndResizeNet, self).__init__()
31        self.crop_and_resize = P.CropAndResize()
32        self.crop_size = crop_size
33
34    def construct(self, x, boxes, box_indices):
35        return self.crop_and_resize(x, boxes, box_indices, self.crop_size)
36
37    def bprop(self, x, boxes, box_indices, out, dout):
38        return x, boxes, box_indices
39
40
41class TestUserDefinedBpropNet(nn.Cell):
42    def __init__(self, in_channel, out_channel):
43        super(TestUserDefinedBpropNet, self).__init__()
44        self.relu = nn.ReLU()
45        self.conv = nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=2, stride=1, has_bias=False,
46                              weight_init='ones', pad_mode='same')
47        self.crop = CropAndResizeNet((10, 10))
48        self.boxes = Tensor(np.ones((128, 4)).astype(np.float32))
49        self.box_indices = Tensor(np.ones((128,)).astype(np.int32))
50
51    def construct(self, x):
52        x = self.relu(x)
53        x = self.conv(x)
54        x = self.crop(x, self.boxes, self.box_indices)
55        return x
56
57
58class TestUserDefinedBpropGradNet(nn.Cell):
59    def __init__(self, net):
60        super(TestUserDefinedBpropGradNet, self).__init__()
61        self.net = net
62
63    def construct(self, x):
64        return grad_all(self.net)(x)
65
66
67def test_user_defined_bprop():
68    context.set_context(mode=context.GRAPH_MODE)
69    net = TestUserDefinedBpropNet(3, 10)
70    grad_net = TestUserDefinedBpropGradNet(net)
71    x = Tensor(np.ones((128, 3, 12, 12)).astype(np.float32))
72    grad_net(x)
73
74
75class TwoInputBPropOperator(nn.Cell):
76    def __init__(self):
77        super().__init__()
78        self.op = P.Mul()
79        self.add = P.Add()
80
81    def construct(self, x, y):
82        return self.op(x, y)
83
84    def bprop(self, x, y, out, dout):
85        return self.add(5, x), self.add(y, 9)
86
87
88class BPropOperatatorNet(nn.Cell):
89    def __init__(self, mul_size):
90        super().__init__()
91        mul_np = np.full(mul_size, 0.1, dtype=np.float32)
92        floordiv_np = np.full(mul_size, 0.1, dtype=np.float32)
93        self.mul_weight = Parameter(Tensor(mul_np), name="mul_weight")
94        self.floordiv_weight = Parameter(Tensor(floordiv_np), name="floordiv_weight")
95        self.mul = TwoInputBPropOperator()
96        self.floor_div = P.FloorDiv()
97        self.bn = nn.BatchNorm1d(num_features=96)
98
99    def construct(self, inputs):
100        x = self.mul(inputs, self.mul_weight)
101        x = self.floor_div(x, self.floordiv_weight)
102        x = self.bn(x)
103        return x
104
105def test_user_defined_bprop_with_u():
106    net = BPropOperatatorNet(mul_size=(128, 96))
107    grad_net = TestUserDefinedBpropGradNet(net)
108    x = Tensor(np.random.randn(128, 96).astype(np.float32))
109    grad_net(x)
110
111
112class SinNet(nn.Cell):
113    def __init__(self):
114        super(SinNet, self).__init__()
115        self.sin = ops.Sin()
116
117    def construct(self, x):
118        out = self.sin(x)
119        return out
120
121
122class SinGrad(nn.Cell):
123    def __init__(self, network):
124        super(SinGrad, self).__init__()
125        self.grad = ops.GradOperation()
126        self.network = network
127
128    def construct(self, x):
129        gout = self.grad(self.network)(x)
130        return gout
131
132
133class SinGradSec(nn.Cell):
134    def __init__(self, network):
135        super(SinGradSec, self).__init__()
136        self.grad = ops.GradOperation()
137        self.network = network
138
139    def construct(self, x):
140        gout = self.grad(self.network)(x)
141        return gout
142
143
144def test_second_grad_with_j_primitive():
145    context.set_context(mode=context.GRAPH_MODE)
146    net = SinNet()
147    first_grad = SinGrad(net)
148    second_grad = SinGradSec(first_grad)
149    x = Tensor(np.array([1.0], dtype=np.float32))
150    second_grad(x)
151
152
153# A CNode being used as FV is MapMorphism after MapMorphism of call-site CNode;
154def test_ad_fv_cnode_order():
155    context.set_context(mode=context.GRAPH_MODE)
156    class Net(nn.Cell):
157        def __init__(self):
158            super(Net, self).__init__()
159
160        # cnode xay is not being MapMorphism when cnode second_level() is being MapMorphism and
161        # BackPropagateFv as MapMorphism is started from output node and from left to right order.
162        def construct(self, x, y):
163            def first_level():
164                xay = x + y
165
166                def second_level():
167                    return xay
168
169                return second_level() + xay
170            return first_level()
171
172    input_x = Tensor(np.array([1.0], dtype=np.float32))
173    input_y = Tensor(np.array([2.0], dtype=np.float32))
174
175    net = Net()
176    net.add_flags_recursive(defer_inline=True)
177    grad_net = grad_all(net)
178    grad_net(input_x, input_y)
179
180
181# True and False branch of switch have different number of parameters.
182def test_if_branch_with_different_params():
183    context.set_context(mode=context.GRAPH_MODE)
184    class Net(nn.Cell):
185        def __init__(self):
186            super(Net, self).__init__()
187            self.weight1 = Parameter(Tensor(np.array([1.0], dtype=np.float32)), name="weight1")
188            self.weight2 = Parameter(Tensor(np.array([2.0], dtype=np.float32)), name="weight2")
189
190        def construct(self, idx, end, x):
191            out = x
192            if idx < end:
193                out = out + self.weight1 * self.weight2
194            else:
195                out = out + self.weight1
196            return out
197
198    class GradNet(nn.Cell):
199        def __init__(self, net):
200            super(GradNet, self).__init__()
201            self.net = net
202            self.weights = ParameterTuple(net.trainable_params())
203
204        def construct(self, idx, end, x):
205            return grad_by_list(self.net, self.weights)(idx, end, x)
206
207    idx = Tensor(np.array((0), dtype=np.int32))
208    end = Tensor(np.array((3), dtype=np.int32))
209    x = Tensor(np.array([2.0], dtype=np.float32))
210
211    net = Net()
212    grad_net = GradNet(net)
213    grad_net(idx, end, x)
214
215
216# Only lift fv in scope of lift_top_func_graph other than all func_graphs inside manager.
217# Otherwise, "Illegal AnfNode for evaluating" may be reported
218# because weight1 in Net may use old_parameter other than replicated one.
219def test_limit_lift_fv_scope():
220    context.set_context(mode=context.GRAPH_MODE)
221    class Net(nn.Cell):
222        def __init__(self):
223            super(Net, self).__init__()
224            self.weight1 = Parameter(Tensor(np.array([1.0], dtype=np.float32)), name="weight1")
225
226        def construct(self, x, y):
227            def inner_add(a, b):
228                return a + b
229
230            out = inner_add(x, y) + self.weight1
231            return out
232
233    class GradNet(nn.Cell):
234        def __init__(self, net):
235            super(GradNet, self).__init__()
236            self.net = net
237            self.weights = ParameterTuple(net.trainable_params())
238
239        def construct(self, x, y):
240            def inner_grad_add(a, b):
241                return a + b
242
243            d_weight = grad_by_list(self.net, self.weights)(x, y)[0]
244            d_out = inner_grad_add(d_weight, y)
245            return d_out
246
247    x = Tensor(np.array([2.0], dtype=np.float32))
248    y = Tensor(np.array([2.0], dtype=np.float32))
249
250    net = Net()
251    net.add_flags_recursive(defer_inline=True)
252    grad_net = GradNet(net)
253    grad_net.add_flags_recursive(defer_inline=True)
254    grad_net(x, y)
255
256
257def test_same_primal_used_by_multi_j():
258    class Net(nn.Cell):
259        def __init__(self):
260            super(Net, self).__init__()
261
262        def construct(self, x):
263            return x
264
265    class GradNet(nn.Cell):
266        def __init__(self, net):
267            super(GradNet, self).__init__()
268            self.net = net
269            self.grad = ops.GradOperation()
270
271        def construct(self, x):
272            out = self.net(x)
273            gout = self.grad(self.net)(x)
274            gout1 = self.grad(self.net)(x)
275            return out, gout, gout1
276
277    x = Tensor(np.array([1.0], dtype=np.float32))
278    net = Net()
279    grad = GradNet(net)
280    grad(x)
281
282
283def test_same_primal_used_by_multi_j_with_monad1():
284    class AdamNet(nn.Cell):
285        def __init__(self, var, m, v):
286            super(AdamNet, self).__init__()
287            self.apply_adam = P.Adam()
288            self.var = Parameter(var, name="var")
289            self.m = Parameter(m, name="m")
290            self.v = Parameter(v, name="v")
291
292        def construct(self, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad):
293            self.apply_adam(self.var, self.m, self.v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad)
294            return self.var
295
296    class AdamGradNet(nn.Cell):
297        def __init__(self, network):
298            super(AdamGradNet, self).__init__()
299            self.grad_fn = ops.GradOperation(sens_param=True)
300            self.sens = [Tensor(np.ones([3, 3, 3]).astype(np.float32)), Tensor(np.ones([3, 3, 3]).astype(np.float32))]
301            self.network = network
302
303        def construct(self, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad):
304            out = self.network(beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad)
305            gout1 = self.grad_fn(self.network)(beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, self.sens[0])
306            gout2 = self.grad_fn(self.network)(beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, self.sens[1])
307            return out, gout1, gout2
308
309    var = Tensor(np.ones([3, 3, 3]).astype(np.float32))
310    m = Tensor(np.ones([3, 3, 3]).astype(np.float32))
311    v = Tensor(np.ones([3, 3, 3]).astype(np.float32))
312    beta1_power = Tensor(np.array([0.9], dtype=np.float32))
313    beta2_power = Tensor(np.array([0.999], dtype=np.float32))
314    lr = Tensor(np.array([0.001], dtype=np.float32))
315    beta1 = Tensor(np.array([0.9], dtype=np.float32))
316    beta2 = Tensor(np.array([0.999], dtype=np.float32))
317    epsilon = Tensor(np.array([1e-8], dtype=np.float32))
318    grad = Tensor(np.random.rand(3, 3, 3).astype(np.float32))
319    net = AdamNet(var, m, v)
320    grad_net = AdamGradNet(net)
321    grad_net(beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad)
322
323
324def test_same_primal_used_by_multi_j_with_monad2():
325    class AdamNet(nn.Cell):
326        def __init__(self, var, m, v):
327            super(AdamNet, self).__init__()
328            self.apply_adam = P.Adam()
329            self.var = Parameter(var, name="var")
330            self.m = Parameter(m, name="m")
331            self.v = Parameter(v, name="v")
332
333        def construct(self, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad):
334            self.apply_adam(self.var, self.m, self.v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad)
335            return self.var
336
337    class AdamGradNet(nn.Cell):
338        def __init__(self, network):
339            super(AdamGradNet, self).__init__()
340            self.grad = ops.GradOperation(sens_param=True)
341            self.sens = [Tensor(np.ones([3, 3, 3]).astype(np.float32)), Tensor(np.ones([3, 3, 3]).astype(np.float32))]
342            self.network = network
343
344        def construct(self, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad):
345            out = self.network(beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad)
346            grad_fn = self.grad(self.network)
347            gout1 = grad_fn(beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, self.sens[0])
348            gout2 = grad_fn(beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, self.sens[1])
349            return out, gout1, gout2
350
351    var = Tensor(np.ones([3, 3, 3]).astype(np.float32))
352    m = Tensor(np.ones([3, 3, 3]).astype(np.float32))
353    v = Tensor(np.ones([3, 3, 3]).astype(np.float32))
354    beta1_power = Tensor(np.array([0.9], dtype=np.float32))
355    beta2_power = Tensor(np.array([0.999], dtype=np.float32))
356    lr = Tensor(np.array([0.001], dtype=np.float32))
357    beta1 = Tensor(np.array([0.9], dtype=np.float32))
358    beta2 = Tensor(np.array([0.999], dtype=np.float32))
359    epsilon = Tensor(np.array([1e-8], dtype=np.float32))
360    grad = Tensor(np.random.rand(3, 3, 3).astype(np.float32))
361    net = AdamNet(var, m, v)
362    grad_net = AdamGradNet(net)
363    grad_net(beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad)
364
365
366def test_grad_args_type_error1():
367    class Net(nn.Cell):
368        def __init__(self):
369            super(Net, self).__init__()
370            self.matmul = P.MatMul()
371        def construct(self, x, y):
372            out = self.matmul(x, y)
373            return out
374
375    class GradNetWrtX(nn.Cell):
376        def __init__(self, net):
377            super(GradNetWrtX, self).__init__()
378            self.net = net
379            self.grad_op = ops.GradOperation(get_all=2)
380        def construct(self, x, y):
381            gradient_function = self.grad_op(self.net)
382            return gradient_function(x, y)
383
384    x = Tensor(np.array([2.0], dtype=np.float32))
385    y = Tensor(np.array([2.0], dtype=np.float32))
386    try:
387        GradNetWrtX(Net())(x, y)
388    except TypeError as e:
389        assert "For 'GradOperation', the 'get_all' should be bool, but got" in str(e)
390
391
392def test_grad_args_type_error2():
393    class Net(nn.Cell):
394        def __init__(self):
395            super(Net, self).__init__()
396            self.matmul = P.MatMul()
397        def construct(self, x, y):
398            out = self.matmul(x, y)
399            return out
400
401    class GradNetWrtX(nn.Cell):
402        def __init__(self, net):
403            super(GradNetWrtX, self).__init__()
404            self.net = net
405            self.grad_op = ops.GradOperation(get_by_list=2)
406        def construct(self, x, y):
407            gradient_function = self.grad_op(self.net)
408            return gradient_function(x, y)
409
410    x = Tensor(np.array([2.0], dtype=np.float32))
411    y = Tensor(np.array([2.0], dtype=np.float32))
412    try:
413        GradNetWrtX(Net())(x, y)
414    except TypeError as e:
415        assert "For 'GradOperation', the 'get_by_list' should be bool, but got" in str(e)
416
417
418def test_grad_args_type_error3():
419    class Net(nn.Cell):
420        def __init__(self):
421            super(Net, self).__init__()
422            self.matmul = P.MatMul()
423        def construct(self, x, y):
424            out = self.matmul(x, y)
425            return out
426
427    class GradNetWrtX(nn.Cell):
428        def __init__(self, net):
429            super(GradNetWrtX, self).__init__()
430            self.net = net
431            self.grad_op = ops.GradOperation(sens_param=2)
432        def construct(self, x, y):
433            gradient_function = self.grad_op(self.net)
434            return gradient_function(x, y)
435
436    x = Tensor(np.array([2.0], dtype=np.float32))
437    y = Tensor(np.array([2.0], dtype=np.float32))
438    try:
439        GradNetWrtX(Net())(x, y)
440    except TypeError as e:
441        assert "For 'GradOperation', the 'sens_param' should be bool, but got" in str(e)
442
443
444def test_grad_net_is_none():
445    class Net(nn.Cell):
446        def __init__(self):
447            super(Net, self).__init__()
448            self.add = P.Add()
449        def construct(self, x, y):
450            out = self.add(x, y)
451            return out
452
453    class GradNetWrtX(nn.Cell):
454        def __init__(self, net):
455            super(GradNetWrtX, self).__init__()
456            self.net = P.Add()
457            self.grad_op = ops.GradOperation()
458        def construct(self, x, y):
459            gradient_function = self.grad_op(None)
460            return gradient_function(x, y)
461
462    x = Tensor(np.array([2.0], dtype=np.float32))
463    y = Tensor(np.array([2.0], dtype=np.float32))
464    try:
465        GradNetWrtX(Net())(x, y)
466    except Exception as e:
467        assert "'GradOperation' arg0 must be a 'Function' or 'Cell', but got" in str(e)
468
469
470def test_grad_missing_net():
471    class Net(nn.Cell):
472        def __init__(self):
473            super(Net, self).__init__()
474            self.add = P.Add()
475        def construct(self, x, y):
476            out = self.add(x, y)
477            return out
478
479    class GradNetWrtX(nn.Cell):
480        def __init__(self, net):
481            super(GradNetWrtX, self).__init__()
482            self.net = net
483            self.grad_op = ops.GradOperation()
484        def construct(self, x, y):
485            gradient_function = self.grad_op()
486            return gradient_function(x, y)
487
488    x = Tensor(np.array([2.0], dtype=np.float32))
489    y = Tensor(np.array([2.0], dtype=np.float32))
490    try:
491        GradNetWrtX(Net())(x, y)
492    except Exception as e:
493        assert "'GradOperation' requires a forward network or function as an input, while the input is empty." in str(e)
494
495
496def test_user_defined_bprop_inputs_size_error():
497    class BpropUserDefinedNet(nn.Cell):
498        def __init__(self):
499            super(BpropUserDefinedNet, self).__init__()
500            self.zeros_like = P.ZerosLike()
501
502        def construct(self, x, y):
503            return x + y
504
505        def bprop(self, out):
506            return self.zeros_like(out), self.zeros_like(out)
507
508    class BpropUserDefinedGradNet(nn.Cell):
509        def __init__(self, net):
510            super(BpropUserDefinedGradNet, self).__init__()
511            self.net = net
512
513        def construct(self, x, y):
514            return grad_all(self.net)(x, y)
515
516    net = BpropUserDefinedNet()
517    grad_net = BpropUserDefinedGradNet(net)
518    x = Tensor(np.array([2.0], dtype=np.float32))
519    y = Tensor(np.array([2.0], dtype=np.float32))
520    try:
521        grad_net(x, y)
522    except Exception as e:
523        assert "The function 'bprop' of Primitive or Cell requires at least 2 params 'out' and 'dout', but got only"\
524               in str(e)
525
526
527def test_user_defined_bprop_net_has_parameter():
528    class BpropUserDefinedNet(nn.Cell):
529        def __init__(self):
530            super(BpropUserDefinedNet, self).__init__()
531            self.zeros_like = P.ZerosLike()
532            self.x = Parameter(Tensor(np.array([2.0], dtype=np.float32)), name="x")
533
534        def construct(self, y):
535            return self.x + y
536
537        def bprop(self, y, out, dout):
538            return (self.zeros_like(out),)
539
540    class BpropUserDefinedGradNet(nn.Cell):
541        def __init__(self, net):
542            super(BpropUserDefinedGradNet, self).__init__()
543            self.net = net
544
545        def construct(self, y):
546            return grad_all(self.net)(y)
547
548    net = BpropUserDefinedNet()
549    grad_net = BpropUserDefinedGradNet(net)
550    y = Tensor(np.array([2.0], dtype=np.float32))
551    try:
552        grad_net(y)
553    except Exception as e:
554        assert "The Cell with user defined 'bprop' function in scope" in str(e)
555        assert "does not support Parameter data type." in str(e)
556
557
558def test_user_defined_bprop_inputs_size_error1():
559    class BpropUserDefinedNet(nn.Cell):
560        def __init__(self):
561            super(BpropUserDefinedNet, self).__init__()
562            self.zeros_like = P.ZerosLike()
563
564        def construct(self, x, y):
565            return x + y
566
567        def bprop(self, x, y, out):
568            return self.zeros_like(out), self.zeros_like(out)
569
570    class BpropUserDefinedGradNet(nn.Cell):
571        def __init__(self, net):
572            super(BpropUserDefinedGradNet, self).__init__()
573            self.net = net
574
575        def construct(self, x, y):
576            return grad_all(self.net)(x, y)
577
578    net = BpropUserDefinedNet()
579    grad_net = BpropUserDefinedGradNet(net)
580    x = Tensor(np.array([2.0], dtype=np.float32))
581    y = Tensor(np.array([2.0], dtype=np.float32))
582    try:
583        grad_net(x, y)
584    except TypeError as e:
585        assert "The params of function 'bprop' of Primitive or Cell requires the forward inputs as well as the 'out' " \
586               "and 'dout'." in str(e)
587
588
589def test_grad_hook():
590    def var_hook_function(grad_out):
591        assert grad_out[0].asnumpy().shape == (32, 120)
592
593    class Net(nn.Cell):
594        def __init__(self):
595            super(Net, self).__init__()
596            self.add = P.Add()
597            self.hook = P.HookBackward(var_hook_function)
598        def construct(self, x, y):
599            x = self.hook(x)
600            out = self.add(x, y)
601            return out
602
603    class GradNetWrtX(nn.Cell):
604        def __init__(self, net):
605            super(GradNetWrtX, self).__init__()
606            self.net = net
607            self.grad_op = ops.GradOperation()
608        def construct(self, x, y):
609            gradient_function = self.grad_op(self.net)
610            return gradient_function(x, y)
611
612    x = Tensor(np.array([2.0], dtype=np.float32))
613    y = Tensor(np.array([2.0], dtype=np.float32))
614    try:
615        GradNetWrtX(Net())(x, y)
616    except Exception as e:
617        assert "The Primitive 'HookBackward' is not supported in graph mode, which is only supported in pynative " \
618               "mode." in str(e)
619