1# Copyright 2024 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16from functools import reduce 17import pytest 18import numpy as np 19import mindspore as ms 20from mindspore import ops 21from mindspore.mint import stack 22from mindspore import jit, JitConfig 23from tests.st.utils import test_utils 24 25 26def stack_func(x1, x2, axis): 27 return stack((x1, x2), axis) 28 29 30@test_utils.run_with_cell 31def stack_forward_func(x1, x2, axis=2): 32 return stack_func(x1, x2, axis) 33 34 35def stack_bwd_func(x1, x2, axis): 36 return ops.grad(stack_func, (0, 1))(x1, x2, axis) 37 38 39@test_utils.run_with_cell 40def stack_backward_func(x1, x2, axis=2): 41 return stack_bwd_func(x1, x2, axis) 42 43 44def stack_fwd_data_prepare(shape, axis=2): 45 num = reduce(lambda x, y: x * y, shape) 46 x1 = np.array([0] * num).reshape(shape).astype(np.float16) 47 x2 = np.arange(num).reshape(shape).astype(np.float16) 48 tensor_inputs = (ms.Tensor(x1), ms.Tensor(x2)) 49 expect = np.stack((x1, x2), axis) 50 return tensor_inputs, expect 51 52 53@pytest.mark.level0 54@pytest.mark.env_onecard 55@pytest.mark.platform_arm_ascend_training 56@pytest.mark.platform_x86_ascend_training 57@pytest.mark.parametrize('mode', ['pynative', 'KBK']) 58def test_stack_forward_backward(mode): 59 """ 60 Feature: Ops. 61 Description: test op stack. 62 Expectation: expect correct result. 63 """ 64 test_shape = (2, 2, 2, 2) 65 tensor_inputs, expect = stack_fwd_data_prepare(test_shape) 66 expects = (np.ones(test_shape).astype(np.float16), np.ones(test_shape).astype(np.float16)) 67 output = stack_forward_func(tensor_inputs[0], tensor_inputs[1]) 68 outputs = stack_backward_func(tensor_inputs[0], tensor_inputs[1]) 69 70 if mode == 'pynative': 71 ms.set_context(mode=ms.PYNATIVE_MODE) 72 output = stack_forward_func(tensor_inputs[0], tensor_inputs[1]) 73 outputs = stack_backward_func(tensor_inputs[0], tensor_inputs[1]) 74 else: 75 output = (jit(stack_forward_func, jit_config=JitConfig(jit_level="O0")))(tensor_inputs[0], tensor_inputs[1]) 76 outputs = (jit(stack_backward_func, jit_config=JitConfig(jit_level="O0")))(tensor_inputs[0], tensor_inputs[1]) 77 assert np.allclose(output.asnumpy(), expect) 78 for output, expect in zip(outputs, expects): 79 assert np.allclose(output.asnumpy(), expect) 80 81 82@pytest.mark.level1 83@pytest.mark.platform_arm_ascend910b_training 84@pytest.mark.env_onecard 85@pytest.mark.parametrize('mode', [ms.PYNATIVE_MODE]) 86def test_stack_bfloat16(mode): 87 """ 88 Feature: test ne functional API. 89 Description: testcase for ne functional API. 90 Expectation: the result match with expected result. 91 """ 92 ms.set_context(mode=mode, device_target="Ascend") 93 test_shape = (2, 3, 4) 94 tensor_inputs, expect = stack_fwd_data_prepare(test_shape) 95 output = stack_forward_func(tensor_inputs[0], tensor_inputs[1]) 96 assert np.allclose(output.float().asnumpy(), expect, 0.004, 0.004) 97