• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16import numpy as np
17import pytest
18
19import mindspore
20import mindspore.context as context
21import mindspore.nn as nn
22from mindspore import Tensor
23from mindspore.ops import operations as P
24
25context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
26
27
28class OpNetWrapper(nn.Cell):
29    def __init__(self, op):
30        super(OpNetWrapper, self).__init__()
31        self.op = op
32
33    def construct(self, *inputs):
34        return self.op(*inputs)
35
36
37@pytest.mark.level0
38@pytest.mark.platform_x86_cpu
39@pytest.mark.env_onecard
40def test_case1_basic_func():
41    op = P.GatherNd()
42    op_wrapper = OpNetWrapper(op)
43
44    indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
45    params = Tensor(np.array([[0, 1], [2, 3]]), mindspore.float32)
46    outputs = op_wrapper(params, indices)
47    print(outputs)
48    expected = [0, 3]
49    assert np.allclose(outputs.asnumpy(), np.array(expected))
50
51
52@pytest.mark.level0
53@pytest.mark.platform_x86_cpu
54@pytest.mark.env_onecard
55def test_case2_indices_to_matrix():
56    op = P.GatherNd()
57    op_wrapper = OpNetWrapper(op)
58
59    indices = Tensor(np.array([[1], [0]]), mindspore.int32)
60    params = Tensor(np.array([[0, 1], [2, 3]]), mindspore.float32)
61    outputs = op_wrapper(params, indices)
62    print(outputs)
63    expected = [[2, 3], [0, 1]]
64    assert np.allclose(outputs.asnumpy(), np.array(expected))
65
66
67@pytest.mark.level0
68@pytest.mark.platform_x86_cpu
69@pytest.mark.env_onecard
70def test_case3_indices_to_3d_tensor():
71    op = P.GatherNd()
72    op_wrapper = OpNetWrapper(op)
73
74    indices = Tensor(np.array([[1]]), mindspore.int32)  # (1, 1)
75    params = Tensor(np.array([[[0, 1], [2, 3]],
76                              [[4, 5], [6, 7]]]), mindspore.float32)  # (2, 2, 2)
77    outputs = op_wrapper(params, indices)
78    print(outputs)
79    expected = [[[4, 5], [6, 7]]]  # (1, 2, 2)
80    assert np.allclose(outputs.asnumpy(), np.array(expected))
81
82
83@pytest.mark.level0
84@pytest.mark.platform_x86_cpu
85@pytest.mark.env_onecard
86def test_case4():
87    op = P.GatherNd()
88    op_wrapper = OpNetWrapper(op)
89
90    indices = Tensor(np.array([[0, 1], [1, 0]]), mindspore.int32)  # (2, 2)
91    params = Tensor(np.array([[[0, 1], [2, 3]],
92                              [[4, 5], [6, 7]]]), mindspore.float32)  # (2, 2, 2)
93    outputs = op_wrapper(params, indices)
94    print(outputs)
95    expected = [[2, 3], [4, 5]]  # (2, 2)
96    assert np.allclose(outputs.asnumpy(), np.array(expected))
97
98
99@pytest.mark.level0
100@pytest.mark.platform_x86_cpu
101@pytest.mark.env_onecard
102def test_case5():
103    op = P.GatherNd()
104    op_wrapper = OpNetWrapper(op)
105
106    indices = Tensor(np.array([[0, 0, 1], [1, 0, 1]]), mindspore.int32)  # (2, 3)
107    params = Tensor(np.array([[[0, 1], [2, 3]],
108                              [[4, 5], [6, 7]]]), mindspore.float32)  # (2, 2, 2)
109    outputs = op_wrapper(params, indices)
110    print(outputs)
111    expected = [1, 5]  # (2,)
112    assert np.allclose(outputs.asnumpy(), np.array(expected))
113
114
115@pytest.mark.level0
116@pytest.mark.platform_x86_cpu
117@pytest.mark.env_onecard
118def test_case6():
119    op = P.GatherNd()
120    op_wrapper = OpNetWrapper(op)
121
122    indices = Tensor(np.array([[[0, 0]], [[0, 1]]]), mindspore.int32)  # (2, 1, 2)
123    params = Tensor(np.array([[[0, 1], [2, 3]],
124                              [[4, 5], [6, 7]]]), mindspore.float32)  # (2, 2, 2)
125    outputs = op_wrapper(params, indices)
126    print(outputs)
127    expected = [[[0, 1]], [[2, 3]]]  # (2, 1, 2)
128    assert np.allclose(outputs.asnumpy(), np.array(expected))
129
130
131@pytest.mark.level0
132@pytest.mark.platform_x86_cpu
133@pytest.mark.env_onecard
134def test_case7():
135    op = P.GatherNd()
136    op_wrapper = OpNetWrapper(op)
137
138    indices = Tensor(np.array([[[1]], [[0]]]), mindspore.int32)  # (2, 1, 1)
139    params = Tensor(np.array([[[0, 1], [2, 3]],
140                              [[4, 5], [6, 7]]]), mindspore.float32)  # (2, 2, 2)
141    outputs = op_wrapper(params, indices)
142    print(outputs)
143    expected = [[[[4, 5], [6, 7]]], [[[0, 1], [2, 3]]]]  # (2, 1, 2, 2)
144    assert np.allclose(outputs.asnumpy(), np.array(expected))
145
146
147@pytest.mark.level0
148@pytest.mark.platform_x86_cpu
149@pytest.mark.env_onecard
150def test_case8():
151    op = P.GatherNd()
152    op_wrapper = OpNetWrapper(op)
153
154    indices = Tensor(np.array([[[0, 1], [1, 0]], [[0, 0], [1, 1]]]), mindspore.int32)  # (2, 2, 2)
155    params = Tensor(np.array([[[0, 1], [2, 3]],
156                              [[4, 5], [6, 7]]]), mindspore.float32)  # (2, 2, 2)
157    outputs = op_wrapper(params, indices)
158    print(outputs)
159    expected = [[[2, 3], [4, 5]], [[0, 1], [6, 7]]]  # (2, 2, 2)
160    assert np.allclose(outputs.asnumpy(), np.array(expected))
161
162
163@pytest.mark.level0
164@pytest.mark.platform_x86_cpu
165@pytest.mark.env_onecard
166def test_case9():
167    op = P.GatherNd()
168    op_wrapper = OpNetWrapper(op)
169
170    indices = Tensor(np.array([[[0, 0, 1], [1, 0, 1]], [[0, 1, 1], [1, 1, 0]]]), mindspore.int32)  # (2, 2, 3)
171    params = Tensor(np.array([[[0, 1], [2, 3]],
172                              [[4, 5], [6, 7]]]), mindspore.int64)  # (2, 2, 2)
173    outputs = op_wrapper(params, indices)
174    print(outputs)
175    expected = [[1, 5], [3, 6]]  # (2, 2, 2)
176    assert np.allclose(outputs.asnumpy(), np.array(expected))
177
178
179if __name__ == '__main__':
180    test_case1_basic_func()
181    test_case2_indices_to_matrix()
182    test_case3_indices_to_3d_tensor()
183    test_case4()
184    test_case5()
185    test_case6()
186    test_case7()
187    test_case8()
188    test_case9()
189