1# Copyright 2019 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15import numpy as np 16 17import mindspore.dataset as ds 18import mindspore.dataset.vision.c_transforms as vision 19import mindspore.dataset.transforms.c_transforms as data_trans 20from mindspore import log as logger 21 22DATA_FILE = "../data/dataset/testManifestData/test.manifest" 23 24 25def test_manifest_dataset_train(): 26 data = ds.ManifestDataset(DATA_FILE, decode=True) 27 count = 0 28 cat_count = 0 29 dog_count = 0 30 for item in data.create_dict_iterator(num_epochs=1, output_numpy=True): 31 logger.info("item[image] is {}".format(item["image"])) 32 count = count + 1 33 if item["label"].size == 1 and item["label"] == 0: 34 cat_count = cat_count + 1 35 elif item["label"].size == 1 and item["label"] == 1: 36 dog_count = dog_count + 1 37 assert cat_count == 2 38 assert dog_count == 1 39 assert count == 4 40 41 42def test_manifest_dataset_eval(): 43 data = ds.ManifestDataset(DATA_FILE, "eval", decode=True) 44 count = 0 45 for item in data.create_dict_iterator(num_epochs=1, output_numpy=True): 46 logger.info("item[image] is {}".format(item["image"])) 47 count = count + 1 48 if item["label"] != 0 and item["label"] != 1: 49 assert 0 50 assert count == 2 51 52 53def test_manifest_dataset_class_index(): 54 class_indexing = {"dog": 11} 55 data = ds.ManifestDataset(DATA_FILE, decode=True, class_indexing=class_indexing) 56 out_class_indexing = data.get_class_indexing() 57 assert out_class_indexing == {"dog": 11} 58 count = 0 59 for item in data.create_dict_iterator(num_epochs=1, output_numpy=True): 60 logger.info("item[image] is {}".format(item["image"])) 61 count = count + 1 62 if item["label"] != 11: 63 assert 0 64 assert count == 1 65 66 67def test_manifest_dataset_get_class_index(): 68 data = ds.ManifestDataset(DATA_FILE, decode=True) 69 class_indexing = data.get_class_indexing() 70 assert class_indexing == {'cat': 0, 'dog': 1, 'flower': 2} 71 data = data.shuffle(4) 72 class_indexing = data.get_class_indexing() 73 assert class_indexing == {'cat': 0, 'dog': 1, 'flower': 2} 74 count = 0 75 for item in data.create_dict_iterator(num_epochs=1): 76 logger.info("item[image] is {}".format(item["image"])) 77 count = count + 1 78 assert count == 4 79 80 81def test_manifest_dataset_multi_label(): 82 data = ds.ManifestDataset(DATA_FILE, decode=True, shuffle=False) 83 count = 0 84 expect_label = [1, 0, 0, [0, 2]] 85 for item in data.create_dict_iterator(num_epochs=1, output_numpy=True): 86 assert item["label"].tolist() == expect_label[count] 87 logger.info("item[image] is {}".format(item["image"])) 88 count = count + 1 89 assert count == 4 90 91 92def multi_label_hot(x): 93 result = np.zeros(x.size // x.ndim, dtype=int) 94 if x.ndim > 1: 95 for i in range(x.ndim): 96 result = np.add(result, x[i]) 97 else: 98 result = np.add(result, x) 99 100 return result 101 102 103def test_manifest_dataset_multi_label_onehot(): 104 data = ds.ManifestDataset(DATA_FILE, decode=True, shuffle=False) 105 expect_label = [[[0, 1, 0], [1, 0, 0]], [[1, 0, 0], [1, 0, 1]]] 106 one_hot_encode = data_trans.OneHot(3) 107 data = data.map(operations=one_hot_encode, input_columns=["label"]) 108 data = data.map(operations=multi_label_hot, input_columns=["label"]) 109 data = data.batch(2) 110 count = 0 111 for item in data.create_dict_iterator(num_epochs=1, output_numpy=True): 112 assert item["label"].tolist() == expect_label[count] 113 logger.info("item[image] is {}".format(item["image"])) 114 count = count + 1 115 116 117def test_manifest_dataset_get_num_class(): 118 data = ds.ManifestDataset(DATA_FILE, decode=True, shuffle=False) 119 assert data.num_classes() == 3 120 121 padded_samples = [{'image': np.zeros(1, np.uint8), 'label': np.array(1, np.int32)}] 122 padded_ds = ds.PaddedDataset(padded_samples) 123 124 data = data.repeat(2) 125 padded_ds = padded_ds.repeat(2) 126 127 data1 = data + padded_ds 128 assert data1.num_classes() == 3 129 130 131def test_manifest_dataset_exception(): 132 def exception_func(item): 133 raise Exception("Error occur!") 134 135 try: 136 data = ds.ManifestDataset(DATA_FILE) 137 data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1) 138 for _ in data.__iter__(): 139 pass 140 assert False 141 except RuntimeError as e: 142 assert "map operation: [PyFunc] failed. The corresponding data files" in str(e) 143 144 try: 145 data = ds.ManifestDataset(DATA_FILE) 146 data = data.map(operations=vision.Decode(), input_columns=["image"], num_parallel_workers=1) 147 data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1) 148 for _ in data.__iter__(): 149 pass 150 assert False 151 except RuntimeError as e: 152 assert "map operation: [PyFunc] failed. The corresponding data files" in str(e) 153 154 try: 155 data = ds.ManifestDataset(DATA_FILE) 156 data = data.map(operations=exception_func, input_columns=["label"], num_parallel_workers=1) 157 for _ in data.__iter__(): 158 pass 159 assert False 160 except RuntimeError as e: 161 assert "map operation: [PyFunc] failed. The corresponding data files" in str(e) 162 163 NO_SOURCE_DATA_FILE = "../data/dataset/testManifestData/invalidNoSource.manifest" 164 try: 165 data = ds.ManifestDataset(NO_SOURCE_DATA_FILE) 166 for _ in data.__iter__(): 167 pass 168 assert False 169 except RuntimeError as e: 170 assert "Invalid data, 'source' is not found in Manifest file" in str(e) 171 172 NO_USAGE_DATA_FILE = "../data/dataset/testManifestData/invalidNoUsage.manifest" 173 try: 174 data = ds.ManifestDataset(NO_USAGE_DATA_FILE) 175 for _ in data.__iter__(): 176 pass 177 assert False 178 except RuntimeError as e: 179 assert "Invalid data, 'usage' is not found in Manifest file" in str(e) 180 181 182if __name__ == '__main__': 183 test_manifest_dataset_train() 184 test_manifest_dataset_eval() 185 test_manifest_dataset_class_index() 186 test_manifest_dataset_get_class_index() 187 test_manifest_dataset_multi_label() 188 test_manifest_dataset_multi_label_onehot() 189 test_manifest_dataset_get_num_class() 190 test_manifest_dataset_exception() 191