• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""container"""
16from collections import OrderedDict
17from abc import abstractmethod
18from ..cell import Cell
19
20__all__ = ['SequentialCell', 'CellList']
21
22
23def _valid_index(cell_num, index, op_name=None):
24    """Internal function, used to detect the value and type of index."""
25    msg_prefix = f"For '{op_name}', the" if op_name else "The"
26    if not isinstance(index, int):
27        raise TypeError(f"{msg_prefix} type of index should be int type, but got {type(index).__name__}.")
28    if not -cell_num <= index < cell_num:
29        raise IndexError(f"{msg_prefix} value of index should be a number in range [{-cell_num}, {cell_num}), "
30                         f"but got {index}.")
31    return index % cell_num
32
33
34def _valid_cell(cell, op_name=None):
35    """Internal function, used to check whether the input cell is a subclass of Cell."""
36    if issubclass(cell.__class__, Cell):
37        return True
38    msg_prefix = f"For '{op_name}'," if op_name else ""
39    raise TypeError(f'{msg_prefix} each cell should be subclass of Cell. '
40                    f'Please check your code')
41
42
43def _get_prefix_and_index(cells):
44    """get prefix and index of parameter name in sequential cell or cell list."""
45    prefix = ""
46    index = 0
47    if not cells:
48        return prefix, index
49
50    cell_list = list(cells.items())
51    first_param, first_key = None, None
52    second_param, second_key = None, None
53    for key, cell in cell_list:
54        try:
55            _, param = next(cell.parameters_and_names())
56        except StopIteration:
57            continue
58        if first_param is None:
59            first_param = param
60            first_key = key
61            continue
62        second_param = param
63        second_key = key
64        break
65
66    if first_param is None:
67        return prefix, index
68
69    split_names = first_param.name.split(".")
70    for idx, name in enumerate(split_names):
71        if name == first_key:
72            prefix = ".".join(split_names[:idx])
73            prefix = prefix + "." if prefix else prefix
74            index = idx
75            if second_param is not None and second_param.name.split(".")[idx] == second_key:
76                break
77    return prefix, index
78
79
80class _CellListBase:
81    """
82    An interface for base the cell as list.
83
84    The sequential cell may be iterated using the construct method using for-in statement.
85    But there are some scenarios that the construct method built-in does not fit.
86    For convenience, we provide an interface that indicates the sequential
87    cell may be interpreted as list of cells, so it can be accessed using
88    iterator or subscript when a sequential cell instantiate is accessed
89    by iterator or subscript , it will be interpreted as a list of cells.
90    """
91    def __init__(self):
92        """Initialize _CellListBase."""
93        self.__cell_as_list__ = True
94
95    @abstractmethod
96    def __len__(self):
97        pass
98
99    @abstractmethod
100    def __getitem__(self, index):
101        pass
102
103    def construct(self):
104        raise NotImplementedError
105
106
107class SequentialCell(Cell):
108    """
109    Sequential cell container.
110
111    A list of Cells will be added to it in the order they are passed in the constructor.
112    Alternatively, an ordered dict of cells can also be passed in.
113
114    Args:
115        args (list, OrderedDict): List of subclass of Cell.
116
117    Inputs:
118        - **x** (Tensor) - Tensor with shape according to the first Cell in the sequence.
119
120    Outputs:
121        Tensor, the output Tensor with shape depending on the input `x` and defined sequence of Cells.
122
123    Raises:
124        TypeError: If the type of the `args` is not list or OrderedDict.
125
126    Supported Platforms:
127        ``Ascend`` ``GPU`` ``CPU``
128
129    Examples:
130        >>> conv = nn.Conv2d(3, 2, 3, pad_mode='valid', weight_init="ones")
131        >>> relu = nn.ReLU()
132        >>> seq = nn.SequentialCell([conv, relu])
133        >>> x = Tensor(np.ones([1, 3, 4, 4]), dtype=mindspore.float32)
134        >>> output = seq(x)
135        >>> print(output)
136        [[[[27. 27.]
137           [27. 27.]]
138          [[27. 27.]
139           [27. 27.]]]]
140    """
141    def __init__(self, *args):
142        """Initialize SequentialCell."""
143        super(SequentialCell, self).__init__()
144        self._is_dynamic_name = []
145        if len(args) == 1:
146            cells = args[0]
147            if isinstance(cells, list):
148                for index, cell in enumerate(cells):
149                    self.insert_child_to_cell(str(index), cell)
150                    cell.update_parameters_name(str(index) + ".")
151                    self._is_dynamic_name.append(True)
152            elif isinstance(cells, OrderedDict):
153                for name, cell in cells.items():
154                    self.insert_child_to_cell(name, cell)
155                    cell.update_parameters_name(name + ".")
156                    self._is_dynamic_name.append(False)
157            else:
158                raise TypeError(f"For '{self.__class__.__name__}', the 'args[0]' must be list or orderedDict, "
159                                f"but got {type(cells).__name__}")
160        else:
161            for index, cell in enumerate(args):
162                self.insert_child_to_cell(str(index), cell)
163                cell.update_parameters_name(str(index) + ".")
164                self._is_dynamic_name.append(True)
165        self.cell_list = list(self._cells.values())
166
167    def __getitem__(self, index):
168        if isinstance(index, slice):
169            return self.__class__(
170                OrderedDict(list(self._cells.items())[index]))
171        index = _valid_index(len(self), index, self.__class__.__name__)
172        return list(self._cells.values())[index]
173
174    def __setitem__(self, index, cell):
175        cls_name = self.__class__.__name__
176        if _valid_cell(cell, cls_name):
177            prefix, _ = _get_prefix_and_index(self._cells)
178            index = _valid_index(len(self), index, cls_name)
179            key = list(self._cells.keys())[index]
180            self._cells[key] = cell
181            cell.update_parameters_name(prefix + key + ".")
182            self.cell_list = list(self._cells.values())
183
184    def __delitem__(self, index):
185        cls_name = self.__class__.__name__
186        if isinstance(index, int):
187            index = _valid_index(len(self), index, cls_name)
188            key = list(self._cells.keys())[index]
189            del self._cells[key]
190            del self._is_dynamic_name[index]
191        elif isinstance(index, slice):
192            keys = list(self._cells.keys())[index]
193            for key in keys:
194                del self._cells[key]
195            del self._is_dynamic_name[index]
196        else:
197            raise TypeError(f"For '{cls_name}', the type of index should be int type or slice type, "
198                            f"but got {type(index).__name__}")
199        prefix, key_index = _get_prefix_and_index(self._cells)
200        temp_dict = OrderedDict()
201        for idx, key in enumerate(self._cells.keys()):
202            cell = self._cells[key]
203            if self._is_dynamic_name[idx]:
204                for _, param in cell.parameters_and_names():
205                    param.name = prefix + str(idx) + "." + ".".join(param.name.split(".")[key_index+1:])
206                temp_dict[str(idx)] = cell
207            else:
208                temp_dict[key] = cell
209        self._cells = temp_dict
210        self.cell_list = list(self._cells.values())
211
212    def __len__(self):
213        return len(self._cells)
214
215    def set_grad(self, flag=True):
216        self.requires_grad = flag
217        for cell in self._cells.values():
218            cell.set_grad(flag)
219
220    def append(self, cell):
221        """Appends a given cell to the end of the list.
222
223        Examples:
224            >>> conv = nn.Conv2d(3, 2, 3, pad_mode='valid', weight_init="ones")
225            >>> bn = nn.BatchNorm2d(2)
226            >>> relu = nn.ReLU()
227            >>> seq = nn.SequentialCell([conv, bn])
228            >>> seq.append(relu)
229            >>> x = Tensor(np.ones([1, 3, 4, 4]), dtype=mindspore.float32)
230            >>> output = seq(x)
231            >>> print(output)
232            [[[[26.999863 26.999863]
233               [26.999863 26.999863]]
234              [[26.999863 26.999863]
235               [26.999863 26.999863]]]]
236        """
237        if _valid_cell(cell, self.__class__.__name__):
238            prefix, _ = _get_prefix_and_index(self._cells)
239            cell.update_parameters_name(prefix + str(len(self)) + ".")
240            self._is_dynamic_name.append(True)
241            self._cells[str(len(self))] = cell
242        self.cell_list = list(self._cells.values())
243
244    def construct(self, input_data):
245        for cell in self.cell_list:
246            input_data = cell(input_data)
247        return input_data
248
249
250class CellList(_CellListBase, Cell):
251    """
252    Holds Cells in a list.
253
254    CellList can be used like a regular Python list, support
255    '__getitem__', '__setitem__', '__delitem__', '__len__', '__iter__' and '__iadd__',
256    but cells it contains are properly registered, and will be visible by all Cell methods.
257
258    Args:
259        args (list, optional): List of subclass of Cell.
260
261    Supported Platforms:
262        ``Ascend`` ``GPU`` ``CPU``
263
264    Examples:
265        >>> conv = nn.Conv2d(100, 20, 3)
266        >>> bn = nn.BatchNorm2d(20)
267        >>> relu = nn.ReLU()
268        >>> cell_ls = nn.CellList([bn])
269        >>> cell_ls.insert(0, conv)
270        >>> cell_ls.append(relu)
271        >>> print(cell_ls)
272        CellList<
273          (0): Conv2d<input_channels=100, output_channels=20, kernel_size=(3, 3),stride=(1, 1),  pad_mode=same,
274          padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=normal, bias_init=zeros, format=NCHW>
275          (1): BatchNorm2d<num_features=20, eps=1e-05, momentum=0.09999999999999998, gamma=Parameter (name=1.gamma,
276          shape=(20,), dtype=Float32, requires_grad=True), beta=Parameter (name=1.beta, shape=(20,), dtype=Float32,
277          requires_grad=True), moving_mean=Parameter (name=1.moving_mean, shape=(20,), dtype=Float32,
278          requires_grad=False), moving_variance=Parameter (name=1.moving_variance, shape=(20,), dtype=Float32,
279          requires_grad=False)>
280          (2): ReLU<>
281          >
282    """
283    def __init__(self, *args, **kwargs):
284        """Initialize CellList."""
285        auto_prefix = kwargs["auto_prefix"] if "auto_prefix" in kwargs.keys() else True
286        _CellListBase.__init__(self)
287        Cell.__init__(self, auto_prefix)
288        if len(args) == 1:
289            self.extend(args[0])
290
291    def __getitem__(self, index):
292        cls_name = self.__class__.__name__
293        if isinstance(index, slice):
294            return self.__class__(list(self._cells.values())[index])
295        if isinstance(index, int):
296            index = _valid_index(len(self), index, cls_name)
297            return self._cells[str(index)]
298        raise TypeError(f"For '{cls_name}', the type of index should be int type or slice type, "
299                        f"but got {type(index).__name__}")
300
301    def __setitem__(self, index, cell):
302        cls_name = self.__class__.__name__
303        if not isinstance(index, int) and _valid_cell(cell, cls_name):
304            raise TypeError(f"For '{cls_name}', the type of index should be int type, "
305                            f"but got {type(index).__name__}")
306        index = _valid_index(len(self), index, cls_name)
307        if self._auto_prefix:
308            prefix, _ = _get_prefix_and_index(self._cells)
309            cell.update_parameters_name(prefix + str(index) + ".")
310        self._cells[str(index)] = cell
311
312    def __delitem__(self, index):
313        cls_name = self.__class__.__name__
314        if isinstance(index, int):
315            index = _valid_index(len(self), index, cls_name)
316            del self._cells[str(index)]
317        elif isinstance(index, slice):
318            keys = list(self._cells.keys())[index]
319            for key in keys:
320                del self._cells[key]
321        else:
322            raise TypeError(f"For '{cls_name}', the type of index should be int type or slice type, "
323                            f"but got {type(index).__name__}")
324        # adjust orderedDict
325        prefix, key_index = _get_prefix_and_index(self._cells)
326        temp_dict = OrderedDict()
327        for idx, cell in enumerate(self._cells.values()):
328            if self._auto_prefix:
329                for _, param in cell.parameters_and_names():
330                    param.name = prefix + str(idx) + "." + ".".join(param.name.split(".")[key_index+1:])
331            temp_dict[str(idx)] = cell
332        self._cells = temp_dict
333
334    def __len__(self):
335        return len(self._cells)
336
337    def __iter__(self):
338        return iter(self._cells.values())
339
340    def __iadd__(self, cells):
341        self.extend(cells)
342        return self
343
344    def insert(self, index, cell):
345        """Inserts a given cell before a given index in the list."""
346        cls_name = self.__class__.__name__
347        idx = _valid_index(len(self), index, cls_name)
348        _valid_cell(cell, cls_name)
349        length = len(self)
350        prefix, key_index = _get_prefix_and_index(self._cells)
351        while length > idx:
352            if self._auto_prefix:
353                tmp_cell = self._cells[str(length-1)]
354                for _, param in tmp_cell.parameters_and_names():
355                    param.name = prefix + str(length) + "." + ".".join(param.name.split(".")[key_index+1:])
356            self._cells[str(length)] = self._cells[str(length - 1)]
357            length -= 1
358        self._cells[str(idx)] = cell
359        if self._auto_prefix:
360            cell.update_parameters_name(prefix + str(idx) + ".")
361
362    def extend(self, cells):
363        """
364        Appends cells from a Python iterable to the end of the list.
365
366        Raises:
367            TypeError: If the cells are not a list of subcells.
368        """
369        cls_name = self.__class__.__name__
370        if not isinstance(cells, list):
371            raise TypeError(f"For '{cls_name}', the new cells wanted to append "
372                            f"should be instance of list.")
373        prefix, _ = _get_prefix_and_index(self._cells)
374        for cell in cells:
375            if _valid_cell(cell, cls_name):
376                if self._auto_prefix:
377                    cell.update_parameters_name(prefix + str(len(self)) + ".")
378                self._cells[str(len(self))] = cell
379        return self
380
381    def append(self, cell):
382        """Appends a given cell to the end of the list."""
383        if _valid_cell(cell, self.__class__.__name__):
384            if self._auto_prefix:
385                prefix, _ = _get_prefix_and_index(self._cells)
386                cell.update_parameters_name(prefix + str(len(self)) + ".")
387            self._cells[str(len(self))] = cell
388
389    def set_grad(self, flag=True):
390        self.requires_grad = flag
391        for cell in self._cells.values():
392            cell.set_grad(flag)
393
394    def construct(self, *inputs):
395        raise NotImplementedError
396