• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16import numpy as np
17import pytest
18
19import mindspore.context as context
20import mindspore.nn as nn
21from mindspore import Tensor
22from mindspore.ops import operations as P
23
24
25class NetArgmaxWithValue(nn.Cell):
26    def __init__(self):
27        super(NetArgmaxWithValue, self).__init__()
28        axis1 = 0
29        axis2 = -1
30        self.argmax1 = P.ArgMaxWithValue(axis1)
31        self.argmax2 = P.ArgMaxWithValue(axis2)
32        self.argmax3 = P.ArgMaxWithValue()
33
34    def construct(self, x):
35        return (self.argmax1(x), self.argmax2(x), self.argmax3(x))
36
37
38class NetArgmaxWithValueBig(nn.Cell):
39    def __init__(self, axis=0):
40        super(NetArgmaxWithValueBig, self).__init__()
41        self.argmax = P.ArgMaxWithValue(axis)
42
43    def construct(self, x):
44        return self.argmax(x)
45
46
47def argmaxwithvalue_base(data_type):
48    x = Tensor(np.array([[1., 20., 5.],
49                         [67., 8., 9.],
50                         [130., 24., 15.],
51                         [0.3, -0.4, -15.]]).astype(data_type))
52    expect1 = np.array([2, 2, 2]).astype(data_type)
53    expect2 = np.array([1, 0, 0, 0]).astype(data_type)
54    expect11 = np.array([130, 24, 15]).astype(data_type)
55    expect22 = np.array([20, 67, 130, 0.3]).astype(data_type)
56    context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
57    argmax = NetArgmaxWithValue()
58    output = argmax(x)
59    assert (output[0][0].asnumpy() == expect1).all()
60    assert (output[0][1].asnumpy() == expect11).all()
61    assert (output[1][0].asnumpy() == expect2).all()
62    assert (output[1][1].asnumpy() == expect22).all()
63    assert (output[2][0].asnumpy() == expect1).all()
64    assert (output[2][1].asnumpy() == expect11).all()
65
66    context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
67    argmax = NetArgmaxWithValue()
68    output = argmax(x)
69    assert (output[0][0].asnumpy() == expect1).all()
70    assert (output[0][1].asnumpy() == expect11).all()
71    assert (output[1][0].asnumpy() == expect2).all()
72    assert (output[1][1].asnumpy() == expect22).all()
73    assert (output[2][0].asnumpy() == expect1).all()
74    assert (output[2][1].asnumpy() == expect11).all()
75
76
77def argmaxwithvalue_3d(data_type, shape_x):
78    np.random.seed(2)
79    x_np = np.random.random(shape_x).astype(data_type)
80    x = Tensor(x_np)
81
82    argmax = NetArgmaxWithValueBig(0)
83    output = argmax(x)
84    expect1 = np.argmax(x_np, axis=0)
85    expect2 = np.maximum.reduce(x_np, 0)
86    assert (output[0].asnumpy() == expect1).all()
87    assert (output[1].asnumpy() == expect2).all()
88
89    argmax = NetArgmaxWithValueBig(1)
90    output = argmax(x)
91    expect1 = np.argmax(x_np, axis=1)
92    expect2 = np.maximum.reduce(x_np, 1)
93    assert (output[0].asnumpy() == expect1).all()
94    assert (output[1].asnumpy() == expect2).all()
95
96    argmax = NetArgmaxWithValueBig(2)
97    output = argmax(x)
98    expect1 = np.argmax(x_np, axis=2)
99    expect2 = np.maximum.reduce(x_np, 2)
100    assert (output[0].asnumpy() == expect1).all()
101    assert (output[1].asnumpy() == expect2).all()
102
103
104@pytest.mark.level0
105@pytest.mark.platform_x86_gpu_training
106@pytest.mark.env_onecard
107def test_argmaxwithvalue_base_float32():
108    argmaxwithvalue_base(np.float32)
109
110
111@pytest.mark.level0
112@pytest.mark.platform_x86_gpu_training
113@pytest.mark.env_onecard
114def test_argmaxwithvalue_base_float16():
115    argmaxwithvalue_base(np.float16)
116
117
118@pytest.mark.level0
119@pytest.mark.platform_x86_gpu_training
120@pytest.mark.env_onecard
121def test_argmaxwithvalue_3d_float32():
122    shape_x = (2, 32, 256)
123    context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
124    argmaxwithvalue_3d(np.float32, shape_x)
125    context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
126    argmaxwithvalue_3d(np.float32, shape_x)
127
128
129@pytest.mark.level0
130@pytest.mark.platform_x86_gpu_training
131@pytest.mark.env_onecard
132def test_argmaxwithvalue_3d_float16():
133    shape_x = (2, 64, 128)
134    context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
135    argmaxwithvalue_3d(np.float16, shape_x)
136
137
138@pytest.mark.level0
139@pytest.mark.platform_x86_gpu_training
140@pytest.mark.env_onecard
141def test_argmaxwithvalue_3d_big_float32():
142    shape_x = (128, 1024, 1)
143    context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
144    argmaxwithvalue_3d(np.float32, shape_x)
145    context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
146    argmaxwithvalue_3d(np.float32, shape_x)
147