1# Copyright 2020-2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15import numpy as np 16import pytest 17 18import mindspore.context as context 19from mindspore import Tensor 20import mindspore.nn as nn 21from mindspore.ops.operations import _inner_ops as inner 22from mindspore.ops import operations as P 23 24 25class Net(nn.Cell): 26 def __init__(self, axis=0, out_nums=1): 27 super(Net, self).__init__() 28 self.split = P.Split(axis, out_nums) 29 30 def construct(self, x): 31 return self.split(x) 32 33 34class NetDynamic(nn.Cell): 35 def __init__(self, axis=0, out_nums=1): 36 super(NetDynamic, self).__init__() 37 self.conv = inner.GpuConvertToDynamicShape() 38 self.split = P.Split(axis, out_nums) 39 40 def construct(self, x): 41 x_conv = self.conv(x) 42 x_split = self.split(x_conv) 43 return x_split 44 45 46context.set_context(mode=context.GRAPH_MODE, device_target="GPU") 47 48 49def split_basic(nptype): 50 x = np.array([[[1, -1, 1], [2, -2, 2]], 51 [[3, -3, 3], [4, -4, 4]], 52 [[5, -5, 5], [6, -6, 6]]]).astype(nptype) 53 54 split_op = Net(0, 3) 55 outputs = split_op(Tensor(x)) 56 for i, out in enumerate(outputs): 57 assert (out.asnumpy() == x[i]).all() 58 59 60@pytest.mark.level1 61@pytest.mark.platform_x86_gpu_training 62@pytest.mark.env_onecard 63def test_split_basic_float16(): 64 split_basic(np.float16) 65 66 67@pytest.mark.level0 68@pytest.mark.platform_x86_gpu_training 69@pytest.mark.env_onecard 70def test_split_basic_float32(): 71 split_basic(np.float32) 72 73 74@pytest.mark.level0 75@pytest.mark.platform_x86_gpu_training 76@pytest.mark.env_onecard 77def test_split_basic_float64(): 78 split_basic(np.float64) 79 80 81@pytest.mark.level1 82@pytest.mark.platform_x86_gpu_training 83@pytest.mark.env_onecard 84def test_split_basic_int32(): 85 split_basic(np.int32) 86 87 88@pytest.mark.level1 89@pytest.mark.platform_x86_gpu_training 90@pytest.mark.env_onecard 91def test_split_basic_uint32(): 92 split_basic(np.uint32) 93 94 95@pytest.mark.level1 96@pytest.mark.platform_x86_gpu_training 97@pytest.mark.env_onecard 98def test_split_basic_int64(): 99 split_basic(np.int64) 100 101 102@pytest.mark.level1 103@pytest.mark.platform_x86_gpu_training 104@pytest.mark.env_onecard 105def test_split_basic_bool(): 106 split_basic(np.bool) 107 108 109@pytest.mark.level0 110@pytest.mark.platform_x86_gpu_training 111@pytest.mark.env_onecard 112def test_split_4d(): 113 x_np = np.random.randn(2, 6, 4, 4).astype(np.float32) 114 y = np.split(x_np, 3, axis=1) 115 116 split_op = Net(1, 3) 117 outputs = split_op(Tensor(x_np)) 118 119 for i, out in enumerate(outputs): 120 assert (out.asnumpy() == y[i]).all() 121 122 123@pytest.mark.level0 124@pytest.mark.platform_x86_gpu_training 125@pytest.mark.env_onecard 126def test_split_dynamic(): 127 x = np.array([[[1, -1, 1], [2, -2, 2]], 128 [[3, -3, 3], [4, -4, 4]], 129 [[5, -5, 5], [6, -6, 6]]]).astype(np.float32) 130 131 net = NetDynamic(0, 3) 132 x_split = net(Tensor(x)) 133 for i, out in enumerate(x_split): 134 assert (out.asnumpy() == x[i]).all() 135 136 137@pytest.mark.level0 138@pytest.mark.platform_x86_gpu_training 139@pytest.mark.env_onecard 140def test_split_dynamic_axis1(): 141 x = np.array([[[1, -1, 1], [2, -2, 2]], 142 [[3, -3, 3], [4, -4, 4]], 143 [[5, -5, 5], [6, -6, 6]]]).astype(np.int32) 144 y = np.split(x, 2, axis=1) 145 146 net = NetDynamic(1, 2) 147 x_split = net(Tensor(x)) 148 for i, out in enumerate(x_split): 149 assert (out.asnumpy() == y[i]).all() 150 151 152@pytest.mark.level0 153@pytest.mark.platform_x86_gpu_training 154@pytest.mark.env_onecard 155def test_split_dynamic_axis2(): 156 x = np.array([[[1, -1, 1], [2, -2, 2]], 157 [[3, -3, 3], [4, -4, 4]], 158 [[5, -5, 5], [6, -6, 6]]]).astype(np.int32) 159 y = np.split(x, 3, axis=2) 160 161 net = NetDynamic(2, 3) 162 x_split = net(Tensor(x)) 163 for i, out in enumerate(x_split): 164 assert (out.asnumpy() == y[i]).all() 165 166 167@pytest.mark.level0 168@pytest.mark.platform_x86_gpu_training 169@pytest.mark.env_onecard 170def test_split_invalid_input(): 171 with pytest.raises(TypeError): 172 _ = Net(0.1, 3) 173 174 with pytest.raises(TypeError): 175 _ = Net(0, 3.0) 176 177 with pytest.raises(ValueError): 178 _ = Net(0, -3) 179 180 x = np.array([[1, 2, 3], [4, 5, 6]]).astype(np.int32) 181 split_net = Net(2, 2) 182 with pytest.raises(ValueError): 183 _ = split_net(Tensor(x)) 184 185 with pytest.raises(TypeError): 186 _ = split_net(x) 187