1#!/usr/bin/env python 2# Copyright 2019-2021 Huawei Technologies Co., Ltd 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16 17import os 18import pytest 19 20import mindspore.dataset as ds 21 22from mindspore.mindrecord import FileWriter 23 24CV_FILE_NAME = "./imagenet.mindrecord" 25CV1_FILE_NAME = "./imagenet1.mindrecord" 26 27 28def create_cv_mindrecord(files_num): 29 """tutorial for cv dataset writer.""" 30 if os.path.exists(CV_FILE_NAME): 31 os.remove(CV_FILE_NAME) 32 if os.path.exists("{}.db".format(CV_FILE_NAME)): 33 os.remove("{}.db".format(CV_FILE_NAME)) 34 writer = FileWriter(CV_FILE_NAME, files_num) 35 cv_schema_json = {"file_name": {"type": "string"}, 36 "label": {"type": "int32"}, "data": {"type": "bytes"}} 37 data = [{"file_name": "001.jpg", "label": 43, 38 "data": bytes('0xffsafdafda', encoding='utf-8')}] 39 writer.add_schema(cv_schema_json, "img_schema") 40 writer.add_index(["file_name", "label"]) 41 writer.write_raw_data(data) 42 writer.commit() 43 44 45def create_diff_schema_cv_mindrecord(files_num): 46 """tutorial for cv dataset writer.""" 47 if os.path.exists(CV1_FILE_NAME): 48 os.remove(CV1_FILE_NAME) 49 if os.path.exists("{}.db".format(CV1_FILE_NAME)): 50 os.remove("{}.db".format(CV1_FILE_NAME)) 51 writer = FileWriter(CV1_FILE_NAME, files_num) 52 cv_schema_json = {"file_name_1": {"type": "string"}, 53 "label": {"type": "int32"}, "data": {"type": "bytes"}} 54 data = [{"file_name_1": "001.jpg", "label": 43, 55 "data": bytes('0xffsafdafda', encoding='utf-8')}] 56 writer.add_schema(cv_schema_json, "img_schema") 57 writer.add_index(["file_name_1", "label"]) 58 writer.write_raw_data(data) 59 writer.commit() 60 61 62def create_diff_page_size_cv_mindrecord(files_num): 63 """tutorial for cv dataset writer.""" 64 if os.path.exists(CV1_FILE_NAME): 65 os.remove(CV1_FILE_NAME) 66 if os.path.exists("{}.db".format(CV1_FILE_NAME)): 67 os.remove("{}.db".format(CV1_FILE_NAME)) 68 writer = FileWriter(CV1_FILE_NAME, files_num) 69 writer.set_page_size(1 << 26) # 64MB 70 cv_schema_json = {"file_name": {"type": "string"}, 71 "label": {"type": "int32"}, "data": {"type": "bytes"}} 72 data = [{"file_name": "001.jpg", "label": 43, 73 "data": bytes('0xffsafdafda', encoding='utf-8')}] 74 writer.add_schema(cv_schema_json, "img_schema") 75 writer.add_index(["file_name", "label"]) 76 writer.write_raw_data(data) 77 writer.commit() 78 79 80def test_cv_lack_json(): 81 """tutorial for cv minderdataset.""" 82 create_cv_mindrecord(1) 83 columns_list = ["data", "file_name", "label"] 84 num_readers = 4 85 with pytest.raises(Exception): 86 ds.MindDataset(CV_FILE_NAME, "no_exist.json", 87 columns_list, num_readers) 88 os.remove(CV_FILE_NAME) 89 os.remove("{}.db".format(CV_FILE_NAME)) 90 91 92def test_cv_lack_mindrecord(): 93 """tutorial for cv minderdataset.""" 94 columns_list = ["data", "file_name", "label"] 95 num_readers = 4 96 with pytest.raises(Exception, match="does not exist or permission denied"): 97 _ = ds.MindDataset("no_exist.mindrecord", columns_list, num_readers) 98 99 100def test_invalid_mindrecord(): 101 with open('dummy.mindrecord', 'w') as f: 102 f.write('just for test') 103 columns_list = ["data", "file_name", "label"] 104 num_readers = 4 105 with pytest.raises(RuntimeError, match="Unexpected error. Invalid file " 106 "content, incorrect file or file header is exceeds the upper limit."): 107 data_set = ds.MindDataset( 108 'dummy.mindrecord', columns_list, num_readers) 109 for _ in data_set.create_dict_iterator(num_epochs=1, output_numpy=True): 110 pass 111 os.remove('dummy.mindrecord') 112 113 114def test_minddataset_lack_db(): 115 create_cv_mindrecord(1) 116 os.remove("{}.db".format(CV_FILE_NAME)) 117 columns_list = ["data", "file_name", "label"] 118 num_readers = 4 119 with pytest.raises(RuntimeError, match="Unexpected error. Invalid database file, path:"): 120 data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers) 121 num_iter = 0 122 for _ in data_set.create_dict_iterator(num_epochs=1, output_numpy=True): 123 num_iter += 1 124 try: 125 assert num_iter == 0 126 except Exception as error: 127 os.remove(CV_FILE_NAME) 128 raise error 129 else: 130 os.remove(CV_FILE_NAME) 131 132 133def test_cv_minddataset_pk_sample_error_class_column(): 134 create_cv_mindrecord(1) 135 columns_list = ["data", "file_name", "label"] 136 num_readers = 4 137 sampler = ds.PKSampler(5, None, True, 'no_exist_column') 138 with pytest.raises(RuntimeError, match="Unexpected error. Failed to launch read threads."): 139 data_set = ds.MindDataset( 140 CV_FILE_NAME, columns_list, num_readers, sampler=sampler) 141 num_iter = 0 142 for _ in data_set.create_dict_iterator(num_epochs=1, output_numpy=True): 143 num_iter += 1 144 os.remove(CV_FILE_NAME) 145 os.remove("{}.db".format(CV_FILE_NAME)) 146 147 148def test_cv_minddataset_pk_sample_exclusive_shuffle(): 149 create_cv_mindrecord(1) 150 columns_list = ["data", "file_name", "label"] 151 num_readers = 4 152 sampler = ds.PKSampler(2) 153 with pytest.raises(Exception, match="sampler and shuffle cannot be specified at the same time."): 154 data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, 155 sampler=sampler, shuffle=False) 156 num_iter = 0 157 for _ in data_set.create_dict_iterator(num_epochs=1, output_numpy=True): 158 num_iter += 1 159 os.remove(CV_FILE_NAME) 160 os.remove("{}.db".format(CV_FILE_NAME)) 161 162 163def test_cv_minddataset_reader_different_schema(): 164 create_cv_mindrecord(1) 165 create_diff_schema_cv_mindrecord(1) 166 columns_list = ["data", "label"] 167 num_readers = 4 168 with pytest.raises(RuntimeError, match="Unexpected error. Invalid data, " 169 "MindRecord files meta data is not consistent."): 170 data_set = ds.MindDataset([CV_FILE_NAME, CV1_FILE_NAME], columns_list, 171 num_readers) 172 num_iter = 0 173 for _ in data_set.create_dict_iterator(num_epochs=1): 174 num_iter += 1 175 os.remove(CV_FILE_NAME) 176 os.remove("{}.db".format(CV_FILE_NAME)) 177 os.remove(CV1_FILE_NAME) 178 os.remove("{}.db".format(CV1_FILE_NAME)) 179 180 181def test_cv_minddataset_reader_different_page_size(): 182 create_cv_mindrecord(1) 183 create_diff_page_size_cv_mindrecord(1) 184 columns_list = ["data", "label"] 185 num_readers = 4 186 with pytest.raises(RuntimeError, match="Unexpected error. Invalid data, " 187 "MindRecord files meta data is not consistent."): 188 data_set = ds.MindDataset([CV_FILE_NAME, CV1_FILE_NAME], columns_list, 189 num_readers) 190 num_iter = 0 191 for _ in data_set.create_dict_iterator(num_epochs=1): 192 num_iter += 1 193 os.remove(CV_FILE_NAME) 194 os.remove("{}.db".format(CV_FILE_NAME)) 195 os.remove(CV1_FILE_NAME) 196 os.remove("{}.db".format(CV1_FILE_NAME)) 197 198 199def test_minddataset_invalidate_num_shards(): 200 create_cv_mindrecord(1) 201 columns_list = ["data", "label"] 202 num_readers = 4 203 with pytest.raises(Exception) as error_info: 204 data_set = ds.MindDataset( 205 CV_FILE_NAME, columns_list, num_readers, True, 1, 2) 206 num_iter = 0 207 for _ in data_set.create_dict_iterator(num_epochs=1): 208 num_iter += 1 209 try: 210 assert 'Input shard_id is not within the required interval of [0, 0].' in str( 211 error_info.value) 212 except Exception as error: 213 os.remove(CV_FILE_NAME) 214 os.remove("{}.db".format(CV_FILE_NAME)) 215 raise error 216 else: 217 os.remove(CV_FILE_NAME) 218 os.remove("{}.db".format(CV_FILE_NAME)) 219 220 221def test_minddataset_invalidate_shard_id(): 222 create_cv_mindrecord(1) 223 columns_list = ["data", "label"] 224 num_readers = 4 225 with pytest.raises(Exception) as error_info: 226 data_set = ds.MindDataset( 227 CV_FILE_NAME, columns_list, num_readers, True, 1, -1) 228 num_iter = 0 229 for _ in data_set.create_dict_iterator(num_epochs=1): 230 num_iter += 1 231 try: 232 assert 'Input shard_id is not within the required interval of [0, 0].' in str( 233 error_info.value) 234 except Exception as error: 235 os.remove(CV_FILE_NAME) 236 os.remove("{}.db".format(CV_FILE_NAME)) 237 raise error 238 else: 239 os.remove(CV_FILE_NAME) 240 os.remove("{}.db".format(CV_FILE_NAME)) 241 242 243def test_minddataset_shard_id_bigger_than_num_shard(): 244 create_cv_mindrecord(1) 245 columns_list = ["data", "label"] 246 num_readers = 4 247 with pytest.raises(Exception) as error_info: 248 data_set = ds.MindDataset( 249 CV_FILE_NAME, columns_list, num_readers, True, 2, 2) 250 num_iter = 0 251 for _ in data_set.create_dict_iterator(num_epochs=1): 252 num_iter += 1 253 try: 254 assert 'Input shard_id is not within the required interval of [0, 1].' in str( 255 error_info.value) 256 except Exception as error: 257 os.remove(CV_FILE_NAME) 258 os.remove("{}.db".format(CV_FILE_NAME)) 259 raise error 260 261 with pytest.raises(Exception) as error_info: 262 data_set = ds.MindDataset( 263 CV_FILE_NAME, columns_list, num_readers, True, 2, 5) 264 num_iter = 0 265 for _ in data_set.create_dict_iterator(num_epochs=1): 266 num_iter += 1 267 try: 268 assert 'Input shard_id is not within the required interval of [0, 1].' in str( 269 error_info.value) 270 except Exception as error: 271 os.remove(CV_FILE_NAME) 272 os.remove("{}.db".format(CV_FILE_NAME)) 273 raise error 274 else: 275 os.remove(CV_FILE_NAME) 276 os.remove("{}.db".format(CV_FILE_NAME)) 277 278 279def test_cv_minddataset_partition_num_samples_equals_0(): 280 """tutorial for cv minddataset.""" 281 create_cv_mindrecord(1) 282 columns_list = ["data", "label"] 283 num_readers = 4 284 285 def partitions(num_shards): 286 for partition_id in range(num_shards): 287 data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, 288 num_shards=num_shards, 289 shard_id=partition_id, num_samples=-1) 290 num_iter = 0 291 for _ in data_set.create_dict_iterator(num_epochs=1): 292 num_iter += 1 293 294 with pytest.raises(ValueError) as error_info: 295 partitions(5) 296 try: 297 assert 'num_samples exceeds the boundary between 0 and 9223372036854775807(INT64_MAX)' in str( 298 error_info.value) 299 except Exception as error: 300 os.remove(CV_FILE_NAME) 301 os.remove("{}.db".format(CV_FILE_NAME)) 302 raise error 303 else: 304 os.remove(CV_FILE_NAME) 305 os.remove("{}.db".format(CV_FILE_NAME)) 306 307 308def test_mindrecord_exception(): 309 """tutorial for exception scenario of minderdataset + map would print error info.""" 310 311 def exception_func(item): 312 raise Exception("Error occur!") 313 314 create_cv_mindrecord(1) 315 columns_list = ["data", "file_name", "label"] 316 with pytest.raises(RuntimeError, match="The corresponding data files"): 317 data_set = ds.MindDataset(CV_FILE_NAME, columns_list, shuffle=False) 318 data_set = data_set.map(operations=exception_func, input_columns=["data"], 319 num_parallel_workers=1) 320 num_iter = 0 321 for _ in data_set.create_dict_iterator(num_epochs=1, output_numpy=True): 322 num_iter += 1 323 with pytest.raises(RuntimeError, match="The corresponding data files"): 324 data_set = ds.MindDataset(CV_FILE_NAME, columns_list, shuffle=False) 325 data_set = data_set.map(operations=exception_func, input_columns=["file_name"], 326 num_parallel_workers=1) 327 num_iter = 0 328 for _ in data_set.create_dict_iterator(num_epochs=1, output_numpy=True): 329 num_iter += 1 330 with pytest.raises(RuntimeError, match="The corresponding data files"): 331 data_set = ds.MindDataset(CV_FILE_NAME, columns_list, shuffle=False) 332 data_set = data_set.map(operations=exception_func, input_columns=["label"], 333 num_parallel_workers=1) 334 num_iter = 0 335 for _ in data_set.create_dict_iterator(num_epochs=1, output_numpy=True): 336 num_iter += 1 337 os.remove(CV_FILE_NAME) 338 os.remove("{}.db".format(CV_FILE_NAME)) 339 340 341if __name__ == '__main__': 342 test_cv_lack_json() 343 test_cv_lack_mindrecord() 344 test_invalid_mindrecord() 345 test_minddataset_lack_db() 346 test_cv_minddataset_pk_sample_error_class_column() 347 test_cv_minddataset_pk_sample_exclusive_shuffle() 348 test_cv_minddataset_reader_different_schema() 349 test_cv_minddataset_reader_different_page_size() 350 test_minddataset_invalidate_num_shards() 351 test_minddataset_invalidate_shard_id() 352 test_minddataset_shard_id_bigger_than_num_shard() 353 test_cv_minddataset_partition_num_samples_equals_0() 354 test_mindrecord_exception() 355