1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15""" 16Testing Slice op in DE 17""" 18import numpy as np 19import pytest 20 21import mindspore.dataset as ds 22import mindspore.dataset.transforms.c_transforms as ops 23 24 25def slice_compare(array, indexing, expected_array): 26 data = ds.NumpySlicesDataset([array]) 27 if isinstance(indexing, list) and indexing and not isinstance(indexing[0], int): 28 data = data.map(operations=ops.Slice(*indexing)) 29 else: 30 data = data.map(operations=ops.Slice(indexing)) 31 for d in data.create_dict_iterator(output_numpy=True): 32 np.testing.assert_array_equal(expected_array, d['column_0']) 33 34 35def test_slice_all(): 36 slice_compare([1, 2, 3, 4, 5], None, [1, 2, 3, 4, 5]) 37 slice_compare([1, 2, 3, 4, 5], ..., [1, 2, 3, 4, 5]) 38 slice_compare([1, 2, 3, 4, 5], True, [1, 2, 3, 4, 5]) 39 40 41def test_slice_single_index(): 42 slice_compare([1, 2, 3, 4, 5], 0, [1]) 43 slice_compare([1, 2, 3, 4, 5], -3, [3]) 44 slice_compare([1, 2, 3, 4, 5], [0], [1]) 45 46 47def test_slice_indices_multidim(): 48 slice_compare([[1, 2, 3, 4, 5]], [[0], [0]], 1) 49 slice_compare([[1, 2, 3, 4, 5]], [[0], [0, 3]], [[1, 4]]) 50 slice_compare([[1, 2, 3, 4, 5]], [0], [[1, 2, 3, 4, 5]]) 51 slice_compare([[1, 2, 3, 4, 5]], [[0], [0, -4]], [[1, 2]]) 52 53 54def test_slice_list_index(): 55 slice_compare([1, 2, 3, 4, 5], [0, 1, 4], [1, 2, 5]) 56 slice_compare([1, 2, 3, 4, 5], [4, 1, 0], [5, 2, 1]) 57 slice_compare([1, 2, 3, 4, 5], [-1, 1, 0], [5, 2, 1]) 58 slice_compare([1, 2, 3, 4, 5], [-1, -4, -2], [5, 2, 4]) 59 slice_compare([1, 2, 3, 4, 5], [3, 3, 3], [4, 4, 4]) 60 61 62def test_slice_index_and_slice(): 63 slice_compare([[1, 2, 3, 4, 5]], [slice(0, 1), [4]], [[5]]) 64 slice_compare([[1, 2, 3, 4, 5]], [[0], slice(0, 2)], [[1, 2]]) 65 slice_compare([[1, 2, 3, 4], [5, 6, 7, 8]], [[1], slice(2, 4, 1)], [[7, 8]]) 66 67 68def test_slice_slice_obj_1s(): 69 slice_compare([1, 2, 3, 4, 5], slice(1), [1]) 70 slice_compare([1, 2, 3, 4, 5], slice(4), [1, 2, 3, 4]) 71 slice_compare([[1, 2, 3, 4], [5, 6, 7, 8]], [slice(2), slice(2)], [[1, 2], [5, 6]]) 72 slice_compare([1, 2, 3, 4, 5], slice(10), [1, 2, 3, 4, 5]) 73 74 75def test_slice_slice_obj_2s(): 76 slice_compare([1, 2, 3, 4, 5], slice(0, 2), [1, 2]) 77 slice_compare([1, 2, 3, 4, 5], slice(2, 4), [3, 4]) 78 slice_compare([[1, 2, 3, 4], [5, 6, 7, 8]], [slice(0, 2), slice(1, 2)], [[2], [6]]) 79 slice_compare([1, 2, 3, 4, 5], slice(4, 10), [5]) 80 81 82def test_slice_slice_obj_2s_multidim(): 83 slice_compare([[1, 2, 3, 4, 5]], [slice(0, 1)], [[1, 2, 3, 4, 5]]) 84 slice_compare([[1, 2, 3, 4, 5]], [slice(0, 1), slice(4)], [[1, 2, 3, 4]]) 85 slice_compare([[1, 2, 3, 4, 5]], [slice(0, 1), slice(0, 3)], [[1, 2, 3]]) 86 slice_compare([[1, 2, 3, 4], [5, 6, 7, 8]], [slice(0, 2, 2), slice(2, 4, 1)], [[3, 4]]) 87 slice_compare([[1, 2, 3, 4], [5, 6, 7, 8]], [slice(1, 0, -1), slice(1)], [[5]]) 88 89 90def test_slice_slice_obj_3s(): 91 """ 92 Test passing in all parameters to the slice objects 93 """ 94 slice_compare([1, 2, 3, 4, 5], slice(0, 2, 1), [1, 2]) 95 slice_compare([1, 2, 3, 4, 5], slice(0, 4, 1), [1, 2, 3, 4]) 96 slice_compare([1, 2, 3, 4, 5], slice(0, 10, 1), [1, 2, 3, 4, 5]) 97 slice_compare([1, 2, 3, 4, 5], slice(0, 5, 2), [1, 3, 5]) 98 slice_compare([1, 2, 3, 4, 5], slice(0, 2, 2), [1]) 99 slice_compare([1, 2, 3, 4, 5], slice(0, 1, 2), [1]) 100 slice_compare([1, 2, 3, 4, 5], slice(4, 5, 1), [5]) 101 slice_compare([1, 2, 3, 4, 5], slice(2, 5, 3), [3]) 102 slice_compare([[1, 2, 3, 4], [5, 6, 7, 8]], [slice(0, 2, 1)], [[1, 2, 3, 4], [5, 6, 7, 8]]) 103 slice_compare([[1, 2, 3, 4], [5, 6, 7, 8]], [slice(0, 2, 3)], [[1, 2, 3, 4]]) 104 slice_compare([[1, 2, 3, 4], [5, 6, 7, 8]], [slice(0, 2, 2), slice(0, 1, 2)], [[1]]) 105 slice_compare([[1, 2, 3, 4], [5, 6, 7, 8]], [slice(0, 2, 1), slice(0, 1, 2)], [[1], [5]]) 106 slice_compare([[[1, 2, 3, 4], [5, 6, 7, 8]], [[1, 2, 3, 4], [5, 6, 7, 8]]], 107 [slice(0, 2, 1), slice(0, 1, 1), slice(0, 4, 2)], 108 [[[1, 3]], [[1, 3]]]) 109 110 111def test_slice_obj_3s_double(): 112 slice_compare([1., 2., 3., 4., 5.], slice(0, 2, 1), [1., 2.]) 113 slice_compare([1., 2., 3., 4., 5.], slice(0, 4, 1), [1., 2., 3., 4.]) 114 slice_compare([1., 2., 3., 4., 5.], slice(0, 5, 2), [1., 3., 5.]) 115 slice_compare([1., 2., 3., 4., 5.], slice(0, 2, 2), [1.]) 116 slice_compare([1., 2., 3., 4., 5.], slice(0, 1, 2), [1.]) 117 slice_compare([1., 2., 3., 4., 5.], slice(4, 5, 1), [5.]) 118 slice_compare([1., 2., 3., 4., 5.], slice(2, 5, 3), [3.]) 119 120 121def test_out_of_bounds_slicing(): 122 """ 123 Test passing indices outside of the input to the slice objects 124 """ 125 slice_compare([1, 2, 3, 4, 5], slice(-15, -1), [1, 2, 3, 4]) 126 slice_compare([1, 2, 3, 4, 5], slice(-15, 15), [1, 2, 3, 4, 5]) 127 slice_compare([1, 2, 3, 4], slice(-15, -7), []) 128 129 130def test_slice_multiple_rows(): 131 """ 132 Test passing in multiple rows 133 """ 134 dataset = [[1], [3, 4, 5], [1, 2], [1, 2, 3, 4, 5, 6, 7]] 135 exp_dataset = [[], [4, 5], [2], [2, 3, 4]] 136 137 def gen(): 138 for row in dataset: 139 yield (np.array(row),) 140 141 data = ds.GeneratorDataset(gen, column_names=["col"]) 142 indexing = slice(1, 4) 143 data = data.map(operations=ops.Slice(indexing)) 144 for (d, exp_d) in zip(data.create_dict_iterator(output_numpy=True), exp_dataset): 145 np.testing.assert_array_equal(exp_d, d['col']) 146 147 148def test_slice_none_and_ellipsis(): 149 """ 150 Test passing None and Ellipsis to Slice 151 """ 152 dataset = [[1], [3, 4, 5], [1, 2], [1, 2, 3, 4, 5, 6, 7]] 153 exp_dataset = [[1], [3, 4, 5], [1, 2], [1, 2, 3, 4, 5, 6, 7]] 154 155 def gen(): 156 for row in dataset: 157 yield (np.array(row),) 158 159 data = ds.GeneratorDataset(gen, column_names=["col"]) 160 data = data.map(operations=ops.Slice(None)) 161 for (d, exp_d) in zip(data.create_dict_iterator(output_numpy=True), exp_dataset): 162 np.testing.assert_array_equal(exp_d, d['col']) 163 164 data = ds.GeneratorDataset(gen, column_names=["col"]) 165 data = data.map(operations=ops.Slice(Ellipsis)) 166 for (d, exp_d) in zip(data.create_dict_iterator(output_numpy=True), exp_dataset): 167 np.testing.assert_array_equal(exp_d, d['col']) 168 169 170def test_slice_obj_neg(): 171 slice_compare([1, 2, 3, 4, 5], slice(-1, -5, -1), [5, 4, 3, 2]) 172 slice_compare([1, 2, 3, 4, 5], slice(-1), [1, 2, 3, 4]) 173 slice_compare([1, 2, 3, 4, 5], slice(-2), [1, 2, 3]) 174 slice_compare([1, 2, 3, 4, 5], slice(-1, -5, -2), [5, 3]) 175 slice_compare([1, 2, 3, 4, 5], slice(-5, -1, 2), [1, 3]) 176 slice_compare([1, 2, 3, 4, 5], slice(-5, -1), [1, 2, 3, 4]) 177 178 179def test_slice_all_str(): 180 slice_compare([b"1", b"2", b"3", b"4", b"5"], None, [b"1", b"2", b"3", b"4", b"5"]) 181 slice_compare([b"1", b"2", b"3", b"4", b"5"], ..., [b"1", b"2", b"3", b"4", b"5"]) 182 183 184def test_slice_single_index_str(): 185 slice_compare([b"1", b"2", b"3", b"4", b"5"], [0, 1], [b"1", b"2"]) 186 slice_compare([b"1", b"2", b"3", b"4", b"5"], [0, 1], [b"1", b"2"]) 187 slice_compare([b"1", b"2", b"3", b"4", b"5"], [4], [b"5"]) 188 slice_compare([b"1", b"2", b"3", b"4", b"5"], [-1], [b"5"]) 189 slice_compare([b"1", b"2", b"3", b"4", b"5"], [-5], [b"1"]) 190 191 192def test_slice_indexes_multidim_str(): 193 slice_compare([[b"1", b"2", b"3", b"4", b"5"]], [[0], 0], [[b"1"]]) 194 slice_compare([[b"1", b"2", b"3", b"4", b"5"]], [[0], [0, 1]], [[b"1", b"2"]]) 195 196 197def test_slice_list_index_str(): 198 slice_compare([b"1", b"2", b"3", b"4", b"5"], [0, 1, 4], [b"1", b"2", b"5"]) 199 slice_compare([b"1", b"2", b"3", b"4", b"5"], [4, 1, 0], [b"5", b"2", b"1"]) 200 slice_compare([b"1", b"2", b"3", b"4", b"5"], [3, 3, 3], [b"4", b"4", b"4"]) 201 202 203# test str index object here 204def test_slice_index_and_slice_str(): 205 slice_compare([[b"1", b"2", b"3", b"4", b"5"]], [slice(0, 1), 4], [[b"5"]]) 206 slice_compare([[b"1", b"2", b"3", b"4", b"5"]], [[0], slice(0, 2)], [[b"1", b"2"]]) 207 slice_compare([[b"1", b"2", b"3", b"4"], [b"5", b"6", b"7", b"8"]], [[1], slice(2, 4, 1)], 208 [[b"7", b"8"]]) 209 210 211def test_slice_slice_obj_1s_str(): 212 slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(1), [b"1"]) 213 slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(4), [b"1", b"2", b"3", b"4"]) 214 slice_compare([[b"1", b"2", b"3", b"4"], [b"5", b"6", b"7", b"8"]], 215 [slice(2), slice(2)], 216 [[b"1", b"2"], [b"5", b"6"]]) 217 218 219def test_slice_slice_obj_2s_str(): 220 slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(0, 2), [b"1", b"2"]) 221 slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(2, 4), [b"3", b"4"]) 222 slice_compare([[b"1", b"2", b"3", b"4"], [b"5", b"6", b"7", b"8"]], 223 [slice(0, 2), slice(1, 2)], [[b"2"], [b"6"]]) 224 225 226def test_slice_slice_obj_2s_multidim_str(): 227 slice_compare([[b"1", b"2", b"3", b"4", b"5"]], [slice(0, 1)], [[b"1", b"2", b"3", b"4", b"5"]]) 228 slice_compare([[b"1", b"2", b"3", b"4", b"5"]], [slice(0, 1), slice(4)], 229 [[b"1", b"2", b"3", b"4"]]) 230 slice_compare([[b"1", b"2", b"3", b"4", b"5"]], [slice(0, 1), slice(0, 3)], 231 [[b"1", b"2", b"3"]]) 232 slice_compare([[b"1", b"2", b"3", b"4"], [b"5", b"6", b"7", b"8"]], 233 [slice(0, 2, 2), slice(2, 4, 1)], 234 [[b"3", b"4"]]) 235 236 237def test_slice_slice_obj_3s_str(): 238 """ 239 Test passing in all parameters to the slice objects 240 """ 241 slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(0, 2, 1), [b"1", b"2"]) 242 slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(0, 4, 1), [b"1", b"2", b"3", b"4"]) 243 slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(0, 5, 2), [b"1", b"3", b"5"]) 244 slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(0, 2, 2), [b"1"]) 245 slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(0, 1, 2), [b"1"]) 246 slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(4, 5, 1), [b"5"]) 247 slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(2, 5, 3), [b"3"]) 248 slice_compare([[b"1", b"2", b"3", b"4"], [b"5", b"6", b"7", b"8"]], [slice(0, 2, 1)], 249 [[b"1", b"2", b"3", b"4"], [b"5", b"6", b"7", b"8"]]) 250 slice_compare([[b"1", b"2", b"3", b"4"], [b"5", b"6", b"7", b"8"]], slice(0, 2, 3), [[b"1", b"2", b"3", b"4"]]) 251 slice_compare([[b"1", b"2", b"3", b"4"], [b"5", b"6", b"7", b"8"]], 252 [slice(0, 2, 2), slice(0, 1, 2)], [[b"1"]]) 253 slice_compare([[b"1", b"2", b"3", b"4"], [b"5", b"6", b"7", b"8"]], 254 [slice(0, 2, 1), slice(0, 1, 2)], 255 [[b"1"], [b"5"]]) 256 slice_compare([[[b"1", b"2", b"3", b"4"], [b"5", b"6", b"7", b"8"]], 257 [[b"1", b"2", b"3", b"4"], [b"5", b"6", b"7", b"8"]]], 258 [slice(0, 2, 1), slice(0, 1, 1), slice(0, 4, 2)], 259 [[[b"1", b"3"]], [[b"1", b"3"]]]) 260 261 262def test_slice_obj_neg_str(): 263 slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(-1, -5, -1), [b"5", b"4", b"3", b"2"]) 264 slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(-1), [b"1", b"2", b"3", b"4"]) 265 slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(-2), [b"1", b"2", b"3"]) 266 slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(-1, -5, -2), [b"5", b"3"]) 267 slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(-5, -1, 2), [b"1", b"3"]) 268 slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(-5, -1), [b"1", b"2", b"3", b"4"]) 269 270 271def test_out_of_bounds_slicing_str(): 272 """ 273 Test passing indices outside of the input to the slice objects 274 """ 275 slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(-15, -1), [b"1", b"2", b"3", b"4"]) 276 slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(-15, 15), [b"1", b"2", b"3", b"4", b"5"]) 277 278 indexing = slice(-15, -7) 279 expected_array = np.array([], dtype="S") 280 data = [b"1", b"2", b"3", b"4", b"5"] 281 data = ds.NumpySlicesDataset([data]) 282 data = data.map(operations=ops.Slice(indexing)) 283 for d in data.create_dict_iterator(output_numpy=True): 284 np.testing.assert_array_equal(expected_array, d['column_0']) 285 286 287def test_slice_exceptions(): 288 """ 289 Test passing in invalid parameters 290 """ 291 with pytest.raises(RuntimeError) as info: 292 slice_compare([b"1", b"2", b"3", b"4", b"5"], [5], [b"1", b"2", b"3", b"4", b"5"]) 293 assert "Index 5 is out of bounds." in str(info.value) 294 295 with pytest.raises(RuntimeError) as info: 296 slice_compare([b"1", b"2", b"3", b"4", b"5"], [], [b"1", b"2", b"3", b"4", b"5"]) 297 assert "Both indices and slices can not be empty." in str(info.value) 298 299 with pytest.raises(TypeError) as info: 300 slice_compare([b"1", b"2", b"3", b"4", b"5"], [[[0, 1]]], [b"1", b"2", b"3", b"4", b"5"]) 301 assert "Argument slice_option[0] with value [0, 1] is not of type " \ 302 "[<class 'int'>]" in str(info.value) 303 304 with pytest.raises(TypeError) as info: 305 slice_compare([b"1", b"2", b"3", b"4", b"5"], [[slice(3)]], [b"1", b"2", b"3", b"4", b"5"]) 306 assert "Argument slice_option[0] with value slice(None, 3, None) is not of type " \ 307 "[<class 'int'>]" in str(info.value) 308 309 310if __name__ == "__main__": 311 test_slice_all() 312 test_slice_single_index() 313 test_slice_indices_multidim() 314 test_slice_list_index() 315 test_slice_index_and_slice() 316 test_slice_slice_obj_1s() 317 test_slice_slice_obj_2s() 318 test_slice_slice_obj_2s_multidim() 319 test_slice_slice_obj_3s() 320 test_slice_obj_3s_double() 321 test_slice_multiple_rows() 322 test_slice_obj_neg() 323 test_slice_all_str() 324 test_slice_single_index_str() 325 test_slice_indexes_multidim_str() 326 test_slice_list_index_str() 327 test_slice_index_and_slice_str() 328 test_slice_slice_obj_1s_str() 329 test_slice_slice_obj_2s_str() 330 test_slice_slice_obj_2s_multidim_str() 331 test_slice_slice_obj_3s_str() 332 test_slice_obj_neg_str() 333 test_out_of_bounds_slicing_str() 334 test_slice_exceptions() 335