• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15import numpy as np
16
17import mindspore.common.dtype as mstype
18import mindspore.dataset as ds
19import mindspore.dataset.transforms.c_transforms as C
20import mindspore.dataset.transforms.py_transforms
21import mindspore.dataset.vision.py_transforms as F
22from mindspore import log as logger
23
24
25# In generator dataset: Number of rows is 3; its values are 0, 1, 2
26def generator():
27    for i in range(3):
28        yield (np.array([i]),)
29
30
31# In generator_10 dataset: Number of rows is 7; its values are 3, 4, 5 ... 9
32def generator_10():
33    for i in range(3, 10):
34        yield (np.array([i]),)
35
36
37# In generator_20 dataset: Number of rows is 10; its values are 10, 11, 12 ... 19
38def generator_20():
39    for i in range(10, 20):
40        yield (np.array([i]),)
41
42
43# In generator_29 dataset: Number of rows is 9; its values are 20, 21, 22 ... 28
44def generator_29():
45    for i in range(20, 29):
46        yield (np.array([i]),)
47
48
49def test_concat_01():
50    """
51    Test concat: test concat 2 datasets that have the same column name and data type
52    """
53    logger.info("test_concat_01")
54    data1 = ds.GeneratorDataset(generator, ["col1"])
55    data2 = ds.GeneratorDataset(generator_10, ["col1"])
56
57    data3 = data1 + data2
58
59    # Here i refers to index, d refers to data element
60    for i, d in enumerate(data3.create_tuple_iterator(output_numpy=True)):
61        t = d
62        logger.info("data: %i", t[0][0])
63        assert i == t[0][0]
64
65    assert sum([1 for _ in data3]) == 10
66
67
68def test_concat_02():
69    """
70    Test concat: test concat 2 datasets using concat operation not "+" operation
71    """
72    logger.info("test_concat_02")
73    data1 = ds.GeneratorDataset(generator, ["col1"])
74    data2 = ds.GeneratorDataset(generator_10, ["col1"])
75
76    data3 = data1.concat(data2)
77
78    # Here i refers to index, d refers to data element
79    for i, d in enumerate(data3.create_tuple_iterator(output_numpy=True)):
80        t = d
81        logger.info("data: %i", t[0][0])
82        assert i == t[0][0]
83
84    assert sum([1 for _ in data3]) == 10
85
86
87def test_concat_03():
88    """
89    Test concat: test concat dataset that has different column
90    """
91    logger.info("test_concat_03")
92    data1 = ds.GeneratorDataset(generator, ["col1"])
93    data2 = ds.GeneratorDataset(generator_10, ["col2"])
94
95    data3 = data1 + data2
96
97    try:
98        for _, _ in enumerate(data3):
99            pass
100        assert False
101    except RuntimeError:
102        pass
103
104
105def test_concat_04():
106    """
107    Test concat: test concat dataset that has different rank
108    """
109    logger.info("test_concat_04")
110    data1 = ds.GeneratorDataset(generator, ["col1"])
111    data2 = ds.GeneratorDataset(generator_10, ["col2"])
112    data2 = data2.batch(3)
113
114    data3 = data1 + data2
115
116    try:
117        for _, _ in enumerate(data3):
118            pass
119        assert False
120    except RuntimeError:
121        pass
122
123
124def test_concat_05():
125    """
126    Test concat: test concat dataset that has different data type
127    """
128    logger.info("test_concat_05")
129    data1 = ds.GeneratorDataset(generator, ["col1"])
130    data2 = ds.GeneratorDataset(generator_10, ["col1"])
131
132    type_cast_op = C.TypeCast(mstype.float32)
133    data1 = data1.map(operations=type_cast_op, input_columns=["col1"])
134
135    data3 = data1 + data2
136
137    try:
138        for _, _ in enumerate(data3):
139            pass
140        assert False
141    except RuntimeError:
142        pass
143
144
145def test_concat_06():
146    """
147    Test concat: test concat multi datasets in one time
148    """
149    logger.info("test_concat_06")
150    data1 = ds.GeneratorDataset(generator, ["col1"])
151    data2 = ds.GeneratorDataset(generator_10, ["col1"])
152    data3 = ds.GeneratorDataset(generator_20, ["col1"])
153
154    dataset = data1 + data2 + data3
155
156    # Here i refers to index, d refers to data element
157    for i, d in enumerate(dataset.create_tuple_iterator(output_numpy=True)):
158        t = d
159        logger.info("data: %i", t[0][0])
160        assert i == t[0][0]
161
162    assert sum([1 for _ in dataset]) == 20
163
164
165def test_concat_07():
166    """
167    Test concat: test concat one dataset with multi datasets (datasets list)
168    """
169    logger.info("test_concat_07")
170    data1 = ds.GeneratorDataset(generator, ["col1"])
171    data2 = ds.GeneratorDataset(generator_10, ["col1"])
172    data3 = ds.GeneratorDataset(generator_20, ["col1"])
173
174    dataset = [data2] + [data3]
175    data4 = data1 + dataset
176
177    # Here i refers to index, d refers to data element
178    for i, d in enumerate(data4.create_tuple_iterator(output_numpy=True)):
179        t = d
180        logger.info("data: %i", t[0][0])
181        assert i == t[0][0]
182
183    assert sum([1 for _ in data4]) == 20
184
185
186def test_concat_08():
187    """
188    Test concat: test concat 2 datasets, and then repeat
189    """
190    logger.info("test_concat_08")
191    data1 = ds.GeneratorDataset(generator, ["col1"])
192    data2 = ds.GeneratorDataset(generator_10, ["col1"])
193
194    data3 = data1 + data2
195    data3 = data3.repeat(2)
196
197    # Here i refers to index, d refers to data element
198    for i, d in enumerate(data3.create_tuple_iterator(output_numpy=True)):
199        t = d
200        logger.info("data: %i", t[0][0])
201        assert i % 10 == t[0][0]
202
203    assert sum([1 for _ in data3]) == 20
204
205
206def test_concat_09():
207    """
208    Test concat: test concat 2 datasets, both of them have been repeat before
209    """
210    logger.info("test_concat_09")
211    data1 = ds.GeneratorDataset(generator, ["col1"])
212    data2 = ds.GeneratorDataset(generator_10, ["col1"])
213
214    data1 = data1.repeat(2)
215    data2 = data2.repeat(2)
216    data3 = data1 + data2
217
218    res = [0, 1, 2, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 3, 4, 5, 6, 7, 8, 9]
219    # Here i refers to index, d refers to data element
220    for i, d in enumerate(data3.create_tuple_iterator(output_numpy=True)):
221        t = d
222        logger.info("data: %i", t[0][0])
223        assert res[i] == t[0][0]
224
225    assert sum([1 for _ in data3]) == 20
226
227
228def test_concat_10():
229    """
230    Test concat: test concat 2 datasets, one of them have repeat before
231    """
232    logger.info("test_concat_10")
233    data1 = ds.GeneratorDataset(generator, ["col1"])
234    data2 = ds.GeneratorDataset(generator_10, ["col1"])
235
236    data1 = data1.repeat(2)
237    data3 = data1 + data2
238
239    res = [0, 1, 2, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
240    # Here i refers to index, d refers to data element
241    for i, d in enumerate(data3.create_tuple_iterator(output_numpy=True)):
242        t = d
243        logger.info("data: %i", t[0][0])
244        assert res[i] == t[0][0]
245
246    assert sum([1 for _ in data3]) == 13
247
248
249def test_concat_11():
250    """
251    Test concat: test dataset batch then concat
252    """
253    logger.info("test_concat_11")
254    data1 = ds.GeneratorDataset(generator, ["col1"])
255    data2 = ds.GeneratorDataset(generator_20, ["col1"])
256
257    data1 = data1.batch(3)
258    data2 = data2.batch(5)
259
260    data3 = data1 + data2
261    res = [0, 10, 15, 20]
262
263    # Here i refers to index, d refers to data element
264    for i, d in enumerate(data3.create_tuple_iterator(output_numpy=True)):
265        t = d
266        logger.info("data: %i", t[0][0])
267        assert res[i] == t[0][0]
268
269    assert sum([1 for _ in data3]) == 3
270
271
272def test_concat_12():
273    """
274    Test concat: test dataset concat then shuffle
275    """
276    logger.info("test_concat_12")
277    data1 = ds.GeneratorDataset(generator, ["col1"])
278    data2 = ds.GeneratorDataset(generator_10, ["col1"])
279
280    data3 = data1 + data2
281    res = [8, 6, 2, 5, 0, 4, 9, 3, 7, 1]
282
283    ds.config.set_seed(1)
284    assert data3.get_dataset_size() == 10
285    data3 = data3.shuffle(buffer_size=10)
286
287    # Here i refers to index, d refers to data element
288    for i, d in enumerate(data3.create_tuple_iterator(output_numpy=True)):
289        t = d
290        logger.info("data: %i", t[0][0])
291        assert res[i] == t[0][0]
292
293    assert sum([1 for _ in data3]) == 10
294
295
296def test_concat_13():
297    """
298    Test concat: test dataset batch then shuffle and concat
299    """
300    logger.info("test_concat_13")
301    data1 = ds.GeneratorDataset(generator, ["col1"])
302    data2 = ds.GeneratorDataset(generator_20, ["col1"])
303
304    data1 = data1.batch(3)
305    data2 = data2.batch(5)
306
307    data3 = data1 + data2
308    res = [15, 0, 10]
309
310    ds.config.set_seed(1)
311    assert data3.get_dataset_size() == 3
312
313    data3 = data3.shuffle(buffer_size=int(data3.get_dataset_size()))
314
315    # Here i refers to index, d refers to data element
316    for i, d in enumerate(data3.create_tuple_iterator(output_numpy=True)):
317        t = d
318        logger.info("data: %i", t[0][0])
319        assert res[i] == t[0][0]
320
321    assert sum([1 for _ in data3]) == 3
322
323
324def test_concat_14():
325    """
326    Test concat: Testing concat on two different source datasets with different dataset operations.
327    """
328    logger.info("test_concat_14")
329    DATA_DIR = "../data/dataset/testPK/data"
330    DATA_DIR2 = "../data/dataset/testImageNetData/train/"
331
332    data1 = ds.ImageFolderDataset(DATA_DIR, num_samples=3)
333    data2 = ds.ImageFolderDataset(DATA_DIR2, num_samples=2)
334
335    transforms1 = mindspore.dataset.transforms.py_transforms.Compose([F.Decode(),
336                                                                      F.Resize((224, 224)),
337                                                                      F.ToTensor()])
338
339    data1 = data1.map(operations=transforms1, input_columns=["image"])
340    data2 = data2.map(operations=transforms1, input_columns=["image"])
341    data3 = data1 + data2
342
343    expected, output = [], []
344    for d in data1.create_tuple_iterator(output_numpy=True):
345        expected.append(d[0])
346    for d in data2.create_tuple_iterator(output_numpy=True):
347        expected.append(d[0])
348    for d in data3.create_tuple_iterator(output_numpy=True):
349        output.append(d[0])
350
351    assert len(expected) == len(output)
352    np.array_equal(np.array(output), np.array(expected))
353
354    assert sum([1 for _ in data3]) == 5
355    assert data3.get_dataset_size() == 5
356
357
358def test_concat_15():
359    """
360    Test concat: create dataset with different format of dataset file, and then concat
361    """
362    logger.info("test_concat_15")
363    DATA_DIR = "../data/dataset/testPK/data"
364    DATA_DIR2 = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
365
366    data1 = ds.ImageFolderDataset(DATA_DIR)
367    data2 = ds.TFRecordDataset(DATA_DIR2, columns_list=["image"])
368
369    data1 = data1.project(["image"])
370    data3 = data1 + data2
371
372    assert sum([1 for _ in data3]) == 47
373
374
375def test_concat_16():
376    """
377    Test concat: test get_dataset_size on nested concats
378    """
379    logger.info("test_concat_16")
380    DATA_DIR = "../data/dataset/testPK/data"
381    DATA_DIR2 = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
382
383    data1 = ds.ImageFolderDataset(DATA_DIR)
384    data2 = ds.TFRecordDataset(DATA_DIR2, columns_list=["image"])
385
386    data3 = ds.GeneratorDataset(generator, ["col1"])
387    data4 = ds.GeneratorDataset(generator_10, ["col1"])
388
389    data5 = data1 + data2
390    data6 = data3 + data4
391    data7 = data5 + data6
392
393    ds.config.set_seed(1)
394
395    # 57 is the total size of all 4 leaf datasets
396    assert data7.get_dataset_size() == 57
397
398
399def test_concat_17():
400    """
401    Test concat: test get_dataset_size on nested concats (with sampler)
402    """
403    logger.info("test_concat_17")
404
405    data1 = ds.GeneratorDataset(generator, ["col1"])
406    data2 = ds.GeneratorDataset(generator_10, ["col1"])
407
408    data3 = ds.GeneratorDataset(generator_20, ["col1"])
409    data4 = ds.GeneratorDataset(generator_29, ["col1"])
410
411    data5 = data1 + data2
412    data6 = data3 + data4
413    data7 = data5 + data6
414
415    ds.config.set_seed(1)
416    shard_num = 10
417    counter = 0
418
419    for i in range(shard_num):
420        distributed_sampler = ds.DistributedSampler(num_shards=shard_num, shard_id=i, shuffle=False, num_samples=None)
421        data7.use_sampler(distributed_sampler)
422        iter_counter = 0
423        for _ in data7.create_dict_iterator(num_epochs=1, output_numpy=True):
424            counter += 1
425            iter_counter += 1
426        assert data7.get_dataset_size() == iter_counter
427
428    # 29 is the total size of all 4 leaf datasets
429    assert counter == 29
430
431
432if __name__ == "__main__":
433    test_concat_01()
434    test_concat_02()
435    test_concat_03()
436    test_concat_04()
437    test_concat_05()
438    test_concat_06()
439    test_concat_07()
440    test_concat_08()
441    test_concat_09()
442    test_concat_10()
443    test_concat_11()
444    test_concat_12()
445    test_concat_13()
446    test_concat_14()
447    test_concat_15()
448    test_concat_16()
449    test_concat_17()
450