1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16import pytest 17import numpy as np 18from mindspore import Tensor 19from mindspore.ops import operations as P 20import mindspore.nn as nn 21import mindspore.context as context 22from mindspore.common import dtype as mstype 23 24context.set_context(mode=context.GRAPH_MODE, device_target='CPU') 25 26class NetGatherV2_axis0(nn.Cell): 27 def __init__(self): 28 super(NetGatherV2_axis0, self).__init__() 29 self.gatherv2 = P.Gather() 30 31 def construct(self, params, indices): 32 return self.gatherv2(params, indices, 0) 33 34@pytest.mark.level0 35@pytest.mark.platform_x86_cpu 36@pytest.mark.env_onecard 37def test_gatherv2_axis0(): 38 x = Tensor(np.arange(3 * 2 * 2).reshape(3, 2, 2), mstype.float32) 39 indices = Tensor(np.array([1, 2]), mstype.int32) 40 gatherv2 = NetGatherV2_axis0() 41 ms_output = gatherv2(x, indices) 42 print("output:\n", ms_output) 43 expect = np.array([[[4., 5.], 44 [6., 7.]], 45 [[8., 9.], 46 [10., 11.]]]) 47 error = np.ones(shape=ms_output.asnumpy().shape) * 1.0e-6 48 diff = ms_output.asnumpy() - expect 49 assert np.all(diff < error) 50 assert np.all(-diff < error) 51 52class NetGatherV2_axis1(nn.Cell): 53 def __init__(self): 54 super(NetGatherV2_axis1, self).__init__() 55 self.gatherv2 = P.Gather() 56 57 def construct(self, params, indices): 58 return self.gatherv2(params, indices, 1) 59 60@pytest.mark.level0 61@pytest.mark.platform_x86_cpu 62@pytest.mark.env_onecard 63def test_gatherv2_axis1(): 64 x = Tensor(np.arange(2 * 3 * 2).reshape(2, 3, 2), mstype.float32) 65 indices = Tensor(np.array([1, 2]), mstype.int32) 66 gatherv2 = NetGatherV2_axis1() 67 ms_output = gatherv2(x, indices) 68 print("output:\n", ms_output) 69 expect = np.array([[[2., 3.], 70 [4., 5.]], 71 [[8., 9.], 72 [10., 11.]]]) 73 error = np.ones(shape=ms_output.asnumpy().shape) * 1.0e-6 74 diff = ms_output.asnumpy() - expect 75 assert np.all(diff < error) 76 assert np.all(-diff < error) 77 78class NetGatherV2_axisN1(nn.Cell): 79 def __init__(self): 80 super(NetGatherV2_axisN1, self).__init__() 81 self.gatherv2 = P.Gather() 82 83 def construct(self, params, indices): 84 return self.gatherv2(params, indices, -1) 85 86@pytest.mark.level0 87@pytest.mark.platform_x86_cpu 88@pytest.mark.env_onecard 89def test_gatherv2_axisN1(): 90 x = Tensor(np.arange(2 * 2 * 3).reshape(2, 2, 3), mstype.float32) 91 indices = Tensor(np.array([1, 2]), mstype.int32) 92 gatherv2 = NetGatherV2_axisN1() 93 ms_output = gatherv2(x, indices) 94 print("output:\n", ms_output) 95 expect = np.array([[[1., 2.], 96 [4., 5.]], 97 [[7., 8.], 98 [10., 11.]]]) 99 error = np.ones(shape=ms_output.asnumpy().shape) * 1.0e-6 100 diff = ms_output.asnumpy() - expect 101 assert np.all(diff < error) 102 assert np.all(-diff < error) 103 104if __name__ == '__main__': 105 test_gatherv2_axis0() 106 test_gatherv2_axis1() 107 test_gatherv2_axisN1() 108