1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15import numpy as np 16import pytest 17import mindspore.context as context 18import mindspore.nn as nn 19from mindspore import Tensor 20from mindspore.ops import operations as P 21 22 23class Net(nn.Cell): 24 def __init__(self, _shape): 25 super(Net, self).__init__() 26 self.shape = _shape 27 self.scatternd = P.ScatterNd() 28 29 def construct(self, indices, update): 30 return self.scatternd(indices, update, self.shape) 31 32 33def scatternd_net(indices, update, _shape, expect): 34 scatternd = Net(_shape) 35 output = scatternd(Tensor(indices), Tensor(update)) 36 error = np.ones(shape=output.asnumpy().shape) * 1.0e-6 37 diff = output.asnumpy() - expect 38 assert np.all(diff < error) 39 assert np.all(-diff < error) 40 41def scatternd_positive(nptype): 42 context.set_context(mode=context.GRAPH_MODE, device_target="CPU") 43 44 arr_indices = np.array([[0, 1], [1, 1], [0, 1], [0, 1], [0, 1]]).astype(np.int32) 45 arr_update = np.array([3.2, 1.1, 5.3, -2.2, -1.0]).astype(nptype) 46 shape = (2, 2) 47 expect = np.array([[0., 5.3], 48 [0., 1.1]]).astype(nptype) 49 scatternd_net(arr_indices, arr_update, shape, expect) 50 51 arr_indices = np.array([[0, 1], [1, 1], [0, 1], [0, 1], [0, 1]]).astype(np.int64) 52 arr_update = np.array([3.2, 1.1, 5.3, -2.2, -1.0]).astype(nptype) 53 shape = (2, 2) 54 expect = np.array([[0., 5.3], 55 [0., 1.1]]).astype(nptype) 56 scatternd_net(arr_indices, arr_update, shape, expect) 57 58def scatternd_negative(nptype): 59 context.set_context(mode=context.GRAPH_MODE, device_target="CPU") 60 61 arr_indices = np.array([[1, 0], [1, 1], [1, 0], [1, 0], [1, 0]]).astype(np.int32) 62 arr_update = np.array([-13.4, -3.1, 5.1, -12.1, -1.0]).astype(nptype) 63 shape = (2, 2) 64 expect = np.array([[0., 0.], 65 [-21.4, -3.1]]).astype(nptype) 66 scatternd_net(arr_indices, arr_update, shape, expect) 67 68 arr_indices = np.array([[1, 0], [1, 1], [1, 0], [1, 0], [1, 0]]).astype(np.int64) 69 arr_update = np.array([-13.4, -3.1, 5.1, -12.1, -1.0]).astype(nptype) 70 shape = (2, 2) 71 expect = np.array([[0., 0.], 72 [-21.4, -3.1]]).astype(nptype) 73 scatternd_net(arr_indices, arr_update, shape, expect) 74 75def scatternd_positive_uint(nptype): 76 context.set_context(mode=context.GRAPH_MODE, device_target="CPU") 77 78 arr_indices = np.array([[0, 1], [1, 1], [0, 1], [0, 1], [0, 1]]).astype(np.int32) 79 arr_update = np.array([3.2, 1.1, 5.3, 3.8, 1.2]).astype(nptype) 80 shape = (2, 2) 81 expect = np.array([[0., 12.], 82 [0., 1.]]).astype(nptype) 83 scatternd_net(arr_indices, arr_update, shape, expect) 84 85 arr_indices = np.array([[0, 1], [1, 1], [0, 1], [0, 1], [0, 1]]).astype(np.int64) 86 arr_update = np.array([3.2, 1.1, 5.3, 3.8, 1.2]).astype(nptype) 87 shape = (2, 2) 88 expect = np.array([[0., 12.], 89 [0., 1.]]).astype(nptype) 90 scatternd_net(arr_indices, arr_update, shape, expect) 91 92@pytest.mark.level0 93@pytest.mark.platform_x86_cpu 94@pytest.mark.env_onecard 95def test_scatternd_float64(): 96 scatternd_positive(np.float64) 97 scatternd_negative(np.float64) 98 99@pytest.mark.level0 100@pytest.mark.platform_x86_cpu 101@pytest.mark.env_onecard 102def test_scatternd_float32(): 103 scatternd_positive(np.float32) 104 scatternd_negative(np.float32) 105 106@pytest.mark.level0 107@pytest.mark.platform_x86_cpu 108@pytest.mark.env_onecard 109def test_scatternd_int64(): 110 scatternd_positive(np.int64) 111 scatternd_negative(np.int64) 112 113@pytest.mark.level0 114@pytest.mark.platform_x86_cpu 115@pytest.mark.env_onecard 116def test_scatternd_int16(): 117 scatternd_positive(np.int16) 118 scatternd_negative(np.int16) 119 120@pytest.mark.level0 121@pytest.mark.platform_x86_cpu 122@pytest.mark.env_onecard 123def test_scatternd_uint64(): 124 scatternd_positive_uint(np.uint64) 125 126@pytest.mark.level0 127@pytest.mark.platform_x86_cpu 128@pytest.mark.env_onecard 129def test_scatternd_uint32(): 130 scatternd_positive_uint(np.uint32) 131 132@pytest.mark.level0 133@pytest.mark.platform_x86_cpu 134@pytest.mark.env_onecard 135def test_scatternd_uint16(): 136 scatternd_positive_uint(np.uint16) 137 138@pytest.mark.level0 139@pytest.mark.platform_x86_cpu 140@pytest.mark.env_onecard 141def test_scatternd_uint8(): 142 scatternd_positive_uint(np.uint8) 143