1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" test_cell """ 16import numpy as np 17import pytest 18 19import mindspore.nn as nn 20from mindspore import Tensor, Parameter 21from ...ut_filter import non_graph_engine 22 23 24class ModA(nn.Cell): 25 """ ModA definition """ 26 27 def __init__(self, tensor): 28 super(ModA, self).__init__() 29 self.weight = Parameter(tensor, name="weight") 30 31 def construct(self, *inputs): 32 pass 33 34 35class ModB(nn.Cell): 36 """ ModB definition """ 37 38 def __init__(self, tensor): 39 super(ModB, self).__init__() 40 self.weight = Parameter(tensor, name="weight") 41 42 def construct(self, *inputs): 43 pass 44 45 46class ModC(nn.Cell): 47 """ ModC definition """ 48 49 def __init__(self, ta, tb): 50 super(ModC, self).__init__() 51 self.mod1 = ModA(ta) 52 self.mod2 = ModB(tb) 53 54 def construct(self, *inputs): 55 pass 56 57 58class Net(nn.Cell): 59 """ Net definition """ 60 name_len = 4 61 cells_num = 3 62 63 def __init__(self, ta, tb): 64 super(Net, self).__init__() 65 self.mod1 = ModA(ta) 66 self.mod2 = ModB(tb) 67 self.mod3 = ModC(ta, tb) 68 69 def construct(self, *inputs): 70 pass 71 72 73class Net2(nn.Cell): 74 """ Net2 definition """ 75 76 def __init__(self, ta, tb): 77 super(Net2, self).__init__(auto_prefix=False) 78 self.mod1 = ModA(ta) 79 self.mod2 = ModB(tb) 80 self.mod3 = ModC(ta, tb) 81 82 def construct(self, *inputs): 83 pass 84 85 86class ConvNet(nn.Cell): 87 """ ConvNet definition """ 88 image_h = 224 89 image_w = 224 90 output_ch = 64 91 92 def __init__(self, num_classes=10): 93 super(ConvNet, self).__init__() 94 self.conv1 = nn.Conv2d(3, ConvNet.output_ch, kernel_size=7, stride=2, pad_mode='pad', padding=3) 95 self.bn1 = nn.BatchNorm2d(ConvNet.output_ch) 96 self.relu = nn.ReLU() 97 self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same") 98 self.flatten = nn.Flatten() 99 self.fc = nn.Dense( 100 int(ConvNet.image_h * ConvNet.image_w * ConvNet.output_ch / (4 * 4)), 101 num_classes) 102 103 def construct(self, x): 104 x = self.conv1(x) 105 x = self.bn1(x) 106 x = self.relu(x) 107 x = self.maxpool(x) 108 x = self.flatten(x) 109 x = self.fc(x) 110 return x 111 112 113def test_basic(): 114 """ test_basic """ 115 ta = Tensor(np.ones([2, 3])) 116 tb = Tensor(np.ones([1, 4])) 117 n = Net(ta, tb) 118 names = list(n.parameters_dict().keys()) 119 assert len(names) == n.name_len 120 assert names[0] == "mod1.weight" 121 assert names[1] == "mod2.weight" 122 assert names[2] == "mod3.mod1.weight" 123 assert names[3] == "mod3.mod2.weight" 124 125 126def test_parameter_name(): 127 """ test_parameter_name """ 128 ta = Tensor(np.ones([2, 3])) 129 tb = Tensor(np.ones([1, 4])) 130 n = Net(ta, tb) 131 names = [] 132 for m in n.parameters_and_names(): 133 if m[0]: 134 names.append(m[0]) 135 assert names[0] == "mod1.weight" 136 assert names[1] == "mod2.weight" 137 assert names[2] == "mod3.mod1.weight" 138 assert names[3] == "mod3.mod2.weight" 139 140 141def test_cell_name(): 142 """ test_cell_name """ 143 ta = Tensor(np.ones([2, 3])) 144 tb = Tensor(np.ones([1, 4])) 145 n = Net(ta, tb) 146 n.insert_child_to_cell('modNone', None) 147 names = [] 148 for m in n.cells_and_names(): 149 if m[0]: 150 names.append(m[0]) 151 assert names[0] == "mod1" 152 assert names[1] == "mod2" 153 assert names[2] == "mod3" 154 assert names[3] == "mod3.mod1" 155 assert names[4] == "mod3.mod2" 156 157 158def test_cells(): 159 """ test_cells """ 160 ta = Tensor(np.ones([2, 3])) 161 tb = Tensor(np.ones([1, 4])) 162 n = Net(ta, tb) 163 ch = list(n.cells()) 164 assert len(ch) == n.cells_num 165 166 167def test_exceptions(): 168 """ test_exceptions """ 169 t = Tensor(np.ones([2, 3])) 170 171 class ModError(nn.Cell): 172 """ ModError definition """ 173 174 def __init__(self, tensor): 175 self.weight = Parameter(tensor, name="weight") 176 super(ModError, self).__init__() 177 178 def construct(self, *inputs): 179 pass 180 181 with pytest.raises(AttributeError): 182 ModError(t) 183 184 class ModError1(nn.Cell): 185 """ ModError1 definition """ 186 187 def __init__(self, tensor): 188 super().__init__() 189 self.weight = Parameter(tensor, name="weight") 190 self.weight = None 191 self.weight = ModA(tensor) 192 193 def construct(self, *inputs): 194 pass 195 196 with pytest.raises(TypeError): 197 ModError1(t) 198 199 class ModError2(nn.Cell): 200 """ ModError2 definition """ 201 202 def __init__(self, tensor): 203 super().__init__() 204 self.mod = ModA(tensor) 205 self.mod = None 206 self.mod = tensor 207 208 def construct(self, *inputs): 209 pass 210 211 with pytest.raises(TypeError): 212 ModError2(t) 213 214 m = nn.Cell() 215 assert m.construct() is None 216 217 218def test_del(): 219 """ test_del """ 220 ta = Tensor(np.ones([2, 3])) 221 tb = Tensor(np.ones([1, 4])) 222 n = Net(ta, tb) 223 names = list(n.parameters_dict().keys()) 224 assert len(names) == n.name_len 225 del n.mod1 226 names = list(n.parameters_dict().keys()) 227 assert len(names) == n.name_len - 1 228 with pytest.raises(AttributeError): 229 del n.mod1.weight 230 del n.mod2.weight 231 names = list(n.parameters_dict().keys()) 232 assert len(names) == n.name_len - 2 233 with pytest.raises(AttributeError): 234 del n.mod 235 236 237def test_add_attr(): 238 """ test_add_attr """ 239 ta = Tensor(np.ones([2, 3])) 240 tb = Tensor(np.ones([1, 4])) 241 p = Parameter(ta, name="weight") 242 m = nn.Cell() 243 m.insert_param_to_cell('weight', p) 244 245 with pytest.raises(TypeError): 246 m.insert_child_to_cell("network", p) 247 248 with pytest.raises(KeyError): 249 m.insert_param_to_cell('', p) 250 with pytest.raises(KeyError): 251 m.insert_param_to_cell('a.b', p) 252 m.insert_param_to_cell('weight', p) 253 with pytest.raises(KeyError): 254 m.insert_child_to_cell('', ModA(ta)) 255 with pytest.raises(KeyError): 256 m.insert_child_to_cell('a.b', ModB(tb)) 257 258 with pytest.raises(TypeError): 259 m.insert_child_to_cell('buffer', tb) 260 with pytest.raises(TypeError): 261 m.insert_param_to_cell('w', ta) 262 with pytest.raises(TypeError): 263 m.insert_child_to_cell('m', p) 264 265 class ModAddCellError(nn.Cell): 266 """ ModAddCellError definition """ 267 268 def __init__(self, tensor): 269 self.mod = ModA(tensor) 270 super().__init__() 271 272 def construct(self, *inputs): 273 pass 274 275 with pytest.raises(AttributeError): 276 ModAddCellError(ta) 277 278 279def test_train_eval(): 280 """ test_train_eval """ 281 m = nn.Cell() 282 assert not m.training 283 m.set_train() 284 assert m.training 285 m.set_train(False) 286 assert not m.training 287 288 289def test_stop_update_name(): 290 """ test_stop_update_name """ 291 ta = Tensor(np.ones([2, 3])) 292 tb = Tensor(np.ones([1, 4])) 293 n = Net2(ta, tb) 294 names = list(n.parameters_dict().keys()) 295 assert names[0] == "weight" 296 assert names[1] == "mod1.weight" 297 assert names[2] == "mod2.weight" 298 299 300@non_graph_engine 301def test_net_call(): 302 """ test_net_call """ 303 with pytest.raises(ValueError): 304 net = ConvNet() 305 input_x = Tensor( 306 np.random.randint(0, 255, [1, 3, net.image_h, net.image_w]).astype(np.float32)) 307 net.construct(input_x) 308