• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15""" test_multigraph_sink """
16import mindspore.context as context
17from mindspore.common import dtype as mstype
18from mindspore.common import ms_function
19from mindspore.common.tensor import Tensor
20
21
22def setup_module(module):
23    context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
24
25
26c1 = Tensor([2], mstype.int32)
27c2 = Tensor([14], mstype.int32)
28c3 = Tensor([1], mstype.int32)
29c4 = Tensor([0], mstype.int32)
30c5 = Tensor([14], mstype.int32)
31
32
33@ms_function
34def simple_if(x, y, z):
35    if x < y:
36        x = x + 1
37    else:
38        x = x + 2
39    x = x + 3
40    return x
41
42
43@ms_function
44def if_by_if(x, y, z):
45    if x < y:
46        x = x + 1
47    if y > x:
48        x = x + 2
49    x = x + 3
50    return x
51
52
53@ms_function
54def if_in_if(x, y, z):
55    out = c4
56    if x < y:
57        z = c4 + c4
58        if z < y:
59            z = z + 2
60            out = out + z
61        x = x + 3
62    out = out + x
63    return out
64
65
66@ms_function
67def simple_while(x, y, z):
68    y = y + 4
69    while x < y:
70        x = x + 1
71    x = x + 3
72    return x
73
74
75@ms_function
76def while_by_while(x, y, z):
77    while x < y:
78        x = x + 1
79    while z < c5:
80        z = z + 1
81        x = x + 1
82    x = x + 1
83    return x
84
85
86@ms_function
87def while_in_while(x, y, z):
88    out = c4
89    while x < y:
90        z = c4 + c4
91        while z < y:
92            z = z + 1
93            out = out + z
94        x = x + 1
95    out = out + x
96    return out
97
98
99def test_simple_if():
100    output = simple_if(c1, c2, c3)
101    expect = Tensor([6], mstype.int32)
102    assert output == expect
103
104
105def test_if_by_if():
106    output = if_by_if(c1, c2, c3)
107    expect = Tensor([8], mstype.int32)
108    assert output == expect
109
110
111def test_if_in_if():
112    output = if_in_if(c1, c2, c3)
113    expect = Tensor([7], mstype.int32)
114    assert output == expect
115
116
117def test_simple_while():
118    output = simple_while(c1, c2, c3)
119    expect = Tensor([21], mstype.int32)
120    assert output == expect
121
122
123def test_while_by_while():
124    output = while_by_while(c1, c2, c3)
125    expect = Tensor([28], mstype.int32)
126    assert output == expect
127
128
129def test_while_in_while():
130    output = while_in_while(c1, c2, c3)
131    expect = Tensor([1274], mstype.int32)
132    assert output == expect
133
134
135@ms_function
136def while_by_while_in_while(x, y, z):
137    out = c4
138    while x < c2:
139        y = c4 + c4
140        while y < c2:
141            y = y + 1
142        out = out + y
143        z = c4 + c4
144        while z < c2:
145            z = z + 1
146        out = out + z
147        x = x + 1
148    out = out + x
149    return out
150
151
152def test_while_by_while_in_while():
153    output = while_by_while_in_while(c1, c2, c3)
154    expect = Tensor([350], mstype.int32)
155    assert output == expect
156