1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16import numpy as np 17import pytest 18 19import mindspore.context as context 20import mindspore.nn as nn 21from mindspore import Tensor 22from mindspore.ops import operations as P 23 24context.set_context(mode=context.GRAPH_MODE, device_target='CPU') 25 26 27class OpNetWrapper(nn.Cell): 28 def __init__(self, op): 29 super(OpNetWrapper, self).__init__() 30 self.op = op 31 32 def construct(self, *inputs): 33 return self.op(*inputs) 34 35 36@pytest.mark.level0 37@pytest.mark.platform_x86_cpu 38@pytest.mark.env_onecard 39def test_out1_axis0(): 40 op = P.Split(0, 1) 41 op_wrapper = OpNetWrapper(op) 42 43 input_x = Tensor(np.arange(24).astype(np.int32).reshape((2, 2, 6))) 44 outputs = op_wrapper(input_x) 45 46 print(outputs) 47 assert outputs[0].shape == (2, 2, 6) 48 assert np.allclose(outputs[0].asnumpy()[0, 0, :], [0, 1, 2, 3, 4, 5]) 49 50 51@pytest.mark.level0 52@pytest.mark.platform_x86_cpu 53@pytest.mark.env_onecard 54def test_out2_axis2(): 55 op = P.Split(2, 2) 56 op_wrapper = OpNetWrapper(op) 57 58 input_x = Tensor(np.arange(24).astype(np.int32).reshape((2, 2, 6))) 59 outputs = op_wrapper(input_x) 60 61 print(outputs) 62 assert outputs[0].shape == (2, 2, 3) 63 assert outputs[1].shape == (2, 2, 3) 64 assert np.allclose(outputs[0].asnumpy()[0, 0, :], [0, 1, 2]) 65 assert np.allclose(outputs[1].asnumpy()[0, 0, :], [3, 4, 5]) 66 67 68@pytest.mark.level0 69@pytest.mark.platform_x86_cpu 70@pytest.mark.env_onecard 71def test_out2_axis1neg(): 72 op = P.Split(-1, 2) 73 op_wrapper = OpNetWrapper(op) 74 75 input_x = Tensor(np.arange(24).astype(np.float32).reshape((2, 2, 6))) 76 outputs = op_wrapper(input_x) 77 78 print(outputs) 79 assert np.allclose(outputs[0].asnumpy()[0, :, :], [[0., 1., 2.], [6., 7., 8.]]) 80 assert np.allclose(outputs[1].asnumpy()[0, :, :], [[3., 4., 5.], [9., 10., 11.]]) 81 82 83@pytest.mark.level0 84@pytest.mark.platform_x86_cpu 85@pytest.mark.env_onecard 86def test_out_float32(): 87 op = P.Split(5, 2) 88 op_wrapper = OpNetWrapper(op) 89 90 input_x = Tensor(np.arange(192).astype(np.float32).reshape((2, 2, 2, 2, 2, 6))) 91 outputs = op_wrapper(input_x) 92 93 assert np.allclose(outputs[0].asnumpy()[0, 0, 0, 0, 0, :], [0., 1., 2.]) 94 assert np.allclose(outputs[1].asnumpy()[0, 0, 0, 0, 0, :], [3., 4., 5.]) 95 96 op = P.Split(5, 3) 97 op_wrapper = OpNetWrapper(op) 98 outputs = op_wrapper(input_x) 99 100 assert np.allclose(outputs[0].asnumpy()[0, 0, 0, 0, 0, :], [0., 1.]) 101 assert np.allclose(outputs[1].asnumpy()[0, 0, 0, 0, 0, :], [2., 3.]) 102 assert np.allclose(outputs[2].asnumpy()[0, 0, 0, 0, 0, :], [4., 5.]) 103 104 105@pytest.mark.level0 106@pytest.mark.platform_x86_cpu 107@pytest.mark.env_onecard 108def test_out_float64(): 109 op = P.Split(5, 2) 110 op_wrapper = OpNetWrapper(op) 111 112 input_x = Tensor(np.arange(192).astype(np.float64).reshape((2, 2, 2, 2, 2, 6))) 113 outputs = op_wrapper(input_x) 114 115 assert np.allclose(outputs[0].asnumpy()[0, 0, 0, 0, 0, :], [0., 1., 2.]) 116 assert np.allclose(outputs[1].asnumpy()[0, 0, 0, 0, 0, :], [3., 4., 5.]) 117 118 op = P.Split(5, 3) 119 op_wrapper = OpNetWrapper(op) 120 outputs = op_wrapper(input_x) 121 122 assert np.allclose(outputs[0].asnumpy()[0, 0, 0, 0, 0, :], [0., 1.]) 123 assert np.allclose(outputs[1].asnumpy()[0, 0, 0, 0, 0, :], [2., 3.]) 124 assert np.allclose(outputs[2].asnumpy()[0, 0, 0, 0, 0, :], [4., 5.]) 125 126 127@pytest.mark.level0 128@pytest.mark.platform_x86_cpu 129@pytest.mark.env_onecard 130def test_out_float16(): 131 op = P.Split(-1, 2) 132 op_wrapper = OpNetWrapper(op) 133 134 input_x = Tensor(np.arange(320).astype(np.float16).reshape((2, 2, 2, 2, 2, 10))) 135 outputs = op_wrapper(input_x) 136 137 assert np.allclose(outputs[0].asnumpy()[0, 0, 0, 0, 0, :], [0., 1., 2., 3., 4.]) 138 assert np.allclose(outputs[1].asnumpy()[0, 0, 0, 0, 0, :], [5., 6., 7., 8., 9.]) 139 140 op = P.Split(-1, 5) 141 op_wrapper = OpNetWrapper(op) 142 outputs = op_wrapper(input_x) 143 144 assert np.allclose(outputs[0].asnumpy()[0, 0, 0, 0, 0, :], [0., 1.]) 145 assert np.allclose(outputs[1].asnumpy()[0, 0, 0, 0, 0, :], [2., 3.]) 146 assert np.allclose(outputs[2].asnumpy()[0, 0, 0, 0, 0, :], [4., 5.]) 147 assert np.allclose(outputs[3].asnumpy()[0, 0, 0, 0, 0, :], [6., 7.]) 148 assert np.allclose(outputs[4].asnumpy()[0, 0, 0, 0, 0, :], [8., 9.]) 149 150 151@pytest.mark.level0 152@pytest.mark.platform_x86_cpu 153@pytest.mark.env_onecard 154def test_out_int32(): 155 op = P.Split(5, 2) 156 op_wrapper = OpNetWrapper(op) 157 158 input_x = Tensor(np.arange(192).astype(np.int32).reshape((2, 2, 2, 2, 2, 6))) 159 outputs = op_wrapper(input_x) 160 161 assert np.allclose(outputs[0].asnumpy()[0, 0, 0, 0, 0, :], [0, 1, 2]) 162 assert np.allclose(outputs[1].asnumpy()[0, 0, 0, 0, 0, :], [3, 4, 5]) 163 164 op = P.Split(5, 3) 165 op_wrapper = OpNetWrapper(op) 166 outputs = op_wrapper(input_x) 167 168 assert np.allclose(outputs[0].asnumpy()[1, 0, 0, 0, 0, :], [96, 97]) 169 assert np.allclose(outputs[1].asnumpy()[1, 0, 0, 0, 0, :], [98, 99]) 170 assert np.allclose(outputs[2].asnumpy()[1, 0, 0, 0, 0, :], [100, 101]) 171 172 173@pytest.mark.level0 174@pytest.mark.platform_x86_cpu 175@pytest.mark.env_onecard 176def test_out_int64(): 177 op = P.Split(5, 2) 178 op_wrapper = OpNetWrapper(op) 179 180 input_x = Tensor(np.arange(192).astype(np.int64).reshape((2, 2, 2, 2, 2, 6))) 181 outputs = op_wrapper(input_x) 182 183 assert np.allclose(outputs[0].asnumpy()[0, 0, 0, 0, 0, :], [0, 1, 2]) 184 assert np.allclose(outputs[1].asnumpy()[0, 0, 0, 0, 0, :], [3, 4, 5]) 185 186 op = P.Split(5, 3) 187 op_wrapper = OpNetWrapper(op) 188 outputs = op_wrapper(input_x) 189 190 assert np.allclose(outputs[0].asnumpy()[1, 0, 0, 0, 0, :], [96, 97]) 191 assert np.allclose(outputs[1].asnumpy()[1, 0, 0, 0, 0, :], [98, 99]) 192 assert np.allclose(outputs[2].asnumpy()[1, 0, 0, 0, 0, :], [100, 101]) 193 194 195@pytest.mark.level0 196@pytest.mark.platform_x86_cpu 197@pytest.mark.env_onecard 198def test_out_uint32(): 199 op = P.Split(-1, 2) 200 op_wrapper = OpNetWrapper(op) 201 202 input_x = Tensor(np.arange(320).astype(np.uint32).reshape((2, 2, 2, 2, 2, 10))) 203 outputs = op_wrapper(input_x) 204 205 assert np.allclose(outputs[0].asnumpy()[0, 0, 0, 0, 0, :], [0, 1, 2, 3, 4]) 206 assert np.allclose(outputs[1].asnumpy()[0, 0, 0, 0, 0, :], [5, 6, 7, 8, 9]) 207 208 op = P.Split(-1, 5) 209 op_wrapper = OpNetWrapper(op) 210 outputs = op_wrapper(input_x) 211 212 assert np.allclose(outputs[0].asnumpy()[1, 1, 1, 1, 1, :], [310, 311]) 213 assert np.allclose(outputs[1].asnumpy()[1, 1, 1, 1, 1, :], [312, 313]) 214 assert np.allclose(outputs[2].asnumpy()[1, 1, 1, 1, 1, :], [314, 315]) 215 assert np.allclose(outputs[3].asnumpy()[1, 1, 1, 1, 1, :], [316, 317]) 216 assert np.allclose(outputs[4].asnumpy()[1, 1, 1, 1, 1, :], [318, 319]) 217 218 op = P.Split(-2, 2) 219 op_wrapper = OpNetWrapper(op) 220 outputs = op_wrapper(input_x) 221 222 assert np.allclose(outputs[0].asnumpy()[0, 0, 0, 0, :, 0], [0]) 223 assert np.allclose(outputs[1].asnumpy()[0, 0, 0, 0, :, 1], [11]) 224 assert np.allclose(outputs[0].asnumpy()[1, 0, 0, 0, :, 2], [162]) 225 assert np.allclose(outputs[1].asnumpy()[1, 0, 0, 0, :, 3], [173]) 226 assert np.allclose(outputs[0].asnumpy()[1, 1, 0, 0, :, 4], [244]) 227 assert np.allclose(outputs[1].asnumpy()[1, 1, 0, 0, :, 5], [255]) 228 assert np.allclose(outputs[0].asnumpy()[1, 1, 1, 0, :, 6], [286]) 229 assert np.allclose(outputs[1].asnumpy()[1, 1, 1, 0, :, 7], [297]) 230 assert np.allclose(outputs[0].asnumpy()[1, 1, 1, 1, :, 8], [308]) 231 assert np.allclose(outputs[1].asnumpy()[1, 1, 1, 1, :, 9], [319]) 232 233 op = P.Split(-1, 1) 234 op_wrapper = OpNetWrapper(op) 235 input_x = Tensor(np.arange(1).astype(np.uint32)) 236 outputs = op_wrapper(input_x) 237 238 assert np.allclose(outputs[0].asnumpy(), [0]) 239 240 241if __name__ == '__main__': 242 test_out1_axis0() 243 test_out2_axis2() 244 test_out2_axis1neg() 245