• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16test mindspore grammar constraints
171. function must have return statement
182. raise statement can not be used
19"""
20# pylint: disable=R1705, R1710, W0223
21import numpy as np
22import pytest
23
24import mindspore.nn as nn
25from mindspore import Tensor
26from mindspore import context
27from mindspore import dtype as mstype
28
29context.set_context(mode=context.GRAPH_MODE)
30
31def test_missing_return():
32    class NetMissReturn(nn.Cell):
33        def __init__(self):
34            super(NetMissReturn, self).__init__()
35
36        def construct(self, x, y, z):
37            if x == 1:
38                return 10
39            elif x == 20:
40                if y == 1:
41                    return 3
42                elif y == 2:
43                    for i in range(z):
44                        return i + z
45                    i = 0
46                    while i < z:
47                        return i + z
48                    def g(u):
49                        return x + u
50                    # here method 'construct' misses a return statement
51                    g(y)
52                else:
53                    return 7
54            else:
55                return 5
56
57    net = NetMissReturn()
58    x = Tensor(0, mstype.int32)
59    y = Tensor(5, mstype.int32)
60    z = Tensor(2, mstype.int32)
61    with pytest.raises(TypeError) as er:
62        net(x, y, z)
63    assert "Function must has 'return' statement, but missing in bound method 'construct'" in str(er.value)
64
65
66def test_nest_function_missing_return():
67    class NetNestFuncMissReturn(nn.Cell):
68        def __init__(self):
69            super(NetNestFuncMissReturn, self).__init__()
70
71        def construct(self, x, y, z):
72            if x == 1:
73                return 10
74            elif x == 20:
75                if y == 1:
76                    return 3
77                elif y == 2:
78                    for i in range(z):
79                        return i + z
80                    i = 0
81                    while i < z:
82                        return i + z
83                    def g(u):
84                        x += u
85                        # nested function 'g' misses a return a statement
86                    return g(y)
87                else:
88                    return 7
89            else:
90                return 5
91
92    net = NetNestFuncMissReturn()
93    x = Tensor(0, mstype.int32)
94    y = Tensor(5, mstype.int32)
95    z = Tensor(2, mstype.int32)
96    with pytest.raises(TypeError) as er:
97        net(x, y, z)
98    assert "Function must has 'return' statement, but missing in function 'g'" in str(er.value)
99
100
101def test_raise_in_method():
102    class NetRaiseInMethod(nn.Cell):
103        def __init__(self):
104            super(NetRaiseInMethod, self).__init__()
105
106        def construct(self, x, y, z):
107            if x == 1:
108                return 10
109            elif x == 20:
110                # add not support grammar 'raise' here
111                raise ValueError('Illegal case')
112            else:
113                return y + z
114
115    net = NetRaiseInMethod()
116    x = Tensor(0, mstype.int32)
117    y = Tensor(5, mstype.int32)
118    z = Tensor(2, mstype.int32)
119    with pytest.raises(RuntimeError) as er:
120        net(x, y, z)
121    assert "Unsupported statement 'Raise'." in str(er.value)
122
123
124def test_raise_in_nested_function():
125    class NetNestRaise(nn.Cell):
126        def __init__(self):
127            super(NetNestRaise, self).__init__()
128
129        def construct(self, x, y, z):
130            if x == 1:
131                return 10
132            elif x == 20:
133                def nest_fn(u):
134                    if u > 0:
135                       # add not support grammar 'raise' here
136                        raise ValueError('Illegal case')
137                    return u + z + 1
138                return nest_fn(y)
139            else:
140                return y + z
141
142    net = NetNestRaise()
143    x = Tensor(0, mstype.int32)
144    y = Tensor(5, mstype.int32)
145    z = Tensor(2, mstype.int32)
146    with pytest.raises(RuntimeError) as er:
147        net(x, y, z)
148    assert "Unsupported statement 'Raise'." in str(er.value)
149
150
151def test_nest_branch_with_return():
152    class NetBranchWithReturn(nn.Cell):
153        def __init__(self):
154            super(NetBranchWithReturn, self).__init__()
155
156        def construct(self, x, y, z):
157            if x == 1:
158                return 10
159            else:
160                return 5
161
162    net = NetBranchWithReturn()
163    x = Tensor(0, mstype.int32)
164    y = Tensor(5, mstype.int32)
165    z = Tensor(2, mstype.int32)
166    net(x, y, z)
167
168
169def test_any_with_no_return():
170    class NetAnyNoReturn(nn.Cell):
171        def __init__(self):
172            super(NetAnyNoReturn, self).__init__()
173
174        def construct(self, inp):
175            result = inp.any()
176            if result:
177                return 6
178
179    np_input = np.arange(2 * 3 * 4).reshape((2, 3, 4)).astype(np.bool_)
180    tensor = Tensor(np_input)
181    net = NetAnyNoReturn()
182    with pytest.raises(TypeError) as er:
183        net(tensor)
184    assert "Function must has 'return' statement, but missing in bound method 'construct'" in str(er.value)
185
186
187def test_missing_construct():
188    class NetMissConstruct(nn.Cell):
189        def __init__(self):
190            super(NetMissConstruct, self).__init__()
191
192        def construct1(self, inp):
193            return 5
194
195    np_input = np.arange(2 * 3 * 4).reshape((2, 3, 4)).astype(np.bool_)
196    tensor = Tensor(np_input)
197    net = NetMissConstruct()
198    assert net(tensor) is None
199