• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16import numpy as np
17import pytest
18
19import mindspore.context as context
20import mindspore.nn as nn
21from mindspore import Tensor
22from mindspore.ops import operations as P
23
24context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
25
26
27class OpNetWrapper(nn.Cell):
28    def __init__(self, op):
29        super(OpNetWrapper, self).__init__()
30        self.op = op
31
32    def construct(self, *inputs):
33        return self.op(*inputs)
34
35
36@pytest.mark.level0
37@pytest.mark.platform_x86_cpu
38@pytest.mark.env_onecard
39def test_out1_axis0():
40    op = P.Split(0, 1)
41    op_wrapper = OpNetWrapper(op)
42
43    input_x = Tensor(np.arange(24).astype(np.int32).reshape((2, 2, 6)))
44    outputs = op_wrapper(input_x)
45
46    print(outputs)
47    assert outputs[0].shape == (2, 2, 6)
48    assert np.allclose(outputs[0].asnumpy()[0, 0, :], [0, 1, 2, 3, 4, 5])
49
50
51@pytest.mark.level0
52@pytest.mark.platform_x86_cpu
53@pytest.mark.env_onecard
54def test_out2_axis2():
55    op = P.Split(2, 2)
56    op_wrapper = OpNetWrapper(op)
57
58    input_x = Tensor(np.arange(24).astype(np.int32).reshape((2, 2, 6)))
59    outputs = op_wrapper(input_x)
60
61    print(outputs)
62    assert outputs[0].shape == (2, 2, 3)
63    assert outputs[1].shape == (2, 2, 3)
64    assert np.allclose(outputs[0].asnumpy()[0, 0, :], [0, 1, 2])
65    assert np.allclose(outputs[1].asnumpy()[0, 0, :], [3, 4, 5])
66
67
68@pytest.mark.level0
69@pytest.mark.platform_x86_cpu
70@pytest.mark.env_onecard
71def test_out2_axis1neg():
72    op = P.Split(-1, 2)
73    op_wrapper = OpNetWrapper(op)
74
75    input_x = Tensor(np.arange(24).astype(np.float32).reshape((2, 2, 6)))
76    outputs = op_wrapper(input_x)
77
78    print(outputs)
79    assert np.allclose(outputs[0].asnumpy()[0, :, :], [[0., 1., 2.], [6., 7., 8.]])
80    assert np.allclose(outputs[1].asnumpy()[0, :, :], [[3., 4., 5.], [9., 10., 11.]])
81
82
83@pytest.mark.level0
84@pytest.mark.platform_x86_cpu
85@pytest.mark.env_onecard
86def test_out_float32():
87    op = P.Split(5, 2)
88    op_wrapper = OpNetWrapper(op)
89
90    input_x = Tensor(np.arange(192).astype(np.float32).reshape((2, 2, 2, 2, 2, 6)))
91    outputs = op_wrapper(input_x)
92
93    assert np.allclose(outputs[0].asnumpy()[0, 0, 0, 0, 0, :], [0., 1., 2.])
94    assert np.allclose(outputs[1].asnumpy()[0, 0, 0, 0, 0, :], [3., 4., 5.])
95
96    op = P.Split(5, 3)
97    op_wrapper = OpNetWrapper(op)
98    outputs = op_wrapper(input_x)
99
100    assert np.allclose(outputs[0].asnumpy()[0, 0, 0, 0, 0, :], [0., 1.])
101    assert np.allclose(outputs[1].asnumpy()[0, 0, 0, 0, 0, :], [2., 3.])
102    assert np.allclose(outputs[2].asnumpy()[0, 0, 0, 0, 0, :], [4., 5.])
103
104
105@pytest.mark.level0
106@pytest.mark.platform_x86_cpu
107@pytest.mark.env_onecard
108def test_out_float64():
109    op = P.Split(5, 2)
110    op_wrapper = OpNetWrapper(op)
111
112    input_x = Tensor(np.arange(192).astype(np.float64).reshape((2, 2, 2, 2, 2, 6)))
113    outputs = op_wrapper(input_x)
114
115    assert np.allclose(outputs[0].asnumpy()[0, 0, 0, 0, 0, :], [0., 1., 2.])
116    assert np.allclose(outputs[1].asnumpy()[0, 0, 0, 0, 0, :], [3., 4., 5.])
117
118    op = P.Split(5, 3)
119    op_wrapper = OpNetWrapper(op)
120    outputs = op_wrapper(input_x)
121
122    assert np.allclose(outputs[0].asnumpy()[0, 0, 0, 0, 0, :], [0., 1.])
123    assert np.allclose(outputs[1].asnumpy()[0, 0, 0, 0, 0, :], [2., 3.])
124    assert np.allclose(outputs[2].asnumpy()[0, 0, 0, 0, 0, :], [4., 5.])
125
126
127@pytest.mark.level0
128@pytest.mark.platform_x86_cpu
129@pytest.mark.env_onecard
130def test_out_float16():
131    op = P.Split(-1, 2)
132    op_wrapper = OpNetWrapper(op)
133
134    input_x = Tensor(np.arange(320).astype(np.float16).reshape((2, 2, 2, 2, 2, 10)))
135    outputs = op_wrapper(input_x)
136
137    assert np.allclose(outputs[0].asnumpy()[0, 0, 0, 0, 0, :], [0., 1., 2., 3., 4.])
138    assert np.allclose(outputs[1].asnumpy()[0, 0, 0, 0, 0, :], [5., 6., 7., 8., 9.])
139
140    op = P.Split(-1, 5)
141    op_wrapper = OpNetWrapper(op)
142    outputs = op_wrapper(input_x)
143
144    assert np.allclose(outputs[0].asnumpy()[0, 0, 0, 0, 0, :], [0., 1.])
145    assert np.allclose(outputs[1].asnumpy()[0, 0, 0, 0, 0, :], [2., 3.])
146    assert np.allclose(outputs[2].asnumpy()[0, 0, 0, 0, 0, :], [4., 5.])
147    assert np.allclose(outputs[3].asnumpy()[0, 0, 0, 0, 0, :], [6., 7.])
148    assert np.allclose(outputs[4].asnumpy()[0, 0, 0, 0, 0, :], [8., 9.])
149
150
151@pytest.mark.level0
152@pytest.mark.platform_x86_cpu
153@pytest.mark.env_onecard
154def test_out_int32():
155    op = P.Split(5, 2)
156    op_wrapper = OpNetWrapper(op)
157
158    input_x = Tensor(np.arange(192).astype(np.int32).reshape((2, 2, 2, 2, 2, 6)))
159    outputs = op_wrapper(input_x)
160
161    assert np.allclose(outputs[0].asnumpy()[0, 0, 0, 0, 0, :], [0, 1, 2])
162    assert np.allclose(outputs[1].asnumpy()[0, 0, 0, 0, 0, :], [3, 4, 5])
163
164    op = P.Split(5, 3)
165    op_wrapper = OpNetWrapper(op)
166    outputs = op_wrapper(input_x)
167
168    assert np.allclose(outputs[0].asnumpy()[1, 0, 0, 0, 0, :], [96, 97])
169    assert np.allclose(outputs[1].asnumpy()[1, 0, 0, 0, 0, :], [98, 99])
170    assert np.allclose(outputs[2].asnumpy()[1, 0, 0, 0, 0, :], [100, 101])
171
172
173@pytest.mark.level0
174@pytest.mark.platform_x86_cpu
175@pytest.mark.env_onecard
176def test_out_int64():
177    op = P.Split(5, 2)
178    op_wrapper = OpNetWrapper(op)
179
180    input_x = Tensor(np.arange(192).astype(np.int64).reshape((2, 2, 2, 2, 2, 6)))
181    outputs = op_wrapper(input_x)
182
183    assert np.allclose(outputs[0].asnumpy()[0, 0, 0, 0, 0, :], [0, 1, 2])
184    assert np.allclose(outputs[1].asnumpy()[0, 0, 0, 0, 0, :], [3, 4, 5])
185
186    op = P.Split(5, 3)
187    op_wrapper = OpNetWrapper(op)
188    outputs = op_wrapper(input_x)
189
190    assert np.allclose(outputs[0].asnumpy()[1, 0, 0, 0, 0, :], [96, 97])
191    assert np.allclose(outputs[1].asnumpy()[1, 0, 0, 0, 0, :], [98, 99])
192    assert np.allclose(outputs[2].asnumpy()[1, 0, 0, 0, 0, :], [100, 101])
193
194
195@pytest.mark.level0
196@pytest.mark.platform_x86_cpu
197@pytest.mark.env_onecard
198def test_out_uint32():
199    op = P.Split(-1, 2)
200    op_wrapper = OpNetWrapper(op)
201
202    input_x = Tensor(np.arange(320).astype(np.uint32).reshape((2, 2, 2, 2, 2, 10)))
203    outputs = op_wrapper(input_x)
204
205    assert np.allclose(outputs[0].asnumpy()[0, 0, 0, 0, 0, :], [0, 1, 2, 3, 4])
206    assert np.allclose(outputs[1].asnumpy()[0, 0, 0, 0, 0, :], [5, 6, 7, 8, 9])
207
208    op = P.Split(-1, 5)
209    op_wrapper = OpNetWrapper(op)
210    outputs = op_wrapper(input_x)
211
212    assert np.allclose(outputs[0].asnumpy()[1, 1, 1, 1, 1, :], [310, 311])
213    assert np.allclose(outputs[1].asnumpy()[1, 1, 1, 1, 1, :], [312, 313])
214    assert np.allclose(outputs[2].asnumpy()[1, 1, 1, 1, 1, :], [314, 315])
215    assert np.allclose(outputs[3].asnumpy()[1, 1, 1, 1, 1, :], [316, 317])
216    assert np.allclose(outputs[4].asnumpy()[1, 1, 1, 1, 1, :], [318, 319])
217
218    op = P.Split(-2, 2)
219    op_wrapper = OpNetWrapper(op)
220    outputs = op_wrapper(input_x)
221
222    assert np.allclose(outputs[0].asnumpy()[0, 0, 0, 0, :, 0], [0])
223    assert np.allclose(outputs[1].asnumpy()[0, 0, 0, 0, :, 1], [11])
224    assert np.allclose(outputs[0].asnumpy()[1, 0, 0, 0, :, 2], [162])
225    assert np.allclose(outputs[1].asnumpy()[1, 0, 0, 0, :, 3], [173])
226    assert np.allclose(outputs[0].asnumpy()[1, 1, 0, 0, :, 4], [244])
227    assert np.allclose(outputs[1].asnumpy()[1, 1, 0, 0, :, 5], [255])
228    assert np.allclose(outputs[0].asnumpy()[1, 1, 1, 0, :, 6], [286])
229    assert np.allclose(outputs[1].asnumpy()[1, 1, 1, 0, :, 7], [297])
230    assert np.allclose(outputs[0].asnumpy()[1, 1, 1, 1, :, 8], [308])
231    assert np.allclose(outputs[1].asnumpy()[1, 1, 1, 1, :, 9], [319])
232
233    op = P.Split(-1, 1)
234    op_wrapper = OpNetWrapper(op)
235    input_x = Tensor(np.arange(1).astype(np.uint32))
236    outputs = op_wrapper(input_x)
237
238    assert np.allclose(outputs[0].asnumpy(), [0])
239
240
241if __name__ == '__main__':
242    test_out1_axis0()
243    test_out2_axis2()
244    test_out2_axis1neg()
245