• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16import pytest
17import numpy as np
18from mindspore import Tensor
19from mindspore.ops import operations as P
20import mindspore.nn as nn
21import mindspore.context as context
22from mindspore.common import dtype as mstype
23
24context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
25
26class NetGatherV2_axis0(nn.Cell):
27    def __init__(self):
28        super(NetGatherV2_axis0, self).__init__()
29        self.gatherv2 = P.Gather()
30
31    def construct(self, params, indices):
32        return self.gatherv2(params, indices, 0)
33
34@pytest.mark.level0
35@pytest.mark.platform_x86_cpu
36@pytest.mark.env_onecard
37def test_gatherv2_axis0():
38    x = Tensor(np.arange(3 * 2 * 2).reshape(3, 2, 2), mstype.float32)
39    indices = Tensor(np.array([1, 2]), mstype.int32)
40    gatherv2 = NetGatherV2_axis0()
41    ms_output = gatherv2(x, indices)
42    print("output:\n", ms_output)
43    expect = np.array([[[4., 5.],
44                        [6., 7.]],
45                       [[8., 9.],
46                        [10., 11.]]])
47    error = np.ones(shape=ms_output.asnumpy().shape) * 1.0e-6
48    diff = ms_output.asnumpy() - expect
49    assert np.all(diff < error)
50    assert np.all(-diff < error)
51
52class NetGatherV2_axis1(nn.Cell):
53    def __init__(self):
54        super(NetGatherV2_axis1, self).__init__()
55        self.gatherv2 = P.Gather()
56
57    def construct(self, params, indices):
58        return self.gatherv2(params, indices, 1)
59
60@pytest.mark.level0
61@pytest.mark.platform_x86_cpu
62@pytest.mark.env_onecard
63def test_gatherv2_axis1():
64    x = Tensor(np.arange(2 * 3 * 2).reshape(2, 3, 2), mstype.float32)
65    indices = Tensor(np.array([1, 2]), mstype.int32)
66    gatherv2 = NetGatherV2_axis1()
67    ms_output = gatherv2(x, indices)
68    print("output:\n", ms_output)
69    expect = np.array([[[2., 3.],
70                        [4., 5.]],
71                       [[8., 9.],
72                        [10., 11.]]])
73    error = np.ones(shape=ms_output.asnumpy().shape) * 1.0e-6
74    diff = ms_output.asnumpy() - expect
75    assert np.all(diff < error)
76    assert np.all(-diff < error)
77
78class NetGatherV2_axisN1(nn.Cell):
79    def __init__(self):
80        super(NetGatherV2_axisN1, self).__init__()
81        self.gatherv2 = P.Gather()
82
83    def construct(self, params, indices):
84        return self.gatherv2(params, indices, -1)
85
86@pytest.mark.level0
87@pytest.mark.platform_x86_cpu
88@pytest.mark.env_onecard
89def test_gatherv2_axisN1():
90    x = Tensor(np.arange(2 * 2 * 3).reshape(2, 2, 3), mstype.float32)
91    indices = Tensor(np.array([1, 2]), mstype.int32)
92    gatherv2 = NetGatherV2_axisN1()
93    ms_output = gatherv2(x, indices)
94    print("output:\n", ms_output)
95    expect = np.array([[[1., 2.],
96                        [4., 5.]],
97                       [[7., 8.],
98                        [10., 11.]]])
99    error = np.ones(shape=ms_output.asnumpy().shape) * 1.0e-6
100    diff = ms_output.asnumpy() - expect
101    assert np.all(diff < error)
102    assert np.all(-diff < error)
103
104if __name__ == '__main__':
105    test_gatherv2_axis0()
106    test_gatherv2_axis1()
107    test_gatherv2_axisN1()
108