1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" test_multigraph_sink """ 16import pytest 17 18import mindspore.context as context 19from mindspore.common import dtype as mstype 20from mindspore.common import ms_function 21from mindspore.common.tensor import Tensor 22 23 24def setup_module(): 25 context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") 26 27 28c1 = Tensor([2], mstype.int32) 29c2 = Tensor([14], mstype.int32) 30c3 = Tensor([1], mstype.int32) 31c4 = Tensor([0], mstype.int32) 32c5 = Tensor([14], mstype.int32) 33 34 35@ms_function 36def simple_if(x, y): 37 if x < y: 38 x = x + 1 39 else: 40 x = x + 2 41 x = x + 3 42 return x 43 44 45@ms_function 46def if_by_if(x, y): 47 if x < y: 48 x = x + 1 49 if y > x: 50 x = x + 2 51 x = x + 3 52 return x 53 54 55@ms_function 56def if_in_if(x, y, z): 57 out = c4 58 if x < y: 59 z = c4 + c4 60 if z < y: 61 z = z + 2 62 out = out + z 63 x = x + 3 64 out = out + x 65 return out 66 67 68@ms_function 69def simple_while(x, y): 70 y = y + 4 71 while x < y: 72 x = x + 1 73 x = x + 3 74 return x 75 76 77@ms_function 78def while_by_while(x, y, z): 79 while x < y: 80 x = x + 1 81 while z < c5: 82 z = z + 1 83 x = x + 1 84 x = x + 1 85 return x 86 87 88@ms_function 89def while_in_while(x, y, z): 90 out = c4 91 while x < y: 92 z = c4 + c4 93 while z < y: 94 z = z + 1 95 out = out + z 96 x = x + 1 97 out = out + x 98 return out 99 100 101@ms_function 102def while_by_while_in_while(x, y, z): 103 out = c4 104 while x < c2: 105 y = c4 + c4 106 while y < c2: 107 y = y + 1 108 out = out + y 109 z = c4 + c4 110 while z < c2: 111 z = z + 1 112 out = out + z 113 x = x + 1 114 out = out + x 115 return out 116 117 118@ms_function 119def while_in_while_in_while(x, y, z): 120 out = c4 121 while x < c2: 122 y = c4 + c4 123 while y < c2: 124 y = y + 1 125 z = c4 + c4 126 while z < c2: 127 z = z + 1 128 out = out + z 129 out = out + y 130 x = x + 1 131 out = out + x 132 return out 133 134 135@pytest.mark.level1 136@pytest.mark.platform_x86_ascend_training 137@pytest.mark.platform_arm_ascend_training 138@pytest.mark.env_onecard 139def test_simple_if(): 140 output = simple_if(c1, c2) 141 expect = Tensor([6], mstype.int32) 142 assert output == expect 143 144 145def test_if_by_if(): 146 output = if_by_if(c1, c2) 147 expect = Tensor([8], mstype.int32) 148 assert output == expect 149 150 151@pytest.mark.level1 152@pytest.mark.platform_x86_ascend_training 153@pytest.mark.platform_arm_ascend_training 154@pytest.mark.env_onecard 155def test_if_in_if(): 156 output = if_in_if(c1, c2, c3) 157 expect = Tensor([7], mstype.int32) 158 assert output == expect 159 160 161@pytest.mark.level1 162@pytest.mark.platform_x86_ascend_training 163@pytest.mark.platform_arm_ascend_training 164@pytest.mark.env_onecard 165def test_simple_while(): 166 output = simple_while(c1, c2) 167 expect = Tensor([21], mstype.int32) 168 assert output == expect 169 170 171@pytest.mark.level1 172@pytest.mark.platform_x86_ascend_training 173@pytest.mark.platform_arm_ascend_training 174@pytest.mark.env_onecard 175def test_while_by_while(): 176 output = while_by_while(c1, c2, c3) 177 expect = Tensor([28], mstype.int32) 178 assert output == expect 179 180 181@pytest.mark.level1 182@pytest.mark.platform_x86_ascend_training 183@pytest.mark.platform_arm_ascend_training 184@pytest.mark.env_onecard 185def test_while_in_while(): 186 output = while_in_while(c1, c2, c3) 187 expect = Tensor([1274], mstype.int32) 188 assert output == expect 189 190 191@pytest.mark.level1 192@pytest.mark.platform_x86_ascend_training 193@pytest.mark.platform_arm_ascend_training 194@pytest.mark.env_onecard 195def test_while_by_while_in_while(): 196 output = while_by_while_in_while(c1, c2, c3) 197 expect = Tensor([350], mstype.int32) 198 assert output == expect 199 200 201@pytest.mark.level1 202@pytest.mark.platform_x86_ascend_training 203@pytest.mark.platform_arm_ascend_training 204@pytest.mark.env_onecard 205def test_while_in_while_in_while(): 206 output = while_in_while_in_while(c1, c2, c3) 207 expect = Tensor([2534], mstype.int32) 208 assert output == expect 209