• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16import numpy as np
17import pytest
18import mindspore.context as context
19import mindspore.nn as nn
20from mindspore import Tensor, Parameter
21from mindspore.ops import operations as P
22from mindspore.ops.operations import _inner_ops as inner
23
24context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
25# all cases tested against dchip
26
27func_map = {
28    "update": P.ScatterUpdate,
29    "add": P.ScatterAdd,
30    "sub": P.ScatterSub,
31}
32
33
34class TestScatterFuncNet(nn.Cell):
35    def __init__(self, func, lock, inputx, indices, updates):
36        super(TestScatterFuncNet, self).__init__()
37
38        self.scatter_func = func_map[func](use_locking=lock)
39        self.inputx = Parameter(inputx, name="inputx")
40        self.indices = Parameter(indices, name="indices")
41        self.updates = Parameter(updates, name="updates")
42
43    def construct(self):
44        out = self.scatter_func(self.inputx, self.indices, self.updates)
45        return out
46
47
48def scatter_func_net(func, inputx, indices, updates):
49    lock = True
50    net = TestScatterFuncNet(func, lock, inputx, indices, updates)
51    return net()
52
53
54def scatter_func_use_locking_false_net(func, inputx, indices, updates):
55    lock = False
56    net = TestScatterFuncNet(func, lock, inputx, indices, updates)
57    return net()
58
59
60class TestScatterFuncDynamicNet(nn.Cell):
61    def __init__(self, func, inputx, indices, updates):
62        super(TestScatterFuncDynamicNet, self).__init__()
63        self.scatter_func = func_map[func]()
64        self.test_dynamic = inner.GpuConvertToDynamicShape()
65        self.inputx = Parameter(inputx, name="inputx")
66        self.indices = Parameter(indices, name="indices")
67        self.updates = Parameter(updates, name="updates")
68
69    def construct(self):
70        indices = self.test_dynamic(self.indices)
71        updates = self.test_dynamic(self.updates)
72        out = self.scatter_func(self.inputx, indices, updates)
73        return out
74
75
76def scatter_func_d_net(func, inputx, indices, updates):
77    context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
78    net = TestScatterFuncDynamicNet(func, inputx, indices, updates)
79    return net()
80
81
82class TestScatterFuncDynamicNet2(nn.Cell):
83    def __init__(self, func, inputx):
84        super(TestScatterFuncDynamicNet2, self).__init__()
85        self.scatter_func = func_map[func]()
86        self.test_dynamic = inner.GpuConvertToDynamicShape()
87        self.inputx = Parameter(inputx, name="inputx")
88
89    def construct(self, indices, updates):
90        indices = self.test_dynamic(indices)
91        updates = self.test_dynamic(updates)
92        out = self.scatter_func(self.inputx, indices, updates)
93        return out
94
95
96def scatter_func_d2_net(func, inputx, indices_1, updates_1, indices_2, updates_2):
97    context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
98    net = TestScatterFuncDynamicNet2(func, inputx)
99    out1 = net(indices_1, updates_1)
100    out2 = net(indices_2, updates_2)
101    return (out1, out2)
102
103
104@pytest.mark.level0
105@pytest.mark.platform_x86_gpu_training
106@pytest.mark.env_onecard
107def test_scatter_func_small_float32():
108    inputx = Tensor(np.zeros((2, 3)).astype(np.float32))
109    indices = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32))
110    updates = Tensor(np.arange(12).reshape((2, 2, 3)).astype(np.float32))
111
112    # update
113    output = scatter_func_net("update", inputx, indices, updates)
114    expected = np.array([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])
115    np.testing.assert_array_almost_equal(output.asnumpy(), expected)
116
117    # add
118    output = scatter_func_net("add", inputx, indices, updates)
119    expected = np.array([[6.0, 8.0, 10.0], [12.0, 14.0, 16.0]])
120    np.testing.assert_array_almost_equal(output.asnumpy(), expected)
121
122    # sub
123    output = scatter_func_net("sub", inputx, indices, updates)
124    expected = np.array([[-6.0, -8.0, -10.0], [-12.0, -14.0, -16.0]])
125    np.testing.assert_array_almost_equal(output.asnumpy(), expected)
126
127
128@pytest.mark.level0
129@pytest.mark.platform_x86_gpu_training
130@pytest.mark.env_onecard
131def test_scatter_func_input_updated():
132    inputx = Tensor(np.zeros((2, 3)).astype(np.float32))
133    indices = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32))
134    updates = Tensor(np.arange(12).reshape((2, 2, 3)).astype(np.float32))
135    lock = True
136
137    # update
138    net = TestScatterFuncNet("update", lock, inputx, indices, updates)
139    net()
140    expected = np.array([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])
141    np.testing.assert_array_almost_equal(net.inputx.asnumpy(), expected)
142
143    # add
144    net = TestScatterFuncNet("add", lock, inputx, indices, updates)
145    net()
146    expected = np.array([[6.0, 8.0, 10.0], [12.0, 14.0, 16.0]])
147    np.testing.assert_array_almost_equal(net.inputx.asnumpy(), expected)
148
149    # sub
150    net = TestScatterFuncNet("sub", lock, inputx, indices, updates)
151    net()
152    expected = np.array([[-6.0, -8.0, -10.0], [-12.0, -14.0, -16.0]])
153    np.testing.assert_array_almost_equal(net.inputx.asnumpy(), expected)
154
155
156@pytest.mark.level0
157@pytest.mark.platform_x86_gpu_training
158@pytest.mark.env_onecard
159def test_scatter_func_large_shape_float32():
160    inputx = Tensor(np.ones((4, 2, 3, 4)).astype(np.float32))
161    indices = Tensor(np.array([[0, 2], [3, 1]]).astype(np.int32))
162    updates = Tensor(np.arange(96).reshape((2, 2, 2, 3, 4)).astype(np.float32))
163
164    # update
165    output = scatter_func_net("update", inputx, indices, updates)
166    expected = np.array(
167        [
168            [
169                [[0.0, 1.0, 2.0, 3.0], [4.0, 5.0, 6.0, 7.0], [8.0, 9.0, 10.0, 11.0]],
170                [[12.0, 13.0, 14.0, 15.0], [16.0, 17.0, 18.0, 19.0], [20.0, 21.0, 22.0, 23.0]],
171            ],
172            [
173                [[72.0, 73.0, 74.0, 75.0], [76.0, 77.0, 78.0, 79.0], [80.0, 81.0, 82.0, 83.0]],
174                [[84.0, 85.0, 86.0, 87.0], [88.0, 89.0, 90.0, 91.0], [92.0, 93.0, 94.0, 95.0]],
175            ],
176            [
177                [[24.0, 25.0, 26.0, 27.0], [28.0, 29.0, 30.0, 31.0], [32.0, 33.0, 34.0, 35.0]],
178                [[36.0, 37.0, 38.0, 39.0], [40.0, 41.0, 42.0, 43.0], [44.0, 45.0, 46.0, 47.0]],
179            ],
180            [
181                [[48.0, 49.0, 50.0, 51.0], [52.0, 53.0, 54.0, 55.0], [56.0, 57.0, 58.0, 59.0]],
182                [[60.0, 61.0, 62.0, 63.0], [64.0, 65.0, 66.0, 67.0], [68.0, 69.0, 70.0, 71.0]],
183            ],
184        ]
185    )
186    np.testing.assert_array_almost_equal(output.asnumpy(), expected)
187
188    # add
189    output = scatter_func_net("add", inputx, indices, updates)
190    expected = np.array(
191        [
192            [
193                [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0]],
194                [[13.0, 14.0, 15.0, 16.0], [17.0, 18.0, 19.0, 20.0], [21.0, 22.0, 23.0, 24.0]],
195            ],
196            [
197                [[73.0, 74.0, 75.0, 76.0], [77.0, 78.0, 79.0, 80.0], [81.0, 82.0, 83.0, 84.0]],
198                [[85.0, 86.0, 87.0, 88.0], [89.0, 90.0, 91.0, 92.0], [93.0, 94.0, 95.0, 96.0]],
199            ],
200            [
201                [[25.0, 26.0, 27.0, 28.0], [29.0, 30.0, 31.0, 32.0], [33.0, 34.0, 35.0, 36.0]],
202                [[37.0, 38.0, 39.0, 40.0], [41.0, 42.0, 43.0, 44.0], [45.0, 46.0, 47.0, 48.0]],
203            ],
204            [
205                [[49.0, 50.0, 51.0, 52.0], [53.0, 54.0, 55.0, 56.0], [57.0, 58.0, 59.0, 60.0]],
206                [[61.0, 62.0, 63.0, 64.0], [65.0, 66.0, 67.0, 68.0], [69.0, 70.0, 71.0, 72.0]],
207            ],
208        ]
209    )
210    np.testing.assert_array_almost_equal(output.asnumpy(), expected)
211
212    # sub
213    output = scatter_func_net("sub", inputx, indices, updates)
214    expected = np.array(
215        [
216            [
217                [[1.0, 0.0, -1.0, -2.0], [-3.0, -4.0, -5.0, -6.0], [-7.0, -8.0, -9.0, -10.0]],
218                [
219                    [-11.0, -12.0, -13.0, -14.0],
220                    [-15.0, -16.0, -17.0, -18.0],
221                    [-19.0, -20.0, -21.0, -22.0],
222                ],
223            ],
224            [
225                [
226                    [-71.0, -72.0, -73.0, -74.0],
227                    [-75.0, -76.0, -77.0, -78.0],
228                    [-79.0, -80.0, -81.0, -82.0],
229                ],
230                [
231                    [-83.0, -84.0, -85.0, -86.0],
232                    [-87.0, -88.0, -89.0, -90.0],
233                    [-91.0, -92.0, -93.0, -94.0],
234                ],
235            ],
236            [
237                [
238                    [-23.0, -24.0, -25.0, -26.0],
239                    [-27.0, -28.0, -29.0, -30.0],
240                    [-31.0, -32.0, -33.0, -34.0],
241                ],
242                [
243                    [-35.0, -36.0, -37.0, -38.0],
244                    [-39.0, -40.0, -41.0, -42.0],
245                    [-43.0, -44.0, -45.0, -46.0],
246                ],
247            ],
248            [
249                [
250                    [-47.0, -48.0, -49.0, -50.0],
251                    [-51.0, -52.0, -53.0, -54.0],
252                    [-55.0, -56.0, -57.0, -58.0],
253                ],
254                [
255                    [-59.0, -60.0, -61.0, -62.0],
256                    [-63.0, -64.0, -65.0, -66.0],
257                    [-67.0, -68.0, -69.0, -70.0],
258                ],
259            ],
260        ]
261    )
262    np.testing.assert_array_almost_equal(output.asnumpy(), expected)
263
264
265@pytest.mark.level0
266@pytest.mark.platform_x86_gpu_training
267@pytest.mark.env_onecard
268def test_scatter_func_small_float32_use_locking_false():
269    inputx = Tensor(np.zeros((2, 3)).astype(np.float32))
270    indices = Tensor(np.array([1, 0]).astype(np.int32))
271    updates = Tensor(np.arange(6).reshape((2, 3)).astype(np.float32))
272
273    # update
274    output = scatter_func_use_locking_false_net("update", inputx, indices, updates)
275    expected = np.array([[3.0, 4.0, 5.0], [0.0, 1.0, 2.0]])
276    np.testing.assert_array_almost_equal(output.asnumpy(), expected)
277
278    # add
279    output = scatter_func_use_locking_false_net("add", inputx, indices, updates)
280    expected = np.array([[3.0, 4.0, 5.0], [0.0, 1.0, 2.0]])
281    np.testing.assert_array_almost_equal(output.asnumpy(), expected)
282
283    # sub
284    output = scatter_func_use_locking_false_net("sub", inputx, indices, updates)
285    expected = np.array([[-3.0, -4.0, -5.0], [0.0, -1.0, -2.0]])
286    np.testing.assert_array_almost_equal(output.asnumpy(), expected)
287
288
289@pytest.mark.level0
290@pytest.mark.platform_x86_gpu_training
291@pytest.mark.env_onecard
292def test_scatter_func_input_less_than_1_float32():
293    inputx = Tensor(
294        np.array(
295            [
296                [0.214141, 0.415151, 0.51516],
297                [0.876542, 0.451611, 0.55112],
298                [0.111244, 0.633333, 0.34444],
299            ]
300        ).astype(np.float32)
301    )
302    indices = Tensor(np.array([[[1, 0, 2], [2, 2, 0]], [[1, 0, 1], [2, 1, 2]]]).astype(np.int32))
303    updates = Tensor(np.arange(34, 70).reshape((2, 2, 3, 3)).astype(np.float32))
304
305    # update
306    output = scatter_func_net("update", inputx, indices, updates)
307    expected = np.array(
308        [[37.0, 38.0, 39.0], [34.0, 35.0, 66.0], [67.0, 68.0, 69.0],], dtype=np.float32,
309    )
310    np.testing.assert_array_almost_equal(output.asnumpy(), expected)
311
312    # add
313    output = scatter_func_net("add", inputx, indices, updates)
314    expected = np.array(
315        [
316            [141.21414, 144.41515, 147.51517],
317            [208.87654, 212.45161, 216.55112],
318            [257.11124, 262.63333, 267.34442],
319        ],
320        dtype=np.float32,
321    )
322    np.testing.assert_array_almost_equal(output.asnumpy(), expected)
323
324    # sub
325    output = scatter_func_net("sub", inputx, indices, updates)
326    expected = np.array(
327        [
328            [-140.78586, -143.58485, -146.48483],
329            [-207.12346, -211.54839, -215.44888],
330            [-256.88876, -261.36667, -266.65558],
331        ],
332        dtype=np.float32,
333    )
334    np.testing.assert_array_almost_equal(output.asnumpy(), expected)
335
336
337@pytest.mark.level0
338@pytest.mark.platform_x86_gpu_training
339@pytest.mark.env_onecard
340def test_scatter_func_float16():
341    inputx = Tensor(np.zeros((2, 3)).astype(np.float16))
342    indices = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32))
343    updates = Tensor(np.arange(12).reshape((2, 2, 3)).astype(np.float16))
344
345    # update
346    output = scatter_func_net("update", inputx, indices, updates)
347    expected = np.array([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])
348    np.testing.assert_array_almost_equal(output.asnumpy(), expected)
349
350    # add
351    output = scatter_func_net("add", inputx, indices, updates)
352    expected = np.array([[6.0, 8.0, 10.0], [12.0, 14.0, 16.0]])
353    np.testing.assert_array_almost_equal(output.asnumpy(), expected)
354
355    # sub
356    output = scatter_func_net("sub", inputx, indices, updates)
357    expected = np.array([[-6.0, -8.0, -10.0], [-12.0, -14.0, -16.0]])
358    np.testing.assert_array_almost_equal(output.asnumpy(), expected)
359
360
361@pytest.mark.level0
362@pytest.mark.platform_x86_gpu_training
363@pytest.mark.env_onecard
364def test_scatter_func_large_float16():
365    inputx = Tensor(np.zeros((2, 3, 4)).astype(np.float16))
366    indices = Tensor(np.array([[0, 0], [1, 1]]).astype(np.int32))
367    updates = Tensor(np.arange(63, 111).reshape((2, 2, 3, 4)).astype(np.float16))
368
369    # update
370    output = scatter_func_net("update", inputx, indices, updates)
371    expected = np.array(
372        [
373            [[63.0, 64.0, 65.0, 66.0], [67.0, 68.0, 69.0, 70.0], [71.0, 72.0, 73.0, 74.0],],
374            [[99.0, 100.0, 101.0, 102.0], [103.0, 104.0, 105.0, 106.0], [95.0, 96.0, 97.0, 98.0],],
375        ]
376    )
377    np.testing.assert_array_almost_equal(output.asnumpy(), expected)
378
379    # add
380    output = scatter_func_net("add", inputx, indices, updates)
381    expected = np.array(
382        [
383            [
384                [138.0, 140.0, 142.0, 144.0],
385                [146.0, 148.0, 150.0, 152.0],
386                [154.0, 156.0, 158.0, 160.0],
387            ],
388            [
389                [186.0, 188.0, 190.0, 192.0],
390                [194.0, 196.0, 198.0, 200.0],
391                [202.0, 204.0, 206.0, 208.0],
392            ],
393        ]
394    )
395    np.testing.assert_array_almost_equal(output.asnumpy(), expected)
396
397    # sub
398    output = scatter_func_net("sub", inputx, indices, updates)
399    expected = np.array(
400        [
401            [
402                [-138.0, -140.0, -142.0, -144.0],
403                [-146.0, -148.0, -150.0, -152.0],
404                [-154.0, -156.0, -158.0, -160.0],
405            ],
406            [
407                [-186.0, -188.0, -190.0, -192.0],
408                [-194.0, -196.0, -198.0, -200.0],
409                [-202.0, -204.0, -206.0, -208.0],
410            ],
411        ]
412    )
413    np.testing.assert_array_almost_equal(output.asnumpy(), expected)
414
415
416@pytest.mark.level0
417@pytest.mark.platform_x86_gpu_training
418@pytest.mark.env_onecard
419def test_scatter_func_disordered_float16():
420    inputx = Tensor(np.flip(np.arange(34, 46).reshape(3, 4).astype(np.float16)))
421    indices = Tensor(np.array([[[0, 1, 2], [2, 1, 0]], [[0, 0, 0], [2, 2, 2]]]).astype(np.int32))
422    updates = Tensor(np.arange(63, 111).reshape((2, 2, 3, 4)).astype(np.float16))
423
424    # update
425    output = scatter_func_net("update", inputx, indices, updates)
426    expected = np.array(
427        [[95.0, 96.0, 97.0, 98.0], [67.0, 68.0, 69.0, 70.0], [99.0, 100.0, 101.0, 102.0]]
428    )
429    np.testing.assert_array_almost_equal(output.asnumpy(), expected)
430
431    # add
432    output = scatter_func_net("add", inputx, indices, updates)
433    expected = np.array(
434        [[464.0, 468.0, 472.0, 476.0], [187.0, 188.0, 189.0, 190.0], [492.0, 496.0, 500.0, 504.0]]
435    )
436    np.testing.assert_array_almost_equal(output.asnumpy(), expected)
437
438    # sub
439    output = scatter_func_net("sub", inputx, indices, updates)
440    expected = np.array(
441        [
442            [-374.0, -380.0, -386.0, -392.0],
443            [-105.0, -108.0, -111.0, -114.0],
444            [-418.0, -424.0, -430.0, -436.0],
445        ]
446    )
447    np.testing.assert_array_almost_equal(output.asnumpy(), expected)
448
449
450@pytest.mark.level0
451@pytest.mark.platform_x86_gpu_training
452@pytest.mark.env_onecard
453def test_scatter_func_large_int32():
454    inputx = Tensor(np.zeros((2, 3, 4)).astype(np.int32))
455    indices = Tensor(np.array([[0, 0], [1, 1]]).astype(np.int32))
456    updates = Tensor(np.arange(63, 111).reshape((2, 2, 3, 4)).astype(np.int32))
457
458    # update
459    output = scatter_func_net("update", inputx, indices, updates)
460    expected = np.array(
461        [
462            [[63.0, 64.0, 65.0, 66.0], [67.0, 68.0, 69.0, 70.0], [71.0, 72.0, 73.0, 74.0],],
463            [[99.0, 100.0, 101.0, 102.0], [103.0, 104.0, 105.0, 106.0], [95.0, 96.0, 97.0, 98.0],],
464        ]
465    ).astype(np.int32)
466    np.testing.assert_array_almost_equal(output.asnumpy(), expected)
467
468    # add
469    output = scatter_func_net("add", inputx, indices, updates)
470    expected = np.array(
471        [
472            [
473                [138.0, 140.0, 142.0, 144.0],
474                [146.0, 148.0, 150.0, 152.0],
475                [154.0, 156.0, 158.0, 160.0],
476            ],
477            [
478                [186.0, 188.0, 190.0, 192.0],
479                [194.0, 196.0, 198.0, 200.0],
480                [202.0, 204.0, 206.0, 208.0],
481            ],
482        ]
483    ).astype(np.int32)
484    np.testing.assert_array_almost_equal(output.asnumpy(), expected)
485
486    # sub
487    output = scatter_func_net("sub", inputx, indices, updates)
488    expected = np.array(
489        [
490            [
491                [-138.0, -140.0, -142.0, -144.0],
492                [-146.0, -148.0, -150.0, -152.0],
493                [-154.0, -156.0, -158.0, -160.0],
494            ],
495            [
496                [-186.0, -188.0, -190.0, -192.0],
497                [-194.0, -196.0, -198.0, -200.0],
498                [-202.0, -204.0, -206.0, -208.0],
499            ],
500        ]
501    ).astype(np.int32)
502    np.testing.assert_array_almost_equal(output.asnumpy(), expected)
503
504
505@pytest.mark.level0
506@pytest.mark.platform_x86_gpu_training
507@pytest.mark.env_onecard
508def test_scatter_func_disordered_int32():
509    inputx = Tensor(np.flip(np.arange(34, 46).reshape(3, 4).astype(np.int32)))
510    indices = Tensor(np.array([[[0, 1, 2], [2, 1, 0]], [[0, 0, 0], [2, 2, 2]]]).astype(np.int32))
511    updates = Tensor(np.arange(63, 111).reshape((2, 2, 3, 4)).astype(np.int32))
512
513    # update
514    output = scatter_func_net("update", inputx, indices, updates)
515    expected = np.array(
516        [[95.0, 96.0, 97.0, 98.0], [67.0, 68.0, 69.0, 70.0], [99.0, 100.0, 101.0, 102.0]]
517    ).astype(np.int32)
518    np.testing.assert_array_almost_equal(output.asnumpy(), expected)
519
520    # add
521    output = scatter_func_net("add", inputx, indices, updates)
522    expected = np.array(
523        [[464.0, 468.0, 472.0, 476.0], [187.0, 188.0, 189.0, 190.0], [492.0, 496.0, 500.0, 504.0]]
524    ).astype(np.int32)
525    np.testing.assert_array_almost_equal(output.asnumpy(), expected)
526
527    # sub
528    output = scatter_func_net("sub", inputx, indices, updates)
529    expected = np.array(
530        [
531            [-374.0, -380.0, -386.0, -392.0],
532            [-105.0, -108.0, -111.0, -114.0],
533            [-418.0, -424.0, -430.0, -436.0],
534        ]
535    ).astype(np.int32)
536    np.testing.assert_array_almost_equal(output.asnumpy(), expected)
537
538
539@pytest.mark.level0
540@pytest.mark.platform_x86_gpu_training
541@pytest.mark.env_onecard
542def test_scatter_func_disordered_dynamic_int32():
543    inputx = Tensor(np.flip(np.arange(34, 46).reshape(3, 4).astype(np.int32)))
544    indices = Tensor(np.array([[[0, 1, 2], [2, 1, 0]], [[0, 0, 0], [2, 2, 2]]]).astype(np.int32))
545    updates = Tensor(np.arange(63, 111).reshape((2, 2, 3, 4)).astype(np.int32))
546
547    # update
548    output = scatter_func_d_net("update", inputx, indices, updates)
549    expected = np.array(
550        [[95.0, 96.0, 97.0, 98.0], [67.0, 68.0, 69.0, 70.0], [99.0, 100.0, 101.0, 102.0]]
551    ).astype(np.int32)
552    np.testing.assert_array_almost_equal(output.asnumpy(), expected)
553
554    # add
555    output = scatter_func_d_net("add", inputx, indices, updates)
556    expected = np.array(
557        [[464.0, 468.0, 472.0, 476.0], [187.0, 188.0, 189.0, 190.0], [492.0, 496.0, 500.0, 504.0]]
558    ).astype(np.int32)
559    np.testing.assert_array_almost_equal(output.asnumpy(), expected)
560
561    # sub
562    output = scatter_func_d_net("sub", inputx, indices, updates)
563    expected = np.array(
564        [
565            [-374.0, -380.0, -386.0, -392.0],
566            [-105.0, -108.0, -111.0, -114.0],
567            [-418.0, -424.0, -430.0, -436.0],
568        ]
569    ).astype(np.int32)
570    np.testing.assert_array_almost_equal(output.asnumpy(), expected)
571
572
573@pytest.mark.level0
574@pytest.mark.platform_x86_gpu_training
575@pytest.mark.env_onecard
576def test_scatter_func_disordered_dynamic_int8():
577    inputx = Tensor(np.flip(np.arange(34, 46).reshape(3, 4).astype(np.int8)))
578    indices = Tensor(np.array([[[0, 1, 2], [2, 1, 0]], [[0, 0, 0], [2, 2, 2]]]).astype(np.int32))
579    updates = Tensor(np.arange(63, 111).reshape((2, 2, 3, 4)).astype(np.int8))
580
581    # update
582    output = scatter_func_d_net("update", inputx, indices, updates)
583    expected = np.array(
584        [[95.0, 96.0, 97.0, 98.0], [67.0, 68.0, 69.0, 70.0], [99.0, 100.0, 101.0, 102.0]]
585    ).astype(np.int8)
586    np.testing.assert_array_almost_equal(output.asnumpy(), expected)
587
588    # add
589    output = scatter_func_d_net("add", inputx, indices, updates)
590    expected = np.array(
591        [[464.0, 468.0, 472.0, 476.0], [187.0, 188.0, 189.0, 190.0], [492.0, 496.0, 500.0, 504.0]]
592    ).astype(np.int8)
593    np.testing.assert_array_almost_equal(output.asnumpy(), expected)
594
595    # sub
596    output = scatter_func_d_net("sub", inputx, indices, updates)
597    expected = np.array(
598        [
599            [-118.0, -124.0, 126.0, 120.0],
600            [-105.0, -108.0, -111.0, -114.0],
601            [94.0, 88.0, 82.0, 76.0],
602        ]
603    ).astype(np.int8)
604    np.testing.assert_array_almost_equal(output.asnumpy(), expected)
605
606
607@pytest.mark.level0
608@pytest.mark.platform_x86_gpu_training
609@pytest.mark.env_onecard
610def test_scatter_func_disordered_dynamic_uint8():
611    inputx = Tensor(np.flip(np.arange(34, 46).reshape(3, 4).astype(np.uint8)))
612    indices = Tensor(np.array([[[0, 1, 2], [2, 1, 0]], [[0, 0, 0], [2, 2, 2]]]).astype(np.int32))
613    updates = Tensor(np.arange(63, 111).reshape((2, 2, 3, 4)).astype(np.uint8))
614
615    # update
616    output = scatter_func_d_net("update", inputx, indices, updates)
617    expected = np.array(
618        [[95.0, 96.0, 97.0, 98.0], [67.0, 68.0, 69.0, 70.0], [99.0, 100.0, 101.0, 102.0]]
619    ).astype(np.uint8)
620    np.testing.assert_array_almost_equal(output.asnumpy(), expected)
621
622    # add
623    output = scatter_func_d_net("add", inputx, indices, updates)
624    expected = np.array(
625        [[464.0, 468.0, 472.0, 476.0], [187.0, 188.0, 189.0, 190.0], [492.0, 496.0, 500.0, 504.0]]
626    ).astype(np.uint8)
627    np.testing.assert_array_almost_equal(output.asnumpy(), expected)
628
629    # sub
630    output = scatter_func_d_net("sub", inputx, indices, updates)
631    expected = np.array(
632        [[138.0, 132.0, 126.0, 120.0], [151.0, 148.0, 145.0, 142.0], [94.0, 88.0, 82.0, 76.0]]
633    ).astype(np.uint8)
634    np.testing.assert_array_almost_equal(output.asnumpy(), expected)
635
636
637@pytest.mark.level0
638@pytest.mark.platform_x86_gpu_training
639@pytest.mark.env_onecard
640def test_scatter_func_input_less_than_1_dynamic_float32():
641    inputx = Tensor(
642        np.array(
643            [
644                [0.214141, 0.415151, 0.51516],
645                [0.876542, 0.451611, 0.55112],
646                [0.111244, 0.633333, 0.34444],
647            ]
648        ).astype(np.float32)
649    )
650    indices = Tensor(np.array([[[1, 0, 2], [2, 2, 0]], [[1, 0, 1], [2, 1, 2]]]).astype(np.int32))
651    updates = Tensor(np.arange(34, 70).reshape((2, 2, 3, 3)).astype(np.float32))
652
653    # update
654    output = scatter_func_d_net("update", inputx, indices, updates)
655    expected = np.array(
656        [[37.0, 38.0, 39.0], [34.0, 35.0, 66.0], [67.0, 68.0, 69.0],], dtype=np.float32,
657    )
658    np.testing.assert_array_almost_equal(output.asnumpy(), expected)
659
660    # add
661    output = scatter_func_d_net("add", inputx, indices, updates)
662    expected = np.array(
663        [
664            [141.21414, 144.41515, 147.51517],
665            [208.87654, 212.45161, 216.55112],
666            [257.11124, 262.63333, 267.34442],
667        ],
668        dtype=np.float32,
669    )
670    np.testing.assert_array_almost_equal(output.asnumpy(), expected)
671
672    # sub
673    output = scatter_func_d_net("sub", inputx, indices, updates)
674    expected = np.array(
675        [
676            [-140.78586, -143.58485, -146.48483],
677            [-207.12346, -211.54839, -215.44888],
678            [-256.88876, -261.36667, -266.65558],
679        ],
680        dtype=np.float32,
681    )
682    np.testing.assert_array_almost_equal(output.asnumpy(), expected)
683
684
685@pytest.mark.level0
686@pytest.mark.platform_x86_gpu_training
687@pytest.mark.env_onecard
688def test_scatter_func_dynamic_two_inputs():
689    inputx = Tensor(np.zeros((2, 3)).astype(np.float32))
690    indices_1 = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32))
691    updates_1 = Tensor(np.arange(12).reshape((2, 2, 3)).astype(np.float32))
692    indices_2 = Tensor(np.array([[0, 0], [1, 1], [1, 0]]).astype(np.int32))
693    updates_2 = Tensor(np.flip(np.arange(18).reshape((3, 2, 3)).astype(np.float32)))
694
695    # update
696    output_1, output_2 = scatter_func_d2_net(
697        "update", inputx, indices_1, updates_1, indices_2, updates_2
698    )
699    expected_1 = np.array([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])
700    expected_2 = np.array([[17.0, 16.0, 15.0], [11.0, 10.0, 9.0]])
701    np.testing.assert_array_almost_equal(output_1.asnumpy(), expected_1)
702    np.testing.assert_array_almost_equal(output_2.asnumpy(), expected_2)
703
704    # add
705    output_1, output_2 = scatter_func_d2_net(
706        "add", inputx, indices_1, updates_1, indices_2, updates_2
707    )
708    expected_1 = np.array([[6.0, 8.0, 10.0], [12.0, 14.0, 16.0]])
709    expected_2 = np.array([[39.0, 38.0, 37.0], [36.0, 35.0, 34.0]])
710    np.testing.assert_array_almost_equal(output_1.asnumpy(), expected_1)
711    np.testing.assert_array_almost_equal(output_2.asnumpy(), expected_2)
712
713    # sub
714    output_1, output_2 = scatter_func_d2_net(
715        "sub", inputx, indices_1, updates_1, indices_2, updates_2
716    )
717    expected_1 = np.array([[-6.0, -8.0, -10.0], [-12.0, -14.0, -16.0]])
718    expected_2 = np.array([[-39.0, -38.0, -37.0], [-36.0, -35.0, -34.0]])
719    np.testing.assert_array_almost_equal(output_1.asnumpy(), expected_1)
720    np.testing.assert_array_almost_equal(output_2.asnumpy(), expected_2)
721