1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" test_cont_break """ 16import numpy as np 17 18from mindspore import Tensor, Model, context 19from mindspore.nn import Cell 20from ...ut_filter import non_graph_engine 21 22 23def run_test(netclass, count): 24 context.set_context(mode=context.GRAPH_MODE) 25 net = netclass() 26 model = Model(net) 27 for _ in range(count): 28 input_np = np.random.randn(2, 3).astype(np.float32) 29 input_ms = Tensor(input_np) 30 output_np = net.construct(input_np) # run python 31 output_ms = model.predict(input_ms) # run graph 32 assert np.shape(output_np) == np.shape(output_ms.asnumpy()) 33 # Disable equal assert because UT in CI use fake backend. 34 # np.testing.assert_array_almost_equal(output_np, output_ms.asnumpy(), decimal=3) 35 36 37# pylint: disable=unnecessary-pass 38class for_loop_with_break(Cell): 39 def __init__(self): 40 super().__init__() 41 42 def construct(self, x): 43 for i in range(8): 44 if i > 5: 45 x *= 3 46 break 47 x = x * 2 48 pass 49 return x 50 51 52@non_graph_engine 53def test_for_loop_with_break(): 54 run_test(for_loop_with_break, 10) 55 56 57class for_loop_with_continue(Cell): 58 def __init__(self): 59 super().__init__() 60 61 def construct(self, x): 62 for i in range(8): 63 if i > 5: 64 x *= 3 65 continue 66 x = x * 2 67 return x 68 69 70@non_graph_engine 71def test_for_loop_with_continue(): 72 run_test(for_loop_with_continue, 10) 73 74# pylint: disable=unnecessary-pass 75class for_loop_with_cont_break(Cell): 76 def __init__(self): 77 super().__init__() 78 79 def construct(self, x): 80 for i in range(8): 81 if i < 3: 82 i *= 2 83 continue 84 if i > 5: 85 x *= 3 86 break 87 # x *= 2 88 x = x * 2 89 pass 90 return x 91 92 93@non_graph_engine 94def test_for_loop_with_cont_break(): 95 run_test(for_loop_with_cont_break, 10) 96 97 98class for_nested_loop_with_break(Cell): 99 def __init__(self): 100 super().__init__() 101 102 def construct(self, x): 103 for i in range(3): 104 for j in range(5): 105 if j > 3: 106 x *= 2 107 break 108 x = x * 1.5 109 return x 110 111 112@non_graph_engine 113def test_for_nested_loop_with_break(): 114 run_test(for_nested_loop_with_break, 10) 115 116 117class while_with_break(Cell): 118 def __init__(self): 119 super().__init__() 120 121 def construct(self, x): 122 i = 0 123 while i < 5: 124 if i > 3: 125 x *= 2 126 break 127 x = x * 1.5 128 i += 1 129 return x 130 131 132@non_graph_engine 133def test_while_with_break(): 134 run_test(while_with_break, 10) 135 136 137class while_with_continue(Cell): 138 def __init__(self): 139 super().__init__() 140 141 def construct(self, x): 142 i = 0 143 while i < 5: 144 if i > 3: 145 x *= 2 146 i += 1 147 continue 148 x = x * 1.5 149 i += 1 150 return x 151 152 153@non_graph_engine 154def test_while_with_continue(): 155 run_test(while_with_continue, 10) 156 157 158class while_for_nested(Cell): 159 def __init__(self): 160 super().__init__() 161 162 def construct(self, x): 163 i = 0 164 while i < 5: 165 if i > 3: 166 for j in range(3): 167 if j > 1: 168 break 169 x *= 2 170 i += 1 171 continue 172 x = x * 1.5 173 i += 1 174 return x 175 176 177@non_graph_engine 178def test_while_for_nested(): 179 run_test(while_for_nested, 10) 180 181 182class pass_branch(Cell): 183 def __init__(self): 184 super().__init__() 185 186 def construct(self, x): 187 i = 0 188 while i < 5: 189 if i > 3: 190 pass 191 else: 192 x = x * 1.5 193 i += 1 194 return x 195 196 197@non_graph_engine 198def test_pass_branch(): 199 run_test(pass_branch, 10) 200