1# Copyright 2020-2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""container""" 16from collections import OrderedDict 17from abc import abstractmethod 18from ..cell import Cell 19 20__all__ = ['SequentialCell', 'CellList'] 21 22 23def _valid_index(cell_num, index, op_name=None): 24 """Internal function, used to detect the value and type of index.""" 25 msg_prefix = f"For '{op_name}', the" if op_name else "The" 26 if not isinstance(index, int): 27 raise TypeError(f"{msg_prefix} type of index should be int type, but got {type(index).__name__}.") 28 if not -cell_num <= index < cell_num: 29 raise IndexError(f"{msg_prefix} value of index should be a number in range [{-cell_num}, {cell_num}), " 30 f"but got {index}.") 31 return index % cell_num 32 33 34def _valid_cell(cell, op_name=None): 35 """Internal function, used to check whether the input cell is a subclass of Cell.""" 36 if issubclass(cell.__class__, Cell): 37 return True 38 msg_prefix = f"For '{op_name}'," if op_name else "" 39 raise TypeError(f'{msg_prefix} each cell should be subclass of Cell. ' 40 f'Please check your code') 41 42 43def _get_prefix_and_index(cells): 44 """get prefix and index of parameter name in sequential cell or cell list.""" 45 prefix = "" 46 index = 0 47 if not cells: 48 return prefix, index 49 50 cell_list = list(cells.items()) 51 first_param, first_key = None, None 52 second_param, second_key = None, None 53 for key, cell in cell_list: 54 try: 55 _, param = next(cell.parameters_and_names()) 56 except StopIteration: 57 continue 58 if first_param is None: 59 first_param = param 60 first_key = key 61 continue 62 second_param = param 63 second_key = key 64 break 65 66 if first_param is None: 67 return prefix, index 68 69 split_names = first_param.name.split(".") 70 for idx, name in enumerate(split_names): 71 if name == first_key: 72 prefix = ".".join(split_names[:idx]) 73 prefix = prefix + "." if prefix else prefix 74 index = idx 75 if second_param is not None and second_param.name.split(".")[idx] == second_key: 76 break 77 return prefix, index 78 79 80class _CellListBase: 81 """ 82 An interface for base the cell as list. 83 84 The sequential cell may be iterated using the construct method using for-in statement. 85 But there are some scenarios that the construct method built-in does not fit. 86 For convenience, we provide an interface that indicates the sequential 87 cell may be interpreted as list of cells, so it can be accessed using 88 iterator or subscript when a sequential cell instantiate is accessed 89 by iterator or subscript , it will be interpreted as a list of cells. 90 """ 91 def __init__(self): 92 """Initialize _CellListBase.""" 93 self.__cell_as_list__ = True 94 95 @abstractmethod 96 def __len__(self): 97 pass 98 99 @abstractmethod 100 def __getitem__(self, index): 101 pass 102 103 def construct(self): 104 raise NotImplementedError 105 106 107class SequentialCell(Cell): 108 """ 109 Sequential cell container. 110 111 A list of Cells will be added to it in the order they are passed in the constructor. 112 Alternatively, an ordered dict of cells can also be passed in. 113 114 Args: 115 args (list, OrderedDict): List of subclass of Cell. 116 117 Inputs: 118 - **x** (Tensor) - Tensor with shape according to the first Cell in the sequence. 119 120 Outputs: 121 Tensor, the output Tensor with shape depending on the input `x` and defined sequence of Cells. 122 123 Raises: 124 TypeError: If the type of the `args` is not list or OrderedDict. 125 126 Supported Platforms: 127 ``Ascend`` ``GPU`` ``CPU`` 128 129 Examples: 130 >>> conv = nn.Conv2d(3, 2, 3, pad_mode='valid', weight_init="ones") 131 >>> relu = nn.ReLU() 132 >>> seq = nn.SequentialCell([conv, relu]) 133 >>> x = Tensor(np.ones([1, 3, 4, 4]), dtype=mindspore.float32) 134 >>> output = seq(x) 135 >>> print(output) 136 [[[[27. 27.] 137 [27. 27.]] 138 [[27. 27.] 139 [27. 27.]]]] 140 """ 141 def __init__(self, *args): 142 """Initialize SequentialCell.""" 143 super(SequentialCell, self).__init__() 144 self._is_dynamic_name = [] 145 if len(args) == 1: 146 cells = args[0] 147 if isinstance(cells, list): 148 for index, cell in enumerate(cells): 149 self.insert_child_to_cell(str(index), cell) 150 cell.update_parameters_name(str(index) + ".") 151 self._is_dynamic_name.append(True) 152 elif isinstance(cells, OrderedDict): 153 for name, cell in cells.items(): 154 self.insert_child_to_cell(name, cell) 155 cell.update_parameters_name(name + ".") 156 self._is_dynamic_name.append(False) 157 else: 158 raise TypeError(f"For '{self.__class__.__name__}', the 'args[0]' must be list or orderedDict, " 159 f"but got {type(cells).__name__}") 160 else: 161 for index, cell in enumerate(args): 162 self.insert_child_to_cell(str(index), cell) 163 cell.update_parameters_name(str(index) + ".") 164 self._is_dynamic_name.append(True) 165 self.cell_list = list(self._cells.values()) 166 167 def __getitem__(self, index): 168 if isinstance(index, slice): 169 return self.__class__( 170 OrderedDict(list(self._cells.items())[index])) 171 index = _valid_index(len(self), index, self.__class__.__name__) 172 return list(self._cells.values())[index] 173 174 def __setitem__(self, index, cell): 175 cls_name = self.__class__.__name__ 176 if _valid_cell(cell, cls_name): 177 prefix, _ = _get_prefix_and_index(self._cells) 178 index = _valid_index(len(self), index, cls_name) 179 key = list(self._cells.keys())[index] 180 self._cells[key] = cell 181 cell.update_parameters_name(prefix + key + ".") 182 self.cell_list = list(self._cells.values()) 183 184 def __delitem__(self, index): 185 cls_name = self.__class__.__name__ 186 if isinstance(index, int): 187 index = _valid_index(len(self), index, cls_name) 188 key = list(self._cells.keys())[index] 189 del self._cells[key] 190 del self._is_dynamic_name[index] 191 elif isinstance(index, slice): 192 keys = list(self._cells.keys())[index] 193 for key in keys: 194 del self._cells[key] 195 del self._is_dynamic_name[index] 196 else: 197 raise TypeError(f"For '{cls_name}', the type of index should be int type or slice type, " 198 f"but got {type(index).__name__}") 199 prefix, key_index = _get_prefix_and_index(self._cells) 200 temp_dict = OrderedDict() 201 for idx, key in enumerate(self._cells.keys()): 202 cell = self._cells[key] 203 if self._is_dynamic_name[idx]: 204 for _, param in cell.parameters_and_names(): 205 param.name = prefix + str(idx) + "." + ".".join(param.name.split(".")[key_index+1:]) 206 temp_dict[str(idx)] = cell 207 else: 208 temp_dict[key] = cell 209 self._cells = temp_dict 210 self.cell_list = list(self._cells.values()) 211 212 def __len__(self): 213 return len(self._cells) 214 215 def set_grad(self, flag=True): 216 self.requires_grad = flag 217 for cell in self._cells.values(): 218 cell.set_grad(flag) 219 220 def append(self, cell): 221 """Appends a given cell to the end of the list. 222 223 Examples: 224 >>> conv = nn.Conv2d(3, 2, 3, pad_mode='valid', weight_init="ones") 225 >>> bn = nn.BatchNorm2d(2) 226 >>> relu = nn.ReLU() 227 >>> seq = nn.SequentialCell([conv, bn]) 228 >>> seq.append(relu) 229 >>> x = Tensor(np.ones([1, 3, 4, 4]), dtype=mindspore.float32) 230 >>> output = seq(x) 231 >>> print(output) 232 [[[[26.999863 26.999863] 233 [26.999863 26.999863]] 234 [[26.999863 26.999863] 235 [26.999863 26.999863]]]] 236 """ 237 if _valid_cell(cell, self.__class__.__name__): 238 prefix, _ = _get_prefix_and_index(self._cells) 239 cell.update_parameters_name(prefix + str(len(self)) + ".") 240 self._is_dynamic_name.append(True) 241 self._cells[str(len(self))] = cell 242 self.cell_list = list(self._cells.values()) 243 244 def construct(self, input_data): 245 for cell in self.cell_list: 246 input_data = cell(input_data) 247 return input_data 248 249 250class CellList(_CellListBase, Cell): 251 """ 252 Holds Cells in a list. 253 254 CellList can be used like a regular Python list, support 255 '__getitem__', '__setitem__', '__delitem__', '__len__', '__iter__' and '__iadd__', 256 but cells it contains are properly registered, and will be visible by all Cell methods. 257 258 Args: 259 args (list, optional): List of subclass of Cell. 260 261 Supported Platforms: 262 ``Ascend`` ``GPU`` ``CPU`` 263 264 Examples: 265 >>> conv = nn.Conv2d(100, 20, 3) 266 >>> bn = nn.BatchNorm2d(20) 267 >>> relu = nn.ReLU() 268 >>> cell_ls = nn.CellList([bn]) 269 >>> cell_ls.insert(0, conv) 270 >>> cell_ls.append(relu) 271 >>> print(cell_ls) 272 CellList< 273 (0): Conv2d<input_channels=100, output_channels=20, kernel_size=(3, 3),stride=(1, 1), pad_mode=same, 274 padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=normal, bias_init=zeros, format=NCHW> 275 (1): BatchNorm2d<num_features=20, eps=1e-05, momentum=0.09999999999999998, gamma=Parameter (name=1.gamma, 276 shape=(20,), dtype=Float32, requires_grad=True), beta=Parameter (name=1.beta, shape=(20,), dtype=Float32, 277 requires_grad=True), moving_mean=Parameter (name=1.moving_mean, shape=(20,), dtype=Float32, 278 requires_grad=False), moving_variance=Parameter (name=1.moving_variance, shape=(20,), dtype=Float32, 279 requires_grad=False)> 280 (2): ReLU<> 281 > 282 """ 283 def __init__(self, *args, **kwargs): 284 """Initialize CellList.""" 285 auto_prefix = kwargs["auto_prefix"] if "auto_prefix" in kwargs.keys() else True 286 _CellListBase.__init__(self) 287 Cell.__init__(self, auto_prefix) 288 if len(args) == 1: 289 self.extend(args[0]) 290 291 def __getitem__(self, index): 292 cls_name = self.__class__.__name__ 293 if isinstance(index, slice): 294 return self.__class__(list(self._cells.values())[index]) 295 if isinstance(index, int): 296 index = _valid_index(len(self), index, cls_name) 297 return self._cells[str(index)] 298 raise TypeError(f"For '{cls_name}', the type of index should be int type or slice type, " 299 f"but got {type(index).__name__}") 300 301 def __setitem__(self, index, cell): 302 cls_name = self.__class__.__name__ 303 if not isinstance(index, int) and _valid_cell(cell, cls_name): 304 raise TypeError(f"For '{cls_name}', the type of index should be int type, " 305 f"but got {type(index).__name__}") 306 index = _valid_index(len(self), index, cls_name) 307 if self._auto_prefix: 308 prefix, _ = _get_prefix_and_index(self._cells) 309 cell.update_parameters_name(prefix + str(index) + ".") 310 self._cells[str(index)] = cell 311 312 def __delitem__(self, index): 313 cls_name = self.__class__.__name__ 314 if isinstance(index, int): 315 index = _valid_index(len(self), index, cls_name) 316 del self._cells[str(index)] 317 elif isinstance(index, slice): 318 keys = list(self._cells.keys())[index] 319 for key in keys: 320 del self._cells[key] 321 else: 322 raise TypeError(f"For '{cls_name}', the type of index should be int type or slice type, " 323 f"but got {type(index).__name__}") 324 # adjust orderedDict 325 prefix, key_index = _get_prefix_and_index(self._cells) 326 temp_dict = OrderedDict() 327 for idx, cell in enumerate(self._cells.values()): 328 if self._auto_prefix: 329 for _, param in cell.parameters_and_names(): 330 param.name = prefix + str(idx) + "." + ".".join(param.name.split(".")[key_index+1:]) 331 temp_dict[str(idx)] = cell 332 self._cells = temp_dict 333 334 def __len__(self): 335 return len(self._cells) 336 337 def __iter__(self): 338 return iter(self._cells.values()) 339 340 def __iadd__(self, cells): 341 self.extend(cells) 342 return self 343 344 def insert(self, index, cell): 345 """Inserts a given cell before a given index in the list.""" 346 cls_name = self.__class__.__name__ 347 idx = _valid_index(len(self), index, cls_name) 348 _valid_cell(cell, cls_name) 349 length = len(self) 350 prefix, key_index = _get_prefix_and_index(self._cells) 351 while length > idx: 352 if self._auto_prefix: 353 tmp_cell = self._cells[str(length-1)] 354 for _, param in tmp_cell.parameters_and_names(): 355 param.name = prefix + str(length) + "." + ".".join(param.name.split(".")[key_index+1:]) 356 self._cells[str(length)] = self._cells[str(length - 1)] 357 length -= 1 358 self._cells[str(idx)] = cell 359 if self._auto_prefix: 360 cell.update_parameters_name(prefix + str(idx) + ".") 361 362 def extend(self, cells): 363 """ 364 Appends cells from a Python iterable to the end of the list. 365 366 Raises: 367 TypeError: If the cells are not a list of subcells. 368 """ 369 cls_name = self.__class__.__name__ 370 if not isinstance(cells, list): 371 raise TypeError(f"For '{cls_name}', the new cells wanted to append " 372 f"should be instance of list.") 373 prefix, _ = _get_prefix_and_index(self._cells) 374 for cell in cells: 375 if _valid_cell(cell, cls_name): 376 if self._auto_prefix: 377 cell.update_parameters_name(prefix + str(len(self)) + ".") 378 self._cells[str(len(self))] = cell 379 return self 380 381 def append(self, cell): 382 """Appends a given cell to the end of the list.""" 383 if _valid_cell(cell, self.__class__.__name__): 384 if self._auto_prefix: 385 prefix, _ = _get_prefix_and_index(self._cells) 386 cell.update_parameters_name(prefix + str(len(self)) + ".") 387 self._cells[str(len(self))] = cell 388 389 def set_grad(self, flag=True): 390 self.requires_grad = flag 391 for cell in self._cells.values(): 392 cell.set_grad(flag) 393 394 def construct(self, *inputs): 395 raise NotImplementedError 396