1# Copyright 2022 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16import numpy as np 17import pytest 18 19import mindspore as ms 20import mindspore.nn as nn 21import mindspore.ops as ops 22from mindspore.ops.function.array_func import split_ext 23 24class TensorSplitNet(nn.Cell): 25 def construct(self, x, indices_or_sections, axis=0): 26 out = ops.tensor_split(x, indices_or_sections, axis) 27 return out 28 29 30@pytest.mark.level2 31@pytest.mark.platform_x86_cpu 32@pytest.mark.platform_arm_cpu 33@pytest.mark.platform_x86_gpu_training 34@pytest.mark.platform_arm_ascend_training 35@pytest.mark.platform_x86_ascend_training 36@pytest.mark.env_onecard 37@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) 38def test_f_tensor_split_int(mode): 39 """ 40 Feature: tensor_split 41 Description: Verify the result of tensor_split when the type of `indices_or_sections` is int. 42 Expectation: success 43 """ 44 ms.set_context(mode=mode) 45 net = TensorSplitNet() 46 a = np.array(np.arange(20).reshape((10, 2)), dtype=np.float32) 47 x = ms.Tensor(a, dtype=ms.float32) 48 indices_or_sections = 3 49 out = net(x, indices_or_sections) 50 expect = np.array_split(a, indices_or_sections) 51 for res, exp in zip(out, expect): 52 assert np.allclose(res.asnumpy(), exp) 53 54 55@pytest.mark.level2 56@pytest.mark.platform_x86_cpu 57@pytest.mark.platform_arm_cpu 58@pytest.mark.platform_x86_gpu_training 59@pytest.mark.platform_arm_ascend_training 60@pytest.mark.platform_x86_ascend_training 61@pytest.mark.env_onecard 62@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) 63def test_f_tensor_split_list(mode): 64 """ 65 Feature: tensor_split 66 Description: Verify the result of tensor_split when the type of `indices_or_sections` is tuple(int) or tuple(int). 67 Expectation: success 68 """ 69 ms.set_context(mode=mode) 70 net = TensorSplitNet() 71 a = np.array(np.arange(10).reshape((5, 2)), dtype=np.float32) 72 x = ms.Tensor(a, dtype=ms.float32) 73 indices_or_sections = [2, 4] 74 out = net(x, indices_or_sections) 75 expect = np.array_split(a, indices_or_sections) 76 for res, exp in zip(out, expect): 77 assert np.allclose(res.asnumpy(), exp) 78 79 80@pytest.mark.level1 81@pytest.mark.platform_x86_cpu 82@pytest.mark.platform_arm_cpu 83@pytest.mark.platform_x86_gpu_training 84@pytest.mark.platform_arm_ascend_training 85@pytest.mark.platform_x86_ascend_training 86@pytest.mark.env_onecard 87@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) 88def test_f_tensor_split_list2(mode): 89 """ 90 Feature: tensor_split 91 Description: Verify the result of tensor_split when `indices_or_sections` is out of normal length. 92 Expectation: success 93 """ 94 ms.set_context(mode=mode) 95 a = np.arange(10).reshape((5, 2)) 96 indices_or_sections = [1, 4, 7] 97 net = TensorSplitNet() 98 x = ms.Tensor(a, dtype=ms.int64) 99 out = net(x, indices_or_sections) 100 expect = np.array_split(a, indices_or_sections) 101 for res, exp in zip(out, expect): 102 assert np.allclose(res.asnumpy(), exp) 103 104 105@pytest.mark.level2 106@pytest.mark.platform_x86_cpu 107@pytest.mark.platform_arm_cpu 108@pytest.mark.platform_x86_gpu_training 109@pytest.mark.platform_arm_ascend_training 110@pytest.mark.platform_x86_ascend_training 111@pytest.mark.env_onecard 112@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) 113def test_f_tensor_split_list3(mode): 114 """ 115 Feature: tensor_split 116 Description: Verify the result of tensor_split when `indices_or_sections` has negative. 117 Expectation: success 118 """ 119 ms.set_context(mode=mode) 120 a = np.arange(10).reshape((5, 2)) 121 indices_or_sections = [-5, 4, 3, 7] 122 net = TensorSplitNet() 123 x = ms.Tensor(a, dtype=ms.int64) 124 out = net(x, indices_or_sections) 125 expect = np.array_split(a, indices_or_sections) 126 for res, exp in zip(out, expect): 127 assert np.allclose(res.asnumpy(), exp) 128 129 130@pytest.mark.level2 131@pytest.mark.platform_x86_cpu 132@pytest.mark.platform_arm_cpu 133@pytest.mark.platform_x86_gpu_training 134@pytest.mark.platform_arm_ascend_training 135@pytest.mark.platform_x86_ascend_training 136@pytest.mark.env_onecard 137@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) 138def test_f_tensor_split_list4(mode): 139 """ 140 Feature: tensor_split 141 Description: Verify the result of tensor_split when `indices_or_sections` has negative number and out of range. 142 Expectation: success 143 """ 144 ms.set_context(mode=mode) 145 a = np.arange(12) 146 indices_or_sections = [-18, -14, -10] 147 net = TensorSplitNet() 148 x = ms.Tensor(a, dtype=ms.int64) 149 out = net(x, indices_or_sections) 150 expect = np.array_split(a, indices_or_sections) 151 for res, exp in zip(out, expect): 152 assert np.allclose(res.asnumpy(), exp) 153 154 155@pytest.mark.level2 156@pytest.mark.platform_x86_cpu 157@pytest.mark.platform_arm_cpu 158@pytest.mark.platform_x86_gpu_training 159@pytest.mark.platform_arm_ascend_training 160@pytest.mark.platform_x86_ascend_training 161@pytest.mark.env_onecard 162@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) 163def test_f_tensor_split_list5(mode): 164 """ 165 Feature: tensor_split 166 Description: Verify the result of tensor_split when `indices_or_sections` has special order. 167 Expectation: success 168 """ 169 ms.set_context(mode=mode) 170 a = np.arange(12) 171 indices_or_sections = [-18, -10, -14, 2] 172 net = TensorSplitNet() 173 x = ms.Tensor(a, dtype=ms.int64) 174 out = net(x, indices_or_sections) 175 expect = np.array_split(a, indices_or_sections) 176 for res, exp in zip(out, expect): 177 assert np.allclose(res.asnumpy(), exp) 178 179 180class VSplitNet(nn.Cell): 181 def construct(self, x, indices_or_sections): 182 out = ops.vsplit(x, indices_or_sections) 183 return out 184 185 186@pytest.mark.level1 187@pytest.mark.platform_x86_cpu 188@pytest.mark.platform_arm_cpu 189@pytest.mark.platform_x86_gpu_training 190@pytest.mark.platform_arm_ascend_training 191@pytest.mark.platform_x86_ascend_training 192@pytest.mark.env_onecard 193@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) 194def test_f_vsplit_int(mode): 195 """ 196 Feature: vsplit 197 Description: Verify the result of tensor_split when the type of `indices_or_sections` is int. 198 Expectation: success 199 """ 200 ms.set_context(mode=mode) 201 net = VSplitNet() 202 a = np.arange(20).reshape((10, 2)) 203 x = ms.Tensor(a, dtype=ms.float32) 204 indices_or_sections = 3 205 out = net(x, indices_or_sections) 206 expect = np.array_split(a, indices_or_sections, axis=0) 207 for res, exp in zip(out, expect): 208 assert np.allclose(res.asnumpy(), exp) 209 210 211@pytest.mark.level2 212@pytest.mark.platform_x86_cpu 213@pytest.mark.platform_arm_cpu 214@pytest.mark.platform_x86_gpu_training 215@pytest.mark.platform_arm_ascend_training 216@pytest.mark.platform_x86_ascend_training 217@pytest.mark.env_onecard 218@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) 219def test_f_vsplit_list(mode): 220 """ 221 Feature: vsplit 222 Description: Verify the result of tensor_split when the type of `indices_or_sections` is tuple(int) or tuple(int). 223 Expectation: success 224 """ 225 ms.set_context(mode=mode) 226 net = VSplitNet() 227 a = np.array(np.arange(10).reshape((5, 2)), dtype=np.float32) 228 x = ms.Tensor(a, dtype=ms.float32) 229 indices_or_sections = [2, 4] 230 out = net(x, indices_or_sections) 231 expect = np.array_split(a, indices_or_sections, axis=0) 232 for res, exp in zip(out, expect): 233 assert np.allclose(res.asnumpy(), exp) 234 235 236class HSplitNet(nn.Cell): 237 def construct(self, x, indices_or_sections): 238 out = ops.hsplit(x, indices_or_sections) 239 return out 240 241 242@pytest.mark.level2 243@pytest.mark.platform_x86_cpu 244@pytest.mark.platform_arm_cpu 245@pytest.mark.platform_x86_gpu_training 246@pytest.mark.platform_arm_ascend_training 247@pytest.mark.platform_x86_ascend_training 248@pytest.mark.env_onecard 249@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) 250def test_f_hsplit_int(mode): 251 """ 252 Feature: hsplit 253 Description: Verify the result of tensor_split when the type of `indices_or_sections` is int. 254 Expectation: success 255 """ 256 ms.set_context(mode=mode) 257 net = HSplitNet() 258 a = np.array(np.arange(20).reshape((2, 10)), dtype=np.float32) 259 x = ms.Tensor(a, dtype=ms.float32) 260 indices_or_sections = 3 261 out = net(x, indices_or_sections) 262 expect = np.array_split(a, indices_or_sections, axis=1) 263 for res, exp in zip(out, expect): 264 assert np.allclose(res.asnumpy(), exp) 265 266 267@pytest.mark.level2 268@pytest.mark.platform_x86_cpu 269@pytest.mark.platform_arm_cpu 270@pytest.mark.platform_x86_gpu_training 271@pytest.mark.platform_arm_ascend_training 272@pytest.mark.platform_x86_ascend_training 273@pytest.mark.env_onecard 274@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) 275def test_f_hsplit_list(mode): 276 """ 277 Feature: hsplit 278 Description: Verify the result of tensor_split when the type of `indices_or_sections` is tuple(int) or tuple(int). 279 Expectation: success 280 """ 281 ms.set_context(mode=mode) 282 net = HSplitNet() 283 a = np.array(np.arange(10).reshape((2, 5)), dtype=np.float32) 284 x = ms.Tensor(a, dtype=ms.float32) 285 indices_or_sections = [2, 4] 286 out = net(x, indices_or_sections) 287 expect = np.array_split(a, indices_or_sections, axis=1) 288 for res, exp in zip(out, expect): 289 assert np.allclose(res.asnumpy(), exp) 290 291 292class DSplitNet(nn.Cell): 293 def construct(self, x, indices_or_sections): 294 out = ops.dsplit(x, indices_or_sections) 295 return out 296 297 298@pytest.mark.level2 299@pytest.mark.platform_x86_cpu 300@pytest.mark.platform_arm_cpu 301@pytest.mark.platform_x86_gpu_training 302@pytest.mark.platform_arm_ascend_training 303@pytest.mark.platform_x86_ascend_training 304@pytest.mark.env_onecard 305@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) 306def test_f_dsplit_int(mode): 307 """ 308 Feature: dsplit 309 Description: Verify the result of tensor_split when the type of `indices_or_sections` is int. 310 Expectation: success 311 """ 312 ms.set_context(mode=mode) 313 net = DSplitNet() 314 a = np.array(np.arange(20).reshape((1, 2, 10)), dtype=np.float32) 315 x = ms.Tensor(a, dtype=ms.float32) 316 indices_or_sections = 3 317 out = net(x, indices_or_sections) 318 expect = np.array_split(a, indices_or_sections, axis=2) 319 for res, exp in zip(out, expect): 320 assert np.allclose(res.asnumpy(), exp) 321 322 323@pytest.mark.level2 324@pytest.mark.platform_x86_cpu 325@pytest.mark.platform_arm_cpu 326@pytest.mark.platform_x86_gpu_training 327@pytest.mark.platform_arm_ascend_training 328@pytest.mark.platform_x86_ascend_training 329@pytest.mark.env_onecard 330@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) 331def test_f_dsplit_list(mode): 332 """ 333 Feature: dsplit 334 Description: Verify the result of tensor_split when the type of `indices_or_sections` is tuple(int) or tuple(int). 335 Expectation: success 336 """ 337 ms.set_context(mode=mode) 338 net = DSplitNet() 339 a = np.array(np.arange(20).reshape((1, 2, 10)), dtype=np.float32) 340 x = ms.Tensor(a, dtype=ms.float32) 341 indices_or_sections = [2, 4] 342 out = net(x, indices_or_sections) 343 expect = np.array_split(a, indices_or_sections, axis=2) 344 for res, exp in zip(out, expect): 345 assert np.allclose(res.asnumpy(), exp) 346 347 348class SplitNet(nn.Cell): 349 def construct(self, x, split_size_or_sections, axis=0): 350 out = split_ext(x, split_size_or_sections, axis) 351 return out 352 353 354@pytest.mark.level0 355@pytest.mark.env_onecard 356@pytest.mark.platform_x86_ascend_training 357@pytest.mark.platform_arm_ascend_training 358@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) 359def test_f_split_ext_int(mode): 360 """ 361 Feature: split 362 Description: Verify the result of split. 363 Expectation: success 364 """ 365 ms.set_context(mode=mode) 366 net = SplitNet() 367 a = np.array(np.arange(20).reshape((10, 2)), dtype=np.float32) 368 x = ms.Tensor(a, dtype=ms.float32) 369 split_size_or_sections = 5 370 out = net(x, split_size_or_sections) 371 expect = [np.array(np.arange(10).reshape((5, 2)), dtype=np.float32), 372 np.array(np.arange(10, 20).reshape((5, 2)), dtype=np.float32)] 373 for res, exp in zip(out, expect): 374 assert np.allclose(res.asnumpy(), exp) 375 376@pytest.mark.level0 377@pytest.mark.env_onecard 378@pytest.mark.platform_x86_ascend_training 379@pytest.mark.platform_arm_ascend_training 380@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) 381def test_f_split_ext_int32(mode): 382 """ 383 Feature: split 384 Description: Verify the result of split. 385 Expectation: success 386 """ 387 ms.set_context(mode=mode) 388 net = SplitNet() 389 a = np.array(np.arange(2000).reshape((1000, 2)), dtype=np.float32) 390 x = ms.Tensor(a, dtype=ms.float32) 391 split_size_or_sections = 10 392 out = net(x, split_size_or_sections) 393 assert np.allclose(len(out), 100) 394 395@pytest.mark.level0 396@pytest.mark.env_onecard 397@pytest.mark.platform_x86_ascend_training 398@pytest.mark.platform_arm_ascend_training 399@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) 400def test_f_split_ext_list(mode): 401 """ 402 Feature: split 403 Description: Verify the result of split. 404 Expectation: success 405 """ 406 ms.set_context(mode=mode) 407 net = SplitNet() 408 a = np.array(np.arange(20).reshape((2, 10)), dtype=np.float32) 409 x = ms.Tensor(a, dtype=ms.float32) 410 split_size_or_sections = [2, 3, 5] 411 out = net(x, split_size_or_sections, axis=1) 412 expect = [np.array([[0, 1], [10, 11]], dtype=np.float32), 413 np.array([[2, 3, 4], [12, 13, 14]], dtype=np.float32), 414 np.array([[5, 6, 7, 8, 9], [15, 16, 17, 18, 19]], dtype=np.float32)] 415 for res, exp in zip(out, expect): 416 assert np.allclose(res.asnumpy(), exp) 417 418@pytest.mark.level0 419@pytest.mark.env_onecard 420@pytest.mark.platform_x86_ascend_training 421@pytest.mark.platform_arm_ascend_training 422@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) 423def test_f_split_ext_list32(mode): 424 """ 425 Feature: split 426 Description: Verify the result of split. 427 Expectation: success 428 """ 429 ms.set_context(mode=mode) 430 net = SplitNet() 431 a = np.array(np.arange(2000).reshape((2, 1000)), dtype=np.float32) 432 x = ms.Tensor(a, dtype=ms.float32) 433 split_size_or_sections = [10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 434 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 435 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 436 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 437 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 438 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 439 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 440 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 441 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 442 10, 10, 10, 10, 10, 10, 10, 10, 10, 10] 443 out = net(x, split_size_or_sections, axis=1) 444 assert np.allclose(len(out), 100) 445