1# Copyright 2024 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15# pylint: disable=unused-variable 16import numpy as np 17import pytest 18import mindspore as ms 19from mindspore import mint, jit, JitConfig 20from tests.st.ops.dynamic_shape.test_op_utils import TEST_OP 21 22 23def generate_random_input(shape, dtype): 24 return np.random.randn(*shape).astype(dtype), np.random.randn(*shape).astype(dtype) 25 26 27def generate_expect_forward_output(x, y): 28 return np.arctan2(x, y) 29 30 31def generate_expect_backward_output(x, y): 32 recip = x * x + y * y 33 return y / recip, -x / recip 34 35 36def atan2_forward_func(x, y): 37 return mint.atan2(x, y) 38 39 40def atan2_backward_func(x, y): 41 input_grad = ms.ops.grad(atan2_forward_func, (0, 1))(x, y) 42 return input_grad 43 44 45@pytest.mark.level0 46@pytest.mark.env_onecard 47@pytest.mark.platform_arm_ascend_training 48@pytest.mark.platform_x86_ascend_training 49@pytest.mark.parametrize('mode', ['pynative', 'KBK']) 50def test_atan2_std(mode): 51 """ 52 Feature: mint 53 Description: Verify the result of mint function 54 Expectation: success 55 """ 56 x, y = generate_random_input((2, 3), np.float32) 57 grad, _ = generate_random_input((2, 3), np.float32) 58 59 expect_forward = generate_expect_forward_output(x, y) 60 expect_grad = generate_expect_backward_output(x, y) 61 62 if mode == 'pynative': 63 ms.set_context(mode=ms.PYNATIVE_MODE) 64 output_forward = atan2_forward_func(ms.Tensor(x), ms.Tensor(y)) 65 output_grad = atan2_backward_func(ms.Tensor(x), ms.Tensor(y)) 66 else: 67 output_forward = (jit(atan2_forward_func, jit_config=JitConfig(jit_level="O0")))(ms.Tensor(x), 68 ms.Tensor(y)) 69 output_grad = (jit(atan2_backward_func, jit_config=JitConfig(jit_level="O0")))(ms.Tensor(x), 70 ms.Tensor(y)) 71 72 np.testing.assert_allclose(output_forward.asnumpy(), expect_forward, 1e-5, 1e-5) 73 np.testing.assert_allclose(output_grad[0].asnumpy(), expect_grad[0], 1e-5, 1e-5) 74 np.testing.assert_allclose(output_grad[1].asnumpy(), expect_grad[1], 1e-5, 1e-5) 75 76 77@pytest.mark.level1 78@pytest.mark.env_onecard 79@pytest.mark.platform_arm_ascend_training 80@pytest.mark.platform_x86_ascend_training 81def test_atan2_dynamic_shape(): 82 """ 83 Feature: Test atan2 with dynamic shape in graph mode. 84 Description: call mint.atan2 with valid input and other. 85 Expectation: return the correct value. 86 """ 87 input1, other1 = generate_random_input((2, 3), np.float32) 88 input2, other2 = generate_random_input((2, 3, 4), np.float32) 89 90 TEST_OP(atan2_forward_func, [[ms.Tensor(input1), ms.Tensor(other1)], [ms.Tensor(input2), ms.Tensor(other2)]], 91 'atan2_ext', disable_mode=['GRAPH_MODE']) 92 93 94@pytest.mark.level1 95@pytest.mark.platform_arm_ascend910b_training 96@pytest.mark.env_onecard 97@pytest.mark.parametrize('mode', [ms.PYNATIVE_MODE]) 98def test_atan2_bfloat16(mode): 99 """ 100 Feature: test atan2 functional API. 101 Description: testcase for atan2 functional API. 102 Expectation: the result match with expected result. 103 """ 104 ms.set_context(mode=mode) 105 x, y = generate_random_input((2, 3), np.float32) 106 output = atan2_forward_func(ms.Tensor(x, dtype=ms.bfloat16), ms.Tensor(y, dtype=ms.bfloat16)) 107 expect = generate_expect_forward_output(x, y).astype(np.float32) 108 assert np.allclose(output.float().asnumpy(), expect, 0.004, 0.004) 109