• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2023 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""
16The configuration module provides various functions to set and get the supported
17configuration parameters.
18
19Common imported modules in corresponding API examples are as follows:
20
21.. code-block::
22
23    from mindspore.mindrecord import set_enc_key, set_enc_mode, set_dec_mode, set_hash_mode
24"""
25
26import hashlib
27import os
28import shutil
29import stat
30import time
31
32from mindspore import log as logger
33from mindspore._c_expression import _encrypt, _decrypt_data
34from .shardutils import MIN_FILE_SIZE
35
36
37__all__ = ['set_enc_key',
38           'set_enc_mode',
39           'set_dec_mode',
40           'set_hash_mode']
41
42
43# default encode key and hash mode
44ENC_KEY = None
45ENC_MODE = "AES-GCM"
46DEC_MODE = None
47HASH_MODE = None
48
49
50# the final mindrecord after hash check and encode should be like below
51# 1. for create new mindrecord: should do hash first, then encode
52# mindrecord ->
53#               mindrecord+hash_value+len(4bytes)+'HASH' ->
54#                                                                enc_mindrecord+'ENCRYPT'
55# 2. for read mindrecord, should decode first, then do hash check
56# enc_mindrecord+'ENCRYPT' ->
57#                             mindrecord+hash_value+len(4bytes)+'HASH'
58
59
60# mindrecord file encode end flag, we will append 'ENCRYPT' to the end of file
61ENCRYPT_END_FLAG = str('ENCRYPT').encode('utf-8')
62
63
64# mindrecord file hash check flag, we will append hash value+'HASH' to the end of file
65HASH_END_FLAG = str('HASH').encode('utf-8')
66
67
68# length of hash value (4bytes) + 'HASH'
69LEN_HASH_WITH_END_FLAG = 4 + len(HASH_END_FLAG)
70
71
72# directory which stored decrypt mindrecord files
73DECRYPT_DIRECTORY = ".decrypt_mindrecord"
74DECRYPT_DIRECTORY_LIST = []
75
76
77# time for warning when encrypt/decrypt or calculate hash takes too long time
78CALCULATE_HASH_TIME = 0
79VERIFY_HASH_TIME = 0
80ENCRYPT_TIME = 0
81DECRYPT_TIME = 0
82WARNING_INTERVAL = 30   # 30s
83
84
85def set_enc_key(enc_key):
86    """
87    Set the encode key.
88
89    Note:
90        When the encryption algorithm is ``"SM4-CBC"`` , only 16 bit length key are supported.
91
92    Args:
93        enc_key (str): Str-type key used for encryption. The valid length is 16, 24, or 32.
94            ``None`` indicates that encryption is not enabled.
95
96    Raises:
97        ValueError: The input is not str or length error.
98
99    Examples:
100        >>> from mindspore.mindrecord import set_enc_key
101        >>>
102        >>> set_enc_key("0123456789012345")
103    """
104    global ENC_KEY
105
106    if enc_key is None:
107        ENC_KEY = None
108        return
109
110    if not isinstance(enc_key, str):
111        raise ValueError("The input enc_key is not str.")
112
113    if len(enc_key) not in [16, 24, 32]:
114        raise ValueError("The length of input enc_key is not 16, 24, 32.")
115
116    ENC_KEY = enc_key
117
118
119def _get_enc_key():
120    """Get the encode key. If the enc_key is not set, it will return ``None``."""
121    global ENC_KEY
122
123    return ENC_KEY
124
125
126def set_enc_mode(enc_mode="AES-GCM"):
127    """
128    Set the encode mode.
129
130    Args:
131        enc_mode (Union[str, function], optional): This parameter is valid only when enc_key is not set to ``None`` .
132            Specifies the encryption mode or customized encryption function, currently supports ``"AES-GCM"``,
133            ``"AES-CBC"`` and ``"SM4-CBC"`` . Default: ``"AES-GCM"`` . If it is customized encryption, users need
134            to ensure its correctness and raise exceptions when errors occur.
135
136    Raises:
137        ValueError: The input is not valid encode mode or callable function.
138
139    Examples:
140        >>> from mindspore.mindrecord import set_enc_mode
141        >>>
142        >>> set_enc_mode("AES-GCM")
143    """
144    global ENC_MODE
145
146    if callable(enc_mode):
147        ENC_MODE = enc_mode
148        return
149
150    if not isinstance(enc_mode, str):
151        raise ValueError("The input enc_mode is not str.")
152
153    if enc_mode not in ["AES-GCM", "AES-CBC", "SM4-CBC"]:
154        raise ValueError("The input enc_mode is invalid.")
155
156    ENC_MODE = enc_mode
157
158
159def _get_enc_mode():
160    """Get the encode mode. If the enc_mode is not set, it will return default encode mode ``"AES-GCM"``."""
161    global ENC_MODE
162
163    return ENC_MODE
164
165
166def set_dec_mode(dec_mode="AES-GCM"):
167    """
168    Set the decode mode.
169
170    If the built-in `enc_mode` is used and `dec_mode` is not specified, the encryption algorithm specified by `enc_mode`
171    is used for decryption. If you are using customized encryption function, you must specify customized decryption
172    function at read time.
173
174    Args:
175        dec_mode (Union[str, function], optional): This parameter is valid only when enc_key is not set to ``None`` .
176            Specifies the decryption mode or customized decryption function, currently supports ``"AES-GCM"``,
177            ``"AES-CBC"`` and ``"SM4-CBC"`` . Default: ``"AES-GCM"`` . ``None`` indicates that decryption
178            mode is not defined. If it is customized decryption, users need to ensure its correctness and raise
179            exceptions when errors occur.
180
181    Raises:
182        ValueError: The input is not valid decode mode or callable function.
183
184    Examples:
185        >>> from mindspore.mindrecord import set_dec_mode
186        >>>
187        >>> set_dec_mode("AES-GCM")
188    """
189    global DEC_MODE
190
191    if dec_mode is None:
192        DEC_MODE = None
193        return
194
195    if callable(dec_mode):
196        DEC_MODE = dec_mode
197        return
198
199    if not isinstance(dec_mode, str):
200        raise ValueError("The input dec_mode is not str.")
201
202    if dec_mode not in ["AES-GCM", "AES-CBC", "SM4-CBC"]:
203        raise ValueError("The input dec_mode is invalid.")
204
205    DEC_MODE = dec_mode
206
207
208def _get_dec_mode():
209    """Get the decode mode. If the dec_mode is not set, it will return encode mode."""
210    global ENC_MODE
211    global DEC_MODE
212
213    if DEC_MODE is None:
214        if callable(ENC_MODE):
215            raise RuntimeError("You use custom encryption, so you must also define custom decryption.")
216        return ENC_MODE
217
218    return DEC_MODE
219
220
221def _get_enc_mode_as_str():
222    """Get the encode mode as string. The length of mode should be 7."""
223    global ENC_MODE
224
225    valid_enc_mode = ""
226    if callable(ENC_MODE):
227        valid_enc_mode = "UDF-ENC"  # "UDF-ENC"
228    else:
229        valid_enc_mode = ENC_MODE
230
231    if len(valid_enc_mode) != 7:
232        raise RuntimeError("The length of enc_mode string is not 7.")
233
234    return str(valid_enc_mode).encode('utf-8')
235
236
237def _get_dec_mode_as_str():
238    """Get the decode mode as string. The length of mode should be 7."""
239    global ENC_MODE
240    global DEC_MODE
241
242    valid_dec_mode = ""
243
244    if DEC_MODE is None:
245        if callable(ENC_MODE):
246            raise RuntimeError("You use custom encryption, so you must also define custom decryption.")
247        valid_dec_mode = ENC_MODE   # "AES-GCM" / "AES-CBC" / "SM4-CBC"
248    elif callable(DEC_MODE):
249        valid_dec_mode = "UDF-ENC"  # "UDF-ENC"
250    else:
251        valid_dec_mode = DEC_MODE
252
253    if len(valid_dec_mode) != 7:
254        raise RuntimeError("The length of enc_mode string is not 7.")
255
256    return str(valid_dec_mode).encode('utf-8')
257
258
259def set_hash_mode(hash_mode):
260    """
261    Set the hash mode to ensure mindrecord file integrity.
262
263    Args:
264        hash_mode (Union[str, function]): The parameter is used to specify the hash mode. Specifies the hash
265            mode or customized hash function, currently supports ``None``, ``"sha256"``,
266            ``"sha384"``, ``"sha512"``, ``"sha3_256"``, ``"sha3_384"``
267            and ``"sha3_512"``. ``None`` indicates that hash check is not enabled.
268
269    Raises:
270        ValueError: The input is not valid hash mode or callable function.
271
272    Examples:
273        >>> from mindspore.mindrecord import set_hash_mode
274        >>>
275        >>> set_hash_mode("sha256")
276    """
277    global HASH_MODE
278
279    if hash_mode is None:
280        HASH_MODE = None
281        return
282
283    if callable(hash_mode):
284        HASH_MODE = hash_mode
285        return
286
287    if not isinstance(hash_mode, str):
288        raise ValueError("The input hash_mode is not str.")
289
290    if hash_mode not in ["sha256", "sha384", "sha512", "sha3_256", "sha3_384", "sha3_512"]:
291        raise ValueError("The input hash_mode is invalid.")
292
293    HASH_MODE = hash_mode
294
295
296def _get_hash_func():
297    """Get the hash func by hash mode"""
298    global HASH_MODE
299
300    if HASH_MODE is None:
301        raise RuntimeError("The HASH_MODE is None, no matching hash function.")
302
303    if callable(HASH_MODE):
304        return HASH_MODE
305
306    if HASH_MODE == "sha256":
307        return hashlib.sha256()
308    if HASH_MODE == "sha384":
309        return hashlib.sha384()
310    if HASH_MODE == "sha512":
311        return hashlib.sha512()
312    if HASH_MODE == "sha3_256":
313        return hashlib.sha3_256()
314    if HASH_MODE == "sha3_384":
315        return hashlib.sha3_384()
316    if HASH_MODE == "sha3_512":
317        return hashlib.sha3_512()
318    raise RuntimeError("The HASH_MODE: {} is invalid.".format(HASH_MODE))
319
320
321def _get_hash_mode():
322    """Get the hash check mode."""
323    global HASH_MODE
324
325    return HASH_MODE
326
327
328def calculate_file_hash(filename, whole=True):
329    """Calculate the file's hash"""
330    if not os.path.exists(filename):
331        raise RuntimeError("The input: {} is not exists.".format(filename))
332
333    if not os.path.isfile(filename):
334        raise RuntimeError("The input: {} should be a regular file.".format(filename))
335
336    # get the hash func
337    m = _get_hash_func()
338
339    f = open(filename, 'rb')
340
341    # get the file size first
342    if whole:
343        file_size = os.path.getsize(filename)
344    else:
345        len_hash_offset = os.path.getsize(filename) - LEN_HASH_WITH_END_FLAG
346        try:
347            f.seek(len_hash_offset)
348        except Exception as e:  # pylint: disable=W0703
349            f.close()
350            raise RuntimeError("Seek the file: {} to position: {} failed. Error: {}"
351                               .format(filename, len_hash_offset, str(e)))
352
353        len_hash = int.from_bytes(f.read(4), byteorder='big')  # length of hash value is 4 bytes
354        file_size = os.path.getsize(filename) - LEN_HASH_WITH_END_FLAG - len_hash
355
356    offset = 64 * 1024 * 1024    ## read the offset 64M
357    current_offset = 0           ## use this to seek file
358
359    # read the file with offset and do sha256 hash
360    hash_value = str("").encode('utf-8')
361    while True:
362        if (file_size - current_offset) >= offset:
363            read_size = offset
364        elif file_size - current_offset > 0:
365            read_size = file_size - current_offset
366        else:
367            # have read the entire file
368            break
369
370        try:
371            f.seek(current_offset)
372        except Exception as e:  # pylint: disable=W0703
373            f.close()
374            raise RuntimeError("Seek the file: {} to position: {} failed. Error: {}"
375                               .format(filename, current_offset, str(e)))
376
377        data = f.read(read_size)
378        if callable(m):
379            hash_value = m(data, hash_value)
380            if not isinstance(hash_value, bytes):
381                raise RuntimeError("User defined hash function should return hash value which is bytes type.")
382            if hash_value is None:
383                raise RuntimeError("User defined hash function return empty.")
384        else:
385            m.update(data)
386
387        current_offset += read_size
388
389    f.close()
390
391    if callable(m):
392        return hash_value
393    return m.digest()
394
395
396def append_hash_to_file(filename):
397    """append the hash value to the end of file"""
398    if not os.path.exists(filename):
399        raise RuntimeError("The input: {} is not exists.".format(filename))
400
401    if not os.path.isfile(filename):
402        raise RuntimeError("The input: {} should be a regular file.".format(filename))
403
404    logger.info("Begin to calculate the hash of the file: {}.".format(filename))
405    start = time.time()
406
407    hash_value = calculate_file_hash(filename)
408
409    # append hash value, length of hash value (4bytes) and HASH_END_FLAG to the file
410    f = open(filename, 'ab')
411    f.write(hash_value)     # append the hash value
412    f.write((len(hash_value)).to_bytes(4, byteorder='big', signed=False))  # append the length of hash value
413    f.write(HASH_END_FLAG)  # append the HASH_END_FLAG
414    f.close()
415
416    end = time.time()
417    global CALCULATE_HASH_TIME
418    CALCULATE_HASH_TIME += end - start
419    if CALCULATE_HASH_TIME > WARNING_INTERVAL:
420        logger.warning("It takes another " + str(WARNING_INTERVAL) +
421                       "s to calculate the hash value of the mindrecord file.")
422        CALCULATE_HASH_TIME = CALCULATE_HASH_TIME - WARNING_INTERVAL
423
424    # change the file mode
425    os.chmod(filename, stat.S_IRUSR | stat.S_IWUSR)
426
427    return True
428
429
430def get_hash_end_flag(filename):
431    """get the hash end flag from the file"""
432    if not os.path.exists(filename):
433        raise RuntimeError("The input: {} is not exists.".format(filename))
434
435    if not os.path.isfile(filename):
436        raise RuntimeError("The input: {} should be a regular file.".format(filename))
437
438    # get the file size first
439    file_size = os.path.getsize(filename)
440    offset = file_size - len(HASH_END_FLAG)
441    f = open(filename, 'rb')
442
443    # get the hash end flag which is HASH_END_FLAG
444    try:
445        f.seek(offset)
446    except Exception as e:  # pylint: disable=W0703
447        f.close()
448        raise RuntimeError("Seek the file: {} to position: {} failed. Error: {}".format(filename, offset, str(e)))
449
450    data = f.read(len(HASH_END_FLAG))
451    f.close()
452
453    return data
454
455
456def get_hash_value(filename):
457    """get the file's hash"""
458    if not os.path.exists(filename):
459        raise RuntimeError("The input: {} is not exists.".format(filename))
460
461    if not os.path.isfile(filename):
462        raise RuntimeError("The input: {} should be a regular file.".format(filename))
463
464    # get the file size first
465    file_size = os.path.getsize(filename)
466
467    # the hash_value+len(4bytes)+'HASH' is stored in the end of the file
468    offset = file_size - LEN_HASH_WITH_END_FLAG
469    f = open(filename, 'rb')
470
471    # seek the position for the length of hash value
472    try:
473        f.seek(offset)
474    except Exception as e:  # pylint: disable=W0703
475        f.close()
476        raise RuntimeError("Seek the file: {} to position: {} failed. Error: {}".format(filename, offset, str(e)))
477
478    len_hash = int.from_bytes(f.read(4), byteorder='big')  # length of hash value is 4 bytes
479    hash_value_offset = file_size - len_hash - LEN_HASH_WITH_END_FLAG
480
481    # seek the position for the hash value
482    try:
483        f.seek(hash_value_offset)
484    except Exception as e:  # pylint: disable=W0703
485        f.close()
486        raise RuntimeError("Seek the file: {} to position: {} failed. Error: {}"
487                           .format(filename, hash_value_offset, str(e)))
488
489    # read the hash value
490    data = f.read(len_hash)
491    f.close()
492
493    return data
494
495
496def verify_file_hash(filename):
497    """Calculate the file hash and compare it with the hash value which is stored in the file"""
498    if not os.path.exists(filename):
499        raise RuntimeError("The input: {} is not exists.".format(filename))
500
501    if not os.path.isfile(filename):
502        raise RuntimeError("The input: {} should be a regular file.".format(filename))
503
504    # verify the hash end flag
505    stored_hash_end_flag = get_hash_end_flag(filename)
506    if _get_hash_mode() is not None:
507        if stored_hash_end_flag != HASH_END_FLAG:
508            raise RuntimeError("The mindrecord file is not hashed. You can set " +
509                               "'mindspore.mindrecord.config.set_hash_mode(None)' to disable the hash check.")
510    else:
511        if stored_hash_end_flag == HASH_END_FLAG:
512            raise RuntimeError("The mindrecord file is hashed. You need to configure " +
513                               "'mindspore.mindrecord.config.set_hash_mode(...)' to enable the hash check.")
514        return True
515
516    # get the pre hash value from the end of the file
517    stored_hash_value = get_hash_value(filename)
518
519    logger.info("Begin to verify the hash of the file: {}.".format(filename))
520    start = time.time()
521
522    # calculate hash by the file
523    current_hash = calculate_file_hash(filename, False)
524
525    if stored_hash_value != current_hash:
526        raise RuntimeError("The input file: " + filename + " hash check fail. The file may be damaged. "
527                           "Or configure a correct hash mode.")
528
529    end = time.time()
530    global VERIFY_HASH_TIME
531    VERIFY_HASH_TIME += end - start
532    if VERIFY_HASH_TIME > WARNING_INTERVAL:
533        logger.warning("It takes another " + str(WARNING_INTERVAL) +
534                       "s to verify the hash value of the mindrecord file.")
535        VERIFY_HASH_TIME = VERIFY_HASH_TIME - WARNING_INTERVAL
536
537    return True
538
539
540def encrypt(filename, enc_key, enc_mode):
541    """Encrypt the file and the original file will be deleted"""
542    if not os.path.exists(filename):
543        raise RuntimeError("The input: {} is not exists.".format(filename))
544
545    if not os.path.isfile(filename):
546        raise RuntimeError("The input: {} should be a regular file.".format(filename))
547
548    logger.info("Begin to encrypt file: {}.".format(filename))
549    start = time.time()
550
551    offset = 64 * 1024 * 1024    ## read the offset 64M
552    current_offset = 0           ## use this to seek file
553    file_size = os.path.getsize(filename)
554
555    f = open(filename, 'rb')
556
557    # create new encrypt file
558    encrypt_filename = filename + ".encrypt"
559    f_encrypt = open(encrypt_filename, 'wb')
560
561    try:
562        if callable(enc_mode):
563            enc_mode(f, file_size, f_encrypt, enc_key)
564        else:
565            # read the file with offset and do encrypt
566            # original mindrecord file like:
567            # |64M|64M|64M|64M|...
568            # encrypted mindrecord file like:
569            # len+encrypt_data|len+encrypt_data|len+encrypt_data|...|0|enc_mode|ENCRYPT_END_FLAG
570            while True:
571                if file_size - current_offset >= offset:
572                    read_size = offset
573                elif file_size - current_offset > 0:
574                    read_size = file_size - current_offset
575                else:
576                    # have read the entire file
577                    break
578
579                try:
580                    f.seek(current_offset)
581                except Exception as e:  # pylint: disable=W0703
582                    f.close()
583                    f_encrypt.close()
584                    raise RuntimeError("Seek the file: {} to position: {} failed. Error: {}"
585                                       .format(filename, current_offset, str(e)))
586
587                data = f.read(read_size)
588                encode_data = _encrypt(data, len(data), enc_key, len(enc_key), enc_mode)
589
590                # write length of data to encrypt file
591                f_encrypt.write(int(len(encode_data)).to_bytes(length=4, byteorder='big', signed=True))
592
593                # write data to encrypt file
594                f_encrypt.write(encode_data)
595
596                current_offset += read_size
597    except Exception as e:
598        f.close()
599        f_encrypt.close()
600        os.chmod(encrypt_filename, stat.S_IRUSR | stat.S_IWUSR)
601        raise e
602
603    f.close()
604
605    # writing 0 at the end indicates that all encrypted data has been written.
606    f_encrypt.write(int(0).to_bytes(length=4, byteorder='big', signed=True))
607
608    # write enc_mode
609    f_encrypt.write(_get_enc_mode_as_str())
610
611    # write ENCRYPT_END_FLAG
612    f_encrypt.write(ENCRYPT_END_FLAG)
613    f_encrypt.close()
614
615    end = time.time()
616    global ENCRYPT_TIME
617    ENCRYPT_TIME += end - start
618    if ENCRYPT_TIME > WARNING_INTERVAL:
619        logger.warning("It takes another " + str(WARNING_INTERVAL) + "s to encrypt the mindrecord file.")
620        ENCRYPT_TIME = ENCRYPT_TIME - WARNING_INTERVAL
621
622    # change the file mode
623    os.chmod(encrypt_filename, stat.S_IRUSR | stat.S_IWUSR)
624
625    # move the encrypt file to origin file
626    shutil.move(encrypt_filename, filename)
627
628    return True
629
630
631def _get_encrypt_end_flag(filename):
632    """get encrypt end flag from the file"""
633    if not os.path.exists(filename):
634        raise RuntimeError("The input: {} is not exists.".format(filename))
635
636    if not os.path.isfile(filename):
637        raise RuntimeError("The input: {} should be a regular file.".format(filename))
638
639    # get the file size first
640    file_size = os.path.getsize(filename)
641    offset = file_size - len(ENCRYPT_END_FLAG)
642
643    f = open(filename, 'rb')
644
645    # get the encrypt end flag which is 'ENCRYPT'
646    try:
647        f.seek(offset)
648    except Exception as e:  # pylint: disable=W0703
649        f.close()
650        raise RuntimeError("Seek the file: {} to position: {} failed. Error: {}".format(filename, offset, str(e)))
651
652    data = f.read(len(ENCRYPT_END_FLAG))
653    f.close()
654
655    return data
656
657
658def _get_enc_mode_from_file(filename):
659    """get encrypt end flag from the file"""
660    if not os.path.exists(filename):
661        raise RuntimeError("The input: {} is not exists.".format(filename))
662
663    if not os.path.isfile(filename):
664        raise RuntimeError("The input: {} should be a regular file.".format(filename))
665
666    # get the file size first
667    file_size = os.path.getsize(filename)
668    offset = file_size - len(ENCRYPT_END_FLAG) - 7
669
670    f = open(filename, 'rb')
671
672    # get the encrypt end flag which is 'ENCRYPT'
673    try:
674        f.seek(offset)
675    except Exception as e:  # pylint: disable=W0703
676        f.close()
677        raise RuntimeError("Seek the file: {} to position: {} failed. Error: {}".format(filename, offset, str(e)))
678
679    # read the enc_mode str which length is 7
680    data = f.read(7)
681    f.close()
682
683    return data
684
685
686def decrypt(filename, enc_key, dec_mode):
687    """decrypt the file by enc_key and dec_mode"""
688    if not os.path.exists(filename):
689        raise RuntimeError("The input: {} is not exists.".format(filename))
690
691    if not os.path.isfile(filename):
692        raise RuntimeError("The input: {} should be a regular file.".format(filename))
693
694    whole_file_size = os.path.getsize(filename)
695    if whole_file_size < MIN_FILE_SIZE:
696        raise RuntimeError("Invalid file, the size of mindrecord file: " + str(whole_file_size) +
697                           " is smaller than the lower limit: " + str(MIN_FILE_SIZE) +
698                           ".\n Please check file path: " + filename +
699                           " and use 'FileWriter' to generate valid mindrecord files.")
700
701    global DECRYPT_DIRECTORY_LIST
702
703    # check ENCRYPT_END_FLAG
704    stored_encrypt_end_flag = _get_encrypt_end_flag(filename)
705    if _get_enc_key() is not None:
706        if stored_encrypt_end_flag != ENCRYPT_END_FLAG:
707            raise RuntimeError("The mindrecord file is not encrypted. You can set " +
708                               "'mindspore.mindrecord.config.set_enc_key(None)' to disable the decryption.")
709    else:
710        if stored_encrypt_end_flag == ENCRYPT_END_FLAG:
711            raise RuntimeError("The mindrecord file is encrypted. You need to configure " +
712                               "'mindspore.mindrecord.config.set_enc_key(...)' and " +
713                               "'mindspore.mindrecord.config.set_enc_mode(...)' for decryption.")
714        return filename
715
716    # check dec_mode with enc_mode
717    enc_mode_from_file = _get_enc_mode_from_file(filename)
718    if enc_mode_from_file != _get_dec_mode_as_str():
719        raise RuntimeError("Failed to decrypt data, please check if enc_key and enc_mode / dec_mode is valid.")
720
721    logger.info("Begin to decrypt file: {}.".format(filename))
722    start = time.time()
723
724    file_size = os.path.getsize(filename) - len(ENCRYPT_END_FLAG)
725
726    f = open(filename, 'rb')
727
728    real_path_filename = os.path.realpath(filename)
729    parent_dir = os.path.dirname(real_path_filename)
730    only_filename = os.path.basename(real_path_filename)
731    current_decrypt_dir = os.path.join(parent_dir, DECRYPT_DIRECTORY)
732    if not os.path.exists(current_decrypt_dir):
733        os.mkdir(current_decrypt_dir)
734        os.chmod(current_decrypt_dir, stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR)
735        logger.info("Create directory: {} to store decrypt mindrecord files."
736                    .format(os.path.join(parent_dir, DECRYPT_DIRECTORY)))
737
738    if current_decrypt_dir not in DECRYPT_DIRECTORY_LIST:
739        DECRYPT_DIRECTORY_LIST.append(current_decrypt_dir)
740        logger.warning("The decrypt mindrecord file will be stored in [" + current_decrypt_dir + "] directory. "
741                       "If you don't use it anymore after train / eval, you need to delete it manually.")
742
743    # create new decrypt file
744    decrypt_filename = os.path.join(current_decrypt_dir, only_filename)
745    if os.path.isfile(decrypt_filename):
746        # the file which had been decrypted early maybe update by user, so we remove the old decrypted one
747        os.remove(decrypt_filename)
748
749    f_decrypt = open(decrypt_filename, 'wb+')
750
751    try:
752        if callable(dec_mode):
753            dec_mode(f, file_size, f_decrypt, enc_key)
754        else:
755            # read the file and do decrypt
756            # encrypted mindrecord file like:
757            # len+encrypt_data|len+encrypt_data|len+encrypt_data|...|0|enc_mode|ENCRYPT_END_FLAG
758            current_offset = 0           ## use this to seek file
759            length = int().from_bytes(f.read(4), byteorder='big', signed=True)
760            while length != 0:
761                # current_offset is the encrypted data
762                current_offset += 4
763                try:
764                    f.seek(current_offset)
765                except Exception as e:  # pylint: disable=W0703
766                    f.close()
767                    raise RuntimeError("Seek the file: {} to position: {} failed. Error: {}"
768                                       .format(filename, current_offset, str(e)))
769
770                data = f.read(length)
771                decode_data = _decrypt_data(data, len(data), enc_key, len(enc_key), dec_mode)
772
773                if decode_data is None:
774                    raise RuntimeError("Failed to decrypt data, " +
775                                       "please check if enc_key and enc_mode / dec_mode is valid.")
776
777                # write to decrypt file
778                f_decrypt.write(decode_data)
779
780                # current_offset is the length of next encrypted data block
781                current_offset += length
782                try:
783                    f.seek(current_offset)
784                except Exception as e:  # pylint: disable=W0703
785                    f.close()
786                    raise RuntimeError("Seek the file: {} to position: {} failed. Error: {}"
787                                       .format(filename, current_offset, str(e)))
788
789                length = int().from_bytes(f.read(4), byteorder='big', signed=True)
790    except Exception as e:
791        f.close()
792        f_decrypt.close()
793        os.chmod(decrypt_filename, stat.S_IRUSR | stat.S_IWUSR)
794        raise e
795
796    f.close()
797    f_decrypt.close()
798
799    end = time.time()
800    global DECRYPT_TIME
801    DECRYPT_TIME += end - start
802    if DECRYPT_TIME > WARNING_INTERVAL:
803        logger.warning("It takes another " + str(WARNING_INTERVAL) + "s to decrypt the mindrecord file.")
804        DECRYPT_TIME = DECRYPT_TIME - WARNING_INTERVAL
805
806    # change the file mode
807    os.chmod(decrypt_filename, stat.S_IRUSR | stat.S_IWUSR)
808
809    return decrypt_filename
810