• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2023 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Transform distributed checkpoint"""
16from __future__ import absolute_import
17
18import os
19import glob
20import copy
21from collections import defaultdict
22import numpy as np
23import mindspore as ms
24from mindspore.common import dtype as mstype
25from mindspore.parallel._utils import _is_in_auto_parallel_mode
26from mindspore.parallel._parallel_serialization import _rank_list_for_transform_parallel_checkpoint, \
27    _transform_parallel_checkpoint, _get_device_num_from_strategy, _make_dir, \
28    _extract_layout_map, _extract_src_dst_layout_map, _parameter_not_in_local_stage, _extract_pipeline_stage_num, \
29    _merge_protobuf_strategy, _merge_json_strategy, _extract_src_dst_layout_map_by_src
30
31
32__all__ = ["merge_pipeline_strategys", "rank_list_for_transform", "transform_checkpoint_by_rank",
33           "transform_checkpoints", "sync_pipeline_shared_parameters", "load_segmented_checkpoints"]
34
35
36def merge_pipeline_strategys(src_strategy_dirs, dst_strategy_file):
37    """
38    Merge parallel strategy between all pipeline stages in pipeline parallel mode.
39    For more details about converting distributed Checkpoint, please refer to
40    `Model Transformation <https://www.mindspore.cn/tutorials/experts/en/master/parallel/model_transformation.html>`_.
41
42    Note:
43        Strategy file of each pipeline stage should be included in src_strategy_dirs.
44
45    Args:
46        src_strategy_dirs (str): The directory of strategy files including all pipeline stage which is saved by
47                                 'mindspore.set_auto_parallel_context(strategy_ckpt_save_file)'.
48        dst_strategy_file (str): The file merged strategy to save.
49
50    Raises:
51        NotADirectoryError: `src_strategy_dirs` is not a directory.
52
53    Examples:
54        >>> import mindspore as ms
55        >>> # src_strategy_dir/stra0.ckpt, src_strategy_dir/stra1.ckpt ... src_strategy_dir/stra127.ckpt
56        >>> ms.merge_pipeline_strategys("./src_strategy_dir", "./dst_strategy.ckpt")
57
58    """
59    dst_strategy_dir, _ = os.path.split(dst_strategy_file)
60    if not os.path.exists(dst_strategy_dir):
61        _make_dir(dst_strategy_dir, "path")
62    if not os.path.isdir(src_strategy_dirs):
63        raise NotADirectoryError("src_strategy_dirs {} is not a directory.".format(src_strategy_dirs))
64    src_strategy_files_protobuf = glob.glob(os.path.join(src_strategy_dirs, "*.ckpt"))
65    src_strategy_files_json = glob.glob(os.path.join(src_strategy_dirs, "*.json"))
66    if src_strategy_files_protobuf and src_strategy_files_json:
67        raise ValueError("The strategys format should be all '.ckpt' or all '.json'")
68    is_protobuf = len(src_strategy_files_protobuf) > 0
69    if is_protobuf:
70        _merge_protobuf_strategy(src_strategy_files_protobuf, dst_strategy_file)
71    else:
72        _merge_json_strategy(src_strategy_files_json, dst_strategy_file)
73
74
75
76def rank_list_for_transform(rank_id, src_strategy_file=None, dst_strategy_file=None):
77    """
78    List of original distributed checkpoint rank index for obtaining the target checkpoint of a rank_id during the
79    distributed checkpoint conversion. For more details about converting distributed Checkpoint, please refer to
80    `Model Transformation <https://www.mindspore.cn/tutorials/experts/en/master/parallel/model_transformation.html>`_.
81
82    Args:
83        rank_id (int): The rank of which distributed checkpoint needs to be obtained after conversion.
84        src_strategy_file (str): Name of source sharding strategy file which saved by
85                                 `mindspore.set_auto_parallel_context(strategy_ckpt_save_file)`.
86                                 when the `src_strategy_file` is ``None``, it means that the source sharding strategy is
87                                 without any sharing for each parameter. Default: ``None``.
88        dst_strategy_file (str): Name of destination sharding strategy file which saved by
89                                 `mindspore.set_auto_parallel_context(strategy_ckpt_save_file)`.
90                                 when the `dst_strategy_file` is ``None``,
91                                 it means that the destination sharding strategy
92                                 is without any sharing for each parameter. Default: ``None``.
93
94    Returns:
95        List, the rank list required for converting the distributed checkpoint of rank_id.
96
97    Raises:
98        ValueError: `src_strategy_file` or `dst_strategy_file` is incorrect.
99        TypeError: `src_strategy_file` or `dst_strategy_file` is not a string.
100        TypeError: `rank_id` is not an int.
101
102    Examples:
103        >>> import mindspore as ms
104        >>> rank_id = 0
105        >>> rank_list = ms.rank_list_for_transform(rank_id, "./src_strategy.ckpt", "./dst_strategy.ckpt")
106        >>> checkpoint_files_map = {}
107        >>> for rank in rank_list:
108        ...     checkpoint_files_map[rank] = "./pangu{}-100_2.ckpt".format(rank)
109
110    """
111    if not isinstance(rank_id, int):
112        raise TypeError("The rank_id should be a int.")
113    if src_strategy_file is None:
114        return [0]
115    src_strategy_list, dst_strategy_list = _extract_src_dst_layout_map(rank_id, src_strategy_file, dst_strategy_file)
116    src_stage_device_num = np.prod(src_strategy_list.get(list(src_strategy_list.keys())[0])[0]) if src_strategy_list \
117                                                                                                   is not None else 1
118    dst_stage_device_num = np.prod(dst_strategy_list.get(list(dst_strategy_list.keys())[0])[0]) if dst_strategy_list \
119                                                                                                   is not None else 1
120
121    if not src_strategy_list:
122        raise ValueError("The src_strategy_file is empty.")
123    local_rank_id = rank_id % dst_stage_device_num if dst_stage_device_num > 1 else rank_id
124    needed_rank_list_in_local_stage = _rank_list_for_transform_parallel_checkpoint(local_rank_id,
125                                                                                   src_strategy_list, dst_strategy_list)
126    result_set = set()
127    handled_pipeline_stage = []
128    for _, layout in src_strategy_list.items():
129        for src_pipeline_stage_id in layout[6]:
130            if src_pipeline_stage_id in handled_pipeline_stage:
131                continue
132            src_rank_id_start = src_pipeline_stage_id * src_stage_device_num
133            result_set.update([src_rank_id_start + rank for rank in needed_rank_list_in_local_stage])
134            handled_pipeline_stage.append(src_pipeline_stage_id)
135    return list(result_set)
136
137
138def transform_checkpoint_by_rank(rank_id, checkpoint_files_map, save_checkpoint_file_name,
139                                 src_strategy_file=None, dst_strategy_file=None):
140    """
141    Transform distributed checkpoint from source sharding strategy to destination sharding strategy by rank
142    for a network. For more details about converting distributed Checkpoint, please refer to
143    `Model Transformation <https://www.mindspore.cn/tutorials/experts/en/master/parallel/model_transformation.html>`_.
144
145    Args:
146        rank_id (int): The rank of which distributed checkpoint needs to be obtained after conversion.
147        checkpoint_files_map (dict): The checkpoint files map whose key is the rank id and the value is
148                                     the checkpoint file name.
149        save_checkpoint_file_name (str): The file name to save the converted checkpoint.
150        src_strategy_file (str): Name of source sharding strategy file which saved by
151                                 'mindspore.set_auto_parallel_context(strategy_ckpt_save_file)'.
152                                 when the `src_strategy_file` is None, it means that the source sharding strategy is
153                                 without any sharing for each parameter. Default: ``None``.
154        dst_strategy_file (str): Name of destination sharding strategy file which saved by
155                                 'mindspore.set_auto_parallel_context(strategy_ckpt_save_file)'.
156                                 when the `dst_strategy_file` is ``None``,
157                                 it means that the destination sharding strategy
158                                 is without any sharing for each parameter. Default: ``None``.
159
160    Raises:
161        ValueError: `src_strategy_file` or `dst_strategy_file` is incorrect.
162        ValueError: item in `checkpoint_files_map` is incorrect.
163        ValueError: `save_checkpoint_file_name` is not end with ".ckpt".
164        TypeError: `checkpoint_files_map` is not a dict.
165        TypeError: `src_strategy_file` or `dst_strategy_file` is not a string.
166        TypeError: `rank_id` is not an int.
167        TypeError: `save_checkpoint_file_name` is not a string.
168
169    Examples:
170        >>> import mindspore as ms
171        >>> dst_device_num = 8
172        >>> for rank_id in range(dst_device_num):
173        ...     rank_list = ms.rank_list_for_transform(rank_id, "./src_strategy.ckpt", "./dst_strategy.ckpt")
174        ...     checkpoint_files_map = {}
175        ...     for rank in rank_list:
176        ...         checkpoint_files_map[rank] = "./origin_checkpoint_rank{}/pangu{}-100_2.ckpt".format(rank)
177        ...     save_checkpoint_file_name = "./new_checkpoint_rank{}/pangu{}-100_2.ckpt".format(rank_id)
178        ...     ms.transform_checkpoint_by_rank(rank_id, checkpoint_files_map, save_checkpoint_file_name,
179        ...                                  "./src_strategy.ckpt", "./dst_strategy.ckpt")
180
181    """
182    if not isinstance(checkpoint_files_map, dict):
183        raise TypeError("The checkpoint_files_map should be a dict.")
184    if not isinstance(rank_id, int):
185        raise TypeError("The rank_id should be a int.")
186    if not isinstance(save_checkpoint_file_name, str):
187        raise TypeError("The save_checkpoint_file_name should be a str.")
188    if save_checkpoint_file_name[-5:] != ".ckpt":
189        raise ValueError("The save_checkpoint_file_name {} should end with .ckpt".format(save_checkpoint_file_name))
190    if dst_strategy_file and os.path.dirname(dst_strategy_file) and not os.path.exists(
191            os.path.dirname(dst_strategy_file)):
192        raise ValueError("The director of dst_strategy_file: {} is not exists.".
193                         format(os.path.dirname(dst_strategy_file)))
194    for rank, local_file in checkpoint_files_map.items():
195        if not os.path.exists(local_file):
196            raise ValueError("Checkpoint file {} in rank {} not exits: ".format(local_file, rank))
197    param_total_dict = defaultdict(dict)
198    param_attr_dict = defaultdict(dict)
199    param_type_dict = defaultdict(dict)
200    src_strategy_list, dst_strategy_list = _extract_src_dst_layout_map(rank_id, src_strategy_file, dst_strategy_file)
201    # src rank => local rank inside pipeline stage
202    src_stage_device_num = np.prod(src_strategy_list.get(list(src_strategy_list.keys())[0])[0]) if src_strategy_list \
203                                                                                                   is not None else 1
204    dst_stage_device_num = np.prod(dst_strategy_list.get(list(dst_strategy_list.keys())[0])[0]) if dst_strategy_list \
205                                                                                                   is not None else 1
206    origin_dst_strategy_list = _extract_layout_map(dst_strategy_file)
207    origin_src_strategy_list = _extract_layout_map(src_strategy_file)
208    for rank, file_name in checkpoint_files_map.items():
209        ckpt_dict = ms.load_checkpoint(file_name)
210        for param_name, param in ckpt_dict.items():
211            # cut the parameter not in the pipeline stage.
212            if _parameter_not_in_local_stage(param_name, origin_src_strategy_list, src_strategy_list) \
213                    and _parameter_not_in_local_stage(param_name, origin_dst_strategy_list, dst_strategy_list):
214                continue
215            src_rank = rank % src_stage_device_num
216            param_type_dict[param_name][src_rank] = str(param.data.dtype)
217            if param.data.dtype == mstype.bfloat16:
218                param.set_dtype(mstype.float32)
219            param_total_dict[param_name][src_rank] = param.data.asnumpy()
220            param_attr_dict[param_name][src_rank] = (param.requires_grad, param.layerwise_parallel)
221    local_rank_id = rank_id % dst_stage_device_num
222    transform_param_list = _transform_parallel_checkpoint(local_rank_id, param_total_dict,
223                                                          param_attr_dict, src_strategy_list, dst_strategy_list,
224                                                          param_type_dict)
225    ms.save_checkpoint(transform_param_list, save_checkpoint_file_name)
226
227
228def _transform_checkpoint_by_stage(src_checkpoints_dir, dst_checkpoints_dir, ckpt_prefix, src_strategy_file,
229                                   dst_strategy_file=None):
230    """Transform checkpoint for stage in src_strategy_file"""
231    param_total_dict = defaultdict(dict)
232    param_attr_dict = defaultdict(dict)
233    param_type_dict = defaultdict(dict)
234    src_strategy_list, dst_strategy_list, stage_id = _extract_src_dst_layout_map_by_src(src_strategy_file, \
235                                                                                             dst_strategy_file)
236    src_stage_device_num = np.prod(src_strategy_list.get(list(src_strategy_list.keys())[0])[0]) if src_strategy_list \
237                                                                                                   is not None else 1
238    dst_stage_device_num = np.prod(dst_strategy_list.get(list(dst_strategy_list.keys())[0])[0]) if dst_strategy_list \
239                                                                                                   is not None else 1
240    origin_dst_strategy_list = _extract_layout_map(dst_strategy_file)
241    origin_src_strategy_list = _extract_layout_map(src_strategy_file)
242    checkpoint_files_map = {}
243    src_rank_id_start = stage_id * src_stage_device_num
244    for local_rank in range(src_stage_device_num):
245        rank_id = src_rank_id_start + local_rank
246        checkpoint_file_name = os.path.join(src_checkpoints_dir, "rank_{}".format(rank_id), "*.ckpt")
247        rank_ckpts = glob.glob(checkpoint_file_name)
248        rank_ckpts.sort()
249        for checkpoint_file in rank_ckpts:
250            if not os.path.isfile(checkpoint_file):
251                ms.log.warning("{} is not a checkpoint file.".format(checkpoint_file))
252                continue
253            checkpoint_files_map[rank_id] = checkpoint_file
254    for rank, local_file in checkpoint_files_map.items():
255        if not os.path.exists(local_file):
256            raise ValueError("Checkpoint file {} in rank {} not exits: ".format(local_file, rank))
257    for rank, file_name in checkpoint_files_map.items():
258        ckpt_dict = ms.load_checkpoint(file_name)
259        for param_name, param in ckpt_dict.items():
260            # cut the parameter not in the pipeline stage.
261            if _parameter_not_in_local_stage(param_name, origin_src_strategy_list, src_strategy_list) \
262                    and _parameter_not_in_local_stage(param_name, origin_dst_strategy_list, dst_strategy_list):
263                continue
264            src_rank = rank % src_stage_device_num
265            param_type_dict[param_name][src_rank] = str(param.data.dtype)
266            if param.data.dtype == mstype.bfloat16:
267                param.set_dtype(mstype.float32)
268            param_total_dict[param_name][src_rank] = param.data.asnumpy()
269            param_attr_dict[param_name][src_rank] = (param.requires_grad, param.layerwise_parallel)
270    for local_rank_id in range(dst_stage_device_num):
271        transform_param_list = _transform_parallel_checkpoint(local_rank_id, param_total_dict,
272                                                              param_attr_dict, src_strategy_list, dst_strategy_list,
273                                                              param_type_dict)
274        save_checkpoint_file = "{}{}_part{}.ckpt".format(ckpt_prefix, local_rank_id, stage_id)
275        save_checkpoint_file_dir = os.path.join(dst_checkpoints_dir, "rank_{}".format(local_rank_id))
276        if not os.path.exists(save_checkpoint_file_dir):
277            _make_dir(save_checkpoint_file_dir, "path")
278        save_checkpoint_file_name = os.path.join(save_checkpoint_file_dir, save_checkpoint_file)
279        ms.save_checkpoint(transform_param_list, save_checkpoint_file_name)
280
281
282def _transform_checkpoints(src_checkpoints_dir, dst_checkpoints_dir, ckpt_prefix, src_strategy_file=None,
283                           dst_strategy_file=None):
284    """Transform checkpoints for all stages in src_strategy_file"""
285    checkpoints_rank_dir_list = os.path.join(src_checkpoints_dir, "rank_[0-9]*")
286    all_checkpoint_files_map = {}
287    for checkpoint_dir in glob.glob(checkpoints_rank_dir_list):
288        if not os.path.isdir(checkpoint_dir):
289            ms.log.warning("{} is not a directory.".format(checkpoint_dir))
290            continue
291        rank_id_str = checkpoint_dir.split('rank_')[-1]
292        if not rank_id_str.isdigit():
293            ms.log.warning("{} is not a expected directory, the directory should end with rank_0/rank_1.....".
294                           format(checkpoint_dir))
295            continue
296        rank_id = int(rank_id_str)
297        checkpoint_file_name = os.path.join(checkpoint_dir, "*.ckpt")
298        rank_ckpts = glob.glob(checkpoint_file_name)
299        rank_ckpts.sort()
300        for checkpoint_file in rank_ckpts:
301            if not os.path.isfile(checkpoint_file):
302                ms.log.warning("{} is not a checkpoint file.".format(checkpoint_file))
303                continue
304            all_checkpoint_files_map[rank_id] = checkpoint_file
305
306    needed_rank_list_map = defaultdict(list)
307    dst_stage_device_num = _get_device_num_from_strategy(dst_strategy_file)
308    src_stage_device_num = _get_device_num_from_strategy(src_strategy_file)
309    dst_stage_num = _extract_pipeline_stage_num(dst_strategy_file)
310    dst_device_num = dst_stage_device_num * dst_stage_num
311    origin_src_strategy_list = _extract_layout_map(src_strategy_file)
312    origin_dst_strategy_list = _extract_layout_map(dst_strategy_file)
313    for rank in range(dst_device_num):
314        needed_rank_list = rank_list_for_transform(rank, src_strategy_file, dst_strategy_file)
315        for needed_rank in needed_rank_list:
316            if needed_rank not in all_checkpoint_files_map:
317                raise ValueError("The checkpoint file of rank{} is needed for converting rank{}'s checkpoint, "
318                                 "but it is missing.".format(needed_rank, rank))
319        needed_rank_list_key = "-".join([str(r) for r in needed_rank_list])
320        needed_rank_list_map[needed_rank_list_key].append(rank)
321    for needed_rank_list_key, transform_rank_list in needed_rank_list_map.items():
322        param_total_dict = defaultdict(dict)
323        param_attr_dict = defaultdict(dict)
324        param_type_dict = defaultdict(dict)
325        needed_rank_list = needed_rank_list_key.split("-")
326        for needed_rank in needed_rank_list:
327            ckpt_dict = ms.load_checkpoint(all_checkpoint_files_map.get(int(needed_rank)))
328            for param_name, param in ckpt_dict.items():
329                src_rank = int(needed_rank) % src_stage_device_num
330                param_type_dict[param_name][src_rank] = str(param.data.dtype)
331                if param.data.dtype == mstype.bfloat16:
332                    param.set_dtype(mstype.float32)
333                param_total_dict[param_name][src_rank] = param.data.asnumpy()
334                param_attr_dict[param_name][src_rank] = (param.requires_grad, param.layerwise_parallel)
335        for transform_rank in transform_rank_list:
336            param_total_dict_copy = copy.deepcopy(param_total_dict)
337            src_strategy_list, dst_strategy_list = _extract_src_dst_layout_map(transform_rank, src_strategy_file,
338                                                                               dst_strategy_file)
339            # cut the parameter not in the pipeline stage.
340            for param in list(param_total_dict_copy.keys()):
341                if _parameter_not_in_local_stage(param, origin_src_strategy_list, src_strategy_list) \
342                        and _parameter_not_in_local_stage(param, origin_dst_strategy_list, dst_strategy_list):
343                    param_total_dict_copy.pop(param)
344
345            local_rank_id = transform_rank % dst_stage_device_num
346            transform_param_list = _transform_parallel_checkpoint(local_rank_id, param_total_dict_copy,
347                                                                  param_attr_dict, src_strategy_list, dst_strategy_list,
348                                                                  param_type_dict)
349            save_checkpoint_file = "{}{}.ckpt".format(ckpt_prefix, transform_rank)
350            save_checkpoint_file_dir = os.path.join(dst_checkpoints_dir, "rank_{}".format(transform_rank))
351            if not os.path.exists(save_checkpoint_file_dir):
352                _make_dir(save_checkpoint_file_dir, "path")
353            save_checkpoint_file_name = os.path.join(save_checkpoint_file_dir, save_checkpoint_file)
354            ms.save_checkpoint(transform_param_list, save_checkpoint_file_name)
355            del param_total_dict_copy
356        del param_total_dict
357
358
359def transform_checkpoints(src_checkpoints_dir, dst_checkpoints_dir, ckpt_prefix, src_strategy_file=None,
360                          dst_strategy_file=None):
361    """
362    Transform distributed checkpoint from source sharding strategy to destination sharding strategy for a rank.
363    For more details about converting distributed Checkpoint, please refer to
364    `Model Transformation <https://www.mindspore.cn/tutorials/experts/en/master/parallel/model_transformation.html>`_.
365
366    Note:
367        The `src_checkpoints_dir` directory structure should be organized like "src_checkpoints_dir/rank_0/a.ckpt", the
368        rank number should be set to a subdirectory and the checkpoint file is stored in this subdirectory. If multiple
369        files exist in a rank directory, the last file in the lexicgraphic order would be selected.
370
371    Args:
372        src_checkpoints_dir (str): The source checkpoints directory.
373        dst_checkpoints_dir (str): The destination checkpoints directory to save the converted checkpoints.
374        ckpt_prefix (str): The destination checkpoint name prefix.
375        src_strategy_file (str): Name of source sharding strategy file which saved by
376                                 'mindspore.set_auto_parallel_context(strategy_ckpt_save_file)'.
377                                 when the 'src_strategy_file' is None, it means that the source sharding strategy is
378                                 without any sharing for each parameter. Default:None.
379        dst_strategy_file (str): Name of destination sharding strategy file which saved by
380                                 'mindspore.set_auto_parallel_context(strategy_ckpt_save_file)'.
381                                 when the 'dst_strategy_file' is None, it means that the destination sharding strategy
382                                 is without any sharing for each parameter. Default:None.
383
384    Raises:
385        ValueError: `src_strategy_file` or `dst_strategy_file` is incorrect.
386        NotADirectoryError: `src_checkpoints_dir` or `dst_checkpoints_dir` is not a directory.
387        ValueError: The checkpoint file is missing in `src_checkpoints_dir`.
388        TypeError: `src_strategy_file` or `dst_strategy_file` is not a string.
389
390    Examples:
391        >>> import mindspore as ms
392        >>> ms.transform_checkpoints(src_checkpoints_dir, dst_checkpoints_dir, "dst_checkpoint",
393        ...                       "./src_strategy.ckpt", "./dst_strategy.ckpt")
394
395    """
396    if not os.path.isdir(src_checkpoints_dir):
397        raise NotADirectoryError("src_checkpoints_dir {} is not a directory.".format(src_checkpoints_dir))
398    _make_dir(dst_checkpoints_dir, "path")
399    if not isinstance(ckpt_prefix, str):
400        raise TypeError("The ckpt_prefix should be a str.")
401    if src_strategy_file and os.path.dirname(src_strategy_file) and not os.path.exists(
402            os.path.dirname(src_strategy_file)):
403        raise ValueError("The director of src_strategy_file: {} is not exists.".
404                         format(os.path.dirname(src_strategy_file)))
405    if dst_strategy_file and os.path.dirname(dst_strategy_file) and not os.path.exists(
406            os.path.dirname(dst_strategy_file)):
407        raise ValueError("The director of dst_strategy_file: {} is not exists.".
408                         format(os.path.dirname(dst_strategy_file)))
409    src_layout_map = _extract_layout_map(src_strategy_file)
410    dst_layout_map = _extract_layout_map(dst_strategy_file)
411    pipeline_stage_num = _extract_pipeline_stage_num(src_strategy_file)
412    dst_stage_num = _extract_pipeline_stage_num(dst_strategy_file)
413    if src_layout_map:
414        src_param_keys = {param_name for param_name in src_layout_map if
415                          not param_name.startswith(("accu_grads", "adam_v", "adam_m"))}
416    if dst_layout_map:
417        dst_param_keys = {param_name for param_name in dst_layout_map if
418                          not param_name.startswith(("accu_grads", "adam_v", "adam_m"))}
419    layout_is_passed = src_layout_map and dst_layout_map
420
421    if layout_is_passed and pipeline_stage_num == 1 and dst_stage_num == 1 and \
422        src_param_keys.issubset(dst_param_keys) and len(src_param_keys) < len(dst_param_keys):
423        ms.log.info("Transform checkpoint by every pipeline stage.")
424        _transform_checkpoint_by_stage(src_checkpoints_dir, dst_checkpoints_dir, ckpt_prefix,
425                                       src_strategy_file, dst_strategy_file)
426    else:
427        ms.log.info("Transform checkpoints by all pipeline stage.")
428        _transform_checkpoints(src_checkpoints_dir, dst_checkpoints_dir, ckpt_prefix,
429                               src_strategy_file, dst_strategy_file)
430
431
432def _sync_params(name, param, layout):
433    """synchronize single parameter"""
434    if len(layout) < 10:
435        ms.log.warning("The layout dict does not contain the pipeline_shared_param info %s", name)
436        return
437
438    pipeline_shared = layout[8]
439    if not pipeline_shared:
440        return
441
442    is_send = layout[9]
443    peer_rank = layout[10]
444    sr_tag = layout[11]
445
446    class SharedParameterSyncCell(ms.nn.Cell):
447        """synchronize cell"""
448        def __init__(self, param, is_send, peer_rank, sr_tag):
449            super().__init__()
450            self.param = param
451            self.is_send = is_send
452            self.ret = ms.Tensor([0])
453
454            from mindspore.ops import Send, Receive
455            if self.is_send:
456                self.send = Send(sr_tag=sr_tag, dest_rank=peer_rank)
457            else:
458                self.receive = Receive(sr_tag=sr_tag, src_rank=peer_rank, shape=param.shape, dtype=param.dtype)
459
460        def construct(self):
461            if self.is_send:
462                out = self.send(self.param)
463                return ms.ops.functional.depend(self.ret, out)
464
465            self.param = self.receive(self.ret)
466            return ms.ops.functional.depend(self.ret, self.param)
467
468    sync_net = SharedParameterSyncCell(param, is_send, peer_rank, sr_tag)
469    sync_net()
470
471
472def sync_pipeline_shared_parameters(net):
473    """synchronize pipeline parallel stage shared parameters.
474    Parameters may be shared between different stages. For example, `embedding table` is
475    shared by `WordEmbedding` layer and `LMHead` layer, which are usually split into different stages. It is necessary
476    to perform synchronization after `embedding table` changes.
477
478    Note:
479        The network should be compiled before synchronize pipeline parallel stage shared parameters.
480
481    Args:
482        net (nn.Cell): the inference network.
483
484    Supported Platforms:
485        ``Ascend``
486
487    Examples:
488        .. note::
489            Before running the following examples, you need to configure the communication environment variables.
490
491            For the Ascend device, users need to write a dynamic cluster startup script, please see the `Dynamic Cluster
492            Startup <https://www.mindspore.cn/tutorials/experts/en/master/parallel/dynamic_cluster.html>`_ .
493
494        >>> import numpy as np
495        >>> import mindspore as ms
496        >>> import mindspore.communication.management as D
497        >>> from mindspore import lazy_inline, context, nn, ops, Parameter, Tensor
498        >>> context.set_context(mode=context.GRAPH_MODE)
499        >>> class Embedding(nn.Cell):
500        ...     def __init__(self, shape):
501        ...         super().__init__()
502        ...         self.w = Parameter(Tensor(np.ones(shape), ms.float32), name='w')
503        ...         self.matmul = ops.MatMul().shard(((1, 1), (1, 1)))
504        ...     def construct(self, x):
505        ...         return self.matmul(x, self.w), self.w
506        ...
507        >>> class LMHead(nn.Cell):
508        ...     def __init__(self):
509        ...         super().__init__()
510        ...         self.matmul = ops.MatMul(transpose_b=True).shard(((1, 1), (1, 1)))
511        ...     def construct(self, x, w):
512        ...         return self.matmul(x, w)
513        ...
514        >>> class Network(nn.Cell):
515        ...     @lazy_inline
516        ...     def __init__(self):
517        ...         super().__init__()
518        ...         shape = (4, 4)
519        ...         self.word_embedding = Embedding(shape)
520        ...         self.lm_head = LMHead()
521        ...         self.word_embedding.pipeline_stage = 0
522        ...         self.lm_head.pipeline_stage = 1
523        ...     def construct(self, x):
524        ...         x, embed = self.word_embedding(x)
525        ...         return self.lm_head(x, embed)
526        ...
527        >>> class PipelineCellInference(nn.Cell):
528        ...     def __init__(self, network, micro_batch_num):
529        ...         super().__init__()
530        ...         self.network = network
531        ...         self.micro_batch_num = micro_batch_num
532        ...         self.concat = ops.Concat()
533        ...     def construct(self, x):
534        ...         ret = ()
535        ...         for i in range(self.micro_batch_num):
536        ...             micro_batch_size = x.shape[0] // self.micro_batch_num
537        ...             start = micro_batch_size * i
538        ...             end = micro_batch_size * (i + 1)
539        ...             micro_input = x[start:end]
540        ...             y = self.network(micro_input)
541        ...             ret = ret + (y,)
542        ...         ret = self.concat(ret)
543        ...         return ret
544        >>> D.init()
545        >>> context.set_auto_parallel_context(parallel_mode='semi_auto_parallel', full_batch=True, pipeline_stages=2)
546        >>> net = Network()
547        >>> net = PipelineCellInference(net, 2)
548        >>> net.set_train(False)
549        >>> x = Tensor(np.ones((2, 4)), ms.float32)
550        >>> net.compile(x)
551        >>> ms.sync_pipeline_shared_parameters(net)
552        >>> print(net.network.word_embedding.w.asnumpy())
553        [[1. 1. 1. 1.]
554         [1. 1. 1. 1.]
555         [1. 1. 1. 1.]
556         [1. 1. 1. 1.]]
557    """
558
559    if not isinstance(net, ms.nn.Cell):
560        ms.log.critical("Failed to synchronize pipeline shared parameters.")
561        msg = ("For 'sync_pipeline_shared_parameters', the argument 'net' should be a Cell, "
562               "but got {}.".format(type(net)))
563        raise TypeError(msg)
564
565    layout_dict = net.parameter_layout_dict
566    if _is_in_auto_parallel_mode() and not layout_dict:
567        from mindspore.common.api import _get_parameter_layout
568        layout_dict = _get_parameter_layout()
569
570    # switch to standalone mode
571    parallel_mode = ms.context.get_auto_parallel_context("parallel_mode")
572    full_batch = ms.context.get_auto_parallel_context("full_batch")
573    ms.context.set_auto_parallel_context(parallel_mode="stand_alone", full_batch=False)
574
575    # synchronize shared parameter
576    for name, param in net.parameters_and_names():
577        if name in layout_dict:
578            _sync_params(name, param, layout_dict[name])
579
580    # restore parallel context
581    ms.context.set_auto_parallel_context(parallel_mode=parallel_mode, full_batch=full_batch)
582
583
584def load_segmented_checkpoints(ckpt_file_dir, net=None, strict_load=False, filter_prefix=None,
585                               dec_key=None, dec_mode="AES-GCM", specify_prefix=None, choice_func=None):
586    """
587    Load checkpoint info from a specified file. If the specified ckpt_file_dir path contains multiple
588    checkpoint files, all checkpoint files will be loaded one by one and the combined dictionary will be return.
589
590    Note:
591        - `specify_prefix` and `filter_prefix` do not affect each other.
592        - If none of the parameters are loaded from checkpoint file, it will throw ValueError.
593        - `specify_prefix` and `filter_prefix` are in the process of being deprecated,
594          `choice_func` is recommended instead.
595          And using either of those two args will override `choice_func` at the same time.
596
597    Args:
598        ckpt_file_dir (str): Checkpoint file directory.
599        net (Cell): The network where the parameters will be loaded. Default: ``None`` .
600        strict_load (bool): Whether to strict load the parameter into net. If ``False`` , it will load parameter
601                            into net when parameter name's suffix in checkpoint file is the same as the
602                            parameter in the network. When the types are inconsistent perform type conversion
603                            on the parameters of the same type, such as float32 to float16. Default: ``False`` .
604        filter_prefix (Union[str, list[str], tuple[str]]): Deprecated(see `choice_func`). Parameters starting with the
605            filter_prefix will not be loaded. Default: ``None`` .
606        dec_key (Union[None, bytes]): Byte type key used for decryption. If the value is ``None`` , the decryption
607                                      is not required. Default: ``None`` .
608        dec_mode (str): This parameter is valid only when dec_key is not set to ``None`` . Specifies the decryption
609                        mode, currently supports ``"AES-GCM"`` and ``"AES-CBC"`` and ``"SM4-CBC"`` .
610                        Default: ``"AES-GCM"`` .
611        specify_prefix (Union[str, list[str], tuple[str]]): Deprecated(see `choice_func`). Parameters starting with the
612            specify_prefix will be loaded. Default: ``None`` .
613        choice_func (Union[None, function]) : Input value of the function is a Parameter name of type string,
614            and the return value is a bool. If returns ``True`` , the Parameter
615            that matches the custom condition will be loaded. If returns ``False`` , the Parameter that
616            matches the custom condition will be removed. Default: ``None`` .
617
618    Returns:
619        Dict, key is parameter name, value is a Parameter or string. When the `append_dict` parameter of
620        :func:`mindspore.save_checkpoint` and the `append_info` parameter of :class:`mindspore.train.CheckpointConfig`
621        are used to save the checkpoint, `append_dict` and `append_info` are dict types, and their value are string,
622        then the return value obtained by loading checkpoint is string, and in other cases the return value is
623        Parameter.
624
625    Raises:
626        TypeError: Input ckpt_file_dir is not a string.
627        ValueError: Checkpoint file directory doesn't exist. Or it's not a directory
628        ValueError: Checkpoint file's format is incorrect.
629        ValueError: Parameter's dict is None after load checkpoint file.
630        TypeError: The type of `specify_prefix` or `filter_prefix` is incorrect.
631    """
632    if not isinstance(ckpt_file_dir, str):
633        raise TypeError("The ckpt_file_dir should be a str.")
634    if not os.path.isdir(ckpt_file_dir):
635        raise ValueError("The dst_strategy_file: {} doesn't exist. Or it's not a directory".
636                         format(ckpt_file_dir))
637    checkpoint_file_name = os.path.join(ckpt_file_dir, "*.ckpt")
638    rank_ckpts = glob.glob(checkpoint_file_name)
639    parameter_dict = {}
640    for checkpoint_file in rank_ckpts:
641        parameter_dict.update(ms.load_checkpoint(checkpoint_file, net, strict_load, filter_prefix, dec_key,
642                                                 dec_mode, specify_prefix, choice_func))
643    return parameter_dict
644