• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16import numpy as np
17import pytest
18
19import mindspore.context as context
20import mindspore.nn as nn
21import mindspore.ops.operations.array_ops as P
22from mindspore import Tensor
23from mindspore.common.api import ms_function
24from mindspore.common.initializer import initializer
25from mindspore.common.parameter import Parameter
26
27
28class Net(nn.Cell):
29    def __init__(self, nptype):
30        super(Net, self).__init__()
31
32        self.unstack = P.Unstack(axis=3)
33        self.data_np = np.array([[[[[0, 0],
34                                    [-2, -1]],
35                                   [[0, 0],
36                                    [0, 1]]],
37                                  [[[0, 0],
38                                    [2, 3]],
39                                   [[0, 0],
40                                    [4, 5]]],
41                                  [[[0, 0],
42                                    [6, 7]],
43                                   [[0, 0],
44                                    [8, 9]]]],
45                                 [[[[0, 0],
46                                    [10, 11]],
47                                   [[0, 0],
48                                    [12, 13]]],
49                                  [[[0, 0],
50                                    [14, 15]],
51                                   [[0, 0],
52                                    [16, 17]]],
53                                  [[[0, 0],
54                                    [18, 19]],
55                                   [[0, 0],
56                                    [20, 21]]]],
57                                 [[[[0, 0],
58                                    [22, 23]],
59                                   [[0, 0],
60                                    [24, 25]]],
61                                  [[[0, 0],
62                                    [26, 27]],
63                                   [[0, 0],
64                                    [28, 29]]],
65                                  [[[0, 0],
66                                    [30, 31]],
67                                   [[0, 0],
68                                    [32, 33]]]]]).astype(nptype)
69        self.x1 = Parameter(initializer(Tensor(self.data_np), [3, 3, 2, 2, 2]), name='x1')
70
71    @ms_function
72    def construct(self):
73        return self.unstack(self.x1)
74
75
76def unpack(nptype):
77    context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
78    unpack_ = Net(nptype)
79    output = unpack_()
80    expect = (np.reshape(np.array([0] * 36).astype(nptype), (3, 3, 2, 2)),
81              np.arange(-2, 34, 1).reshape(3, 3, 2, 2).astype(nptype))
82
83    for i, exp in enumerate(expect):
84        assert (output[i].asnumpy() == exp).all()
85
86
87def unpack_pynative(nptype):
88    context.set_context(mode=context.PYNATIVE_MODE, device_target='CPU')
89    x1 = np.array([[[[[0, 0],
90                      [-2, -1]],
91                     [[0, 0],
92                      [0, 1]]],
93                    [[[0, 0],
94                      [2, 3]],
95                     [[0, 0],
96                      [4, 5]]],
97                    [[[0, 0],
98                      [6, 7]],
99                     [[0, 0],
100                      [8, 9]]]],
101                   [[[[0, 0],
102                      [10, 11]],
103                     [[0, 0],
104                      [12, 13]]],
105                    [[[0, 0],
106                      [14, 15]],
107                     [[0, 0],
108                      [16, 17]]],
109                    [[[0, 0],
110                      [18, 19]],
111                     [[0, 0],
112                      [20, 21]]]],
113                   [[[[0, 0],
114                      [22, 23]],
115                     [[0, 0],
116                      [24, 25]]],
117                    [[[0, 0],
118                      [26, 27]],
119                     [[0, 0],
120                      [28, 29]]],
121                    [[[0, 0],
122                      [30, 31]],
123                     [[0, 0],
124                      [32, 33]]]]]).astype(nptype)
125    x1 = Tensor(x1)
126    expect = (np.reshape(np.array([0] * 36).astype(nptype), (3, 3, 2, 2)),
127              np.arange(-2, 34, 1).reshape(3, 3, 2, 2).astype(nptype))
128    output = P.Unstack(axis=3)(x1)
129
130    for i, exp in enumerate(expect):
131        assert (output[i].asnumpy() == exp).all()
132
133
134@pytest.mark.level0
135@pytest.mark.platform_x86_cpu
136@pytest.mark.env_onecard
137def test_unpack_graph_float32():
138    unpack(np.float32)
139
140
141@pytest.mark.level0
142@pytest.mark.platform_x86_cpu
143@pytest.mark.env_onecard
144def test_unpack_graph_float16():
145    unpack(np.float16)
146
147
148@pytest.mark.level0
149@pytest.mark.platform_x86_cpu
150@pytest.mark.env_onecard
151def test_unpack_graph_int32():
152    unpack(np.int32)
153
154
155@pytest.mark.level0
156@pytest.mark.platform_x86_cpu
157@pytest.mark.env_onecard
158def test_unpack_graph_int16():
159    unpack(np.int16)
160
161
162@pytest.mark.level0
163@pytest.mark.platform_x86_cpu
164@pytest.mark.env_onecard
165def test_unpack_graph_uint8():
166    unpack(np.uint8)
167
168
169@pytest.mark.level0
170@pytest.mark.platform_x86_cpu
171@pytest.mark.env_onecard
172def test_unpack_graph_bool():
173    unpack(np.bool)
174
175
176@pytest.mark.level0
177@pytest.mark.platform_x86_cpu
178@pytest.mark.env_onecard
179def test_unpack_pynative_float32():
180    unpack_pynative(np.float32)
181
182
183@pytest.mark.level0
184@pytest.mark.platform_x86_cpu
185@pytest.mark.env_onecard
186def test_unpack_pynative_float16():
187    unpack_pynative(np.float16)
188
189
190@pytest.mark.level0
191@pytest.mark.platform_x86_cpu
192@pytest.mark.env_onecard
193def test_unpack_pynative_int32():
194    unpack_pynative(np.int32)
195
196
197@pytest.mark.level0
198@pytest.mark.platform_x86_cpu
199@pytest.mark.env_onecard
200def test_unpack_pynative_int16():
201    unpack_pynative(np.int16)
202
203
204@pytest.mark.level0
205@pytest.mark.platform_x86_cpu
206@pytest.mark.env_onecard
207def test_unpack_pynative_uint8():
208    unpack_pynative(np.uint8)
209
210
211@pytest.mark.level0
212@pytest.mark.platform_x86_cpu
213@pytest.mark.env_onecard
214def test_unpack_pynative_bool():
215    unpack_pynative(np.bool)
216