• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15""" test_framstruct """
16import numpy as np
17import mindspore as ms
18import mindspore.nn as nn
19from mindspore import context
20from mindspore.common import dtype as mstype
21from mindspore.common.parameter import Parameter, ParameterTuple
22from mindspore.common.api import ms_function
23from mindspore.ops import composite as C
24from mindspore.ops import operations as P
25from ..ut_filter import non_graph_engine
26from ....mindspore_test_framework.utils.check_gradient import (
27    check_jacobian, Tensor, NNGradChecker,
28    OperationGradChecker, check_gradient)
29
30context.set_context(mode=context.PYNATIVE_MODE)
31
32
33def setup_module(module):
34    context.set_context(mode=context.PYNATIVE_MODE)
35
36
37grad_all = C.GradOperation(get_all=True)
38grad_by_list = C.GradOperation(get_by_list=True)
39
40
41@ms_function
42def while_upper_bound(upper):
43    rval = 2
44    while rval < upper:
45        rval = rval * rval
46    return rval
47
48
49def test_while_upper_bound():
50    res = while_upper_bound(10)
51    assert res == 16
52
53
54@ms_function
55def while_lower_bound(lower):
56    """ t_while """
57    rval = lower
58    while rval < 100:
59        rval = rval * rval
60    return rval
61
62
63def test_while_lower_bound():
64    res = while_lower_bound(2)
65    assert res == 256
66
67
68@ms_function
69def dynamic_make_tuple(x, lower, upper):
70    out = ()
71    i = lower
72    while i < upper:
73        out = out + (x,)
74        i = i + 1
75    return out
76
77
78def test_dynamic_make_tuple():
79    assert dynamic_make_tuple(2, 1, 5) == (2, 2, 2, 2)
80
81
82def test_make_tuple():
83    # Statically recursively creating static type is valid in mindspore.
84    @ms_function
85    def make_tuple(x):
86        out = ()
87        for i in range(3):
88            out = out + (x,)
89        return out
90
91    res = make_tuple(5)
92    assert res == (5, 5, 5)
93
94
95@ms_function
96def add(x, y):
97    """ add """
98    return x + y
99
100
101def mul(x, y):
102    """ mul """
103    return x * y
104
105
106def add_mul(x, y):
107    """ add_mul """
108    return (x + y) * y
109
110
111def mainf(x, y):
112    """ mainf """
113    return grad_all(mul)(x, y)
114
115
116def grad_add_mul(x, y):
117    """ grad_add_mul """
118    return grad_all(add_mul)(x, y)
119
120
121@ms_function
122def sub(x, y):
123    """ sub """
124    return x - y
125
126
127# pylint: disable=using-constant-test
128@ms_function
129def if_always_true(x):
130    """ if_always_true """
131    if True:
132        return x
133    else:
134        return 0
135
136
137def test_add():
138    """ test_add """
139    res = add(2.5, 3)
140    assert res == 5.5
141
142
143def test_sub():
144    """ test_sub """
145    res = sub(3.5, 3)
146    assert res == 0.5
147
148
149@non_graph_engine
150def test_if_always_true():
151    """ test_if_always_true """
152    res = if_always_true(1)
153    assert res == 1
154
155
156@non_graph_engine
157def test_f():
158    """ test_f """
159    res = mainf(Tensor(3, dtype=ms.int32), Tensor(2, dtype=ms.int32))
160    assert res == (2, 3)
161
162
163@non_graph_engine
164def test_grad_add_mul():
165    """ test_grad_add_mul """
166    res = grad_add_mul(Tensor(3, dtype=ms.int32), Tensor(2, dtype=ms.int32))
167    assert res == (2, 7)
168
169
170def f(x):
171    if x > 0:
172        return f(x - 1)
173    return x
174
175
176@ms_function
177def list_subscript():
178    """ list_subscript """
179    x = [1, 2, 3]
180    return x[0] * x[1]
181
182
183def test_list_subscript():
184    """ test_list_subscript """
185    res = list_subscript()
186    assert res == 2
187
188
189@ms_function
190def ms_infer_for(xs, y):
191    """ ms_infer_for """
192    rval = y
193    for x in xs:
194        rval = rval + x
195    return rval
196
197
198def test_infer_for():
199    """ test_infer_for """
200    t = (1, 2, 3)
201    y = 4
202    res = ms_infer_for(t, y)
203    assert res == 10
204
205
206@ms_function
207def if_construct(a, b):
208    z = a
209    if a > b:
210        z = a + b
211    else:
212        z = a * b
213    if z > b:
214        return z - a
215    else:
216        return a - b
217
218
219def test_if_construct():
220    """ test_if_construct """
221    res = if_construct(3, 6)
222    assert res == 15
223
224
225@ms_function
226def if_scalar(a, b):
227    """ if_abstract """
228    if a:
229        return a
230    return b
231
232
233def test_if_scalar1():
234    """ test_if_abstract """
235    res = if_scalar(3, 6)
236    assert res == 3
237
238
239def test_if_scalar2():
240    """ test_if_abstract """
241    res = if_scalar(0, 6)
242    assert res == 6
243
244
245@ms_function
246def if_tensor(a, b):
247    c = a
248    if a < b:
249        c = a + a
250        if c < b:
251            c = a + c
252        else:
253            c = a + b
254    else:
255        c = b + b
256    out = c + c
257    return out
258
259
260def test_if_tensor():
261    res = if_tensor(Tensor(np.ones([1]).astype(np.int32)), Tensor(np.ones([1]).astype(np.int32)))
262    assert res == Tensor(np.ones([1]).astype(np.int32) * 4)
263
264
265def rec(x):
266    """ rec """
267    if x > 0:
268        return rec(x - 1)
269    return x
270
271
272def test_me_rec():
273    """ test_me_rec """
274    res = rec(10)
275    assert res == 0
276
277
278def t2_while(x, y):
279    out = y - x
280    i = 0
281    while i < 10:
282        out = mul(x, y)
283        i = i + 1
284    return out
285
286
287def test_while2():
288    res = t2_while(2, 3)
289    assert res == 6
290
291
292def if_test(a, b):
293    """ if_test """
294    if a > b:
295        return 3 * a
296    return 2 * b
297
298
299def grad_if(x, y):
300    """ grad_if """
301    return grad_all(if_test)(x, y)
302
303
304def test_grad_if():
305    """ test_grad_if """
306    assert grad_if(Tensor(5, dtype=ms.int32), Tensor(4, dtype=ms.int32)) == (3, 0)
307
308
309class ConvNet(nn.Cell):
310    def __init__(self):
311        super(ConvNet, self).__init__()
312        out_channel = 16
313        kernel_size = 3
314        self.conv = P.Conv2D(out_channel,
315                             kernel_size,
316                             mode=1,
317                             pad_mode="pad",
318                             pad=0,
319                             stride=1,
320                             dilation=2,
321                             group=1)
322        self.w = Parameter(Tensor(np.ones([16, 16, 3, 3]).astype(np.float32)), name='w')
323
324    def construct(self, x):
325        return self.conv(x, self.w)
326
327
328conv = ConvNet()
329c1 = Tensor([2], mstype.float32)
330c2 = Tensor([10], mstype.float32)
331c3 = Tensor([1], mstype.float32)
332
333
334@ms_function
335def t1_while(x, y, z):
336    out = x
337    i = c1
338    while i < c2:
339        out = out + conv(z)
340        i = i + c3
341    out = out + out
342    return out
343
344
345def test_while_net():
346    y = Tensor(np.ones([1, 3, 3, 4]).astype(np.float32))
347    x = Tensor(np.ones([1, 16, 12, 12]).astype(np.float32))
348    z = Tensor(np.ones([1, 16, 16, 16]).astype(np.float32))
349    res = t1_while(x, y, z)
350    assert np.all(res.asnumpy() == np.ones([1, 16, 12, 12]).astype(np.float32) * 2306.0)
351
352
353@ms_function
354def if_while(a, b, x, z):
355    c = a
356    i = c1
357    out = x
358    if a < b:
359        c = a + a
360        while i < c2:
361            out = out + conv(z)
362            i = i + c3
363    else:
364        c = b + b
365    out = c + c
366    return out
367
368
369def test_if_while():
370    x = Tensor(np.random.randn(1, 16, 12, 12).astype(np.float32))
371    z = Tensor(np.random.randn(1, 16, 16, 16).astype(np.float32))
372    res = if_while(Tensor(np.ones([1]).astype(np.float32)), Tensor(np.ones([1]).astype(np.float32)), x, z)
373    assert np.all(res.asnumpy() == np.ones([64, 10]).astype(np.float32) * 4.0)
374
375
376def _while(x):
377    """ _while """
378    ret = x * x
379    i = 2
380    while i <= 3:
381        ret = ret * i
382        i = i + 1
383    return ret
384
385
386def grad_while(x):
387    """ grad_while """
388    return grad_all(_while)(x)
389
390
391def test_grad_while():
392    """ test_grad_while """
393    assert grad_while(Tensor(5, dtype=ms.int32)) == (60,)
394
395
396@ms_function
397def factorial(n):
398    """ factorial """
399    if n == 0:
400        return 1
401    return n * factorial(n - 1)
402
403
404def test_factorial():
405    res = factorial(3)
406    assert res == 6
407
408
409@ms_function
410def factorial2(n):
411    """ factorial """
412    if n != 0:
413        return n * factorial2(n - 1)
414    elif n == 1:
415        return 1 * factorial2(n - 1)
416    else:
417        return 1
418
419
420def test_factorial2():
421    res = factorial2(3)
422    assert res == 6
423
424
425@ms_function
426def foo(n):
427    if n <= 1:
428        if n == 1:
429            return foo(n - 1)
430        else:
431            return 1
432    else:
433        return foo(n - 1)
434
435
436def test_foo():
437    res = foo(5)
438    assert res == 1
439
440
441@ms_function
442def double_nested_loop(x):
443    i = 0
444    s = 0
445    while i < x:
446        j = 0
447        i = i + 1
448        while j < 3:
449            j = j + 1
450            s = s + j
451    return s
452
453
454def test_nested_loop():
455    res = double_nested_loop(3)
456    assert res == 18
457
458
459@ms_function
460def double_nested_loop2(x):
461    s = 0
462    for i in range(x):
463        for j in range(3):
464            s = s + j
465    return s
466
467
468def test_nested_loop2():
469    res = double_nested_loop(1)
470    assert res == 6
471
472
473def _for(x):
474    """ _for """
475    ret = x * x
476    for i in (2, 3):
477        ret = ret * i
478    return ret
479
480
481@ms_function
482def grad_for(x):
483    """ grad_for """
484    return grad_all(_for)(x)
485
486
487@ms_function
488def try_tail(x):
489    """ try_tail """
490    return C.tail(x)
491
492
493@non_graph_engine
494def test_tail():
495    """ test_tail """
496    try_tail((0, 1, 2, 3))
497
498
499@ms_function
500def zero_like_tensor(x):
501    """ zero_like_tensor """
502    return C.zeros_like(x)
503
504
505def test_zeros():
506    """ test_zeros """
507    x = Tensor(np.ones([2, 3]).astype(np.int32))
508    res = zero_like_tensor(x)
509    assert np.all(res.asnumpy() == np.zeros([2, 3]).astype(np.int32))
510
511
512@ms_function
513def arithmetic_simplify_01(x, y):
514    """ arithmetic_simplify_01 """
515    return C.zeros_like(x) * y
516
517
518def test_arithmetic_simplify_01():
519    """ test_arithmetic_simplify_01 """
520    x = Tensor(np.ones([2, 3]).astype(np.int32))
521    y = Tensor(np.array([[1, 2, 3], [4, 5, 6]]).astype(np.int32))
522    res = arithmetic_simplify_01(x, y)
523    expect = np.zeros([2, 3]).astype(np.int32)
524    assert np.all(res.asnumpy() == expect)
525
526
527@ms_function
528def arithmetic_simplify_02(x, y):
529    """ arithmetic_simplify_02 """
530    return C.ones_like(x) * y
531
532
533def test_arithmetic_simplify_02():
534    """ test_arithmetic_simplify_02 """
535    x = Tensor(np.ones([2, 3]).astype(np.int32))
536    y = Tensor(np.array([[1, 2, 3], [4, 5, 6]]).astype(np.int32))
537    res = arithmetic_simplify_02(x, y)
538    expect = np.array([[1, 2, 3], [4, 5, 6]]).astype(np.int32)
539    assert np.all(res.asnumpy() == expect)
540
541
542@ms_function
543def arithmetic_simplify_03(x, y):
544    """ arithmetic_simplify_03 """
545    return x * C.ones_like(y)
546
547
548def test_arithmetic_simplify_03():
549    """ test_arithmetic_simplify_03 """
550    x = Tensor(np.ones([2, 3]).astype(np.int32))
551    y = Tensor(np.array([[1, 2, 3], [4, 5, 6]]).astype(np.int32))
552    res = arithmetic_simplify_03(x, y)
553    expect = np.ones([2, 3]).astype(np.int32)
554    assert np.all(res.asnumpy() == expect)
555
556
557@ms_function
558def arithmetic_simplify_04(x):
559    """ arithmetic_simplify_04 """
560    return x + 0
561
562
563def test_arithmetic_simplify_04():
564    """ test_arithmetic_simplify_04 """
565    x = Tensor(np.array([[1, 2, 3], [4, 5, 6]]).astype(np.int32))
566    res = arithmetic_simplify_04(x)
567    expect = np.array([[1, 2, 3], [4, 5, 6]]).astype(np.int32)
568    assert np.all(res.asnumpy() == expect)
569
570
571@ms_function
572def arithmetic_simplify_05(x):
573    """ arithmetic_simplify_05 """
574    return x * 1
575
576
577def test_arithmetic_simplify_05():
578    """ test_arithmetic_simplify_05 """
579    x = Tensor(np.array([[1, 2, 3], [4, 5, 6]]).astype(np.int32))
580    res = arithmetic_simplify_05(x)
581    expect = np.array([[1, 2, 3], [4, 5, 6]]).astype(np.int32)
582    assert np.all(res.asnumpy() == expect)
583
584
585@ms_function
586def arithmetic_simplify_06(x):
587    """ arithmetic_simplify_06 """
588    return x * 2 * 5
589
590
591def test_arithmetic_simplify_06():
592    """ test_arithmetic_simplify_06 """
593    x = Tensor(np.array([[1, 2, 3], [4, 5, 6]]).astype(np.int32))
594    res = arithmetic_simplify_06(x)
595    expect = np.array([[10, 20, 30], [40, 50, 60]]).astype(np.int32)
596    assert np.all(res.asnumpy() == expect)
597
598
599@ms_function
600def arithmetic_simplify_07(x):
601    """ arithmetic_simplify_07 """
602    return (x + 1) * 2 * 5
603
604
605def test_arithmetic_simplify_07():
606    """ test_arithmetic_simplify_07 """
607    x = Tensor(np.array([[1, 2, 3], [4, 5, 6]]).astype(np.int32))
608    res = arithmetic_simplify_07(x)
609    expect = np.array([[20, 30, 40], [50, 60, 70]]).astype(np.int32)
610    assert np.all(res.asnumpy() == expect)
611
612
613@ms_function
614def arithmetic_simplify_08(x, y):
615    """ arithmetic_simplify_08 """
616    return 1 * x * 1 * 1 + 1 * 0 * 1 + 0 + y * 1
617
618
619def test_arithmetic_simplify_08():
620    """ test_arithmetic_simplify_08 """
621    x = Tensor(np.array([[1, 2, 3], [4, 5, 6]]).astype(np.int32))
622    y = Tensor(np.ones([2, 3]).astype(np.int32))
623    res = arithmetic_simplify_08(x, y)
624    expect = np.array([[2, 3, 4], [5, 6, 7]]).astype(np.int32)
625    assert np.all(res.asnumpy() == expect)
626
627
628def test_GradCheckerPrimitive():
629    """ test_GradCheckerPrimitive """
630    matmul = P.MatMul()
631
632    def prim_f(x, y):
633        return matmul(x, y)
634
635    check_gradient(prim_f, Tensor(np.array([[0.65, 0.8, 0.8]], np.float32)),
636                   Tensor(np.array([[0.1], [0.2], [-.1]], np.float32)),
637                   grad_checker_class=OperationGradChecker, sampling_times=2)
638
639
640def test_NNGradChecker():
641    """ test_NNGradChecker """
642
643    class Net(nn.Cell):
644        """ Net definition """
645
646        def __init__(self):
647            super(Net, self).__init__()
648            self.dense = nn.Dense(10, 10)
649
650        def construct(self, x):
651            out = self.dense(x)
652            return out
653
654    check_gradient(Net(), Tensor(np.random.rand(1, 10).astype(np.float32)),
655                   delta=1e-3,
656                   max_error=1e-3,
657                   grad_checker_class=NNGradChecker, sampling_times=3)
658
659
660def test_OperationGradChecker():
661    """ test_OperationGradChecker """
662
663    class Net(nn.Cell):
664        """ Net definition """
665
666        def __init__(self):
667            super(Net, self).__init__()
668            self.matmul = P.MatMul()
669            self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
670
671        def construct(self, x, y):
672            x = x * self.z
673            out = self.matmul(x, y)
674            return out
675
676    check_gradient(Net(), Tensor(np.array([[0.65, 0.8, 0.8]], np.float32)),
677                   Tensor(np.array([[0.1], [0.2], [-.1]], np.float32)), grad_checker_class=OperationGradChecker,
678                   input_selector=[1], sampling_times=2)
679
680
681def test_OperationJacobianChecker():
682    """ test_OperationJacobianChecker """
683
684    class Net(nn.Cell):
685        """ Net definition """
686
687        def __init__(self):
688            super(Net, self).__init__()
689            self.matmul = P.MatMul()
690            self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
691
692        def construct(self, x, y):
693            x = x * self.z
694            out = self.matmul(x, y)
695            return x, out
696
697    check_jacobian(Net(), Tensor(np.array([[0.65, 0.8, 0.8], [0.1, 0.2, 0.3]], np.float32)),
698                   Tensor(np.array([[0.1, 0.3], [0.2, 0.2], [-.1, 0.4]], np.float32)),
699                   grad_checker_class=OperationGradChecker, input_selector=[0],
700                   output_selector=[0])
701
702
703def test_NNJacobianChecker():
704    """ test_NNJacobianChecker """
705
706    class Net(nn.Cell):
707        """ Net definition """
708
709        def __init__(self):
710            super(Net, self).__init__()
711            self.dense = nn.Dense(10, 10)
712
713        def construct(self, x):
714            out = self.dense(x)
715            return out, x
716
717    check_jacobian(Net(), Tensor(np.random.rand(1, 10).astype(np.float32)),
718                   delta=1e-3,
719                   max_error=1e-7,
720                   grad_checker_class=NNGradChecker,
721                   input_selector=[1],
722                   output_selector=[0])
723
724
725def multi_outputs(x, y):
726    z = x + y
727    return 2 * z, 2 * z
728
729
730@ms_function
731def while_sp(x, y, z):
732    out = x
733    i = c3
734    while i < c2:
735        out = mul(x, out)
736        i = i + c3
737    return out
738
739
740def test_while_sp():
741    y = Tensor(np.ones([1, 3]).astype(np.float32))
742    z = Tensor(np.ones([1, 3]).astype(np.float32))
743    x = Tensor(np.ones([1, 3]).astype(np.float32) * 2.0)
744    res = while_sp(x, y, z)
745    assert np.all(res.asnumpy() == np.ones([1, 3]).astype(np.float32) * 1024.0)
746
747
748def grad_refactor_simple_1(x, y):
749    """ add """
750    return x * x + 2 * y
751
752
753def test_grad_refactor_simple_1():
754    assert grad_all(grad_refactor_simple_1)(Tensor(2, dtype=ms.int32), Tensor(1, dtype=ms.int32)) == (4, 2)
755
756
757def grad_refactor_simple_2(x, y, z):
758    """ add """
759    return x * y + z + x * y * z + x + x * y
760
761
762def test_grad_refactor_simple_2():
763    x = Tensor(2, dtype=ms.int32)
764    y = Tensor(3, dtype=ms.int32)
765    z = Tensor(0, dtype=ms.int32)
766    assert grad_all(grad_refactor_simple_2)(x, y, z) == (7, 4, 7)
767
768
769def grad_refactor_1(a, b):
770    """ if_test """
771
772    def inner(x, y):
773        return x * y
774
775    return inner(a, b)
776
777
778def test_grad_refactor_1():
779    assert grad_all(grad_refactor_1)(Tensor(2, dtype=ms.int32), Tensor(3, dtype=ms.int32)) == (3, 2)
780
781
782def grad_refactor_2(a, b):
783    """ if_test """
784
785    def inner(x):
786        return x * b
787
788    return inner(b) * inner(a)
789
790
791def test_grad_refactor_2():
792    assert grad_all(grad_refactor_2)(Tensor(2, dtype=ms.int32), Tensor(3, dtype=ms.int32)) == (27, 54)
793
794
795def grad_refactor_3(a):
796    """ if_test """
797    if a > 3:
798        return 0
799    return 3 * a
800
801
802def grad_refactor_4(a):
803    """ if_test """
804    if a > 3:
805        return 3 * a
806    return 0
807
808
809def test_grad_refactor_4():
810    assert grad_all(grad_refactor_4)(Tensor(4, dtype=ms.int32)) == (3,)
811
812
813def grad_refactor_5(a):
814    """ if_test """
815    if a > 3:
816        return 1
817    return a
818
819
820def grad_refactor_6(a, b):
821    """ if_test """
822    if a > b:
823        return 3 * a + b
824    return 2 * b * a
825
826
827def test_grad_refactor_6():
828    assert grad_all(grad_refactor_6)(Tensor(3, dtype=ms.int32), Tensor(2, dtype=ms.int32)) == (3, 1)
829
830
831def grad_refactor_while(x):
832    """ grad_refactor_while """
833    rval = x
834    while rval < 4:
835        rval = rval * rval
836    return rval
837
838
839def grad_refactor__while_1(x):
840    """ _while """
841    ret = x * x
842    i = 2
843    while i <= 3:
844        ret = ret * i
845        i = i + 1
846    return ret
847
848
849def test_grad_refactor_10():
850    """ test_grad_while """
851    assert grad_all(grad_refactor__while_1)(Tensor(5, dtype=ms.int32)) == (60,)
852
853
854def test_grad_refactor_11():
855    class Net(nn.Cell):
856        """ Net definition """
857
858        def __init__(self):
859            super(Net, self).__init__()
860
861        def construct(self, x, y):
862            return x * y * y
863
864    net = Net()
865    grad_all(net)(Tensor(np.ones([2]).astype(np.float32)), Tensor(np.ones([2]).astype(np.float32)))
866
867
868def test_grad_refactor_12():
869    class Net(nn.Cell):
870        """ Net definition """
871
872        def __init__(self):
873            super(Net, self).__init__()
874            self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
875
876        def construct(self, x, y):
877            return x * self.z * y
878
879    net = Net()
880    grad_all(net)(Tensor(np.ones([2]).astype(np.float32)), Tensor(np.zeros([2]).astype(np.float32)))
881
882
883def test_grad_refactor_13():
884    class Net(nn.Cell):
885        """ Net definition """
886
887        def __init__(self):
888            super(Net, self).__init__()
889            self.z = Parameter(Tensor(np.ones([2]).astype(np.float32)), name='z')
890
891        def construct(self, x, y):
892            return x * self.z * y
893
894    net = Net()
895    weights = ParameterTuple(net.trainable_params())
896    grad_by_list(net, weights)(Tensor(np.ones([2]).astype(np.float32)), Tensor(np.zeros([2]).astype(np.float32)))
897
898
899def grad_refactor_14(a, b):
900    """ if_test """
901
902    def inner1(x):
903        return x * b
904
905    def inner2(x):
906        return a * b
907
908    def inner3(x):
909        if x > 2:
910            return a
911        return b
912
913    return inner1(b) + inner2(a) + inner3(a)
914
915
916# pylint: disable=using-constant-test
917class IfDeferInline(nn.Cell):
918    def __init__(self, mul_size):
919        super().__init__()
920        self.mul_weight = Tensor(np.full(mul_size, 0.6, dtype=np.float32))
921        self.mul = P.Mul()
922
923    def construct(self, inputs):
924        x = self.mul(inputs, self.mul_weight)
925        if True:
926            x = x
927        return x
928
929
930def test_grad_if_defer_inline():
931    """ test_grad_if_defer_inline """
932    network = IfDeferInline([128, 96])
933    network.add_flags(defer_inline=False)
934    inp = Tensor(np.ones([128, 96]).astype(np.float32))
935    grads = grad_all(network)(inp)
936    assert np.all(grads[0].asnumpy() == np.full([128, 96], 0.6, dtype=np.float32))
937
938
939def test_dict_const():
940    class Net(nn.Cell):
941        def __init__(self):
942            super(Net, self).__init__()
943            self.res = {'1': 10}
944
945        def construct(self):
946            return self.res
947
948    Net()()
949