• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Model and parameters serialization."""
16import os
17import sys
18import stat
19import math
20import shutil
21import time
22import copy
23import json
24import threading
25from threading import Thread, Lock
26from collections import defaultdict
27
28import numpy as np
29
30import mindspore
31import mindspore.nn as nn
32from mindspore import context
33from mindspore import log as logger
34from mindspore.train.checkpoint_pb2 import Checkpoint
35from mindspore.train.print_pb2 import Print
36from mindspore.train.node_strategy_pb2 import ParallelStrategyMap, ParallelLayouts
37from mindspore.train.mind_ir_pb2 import ModelProto as mindir_model
38from mindspore.train.mind_ir_pb2 import GraphProto as graph_proto
39from mindspore.common.tensor import Tensor
40from mindspore.common.initializer import initializer
41from mindspore.common.parameter import Parameter
42from mindspore.common.api import _cell_graph_executor as _executor
43from mindspore.common import dtype as mstype
44from mindspore._checkparam import check_input_data, Validator
45from mindspore.compression.export import quant_export
46from mindspore.parallel._tensor import _load_tensor, _get_tensor_strategy, _get_tensor_slice_index
47from mindspore.parallel._utils import _infer_rank_list, _remove_repeated_slices
48from mindspore.communication.management import get_rank, get_group_size
49from mindspore.parallel._tensor import _reshape_param_data_with_weight
50from mindspore.parallel._cell_wrapper import get_allgather_cell
51from mindspore.parallel._tensor import _reshape_param_data
52from .._c_expression import load_mindir, _encrypt, _decrypt, _is_cipher_file
53
54
55tensor_to_ms_type = {"Int8": mstype.int8, "Uint8": mstype.uint8, "Int16": mstype.int16, "Uint16": mstype.uint16,
56                     "Int32": mstype.int32, "Uint32": mstype.uint32, "Int64": mstype.int64, "Uint64": mstype.uint64,
57                     "Float16": mstype.float16, "Float32": mstype.float32, "Float64": mstype.float64,
58                     "Bool": mstype.bool_}
59
60tensor_to_np_type = {"Int8": np.int8, "Uint8": np.uint8, "Int16": np.int16, "Uint16": np.uint16,
61                     "Int32": np.int32, "Uint32": np.uint32, "Int64": np.int64, "Uint64": np.uint64,
62                     "Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_}
63
64_ckpt_mutex = Lock()
65
66# unit is KB
67SLICE_SIZE = 512 * 1024
68PROTO_LIMIT_SIZE = 1024 * 1024 * 2
69TOTAL_SAVE = 1024 * 1024
70
71
72def _special_process_par(par, new_par):
73    """
74    Processes the special condition.
75
76    Like (12,2048,1,1)->(12,2048), this case is caused by GE 4 dimensions tensor.
77    """
78    par_shape_len = len(par.data.shape)
79    new_par_shape_len = len(new_par.data.shape)
80    if new_par_shape_len <= par_shape_len:
81        return False
82
83    for i in range(new_par_shape_len - par_shape_len):
84        if new_par.data.shape[par_shape_len + i] != 1:
85            return False
86
87    new_val = new_par.data.asnumpy()
88    new_val = new_val.reshape(par.data.shape)
89    par.set_data(Tensor(new_val, par.data.dtype))
90    return True
91
92
93def _update_param(param, new_param, strict_load):
94    """Updates param's data from new_param's data."""
95    if isinstance(param.data, Tensor) and isinstance(new_param.data, Tensor):
96        if param.data.shape != new_param.data.shape:
97            if not _special_process_par(param, new_param):
98                logger.error("Failed to combine the net and the parameters for param %s.", param.name)
99                msg = ("Net parameters {} shape({}) different from parameter_dict's({})"
100                       .format(param.name, param.data.shape, new_param.data.shape))
101                raise RuntimeError(msg)
102
103        if param.data.dtype != new_param.data.dtype:
104            if _type_convert(param, new_param, strict_load):
105                new_tensor = Tensor(new_param.data.asnumpy(), param.data.dtype)
106                param.set_data(new_tensor)
107                return
108
109            logger.error("Failed to combine the net and the parameters for param %s.", param.name)
110            msg = ("Net parameters {} type({}) different from parameter_dict's({})"
111                   .format(param.name, param.data.dtype, new_param.data.dtype))
112            raise RuntimeError(msg)
113
114        param.set_data(new_param.data, param.sliced)
115        return
116
117    if isinstance(param.data, Tensor) and not isinstance(new_param.data, Tensor):
118        if param.data.shape != (1,) and param.data.shape != ():
119            logger.error("Failed to combine the net and the parameters for param %s.", param.name)
120            msg = ("Net parameters {} shape({}) is not (1,), inconsistent with parameter_dict's(scalar)."
121                   .format(param.name, param.data.shape))
122            raise RuntimeError(msg)
123        param.set_data(initializer(new_param.data, param.data.shape, param.data.dtype))
124
125    elif isinstance(new_param.data, Tensor) and not isinstance(param.data, Tensor):
126        logger.error("Failed to combine the net and the parameters for param %s.", param.name)
127        msg = ("Net parameters {} type({}) different from parameter_dict's({})"
128               .format(param.name, type(param.data), type(new_param.data)))
129        raise RuntimeError(msg)
130
131    else:
132        param.set_data(type(param.data)(new_param.data))
133
134
135def _type_convert(param, new_param, strict_load):
136    """Whether to convert parameter's type during load checkpoint into network."""
137    float_type = (mstype.float16, mstype.float32, mstype.float64)
138    int_type = (mstype.int8, mstype.int16, mstype.int32, mstype.int64)
139    if not strict_load and ({param.data.dtype, new_param.data.dtype}.issubset(float_type) or
140                            {param.data.dtype, new_param.data.dtype}.issubset(int_type)):
141        logger.warning("ckpt_dict parameter: {}'s type is {}, convert to {} in the network."
142                       .format(new_param.name, new_param.data.dtype, param.data.dtype))
143        return True
144    return False
145
146
147def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM"):
148    """Execute the process of saving checkpoint into file."""
149    try:
150        with _ckpt_mutex:
151            if os.path.exists(ckpt_file_name):
152                os.remove(ckpt_file_name)
153            with open(ckpt_file_name, "ab") as f:
154                if enc_key is not None:
155                    plain_data = bytes(0)
156                    cipher_data = bytes(0)
157
158                for name, value in data_list.items():
159                    data_size = value[2].nbytes / 1024
160                    if data_size > SLICE_SIZE:
161                        slice_count = math.ceil(data_size / SLICE_SIZE)
162                        param_slice_list = np.array_split(value[2], slice_count)
163                    else:
164                        param_slice_list = [value[2]]
165
166                    for param_slice in param_slice_list:
167                        checkpoint_list = Checkpoint()
168                        param_value = checkpoint_list.value.add()
169                        param_value.tag = name
170                        param_tensor = param_value.tensor
171                        param_tensor.dims.extend(value[0])
172                        param_tensor.tensor_type = value[1]
173                        param_tensor.tensor_content = param_slice.tobytes()
174
175                        if enc_key is None:
176                            f.write(checkpoint_list.SerializeToString())
177                        else:
178                            plain_data += checkpoint_list.SerializeToString()
179
180                            max_block_size = SLICE_SIZE*1024
181                            while len(plain_data) >= max_block_size:
182                                cipher_data += _encrypt(plain_data[0: max_block_size], max_block_size, enc_key,
183                                                        len(enc_key), enc_mode)
184                                plain_data = plain_data[max_block_size:]
185
186                if enc_key is not None:
187                    if plain_data:
188                        cipher_data += _encrypt(plain_data, len(plain_data), enc_key, len(enc_key), enc_mode)
189                    f.write(cipher_data)
190
191        os.chmod(ckpt_file_name, stat.S_IRUSR)
192
193    except BaseException as e:
194        logger.error("Failed to save the checkpoint file %s.", ckpt_file_name)
195        raise e
196
197
198def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
199                    async_save=False, append_dict=None, enc_key=None, enc_mode="AES-GCM"):
200    """
201    Save checkpoint to a specified file.
202
203    Args:
204        save_obj (Union[Cell, list]): The cell object or data list(each element is a dictionary, like
205                                      [{"name": param_name, "data": param_data},...], the type of
206                                      param_name would be string, and the type of param_data would
207                                      be parameter or Tensor).
208        ckpt_file_name (str): Checkpoint file name. If the file name already exists, it will be overwritten.
209        integrated_save (bool): Whether to integrated save in automatic model parallel scene. Default: True
210        async_save (bool): Whether to open a independent thread to save the checkpoint file. Default: False
211        append_dict (dict): Additional information that needs to be saved.  The key of dict must be str,
212            the value of dict must be one of int float and bool. Default: None
213        enc_key (Union[None, bytes]): Byte type key used for encryption. If the value is None, the encryption
214                                      is not required. Default: None.
215        enc_mode (str): This parameter is valid only when enc_key is not set to None. Specifies the encryption
216                        mode, currently supports 'AES-GCM' and 'AES-CBC'. Default: 'AES-GCM'.
217
218    Raises:
219        TypeError: If the parameter save_obj is not `nn.Cell` or list type. And if the parameter
220                   `integrated_save` and `async_save` are not bool type.
221
222    Examples:
223        >>> from mindspore import save_checkpoint
224        >>>
225        >>> net = Net()
226        >>> save_checkpoint(net, "lenet.ckpt")
227    """
228
229    if not isinstance(save_obj, nn.Cell) and not isinstance(save_obj, list):
230        raise TypeError("The parameter save_obj should be nn.Cell or list, but got {}".format(type(save_obj)))
231    integrated_save = Validator.check_bool(integrated_save)
232    async_save = Validator.check_bool(async_save)
233    append_dict = _check_append_dict(append_dict)
234    enc_key = Validator.check_isinstance('enc_key', enc_key, (type(None), bytes))
235    enc_mode = Validator.check_isinstance('enc_mode', enc_mode, str)
236
237    logger.info("Execute the process of saving checkpoint files.")
238
239    if isinstance(save_obj, nn.Cell):
240        save_obj.init_parameters_data()
241        param_dict = {}
242        for _, param in save_obj.parameters_and_names():
243            param_dict[param.name] = param
244        param_list = []
245        for (key, value) in param_dict.items():
246            each_param = {"name": key}
247            param_data = Tensor(value.data)
248
249            # in automatic model parallel scenario, some parameters were split to all the devices,
250            # which should be combined before saving
251            if key in save_obj.parameter_layout_dict:
252                param_data = _get_merged_param_data(save_obj, key, param_data, integrated_save)
253
254            each_param["data"] = param_data
255            param_list.append(each_param)
256        save_obj = param_list
257
258    if append_dict:
259        append_info_list = []
260        for k_name, value in append_dict.items():
261            append_info_list.append({"name": k_name, "data": Tensor(value)})
262            save_obj.extend(append_info_list)
263
264    data_list = {}
265    with _ckpt_mutex:
266        for param in save_obj:
267            key = param["name"]
268            data_list[key] = []
269            if isinstance(param["data"], Parameter):
270                param["data"].init_data()
271            dims = []
272            if param['data'].shape == ():
273                dims.append(0)
274            else:
275                for dim in param['data'].shape:
276                    dims.append(dim)
277            data_list[key].append(dims)
278            tensor_type = str(param["data"].dtype)
279            data_list[key].append(tensor_type)
280            data = param["data"].asnumpy().reshape(-1)
281            data_list[key].append(data)
282
283    ckpt_file_name = os.path.realpath(ckpt_file_name)
284    if async_save:
285        thr = Thread(target=_exec_save, args=(ckpt_file_name, data_list, enc_key, enc_mode), name="asyn_save_ckpt")
286        thr.start()
287    else:
288        _exec_save(ckpt_file_name, data_list, enc_key, enc_mode)
289
290    logger.info("Saving checkpoint process is finished.")
291
292
293def _check_param_prefix(filter_prefix, param_name):
294    """Checks whether the prefix of parameter name matches the given filter_prefix."""
295    for prefix in filter_prefix:
296        if param_name.find(prefix) == 0 \
297                and (param_name == prefix or param_name[len(prefix)] == "." or (prefix and prefix[-1] == ".")):
298            return True
299    return False
300
301
302def _check_append_dict(append_dict):
303    if append_dict is None:
304        return append_dict
305    if not isinstance(append_dict, dict):
306        raise TypeError(f"The type of append_dict must dict, but got {str(type(append_dict))}.")
307    if not all(isinstance(ele, str) for ele in append_dict.keys()) or \
308            not all(isinstance(ele, (int, float, bool)) for ele in append_dict.values()):
309        raise TypeError(f"The type of element in append_dict must be key: str, value: int or float.")
310    return append_dict
311
312
313def load(file_name, **kwargs):
314    """
315    Load MindIR.
316
317    The returned object can be executed by a `GraphCell`, see class :class:`mindspore.nn.GraphCell` for more details.
318
319    Args:
320        file_name (str): MindIR file name.
321
322        kwargs (dict): Configuration options dictionary.
323
324            - dec_key (bytes): Byte type key used for decryption. Tha valid length is 16, 24, or 32.
325            - dec_mode (str): Specifies the decryption mode, take effect when dec_key is set.
326              Option: 'AES-GCM' | 'AES-CBC'. Default: 'AES-GCM'.
327    Returns:
328        Object, a compiled graph that can executed by `GraphCell`.
329
330    Raises:
331        ValueError: MindIR file name is incorrect.
332        RuntimeError: Failed to parse MindIR file.
333
334    Examples:
335        >>> import numpy as np
336        >>> import mindspore.nn as nn
337        >>> from mindspore import Tensor, export, load
338        >>>
339        >>> net = nn.Conv2d(1, 1, kernel_size=3, weight_init="ones")
340        >>> input_tensor = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
341        >>> export(net, input_tensor, file_name="net", file_format="MINDIR")
342        >>> graph = load("net.mindir")
343        >>> net = nn.GraphCell(graph)
344        >>> output = net(input_tensor)
345        >>> print(output)
346        [[[[4. 6. 4.]
347           [6. 9. 6.]
348           [4. 6. 4.]]]]
349    """
350    if not isinstance(file_name, str):
351        raise ValueError("The file name must be string.")
352    if not file_name.endswith(".mindir"):
353        raise ValueError("The MindIR should end with mindir, please input the correct file name.")
354    if not os.path.exists(file_name):
355        raise ValueError("The file does not exist.")
356    file_name = os.path.realpath(file_name)
357
358    logger.info("Execute the process of loading mindir.")
359    if 'dec_key' in kwargs.keys():
360        dec_key = Validator.check_isinstance('dec_key', kwargs['dec_key'], bytes)
361        dec_mode = 'AES-GCM'
362        if 'dec_mode' in kwargs.keys():
363            dec_mode = Validator.check_isinstance('dec_mode', kwargs['dec_mode'], str)
364        graph = load_mindir(file_name, dec_key=dec_key, key_len=len(dec_key), dec_mode=dec_mode)
365    else:
366        graph = load_mindir(file_name)
367
368    if graph is None:
369        if _is_cipher_file(file_name):
370            raise RuntimeError("Load MindIR failed. The file may be encrypted, please pass in the "
371                               "correct dec_key and dec_mode.")
372        raise RuntimeError("Load MindIR failed.")
373    return graph
374
375
376def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=None, dec_key=None, dec_mode="AES-GCM"):
377    """
378    Load checkpoint info from a specified file.
379
380    Args:
381        ckpt_file_name (str): Checkpoint file name.
382        net (Cell): The network where the parameters will be loaded. Default: None
383        strict_load (bool): Whether to strict load the parameter into net. If False, it will load parameter
384                            into net when parameter name's suffix in checkpoint file is the same as the
385                            parameter in the network. When the types are inconsistent perform type conversion
386                            on the parameters of the same type, such as float32 to float16. Default: False.
387        filter_prefix (Union[str, list[str], tuple[str]]): Parameters starting with the filter_prefix
388            will not be loaded. Default: None.
389        dec_key (Union[None, bytes]): Byte type key used for decryption. If the value is None, the decryption
390                                      is not required. Default: None.
391        dec_mode (str): This parameter is valid only when dec_key is not set to None. Specifies the decryption
392                        mode, currently supports 'AES-GCM' and 'AES-CBC'. Default: 'AES-GCM'.
393
394    Returns:
395        Dict, key is parameter name, value is a Parameter.
396
397    Raises:
398        ValueError: Checkpoint file is incorrect.
399
400    Examples:
401        >>> from mindspore import load_checkpoint
402        >>>
403        >>> ckpt_file_name = "./checkpoint/LeNet5-1_32.ckpt"
404        >>> param_dict = load_checkpoint(ckpt_file_name, filter_prefix="conv1")
405        >>> print(param_dict["conv2.weight"])
406        Parameter (name=conv2.weight, shape=(16, 6, 5, 5), dtype=Float32, requires_grad=True)
407    """
408    ckpt_file_name, filter_prefix = _check_checkpoint_param(ckpt_file_name, filter_prefix)
409    dec_key = Validator.check_isinstance('dec_key', dec_key, (type(None), bytes))
410    dec_mode = Validator.check_isinstance('dec_mode', dec_mode, str)
411    logger.info("Execute the process of loading checkpoint files.")
412    checkpoint_list = Checkpoint()
413
414    try:
415        if dec_key is None:
416            with open(ckpt_file_name, "rb") as f:
417                pb_content = f.read()
418        else:
419            pb_content = _decrypt(ckpt_file_name, dec_key, len(dec_key), dec_mode)
420            if pb_content is None:
421                raise ValueError
422        checkpoint_list.ParseFromString(pb_content)
423    except BaseException as e:
424        if _is_cipher_file(ckpt_file_name):
425            logger.error("Failed to read the checkpoint file `%s`. The file may be encrypted, please pass in the "
426                         "correct dec_key.", ckpt_file_name)
427        else:
428            logger.error("Failed to read the checkpoint file `%s`, please check the correct of the file.", \
429                         ckpt_file_name)
430        raise ValueError(e.__str__())
431
432    parameter_dict = {}
433    try:
434        param_data_list = []
435        for element_id, element in enumerate(checkpoint_list.value):
436            if filter_prefix is not None and _check_param_prefix(filter_prefix, element.tag):
437                continue
438            data = element.tensor.tensor_content
439            data_type = element.tensor.tensor_type
440            np_type = tensor_to_np_type[data_type]
441            ms_type = tensor_to_ms_type[data_type]
442            element_data = np.frombuffer(data, np_type)
443            param_data_list.append(element_data)
444            if (element_id == len(checkpoint_list.value) - 1) or \
445                    (element.tag != checkpoint_list.value[element_id + 1].tag):
446                param_data = np.concatenate((param_data_list), axis=0)
447                param_data_list.clear()
448                dims = element.tensor.dims
449                if dims == [0]:
450                    if 'Float' in data_type:
451                        param_data = float(param_data[0])
452                    elif 'Int' in data_type:
453                        param_data = int(param_data[0])
454                    parameter_dict[element.tag] = Parameter(Tensor(param_data, ms_type), name=element.tag)
455                elif dims == [1]:
456                    parameter_dict[element.tag] = Parameter(Tensor(param_data, ms_type), name=element.tag)
457                else:
458                    param_dim = []
459                    for dim in dims:
460                        param_dim.append(dim)
461                    param_value = param_data.reshape(param_dim)
462                    parameter_dict[element.tag] = Parameter(Tensor(param_value, ms_type), name=element.tag)
463
464        logger.info("Loading checkpoint files process is finished.")
465
466    except BaseException as e:
467        logger.error("Failed to load the checkpoint file `%s`.", ckpt_file_name)
468        raise RuntimeError(e.__str__())
469
470    if not parameter_dict:
471        raise ValueError(f"The loaded parameter dict is empty after filtering, please check filter_prefix.")
472
473    if net is not None:
474        load_param_into_net(net, parameter_dict, strict_load)
475
476    return parameter_dict
477
478
479def _check_checkpoint_param(ckpt_file_name, filter_prefix=None):
480    """Check function load_checkpoint's parameter."""
481    if not isinstance(ckpt_file_name, str):
482        raise ValueError("The ckpt_file_name must be string.")
483
484    if not os.path.exists(ckpt_file_name):
485        raise ValueError("The checkpoint file does not exist.")
486
487    if ckpt_file_name[-5:] != ".ckpt":
488        raise ValueError("Please input the correct checkpoint file name.")
489    ckpt_file_name = os.path.realpath(ckpt_file_name)
490
491    if filter_prefix is not None:
492        if not isinstance(filter_prefix, (str, list, tuple)):
493            raise TypeError(f"The type of filter_prefix must be str, list[str] or tuple[str] "
494                            f"when filter_prefix is not None, but got {str(type(filter_prefix))}.")
495        if isinstance(filter_prefix, str):
496            filter_prefix = (filter_prefix,)
497        if not filter_prefix:
498            raise ValueError("The filter_prefix can't be empty when filter_prefix is list or tuple.")
499        for index, prefix in enumerate(filter_prefix):
500            if not isinstance(prefix, str):
501                raise TypeError(f"The type of filter_prefix must be str, list[str] or tuple[str], "
502                                f"but got {str(type(prefix))} at index {index}.")
503    return ckpt_file_name, filter_prefix
504
505
506def load_param_into_net(net, parameter_dict, strict_load=False):
507    """
508    Load parameters into network.
509
510    Args:
511        net (Cell): The network where the parameters will be loaded.
512        parameter_dict (dict): The dictionary generated by load checkpoint file.
513        strict_load (bool): Whether to strict load the parameter into net. If False, it will load parameter
514                            into net when parameter name's suffix in checkpoint file is the same as the
515                            parameter in the network. When the types are inconsistent perform type conversion
516                            on the parameters of the same type, such as float32 to float16. Default: False.
517
518    Returns:
519        List, parameter name not loaded into the network
520
521    Raises:
522        TypeError: Argument is not a Cell, or parameter_dict is not a Parameter dictionary.
523
524    Examples:
525        >>> from mindspore import load_checkpoint, load_param_into_net
526        >>>
527        >>> net = Net()
528        >>> ckpt_file_name = "./checkpoint/LeNet5-1_32.ckpt"
529        >>> param_dict = load_checkpoint(ckpt_file_name, filter_prefix="conv1")
530        >>> param_not_load = load_param_into_net(net, param_dict)
531        >>> print(param_not_load)
532        ['conv1.weight']
533    """
534    if not isinstance(net, nn.Cell):
535        logger.error("Failed to combine the net and the parameters.")
536        msg = ("Argument net should be a Cell, but got {}.".format(type(net)))
537        raise TypeError(msg)
538
539    if not isinstance(parameter_dict, dict):
540        logger.error("Failed to combine the net and the parameters.")
541        msg = ("Argument parameter_dict should be a dict, but got {}.".format(type(parameter_dict)))
542        raise TypeError(msg)
543
544    strict_load = Validator.check_bool(strict_load)
545    logger.info("Execute the process of loading parameters into net.")
546    net.init_parameters_data()
547    param_not_load = []
548    for _, param in net.parameters_and_names():
549        if param.name in parameter_dict:
550            new_param = copy.deepcopy(parameter_dict[param.name])
551            if not isinstance(new_param, Parameter):
552                logger.error("Failed to combine the net and the parameters.")
553                msg = ("Argument parameter_dict element should be a Parameter, but got {}.".format(type(new_param)))
554                raise TypeError(msg)
555            _update_param(param, new_param, strict_load)
556        else:
557            param_not_load.append(param.name)
558
559    if param_not_load and not strict_load:
560        _load_dismatch_prefix_params(net, parameter_dict, param_not_load, strict_load)
561
562    logger.debug("Params not matched(in net but not in parameter_dict):")
563    for param_name in param_not_load:
564        logger.debug("%s", param_name)
565
566    logger.info("Loading parameters into net is finished.")
567    if param_not_load:
568        logger.warning("{} parameters in the net are not loaded.".format(len(param_not_load)))
569        for param_name in param_not_load:
570            logger.warning("{} is not loaded.".format(param_name))
571    return param_not_load
572
573
574def _load_dismatch_prefix_params(net, parameter_dict, param_not_load, strict_load):
575    """When some net parameter did not load, try to continue load."""
576    prefix_name = ""
577    longest_name = param_not_load[0]
578    while prefix_name != longest_name and param_not_load:
579        logger.debug("Count: {} parameters has not been loaded, try to load continue.".format(len(param_not_load)))
580        prefix_name = longest_name
581        for net_param_name in param_not_load:
582            for dict_name in parameter_dict:
583                if dict_name.endswith(net_param_name):
584                    prefix_name = dict_name[:-len(net_param_name)]
585                    break
586            if prefix_name != longest_name:
587                break
588
589        if prefix_name != longest_name:
590            logger.warning("Remove parameter prefix name: {}, continue to load.".format(prefix_name))
591            for _, param in net.parameters_and_names():
592                new_param_name = prefix_name + param.name
593                if param.name in param_not_load and new_param_name in parameter_dict:
594                    new_param = parameter_dict[new_param_name]
595                    _update_param(param, new_param, strict_load)
596                    param_not_load.remove(param.name)
597
598
599def _save_graph(network, file_name):
600    """
601    Saves the graph of network to a file.
602
603    Args:
604        network (Cell): Obtain a pipeline through network for saving graph.
605        file_name (str): Graph file name into which the graph will be saved.
606    """
607    logger.info("Execute the process of saving graph.")
608
609    file_name = os.path.realpath(file_name)
610    graph_pb = network.get_func_graph_proto()
611    if graph_pb:
612        with open(file_name, "wb") as f:
613            os.chmod(file_name, stat.S_IRUSR | stat.S_IWUSR)
614            f.write(graph_pb)
615
616
617def _get_merged_param_data(net, param_name, param_data, integrated_save):
618    """
619    Gets the merged data(tensor) from tensor slice, by device arrangement and tensor map.
620
621    Args:
622        net (Cell): MindSpore network.
623        param_name (str): The parameter name, which to be combined.
624        param_data (Tensor): The parameter data on the local device, which was a slice of the whole parameter data.
625        integrated_save (bool): Whether to integrated save in automatic model parallel scene.
626    Returns:
627        Tensor, the combined tensor which with the whole data value.
628    """
629    layout = net.parameter_layout_dict[param_name]
630    if len(layout) < 6:
631        logger.info("layout dict does not contain the key %s", param_name)
632        return param_data
633
634    dev_mat = layout[0]
635    tensor_map = layout[1]
636    uniform_split = layout[4]
637    opt_shard_group = layout[5]
638
639    allgather_net = None
640    mp_weight = False
641    for dim in tensor_map:
642        if dim != -1:
643            mp_weight = True
644            break
645    if param_name in net.parallel_parameter_merge_net_dict:
646        allgather_net = net.parallel_parameter_merge_net_dict[param_name]
647    else:
648        logger.info("need to create allgather net for %s", param_name)
649        if integrated_save:
650            if context.get_auto_parallel_context("pipeline_stages") > 1:
651                raise RuntimeError("Pipeline Parallel don't support Integrated save checkpoint now.")
652            if uniform_split == 0:
653                raise RuntimeError("Integrated save checkpoint only support uniform split tensor now.")
654            # while any dim is not equal to -1, means param is split and needs to be merged
655            # pipeline parallel need to be supported here later
656            if mp_weight:
657                allgather_net = get_allgather_cell(opt_shard_group, bool(opt_shard_group))
658            elif opt_shard_group:
659                allgather_net = get_allgather_cell(opt_shard_group, False)
660        elif opt_shard_group and context.get_auto_parallel_context("optimizer_weight_shard_aggregated_save"):
661            allgather_net = get_allgather_cell(opt_shard_group, False)
662        net.parallel_parameter_merge_net_dict[param_name] = allgather_net
663    if allgather_net:
664        param_data = allgather_net(param_data)
665    if mp_weight and integrated_save:
666        param_data = _reshape_param_data(param_data, dev_mat, tensor_map)
667    return param_data
668
669
670def _fill_param_into_net(net, parameter_list):
671    """
672    Fills parameter_list into net.
673
674    Args:
675        net (Cell): train network.
676        parameter_list (list): parameters list from ge callback.
677    """
678    parameter_dict = {}
679    for each_param in parameter_list:
680        param_name = each_param["name"]
681        if isinstance(each_param["data"], Parameter):
682            each_param["data"].init_data()
683        np_val = each_param["data"].asnumpy()
684        if np_val.shape == (1,):
685            parameter_dict[param_name] = Parameter(np_val, name=param_name)
686        elif np_val.shape == ():
687            parameter_dict[param_name] = Parameter(Tensor(np_val.tolist(), mstype.pytype_to_dtype(np_val.dtype)),
688                                                   name=param_name)
689        else:
690            parameter_dict[param_name] = Parameter(Tensor(np_val), name=param_name)
691
692    load_param_into_net(net, parameter_dict)
693
694
695def export(net, *inputs, file_name, file_format='AIR', **kwargs):
696    """
697    Export the mindspore network into an offline model in the specified format.
698
699    Note:
700        1. When exporting AIR, ONNX format, the size of a single tensor can not exceed 2GB.
701        2. When file_name does not have a suffix, the system will automatically add one according to the file_format.
702
703    Args:
704        net (Cell): MindSpore network.
705        inputs (Tensor): Inputs of the `net`, if the network has multiple inputs, incoming tuple(Tensor).
706        file_name (str): File name of the model to be exported.
707        file_format (str): MindSpore currently supports 'AIR', 'ONNX' and 'MINDIR' format for exported model.
708
709            - AIR: Ascend Intermediate Representation. An intermediate representation format of Ascend model.
710            - ONNX: Open Neural Network eXchange. An open format built to represent machine learning models.
711            - MINDIR: MindSpore Native Intermediate Representation for Anf. An intermediate representation format
712              for MindSpore models.
713
714        kwargs (dict): Configuration options dictionary.
715
716            - quant_mode (str): If the network is quantization aware training network, the quant_mode should
717              be set to "QUANT", else the quant_mode should be set to "NONQUANT".
718            - mean (float): The mean of input data after preprocessing, used for quantizing the first layer of network.
719              Default: 127.5.
720            - std_dev (float): The variance of input data after preprocessing,
721              used for quantizing the first layer of network. Default: 127.5.
722            - enc_key (byte): Byte type key used for encryption. Tha valid length is 16, 24, or 32.
723            - enc_mode (str): Specifies the encryption mode, take effect when enc_key is set.
724              Option: 'AES-GCM' | 'AES-CBC'. Default: 'AES-GCM'.
725            - dataset (Dataset): Specifies the preprocess methods of network.
726
727    Examples:
728        >>> import numpy as np
729        >>> from mindspore import export, Tensor
730        >>>
731        >>> net = LeNet()
732        >>> input_tensor = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32))
733        >>> export(net, Tensor(input_tensor), file_name='lenet', file_format='MINDIR')
734    """
735    logger.info("exporting model file:%s format:%s.", file_name, file_format)
736    check_input_data(*inputs, data_class=Tensor)
737    Validator.check_file_name_by_regular(file_name)
738    file_name = os.path.realpath(file_name)
739    net = _quant_export(net, *inputs, file_format=file_format, **kwargs)
740    if 'enc_key' in kwargs.keys():
741        if file_format != 'MINDIR':
742            raise ValueError(f"enc_key can be passed in only when file_format=='MINDIR', but got {file_format}")
743
744        enc_key = Validator.check_isinstance('enc_key', kwargs['enc_key'], bytes)
745        enc_mode = 'AES-GCM'
746        if 'enc_mode' in kwargs.keys():
747            enc_mode = Validator.check_isinstance('enc_mode', kwargs['enc_mode'], str)
748        dataset = kwargs['dataset'] if 'dataset' in kwargs.keys() else None
749        _export(net, file_name, file_format, *inputs, enc_key=enc_key, enc_mode=enc_mode, dataset=dataset)
750    else:
751        _export(net, file_name, file_format, *inputs, **kwargs)
752
753
754def _export(net, file_name, file_format, *inputs, **kwargs):
755    """
756    It is an internal conversion function. Export the MindSpore prediction model to a file in the specified format.
757    """
758    logger.info("exporting model file:%s format:%s.", file_name, file_format)
759    check_input_data(*inputs, data_class=Tensor)
760    if 'dataset' in kwargs.keys() and kwargs['dataset'] is not None:
761        check_input_data(kwargs['dataset'], data_class=mindspore.dataset.Dataset)
762
763    if file_format == 'GEIR':
764        logger.warning(f"Format 'GEIR' is deprecated, it would be removed in future release, use 'AIR' instead.")
765        file_format = 'AIR'
766
767    supported_formats = ['AIR', 'ONNX', 'MINDIR']
768    if file_format not in supported_formats:
769        raise ValueError(f'Illegal file format {file_format}, it must be one of {supported_formats}')
770    # When dumping ONNX file, switch network mode to infer when it is training(NOTE: ONNX only designed for prediction)
771    is_dump_onnx_in_training = net.training and file_format == 'ONNX'
772    if is_dump_onnx_in_training:
773        net.set_train(mode=False)
774
775    if file_format == 'AIR':
776        phase_name = 'export.air'
777        graph_id, _ = _executor.compile(net, *inputs, phase=phase_name)
778        if not file_name.endswith('.air'):
779            file_name += ".air"
780        if os.path.exists(file_name):
781            os.chmod(file_name, stat.S_IWUSR)
782        if "/" in file_name:
783            real_path = os.path.realpath(file_name[:file_name.rfind("/")])
784            os.makedirs(real_path, exist_ok=True)
785        _executor.export(file_name, graph_id)
786        os.chmod(file_name, stat.S_IRUSR)
787    elif file_format == 'ONNX':
788        total_size = _calculation_net_size(net)
789        if total_size > PROTO_LIMIT_SIZE:
790            raise RuntimeError('Export onnx model failed. Network size is: {}G, it exceeded the protobuf: {}G limit.'
791                               .format(total_size/1024/1024, PROTO_LIMIT_SIZE/1024/1024))
792        phase_name = 'export.onnx'
793        graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False)
794        onnx_stream = _executor._get_func_graph_proto(net, graph_id)
795        if not file_name.endswith('.onnx'):
796            file_name += ".onnx"
797        if os.path.exists(file_name):
798            os.chmod(file_name, stat.S_IWUSR)
799        with open(file_name, 'wb') as f:
800            f.write(onnx_stream)
801            os.chmod(file_name, stat.S_IRUSR)
802    elif file_format == 'MINDIR':
803        _save_mindir(net, file_name, *inputs, **kwargs)
804
805    if is_dump_onnx_in_training:
806        net.set_train(mode=True)
807
808
809def _save_mindir(net, file_name, *inputs, **kwargs):
810    """Save MindIR format file."""
811    model = mindir_model()
812
813    phase_name = "predict" if net._auto_parallel_mode else "export.mindir"
814
815    graph_id, _ = _executor.compile(net, *inputs, phase=phase_name,
816                                    do_convert=False, auto_parallel_mode=net._auto_parallel_mode)
817    mindir_stream = _executor._get_func_graph_proto(net, graph_id, 'mind_ir')
818
819    net_dict = net.parameters_dict()
820    model.ParseFromString(mindir_stream)
821
822    if 'dataset' in kwargs.keys() and kwargs['dataset'] is not None:
823        dataset = kwargs['dataset']
824        model.preprocessor = json.dumps(dataset.to_json(), indent=2)
825
826    save_together = _save_together(net_dict, model)
827    is_encrypt = lambda: 'enc_key' in kwargs.keys() and 'enc_mode' in kwargs.keys()
828    if save_together:
829        _save_mindir_together(net_dict, model, file_name, is_encrypt, **kwargs)
830    else:
831        logger.warning("Parameters in the net capacity exceeds 1G, save MindIR model and parameters separately.")
832        # save parameter
833        file_prefix = file_name.split("/")[-1]
834        if file_prefix.endswith(".mindir"):
835            file_prefix = file_prefix[:-7]
836        current_path = os.path.abspath(file_name)
837        dirname = os.path.dirname(current_path)
838        data_path = os.path.join(dirname, file_prefix + "_variables")
839        if os.path.exists(data_path):
840            shutil.rmtree(data_path)
841        os.makedirs(data_path, exist_ok=True)
842        os.chmod(data_path, stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR)
843        index = 0
844        graphproto = graph_proto()
845        data_size = 0
846
847        for name, param in net_dict.items():
848            for param_proto in model.graph.parameter:
849                if name == param_proto.name[param_proto.name.find(":") + 1:]:
850                    parameter = graphproto.parameter.add()
851                    parameter.name = param_proto.name
852                    parameter.data_type = param_proto.data_type
853                    for dim in param_proto.dims:
854                        parameter.dims.append(dim)
855                    byte_data = param.data.asnumpy().tobytes()
856                    parameter.raw_data = byte_data
857                    data_size += sys.getsizeof(byte_data) / 1024
858                    break
859            if data_size > TOTAL_SAVE:
860                data_file_name = os.path.join(data_path, "data_" + str(index))
861                if os.path.exists(data_file_name):
862                    os.chmod(data_file_name, stat.S_IWUSR)
863                with open(data_file_name, "ab") as f:
864                    os.chmod(data_file_name, stat.S_IRUSR | stat.S_IWUSR)
865                    graph_string = graphproto.SerializeToString()
866                    if is_encrypt():
867                        graph_string = _encrypt(graph_string, len(graph_string), kwargs['enc_key'],
868                                                len(kwargs['enc_key']), kwargs['enc_mode'])
869                    f.write(graph_string)
870                    os.chmod(data_file_name, stat.S_IRUSR)
871                index += 1
872                data_size = 0
873                del graphproto.parameter[:]
874
875        if graphproto.parameter:
876            data_file_name = os.path.join(data_path, "data_" + str(index))
877            if os.path.exists(data_file_name):
878                os.chmod(data_file_name, stat.S_IWUSR)
879            with open(data_file_name, "ab") as f:
880                os.chmod(data_file_name, stat.S_IRUSR | stat.S_IWUSR)
881                graph_string = graphproto.SerializeToString()
882                if is_encrypt():
883                    graph_string = _encrypt(graph_string, len(graph_string), kwargs['enc_key'], len(kwargs['enc_key']),
884                                            kwargs['enc_mode'])
885                f.write(graph_string)
886                os.chmod(data_file_name, stat.S_IRUSR)
887
888        # save graph
889        del model.graph.parameter[:]
890        graph_file_name = os.path.join(dirname, file_prefix + "_graph.mindir")
891        if os.path.exists(graph_file_name):
892            os.chmod(graph_file_name, stat.S_IWUSR)
893        with open(graph_file_name, 'wb') as f:
894            os.chmod(graph_file_name, stat.S_IRUSR | stat.S_IWUSR)
895            model_string = model.SerializeToString()
896            if is_encrypt():
897                model_string = _encrypt(model_string, len(model_string), kwargs['enc_key'], len(kwargs['enc_key']),
898                                        kwargs['enc_mode'])
899            f.write(model_string)
900            os.chmod(graph_file_name, stat.S_IRUSR)
901
902
903def _save_mindir_together(net_dict, model, file_name, is_encrypt, **kwargs):
904    """Save graph and parameter together."""
905    for param_proto in model.graph.parameter:
906        param_name = param_proto.name[param_proto.name.find(":") + 1:]
907        if param_name in net_dict.keys():
908            param_data = net_dict[param_name].data.asnumpy().tobytes()
909            param_proto.raw_data = param_data
910        else:
911            logger.error("The parameter %s in the graph are not in the network.", param_name)
912            raise ValueError("The parameter in the graph must in the network.")
913    if not file_name.endswith('.mindir'):
914        file_name += ".mindir"
915    current_path = os.path.abspath(file_name)
916    dirname = os.path.dirname(current_path)
917    os.makedirs(dirname, exist_ok=True)
918    if os.path.exists(file_name):
919        os.chmod(file_name, stat.S_IWUSR)
920    with open(file_name, 'wb') as f:
921        os.chmod(file_name, stat.S_IRUSR | stat.S_IWUSR)
922        model_string = model.SerializeToString()
923        if is_encrypt():
924            model_string = _encrypt(model_string, len(model_string), kwargs['enc_key'], len(kwargs['enc_key']),
925                                    kwargs['enc_mode'])
926        f.write(model_string)
927        os.chmod(file_name, stat.S_IRUSR)
928
929
930def _save_together(net_dict, model):
931    """Whether graph and parameter save together during save mindir model."""
932    data_total = 0
933    for param_proto in model.graph.parameter:
934        name = param_proto.name[param_proto.name.find(":") + 1:]
935        if name in net_dict.keys():
936            data_total += sys.getsizeof(net_dict[name].data.asnumpy().tobytes()) / 1024
937        else:
938            raise RuntimeError('Graph parameter: {} Undefined in network.'.format(param_proto.name))
939        if data_total > TOTAL_SAVE:
940            return False
941    return True
942
943
944def quant_mode_manage(func):
945    """
946    Inherit the quant_mode in old version.
947    """
948    def warpper(network, *inputs, file_format, **kwargs):
949        if 'quant_mode' not in kwargs:
950            return network
951        quant_mode = kwargs['quant_mode']
952        if not isinstance(quant_mode, str):
953            raise TypeError("The type of quant_mode should be str, but got {}.".format(type(quant_mode)))
954        if quant_mode in ('AUTO', 'MANUAL'):
955            kwargs['quant_mode'] = 'QUANT'
956        return func(network, *inputs, file_format=file_format, **kwargs)
957    return warpper
958
959
960@quant_mode_manage
961def _quant_export(network, *inputs, file_format, **kwargs):
962    """
963    Exports MindSpore quantization predict model to deploy with AIR and MINDIR.
964    """
965    supported_device = ["Ascend", "GPU"]
966    supported_formats = ['AIR', 'MINDIR']
967    quant_mode_formats = ['QUANT', 'NONQUANT']
968
969    quant_mode = kwargs['quant_mode']
970    if quant_mode not in quant_mode_formats:
971        raise KeyError(f'Quant_mode input is wrong, Please choose the right mode of the quant_mode.')
972    if quant_mode == 'NONQUANT':
973        return network
974    quant_net = copy.deepcopy(network)
975    quant_net._create_time = int(time.time() * 1e9)
976
977    mean = 127.5 if kwargs.get('mean', None) is None else kwargs['mean']
978    std_dev = 127.5 if kwargs.get('std_dev', None) is None else kwargs['std_dev']
979    mean = Validator.check_value_type("mean", mean, (int, float))
980    std_dev = Validator.check_value_type("std_dev", std_dev, (int, float))
981
982    if context.get_context('device_target') not in supported_device:
983        raise KeyError("Unsupported {} device target.".format(context.get_context('device_target')))
984
985    if file_format not in supported_formats:
986        raise ValueError('Illegal file format {}.'.format(file_format))
987
988    quant_net.set_train(False)
989    if file_format == "MINDIR":
990        exporter = quant_export.ExportToQuantInferNetwork(quant_net, mean, std_dev, *inputs, is_mindir=True)
991    else:
992        exporter = quant_export.ExportToQuantInferNetwork(quant_net, mean, std_dev, *inputs)
993    deploy_net = exporter.run()
994    return deploy_net
995
996
997def parse_print(print_file_name):
998    """
999    Parse saved data generated by mindspore.ops.Print.  Print is used to print data to screen in graph mode.
1000    It can also been turned off by setting the parameter `print_file_path` in `context`, and the data will be saved
1001    in a file specified by print_file_path. parse_print is used to parse the saved file. For more information
1002    please refer to :func:`mindspore.context.set_context` and :class:`mindspore.ops.Print`.
1003
1004    Args:
1005        print_file_name (str): The file name of saved print data.
1006
1007    Returns:
1008        List, element of list is Tensor.
1009
1010    Raises:
1011        ValueError: The print file may be empty, please make sure enter the correct file name.
1012
1013    Examples:
1014        >>> import numpy as np
1015        >>> import mindspore
1016        >>> import mindspore.ops as ops
1017        >>> from mindspore.nn as nn
1018        >>> from mindspore import Tensor, context
1019        >>> context.set_context(mode=context.GRAPH_MODE, print_file_path='log.data')
1020        >>> class PrintInputTensor(nn.Cell):
1021        ...         def __init__(self):
1022        ...             super().__init__()
1023        ...             self.print = ops.Print()
1024
1025        ...         def construct(self, input_pra):
1026        ...             self.print('print:', input_pra)
1027        ...             return input_pra
1028
1029        >>> x = np.array([[1, 2, 3, 4], [5, 6, 7, 8]]).astype(np.float32)
1030        >>> input_pra = Tensor(x)
1031        >>> net = PrintInputTensor()
1032        >>> net(input_pra)
1033
1034        >>> data = mindspore.parse_print('./log.data')
1035        >>> print(data)
1036        ['print:', Tensor(shape=[2, 4], dtype=Float32, value=
1037        [[ 1.00000000e+00,  2.00000000e+00,  3.00000000e+00,  4.00000000e+00],
1038        [ 5.00000000e+00,  6.00000000e+00,  7.00000000e+00,  8.00000000e+00]])]
1039    """
1040
1041    print_file_path = os.path.realpath(print_file_name)
1042
1043    if os.path.getsize(print_file_path) == 0:
1044        raise ValueError("The print file may be empty, please make sure enter the correct file name.")
1045
1046    logger.info("Execute load print process.")
1047    print_list = Print()
1048
1049    try:
1050        with open(print_file_path, "rb") as f:
1051            pb_content = f.read()
1052        print_list.ParseFromString(pb_content)
1053    except BaseException as e:
1054        logger.error("Failed to read the print file %s, please check the correct of the file.", print_file_name)
1055        raise ValueError(e.__str__())
1056
1057    tensor_list = []
1058
1059    try:
1060        for print_ in print_list.value:
1061            # String type
1062            if print_.HasField("desc"):
1063                tensor_list.append(print_.desc)
1064            elif print_.HasField("tensor"):
1065                dims = print_.tensor.dims
1066                data_type = print_.tensor.tensor_type
1067                data = print_.tensor.tensor_content
1068                np_type = tensor_to_np_type[data_type]
1069                param_data = np.fromstring(data, np_type)
1070                ms_type = tensor_to_ms_type[data_type]
1071                if dims and dims != [0]:
1072                    param_value = param_data.reshape(dims)
1073                    tensor_list.append(Tensor(param_value, ms_type))
1074                # Scalar type
1075                else:
1076                    data_type_ = data_type.lower()
1077                    if 'float' in data_type_:
1078                        param_data = float(param_data[0])
1079                    elif 'int' in data_type_:
1080                        param_data = int(param_data[0])
1081                    elif 'bool' in data_type_:
1082                        param_data = bool(param_data[0])
1083                    tensor_list.append(Tensor(param_data, ms_type))
1084
1085    except BaseException as e:
1086        logger.error("Failed to load the print file %s.", print_list)
1087        raise RuntimeError(e.__str__())
1088
1089    return tensor_list
1090
1091
1092def _merge_param_with_strategy(sliced_data, parameter_name, strategy, is_even):
1093    """
1094    Merge data slices to one tensor with whole data when strategy is not None.
1095
1096    Args:
1097        sliced_data (list[numpy.ndarray]): Data slices in order of rank_id.
1098        parameter_name (str): Name of parameter.
1099        strategy (dict): Parameter slice strategy.
1100        is_even (bool): Slice manner that True represents slicing evenly and False represents slicing unevenly.
1101
1102    Returns:
1103        Tensor, the merged Tensor which has the whole data.
1104
1105    Raises:
1106        ValueError: Failed to merge.
1107    """
1108    layout = strategy.get(parameter_name)
1109    try:
1110        dev_mat = list(layout.dev_matrix[0].dim)
1111        tensor_map = list(layout.tensor_map[0].dim)
1112        param_split_shape = list(layout.param_split_shape[0].dim)
1113        field_size = int(layout.field)
1114    except BaseException as e:
1115        raise ValueError(f"{e.__str__()}. Please make sure that strategy matches the node_strategy.proto.")
1116
1117    device_count = 1
1118    for dim in dev_mat:
1119        device_count *= dim
1120
1121    if len(sliced_data) != device_count:
1122        raise ValueError(f"The sliced_parameters length should be equal to device_count. "
1123                         f"the sliced_parameters length is {len(sliced_data)} but device_count is {device_count}.")
1124
1125    if not param_split_shape:
1126        if not is_even:
1127            raise ValueError("The shape of every parameter in sliced_parameters should be the same "
1128                             "when slice manner is even.")
1129
1130        all_gather_tensor = Tensor(np.concatenate(sliced_data))
1131
1132        if field_size > 0:
1133            merged_tensor = _reshape_param_data_with_weight(all_gather_tensor, dev_mat, field_size)
1134        else:
1135            merged_tensor = _reshape_param_data(all_gather_tensor, dev_mat, tensor_map)
1136
1137    else:
1138        tensor_strategy = _get_tensor_strategy(dev_mat, tensor_map)
1139
1140        slice_count = 1
1141        for dim in tensor_strategy:
1142            slice_count *= dim
1143
1144        if len(param_split_shape) != slice_count:
1145            raise ValueError(f"The param_split_shape length in strategy should be {slice_count}, "
1146                             f"but got {len(param_split_shape)}.")
1147
1148        tensor_slices_new = list(range(slice_count))
1149        tensor_slices = sliced_data
1150        for i in range(device_count):
1151            slice_index = int(_get_tensor_slice_index(dev_mat, tensor_strategy, tensor_map, i))
1152            if tensor_slices[i].shape[0] != param_split_shape[slice_index]:
1153                raise ValueError(f"The slice {slice_index} is {param_split_shape[slice_index]} in 0 axis, "
1154                                 f"but got {tensor_slices[i].shape[0]}.")
1155            tensor_slices_new[slice_index] = np.array(tensor_slices[i])
1156
1157        dim_len = len(tensor_strategy)
1158        for i in range(dim_len):
1159            ele_count = int(len(tensor_slices_new) / tensor_strategy[dim_len - 1 - i])
1160            tensor_slices_new_inner = []
1161            for j in range(ele_count):
1162                new_tensor = tensor_slices_new[j * tensor_strategy[dim_len - 1 - i]]
1163                for l in range(j * tensor_strategy[dim_len - 1 - i] + 1,
1164                               (j + 1) * tensor_strategy[dim_len - 1 - i]):
1165                    new_tensor = np.concatenate((new_tensor, tensor_slices_new[l]), axis=dim_len - 1 - i)
1166                tensor_slices_new_inner.insert(len(tensor_slices_new_inner), np.array(new_tensor))
1167            tensor_slices_new = tensor_slices_new_inner
1168        merged_tensor = Tensor(tensor_slices_new[0])
1169
1170    return merged_tensor
1171
1172
1173def build_searched_strategy(strategy_filename):
1174    """
1175    Build strategy of every parameter in network. Used in the case of distributed inference.
1176    For details of merge_sliced_parameter, please check:
1177    `Enabling Graph-Accounting Convergence <https://www.mindspore.cn/docs/programming_guide
1178    /en/r1.5/save_load_model_hybrid_parallel.html>`_.
1179
1180    Args:
1181        strategy_filename (str): Name of strategy file.
1182
1183    Returns:
1184        Dict, whose key is parameter name and value is slice strategy of this parameter.
1185
1186    Raises:
1187        ValueError: Strategy file is incorrect.
1188        TypeError: strategy_filename is not str.
1189
1190    Examples:
1191        >>> strategy = build_searched_strategy("./strategy_train.ckpt")
1192    """
1193    if not isinstance(strategy_filename, str):
1194        raise TypeError(f"The strategy_filename should be str, but got {type(strategy_filename)}.")
1195
1196    if not os.path.isfile(strategy_filename):
1197        raise ValueError(f"No such strategy file: {strategy_filename}.")
1198
1199    if os.path.getsize(strategy_filename) == 0:
1200        raise ValueError("The strategy file should not be empty.")
1201
1202    parallel_strategy_map = ParallelStrategyMap()
1203
1204    with open(strategy_filename, 'rb') as f:
1205        pb_content = f.read()
1206    parallel_strategy_map.ParseFromString(pb_content)
1207
1208    layout_items = parallel_strategy_map.parallel_layout_item
1209    if not layout_items:
1210        raise ValueError("The strategy file has no sliced parameter.")
1211
1212    strategy = {}
1213    for layout_item in layout_items:
1214        parameter_name = layout_item.param_name
1215        layout = layout_item.parallel_layouts
1216        strategy[parameter_name] = layout
1217
1218    return strategy
1219
1220
1221def merge_sliced_parameter(sliced_parameters, strategy=None):
1222    """
1223    Merge parameter slices into one parameter. Used in the case of distributed inference.
1224    For details of merge_sliced_parameter, please check:
1225    `Enabling Graph-Accounting Convergence <https://www.mindspore.cn/docs/programming_guide
1226    /en/r1.5/save_load_model_hybrid_parallel.html>`_.
1227
1228    Args:
1229        sliced_parameters (list[Parameter]): Parameter slices in order of rank_id.
1230        strategy (Optional[dict]): Parameter slice strategy, whose key is parameter name and
1231            value is slice strategy of this parameter. If strategy is None, just merge
1232            parameter slices in 0 axis order. Default: None.
1233
1234    Returns:
1235        Parameter, the merged parameter which has the whole data.
1236
1237    Raises:
1238        ValueError: Failed to merge.
1239        TypeError: The sliced_parameters is incorrect or strategy is not dict.
1240        KeyError: The parameter name is not in keys of strategy.
1241
1242    Examples:
1243        >>> import numpy as np
1244        >>> from mindspore import Tensor, merge_sliced_parameter, Parameter
1245        >>>
1246        >>> sliced_parameters = [
1247        ...                      Parameter(Tensor(np.array([0.00023915, 0.00013939, -0.00098059])),
1248        ...                                "network.embedding_table"),
1249        ...                      Parameter(Tensor(np.array([0.00015815, 0.00015458, -0.00012125])),
1250        ...                                "network.embedding_table"),
1251        ...                      Parameter(Tensor(np.array([0.00042165, 0.00029692, -0.00007941])),
1252        ...                                "network.embedding_table"),
1253        ...                      Parameter(Tensor(np.array([0.00084451, 0.00089960, -0.00010431])),
1254        ...                                "network.embedding_table")]
1255        >>> merged_parameter = merge_sliced_parameter(sliced_parameters)
1256        >>> print(merged_parameter)
1257        Parameter (name=network.embedding_table, shape=(12,), dtype=Float64, requires_grad=True)
1258    """
1259    if not isinstance(sliced_parameters, list):
1260        raise TypeError(f"The sliced_parameters should be list, but got {type(sliced_parameters)}.")
1261
1262    if not sliced_parameters:
1263        raise ValueError("The sliced_parameters should not be empty.")
1264
1265    if strategy and not isinstance(strategy, dict):
1266        raise TypeError(f"The strategy should be dict, but got {type(strategy)}.")
1267
1268    try:
1269        parameter_name = sliced_parameters[0].name
1270        parameter_shape = sliced_parameters[0].data.shape
1271        parameter_shape_length = len(parameter_shape)
1272    except BaseException as e:
1273        raise TypeError(f"{e.__str__()}. the element in sliced_parameters should be Parameter.")
1274
1275    is_even = True
1276    for index, parameter in enumerate(sliced_parameters):
1277        if not isinstance(parameter, Parameter):
1278            raise TypeError(f"The element in sliced_parameters should be Parameter, "
1279                            f"but got {type(parameter)} at index {index}.")
1280
1281        if parameter.name != parameter_name \
1282                or len(parameter.data.shape) != parameter_shape_length \
1283                or parameter.data.shape[1:] != parameter_shape[1:]:
1284            raise ValueError("Please make sure that the elements in slice_parameters have the same name, "
1285                             "dimension length and shape except 0 axis")
1286
1287        if parameter.data.shape != parameter_shape:
1288            is_even = False
1289
1290    layerwise_parallel = sliced_parameters[0].layerwise_parallel
1291    requires_grad = sliced_parameters[0].requires_grad
1292    sliced_data = [parameter.data.asnumpy() for parameter in sliced_parameters]
1293
1294    if not strategy:
1295        merged_tensor = Tensor(np.concatenate(sliced_data))
1296        merged_parameter = Parameter(merged_tensor, parameter_name, requires_grad, layerwise_parallel)
1297
1298    else:
1299        if parameter_name not in strategy.keys():
1300            raise KeyError(f"The parameter name should be one key of strategy. "
1301                           f"the parameter name is {parameter_name}.")
1302        merged_tensor = _merge_param_with_strategy(sliced_data, parameter_name, strategy, is_even)
1303        merged_parameter = Parameter(merged_tensor, parameter_name, requires_grad, layerwise_parallel)
1304
1305    return merged_parameter
1306
1307
1308def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=None,
1309                                train_strategy_filename=None, strict_load=False, dec_key=None, dec_mode='AES-GCM'):
1310    """
1311    Load checkpoint into net for distributed predication. Used in the case of distributed inference.
1312    For details of distributed inference, please check:
1313    `Enabling Graph-Accounting Convergence <https://www.mindspore.cn/docs/programming_guide
1314    /en/r1.5/distributed_inference.html>`_.
1315
1316    Args:
1317        network (Cell): Network for distributed predication.
1318        checkpoint_filenames (list[str]): The name of Checkpoint files in order of rank id.
1319        predict_strategy (dict): Strategy of predication process, whose key is parameter name, and value is a list or
1320            a tuple that the first four elements are [dev_matrix, tensor_map, param_split_shape, field]. If None,
1321            it means that the predication process just uses single device. Default: None.
1322        train_strategy_filename (str): Train strategy proto file name. Default: None.
1323        strict_load (bool): Whether to strict load the parameter into net. If False, it will load parameter
1324                            into net when parameter name's suffix in checkpoint file is the same as the
1325                            parameter in the network. When the types are inconsistent perform type conversion
1326                            on the parameters of the same type, such as float32 to float16. Default: False.
1327        dec_key (Union[None, bytes]): Byte type key used for decryption. If the value is None, the decryption
1328                                      is not required. Default: None.
1329        dec_mode (str): This parameter is valid only when dec_key is not set to None. Specifies the decryption
1330                        mode, currently supports 'AES-GCM' and 'AES-CBC'. Default: 'AES-GCM'.
1331
1332    Raises:
1333        TypeError: The type of inputs do not match the requirements.
1334        ValueError: Failed to load checkpoint into net.
1335    """
1336    network = Validator.check_isinstance("network", network, nn.Cell)
1337    _check_checkpoint_file(checkpoint_filenames)
1338    _check_predict_strategy(predict_strategy)
1339
1340    dec_key = Validator.check_isinstance('dec_key', dec_key, (type(None), bytes))
1341    dec_mode = Validator.check_isinstance('dec_mode', dec_mode, str)
1342
1343    if train_strategy_filename is None:
1344        train_strategy_filename = context.get_auto_parallel_context("strategy_ckpt_load_file")
1345    _train_strategy = build_searched_strategy(train_strategy_filename)
1346    train_strategy = _convert_to_list(_train_strategy)
1347
1348    train_dev_count = 1
1349    ckpt_file_len = len(checkpoint_filenames)
1350    for dim in train_strategy[list(train_strategy.keys())[0]][0]:
1351        train_dev_count *= dim
1352    if train_dev_count != ckpt_file_len:
1353        raise ValueError(
1354            f"The length of checkpoint_filenames should be equal to the device count of training process. "
1355            f"The length is {ckpt_file_len} but the device count is {train_dev_count}.")
1356
1357    rank_list = _infer_rank_list(train_strategy, predict_strategy)
1358
1359    param_total_dict = defaultdict(dict)
1360    for file_index, file_name in enumerate(checkpoint_filenames):
1361        ckpt_dict = load_checkpoint(file_name, dec_key=dec_key, dec_mode=dec_mode)
1362        for param_name, param in ckpt_dict.items():
1363            param_total_dict[param_name][file_index] = param
1364
1365    param_dict = {}
1366    param_not_in_strategy = []
1367    param_not_in_ckpt = []
1368    for _, param in network.parameters_and_names():
1369        sliced_params = []
1370        if param.name not in rank_list.keys():
1371            param_not_in_strategy.append(param.name)
1372            continue
1373        if param.name not in param_total_dict:
1374            param_not_in_ckpt.append(param.name)
1375            continue
1376
1377        param_rank = rank_list[param.name][0]
1378        skip_merge_split = rank_list[param.name][1]
1379        shard_stride = train_strategy[param.name][4]
1380        if train_strategy[param.name][5]:
1381            shard_size = ckpt_file_len / shard_stride / train_strategy[param.name][5]
1382        else:
1383            shard_size = 0
1384        for rank in param_rank:
1385            param_total_list = list(range(0, ckpt_file_len))
1386            if shard_size > 0:
1387                shard_total_list = [param_total_list[i:i + shard_size] for i in
1388                                    range(0, ckpt_file_len, shard_size)]
1389                param_total_list = shard_total_list[rank // shard_size]
1390            if shard_stride > 0:
1391                param_stride = []
1392                # merge pre parameter
1393                param_index = param_total_list[0:param_total_list.index(rank) + 1][::-1][::shard_stride]
1394                param_index.extend(param_total_list[param_total_list.index(rank):][::shard_stride])
1395                param_index = list(set(param_index))
1396                param_index.sort()
1397                for rank_num in param_index:
1398                    param_stride.append(param_total_dict[param.name][rank_num].data.asnumpy())
1399
1400                sliced_param = Parameter(Tensor(np.concatenate(param_stride)), name=param.name)
1401            else:
1402                sliced_param = param_total_dict[param.name][rank]
1403
1404            sliced_params.append(sliced_param)
1405        if skip_merge_split:
1406            split_param = sliced_params[0]
1407        else:
1408            param_unique_strategy = _remove_repeated_slices(train_strategy[param.name])
1409            _param_unique_strategy = _convert_to_layout(param.name, param_unique_strategy)
1410            split_param = _merge_and_split(sliced_params, _param_unique_strategy, predict_strategy)
1411        opt_shard_group = predict_strategy[param.name][5] if predict_strategy else None
1412        if opt_shard_group:
1413            data = split_param.data.asnumpy()
1414            rank = get_rank(opt_shard_group)
1415            size = get_group_size(opt_shard_group)
1416            try:
1417                data_slice = np.split(data, size)[rank]
1418            except BaseException as e:
1419                logger.error("Failed to load opt shard slice in load distributed checkpoint for {}. Data shape is {}"
1420                             " and group is {}".format(param.name, split_param.data.shape, opt_shard_group))
1421                raise RuntimeError(e.__str__())
1422            split_param = Parameter(Tensor(data_slice), param.name,
1423                                    split_param.requires_grad, split_param.layerwise_parallel)
1424        param_dict[param.name] = split_param
1425
1426    if param_not_in_strategy:
1427        logger.warning("{} parameters in network are not in the slice strategy.".format(param_not_in_strategy))
1428    if param_not_in_ckpt:
1429        logger.warning("{} parameters in slice strategy but not in the checkpoint file.".format(param_not_in_ckpt))
1430
1431    load_param_into_net(network, param_dict, strict_load=strict_load)
1432
1433
1434def async_ckpt_thread_status():
1435    """
1436    Get the status of asynchronous save checkpoint thread.
1437
1438    When performing asynchronous save checkpoint, you can get the thread state through this function
1439    to ensure that write checkpoint file are completed.
1440
1441    Returns:
1442        True, Asynchronous save checkpoint thread is running.
1443        False, Asynchronous save checkpoint thread is not executing.
1444    """
1445    thr_list = threading.enumerate()
1446    return True in [ele.getName() == "asyn_save_ckpt" for ele in thr_list]
1447
1448
1449def _check_predict_strategy(predict_strategy):
1450    """Check predict strategy."""
1451    def _check_int_list(arg):
1452        if not isinstance(arg, list):
1453            return False
1454        for item in arg:
1455            if not isinstance(item, int):
1456                return False
1457        return True
1458
1459    if predict_strategy is None:
1460        return
1461
1462    flag = True
1463    predict_strategy = Validator.check_isinstance("predict_strategy", predict_strategy, dict)
1464    for key in predict_strategy.keys():
1465        if not isinstance(key, str) or not isinstance(predict_strategy[key], (list, tuple)) \
1466                or len(predict_strategy[key]) < 4:
1467            flag = False
1468        dev_matrix, tensor_map, param_split_shape, field_size = predict_strategy[key][:4]
1469        if not _check_int_list(dev_matrix) or not _check_int_list(tensor_map) or \
1470                not (_check_int_list(param_split_shape) or not param_split_shape) or \
1471                not (isinstance(field_size, int) and field_size == 0):
1472            flag = False
1473
1474    if not flag:
1475        raise ValueError(f"Please make sure that the key of predict_strategy is str, "
1476                         f"and the value is a list or a tuple that the first four elements are "
1477                         f"dev_matrix (list[int]), tensor_map (list[int]), "
1478                         f"param_split_shape (list[int]) and field_size (zero).")
1479
1480
1481def _check_checkpoint_file(checkpoint_filenames):
1482    """Check checkpoint file name."""
1483    for index, filename in enumerate(checkpoint_filenames):
1484        if not isinstance(filename, str) or not os.path.exists(filename) \
1485                or filename[-5:] != ".ckpt" or os.path.getsize(filename) == 0:
1486            raise ValueError(f"Please make sure that the {filename} at index {index} is a valid checkpoint file.")
1487
1488
1489def _convert_to_list(strategy):
1490    """Convert ParallelLayouts object to specified list."""
1491    train_map = {}
1492    for param_name in strategy.keys():
1493        try:
1494            layout = strategy.get(param_name)
1495            dev_mat = list(layout.dev_matrix[0].dim)
1496            tensor_map = list(layout.tensor_map[0].dim)
1497            param_split_shape = list(layout.param_split_shape[0].dim)
1498            field_size = int(layout.field)
1499            shard_stride = int(layout.opt_weight_shard_step)
1500            shard_size = int(layout.opt_weight_shard_size)
1501            train_map[param_name] = [dev_mat, tensor_map, param_split_shape, field_size, shard_stride, shard_size]
1502        except BaseException as e:
1503            raise ValueError(f"{e.__str__()}. Please make sure that strategy matches the node_strategy.proto.")
1504    return train_map
1505
1506
1507def _convert_to_layout(param_name, tensor_layout):
1508    """Convert list to ParallelLayouts object."""
1509    strategy = {}
1510    try:
1511        layout = ParallelLayouts()
1512        layout.field = tensor_layout[3]
1513
1514        dev_matrix = layout.dev_matrix.add()
1515        for item in tensor_layout[0]:
1516            dev_matrix.dim.append(item)
1517
1518        tensor_map = layout.tensor_map.add()
1519        for item in tensor_layout[1]:
1520            tensor_map.dim.append(item)
1521
1522        param_split_shape = layout.param_split_shape.add()
1523        for item in tensor_layout[2]:
1524            param_split_shape.dim.append(item)
1525    except BaseException as e:
1526        raise ValueError("Convert failed. " + e.__str__())
1527
1528    strategy[param_name] = layout
1529    return strategy
1530
1531
1532def _merge_and_split(sliced_params, train_strategy, predict_strategy):
1533    """Merge sliced parameter and split it according to the predict strategy."""
1534    merged_param = merge_sliced_parameter(sliced_params, train_strategy)
1535    if predict_strategy is None:
1536        return merged_param
1537    param_name = merged_param.name
1538    tensor_layout = predict_strategy[param_name]
1539    split_tensor = _load_tensor(merged_param.data, tensor_layout[0], tensor_layout[1])
1540    requires_grad = merged_param.requires_grad
1541    layerwise_parallel = merged_param.layerwise_parallel
1542    split_param = Parameter(split_tensor, param_name, requires_grad, layerwise_parallel)
1543    return split_param
1544
1545
1546def _calculation_net_size(net):
1547    """Calculate the size of parameters in the network."""
1548    data_total = 0
1549    net_dict = net.parameters_dict()
1550    for name in net_dict:
1551        data_total += sys.getsizeof(net_dict[name].data.asnumpy().tobytes()) / 1024
1552
1553    return data_total
1554