• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15""" test_cont_break """
16import numpy as np
17
18from mindspore import Tensor, Model, context
19from mindspore.nn import Cell
20from ...ut_filter import non_graph_engine
21
22
23def run_test(netclass, count):
24    context.set_context(mode=context.GRAPH_MODE)
25    net = netclass()
26    model = Model(net)
27    for _ in range(count):
28        input_np = np.random.randn(2, 3).astype(np.float32)
29        input_ms = Tensor(input_np)
30        output_np = net.construct(input_np)  # run python
31        output_ms = model.predict(input_ms)  # run graph
32        assert np.shape(output_np) == np.shape(output_ms.asnumpy())
33        # Disable equal assert because UT in CI use fake backend.
34        # np.testing.assert_array_almost_equal(output_np, output_ms.asnumpy(), decimal=3)
35
36
37# pylint: disable=unnecessary-pass
38class for_loop_with_break(Cell):
39    def __init__(self):
40        super().__init__()
41
42    def construct(self, x):
43        for i in range(8):
44            if i > 5:
45                x *= 3
46                break
47            x = x * 2
48            pass
49        return x
50
51
52@non_graph_engine
53def test_for_loop_with_break():
54    run_test(for_loop_with_break, 10)
55
56
57class for_loop_with_continue(Cell):
58    def __init__(self):
59        super().__init__()
60
61    def construct(self, x):
62        for i in range(8):
63            if i > 5:
64                x *= 3
65                continue
66            x = x * 2
67        return x
68
69
70@non_graph_engine
71def test_for_loop_with_continue():
72    run_test(for_loop_with_continue, 10)
73
74# pylint: disable=unnecessary-pass
75class for_loop_with_cont_break(Cell):
76    def __init__(self):
77        super().__init__()
78
79    def construct(self, x):
80        for i in range(8):
81            if i < 3:
82                i *= 2
83                continue
84            if i > 5:
85                x *= 3
86                break
87                # x *= 2
88            x = x * 2
89            pass
90        return x
91
92
93@non_graph_engine
94def test_for_loop_with_cont_break():
95    run_test(for_loop_with_cont_break, 10)
96
97
98class for_nested_loop_with_break(Cell):
99    def __init__(self):
100        super().__init__()
101
102    def construct(self, x):
103        for i in range(3):
104            for j in range(5):
105                if j > 3:
106                    x *= 2
107                    break
108                x = x * 1.5
109        return x
110
111
112@non_graph_engine
113def test_for_nested_loop_with_break():
114    run_test(for_nested_loop_with_break, 10)
115
116
117class while_with_break(Cell):
118    def __init__(self):
119        super().__init__()
120
121    def construct(self, x):
122        i = 0
123        while i < 5:
124            if i > 3:
125                x *= 2
126                break
127            x = x * 1.5
128            i += 1
129        return x
130
131
132@non_graph_engine
133def test_while_with_break():
134    run_test(while_with_break, 10)
135
136
137class while_with_continue(Cell):
138    def __init__(self):
139        super().__init__()
140
141    def construct(self, x):
142        i = 0
143        while i < 5:
144            if i > 3:
145                x *= 2
146                i += 1
147                continue
148            x = x * 1.5
149            i += 1
150        return x
151
152
153@non_graph_engine
154def test_while_with_continue():
155    run_test(while_with_continue, 10)
156
157
158class while_for_nested(Cell):
159    def __init__(self):
160        super().__init__()
161
162    def construct(self, x):
163        i = 0
164        while i < 5:
165            if i > 3:
166                for j in range(3):
167                    if j > 1:
168                        break
169                    x *= 2
170                i += 1
171                continue
172            x = x * 1.5
173            i += 1
174        return x
175
176
177@non_graph_engine
178def test_while_for_nested():
179    run_test(while_for_nested, 10)
180
181
182class pass_branch(Cell):
183    def __init__(self):
184        super().__init__()
185
186    def construct(self, x):
187        i = 0
188        while i < 5:
189            if i > 3:
190                pass
191            else:
192                x = x * 1.5
193            i += 1
194        return x
195
196
197@non_graph_engine
198def test_pass_branch():
199    run_test(pass_branch, 10)
200