• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16import pytest
17import numpy as np
18from mindspore import Tensor
19from mindspore.ops import operations as P
20import mindspore.nn as nn
21import mindspore.ops as ops
22import mindspore.context as context
23
24context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
25
26class ConcatV10(nn.Cell):
27    def __init__(self, nptype):
28        super(ConcatV10, self).__init__()
29
30        self.cat = P.Concat(axis=2)
31        self.x1 = Tensor(np.array([[[0., 0., 1.],
32                                    [1., 2., 3.]],
33                                   [[2., 4., 5.],
34                                    [3., 6., 7.]]]).astype(nptype))
35
36    def construct(self):
37        return self.cat((self.x1,))
38
39
40def axis10(nptype):
41    cat = ConcatV10(nptype)
42    output = cat()
43    expect = np.array([[[0., 0., 1.],
44                        [1., 2., 3.]],
45                       [[2., 4., 5.],
46                        [3., 6., 7.]]]).astype(nptype)
47    print(output)
48    assert (output.asnumpy() == expect).all()
49
50
51@pytest.mark.level0
52@pytest.mark.platform_x86_cpu
53@pytest.mark.env_onecard
54def test_axis10_float32():
55    axis10(np.float32)
56
57@pytest.mark.level0
58@pytest.mark.platform_x86_cpu
59@pytest.mark.env_onecard
60def test_axis10_int32():
61    axis10(np.int32)
62
63@pytest.mark.level0
64@pytest.mark.platform_x86_cpu
65@pytest.mark.env_onecard
66def test_axis10_bool():
67    axis10(np.bool)
68
69
70class ConcatV32(nn.Cell):
71    def __init__(self, nptype):
72        super(ConcatV32, self).__init__()
73
74        self.cat = P.Concat(axis=2)
75        self.x1 = Tensor(np.arange(2 * 2 * 1).reshape(2, 2, 1).astype(nptype))
76        self.x2 = Tensor(np.arange(2 * 2 * 2).reshape(2, 2, 2).astype(nptype))
77
78    def construct(self):
79        return self.cat((self.x1, self.x2))
80
81
82def axis32(nptype):
83    cat = ConcatV32(nptype)
84    output = cat()
85    expect = np.array([[[0., 0., 1.],
86                        [1., 2., 3.]],
87                       [[2., 4., 5.],
88                        [3., 6., 7.]]]).astype(nptype)
89    print(output)
90    assert (output.asnumpy() == expect).all()
91
92@pytest.mark.level0
93@pytest.mark.platform_x86_cpu
94@pytest.mark.env_onecard
95def test_axis32_float32():
96    axis32(np.float32)
97
98@pytest.mark.level0
99@pytest.mark.platform_x86_cpu
100@pytest.mark.env_onecard
101def test_axis32_int32():
102    axis32(np.int32)
103
104@pytest.mark.level0
105@pytest.mark.platform_x86_cpu
106@pytest.mark.env_onecard
107def test_axis32_bool():
108    axis32(np.bool)
109
110
111class ConcatWithList(nn.Cell):
112    def __init__(self):
113        super(ConcatWithList, self).__init__()
114        self.concat = P.Concat(axis=2)
115
116    def construct(self, x, y):
117        input_list = [x, y]
118        return self.concat(input_list)
119
120
121class ConcatWithTuple(nn.Cell):
122    def __init__(self):
123        super(ConcatWithTuple, self).__init__()
124        self.concat = P.Concat(axis=2)
125
126    def construct(self, x, y):
127        input_list = (x, y)
128        return self.concat(input_list)
129
130
131class GradConcat(nn.Cell):
132    def __init__(self, network):
133        super(GradConcat, self).__init__()
134        self.grad = ops.GradOperation()
135        self.network = network
136
137    def construct(self, x, y):
138        gout = self.grad(self.network)(x, y)
139        return gout
140
141@pytest.mark.level0
142@pytest.mark.platform_x86_cpu
143@pytest.mark.env_onecard
144def test_concat_list_grad():
145    x1 = Tensor(np.arange(2 * 2 * 1).reshape(2, 2, 1).astype(np.float32))
146    x2 = Tensor(np.arange(2 * 2 * 2).reshape(2, 2, 2).astype(np.float32))
147    concat = ConcatWithList()
148    output = GradConcat(concat)(x1, x2)
149    expect = np.array([[[1.],
150                        [1.]],
151                       [[1.],
152                        [1.]]]).astype(np.float32)
153    print(output)
154    assert (output.asnumpy() == expect).all()
155
156
157@pytest.mark.level0
158@pytest.mark.platform_x86_cpu
159@pytest.mark.env_onecard
160def test_concat_tuple_grad():
161    x1 = Tensor(np.arange(2 * 2 * 1).reshape(2, 2, 1).astype(np.float32))
162    x2 = Tensor(np.arange(2 * 2 * 2).reshape(2, 2, 2).astype(np.float32))
163    concat = ConcatWithTuple()
164    output = GradConcat(concat)(x1, x2)
165    expect = np.array([[[1.],
166                        [1.]],
167                       [[1.],
168                        [1.]]]).astype(np.float32)
169    print(output)
170    assert (output.asnumpy() == expect).all()
171
172
173class ConcatV43(nn.Cell):
174    def __init__(self, nptype):
175        super(ConcatV43, self).__init__()
176
177        self.cat = P.Concat(axis=3)
178        self.x1 = Tensor(np.arange(2 * 2 * 2 * 2).reshape(2, 2, 2, 2).astype(nptype))
179        self.x2 = Tensor(np.arange(2 * 2 * 2 * 3).reshape(2, 2, 2, 3).astype(nptype))
180
181    def construct(self):
182        return self.cat((self.x1, self.x2))
183
184
185def axis43(nptype):
186    cat = ConcatV43(nptype)
187    output = cat()
188    expect = np.array([[[[0., 1., 0., 1., 2.],
189                         [2., 3., 3., 4., 5.]],
190                        [[4., 5., 6., 7., 8.],
191                         [6., 7., 9., 10., 11.]]],
192                       [[[8., 9., 12., 13., 14.],
193                         [10., 11., 15., 16., 17.]],
194                        [[12., 13., 18., 19., 20.],
195                         [14., 15., 21., 22., 23.]]]]).astype(nptype)
196    assert (output.asnumpy() == expect).all()
197    print(output)
198
199
200@pytest.mark.level0
201@pytest.mark.platform_x86_cpu
202@pytest.mark.env_onecard
203def test_axis43_float32():
204    axis43(np.float32)
205
206@pytest.mark.level0
207@pytest.mark.platform_x86_cpu
208@pytest.mark.env_onecard
209def test_axis43_int32():
210    axis43(np.int32)
211
212@pytest.mark.level0
213@pytest.mark.platform_x86_cpu
214@pytest.mark.env_onecard
215def test_axis43_bool():
216    axis43(np.bool)
217
218
219class ConcatV21(nn.Cell):
220    def __init__(self, nptype):
221        super(ConcatV21, self).__init__()
222
223        self.cat = P.Concat(axis=1)
224        self.x1 = Tensor(np.arange(2 * 2).reshape(2, 2).astype(nptype))
225        self.x2 = Tensor(np.arange(2 * 3).reshape(2, 3).astype(nptype))
226
227    def construct(self):
228        return self.cat((self.x1, self.x2))
229
230
231def axis21(nptype):
232    cat = ConcatV21(nptype)
233    output = cat()
234    expect = np.array([[0., 1., 0., 1., 2.],
235                       [2., 3., 3., 4., 5.]]).astype(nptype)
236    assert (output.asnumpy() == expect).all()
237    print(output)
238
239
240@pytest.mark.level0
241@pytest.mark.platform_x86_cpu
242@pytest.mark.env_onecard
243def test_axis21_float32():
244    axis21(np.float32)
245
246@pytest.mark.level0
247@pytest.mark.platform_x86_cpu
248@pytest.mark.env_onecard
249def test_axis21_int32():
250    axis21(np.int32)
251
252@pytest.mark.level0
253@pytest.mark.platform_x86_cpu
254@pytest.mark.env_onecard
255def test_axis21_bool():
256    axis21(np.bool)
257
258
259class Concat3INet(nn.Cell):
260    def __init__(self):
261        super(Concat3INet, self).__init__()
262        self.cat = P.Concat(axis=1)
263
264    def construct(self, x1, x2, x3):
265        return self.cat((x1, x2, x3))
266
267
268def concat_3i(nptype):
269    cat = Concat3INet()
270
271    x1_np = np.random.randn(32, 4, 224, 224).astype(nptype)
272    x2_np = np.random.randn(32, 8, 224, 224).astype(nptype)
273    x3_np = np.random.randn(32, 10, 224, 224).astype(nptype)
274    output_np = np.concatenate((x1_np, x2_np, x3_np), axis=1)
275
276    x1_ms = Tensor(x1_np)
277    x2_ms = Tensor(x2_np)
278    x3_ms = Tensor(x3_np)
279    output_ms = cat(x1_ms, x2_ms, x3_ms)
280
281    error = np.ones(shape=output_np.shape) * 10e-6
282    diff = output_ms.asnumpy() - output_np
283    assert np.all(diff < error)
284
285@pytest.mark.level0
286@pytest.mark.platform_x86_cpu
287@pytest.mark.env_onecard
288def test_concat_3i_float32():
289    concat_3i(np.float32)
290
291@pytest.mark.level0
292@pytest.mark.platform_x86_cpu
293@pytest.mark.env_onecard
294def test_concat_3i_int32():
295    concat_3i(np.int32)
296
297@pytest.mark.level0
298@pytest.mark.platform_x86_cpu
299@pytest.mark.env_onecard
300def test_concat_3i_bool():
301    cat = Concat3INet()
302
303    x1_np = np.random.choice([True, False], (32, 4, 224, 224)).astype(np.bool)
304    x2_np = np.random.choice([True, False], (32, 8, 224, 224)).astype(np.bool)
305    x3_np = np.random.choice([True, False], (32, 10, 224, 224)).astype(np.bool)
306    output_np = np.concatenate((x1_np, x2_np, x3_np), axis=1)
307
308    x1_ms = Tensor(x1_np)
309    x2_ms = Tensor(x2_np)
310    x3_ms = Tensor(x3_np)
311    output_ms = cat(x1_ms, x2_ms, x3_ms)
312
313    assert (output_ms.asnumpy() == output_np).all()
314
315
316class Concat4INet(nn.Cell):
317    def __init__(self):
318        super(Concat4INet, self).__init__()
319        self.cat = P.Concat(axis=1)
320
321    def construct(self, x1, x2, x3, x4):
322        return self.cat((x1, x2, x3, x4))
323
324
325def concat_4i(nptype):
326    cat = Concat4INet()
327
328    x1_np = np.random.randn(32, 4, 224, 224).astype(nptype)
329    x2_np = np.random.randn(32, 8, 224, 224).astype(nptype)
330    x3_np = np.random.randn(32, 10, 224, 224).astype(nptype)
331    x4_np = np.random.randn(32, 5, 224, 224).astype(nptype)
332    output_np = np.concatenate((x1_np, x2_np, x3_np, x4_np), axis=1)
333
334    x1_ms = Tensor(x1_np)
335    x2_ms = Tensor(x2_np)
336    x3_ms = Tensor(x3_np)
337    x4_ms = Tensor(x4_np)
338    output_ms = cat(x1_ms, x2_ms, x3_ms, x4_ms)
339
340    error = np.ones(shape=output_np.shape) * 10e-6
341    diff = output_ms.asnumpy() - output_np
342    assert np.all(diff < error)
343
344@pytest.mark.level0
345@pytest.mark.platform_x86_cpu
346@pytest.mark.env_onecard
347def test_concat_4i_float32():
348    concat_4i(np.float32)
349
350@pytest.mark.level0
351@pytest.mark.platform_x86_cpu
352@pytest.mark.env_onecard
353def test_concat_4i_int32():
354    concat_4i(np.int32)
355
356@pytest.mark.level0
357@pytest.mark.platform_x86_cpu
358@pytest.mark.env_onecard
359def test_concat_4i_int8():
360    concat_4i(np.int8)
361
362@pytest.mark.level0
363@pytest.mark.platform_x86_cpu
364@pytest.mark.env_onecard
365def test_concat_4i_uint64():
366    concat_4i(np.uint64)
367
368@pytest.mark.level0
369@pytest.mark.platform_x86_cpu
370@pytest.mark.env_onecard
371def test_concat_4i_bool():
372    cat = Concat4INet()
373
374    x1_np = np.random.choice([True, False], (32, 4, 224, 224)).astype(np.bool)
375    x2_np = np.random.choice([True, False], (32, 8, 224, 224)).astype(np.bool)
376    x3_np = np.random.choice([True, False], (32, 10, 224, 224)).astype(np.bool)
377    x4_np = np.random.choice([True, False], (32, 5, 224, 224)).astype(np.bool)
378    output_np = np.concatenate((x1_np, x2_np, x3_np, x4_np), axis=1)
379
380    x1_ms = Tensor(x1_np)
381    x2_ms = Tensor(x2_np)
382    x3_ms = Tensor(x3_np)
383    x4_ms = Tensor(x4_np)
384    output_ms = cat(x1_ms, x2_ms, x3_ms, x4_ms)
385
386    assert (output_ms.asnumpy() == output_np).all()
387