1# Copyright 2023 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""parallel serialization""" 16from __future__ import absolute_import 17 18import os 19import json 20import numpy as np 21import mindspore as ms 22from mindspore.parallel._tensor import _get_tensor_strategy, _construct_from_to_tensor_layout, \ 23 _get_needed_rank_list_by_layouts, _get_needed_rank_transform_operator_map_by_layouts, \ 24 _generate_transform_operator_stack, _apply_tensor_transform_operators, _construct_tensor_layout_for_opt_shard, \ 25 _extract_layout_item 26 27 28MAX_PATH_LENGTH = 1024 29 30 31def _convert_to_list(strategy, rank_id=None): 32 """Convert ParallelLayouts object to specified list.""" 33 train_map = {} 34 for param_name in strategy.keys(): 35 try: 36 layout = strategy.get(param_name) 37 dev_mat = list(layout.dev_matrix[0].dim) 38 tensor_map = list(layout.tensor_map[0].dim) 39 param_split_shape = list(layout.param_split_shape[0].dim) 40 pipeline_stage = 0 41 origin_param_name = param_name 42 if "-" in param_name: 43 pipeline_stage, origin_param_name = param_name.split("-") 44 pipeline_stage = int(pipeline_stage) 45 if origin_param_name not in train_map: 46 train_map[origin_param_name] = [dev_mat, tensor_map, param_split_shape, int(layout.field), 47 int(layout.opt_weight_shard_step), int(layout.opt_weight_shard_size), 48 [pipeline_stage]] 49 else: 50 update_pipeline_stage_list = train_map.get(origin_param_name)[6] + [pipeline_stage] 51 if rank_id is not None: 52 stage_device_num = np.prod(dev_mat) 53 is_device0_and_pipeline0 = ((rank_id // stage_device_num) == 0) and (pipeline_stage == 0) 54 not_device0_nor_pipeline0 = ((rank_id // stage_device_num) > 0) and (pipeline_stage > 0) 55 if is_device0_and_pipeline0 or not_device0_nor_pipeline0: 56 train_map[origin_param_name] = [dev_mat, tensor_map, param_split_shape, 57 int(layout.field), int(layout.opt_weight_shard_step), 58 int(layout.opt_weight_shard_size), update_pipeline_stage_list] 59 else: 60 train_map.get(origin_param_name)[6] = update_pipeline_stage_list 61 else: 62 if np.all(pipeline_stage <= np.array(update_pipeline_stage_list)): 63 train_map[origin_param_name] = [dev_mat, tensor_map, param_split_shape, 64 int(layout.field), int(layout.opt_weight_shard_step), 65 int(layout.opt_weight_shard_size), update_pipeline_stage_list] 66 else: 67 train_map.get(origin_param_name)[6] = update_pipeline_stage_list 68 except BaseException as e: 69 raise ValueError(f"{e.__str__()}. Convert layout strategy to list " 70 f"failed, please make sure that strategy matches the node_strategy.proto, you can " 71 f"check whether 'train_strategy_filename' is correct.") from e 72 return train_map 73 74 75def _convert_to_layout(param_name, tensor_layout): 76 """Convert list to ParallelLayouts object.""" 77 strategy = {} 78 try: 79 layout = ms.train.node_strategy_pb2.ParallelLayouts() 80 layout.field = tensor_layout[3] 81 82 dev_matrix = layout.dev_matrix.add() 83 for item in tensor_layout[0]: 84 dev_matrix.dim.append(item) 85 86 tensor_map = layout.tensor_map.add() 87 for item in tensor_layout[1]: 88 tensor_map.dim.append(item) 89 90 param_split_shape = layout.param_split_shape.add() 91 for item in tensor_layout[2]: 92 param_split_shape.dim.append(item) 93 except BaseException as e: 94 raise ValueError(f"{e.__str__()}. For 'load_distributed_checkpoint', convert list to layout strategy failed, " 95 f"you can check whether your input list is correct.") from e 96 97 strategy[param_name] = layout 98 return strategy 99 100 101def _check_strategy_file(strategy_filename): 102 """load parallel strategy file""" 103 if not isinstance(strategy_filename, str): 104 raise TypeError(f"For 'build_searched_strategy', the argument 'strategy_filename' should be string, " 105 f"but got {type(strategy_filename)}.") 106 107 if not os.path.isfile(strategy_filename): 108 raise ValueError(f"For 'build_searched_strategy', no such strategy file: {strategy_filename}. " 109 f"Please check whether the 'strategy_filename' exists.") 110 111 if os.path.getsize(strategy_filename) == 0: 112 raise ValueError(f"For 'build_searched_strategy', the strategy file {strategy_filename} should not " 113 f"be empty. Please check whether the 'strategy_filename' is correct.") 114 115 116def _load_protobuf_strategy(strategy_filename): 117 """load strategy from protobuf file""" 118 parallel_strategy_map = ms.train.node_strategy_pb2.ParallelStrategyMap() 119 with open(strategy_filename, 'rb') as f: 120 pb_content = f.read() 121 try: 122 parallel_strategy_map.ParseFromString(pb_content) 123 except BaseException as e: 124 raise TypeError("The strategy file type should be one of json or protobuf. " 125 "When the file name extension is not '.json', " 126 "the file is considered as a protobuf file.") from e 127 return parallel_strategy_map 128 129 130def _build_protobuf_strategy(strategy_filename): 131 """build strategy from protobuf file""" 132 parallel_strategy_map = _load_protobuf_strategy(strategy_filename) 133 layout_items = parallel_strategy_map.parallel_layout_item 134 if not layout_items: 135 raise ValueError(f"For 'build_searched_strategy', the strategy file {strategy_filename} has no sliced " 136 f"parameter, please check whether the 'strategy_filename' is correct.") 137 138 strategy = {} 139 for layout_item in layout_items: 140 parameter_name = layout_item.param_name 141 layout = layout_item.parallel_layouts 142 strategy[parameter_name] = layout 143 return strategy 144 145 146def _build_json_strategy(strategy_filename): 147 """build strategy from json file""" 148 with open(strategy_filename, 'r') as f: 149 json_content = json.load(f) 150 layout_items = json_content.get("parallel_layout_item") 151 strategy = {} 152 for parameter_name, layout_item in layout_items.items(): 153 layout = ms.train.node_strategy_pb2.ParallelLayouts() 154 layout.field = layout_item.get("field") 155 layout.opt_weight_shard_size = layout_item.get("opt_weight_shard_size") 156 layout.opt_weight_shard_step = layout_item.get("opt_weight_shard_step") 157 dev_matrix = layout.dev_matrix.add() 158 for item in layout_item.get("dev_matrix"): 159 dev_matrix.dim.append(item) 160 tensor_map = layout.tensor_map.add() 161 for item in layout_item.get("tensor_map"): 162 tensor_map.dim.append(item) 163 param_split_shape = layout.param_split_shape.add() 164 if "param_split_shape" in layout_item: 165 for item in layout_item.get("param_split_shape"): 166 param_split_shape.dim.append(item) 167 indices_offset = layout.indices_offset.add() 168 if "indices_offset" in layout_item: 169 for item in layout_item.get("indices_offset"): 170 indices_offset.dim.append(item) 171 strategy[parameter_name] = layout 172 return strategy 173 174 175def _build_searched_strategy(strategy_filename): 176 """build searched strategy""" 177 _check_strategy_file(strategy_filename) 178 if strategy_filename[-5:] != ".json": 179 return _build_protobuf_strategy(strategy_filename) 180 return _build_json_strategy(strategy_filename) 181 182 183def _merge_protobuf_strategy(src_strategy_files, dst_strategy_file): 184 """merge protobuf strategy""" 185 dst_parallel_strategy_map = ms.train.node_strategy_pb2.ParallelStrategyMap() 186 merged_stage = [] 187 for src_strategy_file in src_strategy_files: 188 src_parallel_strategy_map = _load_protobuf_strategy(src_strategy_file) 189 strategy_items = src_parallel_strategy_map.parallel_strategy_item 190 layout_items = src_parallel_strategy_map.parallel_layout_item 191 if not strategy_items or not layout_items: 192 raise ValueError("The strategy file {} is empty".format(src_strategy_file)) 193 pipeline_stage = strategy_items[0].parallel_strategys.stage 194 if pipeline_stage in merged_stage: 195 continue 196 for layout_item in layout_items: 197 layout_item.param_name = "-".join([str(pipeline_stage), layout_item.param_name]) 198 dst_parallel_strategy_map.parallel_strategy_item.extend(strategy_items) 199 dst_parallel_strategy_map.parallel_layout_item.extend(layout_items) 200 merged_stage.append(pipeline_stage) 201 dst_parallel_strategy_map.current_stage = 1 202 with open(dst_strategy_file, "wb") as f: 203 f.write(dst_parallel_strategy_map.SerializeToString()) 204 205 206def _merge_json_strategy(src_strategy_files, dst_strategy_file): 207 """merge protobuf strategy""" 208 dst_parallel_strategy_map = {"current_stage": 1, "parallel_strategy_item": {}, "parallel_layout_item": {}} 209 merged_stage = [] 210 for src_strategy_file in src_strategy_files: 211 with open(src_strategy_file, 'r') as f: 212 json_content = json.load(f) 213 layout_items = json_content.get("parallel_layout_item") 214 strategy_items = json_content.get("parallel_strategy_item") 215 if not strategy_items or not layout_items: 216 raise ValueError("The strategy file {} is empty".format(src_strategy_file)) 217 pipeline_stage = strategy_items.get(list(strategy_items.keys())[0]).get('stage') 218 if pipeline_stage in merged_stage: 219 continue 220 for param_name, layout_item in layout_items.items(): 221 new_layout_item = {} 222 new_param_name = "-".join([str(pipeline_stage), param_name]) 223 new_layout_item[new_param_name] = layout_item 224 dst_parallel_strategy_map.get("parallel_layout_item").update(new_layout_item) 225 dst_parallel_strategy_map.get("parallel_strategy_item").update(strategy_items) 226 merged_stage.append(pipeline_stage) 227 with open(dst_strategy_file, "w") as f: 228 json.dump(dst_parallel_strategy_map, f) 229 230 231def _parameter_not_in_local_stage(param_name, origin_strategy_list, strategy_list): 232 """parameter whether in the local stage""" 233 if origin_strategy_list is None or strategy_list is None: 234 return True 235 return param_name in origin_strategy_list and param_name not in strategy_list 236 237 238def _extract_layout_map(strategy_file, rank_id=None): 239 """Extract layout map""" 240 layout_map = None 241 if strategy_file is not None: 242 src_strategy = _build_searched_strategy(strategy_file) 243 layout_map = _convert_to_list(src_strategy, rank_id) 244 return layout_map 245 246 247def _extract_pipeline_stage_num(strategy_file): 248 """extract pipeline stage num""" 249 pipeline_stage_num = 1 250 if strategy_file is not None: 251 src_strategy = _build_searched_strategy(strategy_file) 252 layout_map = _convert_to_list(src_strategy) 253 pipeline_stage_set = set() 254 for _, layout in layout_map.items(): 255 pipeline_stage_set.update(layout[6]) 256 pipeline_stage_num = len(pipeline_stage_set) 257 if list(pipeline_stage_set) != list(range(pipeline_stage_num)): 258 raise ValueError("The strategy file for pipeline parallel dose not contains all stages.") 259 return pipeline_stage_num 260 261 262def _extract_src_dst_layout_map_by_src(src_strategy_file=None, dst_strategy_file=None): 263 """Extract strategy list by src strategy""" 264 src_layout_map = _extract_layout_map(src_strategy_file) 265 dst_layout_map = _extract_layout_map(dst_strategy_file) 266 if dst_layout_map is None: 267 return src_layout_map, dst_layout_map 268 for param_name in list(dst_layout_map.keys()): 269 if param_name in src_layout_map.keys(): 270 continue 271 dst_layout_map.pop(param_name) 272 stage_id = 0 273 if src_strategy_file[-5:] == ".json": 274 with open(src_strategy_file, 'r') as f: 275 json_content = json.load(f) 276 strategy_items = json_content.get("parallel_strategy_item") 277 if not strategy_items: 278 raise ValueError("The strategy file {} if empty.".format(src_strategy_file)) 279 stage_id = strategy_items.get(list(strategy_items.keys())[0]).get('stage') 280 else: 281 src_parallel_strategy_map = _load_protobuf_strategy(src_strategy_file) 282 strategy_items = src_parallel_strategy_map.parallel_strategy_item 283 if not strategy_items: 284 raise ValueError("The strategy file {} if empty.".format(src_strategy_file)) 285 stage_id = strategy_items[0].parallel_strategys.stage 286 return src_layout_map, dst_layout_map, stage_id 287 288 289def _extract_src_dst_layout_map(rank_id, src_strategy_file=None, dst_strategy_file=None): 290 """Extract strategy list""" 291 src_layout_map = _extract_layout_map(src_strategy_file, None) 292 dst_layout_map = _extract_layout_map(dst_strategy_file, rank_id) 293 if dst_layout_map is None: 294 return src_layout_map, dst_layout_map 295 dst_stage_device_num = np.prod(dst_layout_map.get(list(dst_layout_map.keys())[0])[0]) 296 dst_stage_id = rank_id // dst_stage_device_num 297 # cut the source and destination layout, remain the parameter in the dst_stage 298 for param_name in list(dst_layout_map.keys()): 299 if dst_stage_id in dst_layout_map.get(param_name)[6]: 300 continue 301 dst_layout_map.pop(param_name) 302 if src_layout_map is not None and param_name in src_layout_map: 303 src_layout_map.pop(param_name) 304 return src_layout_map, dst_layout_map 305 306 307def _restore_group_info_list(group_info_file_name): 308 """restore group info""" 309 parallel_group_map = ms.train.node_strategy_pb2.ParallelGroupMap() 310 311 with open(group_info_file_name, 'rb') as f: 312 pb_content = f.read() 313 parallel_group_map.ParseFromString(pb_content) 314 315 restore_list = parallel_group_map.ckpt_restore_rank_list 316 if not restore_list: 317 raise ValueError("For 'restore_group_info_list', the group information file has no restore rank list.") 318 319 return [rank for rank in restore_list.dim] 320 321 322def _get_device_num_from_strategy(strategy_file=None): 323 """Get device num from strategy file""" 324 if strategy_file is None: 325 return 1 326 src_strategy = _build_searched_strategy(strategy_file) 327 strategy_list = _convert_to_list(src_strategy) 328 device_mat = list(strategy_list.values())[0][0] 329 return np.prod(device_mat) 330 331 332def _rank_list_for_transform_parallel_checkpoint(rank_id, src_strategy_list, dst_strategy_list): 333 """ 334 Get the needed rank list for transform model parallel dim of checkpoint. 335 """ 336 result_list = set() 337 handled_layout = [] 338 for param_name, _ in src_strategy_list.items(): 339 if dst_strategy_list is not None and param_name not in dst_strategy_list: 340 continue 341 from_dev_matrix, from_tensor_map, from_opt_shard_step, from_opt_shard_size = _extract_layout_item( 342 src_strategy_list.get(param_name)) 343 from_device_num = np.prod(from_dev_matrix) 344 fake_tensor_shape = [8] * len(from_tensor_map) 345 to_dev_matrix = [1] 346 to_tensor_map = [-1] * len(fake_tensor_shape) 347 to_opt_shard_step = 0 348 to_opt_shard_size = 0 349 if dst_strategy_list is not None: 350 to_dev_matrix, to_tensor_map, to_opt_shard_step, to_opt_shard_size = _extract_layout_item( 351 dst_strategy_list.get(param_name)) 352 handled_key = (from_dev_matrix, from_tensor_map, from_opt_shard_step, from_opt_shard_size, 353 to_dev_matrix, to_tensor_map, to_opt_shard_step, to_opt_shard_size) 354 if handled_key in handled_layout: 355 continue 356 handled_layout.append(handled_key) 357 param_strategy = _get_tensor_strategy(from_dev_matrix, from_tensor_map) 358 origin_tensor_shape = () 359 for i, item in enumerate(fake_tensor_shape): 360 if i == 0 and from_opt_shard_size > 0: 361 origin_tensor_shape += (item * param_strategy[i] * from_opt_shard_size,) 362 continue 363 origin_tensor_shape += (item * param_strategy[i],) 364 365 from_dev_matrix, from_tensor_map, from_full_tensor_shape = _construct_tensor_layout_for_opt_shard( 366 from_dev_matrix, from_tensor_map, from_opt_shard_step, from_opt_shard_size, origin_tensor_shape) 367 to_dev_matrix, to_tensor_map, to_full_tensor_shape = _construct_tensor_layout_for_opt_shard( 368 to_dev_matrix, to_tensor_map, to_opt_shard_step, to_opt_shard_size, origin_tensor_shape) 369 # Convert tensor layout to same device num 370 from_tensor_layout, to_tensor_layout = _construct_from_to_tensor_layout(from_full_tensor_shape, from_dev_matrix, 371 from_tensor_map, to_full_tensor_shape, 372 to_dev_matrix, to_tensor_map) 373 device_list = list(range(0, np.prod(from_tensor_layout[0]))) 374 param_rank_list = _get_needed_rank_list_by_layouts(from_tensor_layout, to_tensor_layout, device_list, rank_id) 375 param_rank_list_new = [rank % from_device_num for rank in param_rank_list] 376 param_rank_set_new = set(param_rank_list_new) 377 result_list.update(param_rank_set_new) 378 return list(result_list) 379 380 381def _transform_parallel_checkpoint(rank_id, param_total_dict, param_attr_dict, src_strategy_list, 382 dst_strategy_list, param_type_dict): 383 """ 384 Transform model parallel dimension for distributed checkpoint files. 385 """ 386 transform_param_dict = {} 387 device_num = -1 388 for param_name, _ in param_total_dict.items(): 389 tensor_shape = list(param_total_dict[param_name].values())[0].shape 390 from_dev_matrix = [1] 391 from_tensor_map = [-1] * len(tensor_shape) 392 from_opt_shard_step = 0 393 from_opt_shard_size = 0 394 if src_strategy_list is not None: 395 if param_name not in src_strategy_list: 396 ms.log.warning("The parameter {} is not in src_strategy.".format(param_name)) 397 continue 398 from_dev_matrix, from_tensor_map, from_opt_shard_step, from_opt_shard_size = _extract_layout_item( 399 src_strategy_list.get(param_name)) 400 to_dev_matrix_origin = [1] 401 to_tensor_map_origin = [-1] * len(tensor_shape) 402 to_opt_shard_step = 0 403 to_opt_shard_size = 0 404 if dst_strategy_list is not None: 405 if param_name not in dst_strategy_list: 406 ms.log.warning("The parameter {} is not in dst_strategy.".format(param_name)) 407 continue 408 to_dev_matrix_origin, to_tensor_map_origin, to_opt_shard_step, to_opt_shard_size = _extract_layout_item( 409 dst_strategy_list.get(param_name)) 410 # Add optimizer sharding dim for tensor layout 411 device_num = np.prod(from_dev_matrix) 412 param_strategy = _get_tensor_strategy(from_dev_matrix, from_tensor_map) 413 origin_tensor_shape = () 414 for i, item in enumerate(tensor_shape): 415 if i == 0 and from_opt_shard_size > 0: 416 origin_tensor_shape += (item * param_strategy[i] * from_opt_shard_size,) 417 continue 418 origin_tensor_shape += (item * param_strategy[i],) 419 420 from_dev_matrix, from_tensor_map, from_full_tensor_shape = _construct_tensor_layout_for_opt_shard( 421 from_dev_matrix, from_tensor_map, from_opt_shard_step, from_opt_shard_size, origin_tensor_shape) 422 to_dev_matrix, to_tensor_map, to_full_tensor_shape = _construct_tensor_layout_for_opt_shard( 423 to_dev_matrix_origin, to_tensor_map_origin, to_opt_shard_step, to_opt_shard_size, origin_tensor_shape) 424 # Convert tensor layout to same device num 425 from_tensor_layout, to_tensor_layout = _construct_from_to_tensor_layout(from_full_tensor_shape, from_dev_matrix, 426 from_tensor_map, to_full_tensor_shape, 427 to_dev_matrix, to_tensor_map) 428 429 # when the from_layout is less devices, the checkpoint_map for map[device_num] should using map[0] 430 device_list = list(range(0, np.prod(from_tensor_layout[0]))) 431 if rank_id % device_num not in param_attr_dict[param_name]: 432 raise ValueError("The checkpoint of rank {} is missing.".format(rank_id % device_num)) 433 param_rank_map = _get_needed_rank_transform_operator_map_by_layouts(from_tensor_layout, to_tensor_layout, 434 device_list, rank_id) 435 436 437 from_info_tuple = (from_opt_shard_size, from_dev_matrix, from_tensor_map, from_full_tensor_shape) 438 to_info_tuple = (to_opt_shard_size, to_dev_matrix_origin, to_tensor_map_origin, origin_tensor_shape) 439 _insert_opt_shard_reshape(param_rank_map, from_info_tuple, to_info_tuple) 440 transform_operator_stack = _generate_transform_operator_stack(param_rank_map, rank_id) 441 param_total_dict_copy = param_total_dict[param_name].copy() 442 _apply_tensor_transform_operators(transform_operator_stack, param_total_dict_copy, device_num) 443 transform_tensor = ms.Tensor(param_total_dict_copy[rank_id % device_num]) 444 requires_grad = param_attr_dict[param_name][rank_id % device_num][0] 445 layerwise_parallel = param_attr_dict[param_name][rank_id % device_num][1] 446 transform_para = ms.Parameter(transform_tensor, param_name, requires_grad, layerwise_parallel) 447 if param_type_dict[param_name][rank_id % device_num] == "BFloat16": 448 transform_para.set_dtype(ms.bfloat16) 449 transform_param_dict[param_name] = transform_para 450 if device_num < 1: 451 raise ValueError("None of the parameters in checkpoint file are in either src strategy or " 452 "dst strategy. Please check correctness of strategy files.") 453 454 # Handle those parameter like learning_rate, global_step which not in strategy_file. 455 for param_name, _ in param_total_dict.items(): 456 if param_name not in transform_param_dict: 457 transform_para = ms.Parameter( 458 ms.Tensor(param_total_dict[param_name][rank_id % device_num]), param_name, 459 param_attr_dict[param_name][rank_id % device_num][0], 460 param_attr_dict[param_name][rank_id % device_num][1]) 461 if param_type_dict[param_name][rank_id % device_num] == "BFloat16": 462 transform_para.set_dtype(ms.bfloat16) 463 transform_param_dict[param_name] = transform_para 464 465 transform_param_list = [{"name": param_name, "data": param_data} 466 for param_name, param_data in transform_param_dict.items()] 467 return transform_param_list 468 469 470def _make_dir(path, arg_name): 471 """Make directory.""" 472 if not isinstance(path, str): 473 ms.log.critical("The %s is invalid, the type should be string.", arg_name) 474 raise TypeError("The {} is invalid, the type should be string.".format(arg_name)) 475 if path.strip() == "": 476 ms.log.critical("The %s is invalid, it should be non-blank.", arg_name) 477 raise ValueError("The {} is invalid, it should be non-blank.".format(arg_name)) 478 479 path = os.path.realpath(path) 480 481 if len(path) > MAX_PATH_LENGTH: 482 ms.log.critical("The %s length is too long, it should be limited in %s.", arg_name, MAX_PATH_LENGTH) 483 raise ValueError("The {} length is too long, it should be limited in {}.".format(arg_name, MAX_PATH_LENGTH)) 484 485 ms.log.debug("The abs path is %r", path) 486 487 if os.path.exists(path): 488 if not os.path.isdir(path): 489 ms.log.critical("The path(%r) is a file path, it should be a directory path.", path) 490 raise NotADirectoryError("The path({}) is a file path, it should be a directory path.".format(path)) 491 real_path = path 492 else: 493 ms.log.debug("The directory(%s) doesn't exist, will create it", path) 494 try: 495 permissions = os.R_OK | os.W_OK | os.X_OK 496 os.umask(permissions << 3 | permissions) 497 mode = permissions << 6 498 os.makedirs(path, mode=mode, exist_ok=True) 499 real_path = path 500 except PermissionError as e: 501 ms.log.critical("No write permission on the directory(%r), error = %r", path, e) 502 raise TypeError("No write permission on the directory.") from e 503 finally: 504 pass 505 return real_path 506 507 508def _insert_opt_shard_reshape(param_rank_map, from_info_tuple, to_info_tuple): 509 """insert opt_shard op reshape""" 510 from_opt_shard_size = from_info_tuple[0] 511 from_dev_matrix = from_info_tuple[1] 512 from_tensor_map = from_info_tuple[2] 513 from_full_tensor_shape = from_info_tuple[3] 514 to_opt_shard_size = to_info_tuple[0] 515 to_dev_matrix_origin = to_info_tuple[1] 516 to_tensor_map_origin = to_info_tuple[2] 517 origin_tensor_shape = to_info_tuple[3] 518 for param_rank, _ in param_rank_map.items(): 519 if from_opt_shard_size > 0: 520 from_tensor_strategy = _get_tensor_strategy(from_dev_matrix, from_tensor_map) 521 from_slice_tensor_shape = () 522 for i, item in enumerate(from_full_tensor_shape): 523 from_slice_tensor_shape += (item // from_tensor_strategy[i],) 524 param_rank_map.get(param_rank).insert(0, ('Reshape', list(from_slice_tensor_shape))) 525 if to_opt_shard_size > 0: 526 to_tensor_strategy = _get_tensor_strategy(to_dev_matrix_origin, to_tensor_map_origin) 527 to_slice_tensor_shape = () 528 for i, item in enumerate(origin_tensor_shape): 529 if i == 0 and to_opt_shard_size > 0: 530 to_slice_tensor_shape += (item // (to_tensor_strategy[i] * to_opt_shard_size),) 531 continue 532 to_slice_tensor_shape += (item // to_tensor_strategy[i],) 533 param_rank_map.get(param_rank).append(('Reshape', list(to_slice_tensor_shape))) 534