1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" test cell """ 16import copy 17import numpy as np 18import pytest 19 20import mindspore.nn as nn 21from mindspore import Tensor, Parameter 22from mindspore.common.api import _cell_graph_executor 23 24 25class ModA(nn.Cell): 26 def __init__(self, tensor): 27 super(ModA, self).__init__() 28 self.weight = Parameter(tensor, name="weight") 29 30 def construct(self, *inputs): 31 pass 32 33 34class ModB(nn.Cell): 35 def __init__(self, tensor): 36 super(ModB, self).__init__() 37 self.weight = Parameter(tensor, name="weight") 38 39 def construct(self, *inputs): 40 pass 41 42 43class ModC(nn.Cell): 44 def __init__(self, ta, tb): 45 super(ModC, self).__init__() 46 self.mod1 = ModA(ta) 47 self.mod2 = ModB(tb) 48 49 def construct(self, *inputs): 50 pass 51 52 53class Net(nn.Cell): 54 """ Net definition """ 55 name_len = 4 56 cells_num = 3 57 58 def __init__(self, ta, tb): 59 super(Net, self).__init__() 60 self.mod1 = ModA(ta) 61 self.mod2 = ModB(tb) 62 self.mod3 = ModC(ta, tb) 63 64 def construct(self, *inputs): 65 pass 66 67 68class Net2(nn.Cell): 69 def __init__(self, ta, tb): 70 super(Net2, self).__init__(auto_prefix=False) 71 self.mod1 = ModA(ta) 72 self.mod2 = ModB(tb) 73 self.mod3 = ModC(ta, tb) 74 75 def construct(self, *inputs): 76 pass 77 78 79class ConvNet(nn.Cell): 80 """ ConvNet definition """ 81 image_h = 224 82 image_w = 224 83 output_ch = 64 84 85 def __init__(self, num_classes=10): 86 super(ConvNet, self).__init__() 87 self.conv1 = nn.Conv2d(3, ConvNet.output_ch, kernel_size=7, stride=2, pad_mode="pad", padding=3) 88 self.bn1 = nn.BatchNorm2d(ConvNet.output_ch) 89 self.relu = nn.ReLU() 90 self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same") 91 self.flatten = nn.Flatten() 92 self.fc = nn.Dense( 93 int(ConvNet.image_h * ConvNet.image_w * ConvNet.output_ch / (4 * 4)), 94 num_classes) 95 96 def construct(self, x): 97 x = self.conv1(x) 98 x = self.bn1(x) 99 x = self.relu(x) 100 x = self.maxpool(x) 101 x = self.flatten(x) 102 x = self.fc(x) 103 return x 104 105 106def test_basic(): 107 ta = Tensor(np.ones([2, 3])) 108 tb = Tensor(np.ones([1, 4])) 109 n = Net(ta, tb) 110 names = list(n.parameters_dict().keys()) 111 assert len(names) == n.name_len 112 assert names[0] == "mod1.weight" 113 assert names[1] == "mod2.weight" 114 assert names[2] == "mod3.mod1.weight" 115 assert names[3] == "mod3.mod2.weight" 116 117 118def test_parameter_name(): 119 """ test_parameter_name """ 120 ta = Tensor(np.ones([2, 3])) 121 tb = Tensor(np.ones([1, 4])) 122 n = Net(ta, tb) 123 names = [] 124 for m in n.parameters_and_names(): 125 if m[0]: 126 names.append(m[0]) 127 assert names[0] == "mod1.weight" 128 assert names[1] == "mod2.weight" 129 assert names[2] == "mod3.mod1.weight" 130 assert names[3] == "mod3.mod2.weight" 131 132 133def test_cell_name(): 134 """ test_cell_name """ 135 ta = Tensor(np.ones([2, 3])) 136 tb = Tensor(np.ones([1, 4])) 137 n = Net(ta, tb) 138 n.insert_child_to_cell('modNone', None) 139 names = [] 140 for m in n.cells_and_names(): 141 if m[0]: 142 names.append(m[0]) 143 assert names[0] == "mod1" 144 assert names[1] == "mod2" 145 assert names[2] == "mod3" 146 assert names[3] == "mod3.mod1" 147 assert names[4] == "mod3.mod2" 148 149 150def test_cells(): 151 ta = Tensor(np.ones([2, 3])) 152 tb = Tensor(np.ones([1, 4])) 153 n = Net(ta, tb) 154 ch = list(n.cells()) 155 assert len(ch) == n.cells_num 156 157 158def test_exceptions(): 159 """ test_exceptions """ 160 t = Tensor(np.ones([2, 3])) 161 162 class ModError(nn.Cell): 163 def __init__(self, tensor): 164 self.weight = Parameter(tensor, name="weight") 165 super(ModError, self).__init__() 166 167 def construct(self, *inputs): 168 pass 169 170 with pytest.raises(AttributeError): 171 ModError(t) 172 173 class ModError1(nn.Cell): 174 def __init__(self, tensor): 175 super().__init__() 176 self.weight = Parameter(tensor, name="weight") 177 self.weight = None 178 self.weight = ModA(tensor) 179 180 def construct(self, *inputs): 181 pass 182 183 with pytest.raises(TypeError): 184 ModError1(t) 185 186 class ModError2(nn.Cell): 187 def __init__(self, tensor): 188 super().__init__() 189 self.mod = ModA(tensor) 190 self.mod = None 191 self.mod = tensor 192 193 def construct(self, *inputs): 194 pass 195 196 with pytest.raises(TypeError): 197 ModError2(t) 198 199 m = nn.Cell() 200 assert m.construct() is None 201 202 203def test_cell_copy(): 204 net = ConvNet() 205 copy.deepcopy(net) 206 207 208def test_del(): 209 """ test_del """ 210 ta = Tensor(np.ones([2, 3])) 211 tb = Tensor(np.ones([1, 4])) 212 n = Net(ta, tb) 213 names = list(n.parameters_dict().keys()) 214 assert len(names) == n.name_len 215 del n.mod1 216 names = list(n.parameters_dict().keys()) 217 assert len(names) == n.name_len - 1 218 with pytest.raises(AttributeError): 219 del n.mod1.weight 220 del n.mod2.weight 221 names = list(n.parameters_dict().keys()) 222 assert len(names) == n.name_len - 2 223 with pytest.raises(AttributeError): 224 del n.mod 225 226 227def test_add_attr(): 228 """ test_add_attr """ 229 ta = Tensor(np.ones([2, 3])) 230 tb = Tensor(np.ones([1, 4])) 231 p = Parameter(ta, name="weight") 232 m = nn.Cell() 233 m.insert_param_to_cell('weight', p) 234 235 with pytest.raises(TypeError): 236 m.insert_child_to_cell("network", p) 237 238 with pytest.raises(KeyError): 239 m.insert_param_to_cell('', p) 240 with pytest.raises(KeyError): 241 m.insert_param_to_cell('a.b', p) 242 m.insert_param_to_cell('weight', p) 243 with pytest.raises(KeyError): 244 m.insert_child_to_cell('', ModA(ta)) 245 with pytest.raises(KeyError): 246 m.insert_child_to_cell('a.b', ModB(tb)) 247 248 with pytest.raises(TypeError): 249 m.insert_child_to_cell('buffer', tb) 250 with pytest.raises(TypeError): 251 m.insert_param_to_cell('w', ta) 252 with pytest.raises(TypeError): 253 m.insert_child_to_cell('m', p) 254 255 class ModAddCellError(nn.Cell): 256 def __init__(self, tensor): 257 self.mod = ModA(tensor) 258 super().__init__() 259 260 def construct(self, *inputs): 261 pass 262 263 with pytest.raises(AttributeError): 264 ModAddCellError(ta) 265 266 267def test_train_eval(): 268 m = nn.Cell() 269 assert not m.training 270 m.set_train() 271 assert m.training 272 m.set_train(False) 273 assert not m.training 274 275 276def test_stop_update_name(): 277 ta = Tensor(np.ones([2, 3])) 278 tb = Tensor(np.ones([1, 4])) 279 n = Net2(ta, tb) 280 names = list(n.parameters_dict().keys()) 281 assert names[0] == "weight" 282 assert names[1] == "mod1.weight" 283 assert names[2] == "mod2.weight" 284 285 286class ModelName(nn.Cell): 287 def __init__(self, tensor): 288 super(ModelName, self).__init__() 289 self.w2 = Parameter(tensor, name="weight") 290 self.w1 = Parameter(tensor, name="weight") 291 self.w3 = Parameter(tensor, name=None) 292 self.w4 = Parameter(tensor, name=None) 293 294 def construct(self, *inputs): 295 pass 296 297 298def test_cell_names(): 299 ta = Tensor(np.ones([2, 3])) 300 mn = ModelName(ta) 301 with pytest.raises(ValueError): 302 _cell_graph_executor.compile(mn) 303