• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2022 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from mindspore import Tensor, jit, ops, mutable, nn, lazy_inline, export, load, context
16from mindspore.common import dtype as mstype
17from mindspore.common.parameter import Parameter
18from mindspore.nn import Cell, GraphCell
19import mindspore.ops.operations as P
20import numpy as np
21import pytest
22
23@pytest.mark.level1
24@pytest.mark.platform_arm_ascend_training
25@pytest.mark.platform_x86_ascend_training
26@pytest.mark.env_onecard
27def test_single_if():
28    """
29    Feature: Contrtol flow inline.
30    Description: Inline switch node into kernel graph.
31    Expectation: Not throw exception.
32    """
33    param_a = Parameter(Tensor(5, mstype.int32), name='a')
34    param_b = Parameter(Tensor(4, mstype.int32), name='b')
35
36    @jit
37    def foo(x, y, param_a, param_b):
38        if param_a > param_b:
39            param_b += 1
40        return x + param_b, y + param_b
41
42    x = Tensor(2, mstype.int32)
43    ret1 = foo(x, x, param_a, param_b)
44    ret2 = foo(x, x, param_a, param_b)
45    assert ret1 == (Tensor(7, mstype.int32), Tensor(7, mstype.int32))
46    assert ret2
47
48
49@pytest.mark.level1
50@pytest.mark.platform_arm_ascend_training
51@pytest.mark.platform_x86_ascend_training
52@pytest.mark.env_onecard
53def test_return_parameter():
54    """
55    Feature: Contrtol flow inline.
56    Description: Control flow if.
57    Expectation: AttributeError.
58    """
59    param_a = Parameter(Tensor(5))
60    param_b = Parameter(Tensor(5))
61
62    @jit
63    def foo(x, param_a, param_b):
64        if x < 3:
65            return param_a
66        return param_b
67
68    ret1 = foo(Tensor(1), param_a, param_b)
69    assert ret1
70
71
72@pytest.mark.level1
73@pytest.mark.platform_arm_ascend_training
74@pytest.mark.platform_x86_ascend_training
75@pytest.mark.env_onecard
76def test_return_param_untail_call():
77    """
78    Feature: Contrtol flow inline.
79    Description: Control flow if.
80    Expectation: AttributeError.
81    """
82    param_a = Parameter(Tensor(5))
83    param_b = Parameter(Tensor(6))
84
85    @jit
86    def foo(x, param_a, param_b):
87        if x < 3:
88            z = param_a
89        else:
90            z = param_b
91        z = z + 1
92        z = z - 2
93        z = z * 3
94        z = z / 4
95        return z
96
97    ret1 = foo(Tensor(1), param_a, param_b)
98    assert ret1
99
100
101@pytest.mark.level1
102@pytest.mark.platform_arm_ascend_training
103@pytest.mark.platform_x86_ascend_training
104@pytest.mark.env_onecard
105def test_return_valuenode():
106    """
107    Feature: Contrtol flow inline.
108    Description: Control flow if.
109    Expectation: AttributeError.
110    """
111
112    @jit
113    def foo(x):
114        if x < 3:
115            return 1
116        return 2
117
118    ret1 = foo(Tensor(1))
119    assert ret1
120
121
122@pytest.mark.level1
123@pytest.mark.platform_arm_ascend_training
124@pytest.mark.platform_x86_ascend_training
125@pytest.mark.env_onecard
126def test_return_input():
127    """
128    Feature: Contrtol flow inline.
129    Description: Control flow if.
130    Expectation: AttributeError.
131    """
132
133    @jit
134    def foo(x, y, z):
135        if x < 3:
136            return y
137        return z
138
139    ret1 = foo(Tensor(1), Tensor(2), Tensor(3))
140    assert ret1
141
142
143@pytest.mark.level0
144@pytest.mark.platform_arm_ascend_training
145@pytest.mark.platform_x86_ascend_training
146@pytest.mark.env_onecard
147def test_value_node_output_in_single_branch():
148    """
149    Feature: Contrtol flow inline.
150    Description: Inline switch node into kernel graph.
151    Expectation: Not throw exception.
152    """
153
154    @jit
155    def BranchReturnTensor(x, y):
156        x = x + Tensor(2, mstype.int32)
157        y = x + y
158        if x < 5:
159            return y, Tensor(2, mstype.int32)
160        return x, y
161
162    x = Tensor(2, mstype.int32)
163    ret1 = BranchReturnTensor(x, x)
164    ret2 = BranchReturnTensor(x, x)
165    ret3 = BranchReturnTensor(x, x)
166    assert ret1
167    assert ret2
168    assert ret3
169
170
171@pytest.mark.level0
172@pytest.mark.platform_arm_ascend_training
173@pytest.mark.platform_x86_ascend_training
174@pytest.mark.env_onecard
175def test_diff_ref_count_in_branch():
176    """
177    Feature: Contrtol flow inline.
178    Description: Inline switch node into kernel graph.
179    Expectation: Not throw exception.
180    """
181
182    @jit
183    def BranchDiffRefCount(x, y):
184        x = x + Tensor(2, mstype.int32)
185        y = x + y
186        if x < 5:
187            x = x + 3
188            y = x + y
189        else:
190            x = x + 3
191            x = x + 4
192            x = x + 5
193            y = x + y
194            y = x + y
195            y = x + y
196        return x, y
197
198    x = Tensor(2, mstype.int32)
199    ret1 = BranchDiffRefCount(x, x)
200    x = Tensor(4, mstype.int32)
201    ret2 = BranchDiffRefCount(x, x)
202    assert ret1
203    assert ret2
204
205
206@pytest.mark.level1
207@pytest.mark.platform_arm_ascend_training
208@pytest.mark.platform_x86_ascend_training
209@pytest.mark.env_onecard
210def test_branch_kernel_backoff():
211    """
212    Feature: Contrtol flow inline.
213    Description: Inline switch node into kernel graph.
214    Expectation: Not throw exception.
215    """
216
217    @jit
218    def foo(x, y, shape):
219        x = x + Tensor(2, mstype.int32)
220        if y < 5:
221            z = ops.reshape(x, shape)
222        else:
223            z = x
224        return z + 1
225
226    x = Tensor([2, 2, 2, 2, 2, 2], mstype.int32)
227    y = Tensor(2, mstype.int32)
228    ret1 = foo(x, y, mutable((2, 3)))
229    ret2 = foo(x, y, mutable((2, 3)))
230    ret3 = foo(x, y, mutable((2, 3)))
231    assert ret1[0][0]
232    assert ret2[0][0]
233    assert ret3[0][0]
234
235
236@pytest.mark.level0
237@pytest.mark.platform_arm_ascend_training
238@pytest.mark.platform_x86_ascend_training
239@pytest.mark.env_onecard
240def test_update_parameter():
241    """
242    Feature: Contrtol flow inline.
243    Description: Control flow if.
244    Expectation: AttributeError.
245    """
246
247    param_a = Parameter(Tensor(5))
248
249    @jit
250    def foo(x, param_a):
251        x = x + param_a
252        if x < 3:
253            param_a = param_a + 2
254        else:
255            param_a = param_a + x
256        return param_a
257
258    ret1 = foo(Tensor(1), param_a)
259    ret2 = foo(Tensor(1), param_a)
260    ret3 = foo(Tensor(1), param_a)
261    assert ret1
262    assert ret2
263    assert ret3
264
265
266@pytest.mark.level1
267@pytest.mark.platform_arm_ascend_training
268@pytest.mark.platform_x86_ascend_training
269@pytest.mark.env_onecard
270def test_update_and_return_parameter():
271    """
272    Feature: Contrtol flow inline.
273    Description: Control flow if.
274    Expectation: AttributeError.
275    """
276
277    param_a = Parameter(Tensor(5))
278    param_b = Parameter(Tensor(5))
279
280    @jit
281    def foo(x, param_a, param_b):
282        x = x + param_a
283        if x < 3:
284            param_a = param_a + 2
285            param_b = param_b - param_a
286            return Tensor(2), param_b
287        param_a = param_a + x
288        param_b = param_b + param_a
289        return param_a, param_b
290
291    ret1 = foo(Tensor(1), param_a, param_b)
292    ret2 = foo(Tensor(1), param_a, param_b)
293    ret3 = foo(Tensor(1), param_a, param_b)
294    assert ret1
295    assert ret2
296    assert ret3
297
298
299@pytest.mark.level1
300@pytest.mark.platform_arm_ascend_training
301@pytest.mark.platform_x86_ascend_training
302@pytest.mark.env_onecard
303def test_return_switch_input_in_branch():
304    """
305    Feature: Contrtol flow inline.
306    Description: Control flow if.
307    Expectation: AttributeError.
308    """
309
310    param_a = Parameter(Tensor(5))
311    param_b = Parameter(Tensor(5))
312
313    @jit
314    def foo(x, param_a, param_b):
315        x = x + param_a
316        if x < 3:
317            param_a = param_a + 2
318            param_b = param_b - param_a
319            return x, param_b
320        param_a = param_a + x
321        param_b = param_b + param_a
322        return param_a, param_b
323
324    ret1 = foo(Tensor(1), param_a, param_b)
325    ret2 = foo(Tensor(1), param_a, param_b)
326    ret3 = foo(Tensor(1), param_a, param_b)
327    assert ret1
328    assert ret2
329    assert ret3
330
331
332@pytest.mark.level1
333@pytest.mark.platform_arm_ascend_training
334@pytest.mark.platform_x86_ascend_training
335@pytest.mark.env_onecard
336def test_return_switch_input():
337    """
338    Feature: Contrtol flow inline.
339    Description: Control flow if.
340    Expectation: AttributeError.
341    """
342
343    param_a = Parameter(Tensor(5))
344    param_b = Parameter(Tensor(5))
345
346    @jit
347    def foo(x, param_a, param_b):
348        x = x + param_a
349        if x < 3:
350            param_a = param_a + 2
351            param_b = param_b - param_a
352        else:
353            param_a = param_a + x
354            param_b = param_b + param_a
355        return x, param_b, 3
356
357    ret1 = foo(Tensor(1), param_a, param_b)
358    ret2 = foo(Tensor(1), param_a, param_b)
359    ret3 = foo(Tensor(1), param_a, param_b)
360    assert ret1
361    assert ret2
362    assert ret3
363
364
365@pytest.mark.level0
366@pytest.mark.platform_arm_ascend_training
367@pytest.mark.platform_x86_ascend_training
368@pytest.mark.env_onecard
369def test_tuple_args_to_dynamic_tuple_para():
370    """
371    Feature: Contrtol flow inline.
372    Description: Control flow if.
373    Expectation: AttributeError.
374    """
375
376    @jit
377    def foo(x, y):
378        y_shape = ops.shape(y)
379        if x < 3:
380            y_shape = y_shape * 2
381        else:
382            y_shape = y_shape * 3
383        return y_shape[0]
384
385    ret1 = foo(Tensor(1), Tensor([[6, 6, 6], [6, 6, 6]]))
386    ret2 = foo(Tensor(1), Tensor([[6, 6, 6], [6, 6, 6]]))
387    ret3 = foo(Tensor(1), Tensor([[6, 6, 6], [6, 6, 6]]))
388    assert ret1
389    assert ret2
390    assert ret3
391
392
393@pytest.mark.level0
394@pytest.mark.platform_arm_ascend_training
395@pytest.mark.platform_x86_ascend_training
396@pytest.mark.env_onecard
397def test_tuple_input_to_switch():
398    """
399    Feature: Contrtol flow inline.
400    Description: Control flow if.
401    Expectation: AttributeError.
402    """
403
404    @jit
405    def foo(x, y, dst_shape):
406        y, _ = ops.unique(y)
407        y = ops.reshape(y, dst_shape)
408        y_shape = ops.shape(y)
409        if x < 3:
410            y_shape = y_shape * 2
411        else:
412            y_shape = y_shape * 3
413        return y_shape
414
415    ret1 = foo(Tensor(1), Tensor([[6, 12, 18], [24, 30, 36]]), mutable((2, 3)))
416    ret2 = foo(Tensor(1), Tensor([[6, 12, 18], [24, 30, 36]]), mutable((2, 3)))
417    ret3 = foo(Tensor(1), Tensor([[6, 12, 18], [24, 30, 36]]), mutable((2, 3)))
418    assert ret1[0]
419    assert ret2[0]
420    assert ret3[0]
421
422
423@pytest.mark.level0
424@pytest.mark.platform_arm_ascend_training
425@pytest.mark.platform_x86_ascend_training
426@pytest.mark.env_onecard
427def test_dynamic_tuple_input_to_switch():
428    """
429    Feature: Contrtol flow inline.
430    Description: Control flow if.
431    Expectation: AttributeError.
432    """
433
434    @jit
435    def foo(x, dyn_tuple):
436        if x < 3:
437            dyn_tuple = dyn_tuple * 2
438        else:
439            dyn_tuple = dyn_tuple * 3
440        return dyn_tuple
441
442    ret1 = foo(Tensor(1), mutable((2, 3), dynamic_len=True))
443    ret2 = foo(Tensor(1), mutable((2, 3), dynamic_len=True))
444    ret3 = foo(Tensor(1), mutable((2, 3), dynamic_len=True))
445    assert ret1
446    assert ret2
447    assert ret3
448
449
450@pytest.mark.level1
451@pytest.mark.platform_arm_ascend_training
452@pytest.mark.platform_x86_ascend_training
453@pytest.mark.env_onecard
454def test_return_condition():
455    """
456    Feature: Contrtol flow inline.
457    Description: Control flow if.
458    Expectation: AttributeError.
459    """
460
461    @jit
462    def foo(x, cond):
463        if cond:
464            x = x * 2
465            return x, cond
466        x = x * 3
467        return x, cond
468
469    ret1 = foo(Tensor(1), Tensor(True))
470    ret2 = foo(Tensor(1), Tensor(True))
471    ret3 = foo(Tensor(1), Tensor(True))
472    assert ret1
473    assert ret2
474    assert ret3
475
476
477@pytest.mark.level0
478@pytest.mark.platform_arm_ascend_training
479@pytest.mark.platform_x86_ascend_training
480@pytest.mark.env_onecard
481def test_return_include_other_output():
482    """
483    Feature: Contrtol flow inline.
484    Description: Control flow if.
485    Expectation: AttributeError.
486    """
487
488    @jit
489    def foo(x, y):
490        y = y + 2
491        y = y * 3
492        y = y / 4
493        y = y - 5
494        y = y * y
495        if x < 5:
496            x = x * 2
497        else:
498            x = x + 2
499        return x, y
500
501    ret1 = foo(Tensor(1), Tensor(2))
502    ret2 = foo(Tensor(1), Tensor(2))
503    ret3 = foo(Tensor(1), Tensor(2))
504    assert ret1
505    assert ret2
506    assert ret3
507
508
509@pytest.mark.level1
510@pytest.mark.platform_arm_ascend_training
511@pytest.mark.platform_x86_ascend_training
512@pytest.mark.env_onecard
513def test_branch_output_include_refnode_with_dynamic_shape():
514    """
515    Feature: Contrtol flow inline.
516    Description: Control flow if.
517    Expectation: AttributeError.
518    """
519
520    @jit
521    def foo(x, y, dst_shape):
522        y, _ = ops.unique(y)
523        y = ops.reshape(y, dst_shape)
524        if x < 3:
525            y = ops.expand_dims(y, 1)
526            y = ops.flatten(y)
527        return y
528
529    ret1 = foo(Tensor(1), Tensor([[6, 12, 18], [24, 30, 36], [6, 18, 36]]), mutable((2, 3)))
530    ret2 = foo(Tensor(1), Tensor([[6, 12, 18], [24, 30, 36]]), mutable((2, 3)))
531    ret3 = foo(Tensor(1), Tensor([[6, 12, 18], [24, 30, 36]]), mutable((2, 3)))
532    assert ret1[0][0]
533    assert ret2[0][0]
534    assert ret3[0][0]
535
536
537@pytest.mark.level1
538@pytest.mark.platform_arm_ascend_training
539@pytest.mark.platform_x86_ascend_training
540@pytest.mark.env_onecard
541def test_branch_output_include_refnode_true():
542    """
543    Feature: Contrtol flow inline.
544    Description: Control flow if.
545    Expectation: AttributeError.
546    """
547
548    @jit
549    def foo(x, y):
550        if x < 3:
551            y = ops.expand_dims(y, 1)
552            y = ops.flatten(y)
553            y = y + Tensor([[6, 12], [18, 24], [30, 36]])
554        return y
555
556    ret1 = foo(Tensor(1), Tensor([[6, 12], [18, 24], [30, 36]]))
557    ret2 = foo(Tensor(1), Tensor([[6, 12], [18, 24], [30, 36]]))
558    ret3 = foo(Tensor(1), Tensor([[6, 12], [18, 24], [30, 36]]))
559    assert ret1.shape
560    assert ret2.shape
561    assert ret3.shape
562
563
564@pytest.mark.level0
565@pytest.mark.platform_arm_ascend_training
566@pytest.mark.platform_x86_ascend_training
567@pytest.mark.env_onecard
568def test_branch_output_include_refnode_false():
569    """
570    Feature: Contrtol flow inline.
571    Description: Control flow if.
572    Expectation: AttributeError.
573    """
574
575    @jit
576    def foo(x, y):
577        if x > 3:
578            y = ops.expand_dims(y, 1)
579            y = ops.flatten(y)
580            y = y + Tensor([[6, 12], [18, 24], [30, 36]])
581        else:
582            z = y + Tensor([[36, 30], [24, 18], [12, 6]])
583            y = y + Tensor([[36, 30], [24, 18], [12, 36]])
584            y = z + y
585        return y * 2
586
587    ret1 = foo(Tensor(1), Tensor([[6, 12], [18, 24], [30, 36]]))
588    ret2 = foo(Tensor(1), Tensor([[6, 12], [18, 24], [30, 36]]))
589    ret3 = foo(Tensor(1), Tensor([[6, 12], [18, 24], [30, 36]]))
590    assert ret1.shape
591    assert ret2.shape
592    assert ret3.shape
593
594
595@pytest.mark.level0
596@pytest.mark.platform_arm_ascend_training
597@pytest.mark.platform_x86_ascend_training
598@pytest.mark.env_onecard
599def test_branch_output_include_refnode_output_ref():
600    """
601    Feature: Contrtol flow inline.
602    Description: Control flow if.
603    Expectation: AttributeError.
604    """
605
606    @jit
607    def foo(x, y):
608        if x > 3:
609            y = ops.expand_dims(y, 1)
610            y = ops.flatten(y)
611        else:
612            z = y + Tensor([[36, 30], [24, 18], [12, 6]])
613            y = y + Tensor([[36, 30], [24, 18], [12, 36]])
614            y = z + y
615        return y * 2
616
617    ret1 = foo(Tensor(1), Tensor([[6, 12], [18, 24], [30, 36]]))
618    ret2 = foo(Tensor(1), Tensor([[6, 12], [18, 24], [30, 36]]))
619    ret3 = foo(Tensor(1), Tensor([[6, 12], [18, 24], [30, 36]]))
620    assert ret1.shape
621    assert ret2.shape
622    assert ret3.shape
623
624
625@pytest.mark.level0
626@pytest.mark.platform_arm_ascend_training
627@pytest.mark.platform_x86_ascend_training
628@pytest.mark.env_onecard
629def test_branch_output_include_refnode_twice():
630    """
631    Feature: Contrtol flow inline.
632    Description: Control flow if.
633    Expectation: AttributeError.
634    """
635
636    @jit
637    def foo(x, y):
638        if x > 3:
639            y = ops.expand_dims(y, 1)
640            z1 = ops.flatten(y)
641            z2 = ops.reshape(y, (3, 2))
642            z3 = z2 * 2
643            z4 = z2 * 3
644            y = z1 + z2 + z3 + z4
645        else:
646            z = y + Tensor([[36, 30], [24, 18], [12, 6]])
647            y = y + Tensor([[36, 30], [24, 18], [12, 36]])
648            y = z + y
649        return y * 2
650
651    ret1 = foo(Tensor(1), Tensor([[6, 12], [18, 24], [30, 36]]))
652    ret2 = foo(Tensor(1), Tensor([[6, 12], [18, 24], [30, 36]]))
653    ret3 = foo(Tensor(1), Tensor([[6, 12], [18, 24], [30, 36]]))
654    assert ret1.shape
655    assert ret2.shape
656    assert ret3.shape
657
658
659@pytest.mark.level1
660@pytest.mark.platform_arm_ascend_training
661@pytest.mark.platform_x86_ascend_training
662@pytest.mark.env_onecard
663def test_include_dynamic_shape():
664    """
665    Feature: Contrtol flow inline.
666    Description: Control flow if.
667    Expectation: AttributeError.
668    """
669
670    @jit
671    def foo(x, y):
672        y, _ = ops.unique(y)
673        if x < 3:
674            y = y * 2
675        else:
676            z1 = y / 6
677            z2 = y * 2
678            z3 = y - Tensor([[6, 12, 18], [24, 30, 36]])
679            z4 = y + Tensor([[1, 2, 3], [4, 5, 6]])
680            y = z1 + z2 + z3 + z4
681        return y
682
683    ret1 = foo(Tensor(1), Tensor([[6, 12, 18], [24, 30, 36], [6, 18, 36]]))
684    ret2 = foo(Tensor(1), Tensor([[6, 12, 18], [24, 30, 36], [12, 18, 30], [18, 24, 36]]))
685    ret3 = foo(Tensor(1), Tensor([[6, 12, 18], [24, 30, 36]]))
686    assert ret1[0]
687    assert ret2[0]
688    assert ret3[0]
689
690
691@pytest.mark.level0
692@pytest.mark.platform_arm_ascend_training
693@pytest.mark.platform_x86_ascend_training
694@pytest.mark.env_onecard
695def test_control_arrow_from_switch_to_gather():
696    """
697    Feature: Contrtol flow inline.
698    Description: Control flow if.
699    Expectation: AttributeError.
700    """
701    param_a = Parameter(Tensor(5))
702    param_b = Parameter(Tensor(5))
703
704    @jit
705    def foo(x, param_a, param_b):
706        x = x + param_a
707        if x < 3:
708            param_a = param_a + 2
709            param_b = param_b - param_a
710            return Tensor(2), param_b
711        x = x + param_a
712        return param_a, param_b
713
714    ret1 = foo(Tensor(1), param_a, param_b)
715    ret2 = foo(Tensor(1), param_a, param_b)
716    ret3 = foo(Tensor(1), param_a, param_b)
717    assert ret1
718    assert ret2
719    assert ret3
720
721
722@pytest.mark.level0
723@pytest.mark.platform_arm_ascend_training
724@pytest.mark.platform_x86_ascend_training
725@pytest.mark.env_onecard
726def test_branch_only_u_input():
727    """
728    Feature: Contrtol flow inline.
729    Description: Control flow if.
730    Expectation: AttributeError.
731    """
732
733    @jit
734    def foo(x, y):
735        x = x + 1
736        if x < 3:
737            ops.print("this is true")
738        else:
739            y = ops.reshape(y, (4, 1))
740            ops.print("this is false")
741        return ops.shape(y)
742
743    ret1 = foo(Tensor(1), Tensor([[1, 2], [3, 4]]))
744    assert ret1
745
746
747@pytest.mark.level0
748@pytest.mark.platform_arm_ascend_training
749@pytest.mark.platform_x86_ascend_training
750@pytest.mark.env_onecard
751def test_branch_u_input_and_input():
752    """
753    Feature: Contrtol flow inline.
754    Description: Control flow if.
755    Expectation: AttributeError.
756    """
757
758    @jit
759    def foo(x, y):
760        x = x + 1
761        if x < 3:
762            ops.print("this is true")
763        else:
764            y = ops.reshape(y, (4, 1))
765            ops.print("this is false")
766        return ops.shape(y)
767
768    ret1 = foo(Tensor(1), Tensor([[1, 2], [3, 4]]))
769    assert ret1
770
771
772@pytest.mark.level0
773@pytest.mark.platform_arm_ascend_training
774@pytest.mark.platform_x86_ascend_training
775@pytest.mark.env_onecard
776def test_branch_output_real_tuple():
777    """
778    Feature: Contrtol flow inline.
779    Description: Control flow if.
780
781    Expectation: AttributeError.
782    """
783
784    @jit
785    def foo(x, y):
786        if x < 3:
787            y, _ = ops.unique(y)
788            y = ops.expand_dims(y, 1)
789            y = ops.flatten(y)
790            z = ops.shape(y)
791        else:
792            z = ops.shape(y)
793        return z
794
795    ret1 = foo(Tensor(1), Tensor([[6, 12, 18], [24, 30, 36], [6, 18, 36]]))
796    ret2 = foo(Tensor(5), Tensor([[6, 12, 18], [24, 30, 36]]))
797    assert ret1
798    assert ret2
799
800
801@pytest.mark.level0
802@pytest.mark.platform_arm_ascend_training
803@pytest.mark.platform_x86_ascend_training
804@pytest.mark.env_onecard
805def test_branch_output_dynamic_tuple():
806    """
807    Feature: Contrtol flow inline.
808    Description: Control flow if.
809    Expectation: AttributeError.
810    """
811
812    @jit
813    def foo(x, y, shape):
814        if y < 5:
815            z = ops.reshape(x, shape)
816            out = ops.shape(z)
817        else:
818            out = ops.shape(x)
819        return out
820
821    x = Tensor([2, 2, 2, 2, 2, 2], mstype.int32)
822    y = Tensor(2, mstype.int32)
823    ret1 = foo(x, y, mutable((2, 3), dynamic_len=True))
824    assert ret1[0]
825
826
827@pytest.mark.level0
828@pytest.mark.platform_arm_ascend_training
829@pytest.mark.platform_x86_ascend_training
830@pytest.mark.env_onecard
831def test_if_after_if():
832    """
833    Feature: Contrtol flow inline.
834    Description: Inline switch node into kernel graph.
835    Expectation: Not throw exception.
836    """
837    param_a = Parameter(Tensor(5, mstype.int32), name='a')
838    param_b = Parameter(Tensor(4, mstype.int32), name='b')
839
840    @jit
841    def foo(x, y, param_a, param_b):
842        if param_a > param_b:
843            param_b += 1
844        if param_a + param_b > 10:
845            param_a += 3
846        return x + param_b, y + param_b
847
848    x = Tensor(2, mstype.int32)
849    ret1 = foo(x, x, param_a, param_b)
850    ret2 = foo(x, x, param_a, param_b)
851    assert ret1 == (Tensor(7, mstype.int32), Tensor(7, mstype.int32))
852    assert ret2
853
854
855@pytest.mark.level0
856@pytest.mark.platform_arm_ascend_training
857@pytest.mark.platform_x86_ascend_training
858@pytest.mark.env_onecard
859def test_if_in_if():
860    """
861    Feature: Contrtol flow inline.
862    Description: Inline switch node into kernel graph.
863    Expectation: Not throw exception.
864    """
865    param_a = Parameter(Tensor(5, mstype.int32), name='a')
866    param_b = Parameter(Tensor(4, mstype.int32), name='b')
867
868    @jit
869    def foo(x, y, param_a, param_b):
870        if param_a > param_b:
871            param_b += 1
872            if param_a + param_b > 10:
873                param_a += 3
874        return x + param_b, y + param_b
875
876    x = Tensor(2, mstype.int32)
877    ret1 = foo(x, x, param_a, param_b)
878    ret2 = foo(x, x, param_a, param_b)
879    assert ret1 == (Tensor(7, mstype.int32), Tensor(7, mstype.int32))
880    assert ret2
881
882
883@pytest.mark.level0
884@pytest.mark.platform_arm_ascend_training
885@pytest.mark.platform_x86_ascend_training
886@pytest.mark.env_onecard
887def test_output_ref_of_parameter():
888    """
889    Feature: Contrtol flow inline.
890    Description: Inline switch node into kernel graph.
891    Expectation: Not throw exception.
892    """
893    param_a = Parameter(Tensor(5, mstype.int32), name='a')
894
895    @jit
896    def foo(x, y, param_a):
897        if x > y:
898            out = ops.addn([x, x, param_a])
899        else:
900            out = ops.assign(param_a, x)
901        return out
902
903    x = Tensor(2, mstype.int32)
904    y = Tensor(1, mstype.int32)
905    ret1 = foo(x, x, param_a)
906    ret2 = foo(x, y, param_a)
907    assert ret1
908    assert ret2
909
910
911@pytest.mark.level0
912@pytest.mark.platform_arm_ascend_training
913@pytest.mark.platform_x86_ascend_training
914@pytest.mark.env_onecard
915def test_gather_switch_gather_output():
916    """
917    Feature: Contrtol flow inline.
918    Description: Inline switch node into kernel graph.
919    Expectation: Not throw exception.
920    """
921    param_a = Parameter(Tensor(5, mstype.int32), name='a')
922
923    @jit
924    def foo(x, y, param_a):
925        if x > y:
926            out = param_a
927        else:
928            out = ops.addn([x, x, x])
929        if x > y:
930            out = ops.assign(param_a, x)
931        return out
932
933    x = Tensor(1, mstype.int32)
934    y = Tensor(1, mstype.int32)
935    ret1 = foo(x, y, param_a)
936    assert ret1
937
938
939@pytest.mark.level0
940@pytest.mark.platform_arm_ascend_training
941@pytest.mark.platform_x86_ascend_training
942@pytest.mark.env_onecard
943def test_if_in_if_directly():
944    """
945    Feature: Contrtol flow inline.
946    Description: Inline switch node into kernel graph.
947    Expectation: Not throw exception.
948    """
949    param_a = Parameter(Tensor(5, mstype.int32), name='a')
950    param_b = Parameter(Tensor(4, mstype.int32), name='b')
951
952    @jit
953    def foo(x, y, param_a, param_b):
954        x = x + 2
955        if param_a > param_b:
956            if x > y:
957                x += 3
958            x = x + param_a
959        y = x + y
960        return y
961
962    x = Tensor(2, mstype.int32)
963    ret1 = foo(x, x, param_a, param_b)
964    ret2 = foo(x, x, param_a, param_b)
965    assert ret1
966    assert ret2
967
968
969@pytest.mark.level0
970@pytest.mark.platform_arm_ascend_training
971@pytest.mark.platform_x86_ascend_training
972@pytest.mark.env_onecard
973def test_lazy_inline():
974    """
975    Feature: Switch inline with lazy inline.
976    Description: All inline in single graph.
977    Expectation: Run successfully and the memory usage is reduced.
978    """
979    class Grad(Cell):
980        def __init__(self, net):
981            super(Grad, self).__init__()
982            self.grad = ops.GradOperation()
983            self.net = net
984
985        def construct(self, x):
986            grad_net = self.grad(self.net)
987            return grad_net(x)
988
989    class Block(Cell):
990        def __init__(self):
991            super(Block, self).__init__()
992            self.batch_matmul = P.BatchMatMul()
993            self.expand_dims = P.ExpandDims()
994            self.y = Parameter(Tensor(np.ones((8)).astype(np.float32)))
995
996        def construct(self, x):
997            z1 = self.batch_matmul(x, x)
998            z2 = self.expand_dims(self.y, 1)
999            return z1 + z2
1000
1001    class BaseBlock(Cell):
1002        @lazy_inline
1003        def __init__(self):
1004            super(BaseBlock, self).__init__()
1005            self.block = Block()
1006
1007        def construct(self, x):
1008            return self.block(x)
1009
1010    class Net(Cell):
1011        def __init__(self):
1012            super(Net, self).__init__()
1013            self.blocks = nn.CellList()
1014            b = BaseBlock()
1015            self.blocks.append(b)
1016
1017        def construct(self, x):
1018            out = x
1019            for i in range(1):
1020                out = self.blocks[i](out)
1021            return out
1022    class GradNet(Cell):
1023        def __init__(self, net):
1024            super(GradNet, self).__init__()
1025            self.grad_net = Grad(net)
1026            self.a = Parameter(Tensor(np.ones((8)).astype(np.float32)))
1027            self.b = Parameter(Tensor(np.ones((8)).astype(np.float32)))
1028
1029        def construct(self, x, y):
1030            out = self.grad_net(x)
1031            if y > 3:
1032                return out * 2, self.a
1033            return out, self.b
1034
1035    x = Tensor(np.ones((8, 8)).astype(np.float32))
1036    y = Tensor(6)
1037    net = Net()
1038    grad_net = GradNet(net)
1039    grad_net(x, y)
1040    grad_net(x, y)
1041
1042
1043class TupleParaNet(Cell):
1044    def __init__(self):
1045        super(TupleParaNet, self).__init__()
1046        self.add = ops.Add()
1047    def construct(self, paralist):
1048        length = len(list)
1049        if length >= 2:
1050            x1 = paralist[0]
1051            x2 = paralist[length - 1]
1052            return self.add(x1, x2)
1053        return paralist[0]
1054
1055
1056@pytest.mark.level1
1057@pytest.mark.platform_arm_ascend_training
1058@pytest.mark.platform_x86_ascend_training
1059@pytest.mark.env_onecard
1060def test_tuple_parameter():
1061    """
1062    Feature: Contrtol flow inline.
1063    Description: Tuple parameter.
1064    Expectation: Not throw exception.
1065    """
1066    context.set_context(mode=context.GRAPH_MODE, jit_config={"jit_level": "O0"})
1067    net = TupleParaNet()
1068    input_2_ele = mutable((2, 3), dynamic_len=True)
1069    export(net, input_2_ele, file_name="test.mindir", file_format="MINDIR")
1070    input_3_ele = mutable((2, 2, 3), dynamic_len=False)
1071    y = load("test.mindir")
1072    mindir_load = GraphCell(y)
1073    print(mindir_load(input_3_ele))
1074
1075
1076@pytest.mark.level1
1077@pytest.mark.platform_arm_ascend_training
1078@pytest.mark.platform_x86_ascend_training
1079@pytest.mark.env_onecard
1080def test_call_same_graph():
1081    """
1082    Feature: Contrtol flow inline.
1083    Description: Two call node call same graph.
1084    Expectation: Not throw exception.
1085    """
1086    param_a = Parameter(Tensor(5, mstype.float32), name='a')
1087
1088    @jit
1089    def foo(x, y, param_a):
1090        out = Tensor(1, mstype.float32)
1091        for i in range(0, 2):
1092            if x + i < y:
1093                out += param_a
1094                break
1095        return out
1096
1097    x = Tensor(2, mstype.int32)
1098    ret = foo(x, x, param_a)
1099    assert ret
1100