1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15import numpy as np 16 17import mindspore.dataset as ds 18import mindspore.dataset.vision.c_transforms as vision 19 20CELEBA_DIR = "../data/dataset/testCelebAData" 21CIFAR10_DIR = "../data/dataset/testCifar10Data" 22CIFAR100_DIR = "../data/dataset/testCifar100Data" 23CLUE_DIR = "../data/dataset/testCLUE/afqmc/train.json" 24COCO_DIR = "../data/dataset/testCOCO/train" 25COCO_ANNOTATION = "../data/dataset/testCOCO/annotations/train.json" 26CSV_DIR = "../data/dataset/testCSV/1.csv" 27IMAGE_FOLDER_DIR = "../data/dataset/testPK/data/" 28MANIFEST_DIR = "../data/dataset/testManifestData/test.manifest" 29MNIST_DIR = "../data/dataset/testMnistData" 30TFRECORD_DIR = ["../data/dataset/testTFTestAllTypes/test.data"] 31TFRECORD_SCHEMA = "../data/dataset/testTFTestAllTypes/datasetSchema.json" 32VOC_DIR = "../data/dataset/testVOC2012" 33 34 35def test_get_column_name_celeba(): 36 data = ds.CelebADataset(CELEBA_DIR) 37 assert data.get_col_names() == ["image", "attr"] 38 39 40def test_get_column_name_cifar10(): 41 data = ds.Cifar10Dataset(CIFAR10_DIR) 42 assert data.get_col_names() == ["image", "label"] 43 44 45def test_get_column_name_cifar100(): 46 data = ds.Cifar100Dataset(CIFAR100_DIR) 47 assert data.get_col_names() == ["image", "coarse_label", "fine_label"] 48 49 50def test_get_column_name_clue(): 51 data = ds.CLUEDataset(CLUE_DIR, task="AFQMC", usage="train") 52 assert data.get_col_names() == ["label", "sentence1", "sentence2"] 53 54 55def test_get_column_name_coco(): 56 data = ds.CocoDataset(COCO_DIR, annotation_file=COCO_ANNOTATION, task="Detection", 57 decode=True, shuffle=False) 58 assert data.get_col_names() == ["image", "bbox", "category_id", "iscrowd"] 59 60 61def test_get_column_name_csv(): 62 data = ds.CSVDataset(CSV_DIR) 63 assert data.get_col_names() == ["1", "2", "3", "4"] 64 data = ds.CSVDataset(CSV_DIR, column_names=["col1", "col2", "col3", "col4"]) 65 assert data.get_col_names() == ["col1", "col2", "col3", "col4"] 66 67 68def test_get_column_name_generator(): 69 def generator(): 70 for i in range(64): 71 yield (np.array([i]),) 72 73 data = ds.GeneratorDataset(generator, ["data"]) 74 assert data.get_col_names() == ["data"] 75 76 77def test_get_column_name_imagefolder(): 78 data = ds.ImageFolderDataset(IMAGE_FOLDER_DIR) 79 assert data.get_col_names() == ["image", "label"] 80 81 82def test_get_column_name_iterator(): 83 data = ds.Cifar10Dataset(CIFAR10_DIR) 84 itr = data.create_tuple_iterator(num_epochs=1) 85 assert itr.get_col_names() == ["image", "label"] 86 itr = data.create_dict_iterator(num_epochs=1) 87 assert itr.get_col_names() == ["image", "label"] 88 89 90def test_get_column_name_manifest(): 91 data = ds.ManifestDataset(MANIFEST_DIR) 92 assert data.get_col_names() == ["image", "label"] 93 94 95def test_get_column_name_map(): 96 data = ds.Cifar10Dataset(CIFAR10_DIR) 97 center_crop_op = vision.CenterCrop(10) 98 data = data.map(operations=center_crop_op, input_columns=["image"]) 99 assert data.get_col_names() == ["image", "label"] 100 data = ds.Cifar10Dataset(CIFAR10_DIR) 101 data = data.map(operations=center_crop_op, input_columns=["image"], output_columns=["image"]) 102 assert data.get_col_names() == ["image", "label"] 103 data = ds.Cifar10Dataset(CIFAR10_DIR) 104 data = data.map(operations=center_crop_op, input_columns=["image"], output_columns=["col1"]) 105 assert data.get_col_names() == ["col1", "label"] 106 data = ds.Cifar10Dataset(CIFAR10_DIR) 107 data = data.map(operations=center_crop_op, input_columns=["image"], output_columns=["col1", "col2"], 108 column_order=["col2", "col1"]) 109 assert data.get_col_names() == ["col2", "col1"] 110 111 112def test_get_column_name_mnist(): 113 data = ds.MnistDataset(MNIST_DIR) 114 assert data.get_col_names() == ["image", "label"] 115 116 117def test_get_column_name_numpy_slices(): 118 np_data = {"a": [1, 2], "b": [3, 4]} 119 data = ds.NumpySlicesDataset(np_data, shuffle=False) 120 assert data.get_col_names() == ["a", "b"] 121 data = ds.NumpySlicesDataset([1, 2, 3], shuffle=False) 122 assert data.get_col_names() == ["column_0"] 123 124 125def test_get_column_name_tfrecord(): 126 data = ds.TFRecordDataset(TFRECORD_DIR, TFRECORD_SCHEMA) 127 assert data.get_col_names() == ["col_1d", "col_2d", "col_3d", "col_binary", "col_float", "col_sint16", "col_sint32", 128 "col_sint64"] 129 data = ds.TFRecordDataset(TFRECORD_DIR, TFRECORD_SCHEMA, 130 columns_list=["col_sint16", "col_sint64", "col_2d", "col_binary"]) 131 assert data.get_col_names() == ["col_sint16", "col_sint64", "col_2d", "col_binary"] 132 133 data = ds.TFRecordDataset(TFRECORD_DIR) 134 assert data.get_col_names() == ["col_1d", "col_2d", "col_3d", "col_binary", "col_float", "col_sint16", "col_sint32", 135 "col_sint64", "col_sint8"] 136 s = ds.Schema() 137 s.add_column("line", "string", []) 138 s.add_column("words", "string", [-1]) 139 s.add_column("chinese", "string", []) 140 141 data = ds.TFRecordDataset("../data/dataset/testTextTFRecord/text.tfrecord", shuffle=False, schema=s) 142 assert data.get_col_names() == ["line", "words", "chinese"] 143 144 145def test_get_column_name_to_device(): 146 data = ds.Cifar10Dataset(CIFAR10_DIR) 147 data = data.to_device() 148 assert data.get_col_names() == ["image", "label"] 149 150 151def test_get_column_name_voc(): 152 data = ds.VOCDataset(VOC_DIR, task="Segmentation", usage="train", decode=True, shuffle=False) 153 assert data.get_col_names() == ["image", "target"] 154 data = ds.VOCDataset(VOC_DIR, task="Segmentation", usage="train", decode=True, shuffle=False, extra_metadata=True) 155 assert data.get_col_names() == ["image", "target", "_meta-filename"] 156 157 158def test_get_column_name_project(): 159 data = ds.Cifar10Dataset(CIFAR10_DIR) 160 assert data.get_col_names() == ["image", "label"] 161 data = data.project(columns=["image"]) 162 assert data.get_col_names() == ["image"] 163 164 165def test_get_column_name_rename(): 166 data = ds.Cifar10Dataset(CIFAR10_DIR) 167 assert data.get_col_names() == ["image", "label"] 168 data = data.rename(["image", "label"], ["test1", "test2"]) 169 assert data.get_col_names() == ["test1", "test2"] 170 171 172def test_get_column_name_zip(): 173 data1 = ds.Cifar10Dataset(CIFAR10_DIR) 174 assert data1.get_col_names() == ["image", "label"] 175 data2 = ds.CSVDataset(CSV_DIR) 176 assert data2.get_col_names() == ["1", "2", "3", "4"] 177 data = ds.zip((data1, data2)) 178 assert data.get_col_names() == ["image", "label", "1", "2", "3", "4"] 179 180 181if __name__ == "__main__": 182 test_get_column_name_celeba() 183 test_get_column_name_cifar10() 184 test_get_column_name_cifar100() 185 test_get_column_name_clue() 186 test_get_column_name_coco() 187 test_get_column_name_csv() 188 test_get_column_name_generator() 189 test_get_column_name_imagefolder() 190 test_get_column_name_iterator() 191 test_get_column_name_manifest() 192 test_get_column_name_map() 193 test_get_column_name_mnist() 194 test_get_column_name_numpy_slices() 195 test_get_column_name_tfrecord() 196 test_get_column_name_to_device() 197 test_get_column_name_voc() 198 test_get_column_name_project() 199 test_get_column_name_rename() 200 test_get_column_name_zip() 201