• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2023 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15import pytest
16import numpy as np
17from mindspore import context
18from mindspore import ops
19
20from test_argmin import argmin_argmax_case, argmin_argmax_case_dyn, argmin_argmax_case_vmap
21
22def argmax_(input_x, axis, output_type):
23    return ops.Argmax(axis, output_type)(input_x)
24
25@pytest.mark.level1
26@pytest.mark.platform_x86_cpu
27@pytest.mark.platform_x86_gpu_training
28@pytest.mark.platform_arm_ascend_training
29@pytest.mark.env_onecard
30@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
31def test_argmax(mode):
32    """
33    Feature: Test argmin op.
34    Description: Test argmin.
35    Expectation: the result match with expected result.
36    """
37    context.set_context(mode=mode)
38    argmin_argmax_case(argmax_, np.argmax)
39
40
41@pytest.mark.level1
42@pytest.mark.platform_x86_cpu
43@pytest.mark.platform_x86_gpu_training
44@pytest.mark.platform_arm_ascend_training
45@pytest.mark.env_onecard
46@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
47def test_argmax_vmap(mode):
48    """
49    Feature: Test argmin op.
50    Description: Test argmin vmap.
51    Expectation: the result match with expected result.
52    """
53    context.set_context(mode=mode)
54    argmin_argmax_case_vmap(argmax_)
55
56
57@pytest.mark.level1
58@pytest.mark.platform_x86_cpu
59@pytest.mark.platform_x86_gpu_training
60@pytest.mark.platform_arm_ascend_training
61@pytest.mark.env_onecard
62@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
63def test_argmax_dyn(mode):
64    """
65    Feature: Test argmin op.
66    Description: Test argmin dynamic shape.
67    Expectation: the result match with expected result.
68    """
69    context.set_context(mode=mode)
70    argmin_argmax_case_dyn(argmax_, np.argmax)
71