# Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== import numpy as np import mindspore.common.dtype as mstype import mindspore.dataset as ds import mindspore.dataset.transforms.c_transforms as C import mindspore.dataset.transforms.py_transforms import mindspore.dataset.vision.py_transforms as F from mindspore import log as logger # In generator dataset: Number of rows is 3; its values are 0, 1, 2 def generator(): for i in range(3): yield (np.array([i]),) # In generator_10 dataset: Number of rows is 7; its values are 3, 4, 5 ... 9 def generator_10(): for i in range(3, 10): yield (np.array([i]),) # In generator_20 dataset: Number of rows is 10; its values are 10, 11, 12 ... 19 def generator_20(): for i in range(10, 20): yield (np.array([i]),) # In generator_29 dataset: Number of rows is 9; its values are 20, 21, 22 ... 28 def generator_29(): for i in range(20, 29): yield (np.array([i]),) def test_concat_01(): """ Test concat: test concat 2 datasets that have the same column name and data type """ logger.info("test_concat_01") data1 = ds.GeneratorDataset(generator, ["col1"]) data2 = ds.GeneratorDataset(generator_10, ["col1"]) data3 = data1 + data2 # Here i refers to index, d refers to data element for i, d in enumerate(data3.create_tuple_iterator(output_numpy=True)): t = d logger.info("data: %i", t[0][0]) assert i == t[0][0] assert sum([1 for _ in data3]) == 10 def test_concat_02(): """ Test concat: test concat 2 datasets using concat operation not "+" operation """ logger.info("test_concat_02") data1 = ds.GeneratorDataset(generator, ["col1"]) data2 = ds.GeneratorDataset(generator_10, ["col1"]) data3 = data1.concat(data2) # Here i refers to index, d refers to data element for i, d in enumerate(data3.create_tuple_iterator(output_numpy=True)): t = d logger.info("data: %i", t[0][0]) assert i == t[0][0] assert sum([1 for _ in data3]) == 10 def test_concat_03(): """ Test concat: test concat dataset that has different column """ logger.info("test_concat_03") data1 = ds.GeneratorDataset(generator, ["col1"]) data2 = ds.GeneratorDataset(generator_10, ["col2"]) data3 = data1 + data2 try: for _, _ in enumerate(data3): pass assert False except RuntimeError: pass def test_concat_04(): """ Test concat: test concat dataset that has different rank """ logger.info("test_concat_04") data1 = ds.GeneratorDataset(generator, ["col1"]) data2 = ds.GeneratorDataset(generator_10, ["col2"]) data2 = data2.batch(3) data3 = data1 + data2 try: for _, _ in enumerate(data3): pass assert False except RuntimeError: pass def test_concat_05(): """ Test concat: test concat dataset that has different data type """ logger.info("test_concat_05") data1 = ds.GeneratorDataset(generator, ["col1"]) data2 = ds.GeneratorDataset(generator_10, ["col1"]) type_cast_op = C.TypeCast(mstype.float32) data1 = data1.map(operations=type_cast_op, input_columns=["col1"]) data3 = data1 + data2 try: for _, _ in enumerate(data3): pass assert False except RuntimeError: pass def test_concat_06(): """ Test concat: test concat multi datasets in one time """ logger.info("test_concat_06") data1 = ds.GeneratorDataset(generator, ["col1"]) data2 = ds.GeneratorDataset(generator_10, ["col1"]) data3 = ds.GeneratorDataset(generator_20, ["col1"]) dataset = data1 + data2 + data3 # Here i refers to index, d refers to data element for i, d in enumerate(dataset.create_tuple_iterator(output_numpy=True)): t = d logger.info("data: %i", t[0][0]) assert i == t[0][0] assert sum([1 for _ in dataset]) == 20 def test_concat_07(): """ Test concat: test concat one dataset with multi datasets (datasets list) """ logger.info("test_concat_07") data1 = ds.GeneratorDataset(generator, ["col1"]) data2 = ds.GeneratorDataset(generator_10, ["col1"]) data3 = ds.GeneratorDataset(generator_20, ["col1"]) dataset = [data2] + [data3] data4 = data1 + dataset # Here i refers to index, d refers to data element for i, d in enumerate(data4.create_tuple_iterator(output_numpy=True)): t = d logger.info("data: %i", t[0][0]) assert i == t[0][0] assert sum([1 for _ in data4]) == 20 def test_concat_08(): """ Test concat: test concat 2 datasets, and then repeat """ logger.info("test_concat_08") data1 = ds.GeneratorDataset(generator, ["col1"]) data2 = ds.GeneratorDataset(generator_10, ["col1"]) data3 = data1 + data2 data3 = data3.repeat(2) # Here i refers to index, d refers to data element for i, d in enumerate(data3.create_tuple_iterator(output_numpy=True)): t = d logger.info("data: %i", t[0][0]) assert i % 10 == t[0][0] assert sum([1 for _ in data3]) == 20 def test_concat_09(): """ Test concat: test concat 2 datasets, both of them have been repeat before """ logger.info("test_concat_09") data1 = ds.GeneratorDataset(generator, ["col1"]) data2 = ds.GeneratorDataset(generator_10, ["col1"]) data1 = data1.repeat(2) data2 = data2.repeat(2) data3 = data1 + data2 res = [0, 1, 2, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 3, 4, 5, 6, 7, 8, 9] # Here i refers to index, d refers to data element for i, d in enumerate(data3.create_tuple_iterator(output_numpy=True)): t = d logger.info("data: %i", t[0][0]) assert res[i] == t[0][0] assert sum([1 for _ in data3]) == 20 def test_concat_10(): """ Test concat: test concat 2 datasets, one of them have repeat before """ logger.info("test_concat_10") data1 = ds.GeneratorDataset(generator, ["col1"]) data2 = ds.GeneratorDataset(generator_10, ["col1"]) data1 = data1.repeat(2) data3 = data1 + data2 res = [0, 1, 2, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9] # Here i refers to index, d refers to data element for i, d in enumerate(data3.create_tuple_iterator(output_numpy=True)): t = d logger.info("data: %i", t[0][0]) assert res[i] == t[0][0] assert sum([1 for _ in data3]) == 13 def test_concat_11(): """ Test concat: test dataset batch then concat """ logger.info("test_concat_11") data1 = ds.GeneratorDataset(generator, ["col1"]) data2 = ds.GeneratorDataset(generator_20, ["col1"]) data1 = data1.batch(3) data2 = data2.batch(5) data3 = data1 + data2 res = [0, 10, 15, 20] # Here i refers to index, d refers to data element for i, d in enumerate(data3.create_tuple_iterator(output_numpy=True)): t = d logger.info("data: %i", t[0][0]) assert res[i] == t[0][0] assert sum([1 for _ in data3]) == 3 def test_concat_12(): """ Test concat: test dataset concat then shuffle """ logger.info("test_concat_12") data1 = ds.GeneratorDataset(generator, ["col1"]) data2 = ds.GeneratorDataset(generator_10, ["col1"]) data3 = data1 + data2 res = [8, 6, 2, 5, 0, 4, 9, 3, 7, 1] ds.config.set_seed(1) assert data3.get_dataset_size() == 10 data3 = data3.shuffle(buffer_size=10) # Here i refers to index, d refers to data element for i, d in enumerate(data3.create_tuple_iterator(output_numpy=True)): t = d logger.info("data: %i", t[0][0]) assert res[i] == t[0][0] assert sum([1 for _ in data3]) == 10 def test_concat_13(): """ Test concat: test dataset batch then shuffle and concat """ logger.info("test_concat_13") data1 = ds.GeneratorDataset(generator, ["col1"]) data2 = ds.GeneratorDataset(generator_20, ["col1"]) data1 = data1.batch(3) data2 = data2.batch(5) data3 = data1 + data2 res = [15, 0, 10] ds.config.set_seed(1) assert data3.get_dataset_size() == 3 data3 = data3.shuffle(buffer_size=int(data3.get_dataset_size())) # Here i refers to index, d refers to data element for i, d in enumerate(data3.create_tuple_iterator(output_numpy=True)): t = d logger.info("data: %i", t[0][0]) assert res[i] == t[0][0] assert sum([1 for _ in data3]) == 3 def test_concat_14(): """ Test concat: Testing concat on two different source datasets with different dataset operations. """ logger.info("test_concat_14") DATA_DIR = "../data/dataset/testPK/data" DATA_DIR2 = "../data/dataset/testImageNetData/train/" data1 = ds.ImageFolderDataset(DATA_DIR, num_samples=3) data2 = ds.ImageFolderDataset(DATA_DIR2, num_samples=2) transforms1 = mindspore.dataset.transforms.py_transforms.Compose([F.Decode(), F.Resize((224, 224)), F.ToTensor()]) data1 = data1.map(operations=transforms1, input_columns=["image"]) data2 = data2.map(operations=transforms1, input_columns=["image"]) data3 = data1 + data2 expected, output = [], [] for d in data1.create_tuple_iterator(output_numpy=True): expected.append(d[0]) for d in data2.create_tuple_iterator(output_numpy=True): expected.append(d[0]) for d in data3.create_tuple_iterator(output_numpy=True): output.append(d[0]) assert len(expected) == len(output) np.array_equal(np.array(output), np.array(expected)) assert sum([1 for _ in data3]) == 5 assert data3.get_dataset_size() == 5 def test_concat_15(): """ Test concat: create dataset with different format of dataset file, and then concat """ logger.info("test_concat_15") DATA_DIR = "../data/dataset/testPK/data" DATA_DIR2 = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] data1 = ds.ImageFolderDataset(DATA_DIR) data2 = ds.TFRecordDataset(DATA_DIR2, columns_list=["image"]) data1 = data1.project(["image"]) data3 = data1 + data2 assert sum([1 for _ in data3]) == 47 def test_concat_16(): """ Test concat: test get_dataset_size on nested concats """ logger.info("test_concat_16") DATA_DIR = "../data/dataset/testPK/data" DATA_DIR2 = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] data1 = ds.ImageFolderDataset(DATA_DIR) data2 = ds.TFRecordDataset(DATA_DIR2, columns_list=["image"]) data3 = ds.GeneratorDataset(generator, ["col1"]) data4 = ds.GeneratorDataset(generator_10, ["col1"]) data5 = data1 + data2 data6 = data3 + data4 data7 = data5 + data6 ds.config.set_seed(1) # 57 is the total size of all 4 leaf datasets assert data7.get_dataset_size() == 57 def test_concat_17(): """ Test concat: test get_dataset_size on nested concats (with sampler) """ logger.info("test_concat_17") data1 = ds.GeneratorDataset(generator, ["col1"]) data2 = ds.GeneratorDataset(generator_10, ["col1"]) data3 = ds.GeneratorDataset(generator_20, ["col1"]) data4 = ds.GeneratorDataset(generator_29, ["col1"]) data5 = data1 + data2 data6 = data3 + data4 data7 = data5 + data6 ds.config.set_seed(1) shard_num = 10 counter = 0 for i in range(shard_num): distributed_sampler = ds.DistributedSampler(num_shards=shard_num, shard_id=i, shuffle=False, num_samples=None) data7.use_sampler(distributed_sampler) iter_counter = 0 for _ in data7.create_dict_iterator(num_epochs=1, output_numpy=True): counter += 1 iter_counter += 1 assert data7.get_dataset_size() == iter_counter # 29 is the total size of all 4 leaf datasets assert counter == 29 if __name__ == "__main__": test_concat_01() test_concat_02() test_concat_03() test_concat_04() test_concat_05() test_concat_06() test_concat_07() test_concat_08() test_concat_09() test_concat_10() test_concat_11() test_concat_12() test_concat_13() test_concat_14() test_concat_15() test_concat_16() test_concat_17()