1# Copyright 2024 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16import pytest 17import numpy as np 18import mindspore as ms 19from mindspore import ops 20from mindspore.mint import flatten 21from tests.st.ops.dynamic_shape.test_op_utils import TEST_OP 22from tests.st.utils import test_utils 23 24def generate_random_input(shape, dtype): 25 return np.random.randn(*shape).astype(dtype) 26 27 28def flatten_func(x, start_dim=0, end_dim=-1): 29 return flatten(x, start_dim, end_dim) 30 31 32@test_utils.run_with_cell 33def flatten_forward_func(x, start_dim=0, end_dim=-1): 34 return flatten_func(x, start_dim, end_dim) 35 36 37def flatten_bwd_func(x, start_dim=0, end_dim=-1): 38 return ops.grad(flatten_func, (0,))(x, start_dim, end_dim) 39 40 41@test_utils.run_with_cell 42def flatten_backward_func(x, start_dim=0, end_dim=-1): 43 return flatten_bwd_func(x, start_dim, end_dim) 44 45 46@pytest.mark.level0 47@pytest.mark.env_onecard 48@pytest.mark.platform_arm_ascend_training 49@pytest.mark.platform_x86_ascend_training 50@pytest.mark.parametrize("mode", [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) 51def test_flatten_forward(mode): 52 """ 53 Feature: Ops. 54 Description: test op flatten. 55 Expectation: expect correct result. 56 """ 57 ms.set_context(jit_level='O0') 58 ms.set_context(mode=mode) 59 test_shape = (2, 3, 4, 5) 60 x = generate_random_input(test_shape, np.float32) 61 output = flatten_forward_func(ms.Tensor(x), 1, 2) 62 expect = x.reshape((2, 12, 5)) 63 np.testing.assert_allclose(output.asnumpy(), expect, rtol=1e-4) 64 65 output2 = flatten_forward_func(ms.Tensor(x), 0, 2) 66 expect2 = x.reshape((24, 5)) 67 np.testing.assert_allclose(output2.asnumpy(), expect2, rtol=1e-4) 68 69 70@pytest.mark.level1 71@pytest.mark.platform_arm_ascend910b_training 72@pytest.mark.env_onecard 73@pytest.mark.parametrize("mode", [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) 74def test_flatten_bfloat16(mode): 75 """ 76 Feature: test ne functional API. 77 Description: testcase for ne functional API. 78 Expectation: the result match with expected result. 79 """ 80 ms.set_context(jit_level='O0') 81 ms.set_context(mode=mode) 82 test_shape = (2, 3, 4) 83 x = generate_random_input(test_shape, np.float32) 84 output = flatten_forward_func(ms.Tensor(x)) 85 expect = x.reshape(24) 86 np.testing.assert_allclose(output.float().asnumpy(), expect, rtol=5e-3, atol=5e-3) 87 88 output2 = flatten_forward_func(ms.Tensor(x), 0, 1) 89 expect2 = x.reshape((6, 4)) 90 np.testing.assert_allclose(output2.asnumpy(), expect2, rtol=5e-3, atol=5e-3) 91 92 93@pytest.mark.level0 94@pytest.mark.env_onecard 95@pytest.mark.platform_arm_ascend_training 96@pytest.mark.platform_x86_ascend_training 97@pytest.mark.parametrize("mode", [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) 98def test_flatten_backward(mode): 99 """ 100 Feature: Ops. 101 Description: test op flatten. 102 Expectation: expect correct result. 103 """ 104 ms.set_context(jit_level='O0') 105 ms.set_context(mode=mode) 106 test_shape = (2, 3, 4, 5) 107 x = generate_random_input(test_shape, np.float32) 108 output = flatten_backward_func(ms.Tensor(x), 1, 3) 109 expect = np.ones(test_shape).astype(np.float32) 110 np.testing.assert_allclose(output.asnumpy(), expect, rtol=1e-4) 111 112 output2 = flatten_backward_func(ms.Tensor(x), 0, 2) 113 expect2 = np.ones(test_shape).astype(np.float32) 114 np.testing.assert_allclose(output2.asnumpy(), expect2, rtol=1e-4) 115 116 117@pytest.mark.level0 118@pytest.mark.env_onecard 119@pytest.mark.platform_arm_ascend_training 120@pytest.mark.platform_x86_ascend_training 121def test_flatten_dynamic_shape(): 122 """ 123 Feature: Test dynamic shape. 124 Description: test function div dynamic feature. 125 Expectation: expect correct result. 126 """ 127 ms_data1 = generate_random_input((2, 3, 4, 5, 6), np.float32) 128 ms_data2 = generate_random_input((3, 4, 5, 6), np.float32) 129 TEST_OP(flatten_forward_func, [[ms.Tensor(ms_data1), 2, 3], 130 [ms.Tensor(ms_data2), 0, 1]], '', disable_yaml_check=True, 131 disable_mode=['GRAPH_MODE']) 132