• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15""" test_cont_break """
16import numpy as np
17import pytest
18
19from mindspore import Tensor, Model, context
20from mindspore.nn import Cell
21
22
23def run_test(netclass, count, dev):
24    context.set_context(mode=context.GRAPH_MODE, device_target=dev)
25    net = netclass()
26    model = Model(net)
27    for _ in range(count):
28        input_np = np.random.randn(2, 3).astype(np.float32)
29        input_ms = Tensor(input_np)
30        output_np = net.construct(input_np)  # run python
31        output_ms = model.predict(input_ms)  # run graph
32        np.testing.assert_array_almost_equal(output_np, output_ms.asnumpy(), decimal=3)
33
34
35class ForLoopWithBreak(Cell):
36    def construct(self, x):
37        for i in range(8):
38            if i > 5:
39                x *= 3
40                break
41            x = x * 2
42        return x
43
44
45class ForLoopWithContinue(Cell):
46    def construct(self, x):
47        for i in range(8):
48            if i > 5:
49                x *= 3
50                continue
51            x = x * 2
52        return x
53
54
55class ForLoopWithContBreak(Cell):
56    def construct(self, x):
57        for i in range(8):
58            if i < 3:
59                i *= 2
60                continue
61            if i > 5:
62                x *= 3
63                break
64            x = x * 2
65        return x
66
67
68class ForNestedLoopWithBreak(Cell):
69    def construct(self, x):
70        for _ in range(3):
71            for j in range(5):
72                if j > 3:
73                    x *= 2
74                    break
75                x = x * 1.5
76        return x
77
78
79class WhileWithBreak(Cell):
80    def construct(self, x):
81        i = 0
82        while i < 5:
83            if i > 3:
84                x *= 2
85                break
86            x = x * 1.5
87            i += 1
88        return x
89
90
91class WhileWithContinue(Cell):
92    def construct(self, x):
93        i = 0
94        while i < 5:
95            if i > 3:
96                x *= 2
97                i += 1
98                continue
99            x = x * 1.5
100            i += 1
101        return x
102
103
104class WhileForNested(Cell):
105    def construct(self, x):
106        i = 0
107        while i < 5:
108            if i > 3:
109                for j in range(3):
110                    if j > 1:
111                        break
112                    x *= 2
113                i += 1
114                continue
115            x = x * 1.5
116            i += 1
117        return x
118
119
120class PassBranch(Cell):
121    def construct(self, x):
122        i = 0
123        while i < 5:
124            if i > 3:
125                pass
126            else:
127                x = x * 1.5
128            i += 1
129        return x
130
131
132@pytest.mark.level0
133@pytest.mark.platform_x86_cpu
134@pytest.mark.env_onecard
135def test_cont_break():
136    count = 20
137    dev = 'CPU'
138    run_test(ForLoopWithBreak, count, dev)
139    run_test(ForLoopWithContinue, count, dev)
140    run_test(ForLoopWithContBreak, count, dev)
141    run_test(ForNestedLoopWithBreak, count, dev)
142    run_test(WhileWithBreak, count, dev)
143    run_test(WhileWithContinue, count, dev)
144    run_test(WhileForNested, count, dev)
145    run_test(PassBranch, count, dev)
146