• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2024 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16"""Model and parameters serialization."""
17from __future__ import absolute_import
18from __future__ import division
19
20import binascii
21import copy
22import json
23import os
24import shutil
25import stat
26import threading
27from threading import Thread, RLock
28from collections import defaultdict, OrderedDict
29from io import BytesIO
30
31import math
32import sys
33import time
34import google
35import numpy as np
36
37from mindspore.train.checkpoint_pb2 import Checkpoint
38from mindspore.train.mind_ir_pb2 import ModelProto as mindir_model
39from mindspore.train.print_pb2 import Print
40
41import mindspore
42import mindspore.nn as nn
43from mindspore import context
44from mindspore import log as logger
45from mindspore._checkparam import check_input_data, check_input_dataset
46from mindspore import _checkparam as Validator
47from mindspore.common import dtype as mstype
48from mindspore.common.api import _cell_graph_executor as _executor
49from mindspore.common.api import _MindsporeFunctionExecutor
50from mindspore.common.api import _get_parameter_layout
51from mindspore.common.api import _generate_branch_control_input
52from mindspore.common.initializer import initializer, One
53from mindspore.common.parameter import Parameter, _offload_if_config
54from mindspore.common.tensor import Tensor
55from mindspore._c_expression import Tensor as Tensor_
56from mindspore.common._utils import is_shape_unknown
57from mindspore.common.file_system import FileSystem, _register_basic_file_system, _register_mindio_file_system
58from mindspore.communication.management import get_rank, get_group_size
59from mindspore.experimental import MapParameter
60from mindspore.ops import Cast
61from mindspore.parallel._cell_wrapper import get_allgather_cell
62from mindspore.parallel._tensor import _load_tensor, _get_tensor_strategy, _get_tensor_slice_index
63from mindspore.parallel._tensor import _reshape_param_data, _reshape_param_data_with_weight
64from mindspore.parallel._utils import _infer_rank_list, _remove_repeated_slices, _is_in_auto_parallel_mode
65from mindspore.parallel._parallel_serialization import _convert_to_list, _convert_to_layout, _build_searched_strategy, \
66    _restore_group_info_list
67from mindspore.parallel._ps_context import _set_checkpoint_load_status, _store_warm_up_ptr_by_tensor, \
68    _store_warm_up_ptr_by_tensor_list, _cache_enable
69from mindspore.parallel.checkpoint_transform import sync_pipeline_shared_parameters
70from mindspore.train._utils import read_proto
71from mindspore._c_expression import load_mindir, _encrypt, _decrypt, _is_cipher_file, dynamic_obfuscate_mindir, \
72    split_mindir, split_dynamic_mindir
73from mindspore.common.generator import Generator
74from mindspore.train._utils import get_parameter_redundancy, remove_param_redundancy
75from mindspore.parallel.parameter_broadcast import parameter_broadcast
76from ..ops.operations._opaque_predicate_registry import add_opaque_predicate, clean_funcs
77
78tensor_to_ms_type = {"Int8": mstype.int8, "UInt8": mstype.uint8, "Int16": mstype.int16, "UInt16": mstype.uint16,
79                     "Int32": mstype.int32, "UInt32": mstype.uint32, "Int64": mstype.int64, "UInt64": mstype.uint64,
80                     "Float16": mstype.float16, "Float32": mstype.float32, "Float64": mstype.float64,
81                     "Bool": mstype.bool_, "str": mstype.string, "BFloat16": mstype.bfloat16, "Int4": mstype.qint4x2}
82
83tensor_to_np_type = {"Int8": np.int8, "UInt8": np.uint8, "Int16": np.int16, "UInt16": np.uint16,
84                     "Int32": np.int32, "UInt32": np.uint32, "Int64": np.int64, "UInt64": np.uint64,
85                     "Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_, "str": "U"}
86
87np_type_convert = {"int32": np.int32, "float32": np.float32, "float16": np.float16, "float64": np.float64}
88
89mindir_to_tensor_type = {1: mstype.float32, 2: mstype.uint8, 3: mstype.int8, 4: mstype.uint16,
90                         5: mstype.int16, 6: mstype.int32, 7: mstype.int64, 10: mstype.float16,
91                         11: mstype.float64, 12: mstype.uint32, 13: mstype.uint64}
92
93_ckpt_mutex = RLock()
94
95# unit is KB
96SLICE_SIZE = 512 * 1024
97PROTO_LIMIT_SIZE = 1024 * 1024 * 2
98TOTAL_SAVE = 1024 * 1024
99PARAMETER_SPLIT_SIZE = 1024 * 1024 * 1024
100ENCRYPT_BLOCK_SIZE = 64 * 1024
101INT_64_MAX = 9223372036854775807
102
103cpu_cast = Cast().set_device("CPU")
104
105_ckpt_fs = FileSystem()
106
107
108def init_ckpt_file_system(fs: FileSystem):
109    """Initialize checkpoint file system"""
110    if _register_mindio_file_system(fs):
111        return
112    _register_basic_file_system(fs)
113
114
115# Initialize checkpoint file system
116init_ckpt_file_system(_ckpt_fs)
117
118
119class ParamDictFuture:
120    def __init__(self, executor, param_dict_future):
121        self.executor = executor
122        self.param_dict_future = param_dict_future
123
124    def result(self):
125        param_dict = self.param_dict_future.result()
126        self.executor.shutdown()
127        return param_dict
128
129
130def _special_process_par(par, new_par):
131    """
132    Processes the special condition.
133
134    Like (12,2048,1,1)->(12,2048), this case is caused by GE 4 dimensions tensor.
135    """
136    par_shape_len = len(par.data.shape)
137    new_par_shape_len = len(new_par.data.shape)
138    if new_par_shape_len <= par_shape_len:
139        return False
140
141    for i in range(new_par_shape_len - par_shape_len):
142        if new_par.data.shape[par_shape_len + i] != 1:
143            return False
144
145    if new_par.data.dtype == mstype.bfloat16:
146        new_val = cpu_cast(new_par.data, mstype.float32).asnumpy()
147    else:
148        new_val = new_par.data.asnumpy()
149
150    new_val = new_val.reshape(par.data.shape)
151    par.set_data(Tensor(new_val, par.data.dtype))
152    return True
153
154
155def _update_param(param, new_param, strict_load):
156    """Updates param's data from new_param's data."""
157    if isinstance(param.data, Tensor) and isinstance(new_param.data, Tensor):
158        if param.data.shape != new_param.data.shape:
159            if not _special_process_par(param, new_param):
160                logger.critical("Failed to combine the net and the parameters for param %s.", param.name)
161                msg = (f"For 'load_param_into_net', {param.name} in the argument 'net' should have the same shape "
162                       f"as {param.name} in the argument 'parameter_dict'. But got its shape {param.data.shape} in"
163                       f" the argument 'net' and shape {new_param.data.shape} in the argument 'parameter_dict'."
164                       f"May you need to check whether the checkpoint you loaded is correct or the batch size and "
165                       f"so on in the 'net' and 'parameter_dict' are same.")
166                raise RuntimeError(msg)
167
168        if param.data.dtype != new_param.data.dtype:
169            if _type_convert(param, new_param, strict_load):
170                if new_param.data.dtype == mstype.bfloat16:
171                    new_tensor = cpu_cast(new_param.data, param.data.dtype)
172                else:
173                    new_tensor = Tensor(new_param.data.asnumpy(), param.data.dtype)
174                param.set_data(new_tensor, param.sliced)
175                return
176
177            logger.critical("Failed to combine the net and the parameters for param %s.", param.name)
178            msg = (f"For 'load_param_into_net', {param.name} in the argument 'net' should have the same type as "
179                   f"{param.name} in the argument 'parameter_dict'. but got its type {param.data.dtype} in the "
180                   f"argument 'net' and type {new_param.data.dtype} in the argument 'parameter_dict'."
181                   f"May you need to check whether the checkpoint you loaded is correct.")
182            raise RuntimeError(msg)
183
184        param.set_data(new_param.data, param.sliced)
185        return
186
187    if isinstance(param.data, Tensor) and not isinstance(new_param.data, Tensor):
188        if param.data.shape != (1,) and param.data.shape != ():
189            logger.critical("Failed to combine the net and the parameters for param %s.", param.name)
190            msg = (f"For 'load_param_into_net', {param.name} in the argument 'parameter_dict' is "
191                   f"scalar, then the shape of {param.name} in the argument 'net' should be "
192                   f"(1,) or (), but got shape {param.data.shape}."
193                   f"May you need to check whether the checkpoint you loaded is correct.")
194            raise RuntimeError(msg)
195        param.set_data(initializer(new_param.data, param.data.shape, param.data.dtype))
196
197    elif isinstance(new_param.data, Tensor) and not isinstance(param.data, Tensor):
198        logger.critical("Failed to combine the net and the parameters for param %s.", param.name)
199        msg = (f"For 'load_param_into_net', {param.name} in the argument 'parameter_dict' is Tensor, "
200               f"then {param.name} in the argument 'net' also should be Tensor, but got {type(param.data)}."
201               f"May you need to check whether the checkpoint you loaded is correct.")
202        raise RuntimeError(msg)
203
204    else:
205        param.set_data(type(param.data)(new_param.data))
206
207
208def _type_convert(param, new_param, strict_load):
209    """Whether to convert parameter's type during load checkpoint into network."""
210    float_type = (mstype.float16, mstype.float32, mstype.float64, mstype.bfloat16)
211    int_type = (mstype.int8, mstype.int16, mstype.int32, mstype.int64)
212    if not strict_load and ({param.data.dtype, new_param.data.dtype}.issubset(float_type) or
213                            {param.data.dtype, new_param.data.dtype}.issubset(int_type)):
214        logger.warning(f"The type of {new_param.name}:{new_param.data.dtype} in 'parameter_dict' is different from "
215                       f"the type of it in 'net':{param.data.dtype}, then the type convert from "
216                       f"{new_param.data.dtype} to {param.data.dtype} in the network.")
217        return True
218    return False
219
220
221def _save_weight(checkpoint_dir, model_name, iteration, params):
222    """Save model weight into checkpoint."""
223    logger.debug(f"Checkpoint dir is: '{checkpoint_dir}'")
224    exist_ckpt_file_list = []
225    if os.path.exists(checkpoint_dir):
226        for exist_ckpt_name in os.listdir(checkpoint_dir):
227            file_prefix = os.path.join(model_name, "_iteration_")
228            if exist_ckpt_name.startswith(file_prefix):
229                exist_ckpt_file_list.append(exist_ckpt_name)
230
231        param_dict = OrderedDict()
232        for key in params.keys():
233            value = params[key]
234            weight_type = value[0]
235            weight_shape = value[1]
236            weight_data = value[2]
237            weight_size = value[3]
238            weight_np = np.array(weight_data, dtype=weight_type.lower())
239            logger.debug(f"weight_type: '{weight_type}', weight_shape: '{weight_shape}', weight_size: "
240                         f"'{weight_size}', weight_np.nbytes: '{weight_np.nbytes}'")
241
242            param_dict[key] = [weight_shape, weight_type, weight_np]
243        ckpt_file_save_name = model_name + "_iteration_" + iteration + ".ckpt"
244        ckpt_file_save_path = os.path.join(checkpoint_dir, ckpt_file_save_name)
245
246        _exec_save(ckpt_file_save_path, param_dict)
247
248        for exist_ckpt_name in exist_ckpt_file_list:
249            os.remove(os.path.join(checkpoint_dir, exist_ckpt_name))
250        logger.info(f"Save weight to checkpoint file path '{ckpt_file_save_path}' success.")
251    else:
252        logger.warning(f"Checkpoint dir: '{checkpoint_dir}' is not existed.")
253
254
255def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_param_inc=False, crc_check=False):
256    """Execute the process of saving checkpoint into file."""
257    try:
258        with _ckpt_mutex:
259            if os.path.exists(ckpt_file_name):
260                os.chmod(ckpt_file_name, stat.S_IWUSR)
261                os.remove(ckpt_file_name)
262            with _ckpt_fs.create(ckpt_file_name, *_ckpt_fs.create_args) as f:
263                plain_data = None
264                if enc_key is not None:
265                    plain_data = BytesIO()
266
267                crc_num = 0
268                for name, value in data_list.items():
269                    if name == "random_op":
270                        _write_random_seed(name, value, f)
271                        continue
272                    if value[0] == "mapparameter":
273                        _write_mapparameter(name, value, f, map_param_inc)
274                        continue
275                    if value[0] == "offload_parameter":
276                        new_value = value[1:]
277                        new_value[2] = value[3]
278                        _write_parameter_bytes_data(name, new_value, f, enc_key, plain_data)
279                        _offload_if_config(value[3])
280                        continue
281                    if value[1] == "str":
282                        crc_num = _write_parameter_data(name, value, f, enc_key, plain_data, crc_num, crc_check)
283                        continue
284                    if isinstance(value[2], np.ndarray):
285                        crc_num = _write_parameter_data(name, value, f, enc_key, plain_data, crc_num, crc_check)
286                        continue
287                    if isinstance(value[2], Tensor) and hasattr(value[2], "slice_num") and value[2].slice_num > 1:
288                        _write_hugeparameter(name, value, f)
289                        continue
290
291                    crc_num = _write_parameter_bytes_data(name, value, f, enc_key, plain_data, crc_num, crc_check)
292
293                if enc_key is not None:
294                    plain_data.seek(0)
295                    max_block_size = ENCRYPT_BLOCK_SIZE * 1024
296                    block_data = plain_data.read(max_block_size)
297                    while block_data:
298                        f.write(_encrypt(block_data, len(block_data), enc_key, len(enc_key), enc_mode))
299                        block_data = plain_data.read(max_block_size)
300
301                if crc_check:
302                    f.write('crc_num'.encode() + crc_num.to_bytes(10, byteorder='big'))
303
304            os.chmod(ckpt_file_name, stat.S_IRUSR)
305
306    except BaseException as e:
307        logger.critical("Failed to save the checkpoint file %s. Maybe don't have the permission to write files, "
308                        "or the disk space is insufficient and so on.", ckpt_file_name)
309        raise e
310
311
312def _write_random_seed(name, value, f):
313    """Write random op into protobuf file."""
314    checkpoint_list = Checkpoint()
315    param_value = checkpoint_list.value.add()
316    param_value.tag = name
317    param_tensor = param_value.tensor
318    param_tensor.dims.extend(0)
319    param_tensor.tensor_type = "random_op"
320    param_tensor.tensor_content = value
321    f.write(checkpoint_list.SerializeToString())
322
323
324def _write_parameter_data(name, value, f, enc_key, plain_data, crc_num=0, crc_check=False):
325    """Write parameter data into protobuf file."""
326    data_size = value[2].nbytes / 1024
327    if data_size > SLICE_SIZE:
328        slice_count = math.ceil(data_size / SLICE_SIZE)
329        param_slice_list = np.array_split(value[2], slice_count)
330    else:
331        param_slice_list = [value[2]]
332
333    for param_slice in param_slice_list:
334        checkpoint_list = Checkpoint()
335        param_value = checkpoint_list.value.add()
336        param_value.tag = name
337        param_tensor = param_value.tensor
338        param_tensor.dims.extend(value[0])
339        param_tensor.tensor_type = value[1]
340        param_tensor.tensor_content = param_slice.tobytes()
341
342        if enc_key is None:
343            output_data = checkpoint_list.SerializeToString()
344            if crc_check:
345                crc_num = binascii.crc32(output_data, crc_num)
346            f.write(output_data)
347        else:
348            plain_data.write(checkpoint_list.SerializeToString())
349
350    return crc_num
351
352
353def _write_parameter_bytes_data(name, value, f, enc_key, plain_data, crc_num=0, crc_check=False):
354    """Write parameter bytes data into protobuf file."""
355    bytes_value = value[2].get_bytes()
356    chunk_size = 1024 * SLICE_SIZE
357
358    for i in range(0, len(bytes_value), chunk_size):
359        checkpoint_list = Checkpoint()
360        param_value = checkpoint_list.value.add()
361        param_value.tag = name
362        param_tensor = param_value.tensor
363        param_tensor.dims.extend(value[0])
364        param_tensor.tensor_type = value[1]
365        param_tensor.tensor_content = bytes_value[i:i + chunk_size]
366
367        if enc_key is None:
368            output_data = checkpoint_list.SerializeToString()
369            if crc_check:
370                crc_num = binascii.crc32(output_data, crc_num)
371            f.write(output_data)
372        else:
373            plain_data.write(checkpoint_list.SerializeToString())
374
375    return crc_num
376
377
378def _write_mapparameter(name, value, f, map_param_inc=False):
379    """Write map parameter into protobuf file."""
380    while True:
381        logger.info("Checkpoint save map_parameter.")
382        data_map_slice = value[1].export_slice_data(map_param_inc)
383        checkpoint_list = Checkpoint()
384        param_value = checkpoint_list.value.add()
385        param_value.tag = name
386        map_tensor = param_value.maptensor
387        for numpy_data in data_map_slice[:3]:
388            tensor_pro = map_tensor.tensor.add()
389            tensor_pro.dims.extend(numpy_data.shape)
390            tensor_pro.tensor_type = str(numpy_data.dtype)
391            tensor_pro.tensor_content = numpy_data.reshape(-1).tobytes()
392        f.write(checkpoint_list.SerializeToString())
393        if data_map_slice[3]:
394            break
395
396
397def _write_hugeparameter(name, value, f):
398    """Write huge parameter into protobuf file."""
399    slice_num = value[2].slice_num
400    offset = 0
401    max_size = value[0][0]
402    for param_slice in range(slice_num):
403        checkpoint_list = Checkpoint()
404        param_value = checkpoint_list.value.add()
405        param_value.tag = name
406        param_tensor = param_value.tensor
407        param_tensor.dims.extend(value[0])
408        param_tensor.tensor_type = value[1]
409        param_key = value[3]
410        numpy_data = value[2].asnumpy_of_slice_persistent_data(param_key, param_slice)
411        if offset + numpy_data.shape[0] > max_size:
412            numpy_data = numpy_data[:max_size - offset]
413        param_tensor.tensor_content = numpy_data.tobytes()
414        f.write(checkpoint_list.SerializeToString())
415        offset += numpy_data.shape[0]
416
417
418def _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name):
419    """Check save_obj and ckpt_file_name for save_checkpoint."""
420    if not isinstance(save_obj, (nn.Cell, list, dict)):
421        raise TypeError("For 'save_checkpoint', the parameter 'save_obj' must be nn.Cell, list or dict, "
422                        "but got {}.".format(type(save_obj)))
423    if not isinstance(ckpt_file_name, str):
424        raise TypeError("For 'save_checkpoint', the parameter {} for checkpoint file name is invalid,"
425                        "'ckpt_file_name' must be "
426                        "string, but got {}.".format(ckpt_file_name, type(ckpt_file_name)))
427    ckpt_file_name = os.path.abspath(ckpt_file_name)
428    if os.path.isdir(ckpt_file_name):
429        raise IsADirectoryError("For 'save_checkpoint', the parameter `ckpt_file_name`: {} is a directory, "
430                                "it must be a file name.".format(ckpt_file_name))
431    if not ckpt_file_name.endswith('.ckpt'):
432        ckpt_file_name += ".ckpt"
433    return ckpt_file_name
434
435
436def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
437                    async_save=False, append_dict=None, enc_key=None, enc_mode="AES-GCM", choice_func=None,
438                    crc_check=False, **kwargs):
439    r"""
440    Save checkpoint to a specified file.
441
442    Note:
443        The `enc_mode` and `crc_check` parameters are mutually exclusive and cannot be configured simultaneously.
444
445    Args:
446        save_obj (Union[Cell, list, dict]): The object to be saved. The data type can be :class:`mindspore.nn.Cell`,
447            list, or dict. If a list, it can be the returned value of `Cell.trainable_params()`, or a list of dict
448            elements(each element is a dictionary, like [{"name": param_name, "data": param_data},...], the type of
449            `param_name` must be string, and the type of `param_data` must be parameter or Tensor); If dict,
450            it can be the returned value of `mindspore.load_checkpoint()`.
451        ckpt_file_name (str): Checkpoint file name. If the file name already exists, it will be overwritten.
452        integrated_save (bool): Whether to integrated save in automatic model parallel scene. Default: ``True`` .
453        async_save (bool): Whether to open an independent thread to save the checkpoint file. Default: ``False`` .
454        append_dict (dict): Additional information that needs to be saved. The key of dict must be str, the value
455                            of dict must be one of int, float, bool, string, Parameter or Tensor. Default: ``None`` .
456        enc_key (Union[None, bytes]): Byte type key used for encryption. If the value is ``None`` , the encryption
457                                      is not required. Default: ``None`` .
458        enc_mode (str): This parameter is valid only when enc_key is not set to ``None`` . Specifies the encryption
459                        mode, currently supports ``"AES-GCM"`` and ``"AES-CBC"`` and ``"SM4-CBC"`` .
460                        Default: ``"AES-GCM"`` .
461        choice_func (function) : A function for saving custom selected parameters. The input value of `choice_func` is
462                                 a parameter name in string type, and the returned value is a bool.
463                                 If returns ``True`` , the Parameter that matching the custom condition will be saved.
464                                 If returns ``False`` , the Parameter that not matching the custom condition will not
465                                 be saved. Default: ``None`` .
466        crc_check (bool) : Whether to perform crc32 calculation when saving checkpoint and save the calculation
467            result to the file. Default: ``False`` .
468        kwargs (dict): Configuration options dictionary.
469
470    Raises:
471        TypeError: If the parameter `save_obj` is not :class:`mindspore.nn.Cell` , list or dict type.
472        TypeError: If the parameter `integrated_save` or `async_save` is not bool type.
473        TypeError: If the parameter `ckpt_file_name` is not string type.
474
475    Examples:
476        >>> import mindspore as ms
477        >>>
478        >>> # Define the network structure of LeNet5. Refer to
479        >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
480        >>> net = LeNet5()
481        >>> ms.save_checkpoint(net, "./lenet.ckpt",
482        ...                    choice_func=lambda x: x.startswith("conv") and not x.startswith("conv1"))
483        >>> param_dict1 = ms.load_checkpoint("./lenet.ckpt")
484        >>> print(param_dict1)
485        {'conv2.weight': Parameter (name=conv2.weight, shape=(16, 6, 5, 5), dtype=Float32, requires_grad=True)}
486        >>> params_list = net.trainable_params()
487        >>> ms.save_checkpoint(params_list, "./lenet_list.ckpt",
488        ...                    choice_func=lambda x: x.startswith("conv") and not x.startswith("conv2"))
489        >>> param_dict2 = ms.load_checkpoint("./lenet_list.ckpt")
490        >>> print(param_dict2)
491        {'conv1.weight': Parameter (name=conv1.weight, shape=(6, 1, 5, 5), dtype=Float32, requires_grad=True)}
492        >>> ms.save_checkpoint(param_dict2, "./lenet_dict.ckpt")
493        >>> param_dict3 = ms.load_checkpoint("./lenet_dict.ckpt")
494        >>> print(param_dict3)
495        {'conv1.weight': Parameter (name=conv1.weight, shape=(6, 1, 5, 5), dtype=Float32, requires_grad=True)}
496
497    Tutorial Examples:
498        - `Saving and Loading the Model - Saving and Loading the Model Weight
499          <https://mindspore.cn/tutorials/en/master/beginner/save_load.html#saving-and-loading-the-model-weight>`_
500    """
501    ckpt_file_name = _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name)
502    integrated_save = Validator.check_bool(integrated_save)
503    async_save = Validator.check_bool(async_save)
504    append_dict = _check_append_dict(append_dict)
505    enc_key = Validator.check_isinstance('enc_key', enc_key, (type(None), bytes))
506    enc_mode = Validator.check_isinstance('enc_mode', enc_mode, str)
507    crc_check = Validator.check_isinstance('crc_check', crc_check, bool)
508    map_param_inc = kwargs.get('incremental', False)
509    logger.info("Execute the process of saving checkpoint files.")
510    global_step_num = kwargs.get('global_step_num', None)
511
512    save_obj = _convert_save_obj_to_param_list(save_obj, integrated_save, append_dict, choice_func)
513
514    if append_dict:
515        append_info_list = []
516        for k_name, value in append_dict.items():
517            if isinstance(value, Generator):
518                value = value.get_state()
519            elif not isinstance(value, str):
520                value = Tensor(value)
521            append_info_list.append({"name": k_name, "data": value})
522        save_obj.extend(append_info_list)
523
524    data_list = OrderedDict()
525    data_list_np = OrderedDict()
526    with _ckpt_mutex:
527        for param in save_obj:
528            if param["name"] == "random_op":
529                if os.getenv("AITURBO") == "1":
530                    data_list_np["random_op"] = param["data"]
531                else:
532                    data_list["random_op"] = param["data"]
533                continue
534            key = param["name"]
535            data_list[key] = []
536            if isinstance(param["data"], MapParameter):
537                data_list[param["name"]].append("mapparameter")
538                data_list[param["name"]].append(param["data"])
539                continue
540            if isinstance(param["data"], list):
541                if param["data"][0] == "persistent_data":
542                    _save_param_list_data(data_list, key, param)
543                elif param["data"][0] == "offload_parameter":
544                    data_list[key].append("offload_parameter")
545                    _save_param_list_data(data_list, key, param)
546
547            if isinstance(param["data"], str):
548                if os.getenv("AITURBO") == "1":
549                    data_list_np[key] = np.array(param["data"])
550                else:
551                    data_list[key].append([0])
552                    data_list[key].append('str')
553                    data = np.array(param["data"])
554                    data_list[key].append(data)
555            else:
556                if isinstance(param["data"], Parameter):
557                    param["data"].init_data()
558                if os.getenv("AITURBO") == "1":
559                    data_list_np[key] = param["data"].asnumpy()
560                else:
561                    dims = []
562                    for dim in param['data'].shape:
563                        dims.append(dim)
564                    data_list[key].append(dims)
565                    tensor_type = str(param["data"].dtype)
566                    data_list[key].append(tensor_type)
567                    data = param["data"]
568                    data_list[key].append(data)
569
570    if os.getenv("AITURBO") == "1":
571        import aiturbo
572        ckpt_name = os.path.basename(ckpt_file_name)
573        aiturbo.save_ckpt(ckpt_name, global_step_num, data_list_np)
574    elif async_save:
575        data_copy = copy.deepcopy(data_list)
576        thr = Thread(target=_exec_save, args=(ckpt_file_name, data_copy, enc_key, enc_mode, map_param_inc, crc_check),
577                     name="asyn_save_ckpt")
578        thr.start()
579    else:
580        _exec_save(ckpt_file_name, data_list, enc_key, enc_mode, map_param_inc, crc_check)
581
582    logger.info("Saving checkpoint process is finished.")
583
584
585def _convert_list_to_param_list(save_obj, choice_func):
586    """Convert a list of Parameter to param_list."""
587    param_list = []
588    if not save_obj:
589        return param_list
590    if isinstance(save_obj[0], dict):
591        for param in save_obj:
592            if isinstance(param, dict) and "name" in param and "data" in param:
593                if not isinstance(param["name"], str):
594                    raise TypeError(f"For save_checkpoint, when save_obj is a list of dict items, the name in dict "
595                                    f"should be string, but got {type(param['name'])}.")
596                if not isinstance(param["data"], Tensor):
597                    raise TypeError(f"For save_checkpoint, when save_obj is a list of dict items, the data in dict "
598                                    f"should be parameter, but got {type(param['data'])}.")
599                if choice_func is not None and not choice_func(param["name"]):
600                    continue
601                each_param = {"name": param["name"], "data": param["data"]}
602                param_list.append(each_param)
603            else:
604                raise TypeError(f"For save_checkpoint, save_obj should be a list of dict items, and the dict should "
605                                f"have key values 'name' and 'value', but got {type(param)} and {param}.")
606    else:
607        for param in save_obj:
608            if isinstance(param, Parameter):
609                if choice_func is not None and not choice_func(param.name):
610                    continue
611                each_param = {"name": param.name, "data": param}
612                param_list.append(each_param)
613            else:
614                raise TypeError(f"For save_checkpoint, when save_obj is made up by list of Parameter,"
615                                f"the param should be parameter, but got {type(param)}")
616    return param_list
617
618
619def _convert_dict_to_param_dict(save_obj, choice_func):
620    """Convert a dict of Parameter to param_list."""
621    param_list = []
622    for (key, value) in save_obj.items():
623        if isinstance(key, str) and isinstance(value, (Parameter, str)):
624            if choice_func is not None and not choice_func(key):
625                continue
626            each_param = {"name": key, "data": value}
627            param_list.append(each_param)
628        else:
629            raise TypeError(f"For save_checkpoint, when save_obj is made up by dict, the key should be str and"
630                            f"value should be Parameter, but got the type of key is {type(key)} and"
631                            f"the type of value is {type(value)}")
632    return param_list
633
634
635def _convert_cell_param_and_names_to_dict(save_obj, choice_func):
636    """Convert cell.parameters_and_names to OrderedDict."""
637    param_dict = OrderedDict()
638    for _, param in save_obj.parameters_and_names():
639        not_sliced = not param.sliced
640        is_graph_mode = context.get_context('mode') == context.GRAPH_MODE
641        # All parameters are initialized immediately under PyNative mode, skip this judgement.
642        judgment = not_sliced or param.has_init
643        if is_graph_mode and _is_in_auto_parallel_mode() and judgment:
644            continue
645        if choice_func is not None and not choice_func(param.name):
646            continue
647        # Add suffix for cache_enabled parameter, and then parameter can carry key info.
648        # Notice that suffix needs be removed when loading into net.
649        if param.cache_enable:
650            param_dict[param.name + ".__param_key__" + str(param.key)] = param
651        else:
652            param_dict[param.name] = param
653    return param_dict
654
655
656def _convert_cell_to_param_list(save_obj, integrated_save, append_dict, choice_func):
657    """Convert nn.Cell to param_list."""
658    sync_pipeline_shared_parameters(save_obj)
659    param_list = []
660    parameter_layout_dict = save_obj.parameter_layout_dict
661    if _is_in_auto_parallel_mode() and not parameter_layout_dict:
662        parameter_layout_dict = _get_parameter_layout()
663    if not _is_in_auto_parallel_mode():
664        save_obj.init_parameters_data()
665    param_dict = _convert_cell_param_and_names_to_dict(save_obj, choice_func)
666    if append_dict and "random_op" in append_dict:
667        phase = 'train' + '.' + str(save_obj.create_time) + '.' + str(id(save_obj)) + '.' + save_obj.arguments_key
668        if phase in save_obj.compile_cache and _executor.has_compiled(phase):
669            random_byte = _executor._graph_executor.get_random_status(phase)
670            param_list.append({"name": "random_op", "data": random_byte})
671            append_dict.pop("random_op")
672    for (key, value) in param_dict.items():
673        each_param = {"name": key}
674        if isinstance(value, MapParameter):
675            each_param["data"] = value
676            param_list.append(each_param)
677            continue
678
679        if value.data.is_persistent_data():
680            # list save persistent_data: [Tensor, shape, type, param.key]
681            param_data = ["persistent_data", value.data, value.param_info.origin_shape, str(value.dtype), value.key]
682        elif value.data.offload_file_path() != "":
683            # list save offload data: [Param, shape, type, param.key]
684            param_data = ["offload_parameter"]
685            param_tensor = value.data
686            if key in parameter_layout_dict:
687                param_tensor = _get_merged_param_data(save_obj, parameter_layout_dict, key, param_tensor,
688                                                      integrated_save)
689            param_data.append(param_tensor)
690            param_data.append(param_tensor.shape)
691            param_data.append(str(param_tensor.dtype))
692            param_data.append(value.key)
693        else:
694            param_data = value.data
695
696            # in automatic model parallel scenario, some parameters were split to all the devices,
697            # which should be combined before saving
698            if key in parameter_layout_dict:
699                param_data = Tensor(value.data)
700                param_data = _get_merged_param_data(save_obj, parameter_layout_dict, key, param_data,
701                                                    integrated_save)
702
703        each_param["data"] = param_data
704        param_list.append(each_param)
705    return param_list
706
707
708def _convert_save_obj_to_param_list(save_obj, integrated_save, append_dict, choice_func):
709    """Convert a save_obj to param_list."""
710    if isinstance(save_obj, list):
711        return _convert_list_to_param_list(save_obj, choice_func)
712
713    if isinstance(save_obj, dict):
714        return _convert_dict_to_param_dict(save_obj, choice_func)
715
716    return _convert_cell_to_param_list(save_obj, integrated_save, append_dict, choice_func)
717
718
719def _save_param_list_data(data_list, key, param):
720    """Save persistent data into save_obj."""
721    dims = []
722    # persistent_data shape can not be ()
723    for dim in param['data'][2]:
724        dims.append(dim)
725    data_list[key].append(dims)
726    data_list[key].append(param['data'][3])
727    data_list[key].append(param['data'][1])
728    data_list[key].append(param['data'][4])
729
730
731def _check_append_dict(append_dict):
732    """Check the argument append_dict for save_checkpoint."""
733    if append_dict is None:
734        return append_dict
735    if not isinstance(append_dict, dict):
736        raise TypeError("For 'save_checkpoint', the argument 'append_dict' must be dict, but got "
737                        "{}.".format(type(append_dict)))
738    for key, value in append_dict.items():
739        if not isinstance(key, str) or not isinstance(value, (int, float, bool, str, Parameter, Tensor, Generator)):
740            raise TypeError(f"For 'save_checkpoint', the type of dict 'append_info' must be key: string, "
741                            f"value: int, float, bool or Generator, but got key: {type(key)}, value: {type(value)}")
742    return append_dict
743
744
745def _check_load_obfuscate(**kwargs):
746    if 'obf_func' in kwargs.keys():
747        customized_func = _check_customized_func(kwargs.get('obf_func'))
748        clean_funcs()
749        add_opaque_predicate(customized_func.__name__, customized_func)
750        return True
751    return False
752
753
754def load(file_name, **kwargs):
755    """
756    Load MindIR.
757
758    The returned object can be executed by a `GraphCell`, see class :class:`mindspore.nn.GraphCell` for more details.
759
760    Args:
761        file_name (str): MindIR file name.
762
763        kwargs (dict): Configuration options dictionary.
764
765            - dec_key (bytes): Byte-type key used for decryption. The valid length is 16, 24, or 32.
766            - dec_mode (Union[str, function]): Specifies the decryption mode, to take effect when dec_key is set.
767
768              - Option: 'AES-GCM', 'AES-CBC', 'SM4-CBC' or customized decryption. Default: ``'AES-GCM'``.
769              - For details of using the customized decryption, please check the `tutorial
770                <https://mindspore.cn/mindarmour/docs/en/master/model_encrypt_protection.html>`_.
771
772            - obf_func (function): A python function used for loading obfuscated MindIR model, which can refer to
773              `obfuscate_model()
774              <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.obfuscate_model.html>`_.
775
776    Returns:
777        GraphCell, a compiled graph that can executed by `GraphCell`.
778
779    Raises:
780        ValueError: MindIR file does not exist or `file_name` is not a string.
781        RuntimeError: Failed to parse MindIR file.
782
783    Examples:
784        >>> import numpy as np
785        >>> import mindspore as ms
786        >>> import mindspore.nn as nn
787        >>> from mindspore import Tensor
788        >>> from mindspore import context
789        >>> context.set_context(mode=context.GRAPH_MODE)
790        >>>
791        >>> net = nn.Conv2d(1, 1, kernel_size=3, weight_init="ones")
792        >>> input_tensor = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
793        >>> ms.export(net, input_tensor, file_name="net", file_format="MINDIR")
794        >>> graph = ms.load("net.mindir")
795        >>> net = nn.GraphCell(graph)
796        >>> output = net(input_tensor)
797        >>> print(output)
798        [[[[4. 6. 4.]
799           [6. 9. 6.]
800           [4. 6. 4.]]]]
801
802    Tutorial Examples:
803        - `Saving and Loading the Model - Saving and Loading MindIR
804          <https://mindspore.cn/tutorials/en/master/beginner/save_load.html#saving-and-loading-mindir>`_
805    """
806    if not isinstance(file_name, str):
807        raise ValueError("For 'load', the argument 'file_name' must be string, but "
808                         "got {}.".format(type(file_name)))
809    if not file_name.endswith(".mindir"):
810        raise ValueError("For 'load', the argument 'file_name'(MindIR file) should end with '.mindir', "
811                         "please input the correct 'file_name'.")
812    if not os.path.exists(file_name):
813        raise ValueError("For 'load', the argument 'file_name'(MindIR file) does not exist, "
814                         "please check whether the 'file_name' is correct.")
815    file_name = os.path.abspath(file_name)
816
817    # set customized functions for dynamic obfuscation
818    obfuscated = _check_load_obfuscate(**kwargs)
819
820    logger.info("Execute the process of loading mindir.")
821    if 'dec_key' in kwargs.keys():
822        dec_key = Validator.check_isinstance('dec_key', kwargs.get('dec_key'), bytes)
823        dec_mode = "AES-GCM"
824        dec_func = None
825        if 'dec_mode' in kwargs.keys():
826            if callable(kwargs.get('dec_mode')):
827                dec_mode = "Customized"
828                dec_func = kwargs.get('dec_mode')
829            else:
830                dec_mode = Validator.check_isinstance('dec_mode', kwargs.get('dec_mode'), str)
831        graph = load_mindir(file_name, dec_key=dec_key, key_len=len(dec_key), dec_mode=dec_mode,
832                            decrypt=dec_func, obfuscated=obfuscated)
833    else:
834        graph = load_mindir(file_name, obfuscated=obfuscated)
835
836    if graph is None:
837        if _is_cipher_file(file_name):
838            raise RuntimeError("Load MindIR failed. The file may be encrypted and decrypt failed, you "
839                               "can check whether the values of the arguments 'dec_key' and 'dec_mode'"
840                               " are the same as when exported MindIR file, or check the file integrity.")
841        raise RuntimeError("Load MindIR failed.")
842    return graph
843
844
845def export_split_mindir(file_name, device_num=8, rank_id=0, dynamic=True, sapp=True):
846    """
847    Auto Split MindIR.
848
849    The returned object can be executed by a `GraphCell`, see class :class:`mindspore.nn.GraphCell` for more details.
850
851    Args:
852        file_name (str): MindIR file name.
853        device_num (int): device number. Default: '8'.
854        rank_id (int): rank id. Default: '0'.
855        dynamic (bool): Indicates whether the model is a dynamic shape mindir model. Default: 'True'.
856        sapp (bool): Indicates whether to automatically generate split strategy through SAPP. Default: 'True'.
857
858    Raises:
859        ValueError: MindIR file does not exist or `file_name` is not a string.
860        RuntimeError: Failed to split MindIR file.
861
862    Examples:
863        >>> import mindspore as ms
864        >>> context.set_context(mode=context.GRAPH_MODE)
865        >>>
866        >>> ms.export_split_mindir("net.mindir", device_num=8, rank_id=0)
867
868    """
869    if not isinstance(file_name, str):
870        raise ValueError("For 'Split MindIR', the argument 'file_name' must be string, but "
871                         "got {}.".format(type(file_name)))
872    if not file_name.endswith(".mindir"):
873        raise ValueError("For 'Split MindIR', the argument 'file_name'(MindIR file) should end with '.mindir', "
874                         "please input the correct 'file_name'.")
875    if not os.path.exists(file_name):
876        raise ValueError("For 'Split MindIR', the argument 'file_name'(MindIR file) does not exist, "
877                         "please check whether the 'file_name' is correct.")
878    file_name = os.path.abspath(file_name)
879
880    logger.info("Execute the process of export and split mindir.")
881    dynamic = True
882    if dynamic:
883        graph = split_dynamic_mindir(file_name, device_num, rank_id, sapp)
884    else:
885        graph = split_mindir(file_name)
886
887    if graph is None:
888        if _is_cipher_file(file_name):
889            raise RuntimeError("Export and split MindIR failed. The file may be encrypted and decrypt failed, you "
890                               "can check whether the values of the arguments 'dec_key' and 'dec_mode'"
891                               " are the same as when exported MindIR file, or check the file integrity.")
892        raise RuntimeError("Export and split MindIR failed.")
893    return graph
894
895
896def _check_param_type(param_config, key, target_type, requested):
897    """check type of parameters"""
898    if key in param_config:
899        if not isinstance(param_config[key], target_type):
900            raise TypeError("The type of {} must be {}, but got {}.".format(key, target_type, type(param_config[key])))
901        if key == 'obf_random_seed':
902            if param_config[key] > INT_64_MAX or param_config[key] <= 0:
903                raise ValueError(
904                    "'obf_random_seed' must be in (0, INT_64_MAX({})], but got {}.".format(INT_64_MAX,
905                                                                                           param_config[key]))
906        return param_config[key]
907    if requested:
908        raise ValueError("The parameter {} is requested, but not got.".format(key))
909    if key == "obf_random_seed":
910        return 0
911    return None
912
913
914def _check_customized_func(customized_func):
915    """ check customized function of dynamic obfuscation """
916    if not callable(customized_func):
917        raise TypeError(
918            "'customized_func' must be a function, but not got {}.".format(type(customized_func)))
919    # test customized_func
920    try:
921        func_result = customized_func(1.0, 1.0)
922    except Exception as ex:
923        raise TypeError("customized_func must be a function with two inputs, but got exception: {}".format(ex))
924    else:
925        if not isinstance(func_result, bool):
926            raise TypeError("Return value of customized_func must be boolean, but got: {}".format(type(func_result)))
927    return customized_func
928
929
930def _check_obfuscate_params(obf_config):
931    """Check obfuscation parameters, including obf_random_seed, obf_ratio, customized_func"""
932    if 'obf_random_seed' not in obf_config.keys() and 'customized_func' not in obf_config.keys():
933        raise ValueError(
934            "At least one of 'obf_random_seed' or 'customized_func' must be set in obf_config, but got None of them.")
935    obfuscate_type = _check_param_type(obf_config, "type", str, False)
936    if obfuscate_type not in (None, "dynamic"):
937        raise ValueError("Only 'dynamic' type is supported by now, but got {}.".format(obfuscate_type))
938    if ('obf_ratio' in obf_config) and isinstance(obf_config['obf_ratio'], str):
939        if obf_config['obf_ratio'] not in ["small", "medium", "large"]:
940            raise ValueError("'obf_ratio' can only be 'small', 'medium', 'large' or float, but got {}.".format(
941                obf_config['obf_ratio']))
942        ratio_dict = {"small": 0.1, "medium": 0.3, "large": 0.6}
943        obf_config['obf_ratio'] = ratio_dict.get(obf_config['obf_ratio'])
944    obf_ratio = _check_param_type(obf_config, "obf_ratio", float, True)
945    if (obf_ratio <= 0) or (obf_ratio > 1):
946        raise ValueError("'obf_ratio' must be in (0, 1] if it is a float, but got {}.".format(obf_config['obf_ratio']))
947    customized_funcs = []
948    if 'customized_func' in obf_config.keys():
949        device_target = context.get_context('device_target')
950        if device_target in ["GPU", "Ascend"]:
951            raise ValueError(
952                "Customized func mode only support 'device_target'='CPU, but got {}.".format(device_target))
953        customized_funcs.append(_check_customized_func(obf_config['customized_func']))
954    obf_random_seed = _check_param_type(obf_config, "obf_random_seed", int, False)
955    return obf_ratio, customized_funcs, obf_random_seed
956
957
958def obfuscate_model(obf_config, **kwargs):
959    """
960    Obfuscate a model of MindIR format. Obfuscation means changing the struct of a network without affecting its
961    predict correctness. The obfuscated model can prevent attackers from stealing the model.
962
963    Args:
964        obf_config (dict): obfuscation config.
965
966            - type (str): The type of obfuscation, only 'dynamic' is supported until now.
967            - original_model_path (str): The path of MindIR format model that need to be obfuscated. If the original
968              model is encrypted, then enc_key and enc_mode should be provided.
969            - save_model_path (str): The path to save the obfuscated model.
970            - model_inputs (list(Tensor)): The inputs of the original model, the values of Tensor can be random, which
971              is the same as using :func:`mindspore.export`.
972            - obf_ratio (Union(float, str)): The ratio of nodes in original model that would be obfuscated. `obf_ratio`
973              should be in range of (0, 1] or in ["small", "medium", "large"]. "small", "medium" and "large" are
974              correspond to 0.1, 0.3, and 0.6 respectively.
975            - customized_func (function): A python function used for customized function mode, which used for control
976              the switch branch of obfuscation structure. The outputs of customized_func should be boolean and const (
977              Reference to 'my_func()' in
978              `tutorials <https://www.mindspore.cn/mindarmour/docs/en/master/dynamic_obfuscation_protection.html>`_).
979              This function needs to ensure that its result is constant for any input. Users can refer to opaque
980              predicates. If customized_func is set, then it should be passed to :func:`mindspore.load` interface
981              when loading obfuscated model.
982            - obf_random_seed (int): Obfuscation random seed, which should be in (0, 9223372036854775807]. The
983              structure of obfuscated models corresponding to different random seeds is different. If
984              `obf_random_seed` is set, then it should be passed to :class:`mindspore.nn.GraphCell`
985              interface when loading
986              obfuscated model. It should be noted that at least one of `customized_func` or `obf_random_seed` should
987              be set, and the latter mode would be applied if both of them are set.
988
989        kwargs (dict): Configuration options dictionary.
990
991            - enc_key (bytes): Byte type key used for encryption. The valid length is 16, 24, or 32.
992            - enc_mode (str): Specifies the encryption mode, to take effect when dec_key is set.
993              Options: ``'AES-GCM'`` | ``'AES-CBC'`` | ``'SM4-CBC'``. Default: ``'AES-GCM'``.
994
995    Raises:
996        TypeError: If `obf_config` is not a dict.
997        ValueError: If `enc_key` is passed and `enc_mode` is not in ["AES-GCM", "AES-CBC", "SM4-CBC"].
998        ValueError: If `original_model_path` is not provided in `obf_config`.
999        ValueError: If the model saved in `original_model_path` has been obfuscated.
1000        ValueError: If `save_model_path` is not provided in `obf_config`.
1001        ValueError: If `obf_ratio` is not provided in `obf_config`.
1002        ValueError: If both `customized_func` and `obf_random_seed` are not provided in `obf_config`.
1003        ValueError: If `obf_random_seed` is not in (0, 9223372036854775807].
1004        ValueError: If `original_model_path` does not exist or `original_model_path` does not end with '.mindir'.
1005
1006    Examples:
1007        >>> import mindspore as ms
1008        >>> import mindspore.nn as nn
1009        >>> import numpy as np
1010        >>> # Download ori_net.mindir
1011        >>> # https://gitee.com/mindspore/mindspore/blob/master/tests/ut/python/mindir/ori_net.mindir
1012        >>> input1 = ms.Tensor(np.ones((1, 1, 32, 32)).astype(np.float32))
1013        >>> obf_config = {'original_model_path': "./net.mindir",
1014        ...          'save_model_path': "./obf_net",
1015        ...          'model_inputs': [input1, ],
1016        ...          'obf_ratio': 0.1, 'obf_random_seed': 173262358423}
1017        >>> ms.obfuscate_model(obf_config)
1018        >>> obf_func = ms.load("obf_net.mindir")
1019        >>> obf_net = nn.GraphCell(obf_func, obf_random_seed=173262358423)
1020        >>> print(obf_net(input1).asnumpy())
1021    """
1022    if not isinstance(obf_config, dict):
1023        raise TypeError("'obf_config' must be a dict, but got {}.".format(type(obf_config)))
1024    file_path = _check_param_type(obf_config, "original_model_path", str, True)
1025    if not file_path.endswith(".mindir"):
1026        raise ValueError("For 'obfuscate_model', the argument 'file_path'(MindIR file) should end with '.mindir', "
1027                         "please input the correct 'file_path'.")
1028    if not os.path.exists(file_path):
1029        raise ValueError("For 'obfuscate_model', the argument 'file_path'(MindIR file) does not exist, "
1030                         "please check whether the 'file_path' is correct.")
1031    saved_path = _check_param_type(obf_config, "save_model_path", str, True)
1032    model_inputs = _check_param_type(obf_config, "model_inputs", list, True)
1033    for item in model_inputs:
1034        if not isinstance(item, Tensor):
1035            raise TypeError("The item in 'model_inputs' must be Tensor, but got {}.".format(type(item)))
1036        if -1 in item.shape:
1037            raise ValueError(
1038                "Dynamic shape input is not supported now, but got the shape of inputs: {}.".format(item.shape))
1039    obf_ratio, customized_funcs, obf_random_seed = _check_obfuscate_params(obf_config)
1040    if customized_funcs and obf_random_seed > 0:
1041        logger.warning("Although 'customized_func' and 'obf_random_seed' are set, the 'obf_random_seed' mode would be"
1042                       " applied, remember to set 'obf_random_seed' when loading obfuscated model.")
1043
1044    if obf_random_seed == 0:  # apply customized_func mode
1045        clean_funcs()
1046        for func in customized_funcs:
1047            add_opaque_predicate(func.__name__, func)
1048        branch_control_input = 0
1049    else:  # apply password mode
1050        branch_control_input = _generate_branch_control_input(obf_random_seed)
1051
1052    if 'enc_key' in kwargs.keys():
1053        enc_key = Validator.check_isinstance('enc_key', kwargs.get('enc_key'), bytes)
1054        enc_mode = "AES-GCM"
1055        if 'enc_mode' in kwargs.keys():
1056            enc_mode = Validator.check_isinstance('enc_mode', kwargs.get('enc_mode'), str)
1057            if enc_mode not in ["AES-GCM", "AES-CBC", "SM4-CBC"]:
1058                raise ValueError(
1059                    "Only MindIR files that encrypted with 'AES-GCM', 'AES-CBC' or 'SM4-CBC' is supported for"
1060                    "obfuscate_model(), but got {}.".format(enc_mode))
1061        obf_graph = dynamic_obfuscate_mindir(file_name=file_path, obf_ratio=obf_ratio,
1062                                             branch_control_input=branch_control_input, dec_key=enc_key,
1063                                             key_len=len(enc_key),
1064                                             dec_mode=enc_mode)
1065    else:
1066        obf_graph = dynamic_obfuscate_mindir(file_name=file_path, obf_ratio=obf_ratio,
1067                                             branch_control_input=branch_control_input)
1068
1069    obf_net = nn.GraphCell(obf_graph)
1070    if obf_random_seed != 0:
1071        append_y_tensor = Tensor(np.ones((1, 1)).astype(np.int32))
1072        model_inputs += [append_y_tensor]
1073    export(obf_net, *model_inputs, file_name=saved_path, file_format="MINDIR", **kwargs)
1074
1075
1076def _load_into_param_dict(ckpt_file_name, parameter_dict, specify_prefix, filter_prefix, choice_func, dec_key,
1077                          dec_mode, crc_check):
1078    """load parameter into parameter_dict"""
1079    ckpt_file_name = _check_ckpt_file_name(ckpt_file_name)
1080    checkpoint_list = _parse_ckpt_proto(ckpt_file_name, dec_key, dec_mode, crc_check)
1081    try:
1082        param_data_list = []
1083        map_data_list = [[], [], []]
1084        map_shape_list = [0, 0, 0]
1085        if specify_prefix:
1086            logger.warning("For load_checkpoint, this parameter `specity_prefix` will be deprecated, "
1087                           "please use `choice_func` instead.")
1088        if filter_prefix:
1089            logger.warning("For load_checkpoint, this parameter `filter_prefix` will be deprecated, "
1090                           "please use `choice_func` instead.")
1091        for element_id, element in enumerate(checkpoint_list.value):
1092            if element.tag == "random_op":
1093                parameter_dict["random_op"] = element.tensor.tensor_content
1094                continue
1095            if not _whether_load_param(specify_prefix, filter_prefix, element.tag):
1096                continue
1097            if specify_prefix is None and filter_prefix is None and \
1098                    choice_func is not None and not choice_func(element.tag):
1099                continue
1100            if element.tensor.ByteSize() == 0:
1101                _load_map_parameter(checkpoint_list, element, element_id, map_data_list, map_shape_list,
1102                                    parameter_dict)
1103                if element.tag in parameter_dict:
1104                    map_data_list = [[], [], []]
1105                    map_shape_list = [0, 0, 0]
1106                continue
1107            data = element.tensor.tensor_content
1108            data_type = element.tensor.tensor_type
1109            np_type = tensor_to_np_type.get(data_type)
1110            ms_type = tensor_to_ms_type[data_type]
1111            if data_type == 'str':
1112                str_length = int(len(data) / 4)
1113                np_type = np_type + str(str_length)
1114            param_data_list.append(data)
1115            if (element_id == len(checkpoint_list.value) - 1) or \
1116                    (element.tag != checkpoint_list.value[element_id + 1].tag):
1117                new_data = b"".join(param_data_list)
1118                param_data_list.clear()
1119                dims = element.tensor.dims
1120                if data_type == 'str':
1121                    str_value = np.frombuffer(new_data, np_type)
1122                    parameter_dict[element.tag] = str(str_value[0])
1123                else:
1124                    if dims == [0]:
1125                        dims = []
1126                    param_data = Tensor_.convert_bytes_to_tensor(new_data, tuple(dims), ms_type)
1127                    parameter = Parameter(param_data, name=element.tag)
1128                    parameter_dict[element.tag] = parameter
1129                    _offload_if_config(parameter)
1130
1131        logger.info("Loading checkpoint files process is finished.")
1132
1133    except BaseException as e:
1134        logger.critical("Failed to load the checkpoint file '%s'.", ckpt_file_name)
1135        raise ValueError(e.__str__() + "\nFor 'load_checkpoint', "
1136                                       "failed to load the checkpoint file {}.".format(ckpt_file_name)) from e
1137
1138
1139def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=None,
1140                    dec_key=None, dec_mode="AES-GCM", specify_prefix=None, choice_func=None,
1141                    crc_check=False):
1142    """
1143    Load checkpoint info from a specified file.
1144
1145    Note:
1146        - `specify_prefix` and `filter_prefix` do not affect each other.
1147        - If none of the parameters are loaded from checkpoint file, it will throw ValueError.
1148        - `specify_prefix` and `filter_prefix` are in the process of being deprecated,
1149          `choice_func` is recommended instead.
1150          And using either of those two args will override `choice_func` at the same time.
1151
1152    Args:
1153        ckpt_file_name (str): Checkpoint file name.
1154        net (Cell): The network where the parameters will be loaded. Default: ``None`` .
1155        strict_load (bool): Whether to strict load the parameter into net. If ``False`` , it will load parameter
1156                            into net when parameter name's suffix in checkpoint file is the same as the
1157                            parameter in the network. When the types are inconsistent perform type conversion
1158                            on the parameters of the same type, such as float32 to float16. Default: ``False`` .
1159        filter_prefix (Union[str, list[str], tuple[str]]): Deprecated(see `choice_func`). Parameters starting with the
1160            filter_prefix will not be loaded. Default: ``None`` .
1161        dec_key (Union[None, bytes]): Byte type key used for decryption. If the value is ``None`` , the decryption
1162                                      is not required. Default: ``None`` .
1163        dec_mode (str): This parameter is valid only when dec_key is not set to ``None`` . Specifies the decryption
1164                        mode, currently supports ``"AES-GCM"`` and ``"AES-CBC"`` and ``"SM4-CBC"`` .
1165                        Default: ``"AES-GCM"`` .
1166        specify_prefix (Union[str, list[str], tuple[str]]): Deprecated(see `choice_func`). Parameters starting with the
1167            specify_prefix will be loaded. Default: ``None`` .
1168        choice_func (Union[None, function]) : Input value of the function is a Parameter name of type string,
1169            and the return value is a bool. If returns ``True`` , the Parameter
1170            that matches the custom condition will be loaded. If returns ``False`` , the Parameter that
1171            matches the custom condition will be removed. Default: ``None`` .
1172        crc_check (bool) : Whether to perform crc32 validation when loading checkpoint. Default: ``False`` .
1173
1174    Returns:
1175        Dict, key is parameter name, value is a Parameter or string. When the `append_dict` parameter of
1176        :func:`mindspore.save_checkpoint` and the `append_info` parameter of :class:`mindspore.train.CheckpointConfig`
1177        are used to save the checkpoint, `append_dict` and `append_info` are dict types, and their value are string,
1178        then the return value obtained by loading checkpoint is string, and in other cases the return value is
1179        Parameter.
1180
1181    Raises:
1182        ValueError: Checkpoint file's format is incorrect.
1183        ValueError: Parameter's dict is None after load checkpoint file.
1184        TypeError: The type of `specify_prefix` or `filter_prefix` is incorrect.
1185
1186    Examples:
1187        >>> import mindspore as ms
1188        >>>
1189        >>> ckpt_file_name = "./checkpoint/LeNet5-1_32.ckpt"
1190        >>> param_dict = ms.load_checkpoint(ckpt_file_name,
1191        ...                                 choice_func=lambda x: x.startswith("conv") and not x.startswith("conv1"))
1192        >>> print(param_dict["conv2.weight"])
1193        Parameter (name=conv2.weight, shape=(16, 6, 5, 5), dtype=Float32, requires_grad=True)
1194        >>> def func(param_name):
1195        ...     whether_load = False
1196        ...     if param_name.startswith("conv"):
1197        ...         whether_load = True
1198        ...     if param_name.startswith("conv1"):
1199        ...         whether_load = False
1200        ...     return whether_load
1201        >>> param_dict1 = ms.load_checkpoint(ckpt_file_name, choice_func=func)
1202        >>> print(param_dict1["conv2.weight"])
1203        Parameter (name=conv2.weight, shape=(16, 6, 5, 5), dtype=Float32, requires_grad=True)
1204        >>> def func(param_name):
1205        ...     whether_load = False
1206        ...     if param_name.startswith("conv1"):
1207        ...         whether_load = True
1208        ...     return whether_load
1209        >>> param_dict2 = ms.load_checkpoint(ckpt_file_name, choice_func=func)
1210        >>> print(param_dict2)
1211        {'conv1.weight': Parameter (name=conv1.weight, shape=(6, 1, 5, 5), dtype=Float32, requires_grad=True)}
1212
1213    Tutorial Examples:
1214        - `Saving and Loading the Model - Saving and Loading the Model Weight
1215          <https://mindspore.cn/tutorials/en/master/beginner/save_load.html#saving-and-loading-the-model-weight>`_
1216    """
1217    specify_prefix = _check_prefix(specify_prefix)
1218    filter_prefix = _check_prefix(filter_prefix)
1219    dec_key = Validator.check_isinstance('dec_key', dec_key, (type(None), bytes))
1220    dec_mode = Validator.check_isinstance('dec_mode', dec_mode, str)
1221    crc_check = Validator.check_isinstance('crc_check', crc_check, bool)
1222    logger.info("Execute the process of loading checkpoint files.")
1223
1224    parameter_dict = {}
1225
1226    if os.getenv("AITURBO") == "1":
1227        rank_id = get_rank()
1228        import aiturbo
1229        ckpt_path = os.path.dirname(ckpt_file_name)
1230        ckpt_name = os.path.basename(ckpt_file_name)
1231        np_dict = aiturbo.load_ckpt(ckpt_path, ckpt_name, rank_id)
1232        for key, value in np_dict.items():
1233            if isinstance(value, str):
1234                parameter_dict[key] = value
1235            else:
1236                parameter_dict[key] = Parameter(Tensor(value), name=key)
1237    else:
1238        _load_into_param_dict(ckpt_file_name, parameter_dict, specify_prefix, filter_prefix, choice_func, dec_key,
1239                              dec_mode, crc_check)
1240
1241    if not parameter_dict:
1242        raise ValueError(f"The loaded parameter dict is empty after filter or specify, please check whether "
1243                         f"'filter_prefix' or 'specify_prefix' are set correctly.")
1244
1245    if _warm_up_host_cache_enabled(parameter_dict):
1246        (is_worker, net_dict, warm_up_dict) = _warm_up_host_cache(parameter_dict, net)
1247    if net is not None:
1248        load_param_into_net(net, parameter_dict, strict_load)
1249    if _warm_up_host_cache_enabled(parameter_dict):
1250        _warm_up_host_cache_post_process(is_worker, net_dict, warm_up_dict)
1251
1252    return parameter_dict
1253
1254
1255def load_checkpoint_async(ckpt_file_name, net=None, strict_load=False, filter_prefix=None, dec_key=None,
1256                          dec_mode="AES-GCM", specify_prefix=None, choice_func=None):
1257    """
1258    Load checkpoint info from a specified file asyncly.
1259
1260    .. warning::
1261        This is an experimental API that is subject to change or deletion.
1262
1263    Note:
1264        - `specify_prefix` and `filter_prefix` do not affect each other.
1265        - If none of the parameters are loaded from checkpoint file, it will throw ValueError.
1266        - `specify_prefix` and `filter_prefix` are in the process of being deprecated,
1267          `choice_func` is recommended instead.
1268          And using either of those two args will override `choice_func` at the same time.
1269
1270    Args:
1271        ckpt_file_name (str): Checkpoint file name.
1272        net (Cell, optional): The network where the parameters will be loaded. Default: ``None`` .
1273        strict_load (bool, optional): Whether to strict load the parameter into net. If ``False`` , it will load
1274                                      parameter into net when parameter name's suffix in checkpoint file is the
1275                                      same as the parameter in the network. When the types are inconsistent
1276                                      perform type conversion on the parameters of the same type, such as float32
1277                                      to float16. Default: ``False`` .
1278        filter_prefix (Union[str, list[str], tuple[str]], optional): Deprecated(see `choice_func`). Parameters
1279            starting with the `filter_prefix` will not be loaded. Default: ``None`` .
1280        dec_key (Union[None, bytes], optional): Byte type key used for decryption. If the value is ``None`` ,
1281                                                the decryption is not required. Default: ``None`` .
1282        dec_mode (str, optional): This parameter is valid only when dec_key is not set to ``None`` . Specifies
1283                                  the decryption mode, currently supports ``"AES-GCM"`` and ``"AES-CBC"``
1284                                  and ``"SM4-CBC"`` . Default: ``"AES-GCM"`` .
1285        specify_prefix (Union[str, list[str], tuple[str]], optional): Deprecated(see `choice_func`). Parameters
1286            starting with the specify_prefix will be loaded. Default: ``None`` .
1287        choice_func (Union[None, function], optional): Input value of the function is a Parameter name of type
1288            string, and the return value is a bool. If returns ``True`` , the Parameter
1289            that matches the custom condition will be loaded. If returns ``False`` , the Parameter that
1290            matches the custom condition will be removed. Default: ``None`` .
1291
1292    Returns:
1293        A custom inner class, calling its `result` method yields the :func:`mindspore.load_checkpoint` result.
1294
1295    Raises:
1296        ValueError: Checkpoint file's format is incorrect.
1297        ValueError: Parameter's dict is None after load checkpoint file.
1298        TypeError: The type of `specify_prefix` or `filter_prefix` is incorrect.
1299
1300    Examples:
1301        >>> import mindspore
1302        >>> from mindspore import nn
1303        >>> from mindspore.train import Model
1304        >>> from mindspore.amp import FixedLossScaleManager
1305        >>> from mindspore import context
1306        >>> from mindspore import load_checkpoint_async
1307        >>> from mindspore import load_param_into_net
1308        >>> context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
1309        >>> # Create the dataset taking MNIST as an example. Refer to
1310        >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/mnist.py
1311        >>> dataset = create_dataset()
1312        >>> # Define the network structure of LeNet5. Refer to
1313        >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
1314        >>> ckpt_file = "./checkpoint/LeNet5-1_32.ckpt"
1315        >>> net = LeNet5()
1316        >>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
1317        >>> loss_scale_manager = FixedLossScaleManager()
1318        >>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
1319        >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None,
1320        ...               loss_scale_manager=loss_scale_manager)
1321        >>> pd_future = load_checkpoint_async(ckpt_file)
1322        >>> model.build(train_dataset=dataset, epoch=2)
1323        >>> param_dict = pd_future.result()
1324        >>> load_param_into_net(net, param_dict)
1325        >>> model.train(2, dataset)
1326        >>> print("param dict len: ", len(param_dict), flush=True)
1327    """
1328    from concurrent.futures import ThreadPoolExecutor
1329    executor = ThreadPoolExecutor(max_workers=2)
1330    param_dict_future = executor.submit(load_checkpoint, ckpt_file_name, net, strict_load, filter_prefix,
1331                                        dec_key, dec_mode, specify_prefix, choice_func)
1332    return ParamDictFuture(executor, param_dict_future)
1333
1334
1335def _load_map_parameter(checkpoint_list, element, element_id, map_data_list,
1336                        map_shape_list, parameter_dict):
1337    """load map parameter."""
1338    logger.info("Checkpoint load map_parameter.")
1339    if (element_id != len(checkpoint_list.value) - 1) and \
1340            element.tag == checkpoint_list.value[element_id + 1].tag:
1341        for index, tensor in enumerate(element.maptensor.tensor):
1342            data = tensor.tensor_content
1343            data_type = tensor.tensor_type
1344            np_type = np_type_convert.get(data_type)
1345            element_data = np.frombuffer(data, np_type)
1346            map_data_list[index].append(element_data)
1347            map_shape_list[index] += tensor.dims[0]
1348    else:
1349        map_array = []
1350        for index, tensor in enumerate(element.maptensor.tensor):
1351            data = tensor.tensor_content
1352            data_type = tensor.tensor_type
1353            np_type = np_type_convert.get(data_type)
1354            element_data = np.frombuffer(data, np_type)
1355            map_data_list[index].append(element_data)
1356            new_data = b"".join(map_data_list[index])
1357            param_data = np.frombuffer(new_data, np_type)
1358            dims = tensor.dims
1359            dims[0] += map_shape_list[index]
1360            param_data = param_data.reshape(list(dims))
1361            map_array.append(param_data)
1362        parameter_dict[element.tag] = map_array
1363
1364
1365def _check_ckpt_file_name(ckpt_file_name):
1366    """Check function load_checkpoint's ckpt_file_name."""
1367    if not isinstance(ckpt_file_name, str):
1368        raise TypeError("For 'load_checkpoint', the argument 'ckpt_file_name' must be string, "
1369                        "but got {}.".format(type(ckpt_file_name)))
1370
1371    if ckpt_file_name[-5:] != ".ckpt":
1372        raise ValueError("For 'load_checkpoint', the checkpoint file should end with '.ckpt', please "
1373                         "input the correct 'ckpt_file_name'.")
1374
1375    ckpt_file_name = os.path.abspath(ckpt_file_name)
1376    if not os.path.exists(ckpt_file_name):
1377        raise ValueError("For 'load_checkpoint', the checkpoint file: {} does not exist, please check "
1378                         "whether the 'ckpt_file_name' is correct.".format(ckpt_file_name))
1379
1380    return ckpt_file_name
1381
1382
1383def _check_prefix(prefix):
1384    """Check the correctness of the parameters."""
1385    if prefix is None:
1386        return prefix
1387    if not isinstance(prefix, (str, list, tuple)):
1388        raise TypeError("For 'load_checkpoint', the type of 'specify_prefix' or 'filter_prefix' must be string, "
1389                        "list[string] or tuple[string], but got {}.".format(str(type(prefix))))
1390    if isinstance(prefix, str):
1391        prefix = (prefix,)
1392    if not prefix:
1393        raise ValueError("For 'load_checkpoint', the argument 'specify_prefix' or 'filter_prefix' can't be empty when"
1394                         " 'specify_prefix' or 'filter_prefix' is list or tuple.")
1395    for index, pre in enumerate(prefix):
1396        if not isinstance(pre, str):
1397            raise TypeError("For 'load_checkpoint', when 'specify_prefix' or 'filter_prefix' is list or tuple, "
1398                            "the element in it must be string, but got "
1399                            f"{str(type(pre))} at index {index}.")
1400        if pre == "":
1401            raise ValueError("For 'load_checkpoint', the value of 'specify_prefix' or 'filter_prefix' "
1402                             "can't include ''.")
1403    return prefix
1404
1405
1406def _parse_ckpt_proto(ckpt_file_name, dec_key, dec_mode, crc_check):
1407    """Parse checkpoint protobuf."""
1408    checkpoint_list = Checkpoint()
1409    try:
1410        if dec_key is None:
1411            with _ckpt_fs.open(ckpt_file_name, *_ckpt_fs.open_args) as f:
1412                pb_content = f.read()
1413        else:
1414            pb_content = _decrypt(ckpt_file_name, dec_key, len(dec_key), dec_mode)
1415            if pb_content is None:
1416                raise ValueError("For 'load_checkpoint', failed to decrypt the checkpoint file.")
1417        if crc_check and pb_content[-17:-10] == b"crc_num":
1418            logger.warning("For 'load_checkpoint', the ckpt file do not contain the crc code, please check the file.")
1419        if pb_content[-17:-10] == b"crc_num":
1420            crc_num_bytes = pb_content[-10:]
1421            pb_content = pb_content[:-17]
1422            if crc_check:
1423                crc_num = int.from_bytes(crc_num_bytes, byteorder='big')
1424                cal_crc_num = binascii.crc32(pb_content, 0)
1425                if cal_crc_num != crc_num:
1426                    raise ValueError("For 'load_checkpoint', the crc check is failed, "
1427                                     "please check whether the ckpt file is damaged.")
1428        checkpoint_list.ParseFromString(pb_content)
1429    except BaseException as e:
1430        if _is_cipher_file(ckpt_file_name):
1431            err_info = "Failed to read the checkpoint file {}. The file may be encrypted or tempered with, " \
1432                       "please pass in the correct 'dec_key' or check the file integrity.".format(ckpt_file_name)
1433        else:
1434            err_info = "Failed to read the checkpoint file {}. May not have permission to read it, please check" \
1435                       " the correct of the file.".format(ckpt_file_name)
1436        logger.error(err_info)
1437        raise ValueError(err_info) from e
1438    return checkpoint_list
1439
1440
1441def _whether_load_param(specify_prefix, filter_prefix, param_name):
1442    """Checks whether the load the parameter after `specify_prefix` or `filter_prefix`."""
1443    whether_load = True
1444    if specify_prefix:
1445        whether_load = False
1446        for prefix in specify_prefix:
1447            if param_name.startswith(prefix):
1448                whether_load = True
1449                break
1450    if filter_prefix:
1451        for prefix in filter_prefix:
1452            if param_name.startswith(prefix):
1453                whether_load = False
1454                break
1455    return whether_load
1456
1457
1458def _init_parameter_data_in_parallel_mode(net, parameter_dict):
1459    """In parallel mode, only init the paraemters in ckpt."""
1460    is_train_phase = net.phase.startswith('train')
1461    for _, param in net.parameters_and_names():
1462        if param.name in parameter_dict and param.from_ckpt and not is_train_phase:
1463            param.shape = tuple(parameter_dict[param.name].shape)
1464            continue
1465        if param.name in parameter_dict and param.has_init:
1466            logger.warning("{} is not init while load ckpt.".format(param.name))
1467            new_tensor = param.init_data()
1468            param._update_tensor_data(new_tensor)
1469
1470
1471def _check_load_param_into_net(net, parameter_dict):
1472    """check load_param_into_net"""
1473    if not isinstance(net, nn.Cell):
1474        logger.critical("Failed to combine the net and the parameters.")
1475        msg = ("For 'load_param_into_net', the argument 'net' should be a Cell, but got {}.".format(type(net)))
1476        raise TypeError(msg)
1477    if not isinstance(parameter_dict, dict):
1478        logger.critical("Failed to combine the net and the parameters.")
1479        msg = ("For 'load_param_into_net', the argument 'parameter_dict' should be a dict, "
1480               "but got {}.".format(type(parameter_dict)))
1481        raise TypeError(msg)
1482    if "random_op" in parameter_dict.keys():
1483        net._add_attr("random_op_snapshot", parameter_dict["random_op"])
1484        parameter_dict.pop("random_op")
1485
1486
1487def load_param_into_net(net, parameter_dict, strict_load=False):
1488    """
1489    Load parameters into network, return parameter list that are not loaded in the network.
1490
1491    Args:
1492        net (Cell): The network where the parameters will be loaded.
1493        parameter_dict (dict): The dictionary generated by load checkpoint file,
1494                               it is a dictionary consisting of key: parameters's name, value: parameter.
1495        strict_load (bool): Whether to strict load the parameter into net. If ``False`` , it will load parameter
1496                            into net when parameter name's suffix in checkpoint file is the same as the
1497                            parameter in the network. When the types are inconsistent perform type conversion
1498                            on the parameters of the same type, such as float32 to float16. Default: ``False`` .
1499
1500    Returns:
1501        - param_not_load (List), the parameter name in model which are not loaded into the network.
1502        - ckpt_not_load (List), the parameter name in checkpoint file which are not loaded into the network.
1503
1504    Raises:
1505        TypeError: Argument is not a Cell, or parameter_dict is not a Parameter dictionary.
1506
1507    Examples:
1508        >>> import mindspore as ms
1509        >>>
1510        >>> # Define the network structure of LeNet5. Refer to
1511        >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
1512        >>> net = LeNet5()
1513        >>> ckpt_file_name = "./checkpoint/LeNet5-1_32.ckpt"
1514        >>> param_dict = ms.load_checkpoint(ckpt_file_name, filter_prefix="conv1")
1515        >>> param_not_load, _ = ms.load_param_into_net(net, param_dict)
1516        >>> print(param_not_load)
1517        ['conv1.weight']
1518
1519    Tutorial Examples:
1520        - `Saving and Loading the Model - Saving and Loading the Model Weight
1521          <https://mindspore.cn/tutorials/en/master/beginner/save_load.html#saving-and-loading-the-model-weight>`_
1522    """
1523    _check_load_param_into_net(net, parameter_dict)
1524    for key, value in parameter_dict.items():
1525        if not isinstance(key, str) or not isinstance(value, (Parameter, str, list)):
1526            logger.critical("Load parameters into net failed.")
1527            msg = ("For 'parameter_dict', the element in the argument 'parameter_dict' should be a "
1528                   "'str' and 'Parameter' , but got {} and {}.".format(type(key), type(value)))
1529            raise TypeError(msg)
1530
1531    strict_load = Validator.check_bool(strict_load)
1532    logger.info("Execute the process of loading parameters into net.")
1533    for _, param in net.parameters_and_names():
1534        param.from_ckpt = True
1535    if not _is_in_auto_parallel_mode():
1536        net.init_parameters_data()
1537    else:
1538        _init_parameter_data_in_parallel_mode(net, parameter_dict)
1539    param_not_load = []
1540    ckpt_not_load = list(parameter_dict.keys())
1541    for _, param in net.parameters_and_names():
1542        if param.name in parameter_dict:
1543            if isinstance(param, MapParameter):
1544                param.import_data(parameter_dict[param.name])
1545                continue
1546            # Add has attr protection when load server checkpoint file on worker.
1547            if not hasattr(parameter_dict[param.name], "data"):
1548                continue
1549            new_param = parameter_dict[param.name]
1550            _update_param(param, new_param, strict_load)
1551            ckpt_not_load.remove(param.name)
1552        else:
1553            param_not_load.append(param.name)
1554
1555    if param_not_load and not strict_load:
1556        _load_dismatch_prefix_params(net, parameter_dict, param_not_load, strict_load)
1557
1558    logger.info("Loading parameters into net is finished.")
1559    if param_not_load:
1560        logger.warning("For 'load_param_into_net', "
1561                       "{} parameters in the 'net' are not loaded, because they are not in the "
1562                       "'parameter_dict', please check whether the network structure is consistent "
1563                       "when training and loading checkpoint.".format(len(param_not_load)))
1564        logger.warning("{} are not loaded.".format(param_not_load))
1565    if os.getenv("AITURBO") == "1" and net.parameter_layout_dict is not None:
1566        param_layout = net.parameter_layout_dict
1567        param_redundancy = get_parameter_redundancy(param_layout)
1568        remove_param_redundancy_dict = remove_param_redundancy(param_redundancy)
1569        target_parameter_name_set = set(parameter_dict.keys())
1570        for rank_id, param_name_set in remove_param_redundancy_dict:
1571            if param_name_set == target_parameter_name_set:
1572                parameter_broadcast(net, param_layout, rank_id)
1573    return param_not_load, ckpt_not_load
1574
1575
1576def _warm_up_host_cache_enabled(parameter_dict):
1577    """Warm up host cache enabled."""
1578    if _cache_enable():
1579        return True
1580    for key in parameter_dict.keys():
1581        if key.find(".__param_key__") != -1:
1582            return True
1583    return False
1584
1585
1586def _warm_up_host_cache(parameter_dict, net):
1587    """Warm up host cache."""
1588    ms_role = os.getenv("MS_ROLE")
1589    is_worker = ms_role == "MS_WORKER"
1590    param_key_dict = {}
1591    # Traverse key, value in parameter_dict, warm up param key and record param key into param_key_dict.
1592    if is_worker:
1593        net.init_parameters_data()
1594        net_dict = {}
1595        for name, value in net.parameters_and_names():
1596            net_dict[name] = value
1597        for param_name, value in parameter_dict.items():
1598            pos = param_name.find(".__param_key__")
1599            if pos != -1:
1600                net_param_name = param_name[:pos]
1601                param_key_dict[param_name] = net_param_name
1602                net_value = None
1603                if net_param_name not in net_dict:
1604                    logger.warning("net param name : %s is not in net", net_param_name)
1605                else:
1606                    net_value = net_dict.get(net_param_name, None)
1607                pos += len(".__param_key__")
1608                param_key = int(param_name[pos:])
1609                value_is_map_parameter = isinstance(value, list) and len(value) == 3
1610                if value_is_map_parameter and (net_value is None or isinstance(net_value, Parameter)):
1611                    key_tensor = Tensor.from_numpy(value[0])
1612                    value_tensor = Tensor.from_numpy(value[1])
1613                    status_tensor = Tensor.from_numpy(value[2])
1614                    _store_warm_up_ptr_by_tensor_list(param_key, key_tensor, value_tensor, status_tensor)
1615                elif not isinstance(value, list) and isinstance(net_value, Parameter):
1616                    _store_warm_up_ptr_by_tensor(param_key, value)
1617                else:
1618                    logger.warning("Unknown matches parameter type %s and net_value %s", type(value), type(net_value))
1619    else:
1620        for param_name, value in parameter_dict.items():
1621            pos = param_name.find(".__param_key__")
1622            if pos != -1:
1623                net_param_name = param_name[:pos]
1624                param_key_dict[param_name] = net_param_name
1625    # Split param key from parameter_dict since worker cannot load param key.
1626    warm_up_dict = {}
1627    for key, value in param_key_dict.items():
1628        if is_worker:
1629            warm_up_dict[value] = parameter_dict.pop(key)
1630        else:
1631            parameter_dict[value] = parameter_dict.pop(key)
1632    return (is_worker, parameter_dict, warm_up_dict)
1633
1634
1635def _warm_up_host_cache_post_process(is_worker, net_dict, warm_up_dict):
1636    """Warm up host cache post process."""
1637    if is_worker:
1638        net_dict.update(warm_up_dict)
1639    _set_checkpoint_load_status(True)
1640
1641
1642def _load_dismatch_prefix_params(net, parameter_dict, param_not_load, strict_load):
1643    """When some net parameter did not load, try to continue loading."""
1644    prefix_name = ""
1645    longest_name = param_not_load[0]
1646    while prefix_name != longest_name and param_not_load:
1647        logger.debug("Count: {} parameters has not been loaded, try to continue loading.".format(len(param_not_load)))
1648        prefix_name = longest_name
1649        for net_param_name in param_not_load:
1650            for dict_name in parameter_dict:
1651                if dict_name.endswith(net_param_name):
1652                    prefix_name = dict_name[:-len(net_param_name)]
1653                    break
1654            if prefix_name != longest_name:
1655                break
1656
1657        if prefix_name != longest_name:
1658            logger.warning(f"For 'load_param_into_net', remove parameter prefix name: {prefix_name},"
1659                           f" continue to load.")
1660            for _, param in net.parameters_and_names():
1661                new_param_name = prefix_name + param.name
1662                if param.name in param_not_load and new_param_name in parameter_dict:
1663                    new_param = parameter_dict[new_param_name]
1664                    _update_param(param, new_param, strict_load)
1665                    param_not_load.remove(param.name)
1666
1667
1668def _save_graph(network, file_name):
1669    """
1670    Saves the graph of network to a file.
1671
1672    Args:
1673        network (Cell): Obtain a pipeline through network for saving graph.
1674        file_name (str): Graph file name into which the graph will be saved.
1675    """
1676    logger.info("Execute the process of saving graph.")
1677
1678    file_name = os.path.abspath(file_name)
1679    graph_pb = network.get_func_graph_proto()
1680    if graph_pb:
1681        with open(file_name, "wb") as f:
1682            os.chmod(file_name, stat.S_IRUSR | stat.S_IWUSR)
1683            f.write(graph_pb)
1684
1685
1686def _reshape_tensor(tensor, dst_shape):
1687    """reshape tensor to dst shape"""
1688    np_tensor = tensor.asnumpy()
1689    np_tensor = np_tensor.reshape(dst_shape)
1690    return Tensor(np_tensor, tensor.dtype)
1691
1692
1693def _check_param_for_integrate_save(pipeline_stages, uniform_split):
1694    """check whether current settings and parameters are supported in integrated save checkpoint mode"""
1695    if pipeline_stages > 1:
1696        raise RuntimeError("Pipeline Parallel don't support Integrated save checkpoint now.")
1697    if uniform_split == 0:
1698        raise RuntimeError("For 'save_checkpoint' and in automatic model parallel scene, when set "
1699                           "'integrated_save' to True, the checkpoint will be integrated save, it "
1700                           "is only supports uniform split tensor now.")
1701
1702
1703def _get_merged_param_data(net, parameter_layout_dict, param_name, param_data, integrated_save):
1704    """
1705    Gets the merged data(tensor) from tensor slice, by device arrangement and tensor map.
1706
1707    Args:
1708        net (Cell): MindSpore network.
1709        param_name (str): The parameter name, which to be combined.
1710        param_data (Tensor): The parameter data on the local device, which was a slice of the whole parameter data.
1711        integrated_save (bool): Whether to integrated save in automatic model parallel scene.
1712    Returns:
1713        Tensor, the combined tensor which with the whole data value.
1714    """
1715    layout = parameter_layout_dict[param_name]
1716    if len(layout) < 8:
1717        logger.info("The layout dict does not contain the key %s", param_name)
1718        return param_data
1719
1720    dev_mat = layout[0]
1721    tensor_map = layout[1]
1722    uniform_split = layout[4]
1723    opt_shard_group = layout[5]
1724    before_reshape_slice_shape = layout[2]
1725    before_reshape_full_shape = layout[6]
1726    after_reshape_slice_shape = layout[7]
1727    do_reshape = False
1728    if before_reshape_full_shape and after_reshape_slice_shape \
1729            and after_reshape_slice_shape != before_reshape_slice_shape:
1730        do_reshape = True
1731
1732    allgather_net = None
1733    mp_weight = False
1734    for dim in tensor_map:
1735        if dim != -1:
1736            mp_weight = True
1737            break
1738    if param_name in net.parallel_parameter_merge_net_dict:
1739        allgather_net = net.parallel_parameter_merge_net_dict[param_name]
1740    else:
1741        logger.info("Need to create allgather net for %s", param_name)
1742        if integrated_save:
1743            _check_param_for_integrate_save(context.get_auto_parallel_context("pipeline_stages"), uniform_split)
1744            # while any dim is not equal to -1, means param is split and needs to be merged
1745            # pipeline parallel need to be supported here later
1746            if mp_weight:
1747                allgather_net = get_allgather_cell(opt_shard_group, bool(opt_shard_group), do_reshape,
1748                                                   tuple(after_reshape_slice_shape))
1749                object.__setattr__(allgather_net, "keep_input_unchanged", True)
1750            elif opt_shard_group:
1751                allgather_net = get_allgather_cell(opt_shard_group, False, do_reshape,
1752                                                   tuple(after_reshape_slice_shape))
1753        elif opt_shard_group and context.get_auto_parallel_context("optimizer_weight_shard_aggregated_save"):
1754            allgather_net = get_allgather_cell(opt_shard_group, False, do_reshape,
1755                                               tuple(after_reshape_slice_shape))
1756        net.parallel_parameter_merge_net_dict[param_name] = allgather_net
1757    if allgather_net:
1758        param_data = allgather_net(param_data)
1759    if mp_weight and integrated_save:
1760        param_data = _reshape_param_data(param_data, dev_mat, tensor_map)
1761        if do_reshape:
1762            param_data = _reshape_tensor(param_data, before_reshape_full_shape)
1763    return param_data
1764
1765
1766def export(net, *inputs, file_name, file_format, **kwargs):
1767    """
1768    Export the MindSpore network into an offline model in the specified format.
1769
1770    Note:
1771        1. When exporting AIR, ONNX format, the size of a single tensor can not exceed 2GB.
1772        2. When `file_name` does not have a suffix, the system will automatically add one
1773           according to the `file_format`.
1774        3. Exporting functions decorated with :func:`mindspore.jit` to mindir format is supported.
1775        4. When exporting a function decorated with :func:`mindspore.jit`, the function should not involve
1776           class properties in calculations.
1777        5. AIR format is deprecated, and will be removed in a future version, please use other format or use
1778           MindSpore Lite to do offline inference.
1779
1780    Args:
1781        net (Union[Cell, function]): MindSpore network.
1782        inputs (Union[Tensor, Dataset, List, Tuple, Number, Bool]): It represents the inputs
1783             of the `net`, if the network has multiple inputs, set them together. While its type is Dataset,
1784             it represents the preprocess behavior of the `net`, data preprocess operations will be serialized.
1785             In second situation, you should adjust batch size of dataset script manually which will impact on
1786             the batch size of 'net' input. Only supports parse "image" column from dataset currently.
1787        file_name (str): File name of the model to be exported.
1788        file_format (str): MindSpore currently supports 'AIR', 'ONNX' and 'MINDIR' format for exported model.
1789
1790            - AIR: Ascend Intermediate Representation. An intermediate representation format of Ascend model.
1791            - ONNX: Open Neural Network eXchange. An open format built to represent machine learning models.
1792            - MINDIR: MindSpore Native Intermediate Representation for Anf. An intermediate representation format
1793              for MindSpore models.
1794
1795        kwargs (dict): Configuration options dictionary.
1796
1797            - enc_key (byte): Byte-type key used for encryption. The valid length is 16, 24, or 32.
1798            - enc_mode (Union[str, function]): Specifies the encryption mode, to take effect when enc_key is set.
1799
1800              - For 'AIR' and 'ONNX' models, only customized encryption is supported.
1801              - For 'MINDIR', all options are supported. Option: 'AES-GCM', 'AES-CBC', 'SM4-CBC'
1802                or Customized encryption.
1803                Default: ``'AES-GCM'``.
1804              - For details of using the customized encryption, please check the `tutorial
1805                <https://mindspore.cn/mindarmour/docs/en/master/model_encrypt_protection.html>`_.
1806
1807            - dataset (Dataset): Specifies the preprocessing method of the dataset, which is used to import the
1808              preprocessing of the dataset into MindIR.
1809
1810            - obf_config (dict): obfuscation config.
1811
1812              - type (str): The type of obfuscation, only 'dynamic' is supported until now.
1813              - obf_ratio (float, str): The ratio of nodes in original model that would be obfuscated. `obf_ratio`
1814                should be in range of (0, 1] or in ["small", "medium", "large"]. "small", "medium" and "large" are
1815                correspond to 0.1, 0.3, and 0.6 respectively.
1816              - customized_func (function): A python function used for customized function mode, which used for control
1817                the switch branch of obfuscation structure. The outputs of customized_func should be boolean and const (
1818                Reference to 'my_func()' in
1819                `tutorials <https://www.mindspore.cn/mindarmour/docs/en/master/dynamic_obfuscation_protection.html>`_).
1820                This function needs to ensure that its result is constant for any input. Users can refer to opaque
1821                predicates. If customized_func is set, then it should be passed to `load()` interface when loading
1822                obfuscated model.
1823              - obf_random_seed (int): Obfuscation random seed, which should be in (0, 9223372036854775807]. The
1824                structure of obfuscated models corresponding to different random seeds is different. If
1825                `obf_random_seed` is set, then it should be passed
1826                to :class:`mindspore.nn.GraphCell` interface when loading
1827                obfuscated model. It should be noted that at least one of `customized_func` or `obf_random_seed` should
1828                be set, and the latter mode would be applied if both of them are set.
1829
1830            - incremental (bool): export MindIR incrementally.
1831
1832            - custom_func (function): Functions for custom defined export policies. This function will be used to
1833              customize the model during network export. Currently only support for files with mindir format. The
1834              function only accepts one input representing the proto object of the mindir file. When modifying a model,
1835              it is necessary to ensure the correctness of the `custom_func` , otherwise it may lead to model loading
1836              failure or functional errors. Default: ``None`` .
1837
1838    Examples:
1839        >>> import mindspore as ms
1840        >>> import numpy as np
1841        >>> from mindspore import Tensor
1842        >>>
1843        >>> # Define the network structure of LeNet5. Refer to
1844        >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
1845        >>> net = LeNet5()
1846        >>> input_tensor = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32))
1847        >>> ms.export(net, input_tensor, file_name='lenet', file_format='MINDIR')
1848        >>>
1849        >>> # Export model in MindIR format and modified the model info using custom_func
1850        >>> # The custom_func only support one input representing the Proto object of the model
1851        >>> # And custom_func does not support return value
1852        >>> def _custom_func(mindir_model):
1853        ...     mindir_model.producer_name = "test11111"
1854        ...     mindir_model.producer_version = "11.0"
1855        ...     mindir_model.user_info["version"] = "11.0"
1856        >>> ms.export(net, input_tensor, file_name="lenet", file_format='MINDIR', custom_func=_custom_func)
1857
1858
1859    Tutorial Examples:
1860        - `Saving and Loading the Model - Saving and Loading MindIR
1861          <https://mindspore.cn/tutorials/en/master/beginner/save_load.html#saving-and-loading-mindir>`_
1862    """
1863    old_ms_jit_value = context.get_context("jit_syntax_level")
1864    context.set_context(jit_syntax_level=mindspore.STRICT)
1865
1866    supported_formats = ['AIR', 'ONNX', 'MINDIR']
1867    if file_format not in supported_formats:
1868        raise ValueError(f"For 'export', 'file_format' must be one of {supported_formats}, but got {file_format}.")
1869    if file_format == 'AIR':
1870        logger.warning("AIR format is deprecated, and will be removed in a future version, please use other format or "
1871                       "use MindSpore Lite to do offline inference")
1872    Validator.check_file_name_by_regular(file_name)
1873    logger.info("exporting model file:%s format:%s.", file_name, file_format)
1874
1875    if check_input_dataset(*inputs, dataset_type=mindspore.dataset.Dataset):
1876        if len(inputs) != 1:
1877            raise RuntimeError(f"You can only serialize one dataset into MindIR, got " + str(len(inputs)) + " datasets")
1878        shapes, types, columns = inputs[0].output_shapes(), inputs[0].output_types(), inputs[0].get_col_names()
1879        kwargs['dataset'] = inputs[0]
1880        only_support_col = "image"
1881
1882        inputs_col = list()
1883        for c, s, t in zip(columns, shapes, types):
1884            if only_support_col != c:
1885                continue
1886            inputs_col.append(Tensor(np.random.uniform(-1.0, 1.0, size=s).astype(t)))
1887        if not inputs_col:
1888            raise RuntimeError(f"Only supports parse \"image\" column from dataset now, given dataset has columns: "
1889                               + str(columns))
1890        inputs = tuple(inputs_col)
1891
1892    file_name = os.path.abspath(file_name)
1893    if 'enc_key' in kwargs.keys():
1894        kwargs['enc_key'], kwargs['enc_mode'] = _check_key_mode_type(file_format, **kwargs)
1895    _export(net, file_name, file_format, *inputs, **kwargs)
1896
1897    context.set_context(jit_syntax_level=old_ms_jit_value)
1898
1899
1900def _get_funcgraph(net, *inputs):
1901    """
1902    Compile the MindSpore network and get FuncGraph.
1903
1904    Arg:
1905        net (Union[Cell, function]): MindSpore network.
1906        inputs (Union[Tensor, Dataset, List, Tuple, Number, Bool]): It represents the inputs
1907             of the `net`, if the network has multiple inputs, set them together. While its type is Dataset,
1908             it represents the preprocess behavior of the `net`, data preprocess operations will be serialized.
1909             In second situation, you should adjust batch size of dataset script manually which will impact on
1910             the batch size of 'net' input. Only supports parse "image" column from dataset currently.
1911
1912    Returns:
1913        FuncGraph, a mindspore._c_expression.FuncGraph obj.
1914
1915    Raises:
1916        ValueError: input `net` is not a nn.Cell.
1917
1918    Examples:
1919        >>> import mindspore as ms
1920        >>> import numpy as np
1921        >>> from mindspore import Tensor
1922        >>>
1923        >>> # Define the network structure of LeNet5. Refer to
1924        >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
1925        >>> net = LeNet5()
1926        >>> input_tensor = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32))
1927        >>> ms.get_funcgraph(net, input_tensor)
1928
1929    """
1930    if not isinstance(net, nn.Cell):
1931        raise ValueError(f"For get_funcgraph's parameter 'net', currently only support Cell right now.")
1932    phase_name = "lite_infer_predict" if _is_in_auto_parallel_mode() else "lite_infer_get_func_graph"
1933    graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False)
1934    # pylint: disable=protected-access
1935    func_graph = _executor._get_func_graph(net, graph_id)
1936    return func_graph
1937
1938
1939def _export(net, file_name, file_format, *inputs, **kwargs):
1940    """
1941    It is an internal conversion function. Export the MindSpore prediction model to a file in the specified format.
1942    """
1943    logger.info("exporting model file:%s format:%s.", file_name, file_format)
1944    if "obf_config" in kwargs and file_format != "MINDIR":
1945        raise ValueError(f"Dynamic obfuscation only support for MindIR format, but got {file_format} format.")
1946    if "custom_func" in kwargs and file_format != "MINDIR":
1947        raise ValueError(f"Currently only support custom_func for MindIR format, but got {file_format} format.")
1948    if file_format == 'AIR':
1949        _save_air(net, file_name, *inputs, **kwargs)
1950    elif file_format == 'ONNX':
1951        _save_onnx(net, file_name, *inputs, **kwargs)
1952    elif file_format == 'MINDIR':
1953        _save_mindir(net, file_name, *inputs, **kwargs)
1954
1955
1956def _check_key_mode_type(file_format, **kwargs):
1957    """check enc_key and enc_mode are valid"""
1958    enc_key = Validator.check_isinstance('enc_key', kwargs.get('enc_key'), bytes)
1959    enc_mode = kwargs.get('enc_mode')
1960
1961    if callable(enc_mode):
1962        return enc_key, enc_mode
1963
1964    enc_mode = 'AES-GCM'
1965    if 'enc_mode' in kwargs.keys():
1966        enc_mode = Validator.check_isinstance('enc_mode', kwargs.get('enc_mode'), str)
1967
1968    if file_format in ('AIR', 'ONNX'):
1969        raise ValueError(f"AIR/ONNX only support customized encryption, but got {enc_mode}.")
1970
1971    if enc_mode in ('AES-CBC', 'AES-GCM', 'SM4-CBC'):
1972        return enc_key, enc_mode
1973    raise ValueError(f"MindIR only support AES-GCM/AES-CBC/SM4-CBC encryption, but got {enc_mode}")
1974
1975
1976def _save_air(net, file_name, *inputs, **kwargs):
1977    """Save AIR format file."""
1978    phase_name = 'export.air'
1979    graph_id, _ = _executor.compile(net, *inputs, phase=phase_name)
1980    if not file_name.endswith('.air'):
1981        file_name += ".air"
1982    if os.path.exists(file_name):
1983        os.chmod(file_name, stat.S_IWUSR)
1984    if "/" in file_name:
1985        real_path = os.path.abspath(file_name[:file_name.rfind("/")])
1986        os.makedirs(real_path, exist_ok=True)
1987    if 'enc_key' in kwargs.keys() and 'enc_mode' in kwargs.keys():
1988        _executor.export(file_name, graph_id, enc_key=kwargs.get('enc_key'), encrypt_func=kwargs.get('enc_mode'))
1989    else:
1990        _executor.export(file_name, graph_id)
1991    os.chmod(file_name, stat.S_IRUSR)
1992
1993
1994def _save_onnx(net, file_name, *inputs, **kwargs):
1995    """Save ONNX format file."""
1996    # When dumping ONNX file, switch network mode to infer when it is training(NOTE: ONNX only designed for prediction)
1997    if not isinstance(net, nn.Cell):
1998        raise ValueError(f"Export ONNX format model only support nn.Cell object, but got {type(net)}.")
1999    _check_dynamic_input(inputs)
2000    cell_mode = net.training
2001    net.set_train(mode=False)
2002    total_size = _calculation_net_size(net)
2003    if total_size > PROTO_LIMIT_SIZE:
2004        raise RuntimeError('Export onnx model failed. Network size is: {}G, it exceeded the protobuf: {}G limit.'
2005                           .format(total_size / 1024 / 1024, PROTO_LIMIT_SIZE / 1024 / 1024))
2006    phase_name = 'export.onnx'
2007    graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False)
2008    onnx_stream = _executor._get_func_graph_proto(net, graph_id)
2009    if 'enc_key' in kwargs.keys() and 'enc_mode' in kwargs.keys():
2010        enc_mode = kwargs.get('enc_mode')
2011        onnx_stream = enc_mode(onnx_stream, kwargs.get('enc_key'))
2012    if not file_name.endswith('.onnx'):
2013        file_name += ".onnx"
2014    if os.path.exists(file_name):
2015        os.chmod(file_name, stat.S_IWUSR)
2016    with open(file_name, 'wb') as f:
2017        f.write(onnx_stream)
2018        os.chmod(file_name, stat.S_IRUSR)
2019    net.set_train(mode=cell_mode)
2020
2021
2022def _check_dynamic_input(inputs):
2023    for ele in inputs:
2024        if isinstance(ele, Tensor) and -1 in ele.shape:
2025            raise ValueError(f"Export ONNX format model not support dynamic shape mode.")
2026
2027
2028def _generate_front_info_for_param_data_file(is_encrypt, kwargs):
2029    front_info = bytes()
2030    check_code = sys.byteorder == "little"
2031    front_info += check_code.to_bytes(1, byteorder=sys.byteorder)
2032    front_info += bytes(63)
2033    if is_encrypt():
2034        front_info = _encrypt(front_info, len(front_info), kwargs.get('enc_key'),
2035                              len(kwargs.get('enc_key')), kwargs.get('enc_mode'))
2036    return front_info
2037
2038
2039def _change_file(f, dirname, external_local, is_encrypt, kwargs):
2040    """Change to another file to write parameter data."""
2041    # The parameter has been not written in the file
2042    front_info = _generate_front_info_for_param_data_file(is_encrypt, kwargs)
2043    f.seek(0, 0)
2044    f.write(front_info)
2045    f.close()
2046    ori_data_file_name = f.name
2047    os.chmod(ori_data_file_name, stat.S_IRUSR)
2048    if os.path.getsize(ori_data_file_name) == 64:
2049        raise RuntimeError("The parameter size is exceed 1T,cannot export to the file")
2050    data_file_name = os.path.join(dirname, external_local)
2051    return _get_data_file(is_encrypt, kwargs, data_file_name)
2052
2053
2054def _get_data_file(is_encrypt, kwargs, data_file_name):
2055    """Get Data File to write parameter data."""
2056    # Reserves 64 bytes as spare information such as check data
2057    offset = 64
2058    if os.path.exists(data_file_name):
2059        os.chmod(data_file_name, stat.S_IWUSR)
2060
2061    place_holder_data = bytes(offset)
2062    if is_encrypt():
2063        place_holder_data = _encrypt(place_holder_data, len(place_holder_data), kwargs["enc_key"],
2064                                     len(kwargs["enc_key"]), kwargs["enc_mode"])
2065    parameter_size = (offset / 1024)
2066    try:
2067        f = open(data_file_name, "wb")
2068        f.write(place_holder_data)
2069    except IOError:
2070        f.close()
2071
2072    return f, parameter_size, offset
2073
2074
2075def _encrypt_data(is_encrypt, write_data, kwargs):
2076    """Encrypt parameter data."""
2077    if is_encrypt():
2078        if callable(kwargs.get('enc_mode')):
2079            enc_func = kwargs.get('enc_mode')
2080            write_data = enc_func(write_data, kwargs.get('enc_key'))
2081        else:
2082            write_data = _encrypt(write_data, len(write_data), kwargs.get('enc_key'),
2083                                  len(kwargs.get('enc_key')), kwargs.get('enc_mode'))
2084    return write_data
2085
2086
2087def _split_save(net_dict, model, file_name, is_encrypt, **kwargs):
2088    """The function to save parameter data."""
2089    logger.warning("Parameters in the net capacity exceeds 1G, save MindIR model and parameters separately.")
2090    # save parameter
2091    if model.graph.map_parameter:
2092        raise ValueError("MapParameter not support save in split MindIR file now.")
2093    file_prefix = file_name.split("/")[-1]
2094    if file_prefix.endswith(".mindir"):
2095        file_prefix = file_prefix[:-7]
2096    current_path = os.path.abspath(file_name)
2097    dirname = os.path.dirname(current_path)
2098    data_path = os.path.join(dirname, file_prefix + "_variables")
2099    if os.path.exists(data_path):
2100        shutil.rmtree(data_path)
2101    os.makedirs(data_path, exist_ok=True)
2102    os.chmod(data_path, stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR)
2103    index = 0
2104    external_local = os.path.join(file_prefix + "_variables", "data_" + str(index))
2105    data_file_name = os.path.join(dirname, external_local)
2106    f, parameter_size, offset = _get_data_file(is_encrypt, kwargs, data_file_name)
2107    try:
2108        round = 0
2109        names = []
2110        for param_proto in model.graph.parameter:
2111            name = param_proto.name[param_proto.name.find(":") + 1:]
2112            names.append((name, param_proto))
2113            names.sort(key=lambda x: x[0])
2114        for pairs in names:
2115            name = pairs[0]
2116            param_proto = pairs[1]
2117            param = net_dict[name]
2118            raw_data = param.data.get_bytes()
2119            data_length = len(raw_data)
2120            append_size = 0
2121            if data_length % 64 != 0:
2122                append_size = 64 - (data_length % 64)
2123                parameter_size += ((append_size + data_length) / 1024)
2124            if parameter_size > PARAMETER_SPLIT_SIZE:
2125                index += 1
2126                external_local = os.path.join(file_prefix + "_variables", "data_" + str(index))
2127                f, parameter_size, offset = _change_file(f, dirname, external_local, is_encrypt, kwargs)
2128                parameter_size += ((append_size + data_length) / 1024)
2129            param_proto.external_data.location = external_local
2130            param_proto.external_data.length = data_length
2131            param_proto.external_data.offset = offset
2132            write_data = raw_data + bytes(append_size)
2133            offset += (data_length + append_size)
2134            write_data = _encrypt_data(is_encrypt, write_data, kwargs)
2135            f.write(write_data)
2136            round += 1
2137            logger.debug(f"writing {round}th split data, name:{name}")
2138
2139        graph_file_name = os.path.join(dirname, file_prefix + "_graph.mindir")
2140        if os.path.exists(graph_file_name):
2141            os.chmod(graph_file_name, stat.S_IWUSR)
2142        with open(graph_file_name, 'wb') as model_file:
2143            os.chmod(graph_file_name, stat.S_IRUSR | stat.S_IWUSR)
2144            model_string = model.SerializeToString()
2145            if is_encrypt():
2146                model_string = _encrypt(model_string, len(model_string), kwargs.get('enc_key'),
2147                                        len(kwargs.get('enc_key')), kwargs.get('enc_mode'))
2148            model_file.write(model_string)
2149            os.chmod(graph_file_name, stat.S_IRUSR)
2150
2151        front_info = _generate_front_info_for_param_data_file(is_encrypt, kwargs)
2152        f.seek(0, 0)
2153        f.write(front_info)
2154    finally:
2155        f.close()
2156        os.chmod(data_file_name, stat.S_IRUSR)
2157
2158
2159def _msfunc_info(net, *inputs):
2160    """Get mindir stream and parameter dict of ms_function"""
2161    # pylint: disable=protected-access
2162    net_dict = OrderedDict()
2163    _ms_func_executor = _MindsporeFunctionExecutor(net, time.time() * 1e9)
2164    graph_id = _ms_func_executor.compile(net.__name__, *inputs)
2165    mindir_stream = _executor._get_func_graph_proto(net, graph_id, 'mind_ir')
2166    params = _ms_func_executor._graph_executor.get_params(graph_id)
2167    for name, value in params.items():
2168        net_dict[name] = Parameter(value, name=name)
2169    return mindir_stream, net_dict
2170
2171
2172def _cell_info(net, incremental, *inputs):
2173    """Get mindir stream and net dict of cell"""
2174    phase_name = "export.mindir"
2175    graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False)
2176    # pylint: disable=protected-access
2177    mindir_stream = _executor._get_func_graph_proto(net, graph_id, 'mind_ir', incremental=incremental)
2178    # clean obfuscation config to prevent the next call
2179    _executor.obfuscate_config = None
2180
2181    net_dict = net.parameters_dict()
2182    return mindir_stream, net_dict
2183
2184
2185def _set_obfuscate_config(**kwargs):
2186    """Set obfuscation config for executor."""
2187    logger.warning("Obfuscate model.")
2188    if 'enc_mode' in kwargs.keys():
2189        enc_mode = Validator.check_isinstance('enc_mode', kwargs.get('enc_mode'), str)
2190        if enc_mode not in ["AES-GCM", "AES-CBC", "SM4-CBC"]:
2191            raise ValueError(
2192                "Only MindIR files that encrypted with 'AES-GCM', 'AES-CBC' or 'SM4-CBC' is supported for"
2193                "obfuscation, but got {}.".format(enc_mode))
2194    obf_ratio, customized_funcs, obf_random_seed = _check_obfuscate_params(kwargs.get('obf_config'))
2195    if customized_funcs and obf_random_seed > 0:
2196        logger.warning("Although 'customized_func' and 'obf_random_seed' are set, the 'obf_random_seed' mode would be"
2197                       " applied, remember to set 'obf_random_seed' when loading obfuscated model.")
2198
2199    if obf_random_seed == 0:  # apply customized_func mode
2200        device_target = context.get_context('device_target')
2201        if device_target in ["GPU", "Ascend"]:
2202            raise ValueError(
2203                "Customized func mode only support 'device_target'='CPU, but got {}.".format(device_target))
2204        clean_funcs()
2205        for func in customized_funcs:
2206            add_opaque_predicate(func.__name__, func)
2207    _executor.obfuscate_config = {'obf_ratio': obf_ratio, 'obf_random_seed': obf_random_seed}
2208
2209
2210def _save_mindir(net, file_name, *inputs, **kwargs):
2211    """Save MindIR format file."""
2212    # set obfuscate configs
2213    if 'obf_config' in kwargs.keys():
2214        _set_obfuscate_config(**kwargs)
2215        for item in inputs:
2216            if -1 in item.shape:
2217                raise ValueError(
2218                    "Dynamic shape input is not supported now, but got the shape of inputs: {}.".format(item.shape))
2219
2220    incremental = kwargs.get('incremental', False)
2221
2222    model = mindir_model()
2223    if not isinstance(net, nn.Cell):
2224        mindir_stream, net_dict = _msfunc_info(net, *inputs)
2225    else:
2226        mindir_stream, net_dict = _cell_info(net, incremental, *inputs)
2227    model.ParseFromString(mindir_stream)
2228
2229    if kwargs.get('dataset'):
2230        check_input_data(kwargs.get('dataset'), data_class=mindspore.dataset.Dataset)
2231        dataset = kwargs.get('dataset')
2232        _save_dataset_to_mindir(model, dataset)
2233
2234    custom_func = kwargs.get('custom_func', None)
2235    if custom_func is not None:
2236        custom_func(model)
2237
2238    save_together = _save_together(net_dict, model)
2239    is_encrypt = lambda: 'enc_key' in kwargs.keys() and 'enc_mode' in kwargs.keys()
2240    if save_together:
2241        _save_mindir_together(net_dict, model, file_name, is_encrypt, **kwargs)
2242    else:
2243        _split_save(net_dict, model, file_name, is_encrypt, **kwargs)
2244
2245
2246def _save_mindir_together(net_dict, model, file_name, is_encrypt, **kwargs):
2247    """Save graph and parameter together."""
2248    for param_proto in model.graph.parameter:
2249        param_name = param_proto.name[param_proto.name.find(":") + 1:]
2250        if param_name in net_dict.keys():
2251            param_data = net_dict[param_name].data.get_bytes()
2252            param_proto.raw_data = param_data
2253        else:
2254            raise ValueError("The parameter '{}' is not belongs to any cell,"
2255                             "the data of parameter cannot be exported.".format(param_proto.name))
2256    incremental = kwargs.get('incremental', False)
2257    for map_param_proto in model.graph.map_parameter:
2258        map_param_name = map_param_proto.name[map_param_proto.name.find(":") + 1:]
2259        if map_param_name in net_dict.keys():
2260            map_parameter = net_dict[map_param_name]
2261            key_bytes, value_bytes, status_bytes = map_parameter.export_bytes(incremental)
2262            map_param_proto.key_tensor.raw_data = key_bytes
2263            map_param_proto.value_tensor.raw_data = value_bytes
2264            map_param_proto.status_tensor.raw_data = status_bytes
2265        else:
2266            raise ValueError("The map_parameter '{}' is not belongs to any cell,"
2267                             "the data of parameter cannot be exported.".format(map_param_proto.name))
2268    if not file_name.endswith('.mindir'):
2269        file_name += ".mindir"
2270    current_path = os.path.abspath(file_name)
2271    dirname = os.path.dirname(current_path)
2272    os.makedirs(dirname, exist_ok=True)
2273    if os.path.exists(file_name):
2274        os.chmod(file_name, stat.S_IWUSR)
2275    with open(file_name, 'wb') as f:
2276        os.chmod(file_name, stat.S_IRUSR | stat.S_IWUSR)
2277        model_string = model.SerializeToString()
2278        if is_encrypt():
2279            if callable(kwargs.get('enc_mode')):
2280                enc_func = kwargs.get('enc_mode')
2281                model_string = enc_func(model_string, kwargs.get('enc_key'))
2282            else:
2283                model_string = _encrypt(model_string, len(model_string), kwargs.get('enc_key'),
2284                                        len(kwargs.get('enc_key')), kwargs.get('enc_mode'))
2285        f.write(model_string)
2286        os.chmod(file_name, stat.S_IRUSR)
2287
2288
2289def _save_together(net_dict, model):
2290    """Whether graph and parameter save together during save mindir model."""
2291    data_total = 0
2292    for param_proto in model.graph.parameter:
2293        name = param_proto.name[param_proto.name.find(":") + 1:]
2294        if name in net_dict.keys():
2295            data_total += sys.getsizeof(net_dict[name].data.get_bytes()) / 1024
2296        else:
2297            raise ValueError("The parameter '{}' is not belongs to any cell,"
2298                             "the data of parameter cannot be exported.".format(param_proto.name))
2299        if data_total > TOTAL_SAVE:
2300            return False
2301    return True
2302
2303
2304def _save_dataset_to_mindir(model, dataset):
2305    """Save dataset preprocess operations into mindir model."""
2306    dataset_json = dataset.to_json()
2307    reverse_dataset = []
2308    while dataset_json:
2309        reverse_dataset = [dataset_json] + reverse_dataset
2310        if len(dataset_json['children']) > 1:
2311            logger.warning("Need to support dataset_node with more than one child, using child 0 as default.")
2312        dataset_json = dataset_json['children'][0] if dataset_json['children'] else []
2313
2314    for op in reverse_dataset:
2315        if op['op_type'] == 'Map':
2316            model.preprocessor.op.add()
2317            model.preprocessor.op[-1].input_columns = json.dumps(op['input_columns'])
2318            model.preprocessor.op[-1].output_columns = json.dumps(op['output_columns'])
2319            model.preprocessor.op[-1].op_type = json.dumps(op['op_type'])
2320            model.preprocessor.op[-1].operations = json.dumps(op['operations'])
2321            model.preprocessor.op[-1].offload = op['offload'] if 'offload' in op.keys() else False
2322
2323
2324def check_checkpoint(ckpt_file_name):
2325    """
2326    Check whether the checkpoint is valid.
2327
2328    Args:
2329        ckpt_file_name (str): Checkpoint file name.
2330
2331    Returns:
2332        bool, whether the checkpoint is valid.
2333
2334    Examples:
2335        >>> import mindspore as ms
2336        >>> ckpt_file_name = "./checkpoint/LeNet5-1_32.ckpt"
2337        >>> check_result = ms.check_checkpoint(ckpt_file_name)
2338        >>> print(check_result)
2339        True
2340    """
2341    if not ckpt_file_name.endswith('.ckpt'):
2342        return False
2343    checkpoint_list = Checkpoint()
2344    with _ckpt_fs.open(ckpt_file_name, *_ckpt_fs.open_args) as f:
2345        pb_content = f.read()
2346        if pb_content[-17:-10] == b"crc_num":
2347            crc_num_bytes = pb_content[-10:]
2348            pb_content = pb_content[:-17]
2349            crc_num = int.from_bytes(crc_num_bytes, byteorder='big')
2350            cal_crc_num = binascii.crc32(pb_content, 0)
2351            if cal_crc_num != crc_num:
2352                logger.warning("For 'check_checkpoint', the ckpt crc check is failed.")
2353                return False
2354        try:
2355            checkpoint_list.ParseFromString(pb_content)
2356        except google.protobuf.message.DecodeError as e:
2357            logger.warning("For 'check_checkpoint', the ckpt parse is failed.")
2358            logger.warning(e)
2359            return False
2360        return True
2361
2362
2363def parse_print(print_file_name):
2364    """
2365    Parse data file generated by :class:`mindspore.ops.Print`.
2366
2367    Args:
2368        print_file_name (str): The file name needs to be parsed.
2369
2370    Returns:
2371        List, element of list is Tensor.
2372
2373    Raises:
2374        ValueError: The print file does not exist or is empty.
2375        RuntimeError: Failed to parse the file.
2376
2377    Examples:
2378        >>> import numpy as np
2379        >>> import mindspore as ms
2380        >>> from mindspore import nn, Tensor, ops
2381        >>> ms.set_context(mode=ms.GRAPH_MODE, print_file_path='log.data')
2382        >>> class PrintInputTensor(nn.Cell):
2383        ...         def __init__(self):
2384        ...             super().__init__()
2385        ...             self.print = ops.Print()
2386        ...
2387        ...         def construct(self, input_pra):
2388        ...             self.print('print:', input_pra)
2389        ...             return input_pra
2390        >>> x = np.array([[1, 2, 3, 4], [5, 6, 7, 8]]).astype(np.float32)
2391        >>> input_pra = Tensor(x)
2392        >>> net = PrintInputTensor()
2393        >>> net(input_pra)
2394        >>>
2395        >>> data = ms.parse_print('./log.data')
2396        >>> print(data)
2397        ['print:', Tensor(shape=[2, 4], dtype=Float32, value=
2398        [[ 1.00000000e+00,  2.00000000e+00,  3.00000000e+00,  4.00000000e+00],
2399        [ 5.00000000e+00,  6.00000000e+00,  7.00000000e+00,  8.00000000e+00]])]
2400    """
2401    print_file_path = os.path.abspath(print_file_name)
2402
2403    if os.path.getsize(print_file_path) == 0:
2404        raise ValueError("For 'parse_print', the print file may be empty, please make sure enter the correct "
2405                         "'print_file_name'.")
2406
2407    logger.info("Execute load print process.")
2408    print_list = Print()
2409
2410    try:
2411        with open(print_file_path, "rb") as f:
2412            pb_content = f.read()
2413        print_list.ParseFromString(pb_content)
2414    except BaseException as e:
2415        logger.critical("Failed to read the print file %s, please check whether the file is "
2416                        "correct.", print_file_name)
2417        raise ValueError(e.__str__() + "\nFailed to read the print file {}, please check whether "
2418                                       "the file is correct.".format(print_file_name)) from e
2419
2420    tensor_list = []
2421
2422    try:
2423        for print_ in print_list.value:
2424            # String type
2425            if print_.HasField("desc"):
2426                tensor_list.append(print_.desc)
2427            elif print_.HasField("tensor"):
2428                dims = print_.tensor.dims
2429                data_type = print_.tensor.tensor_type
2430                data = print_.tensor.tensor_content
2431                np_type = tensor_to_np_type.get(data_type)
2432                param_data = np.fromstring(data, np_type)
2433                ms_type = tensor_to_ms_type.get(data_type)
2434                if dims and dims != [0]:
2435                    param_value = param_data.reshape(dims)
2436                    tensor_list.append(Tensor(param_value, ms_type))
2437                # Scalar type
2438                else:
2439                    data_type_ = data_type.lower()
2440                    if 'float' in data_type_:
2441                        param_data = float(param_data[0])
2442                    elif 'int' in data_type_:
2443                        param_data = int(param_data[0])
2444                    elif 'bool' in data_type_:
2445                        param_data = bool(param_data[0])
2446                    tensor_list.append(Tensor(param_data, ms_type))
2447
2448    except BaseException as e:
2449        logger.critical("Failed to load the print file %s.", print_list)
2450        raise RuntimeError(e.__str__() + "\nFailed to load the print file {}.".format(print_list)) from e
2451
2452    return tensor_list
2453
2454
2455def _merge_param_with_strategy(sliced_data, parameter_name, strategy, is_even):
2456    """
2457    Merge data slices to one tensor with whole data when strategy is not None.
2458
2459    Args:
2460        sliced_data (list[numpy.ndarray]): Data slices in order of rank_id.
2461        parameter_name (str): Name of parameter.
2462        strategy (dict): Parameter slice strategy.
2463        is_even (bool): Slice manner that True represents slicing evenly and False represents slicing unevenly.
2464
2465    Returns:
2466        Tensor, the merged Tensor which has the whole data.
2467
2468    Raises:
2469        ValueError: Failed to merge.
2470    """
2471    layout = strategy.get(parameter_name)
2472    try:
2473        dev_mat = list(layout.dev_matrix[0].dim)
2474        tensor_map = list(layout.tensor_map[0].dim)
2475        param_split_shape = list(layout.param_split_shape[0].dim)
2476        field_size = int(layout.field)
2477    except BaseException as e:
2478        raise ValueError(f"{e.__str__()}. For 'merge_sliced_parameter'"
2479                         f", please make sure that 'strategy' is correct.") from e
2480
2481    device_count = 1
2482    for dim in dev_mat:
2483        device_count *= dim
2484
2485    if len(sliced_data) != device_count:
2486        raise ValueError(f"For 'merge_sliced_parameter', the length of 'sliced_parameters' should be equal to "
2487                         f"device_count. The length of 'sliced_parameters' is {len(sliced_data)}, but "
2488                         f"device_count is {device_count}.")
2489
2490    if not param_split_shape:
2491        if not is_even:
2492            raise ValueError("For 'merge_sliced_parameter', the shape of every parameter in 'sliced_parameters' "
2493                             "should be the same when slice manner is even.")
2494
2495        all_gather_tensor = Tensor(np.concatenate(sliced_data))
2496
2497        if field_size > 0:
2498            merged_tensor = _reshape_param_data_with_weight(all_gather_tensor, dev_mat, field_size)
2499        else:
2500            merged_tensor = _reshape_param_data(all_gather_tensor, dev_mat, tensor_map)
2501
2502    else:
2503        tensor_strategy = _get_tensor_strategy(dev_mat, tensor_map)
2504
2505        slice_count = 1
2506        for dim in tensor_strategy:
2507            slice_count *= dim
2508
2509        if len(param_split_shape) != slice_count:
2510            raise ValueError(f"For 'merge_sliced_parameter', the param_split_shape length in 'strategy' should be "
2511                             f"{slice_count}, but got {len(param_split_shape)}.")
2512
2513        tensor_slices_new = list(range(slice_count))
2514        tensor_slices = sliced_data
2515        for i in range(device_count):
2516            slice_index = int(_get_tensor_slice_index(dev_mat, tensor_strategy, tensor_map, i))
2517            if tensor_slices[i].shape[0] != param_split_shape[slice_index]:
2518                raise ValueError(f"For 'merge_sliced_parameter', the slice {slice_index} should be "
2519                                 f"{param_split_shape[slice_index]} in 0 axis, but got "
2520                                 f"{tensor_slices[i].shape[0]}.")
2521            tensor_slices_new[slice_index] = np.array(tensor_slices[i])
2522
2523        dim_len = len(tensor_strategy)
2524        for i in range(dim_len):
2525            ele_count = int(len(tensor_slices_new) / tensor_strategy[dim_len - 1 - i])
2526            tensor_slices_new_inner = []
2527            for j in range(ele_count):
2528                new_tensor = tensor_slices_new[j * tensor_strategy[dim_len - 1 - i]]
2529                for k in range(j * tensor_strategy[dim_len - 1 - i] + 1,
2530                               (j + 1) * tensor_strategy[dim_len - 1 - i]):
2531                    new_tensor = np.concatenate((new_tensor, tensor_slices_new[k]), axis=dim_len - 1 - i)
2532                tensor_slices_new_inner.insert(len(tensor_slices_new_inner), np.array(new_tensor))
2533            tensor_slices_new = tensor_slices_new_inner
2534        merged_tensor = Tensor(tensor_slices_new[0])
2535
2536    return merged_tensor
2537
2538
2539def restore_group_info_list(group_info_file_name):
2540    """
2541    Build rank list, the checkpoint of ranks in the rank list has the same contents with the local rank
2542    who saves the `group_info_file_name`. To save the group info file, please export GROUP_INFO_FIL
2543    environment variables like "export GROUP_INFO_FILE=/data/group_info.pb".
2544
2545    Args:
2546        group_info_file_name (str): Name of group information file.
2547
2548    Returns:
2549        List, the rank list.
2550
2551    Raises:
2552        ValueError: group information file is incorrect.
2553        TypeError: `group_info_file_name` is not str.
2554
2555    Examples:
2556        >>> import mindspore as ms
2557        >>> ms.restore_list = restore_group_info_list("./group_info.pb")
2558    """
2559    if not isinstance(group_info_file_name, str):
2560        raise TypeError(f"For 'restore_group_info_list', the argument 'group_info_file_name' should be str, "
2561                        f"but got {type(group_info_file_name)}.")
2562
2563    if not os.path.isfile(group_info_file_name):
2564        raise ValueError(f"For 'restore_group_info_list', no such group information file: {group_info_file_name}.")
2565
2566    if os.path.getsize(group_info_file_name) == 0:
2567        raise ValueError("For 'restore_group_info_list', the group information file should not be empty.")
2568
2569    return _restore_group_info_list(group_info_file_name)
2570
2571
2572def build_searched_strategy(strategy_filename):
2573    """
2574    Build strategy of every parameter in network. Used in the case of distributed inference.
2575
2576    Args:
2577        strategy_filename (str): Name of strategy file.
2578
2579    Returns:
2580        Dict, whose key is parameter name and value is slice strategy of this parameter.
2581
2582    Raises:
2583        ValueError: Strategy file is incorrect.
2584        TypeError: `strategy_filename` is not a string.
2585
2586    Examples:
2587        >>> import mindspore as ms
2588        >>> strategy = ms.build_searched_strategy("./strategy_train.ckpt")
2589    """
2590    return _build_searched_strategy(strategy_filename)
2591
2592
2593def merge_sliced_parameter(sliced_parameters, strategy=None):
2594    """
2595    Merge parameter slices into one parameter. Used in the case of distributed inference.
2596
2597    Args:
2598        sliced_parameters (list[Parameter]): Parameter slices in order of rank id.
2599        strategy (Optional[dict]): Parameter slice strategy, whose key is parameter name and
2600            value is slice strategy of this parameter. If strategy is None, just merge
2601            parameter slices in 0 axis order. Default: ``None``.
2602
2603    Returns:
2604        Parameter, the merged parameter which has the whole data.
2605
2606    Raises:
2607        ValueError: Failed to merge.
2608        TypeError: The sliced_parameters is incorrect or strategy is not dict.
2609        KeyError: The parameter name is not in keys of strategy.
2610
2611    Examples:
2612        >>> import numpy as np
2613        >>> import mindspore as ms
2614        >>> from mindspore import Tensor, Parameter
2615        >>>
2616        >>> sliced_parameters = [
2617        ...                      Parameter(Tensor(np.array([0.00023915, 0.00013939, -0.00098059])),
2618        ...                                "network.embedding_table"),
2619        ...                      Parameter(Tensor(np.array([0.00015815, 0.00015458, -0.00012125])),
2620        ...                                "network.embedding_table"),
2621        ...                      Parameter(Tensor(np.array([0.00042165, 0.00029692, -0.00007941])),
2622        ...                                "network.embedding_table"),
2623        ...                      Parameter(Tensor(np.array([0.00084451, 0.00089960, -0.00010431])),
2624        ...                                "network.embedding_table")]
2625        >>> merged_parameter = ms.merge_sliced_parameter(sliced_parameters)
2626        >>> print(merged_parameter)
2627        Parameter (name=network.embedding_table, shape=(12,), dtype=Float64, requires_grad=True)
2628    """
2629    if not isinstance(sliced_parameters, list):
2630        raise TypeError(f"For 'merge_sliced_parameter', the argument 'sliced_parameters' should be list, "
2631                        f"but got {type(sliced_parameters)}.")
2632
2633    if not sliced_parameters:
2634        raise ValueError("For 'merge_sliced_parameter', the argument 'sliced_parameters' should not be empty.")
2635
2636    if strategy and not isinstance(strategy, dict):
2637        raise TypeError(f"For 'merge_sliced_parameter', the argument 'strategy' should be dict, "
2638                        f"but got {type(strategy)}.")
2639
2640    try:
2641        parameter_name = sliced_parameters[0].name
2642        parameter_shape = sliced_parameters[0].data.shape
2643        parameter_shape_length = len(parameter_shape)
2644    except BaseException as e:
2645        raise TypeError(e.__str__() + f" For 'merge_sliced_parameter', the element in 'sliced_parameters' should be "
2646                                      f"'Parameter', but got {type(sliced_parameters[0])} at index 0.") from e
2647
2648    is_even = True
2649    for index, parameter in enumerate(sliced_parameters):
2650        if not isinstance(parameter, Parameter):
2651            raise TypeError(f"For 'merge_sliced_parameter', the element in 'sliced_parameters' should be 'Parameter', "
2652                            f"but got {type(parameter)} at index {index}.")
2653
2654        if parameter.name != parameter_name \
2655                or len(parameter.data.shape) != parameter_shape_length \
2656                or parameter.data.shape[1:] != parameter_shape[1:]:
2657            raise ValueError(f"For 'merge_sliced_parameter', please make sure that the elements in 'slice_parameters'"
2658                             f" have the same name, dimension length and shape except 0 axis. The name, dimension "
2659                             f"length, shape except 0 axis should be {parameter_name}, {parameter_shape_length}, "
2660                             f"{parameter_shape[1:]}, but got name: {parameter.name}, dimension length: "
2661                             f"{len(parameter.data.shape)}, shape except 0 axis: {parameter.data.shape[1:]} "
2662                             f"at index {index}.")
2663
2664        if parameter.data.shape != parameter_shape:
2665            is_even = False
2666
2667    layerwise_parallel = sliced_parameters[0].layerwise_parallel
2668    requires_grad = sliced_parameters[0].requires_grad
2669    sliced_data = []
2670    for parameter in sliced_parameters:
2671        if parameter.data.dtype == mstype.bfloat16:
2672            sliced_data.append(cpu_cast(parameter.data, mstype.float32).asnumpy())
2673        else:
2674            sliced_data.append(parameter.data.asnumpy())
2675
2676    if not strategy:
2677        merged_tensor = Tensor(np.concatenate(sliced_data))
2678        merged_parameter = Parameter(merged_tensor, parameter_name, requires_grad, layerwise_parallel)
2679
2680    else:
2681        if parameter_name not in strategy.keys():
2682            raise KeyError(f"For 'merge_sliced_parameter', the parameter name {parameter_name} should be a key in "
2683                           f"the 'strategy'. Please check 'sliced_parameter' and 'strategy'.")
2684        merged_tensor = _merge_param_with_strategy(sliced_data, parameter_name, strategy, is_even)
2685        merged_parameter = Parameter(merged_tensor, parameter_name, requires_grad, layerwise_parallel)
2686
2687    return merged_parameter
2688
2689
2690def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=None,
2691                                train_strategy_filename=None, strict_load=False, dec_key=None, dec_mode='AES-GCM'):
2692    """
2693    Load checkpoint into net for distributed predication. Used in the case of distributed inference.
2694
2695    Args:
2696        network (Cell): Network for distributed predication.
2697        checkpoint_filenames (list[str]): The name of Checkpoint files in order of rank id.
2698        predict_strategy (dict): Strategy of predication process. It means that using one device to predict
2699                                 when setting predict_strategy as None. Default: ``None`` .
2700        train_strategy_filename (str): The filename of training strategy protocol buffer file.
2701                                       When train_strategy_filename is None, the training strategy file will be
2702                                       obtained from context.get_auto_parallel_context("strategy_ckpt_load_file").
2703                                       Therefore, the training strategy file needs to be specified
2704                                       in at least one of them. Default: ``None`` .
2705        strict_load (bool): Whether to strict load the parameter into net. If ``False`` , it will load parameter
2706                            into net when parameter name's suffix in checkpoint file is the same as the
2707                            parameter in the network. When the types are inconsistent, perform type conversion
2708                            on the parameters of the same type, such as float32 to float16. Default: ``False`` .
2709        dec_key (Union[None, bytes]): Byte type key used for decryption. If the value is ``None`` , the decryption
2710                                      is not required. Default: ``None`` .
2711        dec_mode (str): This parameter is valid only when dec_key is not set to ``None`` . Specifies the decryption
2712                        mode, currently supports ``'AES-GCM'`` , ``'AES-CBC'``  and ``'SM4-CBC'`` .
2713                        Default: ``'AES-GCM'`` .
2714
2715    Raises:
2716        TypeError: The type of inputs do not match the requirements.
2717        ValueError: Failed to load checkpoint into net.
2718
2719    Supported Platforms:
2720        ``Ascend`` ``GPU``
2721
2722    Examples:
2723        .. note::
2724            Before running the following examples, you need to configure the communication environment variables.
2725
2726            For the Ascend devices, users need to prepare the rank table, set rank_id and device_id.
2727            Please see the `rank table startup
2728            <https://www.mindspore.cn/tutorials/experts/en/master/parallel/rank_table.html>`_
2729            for more details.
2730
2731            For the GPU devices, users need to prepare the host file and mpi, please see the `mpirun startup
2732            <https://www.mindspore.cn/tutorials/experts/en/master/parallel/mpirun.html>`_ .
2733
2734            For the CPU device, users need to write a dynamic cluster startup script, please see the `Dynamic Cluster
2735            Startup <https://www.mindspore.cn/tutorials/experts/en/master/parallel/dynamic_cluster.html>`_ .
2736
2737        >>> import os
2738        >>> import numpy as np
2739        >>> import mindspore as ms
2740        >>> import mindspore.dataset as ds
2741        >>> from mindspore import nn, ops, train
2742        >>> from mindspore.communication import init
2743        >>>
2744        >>> step_per_epoch = 4
2745        >>> device_num = 8
2746        >>>
2747        >>> # Define the network structure.
2748        >>> class Net(nn.Cell):
2749        ...     def __init__(self, matmul_size, strategy=None):
2750        ...         super().__init__()
2751        ...         matmul_np = np.full(matmul_size, 0.5, dtype=np.float32)
2752        ...         self.matmul_weight = ms.Parameter(ms.Tensor(matmul_np))
2753        ...         self.matmul = ops.MatMul()
2754        ...         self.neg = ops.Neg()
2755        ...         if strategy is not None:
2756        ...             self.matmul.shard(strategy)
2757        ...
2758        ...     def construct(self, inputs):
2759        ...         x = self.matmul(inputs, self.matmul_weight)
2760        ...         x = self.neg(x)
2761        ...         return x
2762        >>>
2763        >>> # Create dataset.
2764        >>> def get_dataset(*inputs):
2765        ...     def generate():
2766        ...         for _ in range(step_per_epoch):
2767        ...             yield inputs
2768        ...     return generate
2769        >>>
2770        >>> # Train network and save distributed checkpoint.
2771        >>> def train_net():
2772        ...     ms.set_context(mode=ms.GRAPH_MODE)
2773        ...     init()
2774        ...     np.random.seed(1)
2775        ...     input_data = np.random.rand(16, 96).astype(np.float32)
2776        ...     label_data = np.random.rand(16, 16).astype(np.float32)
2777        ...     fake_dataset = get_dataset(input_data, label_data)
2778        ...     dataset = ds.GeneratorDataset(fake_dataset, ["input", "label"])
2779        ...
2780        ...     # Set parallel strategy.
2781        ...     strategy = ((1, 4), (4, 1))
2782        ...     ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL, device_num=device_num,
2783        ...                                  strategy_ckpt_save_file="./train_strategy.ckpt")
2784        ...     network = Net(matmul_size=(96, 16), strategy=strategy)
2785        ...     net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9)
2786        ...     net_loss = nn.SoftmaxCrossEntropyWithLogits(reduction="mean")
2787        ...     model = ms.Model(network=network, loss_fn=net_loss, optimizer=net_opt)
2788        ...     ckpt_config = train.CheckpointConfig(keep_checkpoint_max=1, integrated_save=False)
2789        ...     global_rank_id = int(os.getenv("RANK_ID"))
2790        ...     ckpt_path = "./rank_{}_ckpt".format(global_rank_id)
2791        ...     ckpt_callback = train.ModelCheckpoint(prefix="parallel", directory=ckpt_path, config=ckpt_config)
2792        ...     model.train(epoch=2, train_dataset=dataset, callbacks=[ckpt_callback], dataset_sink_mode=False)
2793        ...     ms.reset_auto_parallel_context()
2794        >>>
2795        >>> # Load distributed checkpoint and test.
2796        >>> def load_model():
2797        ...     ms.set_context(mode=ms.GRAPH_MODE)
2798        ...     init()
2799        ...     ms.set_auto_parallel_context(full_batch=True, parallel_mode="semi_auto_parallel",
2800        ...                                  strategy_ckpt_load_file="./train_strategy.ckpt", device_num=device_num)
2801        ...     predict_data = ms.Tensor(np.random.randn(128, 96).astype(np.float32))
2802        ...     network = Net(matmul_size=(96, 16))
2803        ...     model = ms.Model(network)
2804        ...     predict_layout = model.infer_predict_layout(ms.Tensor(predict_data))
2805        ...     ckpt_file_list = ["./rank_{}_ckpt/parallel-2_4.ckpt".format(i) for i in range(0, device_num)]
2806        ...     ms.load_distributed_checkpoint(network, ckpt_file_list, predict_layout)
2807        ...     predict_result = model.predict(predict_data)
2808        ...     print(predict_result)
2809        >>>
2810        >>> train_net()
2811        >>> load_model()
2812        [[-7.3259363 -7.497216  -7.398196  ... -7.374962  -7.204874  -7.234935 ]
2813        [ 3.362938   3.3535435  3.3832688 ...  3.4263954  3.279045   3.3202887]
2814        ...
2815        [ 1.6067538  1.6244187  1.5384722 ...  1.5449994  1.6195512  1.6176052]]
2816    """
2817    network = Validator.check_isinstance("network", network, nn.Cell)
2818    _check_checkpoint_file(checkpoint_filenames)
2819    _check_predict_strategy(predict_strategy)
2820
2821    dec_key = Validator.check_isinstance('dec_key', dec_key, (type(None), bytes))
2822    dec_mode = Validator.check_isinstance('dec_mode', dec_mode, str)
2823
2824    if train_strategy_filename is None:
2825        train_strategy_filename = context.get_auto_parallel_context("strategy_ckpt_load_file")
2826    _train_strategy = build_searched_strategy(train_strategy_filename)
2827    train_strategy = _convert_to_list(_train_strategy)
2828
2829    train_dev_count = 1
2830    ckpt_file_len = len(checkpoint_filenames)
2831    for dim in train_strategy[list(train_strategy.keys())[0]][0]:
2832        train_dev_count *= dim
2833    if train_dev_count != ckpt_file_len:
2834        raise ValueError(f"For 'Load_distributed_checkpoint', the length of 'checkpoint_filenames' should be "
2835                         f"equal to the device count of training process. "
2836                         f"But got the length of 'checkpoint_filenames'"
2837                         f" is {ckpt_file_len} and the device count is {train_dev_count}.")
2838    rank_list = _infer_rank_list(train_strategy, predict_strategy)
2839
2840    param_total_dict = defaultdict(dict)
2841    for file_index, file_name in enumerate(checkpoint_filenames):
2842        ckpt_dict = load_checkpoint(file_name, dec_key=dec_key, dec_mode=dec_mode)
2843        for param_name, param in ckpt_dict.items():
2844            param_total_dict[param_name][file_index] = param
2845
2846    param_dict = {}
2847    param_not_in_strategy = []
2848    param_not_in_ckpt = []
2849    for _, param in network.parameters_and_names():
2850        sliced_params = []
2851        if param.name not in rank_list.keys():
2852            param_not_in_strategy.append(param.name)
2853            continue
2854        if param.name not in param_total_dict:
2855            param_not_in_ckpt.append(param.name)
2856            continue
2857
2858        param_rank = rank_list.get(param.name)[0]
2859        skip_merge_split = rank_list.get(param.name)[1]
2860        shard_stride = train_strategy.get(param.name)[4]
2861        if train_strategy.get(param.name)[5]:
2862            shard_size = ckpt_file_len / shard_stride / train_strategy.get(param.name)[5]
2863        else:
2864            shard_size = 0
2865        for rank in param_rank:
2866            param_total_list = list(range(0, ckpt_file_len))
2867            if shard_size > 0:
2868                shard_total_list = []
2869                for i in range(0, ckpt_file_len, shard_size):
2870                    shard_total_list.append(param_total_list[i:i + shard_size])
2871                param_total_list = shard_total_list[rank // shard_size]
2872            if shard_stride > 0:
2873                param_stride = []
2874                # merge pre parameter
2875                param_index = param_total_list[0:param_total_list.index(rank) + 1][::-1][::shard_stride]
2876                param_index.extend(param_total_list[param_total_list.index(rank):][::shard_stride])
2877                param_index = list(set(param_index))
2878                param_index.sort()
2879                for rank_num in param_index:
2880                    if param_total_dict[param.name][rank_num].data.dtype == mstype.bfloat16:
2881                        param_stride.append(
2882                            cpu_cast(param_total_dict[param.name][rank_num].data, mstype.float32).asnumpy())
2883                    else:
2884                        param_stride.append(param_total_dict[param.name][rank_num].data.asnumpy())
2885
2886                sliced_param = Parameter(Tensor(np.concatenate(param_stride)), name=param.name)
2887            else:
2888                sliced_param = param_total_dict[param.name][rank]
2889
2890            sliced_params.append(sliced_param)
2891        if skip_merge_split:
2892            split_param = sliced_params[0]
2893        else:
2894            param_unique_strategy = _remove_repeated_slices(train_strategy[param.name])
2895            _param_unique_strategy = _convert_to_layout(param.name, param_unique_strategy)
2896            split_param = _merge_and_split(sliced_params, _param_unique_strategy, predict_strategy)
2897        opt_shard_group = predict_strategy[param.name][5] if predict_strategy else None
2898        if opt_shard_group:
2899            if split_param.data.dtype == mstype.bfloat16:
2900                data = cpu_cast(split_param.data, mstype.float32).asnumpy()
2901            else:
2902                data = split_param.data.asnumpy()
2903            rank = get_rank(opt_shard_group)
2904            size = get_group_size(opt_shard_group)
2905            try:
2906                data_slice = np.split(data, size)[rank]
2907            except BaseException as e:
2908                logger.critical("Failed to load opt shard slice in load distributed checkpoint for {}. Data shape is {}"
2909                                " and group is {}".format(param.name, split_param.data.shape, opt_shard_group))
2910                raise RuntimeError(e.__str__() + f"\nFor 'load_distributed_checkpoint', failed to load opt shard slice"
2911                                                 f" in load distributed checkpoint for {param.name}. Data shape is "
2912                                                 f"{split_param.data.shape} and group is {opt_shard_group}.") from e
2913            split_param = Parameter(Tensor(data_slice), param.name,
2914                                    split_param.requires_grad, split_param.layerwise_parallel)
2915        param_dict[param.name] = split_param
2916
2917    if param_not_in_strategy:
2918        logger.warning("For 'load_distributed_checkpoint', {} parameters in network are not in the slice strategy, "
2919                       "you can check whether 'predict_strategy' or 'train_strategy_filename' is correct."
2920                       .format(param_not_in_strategy))
2921    if param_not_in_ckpt:
2922        logger.warning("For 'load_distributed_checkpoint', {} parameters in network and slice strategy but not in "
2923                       "the checkpoint file, please check whether 'checkpoint_filenames' is correct."
2924                       .format(param_not_in_ckpt))
2925
2926    load_param_into_net(network, param_dict, strict_load=strict_load)
2927
2928
2929def async_ckpt_thread_status():
2930    """
2931    Get the status of asynchronous save checkpoint thread.
2932
2933    When performing asynchronous save checkpoint, you can determine whether the asynchronous thread is completed.
2934
2935    Returns:
2936        bool, True, Asynchronous save checkpoint thread is running.
2937        False, Asynchronous save checkpoint thread is not executing.
2938
2939    Examples:
2940        >>> import mindspore as ms
2941        >>> ms.async_ckpt_thread_status()
2942        False
2943    """
2944    thr_list = threading.enumerate()
2945    return True in [ele.getName() == "asyn_save_ckpt" for ele in thr_list]
2946
2947
2948def _check_predict_strategy(predict_strategy):
2949    """Check predict strategy."""
2950
2951    def _check_int_list(arg):
2952        if not isinstance(arg, list):
2953            return False
2954        for item in arg:
2955            if not isinstance(item, int):
2956                return False
2957        return True
2958
2959    if predict_strategy is None:
2960        return
2961
2962    flag = True
2963    predict_strategy = Validator.check_isinstance("predict_strategy", predict_strategy, dict)
2964    for key in predict_strategy.keys():
2965        if not isinstance(key, str) or not isinstance(predict_strategy[key], (list, tuple)) \
2966                or len(predict_strategy[key]) < 4:
2967            flag = False
2968        dev_matrix, tensor_map, param_split_shape, field_size = predict_strategy[key][:4]
2969        if not _check_int_list(dev_matrix) or not _check_int_list(tensor_map) or \
2970                not (_check_int_list(param_split_shape) or not param_split_shape) or \
2971                not (isinstance(field_size, int) and field_size == 0):
2972            flag = False
2973
2974    if not flag:
2975        raise ValueError(f"For 'load_distributed_checkpoint', the argument 'predict_strategy' is dict, "
2976                         f"the key of it must be string, and the value of it must be list or tuple that "
2977                         f"the first four elements must be dev_matrix (list[int]), tensor_map (list[int]), "
2978                         f"param_split_shape (list[int]) and field_size (int, which value is 0)."
2979                         f"Please check whether 'predict_strategy' is correct.")
2980
2981
2982def _check_checkpoint_file(checkpoint_filenames):
2983    """Check checkpoint file name."""
2984    for index, filename in enumerate(checkpoint_filenames):
2985        if not isinstance(filename, str) or not os.path.exists(filename) \
2986                or filename[-5:] != ".ckpt" or os.path.getsize(filename) == 0:
2987            raise ValueError(f"For 'load_distributed_checkpoint', please check 'checkpoint_filenames', and "
2988                             f"make sure the {filename} at index {index} is a valid checkpoint file, it must "
2989                             f"be a string ending with '.ckpt', and the checkpoint file it represents must "
2990                             f"be exist and not empty.")
2991
2992
2993def _merge_and_split(sliced_params, train_strategy, predict_strategy):
2994    """Merge sliced parameter and split it according to the predict strategy."""
2995    merged_param = merge_sliced_parameter(sliced_params, train_strategy)
2996    if predict_strategy is None:
2997        return merged_param
2998    param_name = merged_param.name
2999    tensor_layout = predict_strategy[param_name]
3000    rank = get_rank()
3001    split_tensor = _load_tensor(merged_param.data, tensor_layout[0], tensor_layout[1], rank_id=rank)
3002    requires_grad = merged_param.requires_grad
3003    layerwise_parallel = merged_param.layerwise_parallel
3004    if merged_param.data.dtype == mstype.bfloat16:
3005        split_param = Parameter(Tensor(split_tensor, mstype.bfloat16), param_name, requires_grad, layerwise_parallel)
3006    else:
3007        split_param = Parameter(split_tensor, param_name, requires_grad, layerwise_parallel)
3008    return split_param
3009
3010
3011def _calculation_net_size(net):
3012    """Calculate the size of parameters in the network."""
3013    data_total = 0
3014    net_dict = net.parameters_dict()
3015    for name in net_dict:
3016        data_total += sys.getsizeof(net_dict[name].data.get_bytes()) / 1024
3017
3018    return data_total
3019
3020
3021def _get_mindir_inputs(file_name):
3022    """
3023    Get MindIR file's inputs.
3024
3025    Note:
3026        1. Parsing encrypted MindIR file is not supported.
3027        2. Parsing dynamic shape MindIR file is not supported.
3028
3029    Args:
3030        file_name (str): MindIR file name.
3031
3032    Returns:
3033        Tensor, list(Tensor), the input of MindIR file.
3034
3035    Raises:
3036        TypeError: If the parameter file_name is not `str`.
3037        RuntimeError: MindIR's input is not tensor type or has no dims.
3038
3039    Examples:
3040        >>> input_tensor = get_mindir_inputs("lenet.mindir")
3041    """
3042    Validator.check_file_name_by_regular(file_name)
3043    file_name = os.path.abspath(file_name)
3044    model = read_proto(file_name)
3045    input_tensor = []
3046
3047    for ele_input in model.graph.input:
3048        input_shape = []
3049        if not hasattr(ele_input, "tensor") or not hasattr(ele_input.tensor[0], "dims"):
3050            raise RuntimeError("MindIR's inputs has no tensor or tensor has no dims, please check MindIR file.")
3051
3052        for ele_shape in ele_input.tensor[0].dims:
3053            input_shape.append(ele_shape)
3054        if is_shape_unknown(input_shape):
3055            raise RuntimeError(f"MindIR input's shape is: {input_shape}, dynamic shape is not supported.")
3056
3057        mindir_type = ele_input.tensor[0].data_type
3058        if mindir_type not in mindir_to_tensor_type:
3059            raise RuntimeError(f"MindIR input's type: {mindir_type} is not supported.")
3060
3061        input_type = mindir_to_tensor_type.get(mindir_type)
3062        input_tensor.append(Tensor(shape=input_shape, dtype=input_type, init=One()))
3063
3064    if not input_tensor:
3065        logger.warning("The MindIR model has no input, return None.")
3066        return None
3067    return input_tensor[0] if len(input_tensor) == 1 else input_tensor
3068
3069
3070def convert_model(mindir_file, convert_file, file_format):
3071    """
3072    Convert mindir model to other format model. The current version only supports conversion to ONNX models.
3073
3074    .. warning::
3075        This is an experimental API that is subject to change or deletion.
3076
3077    Args:
3078        mindir_file (str): MindIR file name.
3079        convert_file (str): Convert model file name.
3080        file_format (str): Convert model's format, current version only supports "ONNX".
3081
3082    Raises:
3083        TypeError: If the parameter `mindir_file` is not `str`.
3084        TypeError: If the parameter `convert_file` is not `str`.
3085        ValueError: If the parameter `file_format` is not "ONNX".
3086
3087    Examples:
3088        >>> import mindspore as ms
3089        >>> ms.convert_model("lenet.mindir", "lenet.onnx", "ONNX")
3090    """
3091    Validator.check_file_name_by_regular(mindir_file)
3092    Validator.check_file_name_by_regular(convert_file)
3093    if file_format != "ONNX":
3094        raise ValueError(f"For 'convert_model', 'file_format' must be 'ONNX', but got {file_format}.")
3095    net_input = _get_mindir_inputs(mindir_file)
3096    graph = load(mindir_file)
3097    net = nn.GraphCell(graph)
3098    if isinstance(net_input, Tensor):
3099        export(net, net_input, file_name=convert_file, file_format=file_format)
3100    else:
3101        export(net, *net_input, file_name=convert_file, file_format=file_format)
3102