1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16import numpy as np 17import pytest 18 19import mindspore 20import mindspore.context as context 21import mindspore.nn as nn 22from mindspore import Tensor 23from mindspore.ops import operations as P 24 25context.set_context(mode=context.GRAPH_MODE, device_target='CPU') 26 27 28class OpNetWrapper(nn.Cell): 29 def __init__(self, op): 30 super(OpNetWrapper, self).__init__() 31 self.op = op 32 33 def construct(self, *inputs): 34 return self.op(*inputs) 35 36 37@pytest.mark.level0 38@pytest.mark.platform_x86_cpu 39@pytest.mark.env_onecard 40def test_case1_basic_func(): 41 op = P.GatherNd() 42 op_wrapper = OpNetWrapper(op) 43 44 indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32) 45 params = Tensor(np.array([[0, 1], [2, 3]]), mindspore.float32) 46 outputs = op_wrapper(params, indices) 47 print(outputs) 48 expected = [0, 3] 49 assert np.allclose(outputs.asnumpy(), np.array(expected)) 50 51 52@pytest.mark.level0 53@pytest.mark.platform_x86_cpu 54@pytest.mark.env_onecard 55def test_case2_indices_to_matrix(): 56 op = P.GatherNd() 57 op_wrapper = OpNetWrapper(op) 58 59 indices = Tensor(np.array([[1], [0]]), mindspore.int32) 60 params = Tensor(np.array([[0, 1], [2, 3]]), mindspore.float32) 61 outputs = op_wrapper(params, indices) 62 print(outputs) 63 expected = [[2, 3], [0, 1]] 64 assert np.allclose(outputs.asnumpy(), np.array(expected)) 65 66 67@pytest.mark.level0 68@pytest.mark.platform_x86_cpu 69@pytest.mark.env_onecard 70def test_case3_indices_to_3d_tensor(): 71 op = P.GatherNd() 72 op_wrapper = OpNetWrapper(op) 73 74 indices = Tensor(np.array([[1]]), mindspore.int32) # (1, 1) 75 params = Tensor(np.array([[[0, 1], [2, 3]], 76 [[4, 5], [6, 7]]]), mindspore.float32) # (2, 2, 2) 77 outputs = op_wrapper(params, indices) 78 print(outputs) 79 expected = [[[4, 5], [6, 7]]] # (1, 2, 2) 80 assert np.allclose(outputs.asnumpy(), np.array(expected)) 81 82 83@pytest.mark.level0 84@pytest.mark.platform_x86_cpu 85@pytest.mark.env_onecard 86def test_case4(): 87 op = P.GatherNd() 88 op_wrapper = OpNetWrapper(op) 89 90 indices = Tensor(np.array([[0, 1], [1, 0]]), mindspore.int32) # (2, 2) 91 params = Tensor(np.array([[[0, 1], [2, 3]], 92 [[4, 5], [6, 7]]]), mindspore.float32) # (2, 2, 2) 93 outputs = op_wrapper(params, indices) 94 print(outputs) 95 expected = [[2, 3], [4, 5]] # (2, 2) 96 assert np.allclose(outputs.asnumpy(), np.array(expected)) 97 98 99@pytest.mark.level0 100@pytest.mark.platform_x86_cpu 101@pytest.mark.env_onecard 102def test_case5(): 103 op = P.GatherNd() 104 op_wrapper = OpNetWrapper(op) 105 106 indices = Tensor(np.array([[0, 0, 1], [1, 0, 1]]), mindspore.int32) # (2, 3) 107 params = Tensor(np.array([[[0, 1], [2, 3]], 108 [[4, 5], [6, 7]]]), mindspore.float32) # (2, 2, 2) 109 outputs = op_wrapper(params, indices) 110 print(outputs) 111 expected = [1, 5] # (2,) 112 assert np.allclose(outputs.asnumpy(), np.array(expected)) 113 114 115@pytest.mark.level0 116@pytest.mark.platform_x86_cpu 117@pytest.mark.env_onecard 118def test_case6(): 119 op = P.GatherNd() 120 op_wrapper = OpNetWrapper(op) 121 122 indices = Tensor(np.array([[[0, 0]], [[0, 1]]]), mindspore.int32) # (2, 1, 2) 123 params = Tensor(np.array([[[0, 1], [2, 3]], 124 [[4, 5], [6, 7]]]), mindspore.float32) # (2, 2, 2) 125 outputs = op_wrapper(params, indices) 126 print(outputs) 127 expected = [[[0, 1]], [[2, 3]]] # (2, 1, 2) 128 assert np.allclose(outputs.asnumpy(), np.array(expected)) 129 130 131@pytest.mark.level0 132@pytest.mark.platform_x86_cpu 133@pytest.mark.env_onecard 134def test_case7(): 135 op = P.GatherNd() 136 op_wrapper = OpNetWrapper(op) 137 138 indices = Tensor(np.array([[[1]], [[0]]]), mindspore.int32) # (2, 1, 1) 139 params = Tensor(np.array([[[0, 1], [2, 3]], 140 [[4, 5], [6, 7]]]), mindspore.float32) # (2, 2, 2) 141 outputs = op_wrapper(params, indices) 142 print(outputs) 143 expected = [[[[4, 5], [6, 7]]], [[[0, 1], [2, 3]]]] # (2, 1, 2, 2) 144 assert np.allclose(outputs.asnumpy(), np.array(expected)) 145 146 147@pytest.mark.level0 148@pytest.mark.platform_x86_cpu 149@pytest.mark.env_onecard 150def test_case8(): 151 op = P.GatherNd() 152 op_wrapper = OpNetWrapper(op) 153 154 indices = Tensor(np.array([[[0, 1], [1, 0]], [[0, 0], [1, 1]]]), mindspore.int32) # (2, 2, 2) 155 params = Tensor(np.array([[[0, 1], [2, 3]], 156 [[4, 5], [6, 7]]]), mindspore.float32) # (2, 2, 2) 157 outputs = op_wrapper(params, indices) 158 print(outputs) 159 expected = [[[2, 3], [4, 5]], [[0, 1], [6, 7]]] # (2, 2, 2) 160 assert np.allclose(outputs.asnumpy(), np.array(expected)) 161 162 163@pytest.mark.level0 164@pytest.mark.platform_x86_cpu 165@pytest.mark.env_onecard 166def test_case9(): 167 op = P.GatherNd() 168 op_wrapper = OpNetWrapper(op) 169 170 indices = Tensor(np.array([[[0, 0, 1], [1, 0, 1]], [[0, 1, 1], [1, 1, 0]]]), mindspore.int32) # (2, 2, 3) 171 params = Tensor(np.array([[[0, 1], [2, 3]], 172 [[4, 5], [6, 7]]]), mindspore.int64) # (2, 2, 2) 173 outputs = op_wrapper(params, indices) 174 print(outputs) 175 expected = [[1, 5], [3, 6]] # (2, 2, 2) 176 assert np.allclose(outputs.asnumpy(), np.array(expected)) 177 178 179if __name__ == '__main__': 180 test_case1_basic_func() 181 test_case2_indices_to_matrix() 182 test_case3_indices_to_3d_tensor() 183 test_case4() 184 test_case5() 185 test_case6() 186 test_case7() 187 test_case8() 188 test_case9() 189