• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15""" test_container """
16from collections import OrderedDict
17import numpy as np
18import pytest
19
20import mindspore.nn as nn
21from mindspore import context, Tensor
22
23context.set_context(mode=context.PYNATIVE_MODE)
24
25
26weight = Tensor(np.ones([2, 2]))
27conv2 = nn.Conv2d(3, 64, (3, 3), stride=2, padding=0)
28
29kernel_size = 3
30stride = 2
31padding = 1
32avg_pool = nn.AvgPool2d(kernel_size, stride)
33
34
35class TestSequentialCell():
36    """ TestSequentialCell definition """
37
38    def test_SequentialCell_init(self):
39        m = nn.SequentialCell()
40        assert not m
41
42    def test_SequentialCell_init2(self):
43        m = nn.SequentialCell([conv2])
44        assert len(m) == 1
45
46    def test_SequentialCell_init3(self):
47        m = nn.SequentialCell([conv2, avg_pool])
48        assert len(m) == 2
49
50    def test_SequentialCell_init4(self):
51        m = nn.SequentialCell(OrderedDict(
52            [('cov2d', conv2), ('avg_pool', avg_pool)]))
53        assert len(m) == 2
54
55    def test_getitem1(self):
56        m = nn.SequentialCell(OrderedDict(
57            [('cov2d', conv2), ('avg_pool', avg_pool)]))
58        assert m[0] == conv2
59
60    def test_getitem2(self):
61        m = nn.SequentialCell(OrderedDict(
62            [('cov2d', conv2), ('avg_pool', avg_pool)]))
63        assert len(m[0:2]) == 2
64        assert m[:2][1] == avg_pool
65
66    def test_setitem1(self):
67        m = nn.SequentialCell(OrderedDict(
68            [('cov2d', conv2), ('avg_pool', avg_pool)]))
69        m[1] = conv2
70        assert m[1] == m[0]
71
72    def test_setitem2(self):
73        m = nn.SequentialCell(OrderedDict(
74            [('cov2d', conv2), ('avg_pool', avg_pool)]))
75        with pytest.raises(TypeError):
76            m[1.0] = conv2
77
78    def test_delitem1(self):
79        m = nn.SequentialCell(OrderedDict(
80            [('cov2d', conv2), ('avg_pool', avg_pool)]))
81        del m[0]
82        assert len(m) == 1
83
84    def test_delitem2(self):
85        m = nn.SequentialCell(OrderedDict(
86            [('cov2d', conv2), ('avg_pool', avg_pool)]))
87        del m[:]
88        assert not m
89
90    def test_construct(self):
91        m = nn.SequentialCell(OrderedDict(
92            [('cov2d', conv2), ('avg_pool', avg_pool)]))
93        m.construct(Tensor(np.ones([1, 3, 16, 50], dtype=np.float32)))
94
95
96class TestCellList():
97    """ TestCellList definition """
98
99    def test_init1(self):
100        cell_list = nn.CellList([conv2, avg_pool])
101        assert len(cell_list) == 2
102
103    def test_init2(self):
104        with pytest.raises(TypeError):
105            nn.CellList(["test"])
106
107    def test_getitem(self):
108        cell_list = nn.CellList([conv2, avg_pool])
109        assert cell_list[0] == conv2
110        temp_cells = cell_list[:]
111        assert temp_cells[1] == avg_pool
112
113    def test_setitem(self):
114        cell_list = nn.CellList([conv2, avg_pool])
115        cell_list[0] = avg_pool
116        assert cell_list[0] == cell_list[1]
117
118    def test_delitem(self):
119        cell_list = nn.CellList([conv2, avg_pool])
120        del cell_list[0]
121        assert len(cell_list) == 1
122        del cell_list[:]
123        assert not cell_list
124
125    def test_iter(self):
126        cell_list = nn.CellList([conv2, avg_pool])
127        for _ in cell_list:
128            break
129
130    def test_add(self):
131        cell_list = nn.CellList([conv2, avg_pool])
132        cell_list += [conv2]
133        assert len(cell_list) == 3
134        assert cell_list[0] == cell_list[2]
135
136    def test_insert(self):
137        cell_list = nn.CellList([conv2, avg_pool])
138        cell_list.insert(0, avg_pool)
139        assert len(cell_list) == 3
140        assert cell_list[0] == cell_list[2]
141
142    def test_append(self):
143        cell_list = nn.CellList([conv2, avg_pool])
144        cell_list.append(conv2)
145        assert len(cell_list) == 3
146        assert cell_list[0] == cell_list[2]
147
148    def test_extend(self):
149        cell_list = nn.CellList()
150        cell_list.extend([conv2, avg_pool])
151        assert len(cell_list) == 2
152        assert cell_list[0] == conv2
153