• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019-2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16import numpy as np
17import pytest
18
19import mindspore.context as context
20import mindspore.nn as nn
21from mindspore import Tensor
22from mindspore.ops.operations import _inner_ops as inner
23from mindspore.ops import operations as P
24
25
26class GatherNet(nn.Cell):
27    def __init__(self):
28        super(GatherNet, self).__init__()
29        self.gather = P.Gather()
30
31    def construct(self, x, indices):
32        return self.gather(x, indices, 1)
33
34
35@pytest.mark.level0
36@pytest.mark.platform_x86_gpu_training
37@pytest.mark.env_onecard
38def test_gather0():
39    x = Tensor(np.arange(2 * 3 * 4 * 5, dtype=np.float32).reshape(2, 3, 4, 5))
40    indices = Tensor(np.ones((2, 2, 4, 5), dtype='i4'))
41    expect = np.array([[[[[[[20., 21., 22., 23., 24.],
42                            [25., 26., 27., 28., 29.],
43                            [30., 31., 32., 33., 34.],
44                            [35., 36., 37., 38., 39.]],
45
46                           [[20., 21., 22., 23., 24.],
47                            [25., 26., 27., 28., 29.],
48                            [30., 31., 32., 33., 34.],
49                            [35., 36., 37., 38., 39.]],
50
51                           [[20., 21., 22., 23., 24.],
52                            [25., 26., 27., 28., 29.],
53                            [30., 31., 32., 33., 34.],
54                            [35., 36., 37., 38., 39.]],
55
56                           [[20., 21., 22., 23., 24.],
57                            [25., 26., 27., 28., 29.],
58                            [30., 31., 32., 33., 34.],
59                            [35., 36., 37., 38., 39.]],
60
61                           [[20., 21., 22., 23., 24.],
62                            [25., 26., 27., 28., 29.],
63                            [30., 31., 32., 33., 34.],
64                            [35., 36., 37., 38., 39.]]],
65
66                          [[[20., 21., 22., 23., 24.],
67                            [25., 26., 27., 28., 29.],
68                            [30., 31., 32., 33., 34.],
69                            [35., 36., 37., 38., 39.]],
70
71                           [[20., 21., 22., 23., 24.],
72                            [25., 26., 27., 28., 29.],
73                            [30., 31., 32., 33., 34.],
74                            [35., 36., 37., 38., 39.]],
75
76                           [[20., 21., 22., 23., 24.],
77                            [25., 26., 27., 28., 29.],
78                            [30., 31., 32., 33., 34.],
79                            [35., 36., 37., 38., 39.]],
80
81                           [[20., 21., 22., 23., 24.],
82                            [25., 26., 27., 28., 29.],
83                            [30., 31., 32., 33., 34.],
84                            [35., 36., 37., 38., 39.]],
85
86                           [[20., 21., 22., 23., 24.],
87                            [25., 26., 27., 28., 29.],
88                            [30., 31., 32., 33., 34.],
89                            [35., 36., 37., 38., 39.]]],
90
91                          [[[20., 21., 22., 23., 24.],
92                            [25., 26., 27., 28., 29.],
93                            [30., 31., 32., 33., 34.],
94                            [35., 36., 37., 38., 39.]],
95
96                           [[20., 21., 22., 23., 24.],
97                            [25., 26., 27., 28., 29.],
98                            [30., 31., 32., 33., 34.],
99                            [35., 36., 37., 38., 39.]],
100
101                           [[20., 21., 22., 23., 24.],
102                            [25., 26., 27., 28., 29.],
103                            [30., 31., 32., 33., 34.],
104                            [35., 36., 37., 38., 39.]],
105
106                           [[20., 21., 22., 23., 24.],
107                            [25., 26., 27., 28., 29.],
108                            [30., 31., 32., 33., 34.],
109                            [35., 36., 37., 38., 39.]],
110
111                           [[20., 21., 22., 23., 24.],
112                            [25., 26., 27., 28., 29.],
113                            [30., 31., 32., 33., 34.],
114                            [35., 36., 37., 38., 39.]]],
115
116                          [[[20., 21., 22., 23., 24.],
117                            [25., 26., 27., 28., 29.],
118                            [30., 31., 32., 33., 34.],
119                            [35., 36., 37., 38., 39.]],
120
121                           [[20., 21., 22., 23., 24.],
122                            [25., 26., 27., 28., 29.],
123                            [30., 31., 32., 33., 34.],
124                            [35., 36., 37., 38., 39.]],
125
126                           [[20., 21., 22., 23., 24.],
127                            [25., 26., 27., 28., 29.],
128                            [30., 31., 32., 33., 34.],
129                            [35., 36., 37., 38., 39.]],
130
131                           [[20., 21., 22., 23., 24.],
132                            [25., 26., 27., 28., 29.],
133                            [30., 31., 32., 33., 34.],
134                            [35., 36., 37., 38., 39.]],
135
136                           [[20., 21., 22., 23., 24.],
137                            [25., 26., 27., 28., 29.],
138                            [30., 31., 32., 33., 34.],
139                            [35., 36., 37., 38., 39.]]]],
140
141                         [[[[20., 21., 22., 23., 24.],
142                            [25., 26., 27., 28., 29.],
143                            [30., 31., 32., 33., 34.],
144                            [35., 36., 37., 38., 39.]],
145
146                           [[20., 21., 22., 23., 24.],
147                            [25., 26., 27., 28., 29.],
148                            [30., 31., 32., 33., 34.],
149                            [35., 36., 37., 38., 39.]],
150
151                           [[20., 21., 22., 23., 24.],
152                            [25., 26., 27., 28., 29.],
153                            [30., 31., 32., 33., 34.],
154                            [35., 36., 37., 38., 39.]],
155
156                           [[20., 21., 22., 23., 24.],
157                            [25., 26., 27., 28., 29.],
158                            [30., 31., 32., 33., 34.],
159                            [35., 36., 37., 38., 39.]],
160
161                           [[20., 21., 22., 23., 24.],
162                            [25., 26., 27., 28., 29.],
163                            [30., 31., 32., 33., 34.],
164                            [35., 36., 37., 38., 39.]]],
165
166                          [[[20., 21., 22., 23., 24.],
167                            [25., 26., 27., 28., 29.],
168                            [30., 31., 32., 33., 34.],
169                            [35., 36., 37., 38., 39.]],
170
171                           [[20., 21., 22., 23., 24.],
172                            [25., 26., 27., 28., 29.],
173                            [30., 31., 32., 33., 34.],
174                            [35., 36., 37., 38., 39.]],
175
176                           [[20., 21., 22., 23., 24.],
177                            [25., 26., 27., 28., 29.],
178                            [30., 31., 32., 33., 34.],
179                            [35., 36., 37., 38., 39.]],
180
181                           [[20., 21., 22., 23., 24.],
182                            [25., 26., 27., 28., 29.],
183                            [30., 31., 32., 33., 34.],
184                            [35., 36., 37., 38., 39.]],
185
186                           [[20., 21., 22., 23., 24.],
187                            [25., 26., 27., 28., 29.],
188                            [30., 31., 32., 33., 34.],
189                            [35., 36., 37., 38., 39.]]],
190
191                          [[[20., 21., 22., 23., 24.],
192                            [25., 26., 27., 28., 29.],
193                            [30., 31., 32., 33., 34.],
194                            [35., 36., 37., 38., 39.]],
195
196                           [[20., 21., 22., 23., 24.],
197                            [25., 26., 27., 28., 29.],
198                            [30., 31., 32., 33., 34.],
199                            [35., 36., 37., 38., 39.]],
200
201                           [[20., 21., 22., 23., 24.],
202                            [25., 26., 27., 28., 29.],
203                            [30., 31., 32., 33., 34.],
204                            [35., 36., 37., 38., 39.]],
205
206                           [[20., 21., 22., 23., 24.],
207                            [25., 26., 27., 28., 29.],
208                            [30., 31., 32., 33., 34.],
209                            [35., 36., 37., 38., 39.]],
210
211                           [[20., 21., 22., 23., 24.],
212                            [25., 26., 27., 28., 29.],
213                            [30., 31., 32., 33., 34.],
214                            [35., 36., 37., 38., 39.]]],
215
216                          [[[20., 21., 22., 23., 24.],
217                            [25., 26., 27., 28., 29.],
218                            [30., 31., 32., 33., 34.],
219                            [35., 36., 37., 38., 39.]],
220
221                           [[20., 21., 22., 23., 24.],
222                            [25., 26., 27., 28., 29.],
223                            [30., 31., 32., 33., 34.],
224                            [35., 36., 37., 38., 39.]],
225
226                           [[20., 21., 22., 23., 24.],
227                            [25., 26., 27., 28., 29.],
228                            [30., 31., 32., 33., 34.],
229                            [35., 36., 37., 38., 39.]],
230
231                           [[20., 21., 22., 23., 24.],
232                            [25., 26., 27., 28., 29.],
233                            [30., 31., 32., 33., 34.],
234                            [35., 36., 37., 38., 39.]],
235
236                           [[20., 21., 22., 23., 24.],
237                            [25., 26., 27., 28., 29.],
238                            [30., 31., 32., 33., 34.],
239                            [35., 36., 37., 38., 39.]]]]],
240
241                        [[[[[20., 21., 22., 23., 24.],
242                            [25., 26., 27., 28., 29.],
243                            [30., 31., 32., 33., 34.],
244                            [35., 36., 37., 38., 39.]],
245
246                           [[20., 21., 22., 23., 24.],
247                            [25., 26., 27., 28., 29.],
248                            [30., 31., 32., 33., 34.],
249                            [35., 36., 37., 38., 39.]],
250
251                           [[20., 21., 22., 23., 24.],
252                            [25., 26., 27., 28., 29.],
253                            [30., 31., 32., 33., 34.],
254                            [35., 36., 37., 38., 39.]],
255
256                           [[20., 21., 22., 23., 24.],
257                            [25., 26., 27., 28., 29.],
258                            [30., 31., 32., 33., 34.],
259                            [35., 36., 37., 38., 39.]],
260
261                           [[20., 21., 22., 23., 24.],
262                            [25., 26., 27., 28., 29.],
263                            [30., 31., 32., 33., 34.],
264                            [35., 36., 37., 38., 39.]]],
265
266                          [[[20., 21., 22., 23., 24.],
267                            [25., 26., 27., 28., 29.],
268                            [30., 31., 32., 33., 34.],
269                            [35., 36., 37., 38., 39.]],
270
271                           [[20., 21., 22., 23., 24.],
272                            [25., 26., 27., 28., 29.],
273                            [30., 31., 32., 33., 34.],
274                            [35., 36., 37., 38., 39.]],
275
276                           [[20., 21., 22., 23., 24.],
277                            [25., 26., 27., 28., 29.],
278                            [30., 31., 32., 33., 34.],
279                            [35., 36., 37., 38., 39.]],
280
281                           [[20., 21., 22., 23., 24.],
282                            [25., 26., 27., 28., 29.],
283                            [30., 31., 32., 33., 34.],
284                            [35., 36., 37., 38., 39.]],
285
286                           [[20., 21., 22., 23., 24.],
287                            [25., 26., 27., 28., 29.],
288                            [30., 31., 32., 33., 34.],
289                            [35., 36., 37., 38., 39.]]],
290
291                          [[[20., 21., 22., 23., 24.],
292                            [25., 26., 27., 28., 29.],
293                            [30., 31., 32., 33., 34.],
294                            [35., 36., 37., 38., 39.]],
295
296                           [[20., 21., 22., 23., 24.],
297                            [25., 26., 27., 28., 29.],
298                            [30., 31., 32., 33., 34.],
299                            [35., 36., 37., 38., 39.]],
300
301                           [[20., 21., 22., 23., 24.],
302                            [25., 26., 27., 28., 29.],
303                            [30., 31., 32., 33., 34.],
304                            [35., 36., 37., 38., 39.]],
305
306                           [[20., 21., 22., 23., 24.],
307                            [25., 26., 27., 28., 29.],
308                            [30., 31., 32., 33., 34.],
309                            [35., 36., 37., 38., 39.]],
310
311                           [[20., 21., 22., 23., 24.],
312                            [25., 26., 27., 28., 29.],
313                            [30., 31., 32., 33., 34.],
314                            [35., 36., 37., 38., 39.]]],
315
316                          [[[20., 21., 22., 23., 24.],
317                            [25., 26., 27., 28., 29.],
318                            [30., 31., 32., 33., 34.],
319                            [35., 36., 37., 38., 39.]],
320
321                           [[20., 21., 22., 23., 24.],
322                            [25., 26., 27., 28., 29.],
323                            [30., 31., 32., 33., 34.],
324                            [35., 36., 37., 38., 39.]],
325
326                           [[20., 21., 22., 23., 24.],
327                            [25., 26., 27., 28., 29.],
328                            [30., 31., 32., 33., 34.],
329                            [35., 36., 37., 38., 39.]],
330
331                           [[20., 21., 22., 23., 24.],
332                            [25., 26., 27., 28., 29.],
333                            [30., 31., 32., 33., 34.],
334                            [35., 36., 37., 38., 39.]],
335
336                           [[20., 21., 22., 23., 24.],
337                            [25., 26., 27., 28., 29.],
338                            [30., 31., 32., 33., 34.],
339                            [35., 36., 37., 38., 39.]]]],
340
341                         [[[[20., 21., 22., 23., 24.],
342                            [25., 26., 27., 28., 29.],
343                            [30., 31., 32., 33., 34.],
344                            [35., 36., 37., 38., 39.]],
345
346                           [[20., 21., 22., 23., 24.],
347                            [25., 26., 27., 28., 29.],
348                            [30., 31., 32., 33., 34.],
349                            [35., 36., 37., 38., 39.]],
350
351                           [[20., 21., 22., 23., 24.],
352                            [25., 26., 27., 28., 29.],
353                            [30., 31., 32., 33., 34.],
354                            [35., 36., 37., 38., 39.]],
355
356                           [[20., 21., 22., 23., 24.],
357                            [25., 26., 27., 28., 29.],
358                            [30., 31., 32., 33., 34.],
359                            [35., 36., 37., 38., 39.]],
360
361                           [[20., 21., 22., 23., 24.],
362                            [25., 26., 27., 28., 29.],
363                            [30., 31., 32., 33., 34.],
364                            [35., 36., 37., 38., 39.]]],
365
366                          [[[20., 21., 22., 23., 24.],
367                            [25., 26., 27., 28., 29.],
368                            [30., 31., 32., 33., 34.],
369                            [35., 36., 37., 38., 39.]],
370
371                           [[20., 21., 22., 23., 24.],
372                            [25., 26., 27., 28., 29.],
373                            [30., 31., 32., 33., 34.],
374                            [35., 36., 37., 38., 39.]],
375
376                           [[20., 21., 22., 23., 24.],
377                            [25., 26., 27., 28., 29.],
378                            [30., 31., 32., 33., 34.],
379                            [35., 36., 37., 38., 39.]],
380
381                           [[20., 21., 22., 23., 24.],
382                            [25., 26., 27., 28., 29.],
383                            [30., 31., 32., 33., 34.],
384                            [35., 36., 37., 38., 39.]],
385
386                           [[20., 21., 22., 23., 24.],
387                            [25., 26., 27., 28., 29.],
388                            [30., 31., 32., 33., 34.],
389                            [35., 36., 37., 38., 39.]]],
390
391                          [[[20., 21., 22., 23., 24.],
392                            [25., 26., 27., 28., 29.],
393                            [30., 31., 32., 33., 34.],
394                            [35., 36., 37., 38., 39.]],
395
396                           [[20., 21., 22., 23., 24.],
397                            [25., 26., 27., 28., 29.],
398                            [30., 31., 32., 33., 34.],
399                            [35., 36., 37., 38., 39.]],
400
401                           [[20., 21., 22., 23., 24.],
402                            [25., 26., 27., 28., 29.],
403                            [30., 31., 32., 33., 34.],
404                            [35., 36., 37., 38., 39.]],
405
406                           [[20., 21., 22., 23., 24.],
407                            [25., 26., 27., 28., 29.],
408                            [30., 31., 32., 33., 34.],
409                            [35., 36., 37., 38., 39.]],
410
411                           [[20., 21., 22., 23., 24.],
412                            [25., 26., 27., 28., 29.],
413                            [30., 31., 32., 33., 34.],
414                            [35., 36., 37., 38., 39.]]],
415
416                          [[[20., 21., 22., 23., 24.],
417                            [25., 26., 27., 28., 29.],
418                            [30., 31., 32., 33., 34.],
419                            [35., 36., 37., 38., 39.]],
420
421                           [[20., 21., 22., 23., 24.],
422                            [25., 26., 27., 28., 29.],
423                            [30., 31., 32., 33., 34.],
424                            [35., 36., 37., 38., 39.]],
425
426                           [[20., 21., 22., 23., 24.],
427                            [25., 26., 27., 28., 29.],
428                            [30., 31., 32., 33., 34.],
429                            [35., 36., 37., 38., 39.]],
430
431                           [[20., 21., 22., 23., 24.],
432                            [25., 26., 27., 28., 29.],
433                            [30., 31., 32., 33., 34.],
434                            [35., 36., 37., 38., 39.]],
435
436                           [[20., 21., 22., 23., 24.],
437                            [25., 26., 27., 28., 29.],
438                            [30., 31., 32., 33., 34.],
439                            [35., 36., 37., 38., 39.]]]]]],
440
441                       [[[[[[80., 81., 82., 83., 84.],
442                            [85., 86., 87., 88., 89.],
443                            [90., 91., 92., 93., 94.],
444                            [95., 96., 97., 98., 99.]],
445
446                           [[80., 81., 82., 83., 84.],
447                            [85., 86., 87., 88., 89.],
448                            [90., 91., 92., 93., 94.],
449                            [95., 96., 97., 98., 99.]],
450
451                           [[80., 81., 82., 83., 84.],
452                            [85., 86., 87., 88., 89.],
453                            [90., 91., 92., 93., 94.],
454                            [95., 96., 97., 98., 99.]],
455
456                           [[80., 81., 82., 83., 84.],
457                            [85., 86., 87., 88., 89.],
458                            [90., 91., 92., 93., 94.],
459                            [95., 96., 97., 98., 99.]],
460
461                           [[80., 81., 82., 83., 84.],
462                            [85., 86., 87., 88., 89.],
463                            [90., 91., 92., 93., 94.],
464                            [95., 96., 97., 98., 99.]]],
465
466                          [[[80., 81., 82., 83., 84.],
467                            [85., 86., 87., 88., 89.],
468                            [90., 91., 92., 93., 94.],
469                            [95., 96., 97., 98., 99.]],
470
471                           [[80., 81., 82., 83., 84.],
472                            [85., 86., 87., 88., 89.],
473                            [90., 91., 92., 93., 94.],
474                            [95., 96., 97., 98., 99.]],
475
476                           [[80., 81., 82., 83., 84.],
477                            [85., 86., 87., 88., 89.],
478                            [90., 91., 92., 93., 94.],
479                            [95., 96., 97., 98., 99.]],
480
481                           [[80., 81., 82., 83., 84.],
482                            [85., 86., 87., 88., 89.],
483                            [90., 91., 92., 93., 94.],
484                            [95., 96., 97., 98., 99.]],
485
486                           [[80., 81., 82., 83., 84.],
487                            [85., 86., 87., 88., 89.],
488                            [90., 91., 92., 93., 94.],
489                            [95., 96., 97., 98., 99.]]],
490
491                          [[[80., 81., 82., 83., 84.],
492                            [85., 86., 87., 88., 89.],
493                            [90., 91., 92., 93., 94.],
494                            [95., 96., 97., 98., 99.]],
495
496                           [[80., 81., 82., 83., 84.],
497                            [85., 86., 87., 88., 89.],
498                            [90., 91., 92., 93., 94.],
499                            [95., 96., 97., 98., 99.]],
500
501                           [[80., 81., 82., 83., 84.],
502                            [85., 86., 87., 88., 89.],
503                            [90., 91., 92., 93., 94.],
504                            [95., 96., 97., 98., 99.]],
505
506                           [[80., 81., 82., 83., 84.],
507                            [85., 86., 87., 88., 89.],
508                            [90., 91., 92., 93., 94.],
509                            [95., 96., 97., 98., 99.]],
510
511                           [[80., 81., 82., 83., 84.],
512                            [85., 86., 87., 88., 89.],
513                            [90., 91., 92., 93., 94.],
514                            [95., 96., 97., 98., 99.]]],
515
516                          [[[80., 81., 82., 83., 84.],
517                            [85., 86., 87., 88., 89.],
518                            [90., 91., 92., 93., 94.],
519                            [95., 96., 97., 98., 99.]],
520
521                           [[80., 81., 82., 83., 84.],
522                            [85., 86., 87., 88., 89.],
523                            [90., 91., 92., 93., 94.],
524                            [95., 96., 97., 98., 99.]],
525
526                           [[80., 81., 82., 83., 84.],
527                            [85., 86., 87., 88., 89.],
528                            [90., 91., 92., 93., 94.],
529                            [95., 96., 97., 98., 99.]],
530
531                           [[80., 81., 82., 83., 84.],
532                            [85., 86., 87., 88., 89.],
533                            [90., 91., 92., 93., 94.],
534                            [95., 96., 97., 98., 99.]],
535
536                           [[80., 81., 82., 83., 84.],
537                            [85., 86., 87., 88., 89.],
538                            [90., 91., 92., 93., 94.],
539                            [95., 96., 97., 98., 99.]]]],
540
541                         [[[[80., 81., 82., 83., 84.],
542                            [85., 86., 87., 88., 89.],
543                            [90., 91., 92., 93., 94.],
544                            [95., 96., 97., 98., 99.]],
545
546                           [[80., 81., 82., 83., 84.],
547                            [85., 86., 87., 88., 89.],
548                            [90., 91., 92., 93., 94.],
549                            [95., 96., 97., 98., 99.]],
550
551                           [[80., 81., 82., 83., 84.],
552                            [85., 86., 87., 88., 89.],
553                            [90., 91., 92., 93., 94.],
554                            [95., 96., 97., 98., 99.]],
555
556                           [[80., 81., 82., 83., 84.],
557                            [85., 86., 87., 88., 89.],
558                            [90., 91., 92., 93., 94.],
559                            [95., 96., 97., 98., 99.]],
560
561                           [[80., 81., 82., 83., 84.],
562                            [85., 86., 87., 88., 89.],
563                            [90., 91., 92., 93., 94.],
564                            [95., 96., 97., 98., 99.]]],
565
566                          [[[80., 81., 82., 83., 84.],
567                            [85., 86., 87., 88., 89.],
568                            [90., 91., 92., 93., 94.],
569                            [95., 96., 97., 98., 99.]],
570
571                           [[80., 81., 82., 83., 84.],
572                            [85., 86., 87., 88., 89.],
573                            [90., 91., 92., 93., 94.],
574                            [95., 96., 97., 98., 99.]],
575
576                           [[80., 81., 82., 83., 84.],
577                            [85., 86., 87., 88., 89.],
578                            [90., 91., 92., 93., 94.],
579                            [95., 96., 97., 98., 99.]],
580
581                           [[80., 81., 82., 83., 84.],
582                            [85., 86., 87., 88., 89.],
583                            [90., 91., 92., 93., 94.],
584                            [95., 96., 97., 98., 99.]],
585
586                           [[80., 81., 82., 83., 84.],
587                            [85., 86., 87., 88., 89.],
588                            [90., 91., 92., 93., 94.],
589                            [95., 96., 97., 98., 99.]]],
590
591                          [[[80., 81., 82., 83., 84.],
592                            [85., 86., 87., 88., 89.],
593                            [90., 91., 92., 93., 94.],
594                            [95., 96., 97., 98., 99.]],
595
596                           [[80., 81., 82., 83., 84.],
597                            [85., 86., 87., 88., 89.],
598                            [90., 91., 92., 93., 94.],
599                            [95., 96., 97., 98., 99.]],
600
601                           [[80., 81., 82., 83., 84.],
602                            [85., 86., 87., 88., 89.],
603                            [90., 91., 92., 93., 94.],
604                            [95., 96., 97., 98., 99.]],
605
606                           [[80., 81., 82., 83., 84.],
607                            [85., 86., 87., 88., 89.],
608                            [90., 91., 92., 93., 94.],
609                            [95., 96., 97., 98., 99.]],
610
611                           [[80., 81., 82., 83., 84.],
612                            [85., 86., 87., 88., 89.],
613                            [90., 91., 92., 93., 94.],
614                            [95., 96., 97., 98., 99.]]],
615
616                          [[[80., 81., 82., 83., 84.],
617                            [85., 86., 87., 88., 89.],
618                            [90., 91., 92., 93., 94.],
619                            [95., 96., 97., 98., 99.]],
620
621                           [[80., 81., 82., 83., 84.],
622                            [85., 86., 87., 88., 89.],
623                            [90., 91., 92., 93., 94.],
624                            [95., 96., 97., 98., 99.]],
625
626                           [[80., 81., 82., 83., 84.],
627                            [85., 86., 87., 88., 89.],
628                            [90., 91., 92., 93., 94.],
629                            [95., 96., 97., 98., 99.]],
630
631                           [[80., 81., 82., 83., 84.],
632                            [85., 86., 87., 88., 89.],
633                            [90., 91., 92., 93., 94.],
634                            [95., 96., 97., 98., 99.]],
635
636                           [[80., 81., 82., 83., 84.],
637                            [85., 86., 87., 88., 89.],
638                            [90., 91., 92., 93., 94.],
639                            [95., 96., 97., 98., 99.]]]]],
640
641                        [[[[[80., 81., 82., 83., 84.],
642                            [85., 86., 87., 88., 89.],
643                            [90., 91., 92., 93., 94.],
644                            [95., 96., 97., 98., 99.]],
645
646                           [[80., 81., 82., 83., 84.],
647                            [85., 86., 87., 88., 89.],
648                            [90., 91., 92., 93., 94.],
649                            [95., 96., 97., 98., 99.]],
650
651                           [[80., 81., 82., 83., 84.],
652                            [85., 86., 87., 88., 89.],
653                            [90., 91., 92., 93., 94.],
654                            [95., 96., 97., 98., 99.]],
655
656                           [[80., 81., 82., 83., 84.],
657                            [85., 86., 87., 88., 89.],
658                            [90., 91., 92., 93., 94.],
659                            [95., 96., 97., 98., 99.]],
660
661                           [[80., 81., 82., 83., 84.],
662                            [85., 86., 87., 88., 89.],
663                            [90., 91., 92., 93., 94.],
664                            [95., 96., 97., 98., 99.]]],
665
666                          [[[80., 81., 82., 83., 84.],
667                            [85., 86., 87., 88., 89.],
668                            [90., 91., 92., 93., 94.],
669                            [95., 96., 97., 98., 99.]],
670
671                           [[80., 81., 82., 83., 84.],
672                            [85., 86., 87., 88., 89.],
673                            [90., 91., 92., 93., 94.],
674                            [95., 96., 97., 98., 99.]],
675
676                           [[80., 81., 82., 83., 84.],
677                            [85., 86., 87., 88., 89.],
678                            [90., 91., 92., 93., 94.],
679                            [95., 96., 97., 98., 99.]],
680
681                           [[80., 81., 82., 83., 84.],
682                            [85., 86., 87., 88., 89.],
683                            [90., 91., 92., 93., 94.],
684                            [95., 96., 97., 98., 99.]],
685
686                           [[80., 81., 82., 83., 84.],
687                            [85., 86., 87., 88., 89.],
688                            [90., 91., 92., 93., 94.],
689                            [95., 96., 97., 98., 99.]]],
690
691                          [[[80., 81., 82., 83., 84.],
692                            [85., 86., 87., 88., 89.],
693                            [90., 91., 92., 93., 94.],
694                            [95., 96., 97., 98., 99.]],
695
696                           [[80., 81., 82., 83., 84.],
697                            [85., 86., 87., 88., 89.],
698                            [90., 91., 92., 93., 94.],
699                            [95., 96., 97., 98., 99.]],
700
701                           [[80., 81., 82., 83., 84.],
702                            [85., 86., 87., 88., 89.],
703                            [90., 91., 92., 93., 94.],
704                            [95., 96., 97., 98., 99.]],
705
706                           [[80., 81., 82., 83., 84.],
707                            [85., 86., 87., 88., 89.],
708                            [90., 91., 92., 93., 94.],
709                            [95., 96., 97., 98., 99.]],
710
711                           [[80., 81., 82., 83., 84.],
712                            [85., 86., 87., 88., 89.],
713                            [90., 91., 92., 93., 94.],
714                            [95., 96., 97., 98., 99.]]],
715
716                          [[[80., 81., 82., 83., 84.],
717                            [85., 86., 87., 88., 89.],
718                            [90., 91., 92., 93., 94.],
719                            [95., 96., 97., 98., 99.]],
720
721                           [[80., 81., 82., 83., 84.],
722                            [85., 86., 87., 88., 89.],
723                            [90., 91., 92., 93., 94.],
724                            [95., 96., 97., 98., 99.]],
725
726                           [[80., 81., 82., 83., 84.],
727                            [85., 86., 87., 88., 89.],
728                            [90., 91., 92., 93., 94.],
729                            [95., 96., 97., 98., 99.]],
730
731                           [[80., 81., 82., 83., 84.],
732                            [85., 86., 87., 88., 89.],
733                            [90., 91., 92., 93., 94.],
734                            [95., 96., 97., 98., 99.]],
735
736                           [[80., 81., 82., 83., 84.],
737                            [85., 86., 87., 88., 89.],
738                            [90., 91., 92., 93., 94.],
739                            [95., 96., 97., 98., 99.]]]],
740
741                         [[[[80., 81., 82., 83., 84.],
742                            [85., 86., 87., 88., 89.],
743                            [90., 91., 92., 93., 94.],
744                            [95., 96., 97., 98., 99.]],
745
746                           [[80., 81., 82., 83., 84.],
747                            [85., 86., 87., 88., 89.],
748                            [90., 91., 92., 93., 94.],
749                            [95., 96., 97., 98., 99.]],
750
751                           [[80., 81., 82., 83., 84.],
752                            [85., 86., 87., 88., 89.],
753                            [90., 91., 92., 93., 94.],
754                            [95., 96., 97., 98., 99.]],
755
756                           [[80., 81., 82., 83., 84.],
757                            [85., 86., 87., 88., 89.],
758                            [90., 91., 92., 93., 94.],
759                            [95., 96., 97., 98., 99.]],
760
761                           [[80., 81., 82., 83., 84.],
762                            [85., 86., 87., 88., 89.],
763                            [90., 91., 92., 93., 94.],
764                            [95., 96., 97., 98., 99.]]],
765
766                          [[[80., 81., 82., 83., 84.],
767                            [85., 86., 87., 88., 89.],
768                            [90., 91., 92., 93., 94.],
769                            [95., 96., 97., 98., 99.]],
770
771                           [[80., 81., 82., 83., 84.],
772                            [85., 86., 87., 88., 89.],
773                            [90., 91., 92., 93., 94.],
774                            [95., 96., 97., 98., 99.]],
775
776                           [[80., 81., 82., 83., 84.],
777                            [85., 86., 87., 88., 89.],
778                            [90., 91., 92., 93., 94.],
779                            [95., 96., 97., 98., 99.]],
780
781                           [[80., 81., 82., 83., 84.],
782                            [85., 86., 87., 88., 89.],
783                            [90., 91., 92., 93., 94.],
784                            [95., 96., 97., 98., 99.]],
785
786                           [[80., 81., 82., 83., 84.],
787                            [85., 86., 87., 88., 89.],
788                            [90., 91., 92., 93., 94.],
789                            [95., 96., 97., 98., 99.]]],
790
791                          [[[80., 81., 82., 83., 84.],
792                            [85., 86., 87., 88., 89.],
793                            [90., 91., 92., 93., 94.],
794                            [95., 96., 97., 98., 99.]],
795
796                           [[80., 81., 82., 83., 84.],
797                            [85., 86., 87., 88., 89.],
798                            [90., 91., 92., 93., 94.],
799                            [95., 96., 97., 98., 99.]],
800
801                           [[80., 81., 82., 83., 84.],
802                            [85., 86., 87., 88., 89.],
803                            [90., 91., 92., 93., 94.],
804                            [95., 96., 97., 98., 99.]],
805
806                           [[80., 81., 82., 83., 84.],
807                            [85., 86., 87., 88., 89.],
808                            [90., 91., 92., 93., 94.],
809                            [95., 96., 97., 98., 99.]],
810
811                           [[80., 81., 82., 83., 84.],
812                            [85., 86., 87., 88., 89.],
813                            [90., 91., 92., 93., 94.],
814                            [95., 96., 97., 98., 99.]]],
815
816                          [[[80., 81., 82., 83., 84.],
817                            [85., 86., 87., 88., 89.],
818                            [90., 91., 92., 93., 94.],
819                            [95., 96., 97., 98., 99.]],
820
821                           [[80., 81., 82., 83., 84.],
822                            [85., 86., 87., 88., 89.],
823                            [90., 91., 92., 93., 94.],
824                            [95., 96., 97., 98., 99.]],
825
826                           [[80., 81., 82., 83., 84.],
827                            [85., 86., 87., 88., 89.],
828                            [90., 91., 92., 93., 94.],
829                            [95., 96., 97., 98., 99.]],
830
831                           [[80., 81., 82., 83., 84.],
832                            [85., 86., 87., 88., 89.],
833                            [90., 91., 92., 93., 94.],
834                            [95., 96., 97., 98., 99.]],
835
836                           [[80., 81., 82., 83., 84.],
837                            [85., 86., 87., 88., 89.],
838                            [90., 91., 92., 93., 94.],
839                            [95., 96., 97., 98., 99.]]]]]]])
840
841    context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
842    gather = GatherNet()
843    output = gather(x, indices)
844    error = np.ones(shape=output.asnumpy().shape) * 1.0e-6
845    diff = output.asnumpy() - expect
846    assert np.all(diff < error)
847    assert np.all(-diff < error)
848
849
850class GatherNet1(nn.Cell):
851    def __init__(self):
852        super(GatherNet1, self).__init__()
853        self.gather = P.Gather()
854
855    def construct(self, x, indices):
856        return self.gather(x, indices, -1)
857
858
859@pytest.mark.level0
860@pytest.mark.platform_x86_gpu_training
861@pytest.mark.env_onecard
862def test_gather1():
863    x = Tensor(np.arange(2 * 3 * 4 * 5, dtype=np.float32).reshape(2, 3, 4, 5))
864    indices = Tensor(np.array([1, 3, 4], dtype='i4'))
865    expect = np.array([[[[1., 3., 4.],
866                         [6., 8., 9.],
867                         [11., 13., 14.],
868                         [16., 18., 19.]],
869
870                        [[21., 23., 24.],
871                         [26., 28., 29.],
872                         [31., 33., 34.],
873                         [36., 38., 39.]],
874
875                        [[41., 43., 44.],
876                         [46., 48., 49.],
877                         [51., 53., 54.],
878                         [56., 58., 59.]]],
879
880                       [[[61., 63., 64.],
881                         [66., 68., 69.],
882                         [71., 73., 74.],
883                         [76., 78., 79.]],
884
885                        [[81., 83., 84.],
886                         [86., 88., 89.],
887                         [91., 93., 94.],
888                         [96., 98., 99.]],
889
890                        [[101., 103., 104.],
891                         [106., 108., 109.],
892                         [111., 113., 114.],
893                         [116., 118., 119.]]]])
894
895    context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
896    gather = GatherNet1()
897    output = gather(x, indices)
898    error = np.ones(shape=output.asnumpy().shape) * 1.0e-6
899    diff = output.asnumpy() - expect
900    assert np.all(diff < error)
901    assert np.all(-diff < error)
902
903
904class GatherNet2(nn.Cell):
905    def __init__(self):
906        super(GatherNet2, self).__init__()
907        self.gather = P.Gather()
908
909    def construct(self, x, indices):
910        return self.gather(x, indices, 0)
911
912
913@pytest.mark.level0
914@pytest.mark.platform_x86_gpu_training
915@pytest.mark.env_onecard
916def test_gather2():
917    x = Tensor(np.array([[4., 5., 4., 1., 5.],
918                         [4., 9., 5., 6., 4.],
919                         [9., 8., 4., 3., 6.],
920                         [0., 4., 2., 2., 8.],
921                         [1., 8., 6., 2., 8.],
922                         [8., 1., 9., 7., 3.],
923                         [7., 9., 2., 5., 7.],
924                         [9., 8., 6., 8., 5.],
925                         [3., 7., 2., 7., 4.],
926                         [4., 2., 8., 2., 9.]]
927                        ).astype(np.float32))
928
929    indices = Tensor(np.array([[4000, 1, 300000]]).astype(np.int64))
930    expect = np.array([[[0., 0., 0., 0., 0.],
931                        [4., 9., 5., 6., 4.],
932                        [0., 0., 0., 0., 0.]]])
933
934    context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
935    gather = GatherNet2()
936    output = gather(x, indices)
937    error = np.ones(shape=output.asnumpy().shape) * 1.0e-6
938    diff = output.asnumpy() - expect
939    assert np.all(diff < error)
940    assert np.all(-diff < error)
941
942
943# Dynamic Shape testing ahead
944class GatherNetDynamic(nn.Cell):
945    def __init__(self, axis=0, dyn_a=True, dyn_b=True):
946        super(GatherNetDynamic, self).__init__()
947        self.gather = P.Gather()
948        self.gpu_convert_to_dynamic_shape = inner.GpuConvertToDynamicShape()
949        self.to_dyn_1 = dyn_a
950        self.to_dyn_2 = dyn_b
951        self.axis = axis
952
953    def construct(self, x, indices):
954        # testing selective inputs being dynamic
955        if self.to_dyn_1:
956            x = self.gpu_convert_to_dynamic_shape(x)
957        if self.to_dyn_2:
958            indices = self.gpu_convert_to_dynamic_shape(indices)
959        return self.gather(x, indices, self.axis)
960
961
962@pytest.mark.level0
963@pytest.mark.platform_x86_gpu_training
964@pytest.mark.env_onecard
965def test_gatherV2_dyn_ab():
966    """
967    Tests for Dynamic shape with both inputs dynamic
968    """
969    context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
970    gather = GatherNetDynamic()
971    x = Tensor(np.array([[4., 5., 4., 1., 5.],
972                         [4., 9., 5., 6., 4.],
973                         [9., 8., 4., 3., 6.],
974                         [0., 4., 2., 2., 8.],
975                         [1., 8., 6., 2., 8.],
976                         [8., 1., 9., 7., 3.],
977                         [7., 9., 2., 5., 7.],
978                         [9., 8., 6., 8., 5.],
979                         [3., 7., 2., 7., 4.],
980                         [4., 2., 8., 2., 9.]]
981                        ).astype(np.float32))
982    indices = Tensor(np.array([[4000, 1, 300000]]).astype(np.int32))
983    expect = np.array([[[0., 0., 0., 0., 0.],
984                        [4., 9., 5., 6., 4.],
985                        [0., 0., 0., 0., 0.]]])
986    output = gather(x, indices)
987    error = np.ones(shape=output.asnumpy().shape) * 1.0e-6
988    diff = output.asnumpy() - expect
989    assert np.all(diff < error)
990    assert np.all(-diff < error)
991
992
993@pytest.mark.level0
994@pytest.mark.platform_x86_gpu_training
995@pytest.mark.env_onecard
996def test_gatherV2_dyn_a():
997    """
998    Tests for Dynamic shape with only first input dynamic
999    """
1000    context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
1001    gather = GatherNetDynamic(-1, True, False)
1002    # test 1
1003    x = Tensor(np.array([[4., 5., 4., 1., 5.],
1004                         [4., 9., 5., 6., 4.],
1005                         [9., 8., 4., 3., 6.],
1006                         [0., 4., 2., 2., 8.],
1007                         [1., 8., 6., 2., 8.],
1008                         [8., 1., 9., 7., 3.],
1009                         [7., 9., 2., 5., 7.],
1010                         [9., 8., 6., 8., 5.],
1011                         [3., 7., 2., 7., 4.],
1012                         [4., 2., 8., 2., 9.]]
1013                        ).astype(np.float32))
1014    indices = Tensor(np.array([[4000, 1, 300000]]).astype(np.int64))
1015    expect = np.array([[[0., 5., 0.]],
1016                       [[0., 9., 0.]],
1017                       [[0., 8., 0.]],
1018                       [[0., 4., 0.]],
1019                       [[0., 8., 0.]],
1020                       [[0., 1., 0.]],
1021                       [[0., 9., 0.]],
1022                       [[0., 8., 0.]],
1023                       [[0., 7., 0.]],
1024                       [[0., 2., 0.]]]).astype(np.float32)
1025    output = gather(x, indices)
1026    error = np.ones(shape=output.asnumpy().shape) * 1.0e-6
1027    diff = output.asnumpy() - expect
1028    assert np.all(diff < error)
1029    assert np.all(-diff < error)
1030    # test 2
1031    x = Tensor(np.arange(2 * 3 * 4 * 5, dtype=np.float32).reshape(2, 3, 4, 5))
1032    indices = Tensor(np.array([1, 3, 4], dtype='i4'))
1033    expect = np.array([[[[1., 3., 4.],
1034                         [6., 8., 9.],
1035                         [11., 13., 14.],
1036                         [16., 18., 19.]],
1037
1038                        [[21., 23., 24.],
1039                         [26., 28., 29.],
1040                         [31., 33., 34.],
1041                         [36., 38., 39.]],
1042
1043                        [[41., 43., 44.],
1044                         [46., 48., 49.],
1045                         [51., 53., 54.],
1046                         [56., 58., 59.]]],
1047
1048                       [[[61., 63., 64.],
1049                         [66., 68., 69.],
1050                         [71., 73., 74.],
1051                         [76., 78., 79.]],
1052
1053                        [[81., 83., 84.],
1054                         [86., 88., 89.],
1055                         [91., 93., 94.],
1056                         [96., 98., 99.]],
1057
1058                        [[101., 103., 104.],
1059                         [106., 108., 109.],
1060                         [111., 113., 114.],
1061                         [116., 118., 119.]]]])
1062    output = gather(x, indices)
1063    error = np.ones(shape=output.asnumpy().shape) * 1.0e-6
1064    diff = output.asnumpy() - expect
1065    assert np.all(diff < error)
1066    assert np.all(-diff < error)
1067
1068
1069@pytest.mark.level0
1070@pytest.mark.platform_x86_gpu_training
1071@pytest.mark.env_onecard
1072def test_gatherV2_dyn_b():
1073    """
1074    Tests for Dynamic shape with only second input dynamic
1075    """
1076    context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
1077    gather = GatherNetDynamic(-1, False, True)
1078    # test 1
1079    x = Tensor(np.array([[4., 5., 4., 1., 5.],
1080                         [4., 9., 5., 6., 4.],
1081                         [9., 8., 4., 3., 6.],
1082                         [0., 4., 2., 2., 8.],
1083                         [1., 8., 6., 2., 8.],
1084                         [8., 1., 9., 7., 3.],
1085                         [7., 9., 2., 5., 7.],
1086                         [9., 8., 6., 8., 5.],
1087                         [3., 7., 2., 7., 4.],
1088                         [4., 2., 8., 2., 9.]]
1089                        ).astype(np.float32))
1090    indices = Tensor(np.array([[4000, 1, 300000]]).astype(np.int32))
1091    expect = np.array([[[0., 5., 0.]],
1092                       [[0., 9., 0.]],
1093                       [[0., 8., 0.]],
1094                       [[0., 4., 0.]],
1095                       [[0., 8., 0.]],
1096                       [[0., 1., 0.]],
1097                       [[0., 9., 0.]],
1098                       [[0., 8., 0.]],
1099                       [[0., 7., 0.]],
1100                       [[0., 2., 0.]]]).astype(np.float32)
1101    output = gather(x, indices)
1102    error = np.ones(shape=output.asnumpy().shape) * 1.0e-6
1103    diff = output.asnumpy() - expect
1104    assert np.all(diff < error)
1105    assert np.all(-diff < error)
1106    # test 2
1107    x = Tensor(np.arange(2 * 3 * 4 * 5, dtype=np.float32).reshape(2, 3, 4, 5))
1108    indices = Tensor(np.array([1, 3, 4], dtype='i4'))
1109    expect = np.array([[[[1., 3., 4.],
1110                         [6., 8., 9.],
1111                         [11., 13., 14.],
1112                         [16., 18., 19.]],
1113                        [[21., 23., 24.],
1114                         [26., 28., 29.],
1115                         [31., 33., 34.],
1116                         [36., 38., 39.]],
1117                        [[41., 43., 44.],
1118                         [46., 48., 49.],
1119                         [51., 53., 54.],
1120                         [56., 58., 59.]]],
1121                       [[[61., 63., 64.],
1122                         [66., 68., 69.],
1123                         [71., 73., 74.],
1124                         [76., 78., 79.]],
1125                        [[81., 83., 84.],
1126                         [86., 88., 89.],
1127                         [91., 93., 94.],
1128                         [96., 98., 99.]],
1129                        [[101., 103., 104.],
1130                         [106., 108., 109.],
1131                         [111., 113., 114.],
1132                         [116., 118., 119.]]]])
1133    output = gather(x, indices)
1134    error = np.ones(shape=output.asnumpy().shape) * 1.0e-6
1135    diff = output.asnumpy() - expect
1136    assert np.all(diff < error)
1137    assert np.all(-diff < error)
1138
1139
1140@pytest.mark.level0
1141@pytest.mark.platform_x86_gpu_training
1142@pytest.mark.env_onecard
1143def test_gather1_float64():
1144    x = Tensor(np.arange(2 * 3 * 4 * 5, dtype=np.float64).reshape(2, 3, 4, 5))
1145    indices = Tensor(np.array([1, 3, 4], dtype='i4'))
1146    expect = np.array([[[[1., 3., 4.],
1147                         [6., 8., 9.],
1148                         [11., 13., 14.],
1149                         [16., 18., 19.]],
1150
1151                        [[21., 23., 24.],
1152                         [26., 28., 29.],
1153                         [31., 33., 34.],
1154                         [36., 38., 39.]],
1155
1156                        [[41., 43., 44.],
1157                         [46., 48., 49.],
1158                         [51., 53., 54.],
1159                         [56., 58., 59.]]],
1160
1161                       [[[61., 63., 64.],
1162                         [66., 68., 69.],
1163                         [71., 73., 74.],
1164                         [76., 78., 79.]],
1165
1166                        [[81., 83., 84.],
1167                         [86., 88., 89.],
1168                         [91., 93., 94.],
1169                         [96., 98., 99.]],
1170
1171                        [[101., 103., 104.],
1172                         [106., 108., 109.],
1173                         [111., 113., 114.],
1174                         [116., 118., 119.]]]]).astype(np.float64)
1175
1176    context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
1177    gather = GatherNet1()
1178    output = gather(x, indices)
1179    error = np.ones(shape=output.asnumpy().shape) * 1.0e-6
1180    diff = output.asnumpy() - expect
1181    assert np.all(diff < error)
1182    assert np.all(-diff < error)
1183
1184
1185@pytest.mark.level0
1186@pytest.mark.platform_x86_gpu_training
1187@pytest.mark.env_onecard
1188def test_gather1_int32():
1189    x = Tensor(np.arange(2 * 3 * 4 * 5, dtype=np.int32).reshape(2, 3, 4, 5))
1190    indices = Tensor(np.array([1, 3, 4], dtype='i4'))
1191    expect = np.array([[[[1., 3., 4.],
1192                         [6., 8., 9.],
1193                         [11., 13., 14.],
1194                         [16., 18., 19.]],
1195
1196                        [[21., 23., 24.],
1197                         [26., 28., 29.],
1198                         [31., 33., 34.],
1199                         [36., 38., 39.]],
1200
1201                        [[41., 43., 44.],
1202                         [46., 48., 49.],
1203                         [51., 53., 54.],
1204                         [56., 58., 59.]]],
1205
1206                       [[[61., 63., 64.],
1207                         [66., 68., 69.],
1208                         [71., 73., 74.],
1209                         [76., 78., 79.]],
1210
1211                        [[81., 83., 84.],
1212                         [86., 88., 89.],
1213                         [91., 93., 94.],
1214                         [96., 98., 99.]],
1215
1216                        [[101., 103., 104.],
1217                         [106., 108., 109.],
1218                         [111., 113., 114.],
1219                         [116., 118., 119.]]]]).astype(np.int32)
1220
1221    context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
1222    gather = GatherNet1()
1223    output = gather(x, indices)
1224    error = np.ones(shape=output.asnumpy().shape) * 1.0e-6
1225    diff = output.asnumpy() - expect
1226    assert np.all(diff < error)
1227    assert np.all(-diff < error)
1228
1229
1230@pytest.mark.level1
1231@pytest.mark.platform_x86_gpu_training
1232@pytest.mark.env_onecard
1233def test_gather1_int16():
1234    x = Tensor(np.arange(2 * 3 * 4 * 5, dtype=np.int16).reshape(2, 3, 4, 5))
1235    indices = Tensor(np.array([1, 3, 4], dtype='i4'))
1236    expect = np.array([[[[1., 3., 4.],
1237                         [6., 8., 9.],
1238                         [11., 13., 14.],
1239                         [16., 18., 19.]],
1240
1241                        [[21., 23., 24.],
1242                         [26., 28., 29.],
1243                         [31., 33., 34.],
1244                         [36., 38., 39.]],
1245
1246                        [[41., 43., 44.],
1247                         [46., 48., 49.],
1248                         [51., 53., 54.],
1249                         [56., 58., 59.]]],
1250
1251                       [[[61., 63., 64.],
1252                         [66., 68., 69.],
1253                         [71., 73., 74.],
1254                         [76., 78., 79.]],
1255
1256                        [[81., 83., 84.],
1257                         [86., 88., 89.],
1258                         [91., 93., 94.],
1259                         [96., 98., 99.]],
1260
1261                        [[101., 103., 104.],
1262                         [106., 108., 109.],
1263                         [111., 113., 114.],
1264                         [116., 118., 119.]]]]).astype(np.int16)
1265
1266    context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
1267    gather = GatherNet1()
1268    output = gather(x, indices)
1269    error = np.ones(shape=output.asnumpy().shape) * 1.0e-6
1270    diff = output.asnumpy() - expect
1271    assert np.all(diff < error)
1272    assert np.all(-diff < error)
1273
1274
1275@pytest.mark.level1
1276@pytest.mark.platform_x86_gpu_training
1277@pytest.mark.env_onecard
1278def test_gather1_int8():
1279    x = Tensor(np.arange(2 * 3 * 4 * 5, dtype=np.int8).reshape(2, 3, 4, 5))
1280    indices = Tensor(np.array([1, 3, 4], dtype='i4'))
1281    expect = np.array([[[[1., 3., 4.],
1282                         [6., 8., 9.],
1283                         [11., 13., 14.],
1284                         [16., 18., 19.]],
1285
1286                        [[21., 23., 24.],
1287                         [26., 28., 29.],
1288                         [31., 33., 34.],
1289                         [36., 38., 39.]],
1290
1291                        [[41., 43., 44.],
1292                         [46., 48., 49.],
1293                         [51., 53., 54.],
1294                         [56., 58., 59.]]],
1295
1296                       [[[61., 63., 64.],
1297                         [66., 68., 69.],
1298                         [71., 73., 74.],
1299                         [76., 78., 79.]],
1300
1301                        [[81., 83., 84.],
1302                         [86., 88., 89.],
1303                         [91., 93., 94.],
1304                         [96., 98., 99.]],
1305
1306                        [[101., 103., 104.],
1307                         [106., 108., 109.],
1308                         [111., 113., 114.],
1309                         [116., 118., 119.]]]]).astype(np.int8)
1310
1311    context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
1312    gather = GatherNet1()
1313    output = gather(x, indices)
1314    error = np.ones(shape=output.asnumpy().shape) * 1.0e-6
1315    diff = output.asnumpy() - expect
1316    assert np.all(diff < error)
1317    assert np.all(-diff < error)
1318
1319
1320@pytest.mark.level1
1321@pytest.mark.platform_x86_gpu_training
1322@pytest.mark.env_onecard
1323def test_gather1_uint8():
1324    x = Tensor(np.arange(2 * 3 * 4 * 5, dtype=np.uint8).reshape(2, 3, 4, 5))
1325    indices = Tensor(np.array([1, 3, 4], dtype='i4'))
1326    expect = np.array([[[[1., 3., 4.],
1327                         [6., 8., 9.],
1328                         [11., 13., 14.],
1329                         [16., 18., 19.]],
1330
1331                        [[21., 23., 24.],
1332                         [26., 28., 29.],
1333                         [31., 33., 34.],
1334                         [36., 38., 39.]],
1335
1336                        [[41., 43., 44.],
1337                         [46., 48., 49.],
1338                         [51., 53., 54.],
1339                         [56., 58., 59.]]],
1340
1341                       [[[61., 63., 64.],
1342                         [66., 68., 69.],
1343                         [71., 73., 74.],
1344                         [76., 78., 79.]],
1345
1346                        [[81., 83., 84.],
1347                         [86., 88., 89.],
1348                         [91., 93., 94.],
1349                         [96., 98., 99.]],
1350
1351                        [[101., 103., 104.],
1352                         [106., 108., 109.],
1353                         [111., 113., 114.],
1354                         [116., 118., 119.]]]]).astype(np.uint8)
1355
1356    context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
1357    gather = GatherNet1()
1358    output = gather(x, indices)
1359    error = np.ones(shape=output.asnumpy().shape) * 1.0e-6
1360    diff = output.asnumpy() - expect
1361    assert np.all(diff < error)
1362    assert np.all(-diff < error)
1363
1364
1365@pytest.mark.level1
1366@pytest.mark.platform_x86_gpu_training
1367@pytest.mark.env_onecard
1368def test_gather1_bool():
1369    x = Tensor(np.array([[0, 1, 1, 0], [1, 0, 0, 0], [1, 0, 1, 0]], dtype=np.bool))
1370    indices = Tensor(np.array(([1, 2]), dtype='i4'))
1371    expect = np.array([[1, 1], [0, 0], [0, 1]]).astype(np.bool)
1372
1373    context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
1374    gather = GatherNet1()
1375    output = gather(x, indices)
1376    assert np.all(expect == output.asnumpy())
1377