• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15""" test_multigraph_sink """
16import pytest
17
18import mindspore.context as context
19from mindspore.common import dtype as mstype
20from mindspore.common import ms_function
21from mindspore.common.tensor import Tensor
22
23
24def setup_module():
25    context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
26
27
28c1 = Tensor([2], mstype.int32)
29c2 = Tensor([14], mstype.int32)
30c3 = Tensor([1], mstype.int32)
31c4 = Tensor([0], mstype.int32)
32c5 = Tensor([14], mstype.int32)
33
34
35@ms_function
36def simple_if(x, y):
37    if x < y:
38        x = x + 1
39    else:
40        x = x + 2
41    x = x + 3
42    return x
43
44
45@ms_function
46def if_by_if(x, y):
47    if x < y:
48        x = x + 1
49    if y > x:
50        x = x + 2
51    x = x + 3
52    return x
53
54
55@ms_function
56def if_in_if(x, y, z):
57    out = c4
58    if x < y:
59        z = c4 + c4
60        if z < y:
61            z = z + 2
62            out = out + z
63        x = x + 3
64    out = out + x
65    return out
66
67
68@ms_function
69def simple_while(x, y):
70    y = y + 4
71    while x < y:
72        x = x + 1
73    x = x + 3
74    return x
75
76
77@ms_function
78def while_by_while(x, y, z):
79    while x < y:
80        x = x + 1
81    while z < c5:
82        z = z + 1
83        x = x + 1
84    x = x + 1
85    return x
86
87
88@ms_function
89def while_in_while(x, y, z):
90    out = c4
91    while x < y:
92        z = c4 + c4
93        while z < y:
94            z = z + 1
95            out = out + z
96        x = x + 1
97    out = out + x
98    return out
99
100
101@ms_function
102def while_by_while_in_while(x, y, z):
103    out = c4
104    while x < c2:
105        y = c4 + c4
106        while y < c2:
107            y = y + 1
108        out = out + y
109        z = c4 + c4
110        while z < c2:
111            z = z + 1
112        out = out + z
113        x = x + 1
114    out = out + x
115    return out
116
117
118@ms_function
119def while_in_while_in_while(x, y, z):
120    out = c4
121    while x < c2:
122        y = c4 + c4
123        while y < c2:
124            y = y + 1
125            z = c4 + c4
126            while z < c2:
127                z = z + 1
128            out = out + z
129        out = out + y
130        x = x + 1
131    out = out + x
132    return out
133
134
135@pytest.mark.level1
136@pytest.mark.platform_x86_ascend_training
137@pytest.mark.platform_arm_ascend_training
138@pytest.mark.env_onecard
139def test_simple_if():
140    output = simple_if(c1, c2)
141    expect = Tensor([6], mstype.int32)
142    assert output == expect
143
144
145def test_if_by_if():
146    output = if_by_if(c1, c2)
147    expect = Tensor([8], mstype.int32)
148    assert output == expect
149
150
151@pytest.mark.level1
152@pytest.mark.platform_x86_ascend_training
153@pytest.mark.platform_arm_ascend_training
154@pytest.mark.env_onecard
155def test_if_in_if():
156    output = if_in_if(c1, c2, c3)
157    expect = Tensor([7], mstype.int32)
158    assert output == expect
159
160
161@pytest.mark.level1
162@pytest.mark.platform_x86_ascend_training
163@pytest.mark.platform_arm_ascend_training
164@pytest.mark.env_onecard
165def test_simple_while():
166    output = simple_while(c1, c2)
167    expect = Tensor([21], mstype.int32)
168    assert output == expect
169
170
171@pytest.mark.level1
172@pytest.mark.platform_x86_ascend_training
173@pytest.mark.platform_arm_ascend_training
174@pytest.mark.env_onecard
175def test_while_by_while():
176    output = while_by_while(c1, c2, c3)
177    expect = Tensor([28], mstype.int32)
178    assert output == expect
179
180
181@pytest.mark.level1
182@pytest.mark.platform_x86_ascend_training
183@pytest.mark.platform_arm_ascend_training
184@pytest.mark.env_onecard
185def test_while_in_while():
186    output = while_in_while(c1, c2, c3)
187    expect = Tensor([1274], mstype.int32)
188    assert output == expect
189
190
191@pytest.mark.level1
192@pytest.mark.platform_x86_ascend_training
193@pytest.mark.platform_arm_ascend_training
194@pytest.mark.env_onecard
195def test_while_by_while_in_while():
196    output = while_by_while_in_while(c1, c2, c3)
197    expect = Tensor([350], mstype.int32)
198    assert output == expect
199
200
201@pytest.mark.level1
202@pytest.mark.platform_x86_ascend_training
203@pytest.mark.platform_arm_ascend_training
204@pytest.mark.env_onecard
205def test_while_in_while_in_while():
206    output = while_in_while_in_while(c1, c2, c3)
207    expect = Tensor([2534], mstype.int32)
208    assert output == expect
209