• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15import numpy as np
16import pytest
17
18import mindspore.dataset as ds
19import mindspore.dataset.vision.c_transforms as vision
20
21
22DATA_DIR_TF2 = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
23SCHEMA_DIR_TF2 = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
24
25
26def test_tf_skip():
27    """
28    a simple skip operation.
29    """
30    data1 = ds.TFRecordDataset(DATA_DIR_TF2, SCHEMA_DIR_TF2, shuffle=False)
31
32    resize_height, resize_width = 32, 32
33    decode_op = vision.Decode()
34    resize_op = vision.Resize((resize_height, resize_width), interpolation=ds.transforms.vision.Inter.LINEAR)
35    data1 = data1.map(operations=decode_op, input_columns=["image"])
36    data1 = data1.map(operations=resize_op, input_columns=["image"])
37    data1 = data1.skip(2)
38
39    num_iter = 0
40    for _ in data1.create_dict_iterator(num_epochs=1):
41        num_iter += 1
42    assert num_iter == 1
43
44
45def generator_md():
46    """
47    create a dataset with [0, 1, 2, 3, 4]
48    """
49    for i in range(5):
50        yield (np.array([i]),)
51
52
53def test_generator_skip():
54    ds1 = ds.GeneratorDataset(generator_md, ["data"], num_parallel_workers=4)
55
56    # Here ds1 should be [3, 4]
57    ds1 = ds1.skip(3)
58
59    buf = []
60    for data in ds1.create_tuple_iterator(output_numpy=True):
61        buf.append(data[0][0])
62    assert len(buf) == 2
63    assert buf == [3, 4]
64
65
66def test_skip_1():
67    ds1 = ds.GeneratorDataset(generator_md, ["data"])
68
69    # Here ds1 should be []
70    ds1 = ds1.skip(7)
71
72    buf = []
73    for data in ds1.create_tuple_iterator(output_numpy=True):
74        buf.append(data[0][0])
75    assert buf == []
76
77
78def test_skip_2():
79    ds1 = ds.GeneratorDataset(generator_md, ["data"])
80
81    # Here ds1 should be [0, 1, 2, 3, 4]
82    ds1 = ds1.skip(0)
83
84    buf = []
85    for data in ds1.create_tuple_iterator(output_numpy=True):
86        buf.append(data[0][0])
87    assert len(buf) == 5
88    assert buf == [0, 1, 2, 3, 4]
89
90
91def test_skip_repeat_1():
92    ds1 = ds.GeneratorDataset(generator_md, ["data"])
93
94    # Here ds1 should be [0, 1, 2, 3, 4, 0, 1, 2, 3, 4]
95    ds1 = ds1.repeat(2)
96
97    # Here ds1 should be [3, 4, 0, 1, 2, 3, 4]
98    ds1 = ds1.skip(3)
99
100    buf = []
101    for data in ds1.create_tuple_iterator(output_numpy=True):
102        buf.append(data[0][0])
103    assert len(buf) == 7
104    assert buf == [3, 4, 0, 1, 2, 3, 4]
105
106
107def test_skip_repeat_2():
108    ds1 = ds.GeneratorDataset(generator_md, ["data"])
109
110    # Here ds1 should be [3, 4]
111    ds1 = ds1.skip(3)
112
113    # Here ds1 should be [3, 4, 3, 4]
114    ds1 = ds1.repeat(2)
115
116    buf = []
117    for data in ds1.create_tuple_iterator(output_numpy=True):
118        buf.append(data[0][0])
119    assert len(buf) == 4
120    assert buf == [3, 4, 3, 4]
121
122
123def test_skip_repeat_3():
124    ds1 = ds.GeneratorDataset(generator_md, ["data"])
125
126    # Here ds1 should be [0, 1, 2, 3, 4, 0, 1, 2, 3, 4]
127    ds1 = ds1.repeat(2)
128
129    # Here ds1 should be [3, 4]
130    ds1 = ds1.skip(8)
131
132    # Here ds1 should be [3, 4, 3, 4, 3, 4]
133    ds1 = ds1.repeat(3)
134
135    buf = []
136    for data in ds1.create_tuple_iterator(output_numpy=True):
137        buf.append(data[0][0])
138    assert len(buf) == 6
139    assert buf == [3, 4, 3, 4, 3, 4]
140
141
142def test_skip_take_1():
143    ds1 = ds.GeneratorDataset(generator_md, ["data"])
144
145    # Here ds1 should be [0, 1, 2, 3]
146    ds1 = ds1.take(4)
147
148    # Here ds1 should be [2, 3]
149    ds1 = ds1.skip(2)
150
151    buf = []
152    for data in ds1.create_tuple_iterator(output_numpy=True):
153        buf.append(data[0][0])
154    assert len(buf) == 2
155    assert buf == [2, 3]
156
157
158def test_skip_take_2():
159    ds1 = ds.GeneratorDataset(generator_md, ["data"])
160
161    # Here ds1 should be [2, 3, 4]
162    ds1 = ds1.skip(2)
163
164    # Here ds1 should be [2, 3]
165    ds1 = ds1.take(2)
166
167    buf = []
168    for data in ds1.create_tuple_iterator(output_numpy=True):
169        buf.append(data[0][0])
170    assert len(buf) == 2
171    assert buf == [2, 3]
172
173
174def generator_1d():
175    for i in range(64):
176        yield (np.array([i]),)
177
178
179def test_skip_filter_1():
180    dataset = ds.GeneratorDataset(generator_1d, ['data'])
181    dataset = dataset.skip(5)
182    dataset = dataset.filter(predicate=lambda data: data < 11, num_parallel_workers=4)
183
184    buf = []
185    for item in dataset.create_tuple_iterator(output_numpy=True):
186        buf.append(item[0][0])
187    assert buf == [5, 6, 7, 8, 9, 10]
188
189
190def test_skip_filter_2():
191    dataset = ds.GeneratorDataset(generator_1d, ['data'])
192    dataset = dataset.filter(predicate=lambda data: data < 11, num_parallel_workers=4)
193    dataset = dataset.skip(5)
194
195    buf = []
196    for item in dataset.create_tuple_iterator(output_numpy=True):
197        buf.append(item[0][0])
198    assert buf == [5, 6, 7, 8, 9, 10]
199
200
201def test_skip_exception_1():
202    data1 = ds.GeneratorDataset(generator_md, ["data"])
203
204    try:
205        data1 = data1.skip(count=-1)
206        num_iter = 0
207        for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
208            num_iter += 1
209
210    except ValueError as e:
211        assert "Input count is not within the required interval" in str(e)
212
213
214def test_skip_exception_2():
215    ds1 = ds.GeneratorDataset(generator_md, ["data"])
216
217    with pytest.raises(ValueError) as e:
218        ds1 = ds1.skip(-2)
219    assert "Input count is not within the required interval" in str(e.value)
220
221
222
223if __name__ == "__main__":
224    test_tf_skip()
225    test_generator_skip()
226    test_skip_1()
227    test_skip_2()
228    test_skip_repeat_1()
229    test_skip_repeat_2()
230    test_skip_repeat_3()
231    test_skip_take_1()
232    test_skip_take_2()
233    test_skip_filter_1()
234    test_skip_filter_2()
235    test_skip_exception_1()
236    test_skip_exception_2()
237