• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2024 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16from functools import reduce
17import pytest
18import numpy as np
19import mindspore as ms
20from mindspore import ops
21from mindspore.mint import stack
22from mindspore import jit, JitConfig
23from tests.st.utils import test_utils
24
25
26def stack_func(x1, x2, axis):
27    return stack((x1, x2), axis)
28
29
30@test_utils.run_with_cell
31def stack_forward_func(x1, x2, axis=2):
32    return stack_func(x1, x2, axis)
33
34
35def stack_bwd_func(x1, x2, axis):
36    return ops.grad(stack_func, (0, 1))(x1, x2, axis)
37
38
39@test_utils.run_with_cell
40def stack_backward_func(x1, x2, axis=2):
41    return stack_bwd_func(x1, x2, axis)
42
43
44def stack_fwd_data_prepare(shape, axis=2):
45    num = reduce(lambda x, y: x * y, shape)
46    x1 = np.array([0] * num).reshape(shape).astype(np.float16)
47    x2 = np.arange(num).reshape(shape).astype(np.float16)
48    tensor_inputs = (ms.Tensor(x1), ms.Tensor(x2))
49    expect = np.stack((x1, x2), axis)
50    return tensor_inputs, expect
51
52
53@pytest.mark.level0
54@pytest.mark.env_onecard
55@pytest.mark.platform_arm_ascend_training
56@pytest.mark.platform_x86_ascend_training
57@pytest.mark.parametrize('mode', ['pynative', 'KBK'])
58def test_stack_forward_backward(mode):
59    """
60    Feature: Ops.
61    Description: test op stack.
62    Expectation: expect correct result.
63    """
64    test_shape = (2, 2, 2, 2)
65    tensor_inputs, expect = stack_fwd_data_prepare(test_shape)
66    expects = (np.ones(test_shape).astype(np.float16), np.ones(test_shape).astype(np.float16))
67    output = stack_forward_func(tensor_inputs[0], tensor_inputs[1])
68    outputs = stack_backward_func(tensor_inputs[0], tensor_inputs[1])
69
70    if mode == 'pynative':
71        ms.set_context(mode=ms.PYNATIVE_MODE)
72        output = stack_forward_func(tensor_inputs[0], tensor_inputs[1])
73        outputs = stack_backward_func(tensor_inputs[0], tensor_inputs[1])
74    else:
75        output = (jit(stack_forward_func, jit_config=JitConfig(jit_level="O0")))(tensor_inputs[0], tensor_inputs[1])
76        outputs = (jit(stack_backward_func, jit_config=JitConfig(jit_level="O0")))(tensor_inputs[0], tensor_inputs[1])
77    assert np.allclose(output.asnumpy(), expect)
78    for output, expect in zip(outputs, expects):
79        assert np.allclose(output.asnumpy(), expect)
80
81
82@pytest.mark.level1
83@pytest.mark.platform_arm_ascend910b_training
84@pytest.mark.env_onecard
85@pytest.mark.parametrize('mode', [ms.PYNATIVE_MODE])
86def test_stack_bfloat16(mode):
87    """
88    Feature: test ne functional API.
89    Description: testcase for ne functional API.
90    Expectation: the result match with expected result.
91    """
92    ms.set_context(mode=mode, device_target="Ascend")
93    test_shape = (2, 3, 4)
94    tensor_inputs, expect = stack_fwd_data_prepare(test_shape)
95    output = stack_forward_func(tensor_inputs[0], tensor_inputs[1])
96    assert np.allclose(output.float().asnumpy(), expect, 0.004, 0.004)
97