• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2024 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16import pytest
17import numpy as np
18import mindspore as ms
19from mindspore import ops
20from mindspore.mint import flatten
21from tests.st.ops.dynamic_shape.test_op_utils import TEST_OP
22from tests.st.utils import test_utils
23
24def generate_random_input(shape, dtype):
25    return np.random.randn(*shape).astype(dtype)
26
27
28def flatten_func(x, start_dim=0, end_dim=-1):
29    return flatten(x, start_dim, end_dim)
30
31
32@test_utils.run_with_cell
33def flatten_forward_func(x, start_dim=0, end_dim=-1):
34    return flatten_func(x, start_dim, end_dim)
35
36
37def flatten_bwd_func(x, start_dim=0, end_dim=-1):
38    return ops.grad(flatten_func, (0,))(x, start_dim, end_dim)
39
40
41@test_utils.run_with_cell
42def flatten_backward_func(x, start_dim=0, end_dim=-1):
43    return flatten_bwd_func(x, start_dim, end_dim)
44
45
46@pytest.mark.level0
47@pytest.mark.env_onecard
48@pytest.mark.platform_arm_ascend_training
49@pytest.mark.platform_x86_ascend_training
50@pytest.mark.parametrize("mode", [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
51def test_flatten_forward(mode):
52    """
53    Feature: Ops.
54    Description: test op flatten.
55    Expectation: expect correct result.
56    """
57    ms.set_context(jit_level='O0')
58    ms.set_context(mode=mode)
59    test_shape = (2, 3, 4, 5)
60    x = generate_random_input(test_shape, np.float32)
61    output = flatten_forward_func(ms.Tensor(x), 1, 2)
62    expect = x.reshape((2, 12, 5))
63    np.testing.assert_allclose(output.asnumpy(), expect, rtol=1e-4)
64
65    output2 = flatten_forward_func(ms.Tensor(x), 0, 2)
66    expect2 = x.reshape((24, 5))
67    np.testing.assert_allclose(output2.asnumpy(), expect2, rtol=1e-4)
68
69
70@pytest.mark.level1
71@pytest.mark.platform_arm_ascend910b_training
72@pytest.mark.env_onecard
73@pytest.mark.parametrize("mode", [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
74def test_flatten_bfloat16(mode):
75    """
76    Feature: test ne functional API.
77    Description: testcase for ne functional API.
78    Expectation: the result match with expected result.
79    """
80    ms.set_context(jit_level='O0')
81    ms.set_context(mode=mode)
82    test_shape = (2, 3, 4)
83    x = generate_random_input(test_shape, np.float32)
84    output = flatten_forward_func(ms.Tensor(x))
85    expect = x.reshape(24)
86    np.testing.assert_allclose(output.float().asnumpy(), expect, rtol=5e-3, atol=5e-3)
87
88    output2 = flatten_forward_func(ms.Tensor(x), 0, 1)
89    expect2 = x.reshape((6, 4))
90    np.testing.assert_allclose(output2.asnumpy(), expect2, rtol=5e-3, atol=5e-3)
91
92
93@pytest.mark.level0
94@pytest.mark.env_onecard
95@pytest.mark.platform_arm_ascend_training
96@pytest.mark.platform_x86_ascend_training
97@pytest.mark.parametrize("mode", [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
98def test_flatten_backward(mode):
99    """
100    Feature: Ops.
101    Description: test op flatten.
102    Expectation: expect correct result.
103    """
104    ms.set_context(jit_level='O0')
105    ms.set_context(mode=mode)
106    test_shape = (2, 3, 4, 5)
107    x = generate_random_input(test_shape, np.float32)
108    output = flatten_backward_func(ms.Tensor(x), 1, 3)
109    expect = np.ones(test_shape).astype(np.float32)
110    np.testing.assert_allclose(output.asnumpy(), expect, rtol=1e-4)
111
112    output2 = flatten_backward_func(ms.Tensor(x), 0, 2)
113    expect2 = np.ones(test_shape).astype(np.float32)
114    np.testing.assert_allclose(output2.asnumpy(), expect2, rtol=1e-4)
115
116
117@pytest.mark.level0
118@pytest.mark.env_onecard
119@pytest.mark.platform_arm_ascend_training
120@pytest.mark.platform_x86_ascend_training
121def test_flatten_dynamic_shape():
122    """
123    Feature: Test dynamic shape.
124    Description: test function div dynamic feature.
125    Expectation: expect correct result.
126    """
127    ms_data1 = generate_random_input((2, 3, 4, 5, 6), np.float32)
128    ms_data2 = generate_random_input((3, 4, 5, 6), np.float32)
129    TEST_OP(flatten_forward_func, [[ms.Tensor(ms_data1), 2, 3],
130                                   [ms.Tensor(ms_data2), 0, 1]], '', disable_yaml_check=True,
131            disable_mode=['GRAPH_MODE'])
132