• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16import numpy as np
17import pytest
18
19from mindspore import Tensor
20from mindspore.ops import composite as C
21import mindspore.nn as nn
22import mindspore.context as context
23
24class RepeatElementsNet(nn.Cell):
25    def __init__(self, rep, axis):
26        super(RepeatElementsNet, self).__init__()
27        self.rep = rep
28        self.axis = axis
29
30    def construct(self, x):
31        return C.repeat_elements(x, self.rep, self.axis)
32
33
34def repeat_elements(x, rep, axis):
35    repeat_elements_net = RepeatElementsNet(rep, axis)
36    return repeat_elements_net(Tensor(x.astype(np.int32))).asnumpy()
37
38@pytest.mark.level0
39@pytest.mark.platform_x86_gpu_training
40@pytest.mark.env_onecard
41def test_repeat_elements_1d_one_element_rep_1():
42    context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
43    a = np.arange(1)
44
45    ms_out = repeat_elements(a, 1, 0)
46    np_out = a.repeat(1, 0)
47    np.testing.assert_array_equal(np_out, ms_out)
48
49@pytest.mark.level0
50@pytest.mark.platform_x86_gpu_training
51@pytest.mark.env_onecard
52def test_repeat_elements_1d_one_element_rep_many():
53    context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
54    a = np.arange(1)
55
56    ms_out = repeat_elements(a, 5, 0)
57    np_out = a.repeat(5, 0)
58    np.testing.assert_array_equal(np_out, ms_out)
59
60    ms_out = repeat_elements(a, 513, 0)
61    np_out = a.repeat(513, 0)
62    np.testing.assert_array_equal(np_out, ms_out)
63
64@pytest.mark.level0
65@pytest.mark.platform_x86_gpu_training
66@pytest.mark.env_onecard
67def test_repeat_elements_1d_rep_1():
68    context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
69    a = np.arange(24)
70
71    ms_out = repeat_elements(a, 1, 0)
72    np_out = a.repeat(1, 0)
73    np.testing.assert_array_equal(np_out, ms_out)
74
75@pytest.mark.level0
76@pytest.mark.platform_x86_gpu_training
77@pytest.mark.env_onecard
78def test_repeat_elements_1d_rep_many():
79    context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
80    a = np.arange(24)
81
82    ms_out = repeat_elements(a, 231, 0)
83    np_out = a.repeat(231, 0)
84    np.testing.assert_array_equal(np_out, ms_out)
85
86@pytest.mark.level0
87@pytest.mark.platform_x86_gpu_training
88@pytest.mark.env_onecard
89def test_repeat_elements_2d_one_element_rep_1():
90    context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
91    a = np.arange(1).reshape(1, 1)
92
93    ms_out = repeat_elements(a, 1, 0)
94    np_out = a.repeat(1, 0)
95    np.testing.assert_array_equal(np_out, ms_out)
96
97    ms_out = repeat_elements(a, 1, 1)
98    np_out = a.repeat(1, 1)
99    np.testing.assert_array_equal(np_out, ms_out)
100
101@pytest.mark.level0
102@pytest.mark.platform_x86_gpu_training
103@pytest.mark.env_onecard
104def test_repeat_elements_2d_one_element_rep_many():
105    context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
106    a = np.arange(1).reshape(1, 1)
107
108    ms_out = repeat_elements(a, 13, 0)
109    np_out = a.repeat(13, 0)
110    np.testing.assert_array_equal(np_out, ms_out)
111
112    ms_out = repeat_elements(a, 13, 1)
113    np_out = a.repeat(13, 1)
114    np.testing.assert_array_equal(np_out, ms_out)
115
116@pytest.mark.level0
117@pytest.mark.platform_x86_gpu_training
118@pytest.mark.env_onecard
119def test_repeat_elements_2d_rep_1():
120    context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
121    a = np.arange(24).reshape(12, 2)
122
123    ms_out = repeat_elements(a, 1, 0)
124    np_out = a.repeat(1, 0)
125    np.testing.assert_array_equal(np_out, ms_out)
126
127    ms_out = repeat_elements(a, 1, 1)
128    np_out = a.repeat(1, 1)
129    np.testing.assert_array_equal(np_out, ms_out)
130
131@pytest.mark.level0
132@pytest.mark.platform_x86_gpu_training
133@pytest.mark.env_onecard
134def test_repeat_elements_2d_rep_many():
135    context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
136    a = np.arange(24).reshape(8, 3)
137
138    ms_out = repeat_elements(a, 23, 0)
139    np_out = a.repeat(23, 0)
140    np.testing.assert_array_equal(np_out, ms_out)
141
142    ms_out = repeat_elements(a, 23, 1)
143    np_out = a.repeat(23, 1)
144    np.testing.assert_array_equal(np_out, ms_out)
145
146@pytest.mark.level0
147@pytest.mark.platform_x86_gpu_training
148@pytest.mark.env_onecard
149def test_repeat_elements_3d_one_element_rep_1():
150    context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
151    a = np.arange(1).reshape(1, 1, 1)
152
153    ms_out = repeat_elements(a, 1, 0)
154    np_out = a.repeat(1, 0)
155    np.testing.assert_array_equal(np_out, ms_out)
156
157    ms_out = repeat_elements(a, 1, 1)
158    np_out = a.repeat(1, 1)
159    np.testing.assert_array_equal(np_out, ms_out)
160
161    ms_out = repeat_elements(a, 1, 2)
162    np_out = a.repeat(1, 2)
163    np.testing.assert_array_equal(np_out, ms_out)
164
165@pytest.mark.level0
166@pytest.mark.platform_x86_gpu_training
167@pytest.mark.env_onecard
168def test_repeat_elements_3d_one_element_rep_many():
169    context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
170    a = np.arange(1).reshape(1, 1, 1)
171
172    ms_out = repeat_elements(a, 43, 0)
173    np_out = a.repeat(43, 0)
174    np.testing.assert_array_equal(np_out, ms_out)
175
176    ms_out = repeat_elements(a, 43, 1)
177    np_out = a.repeat(43, 1)
178    np.testing.assert_array_equal(np_out, ms_out)
179
180    ms_out = repeat_elements(a, 43, 2)
181    np_out = a.repeat(43, 2)
182    np.testing.assert_array_equal(np_out, ms_out)
183
184@pytest.mark.level0
185@pytest.mark.platform_x86_gpu_training
186@pytest.mark.env_onecard
187def test_repeat_elements_3d_rep_1():
188    context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
189    a = np.arange(60).reshape(6, 2, 5)
190
191    ms_out = repeat_elements(a, 1, 0)
192    np_out = a.repeat(1, 0)
193    np.testing.assert_array_equal(np_out, ms_out)
194
195    ms_out = repeat_elements(a, 1, 1)
196    np_out = a.repeat(1, 1)
197    np.testing.assert_array_equal(np_out, ms_out)
198
199    ms_out = repeat_elements(a, 1, 2)
200    np_out = a.repeat(1, 2)
201    np.testing.assert_array_equal(np_out, ms_out)
202
203@pytest.mark.level0
204@pytest.mark.platform_x86_gpu_training
205@pytest.mark.env_onecard
206def test_repeat_elements_3d_rep_many():
207    context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
208    a = np.arange(60).reshape(3, 4, 5)
209
210    ms_out = repeat_elements(a, 14, 0)
211    np_out = a.repeat(14, 0)
212    np.testing.assert_array_equal(np_out, ms_out)
213
214    ms_out = repeat_elements(a, 14, 1)
215    np_out = a.repeat(14, 1)
216    np.testing.assert_array_equal(np_out, ms_out)
217
218    ms_out = repeat_elements(a, 14, 2)
219    np_out = a.repeat(14, 2)
220    np.testing.assert_array_equal(np_out, ms_out)
221
222@pytest.mark.level0
223@pytest.mark.platform_x86_gpu_training
224@pytest.mark.env_onecard
225def test_repeat_elements_4d_one_element_rep_1():
226    context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
227    a = np.arange(1).reshape(1, 1, 1, 1)
228
229    ms_out = repeat_elements(a, 1, 0)
230    np_out = a.repeat(1, 0)
231    np.testing.assert_array_equal(np_out, ms_out)
232
233    ms_out = repeat_elements(a, 1, 1)
234    np_out = a.repeat(1, 1)
235    np.testing.assert_array_equal(np_out, ms_out)
236
237    ms_out = repeat_elements(a, 1, 2)
238    np_out = a.repeat(1, 2)
239    np.testing.assert_array_equal(np_out, ms_out)
240
241    ms_out = repeat_elements(a, 1, 3)
242    np_out = a.repeat(1, 3)
243    np.testing.assert_array_equal(np_out, ms_out)
244
245
246@pytest.mark.level0
247@pytest.mark.platform_x86_gpu_training
248@pytest.mark.env_onecard
249def test_repeat_elements_4d_one_element_rep_many():
250    context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
251    a = np.arange(1).reshape(1, 1, 1, 1)
252
253    ms_out = repeat_elements(a, 17, 0)
254    np_out = a.repeat(17, 0)
255    np.testing.assert_array_equal(np_out, ms_out)
256
257    ms_out = repeat_elements(a, 17, 1)
258    np_out = a.repeat(17, 1)
259    np.testing.assert_array_equal(np_out, ms_out)
260
261    ms_out = repeat_elements(a, 17, 2)
262    np_out = a.repeat(17, 2)
263    np.testing.assert_array_equal(np_out, ms_out)
264
265    ms_out = repeat_elements(a, 17, 3)
266    np_out = a.repeat(17, 3)
267    np.testing.assert_array_equal(np_out, ms_out)
268
269@pytest.mark.level0
270@pytest.mark.platform_x86_gpu_training
271@pytest.mark.env_onecard
272def test_repeat_elements_4d_rep_1():
273    context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
274    a = np.arange(24).reshape(4, 3, 2, 1)
275
276    ms_out = repeat_elements(a, 1, 0)
277    np_out = a.repeat(1, 0)
278    np.testing.assert_array_equal(np_out, ms_out)
279
280    ms_out = repeat_elements(a, 1, 1)
281    np_out = a.repeat(1, 1)
282    np.testing.assert_array_equal(np_out, ms_out)
283
284    ms_out = repeat_elements(a, 1, 2)
285    np_out = a.repeat(1, 2)
286    np.testing.assert_array_equal(np_out, ms_out)
287
288    ms_out = repeat_elements(a, 1, 3)
289    np_out = a.repeat(1, 3)
290    np.testing.assert_array_equal(np_out, ms_out)
291
292
293@pytest.mark.level0
294@pytest.mark.platform_x86_gpu_training
295@pytest.mark.env_onecard
296def test_repeat_elements_4d_rep_many():
297    context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
298    a = np.arange(24).reshape(2, 2, 2, 3)
299
300    ms_out = repeat_elements(a, 23, 0)
301    np_out = a.repeat(23, 0)
302    np.testing.assert_array_equal(np_out, ms_out)
303
304    ms_out = repeat_elements(a, 23, 1)
305    np_out = a.repeat(23, 1)
306    np.testing.assert_array_equal(np_out, ms_out)
307
308    ms_out = repeat_elements(a, 23, 2)
309    np_out = a.repeat(23, 2)
310    np.testing.assert_array_equal(np_out, ms_out)
311
312    ms_out = repeat_elements(a, 23, 3)
313    np_out = a.repeat(23, 3)
314    np.testing.assert_array_equal(np_out, ms_out)
315
316@pytest.mark.level0
317@pytest.mark.platform_x86_gpu_training
318@pytest.mark.env_onecard
319def test_repeat_elements_5d_one_element_rep_1():
320    context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
321    a = np.arange(1).reshape(1, 1, 1, 1, 1)
322
323    ms_out = repeat_elements(a, 1, 0)
324    np_out = a.repeat(1, 0)
325    np.testing.assert_array_equal(np_out, ms_out)
326
327    ms_out = repeat_elements(a, 1, 1)
328    np_out = a.repeat(1, 1)
329    np.testing.assert_array_equal(np_out, ms_out)
330
331    ms_out = repeat_elements(a, 1, 2)
332    np_out = a.repeat(1, 2)
333    np.testing.assert_array_equal(np_out, ms_out)
334
335    ms_out = repeat_elements(a, 1, 3)
336    np_out = a.repeat(1, 3)
337    np.testing.assert_array_equal(np_out, ms_out)
338
339    ms_out = repeat_elements(a, 1, 4)
340    np_out = a.repeat(1, 4)
341    np.testing.assert_array_equal(np_out, ms_out)
342
343@pytest.mark.level0
344@pytest.mark.platform_x86_gpu_training
345@pytest.mark.env_onecard
346def test_repeat_elements_5d_one_element_rep_many():
347    context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
348    a = np.arange(1).reshape(1, 1, 1, 1, 1)
349
350    ms_out = repeat_elements(a, 19, 0)
351    np_out = a.repeat(19, 0)
352    np.testing.assert_array_equal(np_out, ms_out)
353
354    ms_out = repeat_elements(a, 19, 1)
355    np_out = a.repeat(19, 1)
356    np.testing.assert_array_equal(np_out, ms_out)
357
358    ms_out = repeat_elements(a, 19, 2)
359    np_out = a.repeat(19, 2)
360    np.testing.assert_array_equal(np_out, ms_out)
361
362    ms_out = repeat_elements(a, 19, 3)
363    np_out = a.repeat(19, 3)
364    np.testing.assert_array_equal(np_out, ms_out)
365
366    ms_out = repeat_elements(a, 19, 4)
367    np_out = a.repeat(19, 4)
368    np.testing.assert_array_equal(np_out, ms_out)
369
370@pytest.mark.level0
371@pytest.mark.platform_x86_gpu_training
372@pytest.mark.env_onecard
373def test_repeat_elements_5d_rep_1():
374    context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
375    a = np.arange(224).reshape(8, 2, 1, 7, 2)
376
377    ms_out = repeat_elements(a, 1, 0)
378    np_out = a.repeat(1, 0)
379    np.testing.assert_array_equal(np_out, ms_out)
380
381    ms_out = repeat_elements(a, 1, 1)
382    np_out = a.repeat(1, 1)
383    np.testing.assert_array_equal(np_out, ms_out)
384
385    ms_out = repeat_elements(a, 1, 2)
386    np_out = a.repeat(1, 2)
387    np.testing.assert_array_equal(np_out, ms_out)
388
389    ms_out = repeat_elements(a, 1, 3)
390    np_out = a.repeat(1, 3)
391    np.testing.assert_array_equal(np_out, ms_out)
392
393    ms_out = repeat_elements(a, 1, 4)
394    np_out = a.repeat(1, 4)
395    np.testing.assert_array_equal(np_out, ms_out)
396
397@pytest.mark.level0
398@pytest.mark.platform_x86_gpu_training
399@pytest.mark.env_onecard
400def test_repeat_elements_5d_rep_many():
401    context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
402    a = np.arange(224).reshape(1, 7, 4, 4, 2)
403
404    ms_out = repeat_elements(a, 7, 0)
405    np_out = a.repeat(7, 0)
406    np.testing.assert_array_equal(np_out, ms_out)
407
408    ms_out = repeat_elements(a, 7, 1)
409    np_out = a.repeat(7, 1)
410    np.testing.assert_array_equal(np_out, ms_out)
411
412    ms_out = repeat_elements(a, 7, 2)
413    np_out = a.repeat(7, 2)
414    np.testing.assert_array_equal(np_out, ms_out)
415
416    ms_out = repeat_elements(a, 7, 3)
417    np_out = a.repeat(7, 3)
418    np.testing.assert_array_equal(np_out, ms_out)
419
420    ms_out = repeat_elements(a, 7, 4)
421    np_out = a.repeat(7, 4)
422    np.testing.assert_array_equal(np_out, ms_out)
423
424@pytest.mark.level0
425@pytest.mark.platform_x86_gpu_training
426@pytest.mark.env_onecard
427def test_repeat_elements_large_one_element_rep_1():
428    context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
429    a = np.arange(1).reshape(1, 1, 1, 1, 1, 1, 1, 1)
430
431    ms_out = repeat_elements(a, 1, 0)
432    np_out = a.repeat(1, 0)
433    np.testing.assert_array_equal(np_out, ms_out)
434
435    ms_out = repeat_elements(a, 1, 1)
436    np_out = a.repeat(1, 1)
437    np.testing.assert_array_equal(np_out, ms_out)
438
439    ms_out = repeat_elements(a, 1, 2)
440    np_out = a.repeat(1, 2)
441    np.testing.assert_array_equal(np_out, ms_out)
442
443    ms_out = repeat_elements(a, 1, 3)
444    np_out = a.repeat(1, 3)
445    np.testing.assert_array_equal(np_out, ms_out)
446
447    ms_out = repeat_elements(a, 1, 4)
448    np_out = a.repeat(1, 4)
449    np.testing.assert_array_equal(np_out, ms_out)
450
451    ms_out = repeat_elements(a, 1, 5)
452    np_out = a.repeat(1, 5)
453    np.testing.assert_array_equal(np_out, ms_out)
454
455    ms_out = repeat_elements(a, 1, 6)
456    np_out = a.repeat(1, 6)
457    np.testing.assert_array_equal(np_out, ms_out)
458
459    ms_out = repeat_elements(a, 1, 7)
460    np_out = a.repeat(1, 7)
461    np.testing.assert_array_equal(np_out, ms_out)
462
463@pytest.mark.level0
464@pytest.mark.platform_x86_gpu_training
465@pytest.mark.env_onecard
466def test_repeat_elements_large_one_element_rep_many():
467    context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
468    a = np.arange(1).reshape(1, 1, 1, 1, 1, 1, 1, 1)
469
470    ms_out = repeat_elements(a, 42, 0)
471    np_out = a.repeat(42, 0)
472    np.testing.assert_array_equal(np_out, ms_out)
473
474    ms_out = repeat_elements(a, 42, 1)
475    np_out = a.repeat(42, 1)
476    np.testing.assert_array_equal(np_out, ms_out)
477
478    ms_out = repeat_elements(a, 42, 2)
479    np_out = a.repeat(42, 2)
480    np.testing.assert_array_equal(np_out, ms_out)
481
482    ms_out = repeat_elements(a, 42, 3)
483    np_out = a.repeat(42, 3)
484    np.testing.assert_array_equal(np_out, ms_out)
485
486    ms_out = repeat_elements(a, 42, 4)
487    np_out = a.repeat(42, 4)
488    np.testing.assert_array_equal(np_out, ms_out)
489
490    ms_out = repeat_elements(a, 42, 5)
491    np_out = a.repeat(42, 5)
492    np.testing.assert_array_equal(np_out, ms_out)
493
494    ms_out = repeat_elements(a, 42, 6)
495    np_out = a.repeat(42, 6)
496    np.testing.assert_array_equal(np_out, ms_out)
497
498    ms_out = repeat_elements(a, 42, 7)
499    np_out = a.repeat(42, 7)
500    np.testing.assert_array_equal(np_out, ms_out)
501
502@pytest.mark.level0
503@pytest.mark.platform_x86_gpu_training
504@pytest.mark.env_onecard
505def test_repeat_elements_large_rep_1():
506    context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
507    a = np.arange(1152).reshape(2, 3, 4, 8, 1, 1, 2, 3)
508
509    ms_out = repeat_elements(a, 1, 0)
510    np_out = a.repeat(1, 0)
511    np.testing.assert_array_equal(np_out, ms_out)
512
513    ms_out = repeat_elements(a, 1, 1)
514    np_out = a.repeat(1, 1)
515    np.testing.assert_array_equal(np_out, ms_out)
516
517    ms_out = repeat_elements(a, 1, 2)
518    np_out = a.repeat(1, 2)
519    np.testing.assert_array_equal(np_out, ms_out)
520
521    ms_out = repeat_elements(a, 1, 3)
522    np_out = a.repeat(1, 3)
523    np.testing.assert_array_equal(np_out, ms_out)
524
525    ms_out = repeat_elements(a, 1, 4)
526    np_out = a.repeat(1, 4)
527    np.testing.assert_array_equal(np_out, ms_out)
528
529    ms_out = repeat_elements(a, 1, 5)
530    np_out = a.repeat(1, 5)
531    np.testing.assert_array_equal(np_out, ms_out)
532
533    ms_out = repeat_elements(a, 1, 6)
534    np_out = a.repeat(1, 6)
535    np.testing.assert_array_equal(np_out, ms_out)
536
537    ms_out = repeat_elements(a, 1, 7)
538    np_out = a.repeat(1, 7)
539    np.testing.assert_array_equal(np_out, ms_out)
540
541@pytest.mark.level0
542@pytest.mark.platform_x86_gpu_training
543@pytest.mark.env_onecard
544def test_repeat_elements_large_rep_many():
545    context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
546    a = np.arange(1152).reshape(4, 3, 4, 2, 1, 1, 4, 3)
547
548    ms_out = repeat_elements(a, 4, 0)
549    np_out = a.repeat(4, 0)
550    np.testing.assert_array_equal(np_out, ms_out)
551
552    ms_out = repeat_elements(a, 4, 1)
553    np_out = a.repeat(4, 1)
554    np.testing.assert_array_equal(np_out, ms_out)
555
556    ms_out = repeat_elements(a, 4, 2)
557    np_out = a.repeat(4, 2)
558    np.testing.assert_array_equal(np_out, ms_out)
559
560    ms_out = repeat_elements(a, 4, 3)
561    np_out = a.repeat(4, 3)
562    np.testing.assert_array_equal(np_out, ms_out)
563
564    ms_out = repeat_elements(a, 4, 4)
565    np_out = a.repeat(4, 4)
566    np.testing.assert_array_equal(np_out, ms_out)
567
568    ms_out = repeat_elements(a, 4, 5)
569    np_out = a.repeat(4, 5)
570    np.testing.assert_array_equal(np_out, ms_out)
571
572    ms_out = repeat_elements(a, 4, 6)
573    np_out = a.repeat(4, 6)
574    np.testing.assert_array_equal(np_out, ms_out)
575
576    ms_out = repeat_elements(a, 4, 7)
577    np_out = a.repeat(4, 7)
578    np.testing.assert_array_equal(np_out, ms_out)
579
580@pytest.mark.level0
581@pytest.mark.platform_x86_gpu_training
582@pytest.mark.env_onecard
583def test_repeat_elements_half():
584    context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
585    a = np.arange(1152).astype(np.float16).reshape(4, 3, 4, 2, 1, 1, 4, 3)
586
587    ms_out = repeat_elements(a, 4, 0)
588    np_out = a.repeat(4, 0)
589    np.testing.assert_array_equal(np_out, ms_out)
590
591    ms_out = repeat_elements(a, 4, 1)
592    np_out = a.repeat(4, 1)
593    np.testing.assert_array_equal(np_out, ms_out)
594
595    ms_out = repeat_elements(a, 4, 2)
596    np_out = a.repeat(4, 2)
597    np.testing.assert_array_equal(np_out, ms_out)
598
599    ms_out = repeat_elements(a, 4, 3)
600    np_out = a.repeat(4, 3)
601    np.testing.assert_array_equal(np_out, ms_out)
602
603    ms_out = repeat_elements(a, 4, 4)
604    np_out = a.repeat(4, 4)
605    np.testing.assert_array_equal(np_out, ms_out)
606
607    ms_out = repeat_elements(a, 4, 5)
608    np_out = a.repeat(4, 5)
609    np.testing.assert_array_equal(np_out, ms_out)
610
611    ms_out = repeat_elements(a, 4, 6)
612    np_out = a.repeat(4, 6)
613    np.testing.assert_array_equal(np_out, ms_out)
614
615    ms_out = repeat_elements(a, 4, 7)
616    np_out = a.repeat(4, 7)
617    np.testing.assert_array_equal(np_out, ms_out)
618
619@pytest.mark.level0
620@pytest.mark.platform_x86_gpu_training
621@pytest.mark.env_onecard
622def test_repeat_elements_net_multi_use():
623    context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
624
625    rep = 3
626    axis = 4
627    repeat_elements_net = RepeatElementsNet(rep, axis)
628
629    a = np.arange(64).reshape(2, 2, 2, 2, 2, 2)
630    ms_out = repeat_elements_net(Tensor(a.astype(np.int32))).asnumpy()
631    np_out = a.repeat(rep, axis)
632    np.testing.assert_array_equal(np_out, ms_out)
633
634    a = np.arange(128).reshape(2, 2, 4, 2, 2, 2)
635    ms_out = repeat_elements_net(Tensor(a.astype(np.int32))).asnumpy()
636    np_out = a.repeat(rep, axis)
637    np.testing.assert_array_equal(np_out, ms_out)
638
639    a = np.arange(18).reshape(1, 1, 3, 2, 3, 1)
640    ms_out = repeat_elements_net(Tensor(a.astype(np.int32))).asnumpy()
641    np_out = a.repeat(rep, axis)
642    np.testing.assert_array_equal(np_out, ms_out)
643
644@pytest.mark.level0
645@pytest.mark.platform_x86_gpu_training
646@pytest.mark.env_onecard
647def test_repeat_elements_invalid_input():
648    context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
649    a = np.arange(64).reshape(2, 2, 2, 2, 2, 2)
650    with pytest.raises(ValueError):
651        _ = repeat_elements(a, 0, 0)
652
653    with pytest.raises(ValueError):
654        _ = repeat_elements(a, 1, 6)
655
656    with pytest.raises(ValueError):
657        _ = repeat_elements(a, 1, -7)
658