1# Copyright 2023 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15""" 16The configuration module provides various functions to set and get the supported 17configuration parameters. 18 19Common imported modules in corresponding API examples are as follows: 20 21.. code-block:: 22 23 from mindspore.mindrecord import set_enc_key, set_enc_mode, set_dec_mode, set_hash_mode 24""" 25 26import hashlib 27import os 28import shutil 29import stat 30import time 31 32from mindspore import log as logger 33from mindspore._c_expression import _encrypt, _decrypt_data 34from .shardutils import MIN_FILE_SIZE 35 36 37__all__ = ['set_enc_key', 38 'set_enc_mode', 39 'set_dec_mode', 40 'set_hash_mode'] 41 42 43# default encode key and hash mode 44ENC_KEY = None 45ENC_MODE = "AES-GCM" 46DEC_MODE = None 47HASH_MODE = None 48 49 50# the final mindrecord after hash check and encode should be like below 51# 1. for create new mindrecord: should do hash first, then encode 52# mindrecord -> 53# mindrecord+hash_value+len(4bytes)+'HASH' -> 54# enc_mindrecord+'ENCRYPT' 55# 2. for read mindrecord, should decode first, then do hash check 56# enc_mindrecord+'ENCRYPT' -> 57# mindrecord+hash_value+len(4bytes)+'HASH' 58 59 60# mindrecord file encode end flag, we will append 'ENCRYPT' to the end of file 61ENCRYPT_END_FLAG = str('ENCRYPT').encode('utf-8') 62 63 64# mindrecord file hash check flag, we will append hash value+'HASH' to the end of file 65HASH_END_FLAG = str('HASH').encode('utf-8') 66 67 68# length of hash value (4bytes) + 'HASH' 69LEN_HASH_WITH_END_FLAG = 4 + len(HASH_END_FLAG) 70 71 72# directory which stored decrypt mindrecord files 73DECRYPT_DIRECTORY = ".decrypt_mindrecord" 74DECRYPT_DIRECTORY_LIST = [] 75 76 77# time for warning when encrypt/decrypt or calculate hash takes too long time 78CALCULATE_HASH_TIME = 0 79VERIFY_HASH_TIME = 0 80ENCRYPT_TIME = 0 81DECRYPT_TIME = 0 82WARNING_INTERVAL = 30 # 30s 83 84 85def set_enc_key(enc_key): 86 """ 87 Set the encode key. 88 89 Note: 90 When the encryption algorithm is ``"SM4-CBC"`` , only 16 bit length key are supported. 91 92 Args: 93 enc_key (str): Str-type key used for encryption. The valid length is 16, 24, or 32. 94 ``None`` indicates that encryption is not enabled. 95 96 Raises: 97 ValueError: The input is not str or length error. 98 99 Examples: 100 >>> from mindspore.mindrecord import set_enc_key 101 >>> 102 >>> set_enc_key("0123456789012345") 103 """ 104 global ENC_KEY 105 106 if enc_key is None: 107 ENC_KEY = None 108 return 109 110 if not isinstance(enc_key, str): 111 raise ValueError("The input enc_key is not str.") 112 113 if len(enc_key) not in [16, 24, 32]: 114 raise ValueError("The length of input enc_key is not 16, 24, 32.") 115 116 ENC_KEY = enc_key 117 118 119def _get_enc_key(): 120 """Get the encode key. If the enc_key is not set, it will return ``None``.""" 121 global ENC_KEY 122 123 return ENC_KEY 124 125 126def set_enc_mode(enc_mode="AES-GCM"): 127 """ 128 Set the encode mode. 129 130 Args: 131 enc_mode (Union[str, function], optional): This parameter is valid only when enc_key is not set to ``None`` . 132 Specifies the encryption mode or customized encryption function, currently supports ``"AES-GCM"``, 133 ``"AES-CBC"`` and ``"SM4-CBC"`` . Default: ``"AES-GCM"`` . If it is customized encryption, users need 134 to ensure its correctness and raise exceptions when errors occur. 135 136 Raises: 137 ValueError: The input is not valid encode mode or callable function. 138 139 Examples: 140 >>> from mindspore.mindrecord import set_enc_mode 141 >>> 142 >>> set_enc_mode("AES-GCM") 143 """ 144 global ENC_MODE 145 146 if callable(enc_mode): 147 ENC_MODE = enc_mode 148 return 149 150 if not isinstance(enc_mode, str): 151 raise ValueError("The input enc_mode is not str.") 152 153 if enc_mode not in ["AES-GCM", "AES-CBC", "SM4-CBC"]: 154 raise ValueError("The input enc_mode is invalid.") 155 156 ENC_MODE = enc_mode 157 158 159def _get_enc_mode(): 160 """Get the encode mode. If the enc_mode is not set, it will return default encode mode ``"AES-GCM"``.""" 161 global ENC_MODE 162 163 return ENC_MODE 164 165 166def set_dec_mode(dec_mode="AES-GCM"): 167 """ 168 Set the decode mode. 169 170 If the built-in `enc_mode` is used and `dec_mode` is not specified, the encryption algorithm specified by `enc_mode` 171 is used for decryption. If you are using customized encryption function, you must specify customized decryption 172 function at read time. 173 174 Args: 175 dec_mode (Union[str, function], optional): This parameter is valid only when enc_key is not set to ``None`` . 176 Specifies the decryption mode or customized decryption function, currently supports ``"AES-GCM"``, 177 ``"AES-CBC"`` and ``"SM4-CBC"`` . Default: ``"AES-GCM"`` . ``None`` indicates that decryption 178 mode is not defined. If it is customized decryption, users need to ensure its correctness and raise 179 exceptions when errors occur. 180 181 Raises: 182 ValueError: The input is not valid decode mode or callable function. 183 184 Examples: 185 >>> from mindspore.mindrecord import set_dec_mode 186 >>> 187 >>> set_dec_mode("AES-GCM") 188 """ 189 global DEC_MODE 190 191 if dec_mode is None: 192 DEC_MODE = None 193 return 194 195 if callable(dec_mode): 196 DEC_MODE = dec_mode 197 return 198 199 if not isinstance(dec_mode, str): 200 raise ValueError("The input dec_mode is not str.") 201 202 if dec_mode not in ["AES-GCM", "AES-CBC", "SM4-CBC"]: 203 raise ValueError("The input dec_mode is invalid.") 204 205 DEC_MODE = dec_mode 206 207 208def _get_dec_mode(): 209 """Get the decode mode. If the dec_mode is not set, it will return encode mode.""" 210 global ENC_MODE 211 global DEC_MODE 212 213 if DEC_MODE is None: 214 if callable(ENC_MODE): 215 raise RuntimeError("You use custom encryption, so you must also define custom decryption.") 216 return ENC_MODE 217 218 return DEC_MODE 219 220 221def _get_enc_mode_as_str(): 222 """Get the encode mode as string. The length of mode should be 7.""" 223 global ENC_MODE 224 225 valid_enc_mode = "" 226 if callable(ENC_MODE): 227 valid_enc_mode = "UDF-ENC" # "UDF-ENC" 228 else: 229 valid_enc_mode = ENC_MODE 230 231 if len(valid_enc_mode) != 7: 232 raise RuntimeError("The length of enc_mode string is not 7.") 233 234 return str(valid_enc_mode).encode('utf-8') 235 236 237def _get_dec_mode_as_str(): 238 """Get the decode mode as string. The length of mode should be 7.""" 239 global ENC_MODE 240 global DEC_MODE 241 242 valid_dec_mode = "" 243 244 if DEC_MODE is None: 245 if callable(ENC_MODE): 246 raise RuntimeError("You use custom encryption, so you must also define custom decryption.") 247 valid_dec_mode = ENC_MODE # "AES-GCM" / "AES-CBC" / "SM4-CBC" 248 elif callable(DEC_MODE): 249 valid_dec_mode = "UDF-ENC" # "UDF-ENC" 250 else: 251 valid_dec_mode = DEC_MODE 252 253 if len(valid_dec_mode) != 7: 254 raise RuntimeError("The length of enc_mode string is not 7.") 255 256 return str(valid_dec_mode).encode('utf-8') 257 258 259def set_hash_mode(hash_mode): 260 """ 261 Set the hash mode to ensure mindrecord file integrity. 262 263 Args: 264 hash_mode (Union[str, function]): The parameter is used to specify the hash mode. Specifies the hash 265 mode or customized hash function, currently supports ``None``, ``"sha256"``, 266 ``"sha384"``, ``"sha512"``, ``"sha3_256"``, ``"sha3_384"`` 267 and ``"sha3_512"``. ``None`` indicates that hash check is not enabled. 268 269 Raises: 270 ValueError: The input is not valid hash mode or callable function. 271 272 Examples: 273 >>> from mindspore.mindrecord import set_hash_mode 274 >>> 275 >>> set_hash_mode("sha256") 276 """ 277 global HASH_MODE 278 279 if hash_mode is None: 280 HASH_MODE = None 281 return 282 283 if callable(hash_mode): 284 HASH_MODE = hash_mode 285 return 286 287 if not isinstance(hash_mode, str): 288 raise ValueError("The input hash_mode is not str.") 289 290 if hash_mode not in ["sha256", "sha384", "sha512", "sha3_256", "sha3_384", "sha3_512"]: 291 raise ValueError("The input hash_mode is invalid.") 292 293 HASH_MODE = hash_mode 294 295 296def _get_hash_func(): 297 """Get the hash func by hash mode""" 298 global HASH_MODE 299 300 if HASH_MODE is None: 301 raise RuntimeError("The HASH_MODE is None, no matching hash function.") 302 303 if callable(HASH_MODE): 304 return HASH_MODE 305 306 if HASH_MODE == "sha256": 307 return hashlib.sha256() 308 if HASH_MODE == "sha384": 309 return hashlib.sha384() 310 if HASH_MODE == "sha512": 311 return hashlib.sha512() 312 if HASH_MODE == "sha3_256": 313 return hashlib.sha3_256() 314 if HASH_MODE == "sha3_384": 315 return hashlib.sha3_384() 316 if HASH_MODE == "sha3_512": 317 return hashlib.sha3_512() 318 raise RuntimeError("The HASH_MODE: {} is invalid.".format(HASH_MODE)) 319 320 321def _get_hash_mode(): 322 """Get the hash check mode.""" 323 global HASH_MODE 324 325 return HASH_MODE 326 327 328def calculate_file_hash(filename, whole=True): 329 """Calculate the file's hash""" 330 if not os.path.exists(filename): 331 raise RuntimeError("The input: {} is not exists.".format(filename)) 332 333 if not os.path.isfile(filename): 334 raise RuntimeError("The input: {} should be a regular file.".format(filename)) 335 336 # get the hash func 337 m = _get_hash_func() 338 339 f = open(filename, 'rb') 340 341 # get the file size first 342 if whole: 343 file_size = os.path.getsize(filename) 344 else: 345 len_hash_offset = os.path.getsize(filename) - LEN_HASH_WITH_END_FLAG 346 try: 347 f.seek(len_hash_offset) 348 except Exception as e: # pylint: disable=W0703 349 f.close() 350 raise RuntimeError("Seek the file: {} to position: {} failed. Error: {}" 351 .format(filename, len_hash_offset, str(e))) 352 353 len_hash = int.from_bytes(f.read(4), byteorder='big') # length of hash value is 4 bytes 354 file_size = os.path.getsize(filename) - LEN_HASH_WITH_END_FLAG - len_hash 355 356 offset = 64 * 1024 * 1024 ## read the offset 64M 357 current_offset = 0 ## use this to seek file 358 359 # read the file with offset and do sha256 hash 360 hash_value = str("").encode('utf-8') 361 while True: 362 if (file_size - current_offset) >= offset: 363 read_size = offset 364 elif file_size - current_offset > 0: 365 read_size = file_size - current_offset 366 else: 367 # have read the entire file 368 break 369 370 try: 371 f.seek(current_offset) 372 except Exception as e: # pylint: disable=W0703 373 f.close() 374 raise RuntimeError("Seek the file: {} to position: {} failed. Error: {}" 375 .format(filename, current_offset, str(e))) 376 377 data = f.read(read_size) 378 if callable(m): 379 hash_value = m(data, hash_value) 380 if not isinstance(hash_value, bytes): 381 raise RuntimeError("User defined hash function should return hash value which is bytes type.") 382 if hash_value is None: 383 raise RuntimeError("User defined hash function return empty.") 384 else: 385 m.update(data) 386 387 current_offset += read_size 388 389 f.close() 390 391 if callable(m): 392 return hash_value 393 return m.digest() 394 395 396def append_hash_to_file(filename): 397 """append the hash value to the end of file""" 398 if not os.path.exists(filename): 399 raise RuntimeError("The input: {} is not exists.".format(filename)) 400 401 if not os.path.isfile(filename): 402 raise RuntimeError("The input: {} should be a regular file.".format(filename)) 403 404 logger.info("Begin to calculate the hash of the file: {}.".format(filename)) 405 start = time.time() 406 407 hash_value = calculate_file_hash(filename) 408 409 # append hash value, length of hash value (4bytes) and HASH_END_FLAG to the file 410 f = open(filename, 'ab') 411 f.write(hash_value) # append the hash value 412 f.write((len(hash_value)).to_bytes(4, byteorder='big', signed=False)) # append the length of hash value 413 f.write(HASH_END_FLAG) # append the HASH_END_FLAG 414 f.close() 415 416 end = time.time() 417 global CALCULATE_HASH_TIME 418 CALCULATE_HASH_TIME += end - start 419 if CALCULATE_HASH_TIME > WARNING_INTERVAL: 420 logger.warning("It takes another " + str(WARNING_INTERVAL) + 421 "s to calculate the hash value of the mindrecord file.") 422 CALCULATE_HASH_TIME = CALCULATE_HASH_TIME - WARNING_INTERVAL 423 424 # change the file mode 425 os.chmod(filename, stat.S_IRUSR | stat.S_IWUSR) 426 427 return True 428 429 430def get_hash_end_flag(filename): 431 """get the hash end flag from the file""" 432 if not os.path.exists(filename): 433 raise RuntimeError("The input: {} is not exists.".format(filename)) 434 435 if not os.path.isfile(filename): 436 raise RuntimeError("The input: {} should be a regular file.".format(filename)) 437 438 # get the file size first 439 file_size = os.path.getsize(filename) 440 offset = file_size - len(HASH_END_FLAG) 441 f = open(filename, 'rb') 442 443 # get the hash end flag which is HASH_END_FLAG 444 try: 445 f.seek(offset) 446 except Exception as e: # pylint: disable=W0703 447 f.close() 448 raise RuntimeError("Seek the file: {} to position: {} failed. Error: {}".format(filename, offset, str(e))) 449 450 data = f.read(len(HASH_END_FLAG)) 451 f.close() 452 453 return data 454 455 456def get_hash_value(filename): 457 """get the file's hash""" 458 if not os.path.exists(filename): 459 raise RuntimeError("The input: {} is not exists.".format(filename)) 460 461 if not os.path.isfile(filename): 462 raise RuntimeError("The input: {} should be a regular file.".format(filename)) 463 464 # get the file size first 465 file_size = os.path.getsize(filename) 466 467 # the hash_value+len(4bytes)+'HASH' is stored in the end of the file 468 offset = file_size - LEN_HASH_WITH_END_FLAG 469 f = open(filename, 'rb') 470 471 # seek the position for the length of hash value 472 try: 473 f.seek(offset) 474 except Exception as e: # pylint: disable=W0703 475 f.close() 476 raise RuntimeError("Seek the file: {} to position: {} failed. Error: {}".format(filename, offset, str(e))) 477 478 len_hash = int.from_bytes(f.read(4), byteorder='big') # length of hash value is 4 bytes 479 hash_value_offset = file_size - len_hash - LEN_HASH_WITH_END_FLAG 480 481 # seek the position for the hash value 482 try: 483 f.seek(hash_value_offset) 484 except Exception as e: # pylint: disable=W0703 485 f.close() 486 raise RuntimeError("Seek the file: {} to position: {} failed. Error: {}" 487 .format(filename, hash_value_offset, str(e))) 488 489 # read the hash value 490 data = f.read(len_hash) 491 f.close() 492 493 return data 494 495 496def verify_file_hash(filename): 497 """Calculate the file hash and compare it with the hash value which is stored in the file""" 498 if not os.path.exists(filename): 499 raise RuntimeError("The input: {} is not exists.".format(filename)) 500 501 if not os.path.isfile(filename): 502 raise RuntimeError("The input: {} should be a regular file.".format(filename)) 503 504 # verify the hash end flag 505 stored_hash_end_flag = get_hash_end_flag(filename) 506 if _get_hash_mode() is not None: 507 if stored_hash_end_flag != HASH_END_FLAG: 508 raise RuntimeError("The mindrecord file is not hashed. You can set " + 509 "'mindspore.mindrecord.config.set_hash_mode(None)' to disable the hash check.") 510 else: 511 if stored_hash_end_flag == HASH_END_FLAG: 512 raise RuntimeError("The mindrecord file is hashed. You need to configure " + 513 "'mindspore.mindrecord.config.set_hash_mode(...)' to enable the hash check.") 514 return True 515 516 # get the pre hash value from the end of the file 517 stored_hash_value = get_hash_value(filename) 518 519 logger.info("Begin to verify the hash of the file: {}.".format(filename)) 520 start = time.time() 521 522 # calculate hash by the file 523 current_hash = calculate_file_hash(filename, False) 524 525 if stored_hash_value != current_hash: 526 raise RuntimeError("The input file: " + filename + " hash check fail. The file may be damaged. " 527 "Or configure a correct hash mode.") 528 529 end = time.time() 530 global VERIFY_HASH_TIME 531 VERIFY_HASH_TIME += end - start 532 if VERIFY_HASH_TIME > WARNING_INTERVAL: 533 logger.warning("It takes another " + str(WARNING_INTERVAL) + 534 "s to verify the hash value of the mindrecord file.") 535 VERIFY_HASH_TIME = VERIFY_HASH_TIME - WARNING_INTERVAL 536 537 return True 538 539 540def encrypt(filename, enc_key, enc_mode): 541 """Encrypt the file and the original file will be deleted""" 542 if not os.path.exists(filename): 543 raise RuntimeError("The input: {} is not exists.".format(filename)) 544 545 if not os.path.isfile(filename): 546 raise RuntimeError("The input: {} should be a regular file.".format(filename)) 547 548 logger.info("Begin to encrypt file: {}.".format(filename)) 549 start = time.time() 550 551 offset = 64 * 1024 * 1024 ## read the offset 64M 552 current_offset = 0 ## use this to seek file 553 file_size = os.path.getsize(filename) 554 555 f = open(filename, 'rb') 556 557 # create new encrypt file 558 encrypt_filename = filename + ".encrypt" 559 f_encrypt = open(encrypt_filename, 'wb') 560 561 try: 562 if callable(enc_mode): 563 enc_mode(f, file_size, f_encrypt, enc_key) 564 else: 565 # read the file with offset and do encrypt 566 # original mindrecord file like: 567 # |64M|64M|64M|64M|... 568 # encrypted mindrecord file like: 569 # len+encrypt_data|len+encrypt_data|len+encrypt_data|...|0|enc_mode|ENCRYPT_END_FLAG 570 while True: 571 if file_size - current_offset >= offset: 572 read_size = offset 573 elif file_size - current_offset > 0: 574 read_size = file_size - current_offset 575 else: 576 # have read the entire file 577 break 578 579 try: 580 f.seek(current_offset) 581 except Exception as e: # pylint: disable=W0703 582 f.close() 583 f_encrypt.close() 584 raise RuntimeError("Seek the file: {} to position: {} failed. Error: {}" 585 .format(filename, current_offset, str(e))) 586 587 data = f.read(read_size) 588 encode_data = _encrypt(data, len(data), enc_key, len(enc_key), enc_mode) 589 590 # write length of data to encrypt file 591 f_encrypt.write(int(len(encode_data)).to_bytes(length=4, byteorder='big', signed=True)) 592 593 # write data to encrypt file 594 f_encrypt.write(encode_data) 595 596 current_offset += read_size 597 except Exception as e: 598 f.close() 599 f_encrypt.close() 600 os.chmod(encrypt_filename, stat.S_IRUSR | stat.S_IWUSR) 601 raise e 602 603 f.close() 604 605 # writing 0 at the end indicates that all encrypted data has been written. 606 f_encrypt.write(int(0).to_bytes(length=4, byteorder='big', signed=True)) 607 608 # write enc_mode 609 f_encrypt.write(_get_enc_mode_as_str()) 610 611 # write ENCRYPT_END_FLAG 612 f_encrypt.write(ENCRYPT_END_FLAG) 613 f_encrypt.close() 614 615 end = time.time() 616 global ENCRYPT_TIME 617 ENCRYPT_TIME += end - start 618 if ENCRYPT_TIME > WARNING_INTERVAL: 619 logger.warning("It takes another " + str(WARNING_INTERVAL) + "s to encrypt the mindrecord file.") 620 ENCRYPT_TIME = ENCRYPT_TIME - WARNING_INTERVAL 621 622 # change the file mode 623 os.chmod(encrypt_filename, stat.S_IRUSR | stat.S_IWUSR) 624 625 # move the encrypt file to origin file 626 shutil.move(encrypt_filename, filename) 627 628 return True 629 630 631def _get_encrypt_end_flag(filename): 632 """get encrypt end flag from the file""" 633 if not os.path.exists(filename): 634 raise RuntimeError("The input: {} is not exists.".format(filename)) 635 636 if not os.path.isfile(filename): 637 raise RuntimeError("The input: {} should be a regular file.".format(filename)) 638 639 # get the file size first 640 file_size = os.path.getsize(filename) 641 offset = file_size - len(ENCRYPT_END_FLAG) 642 643 f = open(filename, 'rb') 644 645 # get the encrypt end flag which is 'ENCRYPT' 646 try: 647 f.seek(offset) 648 except Exception as e: # pylint: disable=W0703 649 f.close() 650 raise RuntimeError("Seek the file: {} to position: {} failed. Error: {}".format(filename, offset, str(e))) 651 652 data = f.read(len(ENCRYPT_END_FLAG)) 653 f.close() 654 655 return data 656 657 658def _get_enc_mode_from_file(filename): 659 """get encrypt end flag from the file""" 660 if not os.path.exists(filename): 661 raise RuntimeError("The input: {} is not exists.".format(filename)) 662 663 if not os.path.isfile(filename): 664 raise RuntimeError("The input: {} should be a regular file.".format(filename)) 665 666 # get the file size first 667 file_size = os.path.getsize(filename) 668 offset = file_size - len(ENCRYPT_END_FLAG) - 7 669 670 f = open(filename, 'rb') 671 672 # get the encrypt end flag which is 'ENCRYPT' 673 try: 674 f.seek(offset) 675 except Exception as e: # pylint: disable=W0703 676 f.close() 677 raise RuntimeError("Seek the file: {} to position: {} failed. Error: {}".format(filename, offset, str(e))) 678 679 # read the enc_mode str which length is 7 680 data = f.read(7) 681 f.close() 682 683 return data 684 685 686def decrypt(filename, enc_key, dec_mode): 687 """decrypt the file by enc_key and dec_mode""" 688 if not os.path.exists(filename): 689 raise RuntimeError("The input: {} is not exists.".format(filename)) 690 691 if not os.path.isfile(filename): 692 raise RuntimeError("The input: {} should be a regular file.".format(filename)) 693 694 whole_file_size = os.path.getsize(filename) 695 if whole_file_size < MIN_FILE_SIZE: 696 raise RuntimeError("Invalid file, the size of mindrecord file: " + str(whole_file_size) + 697 " is smaller than the lower limit: " + str(MIN_FILE_SIZE) + 698 ".\n Please check file path: " + filename + 699 " and use 'FileWriter' to generate valid mindrecord files.") 700 701 global DECRYPT_DIRECTORY_LIST 702 703 # check ENCRYPT_END_FLAG 704 stored_encrypt_end_flag = _get_encrypt_end_flag(filename) 705 if _get_enc_key() is not None: 706 if stored_encrypt_end_flag != ENCRYPT_END_FLAG: 707 raise RuntimeError("The mindrecord file is not encrypted. You can set " + 708 "'mindspore.mindrecord.config.set_enc_key(None)' to disable the decryption.") 709 else: 710 if stored_encrypt_end_flag == ENCRYPT_END_FLAG: 711 raise RuntimeError("The mindrecord file is encrypted. You need to configure " + 712 "'mindspore.mindrecord.config.set_enc_key(...)' and " + 713 "'mindspore.mindrecord.config.set_enc_mode(...)' for decryption.") 714 return filename 715 716 # check dec_mode with enc_mode 717 enc_mode_from_file = _get_enc_mode_from_file(filename) 718 if enc_mode_from_file != _get_dec_mode_as_str(): 719 raise RuntimeError("Failed to decrypt data, please check if enc_key and enc_mode / dec_mode is valid.") 720 721 logger.info("Begin to decrypt file: {}.".format(filename)) 722 start = time.time() 723 724 file_size = os.path.getsize(filename) - len(ENCRYPT_END_FLAG) 725 726 f = open(filename, 'rb') 727 728 real_path_filename = os.path.realpath(filename) 729 parent_dir = os.path.dirname(real_path_filename) 730 only_filename = os.path.basename(real_path_filename) 731 current_decrypt_dir = os.path.join(parent_dir, DECRYPT_DIRECTORY) 732 if not os.path.exists(current_decrypt_dir): 733 os.mkdir(current_decrypt_dir) 734 os.chmod(current_decrypt_dir, stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR) 735 logger.info("Create directory: {} to store decrypt mindrecord files." 736 .format(os.path.join(parent_dir, DECRYPT_DIRECTORY))) 737 738 if current_decrypt_dir not in DECRYPT_DIRECTORY_LIST: 739 DECRYPT_DIRECTORY_LIST.append(current_decrypt_dir) 740 logger.warning("The decrypt mindrecord file will be stored in [" + current_decrypt_dir + "] directory. " 741 "If you don't use it anymore after train / eval, you need to delete it manually.") 742 743 # create new decrypt file 744 decrypt_filename = os.path.join(current_decrypt_dir, only_filename) 745 if os.path.isfile(decrypt_filename): 746 # the file which had been decrypted early maybe update by user, so we remove the old decrypted one 747 os.remove(decrypt_filename) 748 749 f_decrypt = open(decrypt_filename, 'wb+') 750 751 try: 752 if callable(dec_mode): 753 dec_mode(f, file_size, f_decrypt, enc_key) 754 else: 755 # read the file and do decrypt 756 # encrypted mindrecord file like: 757 # len+encrypt_data|len+encrypt_data|len+encrypt_data|...|0|enc_mode|ENCRYPT_END_FLAG 758 current_offset = 0 ## use this to seek file 759 length = int().from_bytes(f.read(4), byteorder='big', signed=True) 760 while length != 0: 761 # current_offset is the encrypted data 762 current_offset += 4 763 try: 764 f.seek(current_offset) 765 except Exception as e: # pylint: disable=W0703 766 f.close() 767 raise RuntimeError("Seek the file: {} to position: {} failed. Error: {}" 768 .format(filename, current_offset, str(e))) 769 770 data = f.read(length) 771 decode_data = _decrypt_data(data, len(data), enc_key, len(enc_key), dec_mode) 772 773 if decode_data is None: 774 raise RuntimeError("Failed to decrypt data, " + 775 "please check if enc_key and enc_mode / dec_mode is valid.") 776 777 # write to decrypt file 778 f_decrypt.write(decode_data) 779 780 # current_offset is the length of next encrypted data block 781 current_offset += length 782 try: 783 f.seek(current_offset) 784 except Exception as e: # pylint: disable=W0703 785 f.close() 786 raise RuntimeError("Seek the file: {} to position: {} failed. Error: {}" 787 .format(filename, current_offset, str(e))) 788 789 length = int().from_bytes(f.read(4), byteorder='big', signed=True) 790 except Exception as e: 791 f.close() 792 f_decrypt.close() 793 os.chmod(decrypt_filename, stat.S_IRUSR | stat.S_IWUSR) 794 raise e 795 796 f.close() 797 f_decrypt.close() 798 799 end = time.time() 800 global DECRYPT_TIME 801 DECRYPT_TIME += end - start 802 if DECRYPT_TIME > WARNING_INTERVAL: 803 logger.warning("It takes another " + str(WARNING_INTERVAL) + "s to decrypt the mindrecord file.") 804 DECRYPT_TIME = DECRYPT_TIME - WARNING_INTERVAL 805 806 # change the file mode 807 os.chmod(decrypt_filename, stat.S_IRUSR | stat.S_IWUSR) 808 809 return decrypt_filename 810