• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15""" test control ops """
16import numpy as np
17import pytest
18
19from mindspore import dtype as ms
20from mindspore import Tensor
21from mindspore import context
22from mindspore import nn
23from mindspore.common.parameter import Parameter, ParameterTuple
24from mindspore.ops import composite as C
25from mindspore.ops import operations as P
26
27
28grad_by_list = C.GradOperation(get_by_list=True)
29grad_all = C.GradOperation(get_all=True)
30
31
32def test_while_grad():
33    class MyWhileNet(nn.Cell):
34        def __init__(self):
35            super().__init__()
36            self.max = P.ReduceMax()
37
38        def construct(self, idx, end, x):
39            while idx < end:
40                part = x[idx, :, :]
41                max_num = self.max(part)
42                x[idx, :, 0:2] = max_num
43                idx = idx + 1
44            return x
45
46    class GradNet(nn.Cell):
47        def __init__(self, net):
48            super(GradNet, self).__init__()
49            self.net = net
50
51        def construct(self, *inputs):
52            return grad_all(self.net)(*inputs)
53
54    idx = Tensor(np.array(0), dtype=ms.int32)
55    end = Tensor(np.array(2), dtype=ms.int32)
56    x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
57    # graph mode
58    context.set_context(mode=context.GRAPH_MODE)
59    while_net = MyWhileNet()
60    net = GradNet(while_net)
61    graph_output = net(idx, end, x)
62
63    assert  graph_output == 0
64
65
66@pytest.mark.level0
67@pytest.mark.platform_arm_ascend_training
68@pytest.mark.platform_x86_gpu_training
69@pytest.mark.env_onecard
70def test_while_with_const_param_grad():
71    class MyWhileNet(nn.Cell):
72        def __init__(self):
73            super().__init__()
74            self.mul = P.Mul()
75            self.add = P.Add()
76
77        def construct(self, x, y):
78            while x < y:
79                z = self.mul(x, x)
80                x = self.add(z, 1)
81            return x
82
83    class GradNet(nn.Cell):
84        def __init__(self, net):
85            super(GradNet, self).__init__()
86            self.net = net
87
88        def construct(self, *inputs):
89            return grad_all(self.net)(*inputs)
90
91    context.set_context(mode=context.GRAPH_MODE)
92    while_net = MyWhileNet()
93    net = GradNet(while_net)
94    idx = Tensor([1.1], dtype=ms.float32)
95    end = Tensor([8.0], dtype=ms.float32)
96    graph_output = net(idx, end)
97    expect_one = np.array([1.14433983e+02], dtype=np.float32)
98    expect_two = np.array([0], dtype=np.float32)
99    assert np.allclose(graph_output[0].asnumpy(), expect_one, 0.0001, 0.0001)
100    assert np.allclose(graph_output[1].asnumpy(), expect_two, 0.0001, 0.0001)
101
102
103@pytest.mark.level0
104@pytest.mark.platform_arm_ascend_training
105@pytest.mark.platform_x86_gpu_training
106@pytest.mark.env_onecard
107def test_while_with_variable_grad():
108    class MyWhileNet(nn.Cell):
109        def __init__(self):
110            super().__init__()
111            self.mul = P.Mul()
112            self.add = P.Add()
113
114        def construct(self, x, y):
115            while x < y:
116                z = self.mul(x, x)
117                x = self.add(z, y)
118            return x
119
120    class GradNet(nn.Cell):
121        def __init__(self, net):
122            super(GradNet, self).__init__()
123            self.net = net
124
125        def construct(self, *inputs):
126            return grad_all(self.net)(*inputs)
127
128    context.set_context(mode=context.GRAPH_MODE)
129    while_net = MyWhileNet()
130    net = GradNet(while_net)
131    idx = Tensor([1.1], dtype=ms.float32)
132    end = Tensor([8.0], dtype=ms.float32)
133    graph_output = net(idx, end)
134    expect_one = np.array([2.20000005e+00], dtype=np.float32)
135    expect_two = np.array([1.00000000e+00], dtype=np.float32)
136    assert np.allclose(graph_output[0].asnumpy(), expect_one, 0.0001, 0.0001)
137    assert np.allclose(graph_output[1].asnumpy(), expect_two, 0.0001, 0.0001)
138
139
140@pytest.mark.level1
141@pytest.mark.platform_arm_ascend_training
142@pytest.mark.platform_x86_gpu_training
143@pytest.mark.env_onecard
144def test_while_with_param_forward():
145    class MyWhileNet(nn.Cell):
146        def __init__(self):
147            super().__init__()
148            self.max = P.ReduceMax()
149            self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
150            self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
151
152        def construct(self, idx, end, x):
153            out = self.zero
154            while idx < end:
155                part = x[idx, :, :]
156                max_num = self.max(part)
157                x[idx, :, 0:2] = max_num
158                out = out + x + self.param
159                idx = idx + 1
160            return out
161
162    # graph mode
163    context.set_context(mode=context.GRAPH_MODE)
164    net = MyWhileNet()
165    idx = Tensor(np.array(0), dtype=ms.int32)
166    end = Tensor(np.array(2), dtype=ms.int32)
167    x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32)
168    graph_output = net(idx, end, x)
169    expect = np.array([[[6, 8], [10, 12]], [[19, 22], [25, 28]]], dtype=np.int32)
170    assert np.allclose(graph_output.asnumpy(), expect, 0.0001, 0.0001)
171
172
173@pytest.mark.level0
174@pytest.mark.platform_arm_ascend_training
175@pytest.mark.platform_x86_gpu_training
176@pytest.mark.env_onecard
177def test_while_endless_case():
178    """endless case when optimization"""
179
180    class MyWhileNet(nn.Cell):
181        def __init__(self):
182            super().__init__()
183            self.max = P.ReduceMax()
184            self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
185            self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
186
187        def construct(self, idx, end, x):
188            out = self.zero
189            while idx < end:
190                part = x[idx, :, :]
191                out = out + part
192                idx = idx + 1
193            return out
194
195    idx = Tensor(np.array(0), dtype=ms.int32)
196    end = Tensor(np.array(2), dtype=ms.int32)
197    x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32)
198    # graph mode
199    context.set_context(mode=context.GRAPH_MODE)
200    net = MyWhileNet()
201    graph_output = net(idx, end, x)
202    expect = np.array([[[4, 6], [8, 10]],
203                       [[4, 6], [8, 10]]]).astype(np.float32)
204    assert np.allclose(graph_output.asnumpy(), expect, 0.0001, 0.0001)
205
206
207@pytest.mark.level1
208@pytest.mark.platform_arm_ascend_training
209@pytest.mark.platform_x86_gpu_training
210@pytest.mark.env_onecard
211def test_while_with_param_grad():
212    class MyWhileNet(nn.Cell):
213        def __init__(self):
214            super().__init__()
215            self.max = P.ReduceMax()
216            self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
217            self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
218
219        def construct(self, idx, end, x):
220            out = self.zero
221            while idx < end:
222                part = x[idx, :, :]
223                max_num = self.max(part)
224                x[idx, :, 0:2] = max_num
225                out = out + x + self.param
226                idx = idx + 1
227            return out
228
229    class GradNet(nn.Cell):
230        def __init__(self, net):
231            super(GradNet, self).__init__()
232            self.net = net
233            self.weights = ParameterTuple(net.trainable_params())
234
235        def construct(self, a, b, c):
236            return grad_by_list(self.net, self.weights)(a, b, c)
237
238    context.set_context(mode=context.GRAPH_MODE)
239    while_net = MyWhileNet()
240    net = GradNet(while_net)
241    idx = Tensor(np.array(0), dtype=ms.int32)
242    end = Tensor(np.array(2), dtype=ms.int32)
243    x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32)
244    graph_output = net(idx, end, x)
245    expect = np.array([[[2, 2], [2, 2]], [[2, 2], [2, 2]]], dtype=np.int32)
246    assert np.allclose(graph_output[0].asnumpy(), expect, 0.0001, 0.0001)
247
248
249@pytest.mark.level0
250@pytest.mark.platform_arm_ascend_training
251@pytest.mark.platform_x86_gpu_training
252@pytest.mark.env_onecard
253def test_while_with_param_forward_with_const_branch():
254    class MyWhileNet(nn.Cell):
255        def __init__(self):
256            super().__init__()
257            self.max = P.ReduceMax()
258            self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
259            self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
260            self.reduce = P.ReduceSum()
261
262        def construct(self, idx, end, x):
263            out = self.zero
264            while idx < end:
265                if 2 > 1:
266                    out = out + self.param
267                else:
268                    out = out + idx + self.param
269                idx = idx + 1
270            return out
271
272    idx = Tensor(np.array(0), dtype=ms.int32)
273    end = Tensor(np.array(4), dtype=ms.int32)
274    x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32)
275    # graph mode
276    context.set_context(mode=context.GRAPH_MODE)
277    while_net = MyWhileNet()
278    net = while_net
279    graph_output = net(idx, end, x)
280
281    expect = np.array([[[0, 4], [8, 12]],
282                       [[16, 20], [24, 28]]]).astype(np.float32)
283    assert np.allclose(graph_output.asnumpy(), expect, 0.0001, 0.0001)
284
285
286@pytest.mark.level1
287@pytest.mark.platform_arm_ascend_training
288@pytest.mark.platform_x86_gpu_training
289@pytest.mark.env_onecard
290def test_while_opt_endless():
291    """endless during optimization case"""
292
293    class MyWhileNet(nn.Cell):
294        def __init__(self):
295            super().__init__()
296            self.max = P.ReduceMax()
297            self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
298            self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
299            self.reduce = P.ReduceSum()
300            self.addn = P.AddN()
301
302        def construct(self, idx, end, x):
303            addn1 = self.addn((x, x, x))
304            out = addn1
305            while idx < end:
306                out = self.addn((out, addn1))
307                idx = idx + 1
308            out = self.addn((out, x))
309            return out
310
311    class GradNet(nn.Cell):
312        def __init__(self, net):
313            super(GradNet, self).__init__()
314            self.net = net
315
316        def construct(self, *inputs):
317            return grad_all(self.net)(*inputs)
318
319    idx = Tensor(np.array(0), dtype=ms.int32)
320    end = Tensor(np.array(4), dtype=ms.int32)
321    x = Tensor(np.ones([2, 2, 2]).astype(np.float32) * 3, dtype=ms.float32)
322    # graph mode
323    context.set_context(mode=context.GRAPH_MODE)
324    while_net = MyWhileNet()
325    net = GradNet(while_net)
326    graph_output = net(idx, end, x)
327
328    expect1 = 0
329    expect2 = 0
330    expect3 = np.array([[[16, 16], [16, 16]],
331                        [[16, 16], [16, 16]]]).astype(np.float32)
332    assert np.allclose(graph_output[0].asnumpy(), expect1, 0.0001, 0.0001)
333    assert np.allclose(graph_output[1].asnumpy(), expect2, 0.0001, 0.0001)
334    assert np.allclose(graph_output[2].asnumpy(), expect3, 0.0001, 0.0001)
335
336
337@pytest.mark.level0
338@pytest.mark.platform_arm_ascend_training
339@pytest.mark.platform_x86_ascend_training
340@pytest.mark.env_onecard
341def test_no_while_call():
342    class MyWhileNet(nn.Cell):
343        def __init__(self):
344            super().__init__()
345            self.max = P.ReduceMax()
346            self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
347            self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
348            self.reduce = P.ReduceSum()
349
350        def construct(self, idx, end, x):
351            out = self.zero
352            if 2 > 1:
353                out = out + self.param
354            else:
355                out = out + idx + self.param
356            return out
357
358    idx = Tensor(np.array(0), dtype=ms.int32)
359    end = Tensor(np.array(4), dtype=ms.int32)
360    x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32)
361    # graph mode
362    context.set_context(mode=context.GRAPH_MODE)
363    while_net = MyWhileNet()
364    net = while_net
365    graph_output = net(idx, end, x)
366
367    expect = np.array([[[0, 1], [2, 3]],
368                       [[4, 5], [6, 7]]]).astype(np.float32)
369    assert np.allclose(graph_output.asnumpy(), expect, 0.0001, 0.0001)
370
371
372@pytest.mark.level0
373@pytest.mark.platform_arm_ascend_training
374@pytest.mark.platform_x86_gpu_training
375@pytest.mark.env_onecard
376def test_while_with_param_grad_with_const_branch():
377    class MyWhileNet(nn.Cell):
378        def __init__(self):
379            super().__init__()
380            self.max = P.ReduceMax()
381            self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
382            self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
383            self.reduce = P.ReduceSum()
384
385        def construct(self, idx, end, x):
386            out = self.zero
387            while idx < end:
388                if 2 > 1:
389                    out = out + self.param
390                else:
391                    out = out + idx + self.param
392                idx = idx + 1
393            return out
394
395    class GradNet(nn.Cell):
396        def __init__(self, net):
397            super(GradNet, self).__init__()
398            self.net = net
399            self.weights = ParameterTuple(net.trainable_params())
400
401        def construct(self, a, b, c):
402            return grad_by_list(self.net, self.weights)(a, b, c)
403
404    idx = Tensor(np.array(0), dtype=ms.int32)
405    end = Tensor(np.array(4), dtype=ms.int32)
406    x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32)
407    # graph mode
408    context.set_context(mode=context.GRAPH_MODE)
409    while_net = MyWhileNet()
410    net = GradNet(while_net)
411    graph_output = net(idx, end, x)
412
413    expect = np.array([[[4, 4], [4, 4]],
414                       [[4, 4], [4, 4]]]).astype(np.float32)
415    assert np.allclose(graph_output[0].asnumpy(), expect, 0.0001, 0.0001)
416
417
418@pytest.mark.level0
419@pytest.mark.platform_arm_ascend_training
420@pytest.mark.platform_x86_ascend_training
421@pytest.mark.env_onecard
422def test_for_while_with_param_grad_with_const_branch():
423    class MyWhileNet(nn.Cell):
424        def __init__(self):
425            super().__init__()
426            self.max = P.ReduceMax()
427            self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
428            self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
429            self.reduce = P.ReduceSum()
430            self.start = Tensor(np.array(0), dtype=ms.int32)
431
432        def construct(self, idx, end, x):
433            out = self.zero
434            for _ in range(0, 2):
435                idx = self.start
436                while idx < end:
437                    if 2 > 1:
438                        out = out + self.param
439                    else:
440                        out = out + idx + self.param
441                    idx = idx + 1
442            return out
443
444    class GradNet(nn.Cell):
445        def __init__(self, net):
446            super(GradNet, self).__init__()
447            self.net = net
448            self.weights = ParameterTuple(net.trainable_params())
449
450        def construct(self, a, b, c):
451            return grad_by_list(self.net, self.weights)(a, b, c)
452
453    idx = Tensor(np.array(0), dtype=ms.int32)
454    end = Tensor(np.array(4), dtype=ms.int32)
455    x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32)
456    # graph mode
457    context.set_context(mode=context.GRAPH_MODE)
458    while_net = MyWhileNet()
459    net = GradNet(while_net)
460    graph_output = net(idx, end, x)
461
462    expect = np.array([[[8, 8], [8, 8]],
463                       [[8, 8], [8, 8]]]).astype(np.float32)
464    assert np.allclose(graph_output[0].asnumpy(), expect, 0.0001, 0.0001)
465
466
467@pytest.mark.level0
468@pytest.mark.platform_arm_ascend_training
469@pytest.mark.platform_x86_gpu_training
470@pytest.mark.env_onecard
471def test_for_while_with_param_grad_basic():
472    class MyWhileNet(nn.Cell):
473        def __init__(self):
474            super().__init__()
475            self.max = P.ReduceMax()
476            self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
477            self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
478            self.reduce = P.ReduceSum()
479            self.start = Tensor(np.array(0), dtype=ms.int32)
480
481        def construct(self, idx, end, x):
482            out = self.zero
483            for _ in range(0, 2):
484                idx = self.start
485                while idx < end:
486                    out = out + self.param
487                    idx = idx + 1
488            return out
489
490    class GradNet(nn.Cell):
491        def __init__(self, net):
492            super(GradNet, self).__init__()
493            self.net = net
494            self.weights = ParameterTuple(net.trainable_params())
495
496        def construct(self, a, b, c):
497            return grad_by_list(self.net, self.weights)(a, b, c)
498
499    idx = Tensor(np.array(0), dtype=ms.int32)
500    end = Tensor(np.array(4), dtype=ms.int32)
501    x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32)
502    # graph mode
503    context.set_context(mode=context.GRAPH_MODE)
504    while_net = MyWhileNet()
505    net = GradNet(while_net)
506    graph_output = net(idx, end, x)
507    expect = np.array([[[8, 8], [8, 8]],
508                       [[8, 8], [8, 8]]]).astype(np.float32)
509    assert np.allclose(graph_output[0].asnumpy(), expect, 0.0001, 0.0001)
510
511
512@pytest.mark.level0
513@pytest.mark.platform_arm_ascend_training
514@pytest.mark.platform_x86_gpu_training
515@pytest.mark.env_onecard
516def test_for_while_with_param_grad_normal():
517    class MyWhileNet(nn.Cell):
518        def __init__(self):
519            super().__init__()
520            self.max = P.ReduceMax()
521            self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
522            self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
523            self.reduce = P.ReduceSum()
524            self.start = Tensor(np.array(0), dtype=ms.int32)
525
526        def construct(self, idx, end, x):
527            out = x
528            for _ in range(0, 2):
529                idx = self.start
530                while idx < end:
531                    out = out + self.param
532                    idx = idx + 1
533            return out
534
535    class GradNet(nn.Cell):
536        def __init__(self, net):
537            super(GradNet, self).__init__()
538            self.net = net
539            self.weights = ParameterTuple(net.trainable_params())
540
541        def construct(self, a, b, c):
542            return grad_by_list(self.net, self.weights)(a, b, c)
543
544    idx = Tensor(np.array(0), dtype=ms.int32)
545    end = Tensor(np.array(4), dtype=ms.int32)
546    x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32)
547    # graph mode
548    context.set_context(mode=context.GRAPH_MODE)
549    while_net = MyWhileNet()
550    net = GradNet(while_net)
551    graph_output = net(idx, end, x)
552    expect = np.array([[[8, 8], [8, 8]],
553                       [[8, 8], [8, 8]]]).astype(np.float32)
554    assert np.allclose(graph_output[0].asnumpy(), expect, 0.0001, 0.0001)
555
556
557@pytest.mark.level0
558@pytest.mark.platform_arm_ascend_training
559@pytest.mark.platform_x86_gpu_training
560@pytest.mark.env_onecard
561def test_while_with_param_basic_grad():
562    class MyWhileNet(nn.Cell):
563        def __init__(self):
564            super().__init__()
565            self.max = P.ReduceMax()
566            self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
567            self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
568            self.t2 = Tensor(np.array(2), dtype=ms.float32)
569
570        def construct(self, idx, end, x):
571            out = self.zero
572            while idx < end:
573                out = out + self.param
574                idx = idx + 1
575            return out + self.param
576
577    class GradNet(nn.Cell):
578        def __init__(self, net):
579            super(GradNet, self).__init__()
580            self.net = net
581            self.weights = ParameterTuple(net.trainable_params())
582
583        def construct(self, a, b, c):
584            return grad_by_list(self.net, self.weights)(a, b, c)
585
586    idx = Tensor(np.array(0), dtype=ms.int32)
587    end = Tensor(np.array(3), dtype=ms.int32)
588    x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32)
589    # graph mode
590    context.set_context(mode=context.GRAPH_MODE)
591    while_net = MyWhileNet()
592    net = GradNet(while_net)
593    graph_output = net(idx, end, x)
594    expect = np.array([[[4, 4], [4, 4]],
595                       [[4, 4], [4, 4]]]).astype(np.float32)
596    assert np.allclose(graph_output[0].asnumpy(), expect, 0.0001, 0.0001)
597
598
599@pytest.mark.level1
600@pytest.mark.platform_arm_ascend_training
601@pytest.mark.platform_x86_gpu_training
602@pytest.mark.env_onecard
603def test_while_with_param_basic_grad_mul():
604    class MyWhileNet(nn.Cell):
605        def __init__(self):
606            super().__init__()
607            self.max = P.ReduceMax()
608            self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
609            self.zero = Tensor(np.ones(([2, 2, 2])), ms.float32)
610            self.t2 = Tensor(np.array(2), dtype=ms.float32)
611
612        def construct(self, idx, end, x):
613            out = self.zero
614            while idx < end:
615                out = out * self.param
616                idx = idx + 1
617            return out + self.param
618
619    class GradNet(nn.Cell):
620        def __init__(self, net):
621            super(GradNet, self).__init__()
622            self.net = net
623            self.weights = ParameterTuple(net.trainable_params())
624
625        def construct(self, a, b, c):
626            return grad_by_list(self.net, self.weights)(a, b, c)
627
628    idx = Tensor(np.array(0), dtype=ms.int32)
629    end = Tensor(np.array(3), dtype=ms.int32)
630    x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32)
631    # graph mode
632    context.set_context(mode=context.GRAPH_MODE)
633    while_net = MyWhileNet()
634    net = GradNet(while_net)
635    graph_output = net(idx, end, x)
636    expect = np.array([[[1, 4], [13, 28]],
637                       [[49, 76], [109, 148]]]).astype(np.float32)
638    assert np.allclose(graph_output[0].asnumpy(), expect, 0.0001, 0.0001)
639
640
641@pytest.mark.level0
642@pytest.mark.platform_arm_ascend_training
643@pytest.mark.platform_x86_gpu_training
644@pytest.mark.env_onecard
645def test_while_with_param_basic_grad_two():
646    class MyWhileNet(nn.Cell):
647        def __init__(self):
648            super().__init__()
649            self.max = P.ReduceMax()
650            self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
651            self.weight = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="loss")
652            self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
653            self.t2 = Tensor(np.array(2), dtype=ms.float32)
654
655        def construct(self, idx, end, x):
656            out = self.zero
657            while idx < end:
658                out = out + self.param + self.weight
659                idx = idx + 1
660            return out + self.param
661
662    class GradNet(nn.Cell):
663        def __init__(self, net):
664            super(GradNet, self).__init__()
665            self.net = net
666            self.weights = ParameterTuple(net.trainable_params())
667
668        def construct(self, a, b, c):
669            return grad_by_list(self.net, self.weights)(a, b, c)
670
671    idx = Tensor(np.array(0), dtype=ms.int32)
672    end = Tensor(np.array(3), dtype=ms.int32)
673    x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32)
674    # graph mode
675    context.set_context(mode=context.GRAPH_MODE)
676    while_net = MyWhileNet()
677    net = GradNet(while_net)
678    graph_output = net(idx, end, x)
679
680    expect1 = np.array([[[4, 4], [4, 4]],
681                        [[4, 4], [4, 4]]]).astype(np.float32)
682    expect2 = np.array([[[3, 3], [3, 3]],
683                        [[3, 3], [3, 3]]]).astype(np.float32)
684    assert np.allclose(graph_output[0].asnumpy(), expect1, 0.0001, 0.0001)
685    assert np.allclose(graph_output[1].asnumpy(), expect2, 0.0001, 0.0001)
686
687
688@pytest.mark.level0
689@pytest.mark.platform_arm_ascend_training
690@pytest.mark.platform_x86_gpu_training
691@pytest.mark.env_onecard
692def test_while_with_param_basic_grad_three():
693    class MyWhileNet(nn.Cell):
694        def __init__(self):
695            super().__init__()
696            self.max = P.ReduceMax()
697            self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
698            self.weight = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="loss")
699            self.key = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="key")
700            self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
701            self.t2 = Tensor(np.array(2), dtype=ms.float32)
702
703        def construct(self, idx, end, x):
704            out = self.zero
705            while idx < end:
706                out = out + self.param + self.weight + self.key
707                idx = idx + 1
708            return out + self.param
709
710    class GradNet(nn.Cell):
711        def __init__(self, net):
712            super(GradNet, self).__init__()
713            self.net = net
714            self.weights = ParameterTuple(net.trainable_params())
715
716        def construct(self, a, b, c):
717            return grad_by_list(self.net, self.weights)(a, b, c)
718
719    idx = Tensor(np.array(0), dtype=ms.int32)
720    end = Tensor(np.array(3), dtype=ms.int32)
721    x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32)
722    # graph mode
723    context.set_context(mode=context.GRAPH_MODE)
724    while_net = MyWhileNet()
725    net = GradNet(while_net)
726    graph_output = net(idx, end, x)
727    expect1 = np.array([[[4, 4], [4, 4]],
728                        [[4, 4], [4, 4]]]).astype(np.float32)
729    expect2 = np.array([[[3, 3], [3, 3]],
730                        [[3, 3], [3, 3]]]).astype(np.float32)
731    expect3 = np.array([[[3, 3], [3, 3]],
732                        [[3, 3], [3, 3]]]).astype(np.float32)
733    assert np.allclose(graph_output[0].asnumpy(), expect1, 0.0001, 0.0001)
734    assert np.allclose(graph_output[1].asnumpy(), expect2, 0.0001, 0.0001)
735    assert np.allclose(graph_output[2].asnumpy(), expect3, 0.0001, 0.0001)
736
737
738@pytest.mark.level0
739@pytest.mark.platform_arm_ascend_training
740@pytest.mark.platform_x86_gpu_training
741@pytest.mark.env_onecard
742def test_while_if_with_param_grad():
743    class MyWhileNet(nn.Cell):
744        def __init__(self):
745            super().__init__()
746            self.max = P.ReduceMax()
747            self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
748            self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
749            self.t2 = Tensor(np.array(2), dtype=ms.float32)
750
751        def construct(self, idx, end, x):
752            out = self.zero
753            while idx < end:
754                if self.max(out) < self.max(x):
755                    out = out + self.param * 2
756                else:
757                    out = out + self.param
758                idx = idx + 1
759            return out + self.param
760
761    class GradNet(nn.Cell):
762        def __init__(self, net):
763            super(GradNet, self).__init__()
764            self.net = net
765            self.weights = ParameterTuple(net.trainable_params())
766
767        def construct(self, a, b, c):
768            return grad_by_list(self.net, self.weights)(a, b, c)
769
770    idx = Tensor(np.array(0), dtype=ms.int32)
771    end = Tensor(np.array(3), dtype=ms.int32)
772    x = Tensor(np.ones([2, 2, 2]).astype(np.float32), dtype=ms.float32)
773    context.set_context(mode=context.GRAPH_MODE)
774    while_net = MyWhileNet()
775    net = GradNet(while_net)
776    graph_output = net(idx, end, x)
777    expect = np.array([[[5, 5], [5, 5]],
778                       [[5, 5], [5, 5]]]).astype(np.float32)
779    assert np.allclose(graph_output[0].asnumpy(), expect, 0.0001, 0.0001)
780
781
782@pytest.mark.level0
783@pytest.mark.platform_arm_ascend_training
784@pytest.mark.platform_x86_ascend_training
785@pytest.mark.env_onecard
786def test_while_with_param_grad_not_enter_while():
787    class MyWhileNet(nn.Cell):
788        def __init__(self):
789            super().__init__()
790            self.max = P.ReduceMax()
791            self.param = Parameter(Tensor(2, ms.float32), name="weight")
792            self.zero = Tensor(0, ms.float32)
793
794        def construct(self, idx, end, x):
795            out = self.zero
796            while idx < end:
797                out = out + self.param * 3
798                idx = idx + 1
799            return out + self.param
800
801    class GradNet(nn.Cell):
802        def __init__(self, net):
803            super(GradNet, self).__init__()
804            self.net = net
805            self.weights = ParameterTuple(net.trainable_params())
806
807        def construct(self, a, b, c):
808            return grad_by_list(self.net, self.weights)(a, b, c)
809
810    idx = Tensor(np.array(3), dtype=ms.int32)
811    end = Tensor(np.array(0), dtype=ms.int32)
812    x = Tensor(2, dtype=ms.float32)
813    # graph mode
814    context.set_context(mode=context.GRAPH_MODE)
815    while_net = MyWhileNet()
816    net = GradNet(while_net)
817    graph_output = net(idx, end, x)
818
819    assert np.allclose(graph_output[0].asnumpy(), 1, 0.0001, 0.0001)
820
821
822@pytest.mark.level0
823@pytest.mark.platform_arm_ascend_training
824@pytest.mark.platform_x86_gpu_training
825@pytest.mark.env_onecard
826def test_with_param_if_by_if_forward():
827    class MyIfByIfNet(nn.Cell):
828        def __init__(self):
829            super().__init__()
830            self.max = P.ReduceMax()
831            self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
832            self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
833
834        def construct(self, a, b, x):
835            out = self.zero
836            if a < b:
837                out = out + x + self.param
838            else:
839                out = out + x
840            if a == b:
841                out = out + x * 3 + self.param
842            else:
843                out = out + x * 2
844            return out
845
846    idx = Tensor(np.array(0), dtype=ms.int32)
847    end = Tensor(np.array(4), dtype=ms.int32)
848    x = Tensor(np.ones([2, 2, 2]).astype(np.float32), dtype=ms.float32)
849    # graph mode
850    context.set_context(mode=context.GRAPH_MODE)
851    if_net = MyIfByIfNet()
852    net = if_net
853    graph_output = net(idx, end, x)
854    expect = np.array([[[3, 4], [5, 6]],
855                       [[7, 8], [9, 10]]]).astype(np.float32)
856    assert np.allclose(graph_output.asnumpy(), expect, 0.0001, 0.0001)
857
858
859@pytest.mark.level0
860@pytest.mark.platform_arm_ascend_training
861@pytest.mark.platform_x86_gpu_training
862@pytest.mark.env_onecard
863def test_with_param_if_by_if_grad_inputs():
864    class MyIfByIfNet(nn.Cell):
865        def __init__(self):
866            super().__init__()
867            self.max = P.ReduceMax()
868            self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
869            self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
870
871        def construct(self, a, b, x):
872            out = self.zero
873            if a < b:
874                out = out + x + self.param * 4
875            if a == b:
876                out = out + x * 3 + self.param * 3
877            return out
878
879    class GradNet(nn.Cell):
880        def __init__(self, net):
881            super(GradNet, self).__init__()
882            self.net = net
883
884        def construct(self, *inputs):
885            return grad_all(self.net)(*inputs)
886
887    idx = Tensor(np.array(0), dtype=ms.int32)
888    end = Tensor(np.array(0), dtype=ms.int32)
889    x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32)
890    # graph mode
891    context.set_context(mode=context.GRAPH_MODE)
892    if_net = MyIfByIfNet()
893    net = GradNet(if_net)
894    graph_output = net(idx, end, x)
895    expect1 = Tensor(np.array(0), dtype=ms.int32)
896    expect2 = Tensor(np.array(0), dtype=ms.int32)
897    expect3 = np.array([[[3, 3], [3, 3]],
898                        [[3, 3], [3, 3]]]).astype(np.float32)
899    assert np.allclose(graph_output[0].asnumpy(), expect1.asnumpy(), 0.0001, 0.0001)
900    assert np.allclose(graph_output[1].asnumpy(), expect2.asnumpy(), 0.0001, 0.0001)
901    assert np.allclose(graph_output[2].asnumpy(), expect3, 0.0001, 0.0001)
902
903
904@pytest.mark.level0
905@pytest.mark.platform_arm_ascend_training
906@pytest.mark.platform_x86_gpu_training
907@pytest.mark.env_onecard
908def test_with_param_if_by_if_grad_parameter():
909    class MyIfByIfNet(nn.Cell):
910        def __init__(self):
911            super().__init__()
912            self.max = P.ReduceMax()
913            self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
914            self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
915
916        def construct(self, a, b, x):
917            out = self.zero
918            if a < b:
919                out = out + x + self.param * 2
920            if a == b:
921                out = out + x * 3 + self.param
922            return out
923
924    class GradNet(nn.Cell):
925        def __init__(self, net):
926            super(GradNet, self).__init__()
927            self.net = net
928            self.weights = ParameterTuple(net.trainable_params())
929
930        def construct(self, *inputs):
931            return grad_by_list(self.net, self.weights)(*inputs)
932
933    idx = Tensor(np.array(0), dtype=ms.int32)
934    end = Tensor(np.array(2), dtype=ms.int32)
935    x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32)
936    # graph mode
937    context.set_context(mode=context.GRAPH_MODE)
938    if_net = MyIfByIfNet()
939    net = GradNet(if_net)
940    graph_output = net(idx, end, x)
941
942    expect = np.array([[[2, 2], [2, 2]],
943                       [[2, 2], [2, 2]]]).astype(np.float32)
944    assert np.allclose(graph_output[0].asnumpy(), expect, 0.0001, 0.0001)
945
946
947@pytest.mark.level0
948@pytest.mark.platform_arm_ascend_training
949@pytest.mark.platform_x86_gpu_training
950@pytest.mark.env_onecard
951def test_with_param_if_by_if_grad_param_excute_null():
952    class MyIfByIfNet(nn.Cell):
953        def __init__(self):
954            super().__init__()
955            self.max = P.ReduceMax()
956            self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
957            self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
958
959        def construct(self, a, b, x):
960            out = self.zero
961            if a < b:
962                out = out + x + self.param * 2
963            return out
964
965    class GradNet(nn.Cell):
966        def __init__(self, net):
967            super(GradNet, self).__init__()
968            self.net = net
969            self.weights = ParameterTuple(net.trainable_params())
970
971        def construct(self, *inputs):
972            return grad_by_list(self.net, self.weights)(*inputs)
973
974    idx = Tensor(np.array(4), dtype=ms.int32)
975    end = Tensor(np.array(0), dtype=ms.int32)
976    x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32)
977    # graph mode
978    context.set_context(mode=context.GRAPH_MODE)
979    if_net = MyIfByIfNet()
980    net = GradNet(if_net)
981    graph_output = net(idx, end, x)
982
983    expect = np.array([[[0, 0], [0, 0]],
984                       [[0, 0], [0, 0]]]).astype(np.float32)
985    assert np.allclose(graph_output[0].asnumpy(), expect, 0.0001, 0.0001)
986
987
988@pytest.mark.level1
989@pytest.mark.platform_arm_ascend_training
990@pytest.mark.platform_x86_gpu_training
991@pytest.mark.env_onecard
992def test_if_by_if_return_inside_grad():
993    class MyIfByIfNet(nn.Cell):
994        def __init__(self):
995            super().__init__()
996            self.max = P.ReduceMax()
997            self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
998            self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
999
1000        def construct(self, a, b, x):
1001            out = self.zero
1002            if a < b:
1003                return out + x + self.param
1004            if a == b:
1005                return out + self.param * 2
1006            return out + self.param * 3
1007
1008    class GradNet(nn.Cell):
1009        def __init__(self, net):
1010            super(GradNet, self).__init__()
1011            self.net = net
1012            self.weights = ParameterTuple(net.trainable_params())
1013
1014        def construct(self, *inputs):
1015            return grad_by_list(self.net, self.weights)(*inputs)
1016
1017    idx = Tensor(np.array(1), dtype=ms.int32)
1018    end = Tensor(np.array(0), dtype=ms.int32)
1019    x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32)
1020    # graph mode
1021    context.set_context(mode=context.GRAPH_MODE)
1022    if_net = MyIfByIfNet()
1023    net = GradNet(if_net)
1024    graph_output = net(idx, end, x)
1025
1026    expect = np.array([[[3, 3], [3, 3]],
1027                       [[3, 3], [3, 3]]]).astype(np.float32)
1028    assert np.allclose(graph_output[0].asnumpy(), expect, 0.0001, 0.0001)
1029
1030
1031@pytest.mark.level1
1032@pytest.mark.platform_arm_ascend_training
1033@pytest.mark.platform_x86_gpu_training
1034@pytest.mark.env_onecard
1035def test_if_by_if_forward():
1036    class MyIfByIfNet(nn.Cell):
1037        def __init__(self):
1038            super().__init__()
1039            self.add = P.Add()
1040            self.sub = P.Sub()
1041            self.mul = P.Mul()
1042            self.div = P.RealDiv()
1043
1044        def construct(self, a, b, x):
1045            if a < b:
1046                a = self.add(a, b)
1047            else:
1048                a = self.sub(a, b)
1049            if a == x:
1050                a = self.mul(a, b)
1051            else:
1052                a = self.div(a, b)
1053            if b == x:
1054                b = self.add(a, b)
1055            else:
1056                b = self.add(a, x)
1057            a = a * b
1058            out = a + b + x
1059            return out
1060
1061    idx = Tensor(np.array(2), dtype=ms.float32)
1062    end = Tensor(np.array(3), dtype=ms.float32)
1063    x = Tensor(np.array(4), dtype=ms.float32)
1064    # graph mode
1065    context.set_context(mode=context.GRAPH_MODE)
1066    if_net = MyIfByIfNet()
1067    net = if_net
1068    graph_output = net(idx, end, x)
1069    expect = 19.11111
1070    assert np.allclose(graph_output.asnumpy(), expect, 0.0001, 0.0001)
1071
1072
1073@pytest.mark.level0
1074@pytest.mark.platform_arm_ascend_training
1075@pytest.mark.platform_x86_gpu_training
1076@pytest.mark.env_onecard
1077def test_if_by_if_forward_control_tuple_switch():
1078    """tuple_get from  switch op will generate new switch inside to eliminate tuple_get"""
1079
1080    class Branch3Net(nn.Cell):
1081        def __init__(self):
1082            super().__init__()
1083            self.add = P.Add()
1084            self.sub = P.Sub()
1085            self.mul = P.Mul()
1086            self.div = P.RealDiv()
1087
1088        def construct(self, a, b, x):
1089            if b == x:
1090                b = self.add(a, b)
1091            else:
1092                b = self.add(a, x)
1093            return a, b, x
1094
1095    class Branch2Net(nn.Cell):
1096        def __init__(self):
1097            super().__init__()
1098            self.add = P.Add()
1099            self.sub = P.Sub()
1100            self.mul = P.Mul()
1101            self.div = P.RealDiv()
1102            self.net = Branch3Net()
1103
1104        def construct(self, a, b, x):
1105            if a == x:
1106                a = self.mul(a, b)
1107            else:
1108                a = self.div(a, b)
1109            return self.net(a, b, x)
1110
1111    class MyIfByIfNet(nn.Cell):
1112        def __init__(self):
1113            super().__init__()
1114            self.add = P.Add()
1115            self.sub = P.Sub()
1116            self.mul = P.Mul()
1117            self.div = P.RealDiv()
1118            self.net = Branch2Net()
1119
1120        def construct(self, a, b, x):
1121            if a < b:
1122                a = self.add(a, b)
1123            else:
1124                a = self.sub(a, b)
1125            a, b, x = self.net(a, b, x)
1126            a = a * b
1127            out = a + b + x
1128            return out
1129
1130    idx = Tensor(np.array(2), dtype=ms.float32)
1131    end = Tensor(np.array(3), dtype=ms.float32)
1132    x = Tensor(np.array(0), dtype=ms.float32)
1133    # graph mode
1134    context.set_context(mode=context.GRAPH_MODE)
1135    if_net = MyIfByIfNet()
1136    net = if_net
1137    graph_output = net(idx, end, x)
1138    expect = 4.444444
1139    assert np.allclose(graph_output.asnumpy(), expect, 0.0001, 0.0001)
1140
1141
1142@pytest.mark.level0
1143@pytest.mark.platform_arm_ascend_training
1144@pytest.mark.platform_x86_gpu_training
1145@pytest.mark.env_onecard
1146def test_if_by_if_forward_control_inside_net():
1147    class Branch3Net(nn.Cell):
1148        def __init__(self):
1149            super().__init__()
1150            self.add = P.Add()
1151            self.sub = P.Sub()
1152            self.mul = P.Mul()
1153            self.div = P.RealDiv()
1154
1155        def construct(self, a, b, x):
1156            if b == x:
1157                b = self.add(a, b)
1158            else:
1159                b = self.add(a, x)
1160            a = a * b
1161            out = a + b + x
1162            return out
1163
1164    class Branch2Net(nn.Cell):
1165        def __init__(self):
1166            super().__init__()
1167            self.add = P.Add()
1168            self.sub = P.Sub()
1169            self.mul = P.Mul()
1170            self.div = P.RealDiv()
1171            self.net = Branch3Net()
1172
1173        def construct(self, a, b, x):
1174            if a == x:
1175                a = self.mul(a, b)
1176            else:
1177                a = self.div(a, b)
1178            return self.net(a, b, x)
1179
1180    class MyIfByIfNet(nn.Cell):
1181        def __init__(self):
1182            super().__init__()
1183            self.add = P.Add()
1184            self.sub = P.Sub()
1185            self.mul = P.Mul()
1186            self.div = P.RealDiv()
1187            self.net = Branch2Net()
1188
1189        def construct(self, a, b, x):
1190            if a < b:
1191                a = self.add(a, b)
1192            else:
1193                a = self.sub(a, b)
1194            out = self.net(a, b, x)
1195            return out
1196
1197    idx = Tensor(np.array(2), dtype=ms.float32)
1198    end = Tensor(np.array(3), dtype=ms.float32)
1199    x = Tensor(np.array(0), dtype=ms.float32)
1200    # graph mode
1201    context.set_context(mode=context.GRAPH_MODE)
1202    if_net = MyIfByIfNet()
1203    net = if_net
1204    graph_output = net(idx, end, x)
1205    expect = 4.444444
1206    assert np.allclose(graph_output.asnumpy(), expect, 0.0001, 0.0001)
1207
1208
1209@pytest.mark.level1
1210@pytest.mark.platform_arm_ascend_training
1211@pytest.mark.platform_x86_ascend_training
1212@pytest.mark.env_onecard
1213def test_if_by_if_forward_use_namespace():
1214    class MyIfByIfNet(nn.Cell):
1215        def __init__(self):
1216            super().__init__()
1217            self.add = P.Add()
1218            self.sub = P.Sub()
1219            self.mul = P.Mul()
1220            self.div = P.RealDiv()
1221
1222        def construct(self, a, b, x):
1223            if a < b:
1224                a = P.Add()(a, b)
1225            else:
1226                a = P.Sub()(a, b)
1227            if a == x:
1228                a = P.Mul()(a, b)
1229            else:
1230                a = P.RealDiv()(a, b)
1231            if b == x:
1232                b = P.Add()(a, b)
1233            else:
1234                b = P.Add()(a, x)
1235            a = a * b
1236            out = a + b + x
1237            return out
1238
1239    idx = Tensor(np.array(2), dtype=ms.float32)
1240    end = Tensor(np.array(3), dtype=ms.float32)
1241    x = Tensor(np.array(0), dtype=ms.float32)
1242    # graph mode
1243    context.set_context(mode=context.GRAPH_MODE)
1244    if_net = MyIfByIfNet()
1245    net = if_net
1246    graph_output = net(idx, end, x)
1247    expect = 4.444444
1248    assert np.allclose(graph_output.asnumpy(), expect, 0.0001, 0.0001)
1249
1250
1251@pytest.mark.level1
1252@pytest.mark.platform_arm_ascend_training
1253@pytest.mark.platform_x86_ascend_training
1254@pytest.mark.env_onecard
1255def test_if_by_if_forward_use_global_op():
1256    class MyIfByIfNet(nn.Cell):
1257        def __init__(self):
1258            super().__init__()
1259            self.add = P.Add()
1260            self.sub = P.Sub()
1261            self.mul = P.Mul()
1262            self.div = P.RealDiv()
1263
1264        def construct(self, a, b, x):
1265            add = P.Add()
1266            sub = P.Sub()
1267            mul = P.Mul()
1268            div = P.RealDiv()
1269            if a < b:
1270                a = add(a, b)
1271            else:
1272                a = sub(a, b)
1273            if a == x:
1274                a = mul(a, b)
1275            else:
1276                a = div(a, b)
1277            if b == x:
1278                b = add(a, b)
1279            else:
1280                b = add(a, x)
1281            a = a * b
1282            out = a + b + x
1283            return out
1284
1285    idx = Tensor(np.array(2), dtype=ms.float32)
1286    end = Tensor(np.array(3), dtype=ms.float32)
1287    x = Tensor(np.array(0), dtype=ms.float32)
1288    # graph mode
1289    context.set_context(mode=context.GRAPH_MODE)
1290    if_net = MyIfByIfNet()
1291    net = if_net
1292    graph_output = net(idx, end, x)
1293
1294    expect = 4.444444
1295    assert np.allclose(graph_output.asnumpy(), expect, 0.0001, 0.0001)
1296
1297
1298@pytest.mark.level1
1299@pytest.mark.platform_arm_ascend_training
1300@pytest.mark.platform_x86_ascend_training
1301@pytest.mark.env_onecard
1302def test_for_with_if_by_if_forward():
1303    class MyIfByIfNet(nn.Cell):
1304        def __init__(self):
1305            super().__init__()
1306            self.add = P.Add()
1307            self.sub = P.Sub()
1308
1309        def construct(self, a, b, x):
1310            for _ in range(0, 4):
1311                if a < b:
1312                    a = self.add(a, b)
1313                else:
1314                    b = self.sub(b, x)
1315            a = a * b
1316            out = a + b + x
1317            return out
1318
1319    idx = Tensor(np.array(2), dtype=ms.float32)
1320    end = Tensor(np.array(3), dtype=ms.float32)
1321    x = Tensor(np.array(0), dtype=ms.float32)
1322    # graph mode
1323    context.set_context(mode=context.GRAPH_MODE)
1324    if_net = MyIfByIfNet()
1325    net = if_net
1326    graph_output = net(idx, end, x)
1327
1328    expect = 18.0
1329    assert np.allclose(graph_output.asnumpy(), expect, 0.0001, 0.0001)
1330
1331
1332@pytest.mark.level1
1333@pytest.mark.platform_arm_ascend_training
1334@pytest.mark.platform_x86_ascend_training
1335@pytest.mark.env_onecard
1336def test_for_with_if_by_if_forward_namespace():
1337    class MyIfByIfNet(nn.Cell):
1338        def __init__(self):
1339            super().__init__()
1340            self.add = P.Add()
1341            self.sub = P.Sub()
1342            self.mul = P.Mul()
1343            self.div = P.RealDiv()
1344
1345        def construct(self, a, b, x):
1346            for _ in range(0, 6):
1347                if a < b:
1348                    a = P.Add()(a, b)
1349                else:
1350                    b = P.Sub()(b, x)
1351            a = a * b
1352            out = a + b + x
1353            return out
1354
1355    idx = Tensor(np.array(2), dtype=ms.float32)
1356    end = Tensor(np.array(3), dtype=ms.float32)
1357    x = Tensor(np.array(0), dtype=ms.float32)
1358    # graph mode
1359    context.set_context(mode=context.GRAPH_MODE)
1360    if_net = MyIfByIfNet()
1361    net = if_net
1362    graph_output = net(idx, end, x)
1363
1364    expect = 18.0
1365    assert np.allclose(graph_output.asnumpy(), expect, 0.0001, 0.0001)
1366
1367
1368@pytest.mark.level1
1369@pytest.mark.platform_arm_ascend_training
1370@pytest.mark.platform_x86_ascend_training
1371@pytest.mark.env_onecard
1372def test_if_by_if_forward_const_branch_inner():
1373    class MyIfByIfNet(nn.Cell):
1374        def __init__(self):
1375            super().__init__()
1376            self.add = P.Add()
1377            self.sub = P.Sub()
1378            self.mul = P.Mul()
1379            self.div = P.RealDiv()
1380
1381        def construct(self, a, b, x):
1382            add = P.Add()
1383            sub = P.Sub()
1384            mul = P.Mul()
1385            div = P.RealDiv()
1386            if a < b:
1387                a = add(a, b)
1388            else:
1389                a = sub(a, b)
1390            if 2 > 1:
1391                a = mul(a, b)
1392            else:
1393                a = div(a, b)
1394            if b == x:
1395                b = add(a, b)
1396            else:
1397                b = add(a, x)
1398            a = a * b
1399            out = a + b + x
1400            return out
1401
1402    idx = Tensor(np.array(2), dtype=ms.float32)
1403    end = Tensor(np.array(3), dtype=ms.float32)
1404    x = Tensor(np.array(0), dtype=ms.float32)
1405    # graph mode
1406    context.set_context(mode=context.GRAPH_MODE)
1407    if_net = MyIfByIfNet()
1408    net = if_net
1409    graph_output = net(idx, end, x)
1410
1411    expect = 240.0
1412    assert np.allclose(graph_output.asnumpy(), expect, 0.0001, 0.0001)
1413
1414
1415@pytest.mark.level1
1416@pytest.mark.platform_arm_ascend_training
1417@pytest.mark.platform_x86_ascend_training
1418@pytest.mark.env_onecard
1419def test_if_by_if_forward_all_const_branch():
1420    class MyIfByIfNet(nn.Cell):
1421        def __init__(self):
1422            super().__init__()
1423            self.add = P.Add()
1424            self.sub = P.Sub()
1425            self.mul = P.Mul()
1426            self.div = P.RealDiv()
1427
1428        def construct(self, a, b, x):
1429            add = P.Add()
1430            sub = P.Sub()
1431            mul = P.Mul()
1432            div = P.RealDiv()
1433            if 2 < 12:
1434                a = add(a, b)
1435            else:
1436                a = sub(a, b)
1437            if 2 > 1:
1438                a = mul(a, b)
1439            else:
1440                a = div(a, b)
1441            if 2 == 1:
1442                b = add(a, b)
1443            else:
1444                b = add(a, x)
1445            a = a * b
1446            out = a + b + x
1447            return out
1448
1449    idx = Tensor(np.array(2), dtype=ms.float32)
1450    end = Tensor(np.array(3), dtype=ms.float32)
1451    x = Tensor(np.array(0), dtype=ms.float32)
1452    # graph mode
1453    context.set_context(mode=context.GRAPH_MODE)
1454    if_net = MyIfByIfNet()
1455    net = if_net
1456    graph_output = net(idx, end, x)
1457
1458    expect = 240.0
1459    assert np.allclose(graph_output.asnumpy(), expect, 0.0001, 0.0001)
1460
1461
1462@pytest.mark.level1
1463@pytest.mark.platform_x86_cpu
1464@pytest.mark.platform_x86_gpu_training
1465@pytest.mark.env_onecard
1466def test_if_const_grad():
1467    class MyNet(nn.Cell):
1468        def __init__(self):
1469            super().__init__()
1470            self.add = P.Add()
1471
1472        def construct(self, *inputs):
1473            out = self.add(*inputs)
1474            return out
1475
1476    class GradNet(nn.Cell):
1477        def __init__(self, net):
1478            super(GradNet, self).__init__()
1479            self.net = net
1480            self.weights = ParameterTuple(net.trainable_params())
1481
1482        def construct(self, *inputs):
1483            a = 1
1484            b = 2
1485            if a > 0:
1486                b = 1
1487            a += b
1488            return grad_by_list(self.net, self.weights)(*inputs)
1489
1490    context.set_context(mode=context.GRAPH_MODE)
1491    my_net = MyNet()
1492    net = GradNet(my_net)
1493    a = Tensor(np.array(0), dtype=ms.int32)
1494    b = Tensor(np.array(1), dtype=ms.int32)
1495    net(a, b)
1496
1497
1498@pytest.mark.level1
1499@pytest.mark.platform_x86_cpu
1500@pytest.mark.platform_x86_gpu_training
1501@pytest.mark.env_onecard
1502def test_if_by_if_const_grad():
1503    class MyNet(nn.Cell):
1504        def __init__(self):
1505            super().__init__()
1506            self.add = P.Add()
1507
1508        def construct(self, *inputs):
1509            out = self.add(*inputs)
1510            return out
1511
1512    class GradNet(nn.Cell):
1513        def __init__(self, net):
1514            super(GradNet, self).__init__()
1515            self.net = net
1516            self.weights = ParameterTuple(net.trainable_params())
1517
1518        def construct(self, *inputs):
1519            a = 1
1520            b = 2
1521            if a > 0:
1522                b = 1
1523            if a < 0:
1524                b = 0
1525            if a == 0:
1526                b = 3
1527            a += b
1528            return grad_by_list(self.net, self.weights)(*inputs)
1529
1530    context.set_context(mode=context.GRAPH_MODE)
1531    my_net = MyNet()
1532    net = GradNet(my_net)
1533    a = Tensor(np.array(0), dtype=ms.int32)
1534    b = Tensor(np.array(1), dtype=ms.int32)
1535    net(a, b)
1536
1537
1538@pytest.mark.level1
1539@pytest.mark.platform_x86_cpu
1540@pytest.mark.platform_x86_gpu_training
1541@pytest.mark.env_onecard
1542def test_while_const_grad():
1543    class MyNet(nn.Cell):
1544        def __init__(self):
1545            super().__init__()
1546            self.add = P.Add()
1547
1548        def construct(self, *inputs):
1549            out = self.add(*inputs)
1550            return out
1551
1552    class GradNet(nn.Cell):
1553        def __init__(self, net):
1554            super(GradNet, self).__init__()
1555            self.net = net
1556            self.weights = ParameterTuple(net.trainable_params())
1557
1558        def construct(self, *inputs):
1559            a = 1
1560            while a > 1:
1561                a = a - 1
1562            return grad_by_list(self.net, self.weights)(*inputs)
1563
1564    context.set_context(mode=context.GRAPH_MODE)
1565    my_net = MyNet()
1566    net = GradNet(my_net)
1567    a = Tensor(np.array(0), dtype=ms.int32)
1568    b = Tensor(np.array(1), dtype=ms.int32)
1569    net(a, b)
1570
1571
1572@pytest.mark.level1
1573@pytest.mark.platform_x86_cpu
1574@pytest.mark.platform_x86_gpu_training
1575@pytest.mark.env_onecard
1576def test_if_by_while_const_grad():
1577    class MyNet(nn.Cell):
1578        def __init__(self):
1579            super().__init__()
1580            self.add = P.Add()
1581
1582        def construct(self, *inputs):
1583            out = self.add(*inputs)
1584            return out
1585
1586    class GradNet(nn.Cell):
1587        def __init__(self, net):
1588            super(GradNet, self).__init__()
1589            self.net = net
1590            self.weights = ParameterTuple(net.trainable_params())
1591
1592        def construct(self, *inputs):
1593            a = 1
1594            b = 2
1595            if a > 0:
1596                b = 0
1597            while a > 1:
1598                a = a - 1
1599            a += b
1600            return grad_by_list(self.net, self.weights)(*inputs)
1601
1602    context.set_context(mode=context.GRAPH_MODE)
1603    my_net = MyNet()
1604    net = GradNet(my_net)
1605    a = Tensor(np.array(0), dtype=ms.int32)
1606    b = Tensor(np.array(1), dtype=ms.int32)
1607    net(a, b)
1608