• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15""" test_cell """
16import numpy as np
17import pytest
18
19import mindspore.nn as nn
20from mindspore import Tensor, Parameter
21from ...ut_filter import non_graph_engine
22
23
24class ModA(nn.Cell):
25    """ ModA definition """
26
27    def __init__(self, tensor):
28        super(ModA, self).__init__()
29        self.weight = Parameter(tensor, name="weight")
30
31    def construct(self, *inputs):
32        pass
33
34
35class ModB(nn.Cell):
36    """ ModB definition """
37
38    def __init__(self, tensor):
39        super(ModB, self).__init__()
40        self.weight = Parameter(tensor, name="weight")
41
42    def construct(self, *inputs):
43        pass
44
45
46class ModC(nn.Cell):
47    """ ModC definition """
48
49    def __init__(self, ta, tb):
50        super(ModC, self).__init__()
51        self.mod1 = ModA(ta)
52        self.mod2 = ModB(tb)
53
54    def construct(self, *inputs):
55        pass
56
57
58class Net(nn.Cell):
59    """ Net definition """
60    name_len = 4
61    cells_num = 3
62
63    def __init__(self, ta, tb):
64        super(Net, self).__init__()
65        self.mod1 = ModA(ta)
66        self.mod2 = ModB(tb)
67        self.mod3 = ModC(ta, tb)
68
69    def construct(self, *inputs):
70        pass
71
72
73class Net2(nn.Cell):
74    """ Net2 definition """
75
76    def __init__(self, ta, tb):
77        super(Net2, self).__init__(auto_prefix=False)
78        self.mod1 = ModA(ta)
79        self.mod2 = ModB(tb)
80        self.mod3 = ModC(ta, tb)
81
82    def construct(self, *inputs):
83        pass
84
85
86class ConvNet(nn.Cell):
87    """ ConvNet definition """
88    image_h = 224
89    image_w = 224
90    output_ch = 64
91
92    def __init__(self, num_classes=10):
93        super(ConvNet, self).__init__()
94        self.conv1 = nn.Conv2d(3, ConvNet.output_ch, kernel_size=7, stride=2, pad_mode='pad', padding=3)
95        self.bn1 = nn.BatchNorm2d(ConvNet.output_ch)
96        self.relu = nn.ReLU()
97        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
98        self.flatten = nn.Flatten()
99        self.fc = nn.Dense(
100            int(ConvNet.image_h * ConvNet.image_w * ConvNet.output_ch / (4 * 4)),
101            num_classes)
102
103    def construct(self, x):
104        x = self.conv1(x)
105        x = self.bn1(x)
106        x = self.relu(x)
107        x = self.maxpool(x)
108        x = self.flatten(x)
109        x = self.fc(x)
110        return x
111
112
113def test_basic():
114    """ test_basic """
115    ta = Tensor(np.ones([2, 3]))
116    tb = Tensor(np.ones([1, 4]))
117    n = Net(ta, tb)
118    names = list(n.parameters_dict().keys())
119    assert len(names) == n.name_len
120    assert names[0] == "mod1.weight"
121    assert names[1] == "mod2.weight"
122    assert names[2] == "mod3.mod1.weight"
123    assert names[3] == "mod3.mod2.weight"
124
125
126def test_parameter_name():
127    """ test_parameter_name """
128    ta = Tensor(np.ones([2, 3]))
129    tb = Tensor(np.ones([1, 4]))
130    n = Net(ta, tb)
131    names = []
132    for m in n.parameters_and_names():
133        if m[0]:
134            names.append(m[0])
135    assert names[0] == "mod1.weight"
136    assert names[1] == "mod2.weight"
137    assert names[2] == "mod3.mod1.weight"
138    assert names[3] == "mod3.mod2.weight"
139
140
141def test_cell_name():
142    """ test_cell_name """
143    ta = Tensor(np.ones([2, 3]))
144    tb = Tensor(np.ones([1, 4]))
145    n = Net(ta, tb)
146    n.insert_child_to_cell('modNone', None)
147    names = []
148    for m in n.cells_and_names():
149        if m[0]:
150            names.append(m[0])
151    assert names[0] == "mod1"
152    assert names[1] == "mod2"
153    assert names[2] == "mod3"
154    assert names[3] == "mod3.mod1"
155    assert names[4] == "mod3.mod2"
156
157
158def test_cells():
159    """ test_cells """
160    ta = Tensor(np.ones([2, 3]))
161    tb = Tensor(np.ones([1, 4]))
162    n = Net(ta, tb)
163    ch = list(n.cells())
164    assert len(ch) == n.cells_num
165
166
167def test_exceptions():
168    """ test_exceptions """
169    t = Tensor(np.ones([2, 3]))
170
171    class ModError(nn.Cell):
172        """ ModError definition """
173
174        def __init__(self, tensor):
175            self.weight = Parameter(tensor, name="weight")
176            super(ModError, self).__init__()
177
178        def construct(self, *inputs):
179            pass
180
181    with pytest.raises(AttributeError):
182        ModError(t)
183
184    class ModError1(nn.Cell):
185        """ ModError1 definition """
186
187        def __init__(self, tensor):
188            super().__init__()
189            self.weight = Parameter(tensor, name="weight")
190            self.weight = None
191            self.weight = ModA(tensor)
192
193        def construct(self, *inputs):
194            pass
195
196    with pytest.raises(TypeError):
197        ModError1(t)
198
199    class ModError2(nn.Cell):
200        """ ModError2 definition """
201
202        def __init__(self, tensor):
203            super().__init__()
204            self.mod = ModA(tensor)
205            self.mod = None
206            self.mod = tensor
207
208        def construct(self, *inputs):
209            pass
210
211    with pytest.raises(TypeError):
212        ModError2(t)
213
214    m = nn.Cell()
215    assert m.construct() is None
216
217
218def test_del():
219    """ test_del """
220    ta = Tensor(np.ones([2, 3]))
221    tb = Tensor(np.ones([1, 4]))
222    n = Net(ta, tb)
223    names = list(n.parameters_dict().keys())
224    assert len(names) == n.name_len
225    del n.mod1
226    names = list(n.parameters_dict().keys())
227    assert len(names) == n.name_len - 1
228    with pytest.raises(AttributeError):
229        del n.mod1.weight
230    del n.mod2.weight
231    names = list(n.parameters_dict().keys())
232    assert len(names) == n.name_len - 2
233    with pytest.raises(AttributeError):
234        del n.mod
235
236
237def test_add_attr():
238    """ test_add_attr """
239    ta = Tensor(np.ones([2, 3]))
240    tb = Tensor(np.ones([1, 4]))
241    p = Parameter(ta, name="weight")
242    m = nn.Cell()
243    m.insert_param_to_cell('weight', p)
244
245    with pytest.raises(TypeError):
246        m.insert_child_to_cell("network", p)
247
248    with pytest.raises(KeyError):
249        m.insert_param_to_cell('', p)
250    with pytest.raises(KeyError):
251        m.insert_param_to_cell('a.b', p)
252    m.insert_param_to_cell('weight', p)
253    with pytest.raises(KeyError):
254        m.insert_child_to_cell('', ModA(ta))
255    with pytest.raises(KeyError):
256        m.insert_child_to_cell('a.b', ModB(tb))
257
258    with pytest.raises(TypeError):
259        m.insert_child_to_cell('buffer', tb)
260    with pytest.raises(TypeError):
261        m.insert_param_to_cell('w', ta)
262    with pytest.raises(TypeError):
263        m.insert_child_to_cell('m', p)
264
265    class ModAddCellError(nn.Cell):
266        """ ModAddCellError definition """
267
268        def __init__(self, tensor):
269            self.mod = ModA(tensor)
270            super().__init__()
271
272        def construct(self, *inputs):
273            pass
274
275    with pytest.raises(AttributeError):
276        ModAddCellError(ta)
277
278
279def test_train_eval():
280    """ test_train_eval """
281    m = nn.Cell()
282    assert not m.training
283    m.set_train()
284    assert m.training
285    m.set_train(False)
286    assert not m.training
287
288
289def test_stop_update_name():
290    """ test_stop_update_name """
291    ta = Tensor(np.ones([2, 3]))
292    tb = Tensor(np.ones([1, 4]))
293    n = Net2(ta, tb)
294    names = list(n.parameters_dict().keys())
295    assert names[0] == "weight"
296    assert names[1] == "mod1.weight"
297    assert names[2] == "mod2.weight"
298
299
300@non_graph_engine
301def test_net_call():
302    """ test_net_call """
303    with pytest.raises(ValueError):
304        net = ConvNet()
305        input_x = Tensor(
306            np.random.randint(0, 255, [1, 3, net.image_h, net.image_w]).astype(np.float32))
307        net.construct(input_x)
308