1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" test_cont_break """ 16import numpy as np 17import pytest 18 19from mindspore import Tensor, Model, context 20from mindspore.nn import Cell 21 22 23def run_test(netclass, count, dev): 24 context.set_context(mode=context.GRAPH_MODE, device_target=dev) 25 net = netclass() 26 model = Model(net) 27 for _ in range(count): 28 input_np = np.random.randn(2, 3).astype(np.float32) 29 input_ms = Tensor(input_np) 30 output_np = net.construct(input_np) # run python 31 output_ms = model.predict(input_ms) # run graph 32 np.testing.assert_array_almost_equal(output_np, output_ms.asnumpy(), decimal=3) 33 34 35class ForLoopWithBreak(Cell): 36 def construct(self, x): 37 for i in range(8): 38 if i > 5: 39 x *= 3 40 break 41 x = x * 2 42 return x 43 44 45class ForLoopWithContinue(Cell): 46 def construct(self, x): 47 for i in range(8): 48 if i > 5: 49 x *= 3 50 continue 51 x = x * 2 52 return x 53 54 55class ForLoopWithContBreak(Cell): 56 def construct(self, x): 57 for i in range(8): 58 if i < 3: 59 i *= 2 60 continue 61 if i > 5: 62 x *= 3 63 break 64 x = x * 2 65 return x 66 67 68class ForNestedLoopWithBreak(Cell): 69 def construct(self, x): 70 for _ in range(3): 71 for j in range(5): 72 if j > 3: 73 x *= 2 74 break 75 x = x * 1.5 76 return x 77 78 79class WhileWithBreak(Cell): 80 def construct(self, x): 81 i = 0 82 while i < 5: 83 if i > 3: 84 x *= 2 85 break 86 x = x * 1.5 87 i += 1 88 return x 89 90 91class WhileWithContinue(Cell): 92 def construct(self, x): 93 i = 0 94 while i < 5: 95 if i > 3: 96 x *= 2 97 i += 1 98 continue 99 x = x * 1.5 100 i += 1 101 return x 102 103 104class WhileForNested(Cell): 105 def construct(self, x): 106 i = 0 107 while i < 5: 108 if i > 3: 109 for j in range(3): 110 if j > 1: 111 break 112 x *= 2 113 i += 1 114 continue 115 x = x * 1.5 116 i += 1 117 return x 118 119 120class PassBranch(Cell): 121 def construct(self, x): 122 i = 0 123 while i < 5: 124 if i > 3: 125 pass 126 else: 127 x = x * 1.5 128 i += 1 129 return x 130 131 132@pytest.mark.level0 133@pytest.mark.platform_x86_cpu 134@pytest.mark.env_onecard 135def test_cont_break(): 136 count = 20 137 dev = 'CPU' 138 run_test(ForLoopWithBreak, count, dev) 139 run_test(ForLoopWithContinue, count, dev) 140 run_test(ForLoopWithContBreak, count, dev) 141 run_test(ForNestedLoopWithBreak, count, dev) 142 run_test(WhileWithBreak, count, dev) 143 run_test(WhileWithContinue, count, dev) 144 run_test(WhileForNested, count, dev) 145 run_test(PassBranch, count, dev) 146