1# Copyright 2024 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15import pytest 16import numpy as np 17from tests.st.utils import test_utils 18 19import mindspore as ms 20from mindspore import mint, Tensor, jit, context, JitConfig, ops 21# from tests.st.ops.dynamic_shape.test_op_utils import TEST_OP 22 23@test_utils.run_with_cell 24def chunk_forward_func(x, chunks, dim): 25 return mint.chunk(x, chunks, dim) 26 27@test_utils.run_with_cell 28def chunk_backward_func(x, chunks, dim): 29 return ops.grad(chunk_forward_func, (0,))(x, chunks, dim) 30 31 32def do_test_chunk_forward(mode): 33 """ 34 Feature: Split 35 Description: test op Split 36 Expectation: expect correct result. 37 """ 38 np_x = np.array(np.arange(10).reshape((5, 2)), dtype=np.float32) 39 x = ms.Tensor(np_x, dtype=ms.float32) 40 dims = 0 41 chunks = 3 42 expect = [np.array(np.arange(4).reshape((2, 2)), dtype=np.float32), 43 np.array(np.arange(4, 8).reshape((2, 2)), dtype=np.float32), 44 np.array(np.arange(8, 10).reshape((1, 2)), dtype=np.float32)] 45 if mode == 'pynative': 46 context.set_context(mode=ms.PYNATIVE_MODE) 47 out = chunk_forward_func(x, chunks, dims) 48 elif mode == 'KBK': 49 context.set_context(mode=ms.GRAPH_MODE) 50 out = (jit(chunk_forward_func, jit_config=JitConfig(jit_level="O0")))(x, chunks, dims) 51 else: 52 context.set_context(mode=ms.GRAPH_MODE) 53 out = chunk_forward_func(x, chunks, dims) 54 for res, exp in zip(out, expect): 55 assert np.allclose(res.asnumpy(), exp) 56 57 58@pytest.mark.level0 59@pytest.mark.env_onecard 60@pytest.mark.platform_arm_ascend_training 61@pytest.mark.platform_x86_ascend_training 62@pytest.mark.parametrize("mode", ['GE', 'pynative', 'KBK']) 63def test_chunk_forward_with_minus_dim(mode): 64 """ 65 Feature: Chunk 66 Description: test op Chunk 67 Expectation: expect correct result. 68 """ 69 np_x = np.array(np.arange(10).reshape((5, 2)), dtype=np.float32) 70 x = ms.Tensor(np_x, dtype=ms.float32) 71 dims = -1 72 chunks = 2 73 expect = [np.array([[0], [2], [4], [6], [8]], dtype=np.float32), 74 np.array([[1], [3], [5], [7], [9]], dtype=np.float32)] 75 if mode == 'pynative': 76 context.set_context(mode=ms.PYNATIVE_MODE) 77 out = chunk_forward_func(x, chunks, dims) 78 elif mode == 'KBK': 79 context.set_context(mode=ms.GRAPH_MODE) 80 out = (jit(chunk_forward_func, jit_config=JitConfig(jit_level="O0")))(x, chunks, dims) 81 else: 82 context.set_context(mode=ms.GRAPH_MODE) 83 out = chunk_forward_func(x, chunks, dims) 84 for res, exp in zip(out, expect): 85 assert np.allclose(res.asnumpy(), exp) 86 87 88@pytest.mark.level0 89@pytest.mark.env_onecard 90@pytest.mark.platform_arm_ascend_training 91@pytest.mark.platform_x86_ascend_training 92@pytest.mark.parametrize("mode", ['GE', 'pynative']) 93def test_chunk_forward(mode): 94 """ 95 Feature: Chunk 96 Description: test op Chunk 97 Expectation: expect correct result. 98 """ 99 do_test_chunk_forward(mode) 100 101 102@pytest.mark.level0 103@pytest.mark.env_onecard 104@pytest.mark.platform_arm_ascend_training 105@pytest.mark.platform_x86_ascend_training 106def test_chunk_forward_kbk(): 107 """ 108 Feature: Chunk 109 Description: test op Chunk 110 Expectation: expect correct result. 111 """ 112 do_test_chunk_forward('KBK') 113 114 115@pytest.mark.level0 116@pytest.mark.env_onecard 117@pytest.mark.platform_arm_ascend_training 118@pytest.mark.platform_x86_ascend_training 119@pytest.mark.parametrize("mode", ['GE', 'pynative', 'KBK']) 120def test_chunk_backward(mode): 121 """ 122 Feature: Auto grad. 123 Description: test auto grad of op Split. 124 Expectation: expect correct result. 125 """ 126 x = Tensor(np.arange(20).reshape(10, 2), dtype=ms.float32) 127 chunks = 2 128 expect_grad = np.ones((10, 2)) 129 if mode == 'pynative': 130 context.set_context(mode=ms.PYNATIVE_MODE) 131 grad = chunk_backward_func(x, chunks, 0) 132 elif mode == 'KBK': 133 context.set_context(mode=ms.GRAPH_MODE) 134 grad = (jit(chunk_backward_func, jit_config=JitConfig(jit_level="O0")))(x, chunks, 0) 135 else: 136 context.set_context(mode=ms.GRAPH_MODE) 137 grad = chunk_backward_func(x, chunks, 0) 138 assert np.allclose(grad.asnumpy(), expect_grad) 139 assert grad.asnumpy().shape == x.shape 140 141 142@pytest.mark.level0 143@pytest.mark.env_onecard 144@pytest.mark.platform_arm_ascend_training 145@pytest.mark.platform_x86_ascend_training 146@pytest.mark.parametrize("context_mode", [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) 147def test_chunk_forward_dynamic_shape(context_mode): 148 """ 149 Feature: chunk ops. 150 Description: test ops chunk with dynamic shape tensor input. 151 Expectation: output the right result. 152 """ 153 context.set_context(mode=context_mode) 154 input_dyn = Tensor(shape=[4, None, None], dtype=ms.int64) 155 chunks = 3 156 dims = 0 157 test_cell = test_utils.to_cell_obj(mint.chunk) 158 test_cell.set_inputs(input_dyn, chunks, dims) 159 input_tensor = Tensor(np.arange(60).reshape((4, 3, 5)).astype(np.int64)) 160 out = test_cell(input_tensor, chunks, dims) 161 expect_output = [np.array(np.arange(30).reshape((2, 3, 5)), dtype=np.float32), 162 np.array(np.arange(30, 60).reshape((2, 3, 5)), dtype=np.float32)] 163 for res, exp in zip(out, expect_output): 164 assert np.allclose(res.asnumpy(), exp) 165 166 input_tensor = Tensor(np.arange(24).reshape((4, 2, 3)).astype(np.int64)) 167 out = test_cell(input_tensor, chunks, dims) 168 expect_output = [np.array(np.arange(12).reshape((2, 2, 3)), dtype=np.float32), 169 np.array(np.arange(12, 24).reshape((2, 2, 3)), dtype=np.float32)] 170 for res, exp in zip(out, expect_output): 171 assert np.allclose(res.asnumpy(), exp) 172 173 if context_mode == ms.GRAPH_MODE: 174 dims = 2 175 test_cell.set_inputs(input_dyn, chunks, dims) 176 input_tensor = Tensor(np.arange(24).reshape((4, 2, 3)).astype(np.int64)) 177 with pytest.raises(RuntimeError): 178 _ = test_cell(input_tensor, chunks, dims) 179 180 181@pytest.mark.level1 182@pytest.mark.env_onecard 183@pytest.mark.platform_arm_ascend_training 184@pytest.mark.platform_x86_ascend_training 185@pytest.mark.parametrize("context_mode", [ms.GRAPH_MODE]) 186def test_chunk_forward_dynamic_rank(context_mode): 187 """ 188 Feature: chunk ops. 189 Description: test ops chunk with dynamic shape tensor input. 190 Expectation: output the right result. 191 """ 192 context.set_context(mode=context_mode) 193 input_dyn = Tensor(shape=None, dtype=ms.int64) 194 chunks = 3 195 dims = 0 196 test_cell = test_utils.to_cell_obj(mint.chunk) 197 test_cell.set_inputs(input_dyn, chunks, dims) 198 input_tensor = Tensor(np.arange(24).reshape((4, 2, 3)).astype(np.int64)) 199 with pytest.raises(RuntimeError): 200 _ = test_cell(input_tensor, chunks, dims) 201 202 203@pytest.mark.level1 204@pytest.mark.env_onecard 205@pytest.mark.platform_arm_ascend_training 206@pytest.mark.platform_x86_ascend_training 207@pytest.mark.parametrize("context_mode", [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) 208def test_chunk_backward_dynamic_shape(context_mode): 209 """ 210 Feature: chunk ops. 211 Description: test ops chunk with dynamic shape tensor input. 212 Expectation: output the right result. 213 """ 214 context.set_context(mode=context_mode) 215 input_dyn = Tensor(shape=[None, 4, None], dtype=ms.float32) 216 chunks = 3 217 dims = 1 218 test_cell = test_utils.to_cell_obj(ops.grad(mint.chunk, (0,))) 219 test_cell.set_inputs(input_dyn, chunks, dims) 220 input_tensor = Tensor(np.arange(60).reshape((3, 4, 5)).astype(np.float32)) 221 out = test_cell(input_tensor, chunks, dims) 222 expect_output = np.ones((3, 4, 5)) 223 assert np.allclose(out.asnumpy(), expect_output) 224 225 input_tensor = Tensor(np.arange(24).reshape((2, 4, 3)).astype(np.float32)) 226 out = test_cell(input_tensor, chunks, dims) 227 expect_output = np.ones((2, 4, 3)) 228 assert np.allclose(out.asnumpy(), expect_output) 229 230 231@pytest.mark.level1 232@pytest.mark.env_onecard 233@pytest.mark.platform_arm_ascend_training 234@pytest.mark.platform_x86_ascend_training 235@pytest.mark.parametrize("context_mode", [ms.GRAPH_MODE]) 236def test_chunk_backward_dynamic_rank(context_mode): 237 """ 238 Feature: chunk ops. 239 Description: test ops chunk with dynamic shape tensor input. 240 Expectation: output the right result. 241 """ 242 context.set_context(mode=context_mode) 243 input_dyn = Tensor(shape=None, dtype=ms.float64) 244 chunks = 3 245 dims = 1 246 test_cell = test_utils.to_cell_obj(ops.grad(mint.chunk, (0,))) 247 test_cell.set_inputs(input_dyn, chunks, dims) 248 input_tensor = Tensor(np.arange(24).reshape((4, 2, 3)).astype(np.float64)) 249 with pytest.raises(RuntimeError): 250 _ = test_cell(input_tensor, chunks, dims) 251 252 253@pytest.mark.level1 254@pytest.mark.env_onecard 255@pytest.mark.platform_arm_ascend_training 256@pytest.mark.platform_x86_ascend_training 257@pytest.mark.parametrize("context_mode", [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) 258def test_chunk_forward_mutable(context_mode): 259 """ 260 Feature: Auto grad. 261 Description: test auto grad of op Split. 262 Expectation: expect correct result. 263 """ 264 context.set_context(mode=context_mode) 265 x = Tensor(np.arange(20).reshape(10, 2), dtype=ms.float32) 266 chunks = 2 267 dims = 0 268 expect = [np.array(np.arange(10).reshape((5, 2)), dtype=np.float32), 269 np.array(np.arange(10, 20).reshape((5, 2)), dtype=np.float32)] 270 if context_mode == ms.GRAPH_MODE: 271 with pytest.raises(RuntimeError): 272 _ = chunk_forward_func(x, ms.mutable(chunks), dims) 273 274 with pytest.raises(RuntimeError): 275 _ = chunk_forward_func(x, chunks, ms.mutable(dims)) 276 else: 277 out = chunk_forward_func(x, ms.mutable(chunks), ms.mutable(dims)) 278 for res, exp in zip(out, expect): 279 assert np.allclose(res.asnumpy(), exp) 280 281''' 282# Dynamic length tuple output is not support for now 283@pytest.mark.level2 284@pytest.mark.platform_x86_cpu 285@pytest.mark.platform_arm_cpu 286@pytest.mark.platform_x86_gpu_training 287@pytest.mark.platform_arm_ascend_training 288@pytest.mark.platform_x86_ascend_training 289@pytest.mark.env_onecard 290@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) 291def test_f_chunk_dynamic(mode): 292 """ 293 Feature: test dynamic split. 294 Description: test auto grad of op Split. 295 Expectation: expect correct result. 296 """ 297 np_x1 = np.arange(4 * 4).reshape(4, 4) 298 x1 = ms.Tensor(np_x1, ms.float32) 299 np_x2 = np.arange(4 * 4 * 5).reshape(4, 4, 5) 300 x2 = ms.Tensor(np_x2, ms.float32) 301 TEST_OP(chunk_forward_func, [[x1, 2, 1], [x2, 4, 0]], mode = mode, grad = False) 302 TEST_OP(chunk_forward_func, [[x1, 2, 1], [x2, 4, 0]], mode = mode, grad = True) 303''' 304