1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16import numpy as np 17import pytest 18 19import mindspore.context as context 20import mindspore.nn as nn 21from mindspore import Tensor 22from mindspore.ops import operations as P 23 24 25class NetArgmaxWithValue(nn.Cell): 26 def __init__(self): 27 super(NetArgmaxWithValue, self).__init__() 28 axis1 = 0 29 axis2 = -1 30 self.argmax1 = P.ArgMaxWithValue(axis1) 31 self.argmax2 = P.ArgMaxWithValue(axis2) 32 self.argmax3 = P.ArgMaxWithValue() 33 34 def construct(self, x): 35 return (self.argmax1(x), self.argmax2(x), self.argmax3(x)) 36 37 38class NetArgmaxWithValueBig(nn.Cell): 39 def __init__(self, axis=0): 40 super(NetArgmaxWithValueBig, self).__init__() 41 self.argmax = P.ArgMaxWithValue(axis) 42 43 def construct(self, x): 44 return self.argmax(x) 45 46 47def argmaxwithvalue_base(data_type): 48 x = Tensor(np.array([[1., 20., 5.], 49 [67., 8., 9.], 50 [130., 24., 15.], 51 [0.3, -0.4, -15.]]).astype(data_type)) 52 expect1 = np.array([2, 2, 2]).astype(data_type) 53 expect2 = np.array([1, 0, 0, 0]).astype(data_type) 54 expect11 = np.array([130, 24, 15]).astype(data_type) 55 expect22 = np.array([20, 67, 130, 0.3]).astype(data_type) 56 context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") 57 argmax = NetArgmaxWithValue() 58 output = argmax(x) 59 assert (output[0][0].asnumpy() == expect1).all() 60 assert (output[0][1].asnumpy() == expect11).all() 61 assert (output[1][0].asnumpy() == expect2).all() 62 assert (output[1][1].asnumpy() == expect22).all() 63 assert (output[2][0].asnumpy() == expect1).all() 64 assert (output[2][1].asnumpy() == expect11).all() 65 66 context.set_context(mode=context.GRAPH_MODE, device_target="GPU") 67 argmax = NetArgmaxWithValue() 68 output = argmax(x) 69 assert (output[0][0].asnumpy() == expect1).all() 70 assert (output[0][1].asnumpy() == expect11).all() 71 assert (output[1][0].asnumpy() == expect2).all() 72 assert (output[1][1].asnumpy() == expect22).all() 73 assert (output[2][0].asnumpy() == expect1).all() 74 assert (output[2][1].asnumpy() == expect11).all() 75 76 77def argmaxwithvalue_3d(data_type, shape_x): 78 np.random.seed(2) 79 x_np = np.random.random(shape_x).astype(data_type) 80 x = Tensor(x_np) 81 82 argmax = NetArgmaxWithValueBig(0) 83 output = argmax(x) 84 expect1 = np.argmax(x_np, axis=0) 85 expect2 = np.maximum.reduce(x_np, 0) 86 assert (output[0].asnumpy() == expect1).all() 87 assert (output[1].asnumpy() == expect2).all() 88 89 argmax = NetArgmaxWithValueBig(1) 90 output = argmax(x) 91 expect1 = np.argmax(x_np, axis=1) 92 expect2 = np.maximum.reduce(x_np, 1) 93 assert (output[0].asnumpy() == expect1).all() 94 assert (output[1].asnumpy() == expect2).all() 95 96 argmax = NetArgmaxWithValueBig(2) 97 output = argmax(x) 98 expect1 = np.argmax(x_np, axis=2) 99 expect2 = np.maximum.reduce(x_np, 2) 100 assert (output[0].asnumpy() == expect1).all() 101 assert (output[1].asnumpy() == expect2).all() 102 103 104@pytest.mark.level0 105@pytest.mark.platform_x86_gpu_training 106@pytest.mark.env_onecard 107def test_argmaxwithvalue_base_float32(): 108 argmaxwithvalue_base(np.float32) 109 110 111@pytest.mark.level0 112@pytest.mark.platform_x86_gpu_training 113@pytest.mark.env_onecard 114def test_argmaxwithvalue_base_float16(): 115 argmaxwithvalue_base(np.float16) 116 117 118@pytest.mark.level0 119@pytest.mark.platform_x86_gpu_training 120@pytest.mark.env_onecard 121def test_argmaxwithvalue_3d_float32(): 122 shape_x = (2, 32, 256) 123 context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") 124 argmaxwithvalue_3d(np.float32, shape_x) 125 context.set_context(mode=context.GRAPH_MODE, device_target="GPU") 126 argmaxwithvalue_3d(np.float32, shape_x) 127 128 129@pytest.mark.level0 130@pytest.mark.platform_x86_gpu_training 131@pytest.mark.env_onecard 132def test_argmaxwithvalue_3d_float16(): 133 shape_x = (2, 64, 128) 134 context.set_context(mode=context.GRAPH_MODE, device_target="GPU") 135 argmaxwithvalue_3d(np.float16, shape_x) 136 137 138@pytest.mark.level0 139@pytest.mark.platform_x86_gpu_training 140@pytest.mark.env_onecard 141def test_argmaxwithvalue_3d_big_float32(): 142 shape_x = (128, 1024, 1) 143 context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") 144 argmaxwithvalue_3d(np.float32, shape_x) 145 context.set_context(mode=context.GRAPH_MODE, device_target="GPU") 146 argmaxwithvalue_3d(np.float32, shape_x) 147