1# Copyright 2019-2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15""" 16This module is to write data into mindrecord. 17""" 18import os 19import platform 20import queue 21import re 22import shutil 23import stat 24import time 25import multiprocessing as mp 26import numpy as np 27from mindspore import log as logger 28from .shardwriter import ShardWriter 29from .shardreader import ShardReader 30from .shardheader import ShardHeader 31from .shardindexgenerator import ShardIndexGenerator 32from .shardutils import MIN_SHARD_COUNT, MAX_SHARD_COUNT, VALID_ATTRIBUTES, VALID_ARRAY_ATTRIBUTES, \ 33 check_filename, VALUE_TYPE_MAP, SUCCESS 34from .common.exceptions import ParamValueError, ParamTypeError, MRMInvalidSchemaError, MRMDefineIndexError 35from .config import _get_enc_key, _get_enc_mode, _get_dec_mode, _get_hash_mode, encrypt, decrypt, append_hash_to_file, \ 36 verify_file_hash 37 38__all__ = ['FileWriter'] 39 40 41class FileWriter: 42 r""" 43 Class to write user defined raw data into MindRecord files. 44 45 Note: 46 After the MindRecord file is generated, if the file name is changed, 47 the file may fail to be read. 48 49 Args: 50 file_name (str): File name of MindRecord file. 51 shard_num (int, optional): The Number of MindRecord files. 52 It should be between [1, 1000]. Default: ``1`` . 53 overwrite (bool, optional): Whether to overwrite if the file already exists. Default: ``False`` . 54 55 Raises: 56 ParamValueError: If `file_name` or `shard_num` or `overwrite` is invalid. 57 58 Examples: 59 >>> from mindspore.mindrecord import FileWriter 60 >>> 61 >>> writer = FileWriter(file_name="test.mindrecord", shard_num=1, overwrite=True) 62 >>> schema_json = {"file_name": {"type": "string"}, "label": {"type": "int32"}, "data": {"type": "bytes"}} 63 >>> writer.add_schema(schema_json, "test_schema") 64 >>> indexes = ["file_name", "label"] 65 >>> writer.add_index(indexes) 66 >>> for i in range(10): 67 ... data = [{"file_name": str(i) + ".jpg", "label": i, 68 ... "data": b"\x10c\xb3w\xa8\xee$o&<q\x8c\x8e(\xa2\x90\x90\x96\xbc\xb1\x1e\xd4QER\x13?\xff"}] 69 ... writer.write_raw_data(data) 70 >>> writer.commit() 71 """ 72 73 def __init__(self, file_name, shard_num=1, overwrite=False): 74 if platform.system().lower() == "windows": 75 file_name = file_name.replace("\\", "/") 76 check_filename(file_name) 77 self._file_name = file_name 78 79 if shard_num is not None: 80 if isinstance(shard_num, int): 81 if shard_num < MIN_SHARD_COUNT or shard_num > MAX_SHARD_COUNT: 82 raise ParamValueError("Parameter shard_num's value: {} should between {} and {}." 83 .format(shard_num, MIN_SHARD_COUNT, MAX_SHARD_COUNT)) 84 else: 85 raise ParamValueError("Parameter shard_num's type is not int.") 86 else: 87 raise ParamValueError("Parameter shard_num is None.") 88 89 if not isinstance(overwrite, bool): 90 raise ParamValueError("Parameter overwrite's type is not bool.") 91 92 self._shard_num = shard_num 93 self._index_generator = True 94 suffix_shard_size = len(str(self._shard_num - 1)) 95 96 if self._shard_num == 1: 97 self._paths = [self._file_name] 98 else: 99 if _get_enc_key() is not None or _get_hash_mode() is not None: 100 raise RuntimeError("When encode mode or hash check is enabled, " + 101 "the automatic sharding function is unavailable.") 102 self._paths = ["{}{}".format(self._file_name, 103 str(x).rjust(suffix_shard_size, '0')) 104 for x in range(self._shard_num)] 105 106 self._overwrite = overwrite 107 self._append = False 108 self._flush = False 109 self._header = ShardHeader() 110 self._writer = ShardWriter() 111 self._generator = None 112 113 # parallel write mode 114 self._parallel_writer = None 115 self._writers = None 116 self._queue = None 117 self._workers = None 118 self._index_workers = None 119 120 @classmethod 121 def open_for_append(cls, file_name): 122 r""" 123 Open MindRecord file and get ready to append data. 124 125 Args: 126 file_name (str): String of MindRecord file name. 127 128 Returns: 129 FileWriter, file writer object for the opened MindRecord file. 130 131 Raises: 132 ParamValueError: If file_name is invalid. 133 FileNameError: If path contains invalid characters. 134 MRMOpenError: If failed to open MindRecord file. 135 MRMOpenForAppendError: If failed to open file for appending data. 136 137 Examples: 138 >>> from mindspore.mindrecord import FileWriter 139 >>> 140 >>> data = [{"file_name": "0.jpg", "label": 0, 141 ... "data": b"\x10c\xb3w\xa8\xee$o&<q\x8c\x8e(\xa2\x90\x90\x96\xbc\xb1\x1e\xd4QER\x13?\xff"}] 142 >>> writer = FileWriter(file_name="test.mindrecord", shard_num=1, overwrite=True) 143 >>> schema_json = {"file_name": {"type": "string"}, "label": {"type": "int32"}, "data": {"type": "bytes"}} 144 >>> writer.add_schema(schema_json, "test_schema") 145 >>> writer.write_raw_data(data) 146 >>> writer.commit() 147 >>> 148 >>> write_append = FileWriter.open_for_append("test.mindrecord") 149 >>> for i in range(9): 150 ... data = [{"file_name": str(i+1) + ".jpg", "label": i, 151 ... "data": b"\x10c\xb3w\xa8\xee$o&<q\x8c\x8e(\xa2\x90\x90\x96\xbc\xb1\x1e\xd4QER\x13?\xff"}] 152 ... write_append.write_raw_data(data) 153 >>> write_append.commit() 154 """ 155 if platform.system().lower() == "windows": 156 file_name = file_name.replace("\\", "/") 157 check_filename(file_name) 158 159 # decrypt the data file and index file 160 index_file_name = file_name + ".db" 161 decrypt_filename = decrypt(file_name, _get_enc_key(), _get_dec_mode()) 162 decrypt_index_filename = decrypt(index_file_name, _get_enc_key(), _get_dec_mode()) 163 164 # verify integrity check 165 verify_file_hash(decrypt_filename) 166 verify_file_hash(decrypt_index_filename) 167 168 # move after decrypt and hash check all success 169 if decrypt_filename != file_name: 170 shutil.move(decrypt_filename, file_name) 171 shutil.move(decrypt_index_filename, index_file_name) 172 173 # construct ShardHeader 174 reader = ShardReader() 175 reader.open(file_name, False) 176 header = ShardHeader(reader.get_header()) 177 reader.close() 178 179 instance = cls("append") 180 instance.init_append(file_name, header) 181 return instance 182 183 # pylint: disable=missing-docstring 184 def init_append(self, file_name, header): 185 self._append = True 186 187 if platform.system().lower() == "windows": 188 self._file_name = file_name.replace("\\", "/") 189 else: 190 self._file_name = file_name 191 192 self._header = header 193 self._writer.open_for_append(self._file_name) 194 self._paths = [self._file_name] 195 196 def add_schema(self, content, desc=None): 197 """ 198 The schema is added to describe the raw data to be written. 199 200 Note: 201 Please refer to the Examples of :class:`mindspore.mindrecord.FileWriter` . 202 203 .. list-table:: The data types supported by MindRecord. 204 :widths: 25 25 50 205 :header-rows: 1 206 207 * - Data Type 208 - Data Shape 209 - Details 210 * - int32 211 - / 212 - integer number 213 * - int64 214 - / 215 - integer number 216 * - float32 217 - / 218 - real number 219 * - float64 220 - / 221 - real number 222 * - string 223 - / 224 - string data 225 * - bytes 226 - / 227 - binary data 228 * - int32 229 - [-1] / [-1, 32, 32] / [3, 224, 224] 230 - numpy ndarray 231 * - int64 232 - [-1] / [-1, 32, 32] / [3, 224, 224] 233 - numpy ndarray 234 * - float32 235 - [-1] / [-1, 32, 32] / [3, 224, 224] 236 - numpy ndarray 237 * - float64 238 - [-1] / [-1, 32, 32] / [3, 224, 224] 239 - numpy ndarray 240 241 Args: 242 content (dict): Dictionary of schema content. 243 desc (str, optional): String of schema description, Default: ``None`` . 244 245 Raises: 246 MRMInvalidSchemaError: If schema is invalid. 247 MRMBuildSchemaError: If failed to build schema. 248 MRMAddSchemaError: If failed to add schema. 249 250 Examples: 251 >>> # Examples of available schemas 252 >>> schema1 = {"file_name": {"type": "string"}, "label": {"type": "int32"}, "data": {"type": "bytes"}} 253 >>> schema2 = {"input_ids": {"type": "int32", "shape": [-1]}, 254 ... "input_masks": {"type": "int32", "shape": [-1]}} 255 """ 256 ret, error_msg = self._validate_schema(content) 257 if ret is False: 258 raise MRMInvalidSchemaError(error_msg) 259 schema = self._header.build_schema(content, desc) 260 self._header.add_schema(schema) 261 262 def add_index(self, index_fields): 263 """ 264 Select index fields from schema to accelerate reading. 265 schema is added through `add_schema` . 266 267 Note: 268 The index fields should be primitive type. e.g. int/float/str. 269 If the function is not called, the fields of the primitive type 270 in schema are set as indexes by default. 271 272 Please refer to the Examples of :class:`mindspore.mindrecord.FileWriter` . 273 274 Args: 275 index_fields (list[str]): fields from schema. 276 277 Raises: 278 ParamTypeError: If index field is invalid. 279 MRMDefineIndexError: If index field is not primitive type. 280 MRMAddIndexError: If failed to add index field. 281 MRMGetMetaError: If the schema is not set or failed to get meta. 282 """ 283 if not index_fields or not isinstance(index_fields, list): 284 raise ParamTypeError('index_fields', 'list') 285 286 for field in index_fields: 287 if field in self._header.blob_fields: 288 raise MRMDefineIndexError("Failed to set field {} since it's not primitive type.".format(field)) 289 if not isinstance(field, str): 290 raise ParamTypeError('index field', 'str') 291 self._header.add_index_fields(index_fields) 292 293 def write_raw_data(self, raw_data, parallel_writer=False): 294 """ 295 Convert raw data into a series of consecutive MindRecord \ 296 files after the raw data is verified against the schema. 297 298 Note: 299 Please refer to the Examples of :class:`mindspore.mindrecord.FileWriter` . 300 301 Args: 302 raw_data (list[dict]): List of raw data. 303 parallel_writer (bool, optional): Write raw data in parallel if it equals to True. Default: ``False`` . 304 305 Raises: 306 ParamTypeError: If index field is invalid. 307 MRMOpenError: If failed to open MindRecord file. 308 MRMValidateDataError: If data does not match blob fields. 309 MRMSetHeaderError: If failed to set header. 310 MRMWriteDatasetError: If failed to write dataset. 311 TypeError: If parallel_writer is not bool. 312 """ 313 if not isinstance(parallel_writer, bool): 314 raise TypeError("The parameter `parallel_writer` must be bool.") 315 316 if self._parallel_writer is None: 317 self._parallel_writer = parallel_writer 318 if self._parallel_writer != parallel_writer: 319 raise RuntimeError("The parameter `parallel_writer` must be consistent during use.") 320 if not self._parallel_writer: 321 if not self._writer.is_open: 322 self._writer.open(self._paths, self._overwrite) 323 if not self._writer.get_shard_header(): 324 self._writer.set_shard_header(self._header) 325 if not isinstance(raw_data, list): 326 raise ParamTypeError('raw_data', 'list') 327 if self._flush and not self._append: 328 raise RuntimeError("Not allowed to call `write_raw_data` on flushed MindRecord files." \ 329 "When creating new MindRecord files, please remove `commit` before " \ 330 "`write_raw_data`. In other cases, when appending to existing MindRecord files, " \ 331 "please call `open_for_append` first and then `write_raw_data`.") 332 for each_raw in raw_data: 333 if not isinstance(each_raw, dict): 334 raise ParamTypeError('raw_data item', 'dict') 335 self._verify_based_on_schema(raw_data) 336 self._writer.write_raw_data(raw_data, True, parallel_writer) 337 return 338 339 ## parallel write mode 340 # init the _writers and launch the workers 341 if self._writers is None: 342 self._writers = [None] * len(self._paths) # writers used by worker 343 self._queue = mp.Queue(len(self._paths) * 2) # queue for worker 344 self._workers = [None] * len(self._paths) # worker process 345 for i, path in enumerate(self._paths): 346 self._writers[i] = ShardWriter() 347 self._writers[i].open([path], self._overwrite) 348 self._writers[i].set_shard_header(self._header) 349 350 # launch the workers for parallel write 351 self._queue._joincancelled = True # pylint: disable=W0212 352 p = mp.Process(target=self._write_worker, name="WriterWorker" + str(i), args=(i, self._queue)) 353 p.daemon = True 354 p.start() 355 logger.info("Start worker process(pid:{}) to parallel write.".format(p.pid)) 356 self._workers[i] = p 357 358 # fill the self._queue 359 check_interval = 0.5 # 0.5s 360 start_time = time.time() 361 while True: 362 try: 363 self._queue.put(raw_data, block=False) 364 except queue.Full: 365 if time.time() - start_time > check_interval: 366 start_time = time.time() 367 logger.warning("Because there are too few MindRecord file shards, the efficiency of parallel " \ 368 "writing is too low. You can stop the current task and add the parameter " \ 369 "`shard_num` of `FileWriter` to upgrade the task.") 370 371 # check the status of worker process 372 for i in range(len(self._paths)): 373 if not self._workers[i].is_alive(): 374 raise RuntimeError("Worker process(pid:{}) has stopped abnormally. Please check " \ 375 "the above log".format(self._workers[i].pid)) 376 continue 377 return 378 379 def set_header_size(self, header_size): 380 """ 381 Set the size of header which contains shard information, schema information, \ 382 page meta information, etc. The larger a header, the more data \ 383 the MindRecord file can store. If the size of header is larger than \ 384 the default size (16MB), users need to call the API to set a proper size. 385 386 Args: 387 header_size (int): Size of header, in bytes, which between 16*1024(16KB) and 388 128*1024*1024(128MB). 389 390 Raises: 391 MRMInvalidHeaderSizeError: If failed to set header size. 392 393 Examples: 394 >>> from mindspore.mindrecord import FileWriter 395 >>> writer = FileWriter(file_name="test.mindrecord", shard_num=1) 396 >>> writer.set_header_size(1 << 25) # 32MB 397 """ 398 self._writer.set_header_size(header_size) 399 400 def set_page_size(self, page_size): 401 """ 402 Set the size of page that represents the area where data is stored, \ 403 and the areas are divided into two types: raw page and blob page. \ 404 The larger a page, the more data the page can store. If the size of \ 405 a sample is larger than the default size (32MB), users need to call the API \ 406 to set a proper size. 407 408 Args: 409 page_size (int): Size of page, in bytes, which between 32*1024(32KB) and 410 256*1024*1024(256MB). 411 412 Raises: 413 MRMInvalidPageSizeError: If failed to set page size. 414 415 Examples: 416 >>> from mindspore.mindrecord import FileWriter 417 >>> writer = FileWriter(file_name="test.mindrecord", shard_num=1) 418 >>> writer.set_page_size(1 << 26) # 64MB 419 """ 420 self._writer.set_page_size(page_size) 421 422 def commit(self): # pylint: disable=W0212 423 """ 424 Flush data in memory to disk and generate the corresponding database files. 425 426 Note: 427 Please refer to the Examples of :class:`mindspore.mindrecord.FileWriter` . 428 429 Raises: 430 MRMOpenError: If failed to open MindRecord file. 431 MRMSetHeaderError: If failed to set header. 432 MRMIndexGeneratorError: If failed to create index generator. 433 MRMGenerateIndexError: If failed to write to database. 434 MRMCommitError: If failed to flush data to disk. 435 RuntimeError: Parallel write failed. 436 """ 437 if not self._parallel_writer: 438 self._flush = True 439 if not self._writer.is_open: 440 self._writer.open(self._paths, self._overwrite) 441 # permit commit without data 442 if not self._writer.get_shard_header(): 443 self._writer.set_shard_header(self._header) 444 self._writer.commit() 445 if self._index_generator: 446 if self._append: 447 self._generator = ShardIndexGenerator(self._file_name, self._append) 448 elif len(self._paths) >= 1: 449 self._generator = ShardIndexGenerator(os.path.realpath(self._paths[0]), self._append) 450 self._generator.build() 451 self._generator.write_to_db() 452 else: 453 # maybe a empty mindrecord, so need check _writers 454 if self._writers is None: 455 self._writers = [None] * len(self._paths) 456 for i, path in enumerate(self._paths): 457 self._writers[i] = ShardWriter() 458 self._writers[i].open(path, self._overwrite) 459 self._writers[i].set_shard_header(self._header) 460 461 self._parallel_commit() 462 463 # change file mode first, because encrypt / hash check may failed 464 mindrecord_files = [] 465 index_files = [] 466 for item in self._paths: 467 if os.path.exists(item): 468 os.chmod(item, stat.S_IRUSR | stat.S_IWUSR) 469 mindrecord_files.append(item) 470 index_file = item + ".db" 471 if os.path.exists(index_file): 472 os.chmod(index_file, stat.S_IRUSR | stat.S_IWUSR) 473 index_files.append(index_file) 474 475 for item in self._paths: 476 if os.path.exists(item): 477 # add the integrity check string 478 if _get_hash_mode() is not None: 479 append_hash_to_file(item) 480 append_hash_to_file(item + ".db") 481 482 # encrypt the mindrecord file 483 if _get_enc_key() is not None: 484 encrypt(item, _get_enc_key(), _get_enc_mode()) 485 encrypt(item + ".db", _get_enc_key(), _get_enc_mode()) 486 487 logger.info("The list of mindrecord files created are: {}, and the list of index files are: {}".format( 488 mindrecord_files, index_files)) 489 490 def _index_worker(self, i): 491 """The worker do the index generator""" 492 generator = ShardIndexGenerator(os.path.realpath(self._paths[i]), False) 493 generator.build() 494 generator.write_to_db() 495 496 def _parallel_commit(self): 497 """Parallel commit""" 498 # if some workers stopped, error may occur 499 alive_count = 0 500 for i in range(len(self._paths)): 501 if self._workers[i].is_alive(): 502 alive_count += 1 503 if alive_count != len(self._paths): 504 raise RuntimeError("Parallel write worker error, please check the above log.") 505 506 # send EOF to worker process 507 for i in range(len(self._paths)): 508 while True: 509 try: 510 self._queue.put("EOF", block=False) 511 except queue.Full: 512 time.sleep(1) 513 if not self._workers[i].is_alive(): 514 raise RuntimeError("Worker process(pid:{}) has stopped abnormally. Please check " \ 515 "the above log".format(self._workers[i].pid)) 516 continue 517 break 518 519 # wait the worker processing 520 while True: 521 alive_count = 0 522 for i in range(len(self._paths)): 523 if self._workers[i].is_alive(): 524 alive_count += 1 525 if alive_count == 0: 526 break 527 time.sleep(1) 528 logger.info("Waiting for all the parallel workers to finish.") 529 530 del self._queue 531 532 # wait for worker process stop 533 for index in range(len(self._paths)): 534 while True: 535 logger.info("Waiting for the worker process(pid:{}) to process all the data.".format( 536 self._workers[index].pid)) 537 if self._workers[index].is_alive(): 538 time.sleep(1) 539 continue 540 elif self._workers[index].exitcode != 0: 541 raise RuntimeError("Worker process(pid:{}) has stopped abnormally. Please check " \ 542 "the above log".format(self._workers[index].pid)) 543 break 544 545 if self._index_generator: 546 # use parallel index workers to generator index 547 self._index_workers = [None] * len(self._paths) 548 for index in range(len(self._paths)): 549 p = mp.Process(target=self._index_worker, name="IndexWorker" + str(index), args=(index,)) 550 p.daemon = True 551 p.start() 552 logger.info("Start worker process(pid:{}) to generate index.".format(p.pid)) 553 self._index_workers[index] = p 554 555 # wait the index workers stop 556 for index in range(len(self._paths)): 557 self._index_workers[index].join() 558 559 def _validate_array(self, k, v): 560 """ 561 Validate array item in schema 562 563 Args: 564 k (str): Key in dict. 565 v (dict): Sub dict in schema 566 567 Returns: 568 bool, whether the array item is valid. 569 str, error message. 570 """ 571 if v['type'] not in VALID_ARRAY_ATTRIBUTES: 572 error = "Field '{}' contain illegal " \ 573 "attribute '{}'.".format(k, v['type']) 574 return False, error 575 if 'shape' in v: 576 if isinstance(v['shape'], list) is False: 577 error = "Field '{}' contain illegal " \ 578 "attribute '{}'.".format(k, v['shape']) 579 return False, error 580 else: 581 error = "Field '{}' contains illegal attributes.".format(v) 582 return False, error 583 return True, '' 584 585 def _verify_based_on_schema(self, raw_data): 586 """ 587 Verify data according to schema and remove invalid data if validation failed. 588 589 1) allowed data type contains: "int32", "int64", "float32", "float64", "string", "bytes". 590 591 Args: 592 raw_data (list[dict]): List of raw data. 593 """ 594 error_data_dic = {} 595 schema_content = self._header.schema 596 for field in schema_content: 597 for i, v in enumerate(raw_data): 598 if i in error_data_dic: 599 continue 600 601 if field not in v: 602 error_data_dic[i] = "for schema, {} th data is wrong, " \ 603 "there is not field: '{}' in the raw data.".format(i, field) 604 continue 605 field_type = type(v[field]).__name__ 606 if field_type not in VALUE_TYPE_MAP: 607 error_data_dic[i] = "for schema, {} th data is wrong, " \ 608 "data type: '{}' for field: '{}' is not matched.".format(i, field_type, field) 609 continue 610 611 if schema_content[field]["type"] not in VALUE_TYPE_MAP[field_type]: 612 error_data_dic[i] = "for schema, {} th data is wrong, " \ 613 "data type: '{}' for field: '{}' is not matched." \ 614 .format(i, schema_content[field]["type"], field) 615 continue 616 617 if field_type == 'ndarray': 618 if 'shape' not in schema_content[field]: 619 error_data_dic[i] = "for schema, {} th data is wrong, " \ 620 "data shape for field: '{}' is not specified.".format(i, field) 621 elif 'type' not in schema_content[field]: 622 error_data_dic[i] = "for schema, {} th data is wrong, " \ 623 "data type for field: '{}' is not specified.".format(i, field) 624 elif schema_content[field]['type'] != str(v[field].dtype): 625 error_data_dic[i] = "for schema, {} th data is wrong, " \ 626 "data type: '{}' for field: '{}' is not matched." \ 627 .format(i, str(v[field].dtype), field) 628 else: 629 try: 630 np.reshape(v[field], schema_content[field]['shape']) 631 except ValueError: 632 error_data_dic[i] = "for schema, {} th data is wrong, " \ 633 "data shape: '{}' for field: '{}' is not matched." \ 634 .format(i, str(v[field].shape), field) 635 error_data_dic = sorted(error_data_dic.items(), reverse=True) 636 for i, v in error_data_dic: 637 raw_data.pop(i) 638 logger.warning(v) 639 640 def _validate_schema(self, content): 641 """ 642 Validate schema and return validation result and error message. 643 644 Args: 645 content (dict): Dict of raw schema. 646 647 Returns: 648 bool, whether the schema is valid. 649 str, error message. 650 """ 651 error = '' 652 if not content: 653 error = 'Schema content is empty.' 654 return False, error 655 if isinstance(content, dict) is False: 656 error = 'Schema content should be dict.' 657 return False, error 658 for k, v in content.items(): 659 if not re.match(r'^[0-9a-zA-Z\_]+$', k): 660 error = "Field '{}' should be composed of " \ 661 "'0-9' or 'a-z' or 'A-Z' or '_'.".format(k) 662 return False, error 663 if v and isinstance(v, dict): 664 if len(v) == 1 and 'type' in v: 665 if v['type'] not in VALID_ATTRIBUTES: 666 error = "Field '{}' contain illegal " \ 667 "attribute '{}'.".format(k, v['type']) 668 return False, error 669 elif len(v) == 2 and 'type' in v: 670 res_1, res_2 = self._validate_array(k, v) 671 if not res_1: 672 return res_1, res_2 673 else: 674 error = "Field '{}' contains illegal attributes.".format(v) 675 return False, error 676 else: 677 error = "Field '{}' should be dict.".format(k) 678 return False, error 679 return True, error 680 681 def _write_worker(self, i, in_queue): 682 """The worker do the data check and write to disk for parallel mode""" 683 while True: 684 # try to get new raw_data from master 685 try: 686 raw_data = in_queue.get(block=False) 687 except queue.Empty: 688 continue 689 690 # get EOF from master, worker should commit and stop 691 if raw_data == "EOF": 692 ret = self._writers[i].commit() 693 if ret != SUCCESS: 694 raise RuntimeError("Commit the {}th shard of MindRecord file failed.".format(i)) 695 break 696 697 # check the raw_data 698 if not isinstance(raw_data, list): 699 raise ParamTypeError('raw_data', 'list') 700 for each_raw in raw_data: 701 if not isinstance(each_raw, dict): 702 raise ParamTypeError('raw_data item', 'dict') 703 704 self._verify_based_on_schema(raw_data) 705 self._writers[i].write_raw_data(raw_data, True, False) 706