• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2022 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16import pytest
17import mindspore as ms
18import mindspore.nn as nn
19from mindspore import Tensor
20import numpy as np
21from collections import OrderedDict
22
23
24class TestGetitemMethodNet(nn.Cell):
25    def __init__(self):
26        super(TestGetitemMethodNet, self).__init__()
27        self.cell_dict = nn.CellDict([['conv', nn.Conv2d(6, 16, 5, pad_mode='valid')],
28                                      ['relu', nn.ReLU()],
29                                      ['max_pool2d', nn.MaxPool2d(kernel_size=4, stride=4)]]
30                                     )
31
32    def construct(self):
33        return self.cell_dict['conv']
34
35
36@pytest.mark.level1
37@pytest.mark.platform_x86_cpu
38@pytest.mark.platform_arm_cpu
39@pytest.mark.platform_x86_gpu_training
40@pytest.mark.platform_arm_ascend_training
41@pytest.mark.platform_x86_ascend_training
42@pytest.mark.env_onecard
43@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
44def test_celldict_getitem_method(mode):
45    """
46    Feature: CellDict.__getitem__()
47    Description: Verify the result of CellDict.__getitem__().
48    Expectation: success
49    """
50    net = TestGetitemMethodNet()
51    x = Tensor(np.ones([1, 6, 16, 5]), ms.float32)
52    conv2d_op = nn.Conv2d(6, 16, 5, pad_mode='valid')
53    expect_output = conv2d_op(x)
54    net_op = net()
55    output = net_op(x)
56    assert np.allclose(output.shape, expect_output.shape)
57
58
59class TestSetitemMethodNet(nn.Cell):
60    def __init__(self):
61        super(TestSetitemMethodNet, self).__init__()
62        self.cell_dict = nn.CellDict([['conv', nn.Conv2d(10, 16, 5, pad_mode='valid')],
63                                      ['relu', nn.ReLU()],
64                                      ['max_pool2d', nn.MaxPool2d(kernel_size=4, stride=4)]]
65                                     )
66
67    def construct(self):
68        self.cell_dict['conv'] = nn.Conv2d(6, 16, 5, pad_mode='valid')
69        return self.cell_dict['conv']
70
71
72@pytest.mark.level1
73@pytest.mark.platform_x86_cpu
74@pytest.mark.platform_arm_cpu
75@pytest.mark.platform_x86_gpu_training
76@pytest.mark.platform_arm_ascend_training
77@pytest.mark.platform_x86_ascend_training
78@pytest.mark.env_onecard
79@pytest.mark.parametrize('mode', [ms.PYNATIVE_MODE])
80def test_celldict_setitem_method(mode):
81    """
82    Feature: CellDict.__setitem__()
83    Description: Verify the result of CellDict.__setitem__().
84    Expectation: success
85    """
86    net = TestSetitemMethodNet()
87    x = Tensor(np.ones([1, 6, 16, 5]), ms.float32)
88    conv2d_op = nn.Conv2d(6, 16, 5, pad_mode='valid')
89    expect_output = conv2d_op(x)
90    net_op = net()
91    output = net_op(x)
92    assert np.allclose(output.shape, expect_output.shape)
93
94
95class TestSetitemMethodErrCaseNet(nn.Cell):
96    def __init__(self):
97        super(TestSetitemMethodErrCaseNet, self).__init__()
98        self.cell_dict = nn.CellDict([['conv', nn.Conv2d(10, 16, 5, pad_mode='valid')],
99                                      ['relu', nn.ReLU()],
100                                      ['max_pool2d', nn.MaxPool2d(kernel_size=4, stride=4)]]
101                                     )
102
103    def construct(self, key, cell):
104        self.cell_dict[key] = cell
105        return self.cell_dict[key]
106
107
108@pytest.mark.level1
109@pytest.mark.platform_x86_cpu
110@pytest.mark.platform_arm_cpu
111@pytest.mark.platform_x86_gpu_training
112@pytest.mark.platform_arm_ascend_training
113@pytest.mark.platform_x86_ascend_training
114@pytest.mark.env_onecard
115@pytest.mark.parametrize('mode', [ms.PYNATIVE_MODE])
116def test_celldict_setitem_error_case_method(mode):
117    """
118    Feature: CellDict.__setitem__()
119    Description: Verify the result of CellDict.__setitem__() in error input.
120    Expectation: success
121    """
122    net = TestSetitemMethodErrCaseNet()
123
124    cell = nn.Conv1d(120, 240, 4, has_bias=False, weight_init='normal')
125    key = 1
126    with pytest.raises(TypeError):
127        net(key, cell)
128
129    cell = nn.Conv1d(120, 240, 4, has_bias=False, weight_init='normal')
130    key = "_scope"
131    with pytest.raises(KeyError):
132        net(key, cell)
133
134    cell = nn.Conv1d(120, 240, 4, has_bias=False, weight_init='normal')
135    key = ".conv1d"
136    with pytest.raises(KeyError):
137        net(key, cell)
138
139    cell = nn.Conv1d(120, 240, 4, has_bias=False, weight_init='normal')
140    key = ""
141    with pytest.raises(KeyError):
142        net(key, cell)
143
144    cell = None
145    key = "conv1d"
146    with pytest.raises(TypeError):
147        net(key, cell)
148
149    cell = 1
150    key = "conv1d"
151    with pytest.raises(TypeError):
152        net(key, cell)
153
154class TestDelitemMethodNet(nn.Cell):
155    def __init__(self):
156        super(TestDelitemMethodNet, self).__init__()
157        self.cell_dict = nn.CellDict([['conv', nn.Conv2d(10, 16, 5, pad_mode='valid')],
158                                      ['relu', nn.ReLU()],
159                                      ['max_pool2d', nn.MaxPool2d(kernel_size=4, stride=4)]]
160                                     )
161
162    def construct(self, key1, key2):
163        del self.cell_dict[key1]
164        del self.cell_dict[key2]
165        return len(self.cell_dict)
166
167
168@pytest.mark.level1
169@pytest.mark.platform_x86_cpu
170@pytest.mark.platform_arm_cpu
171@pytest.mark.platform_x86_gpu_training
172@pytest.mark.platform_arm_ascend_training
173@pytest.mark.platform_x86_ascend_training
174@pytest.mark.env_onecard
175@pytest.mark.parametrize('mode', [ms.PYNATIVE_MODE])
176def test_celldict_delitem_method(mode):
177    """
178    Feature: CellDict.__delitem__()
179    Description: Verify the result of CellDict.__delitem__().
180    Expectation: success
181    """
182    net = TestDelitemMethodNet()
183    expect_output = 1
184    output = net('conv', 'relu')
185    assert np.allclose(output, expect_output)
186
187
188class TestContainsMethodNet(nn.Cell):
189    def __init__(self):
190        super(TestContainsMethodNet, self).__init__()
191        self.cell_dict = nn.CellDict([['conv', nn.Conv2d(10, 16, 5, pad_mode='valid')],
192                                      ['relu', nn.ReLU()],
193                                      ['max_pool2d', nn.MaxPool2d(kernel_size=4, stride=4)]]
194                                     )
195
196    def construct(self, key1, key2):
197        ret1 = key1 in self.cell_dict
198        ret2 = key2 in self.cell_dict
199        return ret1, ret2
200
201
202@pytest.mark.level1
203@pytest.mark.platform_x86_cpu
204@pytest.mark.platform_arm_cpu
205@pytest.mark.platform_x86_gpu_training
206@pytest.mark.platform_arm_ascend_training
207@pytest.mark.platform_x86_ascend_training
208@pytest.mark.env_onecard
209@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
210def test_celldict_contains_method(mode):
211    """
212    Feature: CellDict.__contains__()
213    Description: Verify the result of CellDict.__contains__().
214    Expectation: success
215    """
216    net = TestContainsMethodNet()
217    expect_output1 = True
218    expect_output2 = False
219    output1, output2 = net('conv', 'relu1')
220    assert expect_output1 == output1
221    assert expect_output2 == output2
222
223
224class TestClearMethodNet(nn.Cell):
225    def __init__(self):
226        super(TestClearMethodNet, self).__init__()
227        self.cell_dict = nn.CellDict([['conv', nn.Conv2d(10, 16, 5, pad_mode='valid')],
228                                      ['relu', nn.ReLU()],
229                                      ['max_pool2d', nn.MaxPool2d(kernel_size=4, stride=4)]]
230                                     )
231
232    def construct(self):
233        self.cell_dict.clear()
234        return len(self.cell_dict)
235
236
237@pytest.mark.level1
238@pytest.mark.platform_x86_cpu
239@pytest.mark.platform_arm_cpu
240@pytest.mark.platform_x86_gpu_training
241@pytest.mark.platform_arm_ascend_training
242@pytest.mark.platform_x86_ascend_training
243@pytest.mark.env_onecard
244@pytest.mark.parametrize('mode', [ms.PYNATIVE_MODE])
245def test_celldict_clear_method(mode):
246    """
247    Feature: CellDict.clear()
248    Description: Verify the result of CellDict.clear().
249    Expectation: success
250    """
251    net = TestClearMethodNet()
252    expect_output = 0
253    output = net()
254    assert np.allclose(expect_output, output)
255
256
257class TestPopMethodNet(nn.Cell):
258    def __init__(self):
259        super(TestPopMethodNet, self).__init__()
260        self.cell_dict = nn.CellDict([['conv', nn.Conv2d(10, 16, 5, pad_mode='valid')],
261                                      ['relu', nn.ReLU()],
262                                      ['max_pool2d', nn.MaxPool2d(kernel_size=4, stride=4)]]
263                                     )
264
265    def construct(self, key):
266        op = self.cell_dict.pop(key)
267        cell_dict_len = len(self.cell_dict)
268        return op, cell_dict_len
269
270
271@pytest.mark.level1
272@pytest.mark.platform_x86_cpu
273@pytest.mark.platform_arm_cpu
274@pytest.mark.platform_x86_gpu_training
275@pytest.mark.platform_arm_ascend_training
276@pytest.mark.platform_x86_ascend_training
277@pytest.mark.env_onecard
278@pytest.mark.parametrize('mode', [ms.PYNATIVE_MODE])
279def test_celldict_pop_method(mode):
280    """
281    Feature: CellDict.pop()
282    Description: Verify the result of CellDict.pop().
283    Expectation: success
284    """
285    net = TestPopMethodNet()
286    conv_op = nn.Conv2d(10, 16, 5, pad_mode='valid')
287    x = Tensor(np.ones([1, 10, 6, 5]), ms.float32)
288    expect_output = conv_op(x)
289    expect_len = 2
290    op, cell_dict_len = net('conv')
291    output = op(x)
292    assert np.allclose(expect_output.shape, output.shape)
293    assert np.allclose(expect_len, cell_dict_len)
294
295
296class TestKeysMethodNet(nn.Cell):
297    def __init__(self):
298        super(TestKeysMethodNet, self).__init__()
299        self.cell_dict = nn.CellDict([['conv', nn.Conv2d(10, 16, 5, pad_mode='valid')],
300                                      ['relu', nn.ReLU()],
301                                      ['max_pool2d', nn.MaxPool2d(kernel_size=4, stride=4)]]
302                                     )
303
304    def construct(self):
305        return self.cell_dict.keys()
306
307
308@pytest.mark.level1
309@pytest.mark.platform_x86_cpu
310@pytest.mark.platform_arm_cpu
311@pytest.mark.platform_x86_gpu_training
312@pytest.mark.platform_arm_ascend_training
313@pytest.mark.platform_x86_ascend_training
314@pytest.mark.env_onecard
315@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
316def test_celldict_keys_method(mode):
317    """
318    Feature: CellDict.keys()
319    Description: Verify the result of CellDict.keys().
320    Expectation: success
321    """
322    net = TestKeysMethodNet()
323    expect_keys = ['conv', 'relu', 'max_pool2d']
324    cell_dict_keys = net()
325    for key, expect_key in zip(cell_dict_keys, expect_keys):
326        assert key == expect_key
327
328
329class TestValuesMethodNet(nn.Cell):
330    def __init__(self):
331        super(TestValuesMethodNet, self).__init__()
332        self.cell_dict = nn.CellDict([['conv', nn.Conv2d(10, 16, 5, pad_mode='valid')],
333                                      ['relu', nn.ReLU()],
334                                      ['max_pool2d', nn.MaxPool2d(kernel_size=4, stride=4)]]
335                                     )
336
337    def construct(self):
338        return self.cell_dict.values()
339
340
341@pytest.mark.level1
342@pytest.mark.platform_x86_cpu
343@pytest.mark.platform_arm_cpu
344@pytest.mark.platform_x86_gpu_training
345@pytest.mark.platform_arm_ascend_training
346@pytest.mark.platform_x86_ascend_training
347@pytest.mark.env_onecard
348@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
349def test_celldict_values_method(mode):
350    """
351    Feature: CellDict.values()
352    Description: Verify the result of CellDict.values().
353    Expectation: success
354    """
355    net = TestValuesMethodNet()
356    conv2d_op = nn.Conv2d(10, 16, 5, pad_mode='valid')
357    relu_op = nn.ReLU()
358    maxpool2d_op = nn.MaxPool2d(kernel_size=4, stride=4)
359    x = Tensor(np.ones([1, 10, 16, 10]), ms.float32)
360    expect_x = conv2d_op(x)
361    expect_x = relu_op(expect_x)
362    expect_x = maxpool2d_op(expect_x)
363
364    cell_dict_values = net()
365    for cell in cell_dict_values:
366        x = cell(x)
367
368    assert np.allclose(x.shape, expect_x.shape)
369
370
371class TestItemsMethodNet(nn.Cell):
372    def __init__(self):
373        super(TestItemsMethodNet, self).__init__()
374        self.cell_dict = nn.CellDict([['conv', nn.Conv2d(10, 16, 5, pad_mode='valid')],
375                                      ['relu', nn.ReLU()],
376                                      ['max_pool2d', nn.MaxPool2d(kernel_size=4, stride=4)]]
377                                     )
378
379    def construct(self):
380        return self.cell_dict.items()
381
382
383@pytest.mark.level1
384@pytest.mark.platform_x86_cpu
385@pytest.mark.platform_arm_cpu
386@pytest.mark.platform_x86_gpu_training
387@pytest.mark.platform_arm_ascend_training
388@pytest.mark.platform_x86_ascend_training
389@pytest.mark.env_onecard
390@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
391def test_celldict_items_method(mode):
392    """
393    Feature: CellDict.items()
394    Description: Verify the result of CellDict.items().
395    Expectation: success
396    """
397    net = TestItemsMethodNet()
398    expect_keys = ['conv', 'relu', 'max_pool2d']
399    cell_dict_items = net()
400    for item, expect_key in zip(cell_dict_items, expect_keys):
401        assert item[0] == expect_key
402
403    conv2d_op = nn.Conv2d(10, 16, 5, pad_mode='valid')
404    relu_op = nn.ReLU()
405    maxpool2d_op = nn.MaxPool2d(kernel_size=4, stride=4)
406    x = Tensor(np.ones([1, 10, 16, 10]), ms.float32)
407    expect_x = conv2d_op(x)
408    expect_x = relu_op(expect_x)
409    expect_x = maxpool2d_op(expect_x)
410    for item in cell_dict_items:
411        x = item[1](x)
412    assert np.allclose(x.shape, expect_x.shape)
413
414
415class TestUpdateMethodNet(nn.Cell):
416    def __init__(self):
417        super(TestUpdateMethodNet, self).__init__()
418        self.cell_dict = nn.CellDict([['conv', nn.Conv2d(10, 16, 5, pad_mode='same')],
419                                      ['relu', nn.ReLU()],
420                                      ['max_pool2d', nn.MaxPool2d(kernel_size=4, stride=4)]]
421                                     )
422
423    def construct(self):
424        x = Tensor(np.array([[180, 234, 154], [244, 48, 247]]), ms.float32)
425        y = Tensor(np.ones([1, 10, 6, 10]), ms.float32)
426
427        # 用包含键值对的列表更新CellDict
428        self.cell_dict.clear()
429        cell_list = [['dense1', nn.Dense(3, 4)], ['dense2', nn.Dense(4, 6)], ['dense3', nn.Dense(6, 8)]]
430        self.cell_dict.update(cell_list)
431        output1 = x
432        for cell in self.cell_dict.values():
433            output1 = cell(output1)
434
435        # 用OrderDict更新CellDict
436        self.cell_dict.clear()
437        cell_order_dict = OrderedDict([('conv', nn.Conv2d(10, 6, 5, pad_mode='same')),
438                                       ('relu', nn.ReLU()),
439                                       ('max_pool2d', nn.MaxPool2d(kernel_size=4, stride=4))]
440                                      )
441        self.cell_dict.update(cell_order_dict)
442        output2 = y
443        for cell in self.cell_dict.values():
444            output2 = cell(output2)
445
446        # 用CellDict更新CellDict
447        self.cell_dict.clear()
448        cell_dict = nn.CellDict([['conv', nn.Conv2d(10, 6, 5, pad_mode='same')],
449                                 ['relu', nn.ReLU()],
450                                 ['max_pool2d', nn.MaxPool2d(kernel_size=4, stride=4)]]
451                                )
452        self.cell_dict.update(cell_dict)
453        output3 = y
454        for cell in self.cell_dict.values():
455            output3 = cell(output3)
456
457        return output1, output2, output3
458
459
460@pytest.mark.level0
461@pytest.mark.platform_x86_cpu
462@pytest.mark.platform_arm_cpu
463@pytest.mark.platform_x86_gpu_training
464@pytest.mark.env_onecard
465@pytest.mark.platform_arm_ascend_training
466@pytest.mark.platform_x86_ascend_training
467@pytest.mark.parametrize('mode', [ms.PYNATIVE_MODE])
468def test_celldict_update_method(mode):
469    """
470    Feature: CellDict.update()
471    Description: Verify the result of CellDict.update().
472    Expectation: success
473    """
474    net = TestUpdateMethodNet()
475    x = Tensor(np.array([[180, 234, 154], [244, 48, 247]]), ms.float32)
476    y = Tensor(np.ones([1, 10, 6, 10]), ms.float32)
477
478    dense_op1 = nn.Dense(3, 4)
479    dense_op2 = nn.Dense(4, 6)
480    dense_op3 = nn.Dense(6, 8)
481    expect_dense_output = x
482    expect_dense_output = dense_op1(expect_dense_output)
483    expect_dense_output = dense_op2(expect_dense_output)
484    expect_dense_output = dense_op3(expect_dense_output)
485
486    conv2d_op = nn.Conv2d(10, 6, 5, pad_mode='same')
487    relu_op = nn.ReLU()
488    maxpool2d_op = nn.MaxPool2d(kernel_size=4, stride=4)
489    expect_output = y
490    expect_output = conv2d_op(expect_output)
491    expect_output = relu_op(expect_output)
492    expect_output = maxpool2d_op(expect_output)
493
494    output1, output2, output3 = net()
495    assert np.allclose(expect_dense_output.shape, output1.shape)
496    assert np.allclose(expect_output.shape, output2.shape)
497    assert np.allclose(expect_output.shape, output3.shape)
498
499
500class TestUpdateMethodEmbeddedNet(nn.Cell):
501    def __init__(self):
502        super(TestUpdateMethodEmbeddedNet, self).__init__()
503        self.cell_dict = nn.CellDict([['conv', nn.Conv2d(10, 16, 5, pad_mode='same')],
504                                      ['relu', nn.ReLU()],
505                                      ['max_pool2d', nn.MaxPool2d(kernel_size=4, stride=4)]]
506                                     )
507    def construct(self, object_list):
508        self.cell_dict.update(object_list)
509
510@pytest.mark.level1
511@pytest.mark.platform_x86_cpu
512@pytest.mark.platform_arm_cpu
513@pytest.mark.platform_x86_gpu_training
514@pytest.mark.platform_arm_ascend_training
515@pytest.mark.platform_x86_ascend_training
516@pytest.mark.env_onecard
517@pytest.mark.parametrize('mode', [ms.PYNATIVE_MODE])
518def test_celldict_update_method_embedded_case(mode):
519    """
520    Feature: CellDict.update()
521    Description: Verify the result of CellDict.update() in embedded_case.
522    Expectation: success
523    """
524    net = TestUpdateMethodEmbeddedNet()
525    cell_dict = nn.CellDict({'conv': nn.Conv2d(1, 1, 3), 'Dense': nn.Dense(2, 2)})
526    cell_list = nn.CellList([nn.Dense(2, 2)])
527    conv = nn.Conv2d(3, 2, 3, pad_mode='valid', weight_init="ones")
528    relu = nn.ReLU()
529    seq_cell = nn.SequentialCell([conv, relu])
530
531    celldict_embedded_list = [['cell_dict', cell_dict]]
532    celllist_embedded_list = [['cell_list', cell_list]]
533    seqcell_embedded_list = [['seq_cell', seq_cell]]
534
535    with pytest.raises(TypeError):
536        net(celldict_embedded_list)
537
538    with pytest.raises(TypeError):
539        net(celllist_embedded_list)
540
541    with pytest.raises(TypeError):
542        net(seqcell_embedded_list)
543
544class DupParaNameNet1(nn.Cell):
545    def __init__(self):
546        super(DupParaNameNet1, self).__init__()
547        self.cell_dict1 = nn.CellDict({'conv2d': nn.Conv2d(20, 20, 5),
548                                       'pool2d': nn.MaxPool2d(7)}
549                                      )
550        self.cell_dict2 = nn.CellDict({'conv2d': nn.Conv2d(20, 20, 5),
551                                       'pool2d': nn.MaxPool2d(7)}
552                                      )
553
554    def construct(self, x1, x2):
555        a = self.cell_dict1['conv2d'](x1)
556        b = self.cell_dict2['conv2d'](x2)
557        return a + b
558
559
560class DupParaNameNet2(nn.Cell):
561    def __init__(self):
562        super(DupParaNameNet2, self).__init__()
563        self.cell_dict1 = nn.CellDict({'dense': nn.Dense(3, 4)})
564        self.cell_dict2 = nn.CellDict({'dense': nn.Dense(3, 4)})
565
566    def construct(self, x1, x2):
567        a = self.cell_dict1['dense'](x1)
568        b = self.cell_dict2['dense'](x2)
569        return a + b
570
571
572@pytest.mark.level1
573@pytest.mark.platform_x86_cpu
574@pytest.mark.platform_arm_cpu
575@pytest.mark.platform_x86_gpu_training
576@pytest.mark.platform_arm_ascend_training
577@pytest.mark.platform_x86_ascend_training
578@pytest.mark.env_onecard
579@pytest.mark.parametrize('mode', [ms.PYNATIVE_MODE])
580def test_duplicate_para_name_case(mode):
581    """
582    Feature: Verify the same parameter names of two CellDicts within the same net can be distinguished.
583    Description: Within a net, constructing two CellDicts which are same.
584    Expectation: success
585    """
586    net = DupParaNameNet1()
587    x1 = Tensor(np.ones([1, 20, 20, 10]), ms.float32)
588    x2 = Tensor(np.ones([1, 20, 20, 1]), ms.float32)
589    output = net(x1, x2)
590    expect_output_shape = (1, 20, 20, 10)
591    assert np.allclose(output.shape, expect_output_shape)
592
593    net = DupParaNameNet2()
594    x1 = Tensor(np.array([[180, 234, 154], [244, 48, 247]]), ms.float32)
595    x2 = Tensor(np.array([[110, 134, 150], [224, 148, 347]]), ms.float32)
596    output = net(x1, x2)
597    expect_output_shape = (2, 4)
598    assert np.allclose(output.shape, expect_output_shape)
599