1# Copyright 2023 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15import pytest 16import numpy as np 17from mindspore import context 18from mindspore import ops 19 20from test_argmin import argmin_argmax_case, argmin_argmax_case_dyn, argmin_argmax_case_vmap 21 22def argmax_(input_x, axis, output_type): 23 return ops.Argmax(axis, output_type)(input_x) 24 25@pytest.mark.level1 26@pytest.mark.platform_x86_cpu 27@pytest.mark.platform_x86_gpu_training 28@pytest.mark.platform_arm_ascend_training 29@pytest.mark.env_onecard 30@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) 31def test_argmax(mode): 32 """ 33 Feature: Test argmin op. 34 Description: Test argmin. 35 Expectation: the result match with expected result. 36 """ 37 context.set_context(mode=mode) 38 argmin_argmax_case(argmax_, np.argmax) 39 40 41@pytest.mark.level1 42@pytest.mark.platform_x86_cpu 43@pytest.mark.platform_x86_gpu_training 44@pytest.mark.platform_arm_ascend_training 45@pytest.mark.env_onecard 46@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) 47def test_argmax_vmap(mode): 48 """ 49 Feature: Test argmin op. 50 Description: Test argmin vmap. 51 Expectation: the result match with expected result. 52 """ 53 context.set_context(mode=mode) 54 argmin_argmax_case_vmap(argmax_) 55 56 57@pytest.mark.level1 58@pytest.mark.platform_x86_cpu 59@pytest.mark.platform_x86_gpu_training 60@pytest.mark.platform_arm_ascend_training 61@pytest.mark.env_onecard 62@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) 63def test_argmax_dyn(mode): 64 """ 65 Feature: Test argmin op. 66 Description: Test argmin dynamic shape. 67 Expectation: the result match with expected result. 68 """ 69 context.set_context(mode=mode) 70 argmin_argmax_case_dyn(argmax_, np.argmax) 71