1# Copyright 2023 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Transform distributed checkpoint""" 16from __future__ import absolute_import 17 18import os 19import glob 20import copy 21from collections import defaultdict 22import numpy as np 23import mindspore as ms 24from mindspore.common import dtype as mstype 25from mindspore.parallel._utils import _is_in_auto_parallel_mode 26from mindspore.parallel._parallel_serialization import _rank_list_for_transform_parallel_checkpoint, \ 27 _transform_parallel_checkpoint, _get_device_num_from_strategy, _make_dir, \ 28 _extract_layout_map, _extract_src_dst_layout_map, _parameter_not_in_local_stage, _extract_pipeline_stage_num, \ 29 _merge_protobuf_strategy, _merge_json_strategy, _extract_src_dst_layout_map_by_src 30 31 32__all__ = ["merge_pipeline_strategys", "rank_list_for_transform", "transform_checkpoint_by_rank", 33 "transform_checkpoints", "sync_pipeline_shared_parameters", "load_segmented_checkpoints"] 34 35 36def merge_pipeline_strategys(src_strategy_dirs, dst_strategy_file): 37 """ 38 Merge parallel strategy between all pipeline stages in pipeline parallel mode. 39 For more details about converting distributed Checkpoint, please refer to 40 `Model Transformation <https://www.mindspore.cn/tutorials/experts/en/master/parallel/model_transformation.html>`_. 41 42 Note: 43 Strategy file of each pipeline stage should be included in src_strategy_dirs. 44 45 Args: 46 src_strategy_dirs (str): The directory of strategy files including all pipeline stage which is saved by 47 'mindspore.set_auto_parallel_context(strategy_ckpt_save_file)'. 48 dst_strategy_file (str): The file merged strategy to save. 49 50 Raises: 51 NotADirectoryError: `src_strategy_dirs` is not a directory. 52 53 Examples: 54 >>> import mindspore as ms 55 >>> # src_strategy_dir/stra0.ckpt, src_strategy_dir/stra1.ckpt ... src_strategy_dir/stra127.ckpt 56 >>> ms.merge_pipeline_strategys("./src_strategy_dir", "./dst_strategy.ckpt") 57 58 """ 59 dst_strategy_dir, _ = os.path.split(dst_strategy_file) 60 if not os.path.exists(dst_strategy_dir): 61 _make_dir(dst_strategy_dir, "path") 62 if not os.path.isdir(src_strategy_dirs): 63 raise NotADirectoryError("src_strategy_dirs {} is not a directory.".format(src_strategy_dirs)) 64 src_strategy_files_protobuf = glob.glob(os.path.join(src_strategy_dirs, "*.ckpt")) 65 src_strategy_files_json = glob.glob(os.path.join(src_strategy_dirs, "*.json")) 66 if src_strategy_files_protobuf and src_strategy_files_json: 67 raise ValueError("The strategys format should be all '.ckpt' or all '.json'") 68 is_protobuf = len(src_strategy_files_protobuf) > 0 69 if is_protobuf: 70 _merge_protobuf_strategy(src_strategy_files_protobuf, dst_strategy_file) 71 else: 72 _merge_json_strategy(src_strategy_files_json, dst_strategy_file) 73 74 75 76def rank_list_for_transform(rank_id, src_strategy_file=None, dst_strategy_file=None): 77 """ 78 List of original distributed checkpoint rank index for obtaining the target checkpoint of a rank_id during the 79 distributed checkpoint conversion. For more details about converting distributed Checkpoint, please refer to 80 `Model Transformation <https://www.mindspore.cn/tutorials/experts/en/master/parallel/model_transformation.html>`_. 81 82 Args: 83 rank_id (int): The rank of which distributed checkpoint needs to be obtained after conversion. 84 src_strategy_file (str): Name of source sharding strategy file which saved by 85 `mindspore.set_auto_parallel_context(strategy_ckpt_save_file)`. 86 when the `src_strategy_file` is ``None``, it means that the source sharding strategy is 87 without any sharing for each parameter. Default: ``None``. 88 dst_strategy_file (str): Name of destination sharding strategy file which saved by 89 `mindspore.set_auto_parallel_context(strategy_ckpt_save_file)`. 90 when the `dst_strategy_file` is ``None``, 91 it means that the destination sharding strategy 92 is without any sharing for each parameter. Default: ``None``. 93 94 Returns: 95 List, the rank list required for converting the distributed checkpoint of rank_id. 96 97 Raises: 98 ValueError: `src_strategy_file` or `dst_strategy_file` is incorrect. 99 TypeError: `src_strategy_file` or `dst_strategy_file` is not a string. 100 TypeError: `rank_id` is not an int. 101 102 Examples: 103 >>> import mindspore as ms 104 >>> rank_id = 0 105 >>> rank_list = ms.rank_list_for_transform(rank_id, "./src_strategy.ckpt", "./dst_strategy.ckpt") 106 >>> checkpoint_files_map = {} 107 >>> for rank in rank_list: 108 ... checkpoint_files_map[rank] = "./pangu{}-100_2.ckpt".format(rank) 109 110 """ 111 if not isinstance(rank_id, int): 112 raise TypeError("The rank_id should be a int.") 113 if src_strategy_file is None: 114 return [0] 115 src_strategy_list, dst_strategy_list = _extract_src_dst_layout_map(rank_id, src_strategy_file, dst_strategy_file) 116 src_stage_device_num = np.prod(src_strategy_list.get(list(src_strategy_list.keys())[0])[0]) if src_strategy_list \ 117 is not None else 1 118 dst_stage_device_num = np.prod(dst_strategy_list.get(list(dst_strategy_list.keys())[0])[0]) if dst_strategy_list \ 119 is not None else 1 120 121 if not src_strategy_list: 122 raise ValueError("The src_strategy_file is empty.") 123 local_rank_id = rank_id % dst_stage_device_num if dst_stage_device_num > 1 else rank_id 124 needed_rank_list_in_local_stage = _rank_list_for_transform_parallel_checkpoint(local_rank_id, 125 src_strategy_list, dst_strategy_list) 126 result_set = set() 127 handled_pipeline_stage = [] 128 for _, layout in src_strategy_list.items(): 129 for src_pipeline_stage_id in layout[6]: 130 if src_pipeline_stage_id in handled_pipeline_stage: 131 continue 132 src_rank_id_start = src_pipeline_stage_id * src_stage_device_num 133 result_set.update([src_rank_id_start + rank for rank in needed_rank_list_in_local_stage]) 134 handled_pipeline_stage.append(src_pipeline_stage_id) 135 return list(result_set) 136 137 138def transform_checkpoint_by_rank(rank_id, checkpoint_files_map, save_checkpoint_file_name, 139 src_strategy_file=None, dst_strategy_file=None): 140 """ 141 Transform distributed checkpoint from source sharding strategy to destination sharding strategy by rank 142 for a network. For more details about converting distributed Checkpoint, please refer to 143 `Model Transformation <https://www.mindspore.cn/tutorials/experts/en/master/parallel/model_transformation.html>`_. 144 145 Args: 146 rank_id (int): The rank of which distributed checkpoint needs to be obtained after conversion. 147 checkpoint_files_map (dict): The checkpoint files map whose key is the rank id and the value is 148 the checkpoint file name. 149 save_checkpoint_file_name (str): The file name to save the converted checkpoint. 150 src_strategy_file (str): Name of source sharding strategy file which saved by 151 'mindspore.set_auto_parallel_context(strategy_ckpt_save_file)'. 152 when the `src_strategy_file` is None, it means that the source sharding strategy is 153 without any sharing for each parameter. Default: ``None``. 154 dst_strategy_file (str): Name of destination sharding strategy file which saved by 155 'mindspore.set_auto_parallel_context(strategy_ckpt_save_file)'. 156 when the `dst_strategy_file` is ``None``, 157 it means that the destination sharding strategy 158 is without any sharing for each parameter. Default: ``None``. 159 160 Raises: 161 ValueError: `src_strategy_file` or `dst_strategy_file` is incorrect. 162 ValueError: item in `checkpoint_files_map` is incorrect. 163 ValueError: `save_checkpoint_file_name` is not end with ".ckpt". 164 TypeError: `checkpoint_files_map` is not a dict. 165 TypeError: `src_strategy_file` or `dst_strategy_file` is not a string. 166 TypeError: `rank_id` is not an int. 167 TypeError: `save_checkpoint_file_name` is not a string. 168 169 Examples: 170 >>> import mindspore as ms 171 >>> dst_device_num = 8 172 >>> for rank_id in range(dst_device_num): 173 ... rank_list = ms.rank_list_for_transform(rank_id, "./src_strategy.ckpt", "./dst_strategy.ckpt") 174 ... checkpoint_files_map = {} 175 ... for rank in rank_list: 176 ... checkpoint_files_map[rank] = "./origin_checkpoint_rank{}/pangu{}-100_2.ckpt".format(rank) 177 ... save_checkpoint_file_name = "./new_checkpoint_rank{}/pangu{}-100_2.ckpt".format(rank_id) 178 ... ms.transform_checkpoint_by_rank(rank_id, checkpoint_files_map, save_checkpoint_file_name, 179 ... "./src_strategy.ckpt", "./dst_strategy.ckpt") 180 181 """ 182 if not isinstance(checkpoint_files_map, dict): 183 raise TypeError("The checkpoint_files_map should be a dict.") 184 if not isinstance(rank_id, int): 185 raise TypeError("The rank_id should be a int.") 186 if not isinstance(save_checkpoint_file_name, str): 187 raise TypeError("The save_checkpoint_file_name should be a str.") 188 if save_checkpoint_file_name[-5:] != ".ckpt": 189 raise ValueError("The save_checkpoint_file_name {} should end with .ckpt".format(save_checkpoint_file_name)) 190 if dst_strategy_file and os.path.dirname(dst_strategy_file) and not os.path.exists( 191 os.path.dirname(dst_strategy_file)): 192 raise ValueError("The director of dst_strategy_file: {} is not exists.". 193 format(os.path.dirname(dst_strategy_file))) 194 for rank, local_file in checkpoint_files_map.items(): 195 if not os.path.exists(local_file): 196 raise ValueError("Checkpoint file {} in rank {} not exits: ".format(local_file, rank)) 197 param_total_dict = defaultdict(dict) 198 param_attr_dict = defaultdict(dict) 199 param_type_dict = defaultdict(dict) 200 src_strategy_list, dst_strategy_list = _extract_src_dst_layout_map(rank_id, src_strategy_file, dst_strategy_file) 201 # src rank => local rank inside pipeline stage 202 src_stage_device_num = np.prod(src_strategy_list.get(list(src_strategy_list.keys())[0])[0]) if src_strategy_list \ 203 is not None else 1 204 dst_stage_device_num = np.prod(dst_strategy_list.get(list(dst_strategy_list.keys())[0])[0]) if dst_strategy_list \ 205 is not None else 1 206 origin_dst_strategy_list = _extract_layout_map(dst_strategy_file) 207 origin_src_strategy_list = _extract_layout_map(src_strategy_file) 208 for rank, file_name in checkpoint_files_map.items(): 209 ckpt_dict = ms.load_checkpoint(file_name) 210 for param_name, param in ckpt_dict.items(): 211 # cut the parameter not in the pipeline stage. 212 if _parameter_not_in_local_stage(param_name, origin_src_strategy_list, src_strategy_list) \ 213 and _parameter_not_in_local_stage(param_name, origin_dst_strategy_list, dst_strategy_list): 214 continue 215 src_rank = rank % src_stage_device_num 216 param_type_dict[param_name][src_rank] = str(param.data.dtype) 217 if param.data.dtype == mstype.bfloat16: 218 param.set_dtype(mstype.float32) 219 param_total_dict[param_name][src_rank] = param.data.asnumpy() 220 param_attr_dict[param_name][src_rank] = (param.requires_grad, param.layerwise_parallel) 221 local_rank_id = rank_id % dst_stage_device_num 222 transform_param_list = _transform_parallel_checkpoint(local_rank_id, param_total_dict, 223 param_attr_dict, src_strategy_list, dst_strategy_list, 224 param_type_dict) 225 ms.save_checkpoint(transform_param_list, save_checkpoint_file_name) 226 227 228def _transform_checkpoint_by_stage(src_checkpoints_dir, dst_checkpoints_dir, ckpt_prefix, src_strategy_file, 229 dst_strategy_file=None): 230 """Transform checkpoint for stage in src_strategy_file""" 231 param_total_dict = defaultdict(dict) 232 param_attr_dict = defaultdict(dict) 233 param_type_dict = defaultdict(dict) 234 src_strategy_list, dst_strategy_list, stage_id = _extract_src_dst_layout_map_by_src(src_strategy_file, \ 235 dst_strategy_file) 236 src_stage_device_num = np.prod(src_strategy_list.get(list(src_strategy_list.keys())[0])[0]) if src_strategy_list \ 237 is not None else 1 238 dst_stage_device_num = np.prod(dst_strategy_list.get(list(dst_strategy_list.keys())[0])[0]) if dst_strategy_list \ 239 is not None else 1 240 origin_dst_strategy_list = _extract_layout_map(dst_strategy_file) 241 origin_src_strategy_list = _extract_layout_map(src_strategy_file) 242 checkpoint_files_map = {} 243 src_rank_id_start = stage_id * src_stage_device_num 244 for local_rank in range(src_stage_device_num): 245 rank_id = src_rank_id_start + local_rank 246 checkpoint_file_name = os.path.join(src_checkpoints_dir, "rank_{}".format(rank_id), "*.ckpt") 247 rank_ckpts = glob.glob(checkpoint_file_name) 248 rank_ckpts.sort() 249 for checkpoint_file in rank_ckpts: 250 if not os.path.isfile(checkpoint_file): 251 ms.log.warning("{} is not a checkpoint file.".format(checkpoint_file)) 252 continue 253 checkpoint_files_map[rank_id] = checkpoint_file 254 for rank, local_file in checkpoint_files_map.items(): 255 if not os.path.exists(local_file): 256 raise ValueError("Checkpoint file {} in rank {} not exits: ".format(local_file, rank)) 257 for rank, file_name in checkpoint_files_map.items(): 258 ckpt_dict = ms.load_checkpoint(file_name) 259 for param_name, param in ckpt_dict.items(): 260 # cut the parameter not in the pipeline stage. 261 if _parameter_not_in_local_stage(param_name, origin_src_strategy_list, src_strategy_list) \ 262 and _parameter_not_in_local_stage(param_name, origin_dst_strategy_list, dst_strategy_list): 263 continue 264 src_rank = rank % src_stage_device_num 265 param_type_dict[param_name][src_rank] = str(param.data.dtype) 266 if param.data.dtype == mstype.bfloat16: 267 param.set_dtype(mstype.float32) 268 param_total_dict[param_name][src_rank] = param.data.asnumpy() 269 param_attr_dict[param_name][src_rank] = (param.requires_grad, param.layerwise_parallel) 270 for local_rank_id in range(dst_stage_device_num): 271 transform_param_list = _transform_parallel_checkpoint(local_rank_id, param_total_dict, 272 param_attr_dict, src_strategy_list, dst_strategy_list, 273 param_type_dict) 274 save_checkpoint_file = "{}{}_part{}.ckpt".format(ckpt_prefix, local_rank_id, stage_id) 275 save_checkpoint_file_dir = os.path.join(dst_checkpoints_dir, "rank_{}".format(local_rank_id)) 276 if not os.path.exists(save_checkpoint_file_dir): 277 _make_dir(save_checkpoint_file_dir, "path") 278 save_checkpoint_file_name = os.path.join(save_checkpoint_file_dir, save_checkpoint_file) 279 ms.save_checkpoint(transform_param_list, save_checkpoint_file_name) 280 281 282def _transform_checkpoints(src_checkpoints_dir, dst_checkpoints_dir, ckpt_prefix, src_strategy_file=None, 283 dst_strategy_file=None): 284 """Transform checkpoints for all stages in src_strategy_file""" 285 checkpoints_rank_dir_list = os.path.join(src_checkpoints_dir, "rank_[0-9]*") 286 all_checkpoint_files_map = {} 287 for checkpoint_dir in glob.glob(checkpoints_rank_dir_list): 288 if not os.path.isdir(checkpoint_dir): 289 ms.log.warning("{} is not a directory.".format(checkpoint_dir)) 290 continue 291 rank_id_str = checkpoint_dir.split('rank_')[-1] 292 if not rank_id_str.isdigit(): 293 ms.log.warning("{} is not a expected directory, the directory should end with rank_0/rank_1.....". 294 format(checkpoint_dir)) 295 continue 296 rank_id = int(rank_id_str) 297 checkpoint_file_name = os.path.join(checkpoint_dir, "*.ckpt") 298 rank_ckpts = glob.glob(checkpoint_file_name) 299 rank_ckpts.sort() 300 for checkpoint_file in rank_ckpts: 301 if not os.path.isfile(checkpoint_file): 302 ms.log.warning("{} is not a checkpoint file.".format(checkpoint_file)) 303 continue 304 all_checkpoint_files_map[rank_id] = checkpoint_file 305 306 needed_rank_list_map = defaultdict(list) 307 dst_stage_device_num = _get_device_num_from_strategy(dst_strategy_file) 308 src_stage_device_num = _get_device_num_from_strategy(src_strategy_file) 309 dst_stage_num = _extract_pipeline_stage_num(dst_strategy_file) 310 dst_device_num = dst_stage_device_num * dst_stage_num 311 origin_src_strategy_list = _extract_layout_map(src_strategy_file) 312 origin_dst_strategy_list = _extract_layout_map(dst_strategy_file) 313 for rank in range(dst_device_num): 314 needed_rank_list = rank_list_for_transform(rank, src_strategy_file, dst_strategy_file) 315 for needed_rank in needed_rank_list: 316 if needed_rank not in all_checkpoint_files_map: 317 raise ValueError("The checkpoint file of rank{} is needed for converting rank{}'s checkpoint, " 318 "but it is missing.".format(needed_rank, rank)) 319 needed_rank_list_key = "-".join([str(r) for r in needed_rank_list]) 320 needed_rank_list_map[needed_rank_list_key].append(rank) 321 for needed_rank_list_key, transform_rank_list in needed_rank_list_map.items(): 322 param_total_dict = defaultdict(dict) 323 param_attr_dict = defaultdict(dict) 324 param_type_dict = defaultdict(dict) 325 needed_rank_list = needed_rank_list_key.split("-") 326 for needed_rank in needed_rank_list: 327 ckpt_dict = ms.load_checkpoint(all_checkpoint_files_map.get(int(needed_rank))) 328 for param_name, param in ckpt_dict.items(): 329 src_rank = int(needed_rank) % src_stage_device_num 330 param_type_dict[param_name][src_rank] = str(param.data.dtype) 331 if param.data.dtype == mstype.bfloat16: 332 param.set_dtype(mstype.float32) 333 param_total_dict[param_name][src_rank] = param.data.asnumpy() 334 param_attr_dict[param_name][src_rank] = (param.requires_grad, param.layerwise_parallel) 335 for transform_rank in transform_rank_list: 336 param_total_dict_copy = copy.deepcopy(param_total_dict) 337 src_strategy_list, dst_strategy_list = _extract_src_dst_layout_map(transform_rank, src_strategy_file, 338 dst_strategy_file) 339 # cut the parameter not in the pipeline stage. 340 for param in list(param_total_dict_copy.keys()): 341 if _parameter_not_in_local_stage(param, origin_src_strategy_list, src_strategy_list) \ 342 and _parameter_not_in_local_stage(param, origin_dst_strategy_list, dst_strategy_list): 343 param_total_dict_copy.pop(param) 344 345 local_rank_id = transform_rank % dst_stage_device_num 346 transform_param_list = _transform_parallel_checkpoint(local_rank_id, param_total_dict_copy, 347 param_attr_dict, src_strategy_list, dst_strategy_list, 348 param_type_dict) 349 save_checkpoint_file = "{}{}.ckpt".format(ckpt_prefix, transform_rank) 350 save_checkpoint_file_dir = os.path.join(dst_checkpoints_dir, "rank_{}".format(transform_rank)) 351 if not os.path.exists(save_checkpoint_file_dir): 352 _make_dir(save_checkpoint_file_dir, "path") 353 save_checkpoint_file_name = os.path.join(save_checkpoint_file_dir, save_checkpoint_file) 354 ms.save_checkpoint(transform_param_list, save_checkpoint_file_name) 355 del param_total_dict_copy 356 del param_total_dict 357 358 359def transform_checkpoints(src_checkpoints_dir, dst_checkpoints_dir, ckpt_prefix, src_strategy_file=None, 360 dst_strategy_file=None): 361 """ 362 Transform distributed checkpoint from source sharding strategy to destination sharding strategy for a rank. 363 For more details about converting distributed Checkpoint, please refer to 364 `Model Transformation <https://www.mindspore.cn/tutorials/experts/en/master/parallel/model_transformation.html>`_. 365 366 Note: 367 The `src_checkpoints_dir` directory structure should be organized like "src_checkpoints_dir/rank_0/a.ckpt", the 368 rank number should be set to a subdirectory and the checkpoint file is stored in this subdirectory. If multiple 369 files exist in a rank directory, the last file in the lexicgraphic order would be selected. 370 371 Args: 372 src_checkpoints_dir (str): The source checkpoints directory. 373 dst_checkpoints_dir (str): The destination checkpoints directory to save the converted checkpoints. 374 ckpt_prefix (str): The destination checkpoint name prefix. 375 src_strategy_file (str): Name of source sharding strategy file which saved by 376 'mindspore.set_auto_parallel_context(strategy_ckpt_save_file)'. 377 when the 'src_strategy_file' is None, it means that the source sharding strategy is 378 without any sharing for each parameter. Default:None. 379 dst_strategy_file (str): Name of destination sharding strategy file which saved by 380 'mindspore.set_auto_parallel_context(strategy_ckpt_save_file)'. 381 when the 'dst_strategy_file' is None, it means that the destination sharding strategy 382 is without any sharing for each parameter. Default:None. 383 384 Raises: 385 ValueError: `src_strategy_file` or `dst_strategy_file` is incorrect. 386 NotADirectoryError: `src_checkpoints_dir` or `dst_checkpoints_dir` is not a directory. 387 ValueError: The checkpoint file is missing in `src_checkpoints_dir`. 388 TypeError: `src_strategy_file` or `dst_strategy_file` is not a string. 389 390 Examples: 391 >>> import mindspore as ms 392 >>> ms.transform_checkpoints(src_checkpoints_dir, dst_checkpoints_dir, "dst_checkpoint", 393 ... "./src_strategy.ckpt", "./dst_strategy.ckpt") 394 395 """ 396 if not os.path.isdir(src_checkpoints_dir): 397 raise NotADirectoryError("src_checkpoints_dir {} is not a directory.".format(src_checkpoints_dir)) 398 _make_dir(dst_checkpoints_dir, "path") 399 if not isinstance(ckpt_prefix, str): 400 raise TypeError("The ckpt_prefix should be a str.") 401 if src_strategy_file and os.path.dirname(src_strategy_file) and not os.path.exists( 402 os.path.dirname(src_strategy_file)): 403 raise ValueError("The director of src_strategy_file: {} is not exists.". 404 format(os.path.dirname(src_strategy_file))) 405 if dst_strategy_file and os.path.dirname(dst_strategy_file) and not os.path.exists( 406 os.path.dirname(dst_strategy_file)): 407 raise ValueError("The director of dst_strategy_file: {} is not exists.". 408 format(os.path.dirname(dst_strategy_file))) 409 src_layout_map = _extract_layout_map(src_strategy_file) 410 dst_layout_map = _extract_layout_map(dst_strategy_file) 411 pipeline_stage_num = _extract_pipeline_stage_num(src_strategy_file) 412 dst_stage_num = _extract_pipeline_stage_num(dst_strategy_file) 413 if src_layout_map: 414 src_param_keys = {param_name for param_name in src_layout_map if 415 not param_name.startswith(("accu_grads", "adam_v", "adam_m"))} 416 if dst_layout_map: 417 dst_param_keys = {param_name for param_name in dst_layout_map if 418 not param_name.startswith(("accu_grads", "adam_v", "adam_m"))} 419 layout_is_passed = src_layout_map and dst_layout_map 420 421 if layout_is_passed and pipeline_stage_num == 1 and dst_stage_num == 1 and \ 422 src_param_keys.issubset(dst_param_keys) and len(src_param_keys) < len(dst_param_keys): 423 ms.log.info("Transform checkpoint by every pipeline stage.") 424 _transform_checkpoint_by_stage(src_checkpoints_dir, dst_checkpoints_dir, ckpt_prefix, 425 src_strategy_file, dst_strategy_file) 426 else: 427 ms.log.info("Transform checkpoints by all pipeline stage.") 428 _transform_checkpoints(src_checkpoints_dir, dst_checkpoints_dir, ckpt_prefix, 429 src_strategy_file, dst_strategy_file) 430 431 432def _sync_params(name, param, layout): 433 """synchronize single parameter""" 434 if len(layout) < 10: 435 ms.log.warning("The layout dict does not contain the pipeline_shared_param info %s", name) 436 return 437 438 pipeline_shared = layout[8] 439 if not pipeline_shared: 440 return 441 442 is_send = layout[9] 443 peer_rank = layout[10] 444 sr_tag = layout[11] 445 446 class SharedParameterSyncCell(ms.nn.Cell): 447 """synchronize cell""" 448 def __init__(self, param, is_send, peer_rank, sr_tag): 449 super().__init__() 450 self.param = param 451 self.is_send = is_send 452 self.ret = ms.Tensor([0]) 453 454 from mindspore.ops import Send, Receive 455 if self.is_send: 456 self.send = Send(sr_tag=sr_tag, dest_rank=peer_rank) 457 else: 458 self.receive = Receive(sr_tag=sr_tag, src_rank=peer_rank, shape=param.shape, dtype=param.dtype) 459 460 def construct(self): 461 if self.is_send: 462 out = self.send(self.param) 463 return ms.ops.functional.depend(self.ret, out) 464 465 self.param = self.receive(self.ret) 466 return ms.ops.functional.depend(self.ret, self.param) 467 468 sync_net = SharedParameterSyncCell(param, is_send, peer_rank, sr_tag) 469 sync_net() 470 471 472def sync_pipeline_shared_parameters(net): 473 """synchronize pipeline parallel stage shared parameters. 474 Parameters may be shared between different stages. For example, `embedding table` is 475 shared by `WordEmbedding` layer and `LMHead` layer, which are usually split into different stages. It is necessary 476 to perform synchronization after `embedding table` changes. 477 478 Note: 479 The network should be compiled before synchronize pipeline parallel stage shared parameters. 480 481 Args: 482 net (nn.Cell): the inference network. 483 484 Supported Platforms: 485 ``Ascend`` 486 487 Examples: 488 .. note:: 489 Before running the following examples, you need to configure the communication environment variables. 490 491 For the Ascend device, users need to write a dynamic cluster startup script, please see the `Dynamic Cluster 492 Startup <https://www.mindspore.cn/tutorials/experts/en/master/parallel/dynamic_cluster.html>`_ . 493 494 >>> import numpy as np 495 >>> import mindspore as ms 496 >>> import mindspore.communication.management as D 497 >>> from mindspore import lazy_inline, context, nn, ops, Parameter, Tensor 498 >>> context.set_context(mode=context.GRAPH_MODE) 499 >>> class Embedding(nn.Cell): 500 ... def __init__(self, shape): 501 ... super().__init__() 502 ... self.w = Parameter(Tensor(np.ones(shape), ms.float32), name='w') 503 ... self.matmul = ops.MatMul().shard(((1, 1), (1, 1))) 504 ... def construct(self, x): 505 ... return self.matmul(x, self.w), self.w 506 ... 507 >>> class LMHead(nn.Cell): 508 ... def __init__(self): 509 ... super().__init__() 510 ... self.matmul = ops.MatMul(transpose_b=True).shard(((1, 1), (1, 1))) 511 ... def construct(self, x, w): 512 ... return self.matmul(x, w) 513 ... 514 >>> class Network(nn.Cell): 515 ... @lazy_inline 516 ... def __init__(self): 517 ... super().__init__() 518 ... shape = (4, 4) 519 ... self.word_embedding = Embedding(shape) 520 ... self.lm_head = LMHead() 521 ... self.word_embedding.pipeline_stage = 0 522 ... self.lm_head.pipeline_stage = 1 523 ... def construct(self, x): 524 ... x, embed = self.word_embedding(x) 525 ... return self.lm_head(x, embed) 526 ... 527 >>> class PipelineCellInference(nn.Cell): 528 ... def __init__(self, network, micro_batch_num): 529 ... super().__init__() 530 ... self.network = network 531 ... self.micro_batch_num = micro_batch_num 532 ... self.concat = ops.Concat() 533 ... def construct(self, x): 534 ... ret = () 535 ... for i in range(self.micro_batch_num): 536 ... micro_batch_size = x.shape[0] // self.micro_batch_num 537 ... start = micro_batch_size * i 538 ... end = micro_batch_size * (i + 1) 539 ... micro_input = x[start:end] 540 ... y = self.network(micro_input) 541 ... ret = ret + (y,) 542 ... ret = self.concat(ret) 543 ... return ret 544 >>> D.init() 545 >>> context.set_auto_parallel_context(parallel_mode='semi_auto_parallel', full_batch=True, pipeline_stages=2) 546 >>> net = Network() 547 >>> net = PipelineCellInference(net, 2) 548 >>> net.set_train(False) 549 >>> x = Tensor(np.ones((2, 4)), ms.float32) 550 >>> net.compile(x) 551 >>> ms.sync_pipeline_shared_parameters(net) 552 >>> print(net.network.word_embedding.w.asnumpy()) 553 [[1. 1. 1. 1.] 554 [1. 1. 1. 1.] 555 [1. 1. 1. 1.] 556 [1. 1. 1. 1.]] 557 """ 558 559 if not isinstance(net, ms.nn.Cell): 560 ms.log.critical("Failed to synchronize pipeline shared parameters.") 561 msg = ("For 'sync_pipeline_shared_parameters', the argument 'net' should be a Cell, " 562 "but got {}.".format(type(net))) 563 raise TypeError(msg) 564 565 layout_dict = net.parameter_layout_dict 566 if _is_in_auto_parallel_mode() and not layout_dict: 567 from mindspore.common.api import _get_parameter_layout 568 layout_dict = _get_parameter_layout() 569 570 # switch to standalone mode 571 parallel_mode = ms.context.get_auto_parallel_context("parallel_mode") 572 full_batch = ms.context.get_auto_parallel_context("full_batch") 573 ms.context.set_auto_parallel_context(parallel_mode="stand_alone", full_batch=False) 574 575 # synchronize shared parameter 576 for name, param in net.parameters_and_names(): 577 if name in layout_dict: 578 _sync_params(name, param, layout_dict[name]) 579 580 # restore parallel context 581 ms.context.set_auto_parallel_context(parallel_mode=parallel_mode, full_batch=full_batch) 582 583 584def load_segmented_checkpoints(ckpt_file_dir, net=None, strict_load=False, filter_prefix=None, 585 dec_key=None, dec_mode="AES-GCM", specify_prefix=None, choice_func=None): 586 """ 587 Load checkpoint info from a specified file. If the specified ckpt_file_dir path contains multiple 588 checkpoint files, all checkpoint files will be loaded one by one and the combined dictionary will be return. 589 590 Note: 591 - `specify_prefix` and `filter_prefix` do not affect each other. 592 - If none of the parameters are loaded from checkpoint file, it will throw ValueError. 593 - `specify_prefix` and `filter_prefix` are in the process of being deprecated, 594 `choice_func` is recommended instead. 595 And using either of those two args will override `choice_func` at the same time. 596 597 Args: 598 ckpt_file_dir (str): Checkpoint file directory. 599 net (Cell): The network where the parameters will be loaded. Default: ``None`` . 600 strict_load (bool): Whether to strict load the parameter into net. If ``False`` , it will load parameter 601 into net when parameter name's suffix in checkpoint file is the same as the 602 parameter in the network. When the types are inconsistent perform type conversion 603 on the parameters of the same type, such as float32 to float16. Default: ``False`` . 604 filter_prefix (Union[str, list[str], tuple[str]]): Deprecated(see `choice_func`). Parameters starting with the 605 filter_prefix will not be loaded. Default: ``None`` . 606 dec_key (Union[None, bytes]): Byte type key used for decryption. If the value is ``None`` , the decryption 607 is not required. Default: ``None`` . 608 dec_mode (str): This parameter is valid only when dec_key is not set to ``None`` . Specifies the decryption 609 mode, currently supports ``"AES-GCM"`` and ``"AES-CBC"`` and ``"SM4-CBC"`` . 610 Default: ``"AES-GCM"`` . 611 specify_prefix (Union[str, list[str], tuple[str]]): Deprecated(see `choice_func`). Parameters starting with the 612 specify_prefix will be loaded. Default: ``None`` . 613 choice_func (Union[None, function]) : Input value of the function is a Parameter name of type string, 614 and the return value is a bool. If returns ``True`` , the Parameter 615 that matches the custom condition will be loaded. If returns ``False`` , the Parameter that 616 matches the custom condition will be removed. Default: ``None`` . 617 618 Returns: 619 Dict, key is parameter name, value is a Parameter or string. When the `append_dict` parameter of 620 :func:`mindspore.save_checkpoint` and the `append_info` parameter of :class:`mindspore.train.CheckpointConfig` 621 are used to save the checkpoint, `append_dict` and `append_info` are dict types, and their value are string, 622 then the return value obtained by loading checkpoint is string, and in other cases the return value is 623 Parameter. 624 625 Raises: 626 TypeError: Input ckpt_file_dir is not a string. 627 ValueError: Checkpoint file directory doesn't exist. Or it's not a directory 628 ValueError: Checkpoint file's format is incorrect. 629 ValueError: Parameter's dict is None after load checkpoint file. 630 TypeError: The type of `specify_prefix` or `filter_prefix` is incorrect. 631 """ 632 if not isinstance(ckpt_file_dir, str): 633 raise TypeError("The ckpt_file_dir should be a str.") 634 if not os.path.isdir(ckpt_file_dir): 635 raise ValueError("The dst_strategy_file: {} doesn't exist. Or it's not a directory". 636 format(ckpt_file_dir)) 637 checkpoint_file_name = os.path.join(ckpt_file_dir, "*.ckpt") 638 rank_ckpts = glob.glob(checkpoint_file_name) 639 parameter_dict = {} 640 for checkpoint_file in rank_ckpts: 641 parameter_dict.update(ms.load_checkpoint(checkpoint_file, net, strict_load, filter_prefix, dec_key, 642 dec_mode, specify_prefix, choice_func)) 643 return parameter_dict 644