• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""
16Testing Slice op in DE
17"""
18import numpy as np
19import pytest
20
21import mindspore.dataset as ds
22import mindspore.dataset.transforms.c_transforms as ops
23
24
25def slice_compare(array, indexing, expected_array):
26    data = ds.NumpySlicesDataset([array])
27    if isinstance(indexing, list) and indexing and not isinstance(indexing[0], int):
28        data = data.map(operations=ops.Slice(*indexing))
29    else:
30        data = data.map(operations=ops.Slice(indexing))
31    for d in data.create_dict_iterator(output_numpy=True):
32        np.testing.assert_array_equal(expected_array, d['column_0'])
33
34
35def test_slice_all():
36    slice_compare([1, 2, 3, 4, 5], None, [1, 2, 3, 4, 5])
37    slice_compare([1, 2, 3, 4, 5], ..., [1, 2, 3, 4, 5])
38    slice_compare([1, 2, 3, 4, 5], True, [1, 2, 3, 4, 5])
39
40
41def test_slice_single_index():
42    slice_compare([1, 2, 3, 4, 5], 0, [1])
43    slice_compare([1, 2, 3, 4, 5], -3, [3])
44    slice_compare([1, 2, 3, 4, 5], [0], [1])
45
46
47def test_slice_indices_multidim():
48    slice_compare([[1, 2, 3, 4, 5]], [[0], [0]], 1)
49    slice_compare([[1, 2, 3, 4, 5]], [[0], [0, 3]], [[1, 4]])
50    slice_compare([[1, 2, 3, 4, 5]], [0], [[1, 2, 3, 4, 5]])
51    slice_compare([[1, 2, 3, 4, 5]], [[0], [0, -4]], [[1, 2]])
52
53
54def test_slice_list_index():
55    slice_compare([1, 2, 3, 4, 5], [0, 1, 4], [1, 2, 5])
56    slice_compare([1, 2, 3, 4, 5], [4, 1, 0], [5, 2, 1])
57    slice_compare([1, 2, 3, 4, 5], [-1, 1, 0], [5, 2, 1])
58    slice_compare([1, 2, 3, 4, 5], [-1, -4, -2], [5, 2, 4])
59    slice_compare([1, 2, 3, 4, 5], [3, 3, 3], [4, 4, 4])
60
61
62def test_slice_index_and_slice():
63    slice_compare([[1, 2, 3, 4, 5]], [slice(0, 1), [4]], [[5]])
64    slice_compare([[1, 2, 3, 4, 5]], [[0], slice(0, 2)], [[1, 2]])
65    slice_compare([[1, 2, 3, 4], [5, 6, 7, 8]], [[1], slice(2, 4, 1)], [[7, 8]])
66
67
68def test_slice_slice_obj_1s():
69    slice_compare([1, 2, 3, 4, 5], slice(1), [1])
70    slice_compare([1, 2, 3, 4, 5], slice(4), [1, 2, 3, 4])
71    slice_compare([[1, 2, 3, 4], [5, 6, 7, 8]], [slice(2), slice(2)], [[1, 2], [5, 6]])
72    slice_compare([1, 2, 3, 4, 5], slice(10), [1, 2, 3, 4, 5])
73
74
75def test_slice_slice_obj_2s():
76    slice_compare([1, 2, 3, 4, 5], slice(0, 2), [1, 2])
77    slice_compare([1, 2, 3, 4, 5], slice(2, 4), [3, 4])
78    slice_compare([[1, 2, 3, 4], [5, 6, 7, 8]], [slice(0, 2), slice(1, 2)], [[2], [6]])
79    slice_compare([1, 2, 3, 4, 5], slice(4, 10), [5])
80
81
82def test_slice_slice_obj_2s_multidim():
83    slice_compare([[1, 2, 3, 4, 5]], [slice(0, 1)], [[1, 2, 3, 4, 5]])
84    slice_compare([[1, 2, 3, 4, 5]], [slice(0, 1), slice(4)], [[1, 2, 3, 4]])
85    slice_compare([[1, 2, 3, 4, 5]], [slice(0, 1), slice(0, 3)], [[1, 2, 3]])
86    slice_compare([[1, 2, 3, 4], [5, 6, 7, 8]], [slice(0, 2, 2), slice(2, 4, 1)], [[3, 4]])
87    slice_compare([[1, 2, 3, 4], [5, 6, 7, 8]], [slice(1, 0, -1), slice(1)], [[5]])
88
89
90def test_slice_slice_obj_3s():
91    """
92    Test passing in all parameters to the slice objects
93    """
94    slice_compare([1, 2, 3, 4, 5], slice(0, 2, 1), [1, 2])
95    slice_compare([1, 2, 3, 4, 5], slice(0, 4, 1), [1, 2, 3, 4])
96    slice_compare([1, 2, 3, 4, 5], slice(0, 10, 1), [1, 2, 3, 4, 5])
97    slice_compare([1, 2, 3, 4, 5], slice(0, 5, 2), [1, 3, 5])
98    slice_compare([1, 2, 3, 4, 5], slice(0, 2, 2), [1])
99    slice_compare([1, 2, 3, 4, 5], slice(0, 1, 2), [1])
100    slice_compare([1, 2, 3, 4, 5], slice(4, 5, 1), [5])
101    slice_compare([1, 2, 3, 4, 5], slice(2, 5, 3), [3])
102    slice_compare([[1, 2, 3, 4], [5, 6, 7, 8]], [slice(0, 2, 1)], [[1, 2, 3, 4], [5, 6, 7, 8]])
103    slice_compare([[1, 2, 3, 4], [5, 6, 7, 8]], [slice(0, 2, 3)], [[1, 2, 3, 4]])
104    slice_compare([[1, 2, 3, 4], [5, 6, 7, 8]], [slice(0, 2, 2), slice(0, 1, 2)], [[1]])
105    slice_compare([[1, 2, 3, 4], [5, 6, 7, 8]], [slice(0, 2, 1), slice(0, 1, 2)], [[1], [5]])
106    slice_compare([[[1, 2, 3, 4], [5, 6, 7, 8]], [[1, 2, 3, 4], [5, 6, 7, 8]]],
107                  [slice(0, 2, 1), slice(0, 1, 1), slice(0, 4, 2)],
108                  [[[1, 3]], [[1, 3]]])
109
110
111def test_slice_obj_3s_double():
112    slice_compare([1., 2., 3., 4., 5.], slice(0, 2, 1), [1., 2.])
113    slice_compare([1., 2., 3., 4., 5.], slice(0, 4, 1), [1., 2., 3., 4.])
114    slice_compare([1., 2., 3., 4., 5.], slice(0, 5, 2), [1., 3., 5.])
115    slice_compare([1., 2., 3., 4., 5.], slice(0, 2, 2), [1.])
116    slice_compare([1., 2., 3., 4., 5.], slice(0, 1, 2), [1.])
117    slice_compare([1., 2., 3., 4., 5.], slice(4, 5, 1), [5.])
118    slice_compare([1., 2., 3., 4., 5.], slice(2, 5, 3), [3.])
119
120
121def test_out_of_bounds_slicing():
122    """
123    Test passing indices outside of the input to the slice objects
124    """
125    slice_compare([1, 2, 3, 4, 5], slice(-15, -1), [1, 2, 3, 4])
126    slice_compare([1, 2, 3, 4, 5], slice(-15, 15), [1, 2, 3, 4, 5])
127    slice_compare([1, 2, 3, 4], slice(-15, -7), [])
128
129
130def test_slice_multiple_rows():
131    """
132    Test passing in multiple rows
133    """
134    dataset = [[1], [3, 4, 5], [1, 2], [1, 2, 3, 4, 5, 6, 7]]
135    exp_dataset = [[], [4, 5], [2], [2, 3, 4]]
136
137    def gen():
138        for row in dataset:
139            yield (np.array(row),)
140
141    data = ds.GeneratorDataset(gen, column_names=["col"])
142    indexing = slice(1, 4)
143    data = data.map(operations=ops.Slice(indexing))
144    for (d, exp_d) in zip(data.create_dict_iterator(output_numpy=True), exp_dataset):
145        np.testing.assert_array_equal(exp_d, d['col'])
146
147
148def test_slice_none_and_ellipsis():
149    """
150    Test passing None and Ellipsis to Slice
151    """
152    dataset = [[1], [3, 4, 5], [1, 2], [1, 2, 3, 4, 5, 6, 7]]
153    exp_dataset = [[1], [3, 4, 5], [1, 2], [1, 2, 3, 4, 5, 6, 7]]
154
155    def gen():
156        for row in dataset:
157            yield (np.array(row),)
158
159    data = ds.GeneratorDataset(gen, column_names=["col"])
160    data = data.map(operations=ops.Slice(None))
161    for (d, exp_d) in zip(data.create_dict_iterator(output_numpy=True), exp_dataset):
162        np.testing.assert_array_equal(exp_d, d['col'])
163
164    data = ds.GeneratorDataset(gen, column_names=["col"])
165    data = data.map(operations=ops.Slice(Ellipsis))
166    for (d, exp_d) in zip(data.create_dict_iterator(output_numpy=True), exp_dataset):
167        np.testing.assert_array_equal(exp_d, d['col'])
168
169
170def test_slice_obj_neg():
171    slice_compare([1, 2, 3, 4, 5], slice(-1, -5, -1), [5, 4, 3, 2])
172    slice_compare([1, 2, 3, 4, 5], slice(-1), [1, 2, 3, 4])
173    slice_compare([1, 2, 3, 4, 5], slice(-2), [1, 2, 3])
174    slice_compare([1, 2, 3, 4, 5], slice(-1, -5, -2), [5, 3])
175    slice_compare([1, 2, 3, 4, 5], slice(-5, -1, 2), [1, 3])
176    slice_compare([1, 2, 3, 4, 5], slice(-5, -1), [1, 2, 3, 4])
177
178
179def test_slice_all_str():
180    slice_compare([b"1", b"2", b"3", b"4", b"5"], None, [b"1", b"2", b"3", b"4", b"5"])
181    slice_compare([b"1", b"2", b"3", b"4", b"5"], ..., [b"1", b"2", b"3", b"4", b"5"])
182
183
184def test_slice_single_index_str():
185    slice_compare([b"1", b"2", b"3", b"4", b"5"], [0, 1], [b"1", b"2"])
186    slice_compare([b"1", b"2", b"3", b"4", b"5"], [0, 1], [b"1", b"2"])
187    slice_compare([b"1", b"2", b"3", b"4", b"5"], [4], [b"5"])
188    slice_compare([b"1", b"2", b"3", b"4", b"5"], [-1], [b"5"])
189    slice_compare([b"1", b"2", b"3", b"4", b"5"], [-5], [b"1"])
190
191
192def test_slice_indexes_multidim_str():
193    slice_compare([[b"1", b"2", b"3", b"4", b"5"]], [[0], 0], [[b"1"]])
194    slice_compare([[b"1", b"2", b"3", b"4", b"5"]], [[0], [0, 1]], [[b"1", b"2"]])
195
196
197def test_slice_list_index_str():
198    slice_compare([b"1", b"2", b"3", b"4", b"5"], [0, 1, 4], [b"1", b"2", b"5"])
199    slice_compare([b"1", b"2", b"3", b"4", b"5"], [4, 1, 0], [b"5", b"2", b"1"])
200    slice_compare([b"1", b"2", b"3", b"4", b"5"], [3, 3, 3], [b"4", b"4", b"4"])
201
202
203# test str index object here
204def test_slice_index_and_slice_str():
205    slice_compare([[b"1", b"2", b"3", b"4", b"5"]], [slice(0, 1), 4], [[b"5"]])
206    slice_compare([[b"1", b"2", b"3", b"4", b"5"]], [[0], slice(0, 2)], [[b"1", b"2"]])
207    slice_compare([[b"1", b"2", b"3", b"4"], [b"5", b"6", b"7", b"8"]], [[1], slice(2, 4, 1)],
208                  [[b"7", b"8"]])
209
210
211def test_slice_slice_obj_1s_str():
212    slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(1), [b"1"])
213    slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(4), [b"1", b"2", b"3", b"4"])
214    slice_compare([[b"1", b"2", b"3", b"4"], [b"5", b"6", b"7", b"8"]],
215                  [slice(2), slice(2)],
216                  [[b"1", b"2"], [b"5", b"6"]])
217
218
219def test_slice_slice_obj_2s_str():
220    slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(0, 2), [b"1", b"2"])
221    slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(2, 4), [b"3", b"4"])
222    slice_compare([[b"1", b"2", b"3", b"4"], [b"5", b"6", b"7", b"8"]],
223                  [slice(0, 2), slice(1, 2)], [[b"2"], [b"6"]])
224
225
226def test_slice_slice_obj_2s_multidim_str():
227    slice_compare([[b"1", b"2", b"3", b"4", b"5"]], [slice(0, 1)], [[b"1", b"2", b"3", b"4", b"5"]])
228    slice_compare([[b"1", b"2", b"3", b"4", b"5"]], [slice(0, 1), slice(4)],
229                  [[b"1", b"2", b"3", b"4"]])
230    slice_compare([[b"1", b"2", b"3", b"4", b"5"]], [slice(0, 1), slice(0, 3)],
231                  [[b"1", b"2", b"3"]])
232    slice_compare([[b"1", b"2", b"3", b"4"], [b"5", b"6", b"7", b"8"]],
233                  [slice(0, 2, 2), slice(2, 4, 1)],
234                  [[b"3", b"4"]])
235
236
237def test_slice_slice_obj_3s_str():
238    """
239    Test passing in all parameters to the slice objects
240    """
241    slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(0, 2, 1), [b"1", b"2"])
242    slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(0, 4, 1), [b"1", b"2", b"3", b"4"])
243    slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(0, 5, 2), [b"1", b"3", b"5"])
244    slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(0, 2, 2), [b"1"])
245    slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(0, 1, 2), [b"1"])
246    slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(4, 5, 1), [b"5"])
247    slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(2, 5, 3), [b"3"])
248    slice_compare([[b"1", b"2", b"3", b"4"], [b"5", b"6", b"7", b"8"]], [slice(0, 2, 1)],
249                  [[b"1", b"2", b"3", b"4"], [b"5", b"6", b"7", b"8"]])
250    slice_compare([[b"1", b"2", b"3", b"4"], [b"5", b"6", b"7", b"8"]], slice(0, 2, 3), [[b"1", b"2", b"3", b"4"]])
251    slice_compare([[b"1", b"2", b"3", b"4"], [b"5", b"6", b"7", b"8"]],
252                  [slice(0, 2, 2), slice(0, 1, 2)], [[b"1"]])
253    slice_compare([[b"1", b"2", b"3", b"4"], [b"5", b"6", b"7", b"8"]],
254                  [slice(0, 2, 1), slice(0, 1, 2)],
255                  [[b"1"], [b"5"]])
256    slice_compare([[[b"1", b"2", b"3", b"4"], [b"5", b"6", b"7", b"8"]],
257                   [[b"1", b"2", b"3", b"4"], [b"5", b"6", b"7", b"8"]]],
258                  [slice(0, 2, 1), slice(0, 1, 1), slice(0, 4, 2)],
259                  [[[b"1", b"3"]], [[b"1", b"3"]]])
260
261
262def test_slice_obj_neg_str():
263    slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(-1, -5, -1), [b"5", b"4", b"3", b"2"])
264    slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(-1), [b"1", b"2", b"3", b"4"])
265    slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(-2), [b"1", b"2", b"3"])
266    slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(-1, -5, -2), [b"5", b"3"])
267    slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(-5, -1, 2), [b"1", b"3"])
268    slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(-5, -1), [b"1", b"2", b"3", b"4"])
269
270
271def test_out_of_bounds_slicing_str():
272    """
273    Test passing indices outside of the input to the slice objects
274    """
275    slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(-15, -1), [b"1", b"2", b"3", b"4"])
276    slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(-15, 15), [b"1", b"2", b"3", b"4", b"5"])
277
278    indexing = slice(-15, -7)
279    expected_array = np.array([], dtype="S")
280    data = [b"1", b"2", b"3", b"4", b"5"]
281    data = ds.NumpySlicesDataset([data])
282    data = data.map(operations=ops.Slice(indexing))
283    for d in data.create_dict_iterator(output_numpy=True):
284        np.testing.assert_array_equal(expected_array, d['column_0'])
285
286
287def test_slice_exceptions():
288    """
289    Test passing in invalid parameters
290    """
291    with pytest.raises(RuntimeError) as info:
292        slice_compare([b"1", b"2", b"3", b"4", b"5"], [5], [b"1", b"2", b"3", b"4", b"5"])
293    assert "Index 5 is out of bounds." in str(info.value)
294
295    with pytest.raises(RuntimeError) as info:
296        slice_compare([b"1", b"2", b"3", b"4", b"5"], [], [b"1", b"2", b"3", b"4", b"5"])
297    assert "Both indices and slices can not be empty." in str(info.value)
298
299    with pytest.raises(TypeError) as info:
300        slice_compare([b"1", b"2", b"3", b"4", b"5"], [[[0, 1]]], [b"1", b"2", b"3", b"4", b"5"])
301    assert "Argument slice_option[0] with value [0, 1] is not of type " \
302           "[<class 'int'>]" in str(info.value)
303
304    with pytest.raises(TypeError) as info:
305        slice_compare([b"1", b"2", b"3", b"4", b"5"], [[slice(3)]], [b"1", b"2", b"3", b"4", b"5"])
306    assert "Argument slice_option[0] with value slice(None, 3, None) is not of type " \
307           "[<class 'int'>]" in str(info.value)
308
309
310if __name__ == "__main__":
311    test_slice_all()
312    test_slice_single_index()
313    test_slice_indices_multidim()
314    test_slice_list_index()
315    test_slice_index_and_slice()
316    test_slice_slice_obj_1s()
317    test_slice_slice_obj_2s()
318    test_slice_slice_obj_2s_multidim()
319    test_slice_slice_obj_3s()
320    test_slice_obj_3s_double()
321    test_slice_multiple_rows()
322    test_slice_obj_neg()
323    test_slice_all_str()
324    test_slice_single_index_str()
325    test_slice_indexes_multidim_str()
326    test_slice_list_index_str()
327    test_slice_index_and_slice_str()
328    test_slice_slice_obj_1s_str()
329    test_slice_slice_obj_2s_str()
330    test_slice_slice_obj_2s_multidim_str()
331    test_slice_slice_obj_3s_str()
332    test_slice_obj_neg_str()
333    test_out_of_bounds_slicing_str()
334    test_slice_exceptions()
335