• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15""" test container """
16from collections import OrderedDict
17import numpy as np
18import pytest
19
20import mindspore.nn as nn
21from mindspore import context, Tensor
22
23context.set_context(mode=context.PYNATIVE_MODE)
24
25
26weight = Tensor(np.ones([2, 2]))
27conv2 = nn.Conv2d(3, 64, (3, 3), stride=2, padding=0)
28
29kernel_size = 3
30stride = 2
31padding = 1
32avg_pool = nn.AvgPool2d(kernel_size, stride)
33
34
35class TestSequentialCell():
36    """ TestSequentialCell """
37
38    def test_SequentialCell_init(self):
39        m = nn.SequentialCell()
40        assert type(m).__name__ == 'SequentialCell'
41
42    def test_SequentialCell_init2(self):
43        m = nn.SequentialCell([conv2])
44        assert len(m) == 1
45
46    def test_SequentialCell_init3(self):
47        m = nn.SequentialCell([conv2, avg_pool])
48        assert len(m) == 2
49
50    def test_SequentialCell_init4(self):
51        m = nn.SequentialCell(OrderedDict(
52            [('cov2d', conv2), ('avg_pool', avg_pool)]))
53        assert len(m) == 2
54
55    def test_getitem1(self):
56        m = nn.SequentialCell(OrderedDict(
57            [('cov2d', conv2), ('avg_pool', avg_pool)]))
58        assert m[0] == conv2
59
60    def test_getitem2(self):
61        m = nn.SequentialCell(OrderedDict(
62            [('cov2d', conv2), ('avg_pool', avg_pool)]))
63        assert len(m[0:2]) == 2
64        assert m[:2][1] == avg_pool
65
66    def test_setitem1(self):
67        m = nn.SequentialCell(OrderedDict(
68            [('cov2d', conv2), ('avg_pool', avg_pool)]))
69        m[1] = conv2
70        assert m[1] == m[0]
71
72    def test_setitem2(self):
73        m = nn.SequentialCell(OrderedDict(
74            [('cov2d', conv2), ('avg_pool', avg_pool)]))
75        with pytest.raises(TypeError):
76            m[1.0] = conv2
77
78    def test_delitem1(self):
79        m = nn.SequentialCell(OrderedDict(
80            [('cov2d', conv2), ('avg_pool', avg_pool)]))
81        del m[0]
82        assert len(m) == 1
83
84    def test_delitem2(self):
85        m = nn.SequentialCell(OrderedDict(
86            [('cov2d', conv2), ('avg_pool', avg_pool)]))
87        del m[:]
88        assert type(m).__name__ == 'SequentialCell'
89
90    def test_sequentialcell_append(self):
91        input_np = np.ones((1, 3)).astype(np.float32)
92        input_me = Tensor(input_np)
93        relu = nn.ReLU()
94        tanh = nn.Tanh()
95        seq = nn.SequentialCell([relu])
96        seq.append(tanh)
97        out_me = seq(input_me)
98
99        seq1 = nn.SequentialCell([relu, tanh])
100        out = seq1(input_me)
101
102        assert out[0][0] == out_me[0][0]
103
104
105class TestCellList():
106    """ TestCellList """
107
108    def test_init1(self):
109        cell_list = nn.CellList([conv2, avg_pool])
110        assert len(cell_list) == 2
111
112    def test_init2(self):
113        with pytest.raises(TypeError):
114            nn.CellList(["test"])
115
116    def test_getitem(self):
117        cell_list = nn.CellList([conv2, avg_pool])
118        assert cell_list[0] == conv2
119        temp_cells = cell_list[:]
120        assert temp_cells[1] == avg_pool
121
122    def test_setitem(self):
123        cell_list = nn.CellList([conv2, avg_pool])
124        cell_list[0] = avg_pool
125        assert cell_list[0] == cell_list[1]
126
127    def test_delitem(self):
128        cell_list = nn.CellList([conv2, avg_pool])
129        del cell_list[0]
130        assert len(cell_list) == 1
131        del cell_list[:]
132        assert type(cell_list).__name__ == 'CellList'
133
134    def test_iter(self):
135        cell_list = nn.CellList([conv2, avg_pool])
136        for item in cell_list:
137            cell = item
138        assert type(cell).__name__ == 'AvgPool2d'
139
140    def test_add(self):
141        cell_list = nn.CellList([conv2, avg_pool])
142        cell_list += [conv2]
143        assert len(cell_list) == 3
144        assert cell_list[0] == cell_list[2]
145
146    def test_insert(self):
147        cell_list = nn.CellList([conv2, avg_pool])
148        cell_list.insert(0, avg_pool)
149        assert len(cell_list) == 3
150        assert cell_list[0] == cell_list[2]
151
152    def test_append(self):
153        cell_list = nn.CellList([conv2, avg_pool])
154        cell_list.append(conv2)
155        assert len(cell_list) == 3
156        assert cell_list[0] == cell_list[2]
157
158    def test_extend(self):
159        cell_list = nn.CellList()
160        cell_list.extend([conv2, avg_pool])
161        assert len(cell_list) == 2
162        assert cell_list[0] == conv2
163