• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15import numpy as np
16import pytest
17import mindspore.dataset as ds
18from mindspore import log as logger
19
20
21# In generator dataset: Number of rows is 3, its value is 0, 1, 2
22def generator():
23    for i in range(3):
24        yield (np.array([i]),)
25
26
27# In generator dataset: Number of rows is 10, its value is 0, 1, 2 ... 10
28def generator_10():
29    for i in range(10):
30        yield (np.array([i]),)
31
32
33def filter_func_ge(data):
34    if data > 3:
35        return False
36    return True
37
38
39def test_take_01():
40    """
41    Test take: origin there are 3 row, and take 1 row, in this case: will not meet eoe and eof
42    """
43    logger.info("test_take_01")
44    data1 = ds.GeneratorDataset(generator, ["data"])
45
46    data1 = data1.take(1)
47    data1 = data1.repeat(2)
48
49    # Here i refers to index, d refers to data element
50    for _, d in enumerate(data1):
51        assert d[0].asnumpy()[0] == 0
52
53    assert sum([1 for _ in data1]) == 2
54
55
56def test_take_02():
57    """
58    Test take: origin there are 3 row, and take 2 row, in this case: will meet eoe
59    """
60    logger.info("test_take_02")
61    data1 = ds.GeneratorDataset(generator, ["data"])
62
63    data1 = data1.take(2)
64    data1 = data1.repeat(2)
65
66    # Here i refers to index, d refers to data element
67    for i, d in enumerate(data1):
68        assert i % 2 == d[0].asnumpy()[0]
69
70    assert sum([1 for _ in data1]) == 4
71
72
73def test_take_03():
74    """
75    Test take: origin there are 3 row, and take 3 row, in this case: will meet eoe and eof
76    """
77    logger.info("test_take_03")
78    data1 = ds.GeneratorDataset(generator, ["data"])
79
80    data1 = data1.take(3)
81    data1 = data1.repeat(2)
82
83    # Here i refers to index, d refers to data elements
84    for i, d in enumerate(data1):
85        assert i % 3 == d[0].asnumpy()[0]
86
87    assert sum([1 for _ in data1]) == 6
88
89
90def test_take_04():
91    """
92    Test take: origin there are 3 row, and take 4 row, this is more than the total rows
93    """
94    logger.info("test_take_04")
95    data1 = ds.GeneratorDataset(generator, ["data"])
96
97    data1 = data1.take(4)
98    data1 = data1.repeat(2)
99
100    # Here i refers to index, d refers to data element
101    for i, d in enumerate(data1):
102        assert i % 3 == d[0].asnumpy()[0]
103
104    assert sum([1 for _ in data1]) == 6
105
106
107def test_take_05():
108    """
109    Test take: there is no repeat op
110    """
111    logger.info("test_take_05")
112    data1 = ds.GeneratorDataset(generator, ["data"])
113
114    data1 = data1.take(2)
115
116    # Here i refers to index, d refers to data element
117    for i, d in enumerate(data1):
118        assert i == d[0].asnumpy()[0]
119
120    assert sum([1 for _ in data1]) == 2
121
122
123def test_take_06():
124    """
125    Test take: repeat is before take
126    """
127    logger.info("test_take_06")
128    data1 = ds.GeneratorDataset(generator, ["data"])
129
130    data1 = data1.repeat(2)
131    data1 = data1.take(4)
132
133    # Here i refers to index, d refers to data element
134    for i, d in enumerate(data1):
135        assert i % 3 == d[0].asnumpy()[0]
136
137    assert sum([1 for _ in data1]) == 4
138
139
140def test_take_07():
141    """
142    Test take: take is before batch, that mean take(N), N refer to rows num
143    """
144    logger.info("test_take_07")
145    data1 = ds.GeneratorDataset(generator, ["data"])
146
147    data1 = data1.take(2)
148    data1 = data1.batch(2)
149    assert sum([1 for _ in data1]) == 1
150
151
152def test_take_08():
153    """
154    Test take: take is after batch, that mean take(N), N refer to batches num
155    """
156    logger.info("test_take_08")
157    data1 = ds.GeneratorDataset(generator, ["data"])
158
159    data1 = data1.batch(2)
160    data1 = data1.take(2)
161    assert sum([1 for _ in data1]) == 2
162
163
164def test_take_09():
165    """
166    Test take: take count is -1, and read the whole dataset, take after repeat
167    """
168    logger.info("test_take_09")
169    data1 = ds.GeneratorDataset(generator, ["data"])
170
171    data1 = data1.repeat(2)
172    data1 = data1.take(-1)
173
174    # Here i refers to index, d refers to data element
175    for i, d in enumerate(data1):
176        assert i % 3 == d[0].asnumpy()[0]
177
178    assert sum([1 for _ in data1]) == 6
179
180
181def test_take_10():
182    """
183    Test take: take count is -1, and read the whole dataset, take before repeat
184    """
185    logger.info("test_take_10")
186    data1 = ds.GeneratorDataset(generator, ["data"])
187
188    data1 = data1.take(-1)
189    data1 = data1.repeat(2)
190
191    # Here i refers to index, d refers to data element
192    for i, d in enumerate(data1):
193        assert i % 3 == d[0].asnumpy()[0]
194
195    assert sum([1 for _ in data1]) == 6
196
197
198def test_take_11():
199    """
200    Test take: batch first, then do repeat and take operation
201    """
202    logger.info("test_take_11")
203    data1 = ds.GeneratorDataset(generator, ["data"])
204
205    data1 = data1.batch(2)
206    data1 = data1.repeat(2)
207    data1 = data1.take(-1)
208
209    # Here i refers to index, d refers to data element
210    for i, d in enumerate(data1):
211        assert 2 * (i % 2) == d[0].asnumpy()[0]
212
213    assert sum([1 for _ in data1]) == 4
214
215
216def test_take_12():
217    """
218    Test take: take first, then do batch and repeat operation
219    """
220    logger.info("test_take_12")
221    data1 = ds.GeneratorDataset(generator, ["data"])
222
223    data1 = data1.take(2)
224    data1 = data1.batch(2)
225    data1 = data1.repeat(2)
226
227    # Here i refers to index, d refers to data element
228    for _, d in enumerate(data1):
229        assert d[0].asnumpy()[0] == 0
230
231    assert sum([1 for _ in data1]) == 2
232
233
234def test_take_13():
235    """
236    Test take: skip first, then do take, batch and repeat operation
237    """
238    logger.info("test_take_13")
239    data1 = ds.GeneratorDataset(generator, ["data"])
240
241    data1 = data1.skip(2)
242    data1 = data1.take(-1)
243    data1 = data1.batch(2)
244    data1 = data1.repeat(2)
245
246    # Here i refers to index, d refers to data element
247    for _, d in enumerate(data1):
248        assert d[0].asnumpy()[0] == 2
249
250    assert sum([1 for _ in data1]) == 2
251
252
253def test_take_14():
254    """
255    Test take: take first, then do batch, skip and repeat operation
256    """
257    logger.info("test_take_14")
258    data1 = ds.GeneratorDataset(generator, ["data"])
259
260    data1 = data1.take(-1)
261    data1 = data1.batch(2)
262    data1 = data1.skip(1)
263    data1 = data1.repeat(2)
264
265    # Here i refers to index, d refers to data element
266    for _, d in enumerate(data1):
267        assert d[0].asnumpy()[0] == 2
268
269    assert sum([1 for _ in data1]) == 2
270
271
272def test_take_15():
273    """
274    Test take: large amount data, take a part, then do skip operation
275    """
276    logger.info("test_take_15")
277    data1 = ds.GeneratorDataset(generator_10, ["data"])
278
279    data1 = data1.take(6)
280    data1 = data1.skip(2)
281
282    # Here i refers to index, d refers to data element
283    for i, d in enumerate(data1):
284        assert (i + 2) == d[0].asnumpy()[0]
285
286    assert sum([1 for _ in data1]) == 4
287
288
289def test_take_16():
290    """
291    Test take: large amount data, skip a part, then do take operation
292    """
293    logger.info("test_take_16")
294    data1 = ds.GeneratorDataset(generator_10, ["data"])
295
296    data1 = data1.skip(3)
297    data1 = data1.take(5)
298
299    # Here i refers to index, d refers to data element
300    for i, d in enumerate(data1):
301        assert (i + 3) == d[0].asnumpy()[0]
302
303    assert sum([1 for _ in data1]) == 5
304
305
306def test_take_17():
307    """
308    Test take: take first, then do filter operation
309    """
310    logger.info("test_take_17")
311    data1 = ds.GeneratorDataset(generator_10, ["data"])
312
313    data1 = data1.take(8)
314    data1 = data1.filter(predicate=filter_func_ge, num_parallel_workers=4)
315
316    # Here i refers to index, d refers to data element
317    for i, d in enumerate(data1):
318        assert i == d[0].asnumpy()[0]
319
320    assert sum([1 for _ in data1]) == 4
321
322
323def test_take_18():
324    """
325    Test take: take first, then do filter, skip, batch and repeat operation
326    """
327    logger.info("test_take_18")
328    data1 = ds.GeneratorDataset(generator_10, ["data"])
329
330    data1 = data1.take(8)
331    data1 = data1.filter(predicate=filter_func_ge, num_parallel_workers=4)
332    data1 = data1.skip(2)
333
334    data1 = data1.batch(2)
335    data1 = data1.repeat(2)
336
337    # Here i refers to index, d refers to data element
338    for _, d in enumerate(data1):
339        assert d[0].asnumpy()[0] == 2
340
341    assert sum([1 for _ in data1]) == 2
342
343
344def test_take_19():
345    """
346    Test take: take is after batch, that mean take(N), N refer to batches num
347    """
348    logger.info("test_take_19")
349    with pytest.raises(ValueError) as info:
350        data1 = ds.GeneratorDataset(generator, ["data"])
351
352        data1 = data1.batch(2)
353        data1 = data1.take(0)
354    assert "within the required interval" in str(info.value)
355
356if __name__ == '__main__':
357    test_take_01()
358    test_take_02()
359    test_take_03()
360    test_take_04()
361    test_take_05()
362    test_take_06()
363    test_take_07()
364    test_take_08()
365    test_take_09()
366    test_take_10()
367    test_take_11()
368    test_take_12()
369    test_take_13()
370    test_take_14()
371    test_take_15()
372    test_take_16()
373    test_take_17()
374    test_take_18()
375    test_take_19()
376    logger.info('== test take operation finished ==')
377