• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15""" test cell """
16import copy
17import numpy as np
18import pytest
19
20import mindspore.nn as nn
21from mindspore import Tensor, Parameter
22from mindspore.common.api import _cell_graph_executor
23
24
25class ModA(nn.Cell):
26    def __init__(self, tensor):
27        super(ModA, self).__init__()
28        self.weight = Parameter(tensor, name="weight")
29
30    def construct(self, *inputs):
31        pass
32
33
34class ModB(nn.Cell):
35    def __init__(self, tensor):
36        super(ModB, self).__init__()
37        self.weight = Parameter(tensor, name="weight")
38
39    def construct(self, *inputs):
40        pass
41
42
43class ModC(nn.Cell):
44    def __init__(self, ta, tb):
45        super(ModC, self).__init__()
46        self.mod1 = ModA(ta)
47        self.mod2 = ModB(tb)
48
49    def construct(self, *inputs):
50        pass
51
52
53class Net(nn.Cell):
54    """ Net definition """
55    name_len = 4
56    cells_num = 3
57
58    def __init__(self, ta, tb):
59        super(Net, self).__init__()
60        self.mod1 = ModA(ta)
61        self.mod2 = ModB(tb)
62        self.mod3 = ModC(ta, tb)
63
64    def construct(self, *inputs):
65        pass
66
67
68class Net2(nn.Cell):
69    def __init__(self, ta, tb):
70        super(Net2, self).__init__(auto_prefix=False)
71        self.mod1 = ModA(ta)
72        self.mod2 = ModB(tb)
73        self.mod3 = ModC(ta, tb)
74
75    def construct(self, *inputs):
76        pass
77
78
79class ConvNet(nn.Cell):
80    """ ConvNet definition """
81    image_h = 224
82    image_w = 224
83    output_ch = 64
84
85    def __init__(self, num_classes=10):
86        super(ConvNet, self).__init__()
87        self.conv1 = nn.Conv2d(3, ConvNet.output_ch, kernel_size=7, stride=2, pad_mode="pad", padding=3)
88        self.bn1 = nn.BatchNorm2d(ConvNet.output_ch)
89        self.relu = nn.ReLU()
90        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
91        self.flatten = nn.Flatten()
92        self.fc = nn.Dense(
93            int(ConvNet.image_h * ConvNet.image_w * ConvNet.output_ch / (4 * 4)),
94            num_classes)
95
96    def construct(self, x):
97        x = self.conv1(x)
98        x = self.bn1(x)
99        x = self.relu(x)
100        x = self.maxpool(x)
101        x = self.flatten(x)
102        x = self.fc(x)
103        return x
104
105
106def test_basic():
107    ta = Tensor(np.ones([2, 3]))
108    tb = Tensor(np.ones([1, 4]))
109    n = Net(ta, tb)
110    names = list(n.parameters_dict().keys())
111    assert len(names) == n.name_len
112    assert names[0] == "mod1.weight"
113    assert names[1] == "mod2.weight"
114    assert names[2] == "mod3.mod1.weight"
115    assert names[3] == "mod3.mod2.weight"
116
117
118def test_parameter_name():
119    """ test_parameter_name """
120    ta = Tensor(np.ones([2, 3]))
121    tb = Tensor(np.ones([1, 4]))
122    n = Net(ta, tb)
123    names = []
124    for m in n.parameters_and_names():
125        if m[0]:
126            names.append(m[0])
127    assert names[0] == "mod1.weight"
128    assert names[1] == "mod2.weight"
129    assert names[2] == "mod3.mod1.weight"
130    assert names[3] == "mod3.mod2.weight"
131
132
133def test_cell_name():
134    """ test_cell_name """
135    ta = Tensor(np.ones([2, 3]))
136    tb = Tensor(np.ones([1, 4]))
137    n = Net(ta, tb)
138    n.insert_child_to_cell('modNone', None)
139    names = []
140    for m in n.cells_and_names():
141        if m[0]:
142            names.append(m[0])
143    assert names[0] == "mod1"
144    assert names[1] == "mod2"
145    assert names[2] == "mod3"
146    assert names[3] == "mod3.mod1"
147    assert names[4] == "mod3.mod2"
148
149
150def test_cells():
151    ta = Tensor(np.ones([2, 3]))
152    tb = Tensor(np.ones([1, 4]))
153    n = Net(ta, tb)
154    ch = list(n.cells())
155    assert len(ch) == n.cells_num
156
157
158def test_exceptions():
159    """ test_exceptions """
160    t = Tensor(np.ones([2, 3]))
161
162    class ModError(nn.Cell):
163        def __init__(self, tensor):
164            self.weight = Parameter(tensor, name="weight")
165            super(ModError, self).__init__()
166
167        def construct(self, *inputs):
168            pass
169
170    with pytest.raises(AttributeError):
171        ModError(t)
172
173    class ModError1(nn.Cell):
174        def __init__(self, tensor):
175            super().__init__()
176            self.weight = Parameter(tensor, name="weight")
177            self.weight = None
178            self.weight = ModA(tensor)
179
180        def construct(self, *inputs):
181            pass
182
183    with pytest.raises(TypeError):
184        ModError1(t)
185
186    class ModError2(nn.Cell):
187        def __init__(self, tensor):
188            super().__init__()
189            self.mod = ModA(tensor)
190            self.mod = None
191            self.mod = tensor
192
193        def construct(self, *inputs):
194            pass
195
196    with pytest.raises(TypeError):
197        ModError2(t)
198
199    m = nn.Cell()
200    assert m.construct() is None
201
202
203def test_cell_copy():
204    net = ConvNet()
205    copy.deepcopy(net)
206
207
208def test_del():
209    """ test_del """
210    ta = Tensor(np.ones([2, 3]))
211    tb = Tensor(np.ones([1, 4]))
212    n = Net(ta, tb)
213    names = list(n.parameters_dict().keys())
214    assert len(names) == n.name_len
215    del n.mod1
216    names = list(n.parameters_dict().keys())
217    assert len(names) == n.name_len - 1
218    with pytest.raises(AttributeError):
219        del n.mod1.weight
220    del n.mod2.weight
221    names = list(n.parameters_dict().keys())
222    assert len(names) == n.name_len - 2
223    with pytest.raises(AttributeError):
224        del n.mod
225
226
227def test_add_attr():
228    """ test_add_attr """
229    ta = Tensor(np.ones([2, 3]))
230    tb = Tensor(np.ones([1, 4]))
231    p = Parameter(ta, name="weight")
232    m = nn.Cell()
233    m.insert_param_to_cell('weight', p)
234
235    with pytest.raises(TypeError):
236        m.insert_child_to_cell("network", p)
237
238    with pytest.raises(KeyError):
239        m.insert_param_to_cell('', p)
240    with pytest.raises(KeyError):
241        m.insert_param_to_cell('a.b', p)
242    m.insert_param_to_cell('weight', p)
243    with pytest.raises(KeyError):
244        m.insert_child_to_cell('', ModA(ta))
245    with pytest.raises(KeyError):
246        m.insert_child_to_cell('a.b', ModB(tb))
247
248    with pytest.raises(TypeError):
249        m.insert_child_to_cell('buffer', tb)
250    with pytest.raises(TypeError):
251        m.insert_param_to_cell('w', ta)
252    with pytest.raises(TypeError):
253        m.insert_child_to_cell('m', p)
254
255    class ModAddCellError(nn.Cell):
256        def __init__(self, tensor):
257            self.mod = ModA(tensor)
258            super().__init__()
259
260        def construct(self, *inputs):
261            pass
262
263    with pytest.raises(AttributeError):
264        ModAddCellError(ta)
265
266
267def test_train_eval():
268    m = nn.Cell()
269    assert not m.training
270    m.set_train()
271    assert m.training
272    m.set_train(False)
273    assert not m.training
274
275
276def test_stop_update_name():
277    ta = Tensor(np.ones([2, 3]))
278    tb = Tensor(np.ones([1, 4]))
279    n = Net2(ta, tb)
280    names = list(n.parameters_dict().keys())
281    assert names[0] == "weight"
282    assert names[1] == "mod1.weight"
283    assert names[2] == "mod2.weight"
284
285
286class ModelName(nn.Cell):
287    def __init__(self, tensor):
288        super(ModelName, self).__init__()
289        self.w2 = Parameter(tensor, name="weight")
290        self.w1 = Parameter(tensor, name="weight")
291        self.w3 = Parameter(tensor, name=None)
292        self.w4 = Parameter(tensor, name=None)
293
294    def construct(self, *inputs):
295        pass
296
297
298def test_cell_names():
299    ta = Tensor(np.ones([2, 3]))
300    mn = ModelName(ta)
301    with pytest.raises(ValueError):
302        _cell_graph_executor.compile(mn)
303