1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" test_multigraph_sink """ 16import mindspore.context as context 17from mindspore.common import dtype as mstype 18from mindspore.common import ms_function 19from mindspore.common.tensor import Tensor 20 21 22def setup_module(module): 23 context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") 24 25 26c1 = Tensor([2], mstype.int32) 27c2 = Tensor([14], mstype.int32) 28c3 = Tensor([1], mstype.int32) 29c4 = Tensor([0], mstype.int32) 30c5 = Tensor([14], mstype.int32) 31 32 33@ms_function 34def simple_if(x, y, z): 35 if x < y: 36 x = x + 1 37 else: 38 x = x + 2 39 x = x + 3 40 return x 41 42 43@ms_function 44def if_by_if(x, y, z): 45 if x < y: 46 x = x + 1 47 if y > x: 48 x = x + 2 49 x = x + 3 50 return x 51 52 53@ms_function 54def if_in_if(x, y, z): 55 out = c4 56 if x < y: 57 z = c4 + c4 58 if z < y: 59 z = z + 2 60 out = out + z 61 x = x + 3 62 out = out + x 63 return out 64 65 66@ms_function 67def simple_while(x, y, z): 68 y = y + 4 69 while x < y: 70 x = x + 1 71 x = x + 3 72 return x 73 74 75@ms_function 76def while_by_while(x, y, z): 77 while x < y: 78 x = x + 1 79 while z < c5: 80 z = z + 1 81 x = x + 1 82 x = x + 1 83 return x 84 85 86@ms_function 87def while_in_while(x, y, z): 88 out = c4 89 while x < y: 90 z = c4 + c4 91 while z < y: 92 z = z + 1 93 out = out + z 94 x = x + 1 95 out = out + x 96 return out 97 98 99def test_simple_if(): 100 output = simple_if(c1, c2, c3) 101 expect = Tensor([6], mstype.int32) 102 assert output == expect 103 104 105def test_if_by_if(): 106 output = if_by_if(c1, c2, c3) 107 expect = Tensor([8], mstype.int32) 108 assert output == expect 109 110 111def test_if_in_if(): 112 output = if_in_if(c1, c2, c3) 113 expect = Tensor([7], mstype.int32) 114 assert output == expect 115 116 117def test_simple_while(): 118 output = simple_while(c1, c2, c3) 119 expect = Tensor([21], mstype.int32) 120 assert output == expect 121 122 123def test_while_by_while(): 124 output = while_by_while(c1, c2, c3) 125 expect = Tensor([28], mstype.int32) 126 assert output == expect 127 128 129def test_while_in_while(): 130 output = while_in_while(c1, c2, c3) 131 expect = Tensor([1274], mstype.int32) 132 assert output == expect 133 134 135@ms_function 136def while_by_while_in_while(x, y, z): 137 out = c4 138 while x < c2: 139 y = c4 + c4 140 while y < c2: 141 y = y + 1 142 out = out + y 143 z = c4 + c4 144 while z < c2: 145 z = z + 1 146 out = out + z 147 x = x + 1 148 out = out + x 149 return out 150 151 152def test_while_by_while_in_while(): 153 output = while_by_while_in_while(c1, c2, c3) 154 expect = Tensor([350], mstype.int32) 155 assert output == expect 156