1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15import numpy as np 16import pytest 17import mindspore.context as context 18import mindspore.nn as nn 19from mindspore import Tensor 20from mindspore.ops import operations as P 21 22 23class Net(nn.Cell): 24 def __init__(self, _shape): 25 super(Net, self).__init__() 26 self.shape = _shape 27 self.scatternd = P.ScatterNd() 28 29 def construct(self, indices, update): 30 return self.scatternd(indices, update, self.shape) 31 32 33def scatternd_net(indices, update, _shape, expect): 34 scatternd = Net(_shape) 35 output = scatternd(Tensor(indices), Tensor(update)) 36 error = np.ones(shape=output.asnumpy().shape) * 1.0e-6 37 diff = output.asnumpy() - expect 38 assert np.all(diff < error) 39 assert np.all(-diff < error) 40 41def scatternd_positive(nptype): 42 context.set_context(mode=context.GRAPH_MODE, device_target="GPU") 43 44 arr_indices = np.array([[0, 1], [1, 1], [0, 1], [0, 1], [0, 1]]).astype(np.int32) 45 arr_update = np.array([3.2, 1.1, 5.3, -2.2, -1.0]).astype(nptype) 46 shape = (2, 2) 47 expect = np.array([[0., 5.3], 48 [0., 1.1]]).astype(nptype) 49 scatternd_net(arr_indices, arr_update, shape, expect) 50 51 arr_indices = np.array([[0, 1], [1, 1], [0, 1], [0, 1], [0, 1]]).astype(np.int64) 52 arr_update = np.array([3.2, 1.1, 5.3, -2.2, -1.0]).astype(nptype) 53 shape = (2, 2) 54 expect = np.array([[0., 5.3], 55 [0., 1.1]]).astype(nptype) 56 scatternd_net(arr_indices, arr_update, shape, expect) 57 58def scatternd_negative(nptype): 59 context.set_context(mode=context.GRAPH_MODE, device_target="GPU") 60 61 arr_indices = np.array([[1, 0], [1, 1], [1, 0], [1, 0], [1, 0]]).astype(np.int32) 62 arr_update = np.array([-13.4, -3.1, 5.1, -12.1, -1.0]).astype(nptype) 63 shape = (2, 2) 64 expect = np.array([[0., 0.], 65 [-21.4, -3.1]]).astype(nptype) 66 scatternd_net(arr_indices, arr_update, shape, expect) 67 68 arr_indices = np.array([[1, 0], [1, 1], [1, 0], [1, 0], [1, 0]]).astype(np.int64) 69 arr_update = np.array([-13.4, -3.1, 5.1, -12.1, -1.0]).astype(nptype) 70 shape = (2, 2) 71 expect = np.array([[0., 0.], 72 [-21.4, -3.1]]).astype(nptype) 73 scatternd_net(arr_indices, arr_update, shape, expect) 74 75@pytest.mark.level0 76@pytest.mark.platform_x86_gpu_traning 77@pytest.mark.env_onecard 78def test_scatternd_float32(): 79 scatternd_positive(np.float32) 80 scatternd_negative(np.float32) 81 82@pytest.mark.level0 83@pytest.mark.platform_x86_gpu_traning 84@pytest.mark.env_onecard 85def test_scatternd_float16(): 86 scatternd_positive(np.float16) 87 scatternd_negative(np.float16) 88 89@pytest.mark.level0 90@pytest.mark.platform_x86_gpu_traning 91@pytest.mark.env_onecard 92def test_scatternd_int16(): 93 scatternd_positive(np.int16) 94 scatternd_negative(np.int16) 95 96@pytest.mark.level0 97@pytest.mark.platform_x86_gpu_traning 98@pytest.mark.env_onecard 99def test_scatternd_uint8(): 100 scatternd_positive(np.uint8) 101