1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" 16test mindspore grammar constraints 171. function must have return statement 182. raise statement can not be used 19""" 20# pylint: disable=R1705, R1710, W0223 21import numpy as np 22import pytest 23 24import mindspore.nn as nn 25from mindspore import Tensor 26from mindspore import context 27from mindspore import dtype as mstype 28 29context.set_context(mode=context.GRAPH_MODE) 30 31def test_missing_return(): 32 class NetMissReturn(nn.Cell): 33 def __init__(self): 34 super(NetMissReturn, self).__init__() 35 36 def construct(self, x, y, z): 37 if x == 1: 38 return 10 39 elif x == 20: 40 if y == 1: 41 return 3 42 elif y == 2: 43 for i in range(z): 44 return i + z 45 i = 0 46 while i < z: 47 return i + z 48 def g(u): 49 return x + u 50 # here method 'construct' misses a return statement 51 g(y) 52 else: 53 return 7 54 else: 55 return 5 56 57 net = NetMissReturn() 58 x = Tensor(0, mstype.int32) 59 y = Tensor(5, mstype.int32) 60 z = Tensor(2, mstype.int32) 61 with pytest.raises(TypeError) as er: 62 net(x, y, z) 63 assert "Function must has 'return' statement, but missing in bound method 'construct'" in str(er.value) 64 65 66def test_nest_function_missing_return(): 67 class NetNestFuncMissReturn(nn.Cell): 68 def __init__(self): 69 super(NetNestFuncMissReturn, self).__init__() 70 71 def construct(self, x, y, z): 72 if x == 1: 73 return 10 74 elif x == 20: 75 if y == 1: 76 return 3 77 elif y == 2: 78 for i in range(z): 79 return i + z 80 i = 0 81 while i < z: 82 return i + z 83 def g(u): 84 x += u 85 # nested function 'g' misses a return a statement 86 return g(y) 87 else: 88 return 7 89 else: 90 return 5 91 92 net = NetNestFuncMissReturn() 93 x = Tensor(0, mstype.int32) 94 y = Tensor(5, mstype.int32) 95 z = Tensor(2, mstype.int32) 96 with pytest.raises(TypeError) as er: 97 net(x, y, z) 98 assert "Function must has 'return' statement, but missing in function 'g'" in str(er.value) 99 100 101def test_raise_in_method(): 102 class NetRaiseInMethod(nn.Cell): 103 def __init__(self): 104 super(NetRaiseInMethod, self).__init__() 105 106 def construct(self, x, y, z): 107 if x == 1: 108 return 10 109 elif x == 20: 110 # add not support grammar 'raise' here 111 raise ValueError('Illegal case') 112 else: 113 return y + z 114 115 net = NetRaiseInMethod() 116 x = Tensor(0, mstype.int32) 117 y = Tensor(5, mstype.int32) 118 z = Tensor(2, mstype.int32) 119 with pytest.raises(RuntimeError) as er: 120 net(x, y, z) 121 assert "Unsupported statement 'Raise'." in str(er.value) 122 123 124def test_raise_in_nested_function(): 125 class NetNestRaise(nn.Cell): 126 def __init__(self): 127 super(NetNestRaise, self).__init__() 128 129 def construct(self, x, y, z): 130 if x == 1: 131 return 10 132 elif x == 20: 133 def nest_fn(u): 134 if u > 0: 135 # add not support grammar 'raise' here 136 raise ValueError('Illegal case') 137 return u + z + 1 138 return nest_fn(y) 139 else: 140 return y + z 141 142 net = NetNestRaise() 143 x = Tensor(0, mstype.int32) 144 y = Tensor(5, mstype.int32) 145 z = Tensor(2, mstype.int32) 146 with pytest.raises(RuntimeError) as er: 147 net(x, y, z) 148 assert "Unsupported statement 'Raise'." in str(er.value) 149 150 151def test_nest_branch_with_return(): 152 class NetBranchWithReturn(nn.Cell): 153 def __init__(self): 154 super(NetBranchWithReturn, self).__init__() 155 156 def construct(self, x, y, z): 157 if x == 1: 158 return 10 159 else: 160 return 5 161 162 net = NetBranchWithReturn() 163 x = Tensor(0, mstype.int32) 164 y = Tensor(5, mstype.int32) 165 z = Tensor(2, mstype.int32) 166 net(x, y, z) 167 168 169def test_any_with_no_return(): 170 class NetAnyNoReturn(nn.Cell): 171 def __init__(self): 172 super(NetAnyNoReturn, self).__init__() 173 174 def construct(self, inp): 175 result = inp.any() 176 if result: 177 return 6 178 179 np_input = np.arange(2 * 3 * 4).reshape((2, 3, 4)).astype(np.bool_) 180 tensor = Tensor(np_input) 181 net = NetAnyNoReturn() 182 with pytest.raises(TypeError) as er: 183 net(tensor) 184 assert "Function must has 'return' statement, but missing in bound method 'construct'" in str(er.value) 185 186 187def test_missing_construct(): 188 class NetMissConstruct(nn.Cell): 189 def __init__(self): 190 super(NetMissConstruct, self).__init__() 191 192 def construct1(self, inp): 193 return 5 194 195 np_input = np.arange(2 * 3 * 4).reshape((2, 3, 4)).astype(np.bool_) 196 tensor = Tensor(np_input) 197 net = NetMissConstruct() 198 assert net(tensor) is None 199