• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2024 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15# pylint: disable=unused-variable
16import numpy as np
17import pytest
18import mindspore as ms
19from mindspore import mint, jit, JitConfig
20from tests.st.ops.dynamic_shape.test_op_utils import TEST_OP
21
22
23def generate_random_input(shape, dtype):
24    return np.random.randn(*shape).astype(dtype)
25
26
27def generate_expect_forward_output(x):
28    return np.linalg.inv(x)
29
30
31def generate_expect_backward_output(x):
32    res = generate_expect_forward_output(x)
33    temp = np.matmul(np.ones_like(res, np.float32), res.T)
34    out = -1 * np.matmul(res.T, temp)
35    return out
36
37
38def inverse_forward_func(x):
39    return mint.inverse(x)
40
41
42def inverse_backward_func(x):
43    input_grad = ms.ops.grad(inverse_forward_func, 0)(x)
44    return input_grad
45
46
47@pytest.mark.level0
48@pytest.mark.env_onecard
49@pytest.mark.platform_arm_ascend_training
50@pytest.mark.platform_x86_ascend_training
51@pytest.mark.parametrize('mode', ['pynative', 'KBK'])
52def test_inverse_std(mode):
53    """
54    Feature: mint
55    Description: Verify the result of mint function
56    Expectation: success
57    """
58    x = generate_random_input((9, 9), np.float32)
59
60    expect_forward = generate_expect_forward_output(x)
61    expect_grad = generate_expect_backward_output(x)
62
63    if mode == 'pynative':
64        ms.set_context(mode=ms.PYNATIVE_MODE)
65        output_forward = inverse_forward_func(ms.Tensor(x))
66        output_grad = inverse_backward_func(ms.Tensor(x))
67    else:
68        output_forward = (jit(inverse_forward_func, jit_config=JitConfig(jit_level="O0")))(ms.Tensor(x))
69        output_grad = (jit(inverse_backward_func, jit_config=JitConfig(jit_level="O0")))(ms.Tensor(x))
70
71    assert np.allclose(output_forward.asnumpy(), expect_forward, 1e-2, 1e-2)
72    assert np.allclose(output_grad.asnumpy(), expect_grad, 5e-1, 5e-1)
73
74
75@pytest.mark.level0
76@pytest.mark.env_onecard
77@pytest.mark.platform_arm_ascend_training
78@pytest.mark.platform_x86_ascend_training
79@pytest.mark.parametrize('mode', ['pynative', 'KBK'])
80def test_inverse_dynamic_shape(mode):
81    """
82    Feature: Test leaky relu with dynamic shape in graph mode.
83    Description: call mint.inverse with valid input and index.
84    Expectation: return the correct value.
85    """
86    x1 = generate_random_input((3, 3), np.float32)
87    x2 = generate_random_input((2, 4, 4), np.float32)
88
89    TEST_OP(inverse_forward_func, [[ms.Tensor(x1)], [ms.Tensor(x2)]], 'matrix_inverse_ext',
90            disable_input_check=True, disable_mode=['GRAPH_MODE'])
91