• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2024 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15import pytest
16import numpy as np
17from tests.st.utils import test_utils
18
19import mindspore as ms
20from mindspore import mint, Tensor, jit, context, JitConfig, ops
21# from tests.st.ops.dynamic_shape.test_op_utils import TEST_OP
22
23@test_utils.run_with_cell
24def chunk_forward_func(x, chunks, dim):
25    return mint.chunk(x, chunks, dim)
26
27@test_utils.run_with_cell
28def chunk_backward_func(x, chunks, dim):
29    return ops.grad(chunk_forward_func, (0,))(x, chunks, dim)
30
31
32def do_test_chunk_forward(mode):
33    """
34    Feature: Split
35    Description: test op Split
36    Expectation: expect correct result.
37    """
38    np_x = np.array(np.arange(10).reshape((5, 2)), dtype=np.float32)
39    x = ms.Tensor(np_x, dtype=ms.float32)
40    dims = 0
41    chunks = 3
42    expect = [np.array(np.arange(4).reshape((2, 2)), dtype=np.float32),
43              np.array(np.arange(4, 8).reshape((2, 2)), dtype=np.float32),
44              np.array(np.arange(8, 10).reshape((1, 2)), dtype=np.float32)]
45    if mode == 'pynative':
46        context.set_context(mode=ms.PYNATIVE_MODE)
47        out = chunk_forward_func(x, chunks, dims)
48    elif mode == 'KBK':
49        context.set_context(mode=ms.GRAPH_MODE)
50        out = (jit(chunk_forward_func, jit_config=JitConfig(jit_level="O0")))(x, chunks, dims)
51    else:
52        context.set_context(mode=ms.GRAPH_MODE)
53        out = chunk_forward_func(x, chunks, dims)
54    for res, exp in zip(out, expect):
55        assert np.allclose(res.asnumpy(), exp)
56
57
58@pytest.mark.level0
59@pytest.mark.env_onecard
60@pytest.mark.platform_arm_ascend_training
61@pytest.mark.platform_x86_ascend_training
62@pytest.mark.parametrize("mode", ['GE', 'pynative', 'KBK'])
63def test_chunk_forward_with_minus_dim(mode):
64    """
65    Feature: Chunk
66    Description: test op Chunk
67    Expectation: expect correct result.
68    """
69    np_x = np.array(np.arange(10).reshape((5, 2)), dtype=np.float32)
70    x = ms.Tensor(np_x, dtype=ms.float32)
71    dims = -1
72    chunks = 2
73    expect = [np.array([[0], [2], [4], [6], [8]], dtype=np.float32),
74              np.array([[1], [3], [5], [7], [9]], dtype=np.float32)]
75    if mode == 'pynative':
76        context.set_context(mode=ms.PYNATIVE_MODE)
77        out = chunk_forward_func(x, chunks, dims)
78    elif mode == 'KBK':
79        context.set_context(mode=ms.GRAPH_MODE)
80        out = (jit(chunk_forward_func, jit_config=JitConfig(jit_level="O0")))(x, chunks, dims)
81    else:
82        context.set_context(mode=ms.GRAPH_MODE)
83        out = chunk_forward_func(x, chunks, dims)
84    for res, exp in zip(out, expect):
85        assert np.allclose(res.asnumpy(), exp)
86
87
88@pytest.mark.level0
89@pytest.mark.env_onecard
90@pytest.mark.platform_arm_ascend_training
91@pytest.mark.platform_x86_ascend_training
92@pytest.mark.parametrize("mode", ['GE', 'pynative'])
93def test_chunk_forward(mode):
94    """
95    Feature: Chunk
96    Description: test op Chunk
97    Expectation: expect correct result.
98    """
99    do_test_chunk_forward(mode)
100
101
102@pytest.mark.level0
103@pytest.mark.env_onecard
104@pytest.mark.platform_arm_ascend_training
105@pytest.mark.platform_x86_ascend_training
106def test_chunk_forward_kbk():
107    """
108    Feature: Chunk
109    Description: test op Chunk
110    Expectation: expect correct result.
111    """
112    do_test_chunk_forward('KBK')
113
114
115@pytest.mark.level0
116@pytest.mark.env_onecard
117@pytest.mark.platform_arm_ascend_training
118@pytest.mark.platform_x86_ascend_training
119@pytest.mark.parametrize("mode", ['GE', 'pynative', 'KBK'])
120def test_chunk_backward(mode):
121    """
122    Feature: Auto grad.
123    Description: test auto grad of op Split.
124    Expectation: expect correct result.
125    """
126    x = Tensor(np.arange(20).reshape(10, 2), dtype=ms.float32)
127    chunks = 2
128    expect_grad = np.ones((10, 2))
129    if mode == 'pynative':
130        context.set_context(mode=ms.PYNATIVE_MODE)
131        grad = chunk_backward_func(x, chunks, 0)
132    elif mode == 'KBK':
133        context.set_context(mode=ms.GRAPH_MODE)
134        grad = (jit(chunk_backward_func, jit_config=JitConfig(jit_level="O0")))(x, chunks, 0)
135    else:
136        context.set_context(mode=ms.GRAPH_MODE)
137        grad = chunk_backward_func(x, chunks, 0)
138    assert np.allclose(grad.asnumpy(), expect_grad)
139    assert grad.asnumpy().shape == x.shape
140
141
142@pytest.mark.level0
143@pytest.mark.env_onecard
144@pytest.mark.platform_arm_ascend_training
145@pytest.mark.platform_x86_ascend_training
146@pytest.mark.parametrize("context_mode", [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
147def test_chunk_forward_dynamic_shape(context_mode):
148    """
149    Feature: chunk ops.
150    Description: test ops chunk with dynamic shape tensor input.
151    Expectation: output the right result.
152    """
153    context.set_context(mode=context_mode)
154    input_dyn = Tensor(shape=[4, None, None], dtype=ms.int64)
155    chunks = 3
156    dims = 0
157    test_cell = test_utils.to_cell_obj(mint.chunk)
158    test_cell.set_inputs(input_dyn, chunks, dims)
159    input_tensor = Tensor(np.arange(60).reshape((4, 3, 5)).astype(np.int64))
160    out = test_cell(input_tensor, chunks, dims)
161    expect_output = [np.array(np.arange(30).reshape((2, 3, 5)), dtype=np.float32),
162                     np.array(np.arange(30, 60).reshape((2, 3, 5)), dtype=np.float32)]
163    for res, exp in zip(out, expect_output):
164        assert np.allclose(res.asnumpy(), exp)
165
166    input_tensor = Tensor(np.arange(24).reshape((4, 2, 3)).astype(np.int64))
167    out = test_cell(input_tensor, chunks, dims)
168    expect_output = [np.array(np.arange(12).reshape((2, 2, 3)), dtype=np.float32),
169                     np.array(np.arange(12, 24).reshape((2, 2, 3)), dtype=np.float32)]
170    for res, exp in zip(out, expect_output):
171        assert np.allclose(res.asnumpy(), exp)
172
173    if context_mode == ms.GRAPH_MODE:
174        dims = 2
175        test_cell.set_inputs(input_dyn, chunks, dims)
176        input_tensor = Tensor(np.arange(24).reshape((4, 2, 3)).astype(np.int64))
177        with pytest.raises(RuntimeError):
178            _ = test_cell(input_tensor, chunks, dims)
179
180
181@pytest.mark.level1
182@pytest.mark.env_onecard
183@pytest.mark.platform_arm_ascend_training
184@pytest.mark.platform_x86_ascend_training
185@pytest.mark.parametrize("context_mode", [ms.GRAPH_MODE])
186def test_chunk_forward_dynamic_rank(context_mode):
187    """
188    Feature: chunk ops.
189    Description: test ops chunk with dynamic shape tensor input.
190    Expectation: output the right result.
191    """
192    context.set_context(mode=context_mode)
193    input_dyn = Tensor(shape=None, dtype=ms.int64)
194    chunks = 3
195    dims = 0
196    test_cell = test_utils.to_cell_obj(mint.chunk)
197    test_cell.set_inputs(input_dyn, chunks, dims)
198    input_tensor = Tensor(np.arange(24).reshape((4, 2, 3)).astype(np.int64))
199    with pytest.raises(RuntimeError):
200        _ = test_cell(input_tensor, chunks, dims)
201
202
203@pytest.mark.level1
204@pytest.mark.env_onecard
205@pytest.mark.platform_arm_ascend_training
206@pytest.mark.platform_x86_ascend_training
207@pytest.mark.parametrize("context_mode", [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
208def test_chunk_backward_dynamic_shape(context_mode):
209    """
210    Feature: chunk ops.
211    Description: test ops chunk with dynamic shape tensor input.
212    Expectation: output the right result.
213    """
214    context.set_context(mode=context_mode)
215    input_dyn = Tensor(shape=[None, 4, None], dtype=ms.float32)
216    chunks = 3
217    dims = 1
218    test_cell = test_utils.to_cell_obj(ops.grad(mint.chunk, (0,)))
219    test_cell.set_inputs(input_dyn, chunks, dims)
220    input_tensor = Tensor(np.arange(60).reshape((3, 4, 5)).astype(np.float32))
221    out = test_cell(input_tensor, chunks, dims)
222    expect_output = np.ones((3, 4, 5))
223    assert np.allclose(out.asnumpy(), expect_output)
224
225    input_tensor = Tensor(np.arange(24).reshape((2, 4, 3)).astype(np.float32))
226    out = test_cell(input_tensor, chunks, dims)
227    expect_output = np.ones((2, 4, 3))
228    assert np.allclose(out.asnumpy(), expect_output)
229
230
231@pytest.mark.level1
232@pytest.mark.env_onecard
233@pytest.mark.platform_arm_ascend_training
234@pytest.mark.platform_x86_ascend_training
235@pytest.mark.parametrize("context_mode", [ms.GRAPH_MODE])
236def test_chunk_backward_dynamic_rank(context_mode):
237    """
238    Feature: chunk ops.
239    Description: test ops chunk with dynamic shape tensor input.
240    Expectation: output the right result.
241    """
242    context.set_context(mode=context_mode)
243    input_dyn = Tensor(shape=None, dtype=ms.float64)
244    chunks = 3
245    dims = 1
246    test_cell = test_utils.to_cell_obj(ops.grad(mint.chunk, (0,)))
247    test_cell.set_inputs(input_dyn, chunks, dims)
248    input_tensor = Tensor(np.arange(24).reshape((4, 2, 3)).astype(np.float64))
249    with pytest.raises(RuntimeError):
250        _ = test_cell(input_tensor, chunks, dims)
251
252
253@pytest.mark.level1
254@pytest.mark.env_onecard
255@pytest.mark.platform_arm_ascend_training
256@pytest.mark.platform_x86_ascend_training
257@pytest.mark.parametrize("context_mode", [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
258def test_chunk_forward_mutable(context_mode):
259    """
260    Feature: Auto grad.
261    Description: test auto grad of op Split.
262    Expectation: expect correct result.
263    """
264    context.set_context(mode=context_mode)
265    x = Tensor(np.arange(20).reshape(10, 2), dtype=ms.float32)
266    chunks = 2
267    dims = 0
268    expect = [np.array(np.arange(10).reshape((5, 2)), dtype=np.float32),
269              np.array(np.arange(10, 20).reshape((5, 2)), dtype=np.float32)]
270    if context_mode == ms.GRAPH_MODE:
271        with pytest.raises(RuntimeError):
272            _ = chunk_forward_func(x, ms.mutable(chunks), dims)
273
274        with pytest.raises(RuntimeError):
275            _ = chunk_forward_func(x, chunks, ms.mutable(dims))
276    else:
277        out = chunk_forward_func(x, ms.mutable(chunks), ms.mutable(dims))
278        for res, exp in zip(out, expect):
279            assert np.allclose(res.asnumpy(), exp)
280
281'''
282# Dynamic length tuple output is not support for now
283@pytest.mark.level2
284@pytest.mark.platform_x86_cpu
285@pytest.mark.platform_arm_cpu
286@pytest.mark.platform_x86_gpu_training
287@pytest.mark.platform_arm_ascend_training
288@pytest.mark.platform_x86_ascend_training
289@pytest.mark.env_onecard
290@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
291def test_f_chunk_dynamic(mode):
292    """
293    Feature: test dynamic split.
294    Description: test auto grad of op Split.
295    Expectation: expect correct result.
296    """
297    np_x1 = np.arange(4 * 4).reshape(4, 4)
298    x1 = ms.Tensor(np_x1, ms.float32)
299    np_x2 = np.arange(4 * 4 * 5).reshape(4, 4, 5)
300    x2 = ms.Tensor(np_x2, ms.float32)
301    TEST_OP(chunk_forward_func, [[x1, 2, 1], [x2, 4, 0]], mode = mode, grad = False)
302    TEST_OP(chunk_forward_func, [[x1, 2, 1], [x2, 4, 0]], mode = mode, grad = True)
303'''
304