• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15import numpy as np
16
17import mindspore.dataset as ds
18import mindspore.dataset.vision.c_transforms as vision
19
20CELEBA_DIR = "../data/dataset/testCelebAData"
21CIFAR10_DIR = "../data/dataset/testCifar10Data"
22CIFAR100_DIR = "../data/dataset/testCifar100Data"
23CLUE_DIR = "../data/dataset/testCLUE/afqmc/train.json"
24COCO_DIR = "../data/dataset/testCOCO/train"
25COCO_ANNOTATION = "../data/dataset/testCOCO/annotations/train.json"
26CSV_DIR = "../data/dataset/testCSV/1.csv"
27IMAGE_FOLDER_DIR = "../data/dataset/testPK/data/"
28MANIFEST_DIR = "../data/dataset/testManifestData/test.manifest"
29MNIST_DIR = "../data/dataset/testMnistData"
30TFRECORD_DIR = ["../data/dataset/testTFTestAllTypes/test.data"]
31TFRECORD_SCHEMA = "../data/dataset/testTFTestAllTypes/datasetSchema.json"
32VOC_DIR = "../data/dataset/testVOC2012"
33
34
35def test_get_column_name_celeba():
36    data = ds.CelebADataset(CELEBA_DIR)
37    assert data.get_col_names() == ["image", "attr"]
38
39
40def test_get_column_name_cifar10():
41    data = ds.Cifar10Dataset(CIFAR10_DIR)
42    assert data.get_col_names() == ["image", "label"]
43
44
45def test_get_column_name_cifar100():
46    data = ds.Cifar100Dataset(CIFAR100_DIR)
47    assert data.get_col_names() == ["image", "coarse_label", "fine_label"]
48
49
50def test_get_column_name_clue():
51    data = ds.CLUEDataset(CLUE_DIR, task="AFQMC", usage="train")
52    assert data.get_col_names() == ["label", "sentence1", "sentence2"]
53
54
55def test_get_column_name_coco():
56    data = ds.CocoDataset(COCO_DIR, annotation_file=COCO_ANNOTATION, task="Detection",
57                          decode=True, shuffle=False)
58    assert data.get_col_names() == ["image", "bbox", "category_id", "iscrowd"]
59
60
61def test_get_column_name_csv():
62    data = ds.CSVDataset(CSV_DIR)
63    assert data.get_col_names() == ["1", "2", "3", "4"]
64    data = ds.CSVDataset(CSV_DIR, column_names=["col1", "col2", "col3", "col4"])
65    assert data.get_col_names() == ["col1", "col2", "col3", "col4"]
66
67
68def test_get_column_name_generator():
69    def generator():
70        for i in range(64):
71            yield (np.array([i]),)
72
73    data = ds.GeneratorDataset(generator, ["data"])
74    assert data.get_col_names() == ["data"]
75
76
77def test_get_column_name_imagefolder():
78    data = ds.ImageFolderDataset(IMAGE_FOLDER_DIR)
79    assert data.get_col_names() == ["image", "label"]
80
81
82def test_get_column_name_iterator():
83    data = ds.Cifar10Dataset(CIFAR10_DIR)
84    itr = data.create_tuple_iterator(num_epochs=1)
85    assert itr.get_col_names() == ["image", "label"]
86    itr = data.create_dict_iterator(num_epochs=1)
87    assert itr.get_col_names() == ["image", "label"]
88
89
90def test_get_column_name_manifest():
91    data = ds.ManifestDataset(MANIFEST_DIR)
92    assert data.get_col_names() == ["image", "label"]
93
94
95def test_get_column_name_map():
96    data = ds.Cifar10Dataset(CIFAR10_DIR)
97    center_crop_op = vision.CenterCrop(10)
98    data = data.map(operations=center_crop_op, input_columns=["image"])
99    assert data.get_col_names() == ["image", "label"]
100    data = ds.Cifar10Dataset(CIFAR10_DIR)
101    data = data.map(operations=center_crop_op, input_columns=["image"], output_columns=["image"])
102    assert data.get_col_names() == ["image", "label"]
103    data = ds.Cifar10Dataset(CIFAR10_DIR)
104    data = data.map(operations=center_crop_op, input_columns=["image"], output_columns=["col1"])
105    assert data.get_col_names() == ["col1", "label"]
106    data = ds.Cifar10Dataset(CIFAR10_DIR)
107    data = data.map(operations=center_crop_op, input_columns=["image"], output_columns=["col1", "col2"],
108                    column_order=["col2", "col1"])
109    assert data.get_col_names() == ["col2", "col1"]
110
111
112def test_get_column_name_mnist():
113    data = ds.MnistDataset(MNIST_DIR)
114    assert data.get_col_names() == ["image", "label"]
115
116
117def test_get_column_name_numpy_slices():
118    np_data = {"a": [1, 2], "b": [3, 4]}
119    data = ds.NumpySlicesDataset(np_data, shuffle=False)
120    assert data.get_col_names() == ["a", "b"]
121    data = ds.NumpySlicesDataset([1, 2, 3], shuffle=False)
122    assert data.get_col_names() == ["column_0"]
123
124
125def test_get_column_name_tfrecord():
126    data = ds.TFRecordDataset(TFRECORD_DIR, TFRECORD_SCHEMA)
127    assert data.get_col_names() == ["col_1d", "col_2d", "col_3d", "col_binary", "col_float", "col_sint16", "col_sint32",
128                                    "col_sint64"]
129    data = ds.TFRecordDataset(TFRECORD_DIR, TFRECORD_SCHEMA,
130                              columns_list=["col_sint16", "col_sint64", "col_2d", "col_binary"])
131    assert data.get_col_names() == ["col_sint16", "col_sint64", "col_2d", "col_binary"]
132
133    data = ds.TFRecordDataset(TFRECORD_DIR)
134    assert data.get_col_names() == ["col_1d", "col_2d", "col_3d", "col_binary", "col_float", "col_sint16", "col_sint32",
135                                    "col_sint64", "col_sint8"]
136    s = ds.Schema()
137    s.add_column("line", "string", [])
138    s.add_column("words", "string", [-1])
139    s.add_column("chinese", "string", [])
140
141    data = ds.TFRecordDataset("../data/dataset/testTextTFRecord/text.tfrecord", shuffle=False, schema=s)
142    assert data.get_col_names() == ["line", "words", "chinese"]
143
144
145def test_get_column_name_to_device():
146    data = ds.Cifar10Dataset(CIFAR10_DIR)
147    data = data.to_device()
148    assert data.get_col_names() == ["image", "label"]
149
150
151def test_get_column_name_voc():
152    data = ds.VOCDataset(VOC_DIR, task="Segmentation", usage="train", decode=True, shuffle=False)
153    assert data.get_col_names() == ["image", "target"]
154    data = ds.VOCDataset(VOC_DIR, task="Segmentation", usage="train", decode=True, shuffle=False, extra_metadata=True)
155    assert data.get_col_names() == ["image", "target", "_meta-filename"]
156
157
158def test_get_column_name_project():
159    data = ds.Cifar10Dataset(CIFAR10_DIR)
160    assert data.get_col_names() == ["image", "label"]
161    data = data.project(columns=["image"])
162    assert data.get_col_names() == ["image"]
163
164
165def test_get_column_name_rename():
166    data = ds.Cifar10Dataset(CIFAR10_DIR)
167    assert data.get_col_names() == ["image", "label"]
168    data = data.rename(["image", "label"], ["test1", "test2"])
169    assert data.get_col_names() == ["test1", "test2"]
170
171
172def test_get_column_name_zip():
173    data1 = ds.Cifar10Dataset(CIFAR10_DIR)
174    assert data1.get_col_names() == ["image", "label"]
175    data2 = ds.CSVDataset(CSV_DIR)
176    assert data2.get_col_names() == ["1", "2", "3", "4"]
177    data = ds.zip((data1, data2))
178    assert data.get_col_names() == ["image", "label", "1", "2", "3", "4"]
179
180
181if __name__ == "__main__":
182    test_get_column_name_celeba()
183    test_get_column_name_cifar10()
184    test_get_column_name_cifar100()
185    test_get_column_name_clue()
186    test_get_column_name_coco()
187    test_get_column_name_csv()
188    test_get_column_name_generator()
189    test_get_column_name_imagefolder()
190    test_get_column_name_iterator()
191    test_get_column_name_manifest()
192    test_get_column_name_map()
193    test_get_column_name_mnist()
194    test_get_column_name_numpy_slices()
195    test_get_column_name_tfrecord()
196    test_get_column_name_to_device()
197    test_get_column_name_voc()
198    test_get_column_name_project()
199    test_get_column_name_rename()
200    test_get_column_name_zip()
201