• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2024 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15# pylint: disable=unused-variable
16import numpy as np
17import pytest
18import mindspore as ms
19from mindspore import mint, jit, JitConfig
20from tests.st.ops.dynamic_shape.test_op_utils import TEST_OP
21
22
23def generate_random_input(shape, dtype):
24    return np.random.randn(*shape).astype(dtype), np.random.randn(*shape).astype(dtype)
25
26
27def generate_expect_forward_output(x, y):
28    return np.arctan2(x, y)
29
30
31def generate_expect_backward_output(x, y):
32    recip = x * x + y * y
33    return y / recip, -x / recip
34
35
36def atan2_forward_func(x, y):
37    return mint.atan2(x, y)
38
39
40def atan2_backward_func(x, y):
41    input_grad = ms.ops.grad(atan2_forward_func, (0, 1))(x, y)
42    return input_grad
43
44
45@pytest.mark.level0
46@pytest.mark.env_onecard
47@pytest.mark.platform_arm_ascend_training
48@pytest.mark.platform_x86_ascend_training
49@pytest.mark.parametrize('mode', ['pynative', 'KBK'])
50def test_atan2_std(mode):
51    """
52    Feature: mint
53    Description: Verify the result of mint function
54    Expectation: success
55    """
56    x, y = generate_random_input((2, 3), np.float32)
57    grad, _ = generate_random_input((2, 3), np.float32)
58
59    expect_forward = generate_expect_forward_output(x, y)
60    expect_grad = generate_expect_backward_output(x, y)
61
62    if mode == 'pynative':
63        ms.set_context(mode=ms.PYNATIVE_MODE)
64        output_forward = atan2_forward_func(ms.Tensor(x), ms.Tensor(y))
65        output_grad = atan2_backward_func(ms.Tensor(x), ms.Tensor(y))
66    else:
67        output_forward = (jit(atan2_forward_func, jit_config=JitConfig(jit_level="O0")))(ms.Tensor(x),
68                                                                                         ms.Tensor(y))
69        output_grad = (jit(atan2_backward_func, jit_config=JitConfig(jit_level="O0")))(ms.Tensor(x),
70                                                                                       ms.Tensor(y))
71
72    np.testing.assert_allclose(output_forward.asnumpy(), expect_forward, 1e-5, 1e-5)
73    np.testing.assert_allclose(output_grad[0].asnumpy(), expect_grad[0], 1e-5, 1e-5)
74    np.testing.assert_allclose(output_grad[1].asnumpy(), expect_grad[1], 1e-5, 1e-5)
75
76
77@pytest.mark.level1
78@pytest.mark.env_onecard
79@pytest.mark.platform_arm_ascend_training
80@pytest.mark.platform_x86_ascend_training
81def test_atan2_dynamic_shape():
82    """
83    Feature: Test atan2 with dynamic shape in graph mode.
84    Description: call mint.atan2 with valid input and other.
85    Expectation: return the correct value.
86    """
87    input1, other1 = generate_random_input((2, 3), np.float32)
88    input2, other2 = generate_random_input((2, 3, 4), np.float32)
89
90    TEST_OP(atan2_forward_func, [[ms.Tensor(input1), ms.Tensor(other1)], [ms.Tensor(input2), ms.Tensor(other2)]],
91            'atan2_ext', disable_mode=['GRAPH_MODE'])
92
93
94@pytest.mark.level1
95@pytest.mark.platform_arm_ascend910b_training
96@pytest.mark.env_onecard
97@pytest.mark.parametrize('mode', [ms.PYNATIVE_MODE])
98def test_atan2_bfloat16(mode):
99    """
100    Feature: test atan2 functional API.
101    Description: testcase for atan2 functional API.
102    Expectation: the result match with expected result.
103    """
104    ms.set_context(mode=mode)
105    x, y = generate_random_input((2, 3), np.float32)
106    output = atan2_forward_func(ms.Tensor(x, dtype=ms.bfloat16), ms.Tensor(y, dtype=ms.bfloat16))
107    expect = generate_expect_forward_output(x, y).astype(np.float32)
108    assert np.allclose(output.float().asnumpy(), expect, 0.004, 0.004)
109