• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2023 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""parallel serialization"""
16from __future__ import absolute_import
17
18import os
19import json
20import numpy as np
21import mindspore as ms
22from mindspore.parallel._tensor import _get_tensor_strategy, _construct_from_to_tensor_layout, \
23    _get_needed_rank_list_by_layouts, _get_needed_rank_transform_operator_map_by_layouts, \
24    _generate_transform_operator_stack, _apply_tensor_transform_operators, _construct_tensor_layout_for_opt_shard, \
25    _extract_layout_item
26
27
28MAX_PATH_LENGTH = 1024
29
30
31def _convert_to_list(strategy, rank_id=None):
32    """Convert ParallelLayouts object to specified list."""
33    train_map = {}
34    for param_name in strategy.keys():
35        try:
36            layout = strategy.get(param_name)
37            dev_mat = list(layout.dev_matrix[0].dim)
38            tensor_map = list(layout.tensor_map[0].dim)
39            param_split_shape = list(layout.param_split_shape[0].dim)
40            pipeline_stage = 0
41            origin_param_name = param_name
42            if "-" in param_name:
43                pipeline_stage, origin_param_name = param_name.split("-")
44                pipeline_stage = int(pipeline_stage)
45            if origin_param_name not in train_map:
46                train_map[origin_param_name] = [dev_mat, tensor_map, param_split_shape, int(layout.field),
47                                                int(layout.opt_weight_shard_step), int(layout.opt_weight_shard_size),
48                                                [pipeline_stage]]
49            else:
50                update_pipeline_stage_list = train_map.get(origin_param_name)[6] + [pipeline_stage]
51                if rank_id is not None:
52                    stage_device_num = np.prod(dev_mat)
53                    is_device0_and_pipeline0 = ((rank_id // stage_device_num) == 0) and (pipeline_stage == 0)
54                    not_device0_nor_pipeline0 = ((rank_id // stage_device_num) > 0) and (pipeline_stage > 0)
55                    if is_device0_and_pipeline0 or not_device0_nor_pipeline0:
56                        train_map[origin_param_name] = [dev_mat, tensor_map, param_split_shape,
57                                                        int(layout.field), int(layout.opt_weight_shard_step),
58                                                        int(layout.opt_weight_shard_size), update_pipeline_stage_list]
59                    else:
60                        train_map.get(origin_param_name)[6] = update_pipeline_stage_list
61                else:
62                    if np.all(pipeline_stage <= np.array(update_pipeline_stage_list)):
63                        train_map[origin_param_name] = [dev_mat, tensor_map, param_split_shape,
64                                                        int(layout.field), int(layout.opt_weight_shard_step),
65                                                        int(layout.opt_weight_shard_size), update_pipeline_stage_list]
66                    else:
67                        train_map.get(origin_param_name)[6] = update_pipeline_stage_list
68        except BaseException as e:
69            raise ValueError(f"{e.__str__()}. Convert layout strategy to list "
70                             f"failed, please make sure that strategy matches the node_strategy.proto, you can "
71                             f"check whether 'train_strategy_filename' is correct.") from e
72    return train_map
73
74
75def _convert_to_layout(param_name, tensor_layout):
76    """Convert list to ParallelLayouts object."""
77    strategy = {}
78    try:
79        layout = ms.train.node_strategy_pb2.ParallelLayouts()
80        layout.field = tensor_layout[3]
81
82        dev_matrix = layout.dev_matrix.add()
83        for item in tensor_layout[0]:
84            dev_matrix.dim.append(item)
85
86        tensor_map = layout.tensor_map.add()
87        for item in tensor_layout[1]:
88            tensor_map.dim.append(item)
89
90        param_split_shape = layout.param_split_shape.add()
91        for item in tensor_layout[2]:
92            param_split_shape.dim.append(item)
93    except BaseException as e:
94        raise ValueError(f"{e.__str__()}. For 'load_distributed_checkpoint', convert list to layout strategy failed, "
95                         f"you can check whether your input list is correct.") from e
96
97    strategy[param_name] = layout
98    return strategy
99
100
101def _check_strategy_file(strategy_filename):
102    """load parallel strategy file"""
103    if not isinstance(strategy_filename, str):
104        raise TypeError(f"For 'build_searched_strategy', the argument 'strategy_filename' should be string, "
105                        f"but got {type(strategy_filename)}.")
106
107    if not os.path.isfile(strategy_filename):
108        raise ValueError(f"For 'build_searched_strategy', no such strategy file: {strategy_filename}. "
109                         f"Please check whether the 'strategy_filename' exists.")
110
111    if os.path.getsize(strategy_filename) == 0:
112        raise ValueError(f"For 'build_searched_strategy', the strategy file {strategy_filename} should not "
113                         f"be empty. Please check whether the 'strategy_filename' is correct.")
114
115
116def _load_protobuf_strategy(strategy_filename):
117    """load strategy from protobuf file"""
118    parallel_strategy_map = ms.train.node_strategy_pb2.ParallelStrategyMap()
119    with open(strategy_filename, 'rb') as f:
120        pb_content = f.read()
121    try:
122        parallel_strategy_map.ParseFromString(pb_content)
123    except BaseException as e:
124        raise TypeError("The strategy file type should be one of json or protobuf. "
125                        "When the file name extension is not '.json', "
126                        "the file is considered as a protobuf file.") from e
127    return parallel_strategy_map
128
129
130def _build_protobuf_strategy(strategy_filename):
131    """build strategy from protobuf file"""
132    parallel_strategy_map = _load_protobuf_strategy(strategy_filename)
133    layout_items = parallel_strategy_map.parallel_layout_item
134    if not layout_items:
135        raise ValueError(f"For 'build_searched_strategy', the strategy file {strategy_filename} has no sliced "
136                         f"parameter, please check whether the 'strategy_filename' is correct.")
137
138    strategy = {}
139    for layout_item in layout_items:
140        parameter_name = layout_item.param_name
141        layout = layout_item.parallel_layouts
142        strategy[parameter_name] = layout
143    return strategy
144
145
146def _build_json_strategy(strategy_filename):
147    """build strategy from json file"""
148    with open(strategy_filename, 'r') as f:
149        json_content = json.load(f)
150    layout_items = json_content.get("parallel_layout_item")
151    strategy = {}
152    for parameter_name, layout_item in layout_items.items():
153        layout = ms.train.node_strategy_pb2.ParallelLayouts()
154        layout.field = layout_item.get("field")
155        layout.opt_weight_shard_size = layout_item.get("opt_weight_shard_size")
156        layout.opt_weight_shard_step = layout_item.get("opt_weight_shard_step")
157        dev_matrix = layout.dev_matrix.add()
158        for item in layout_item.get("dev_matrix"):
159            dev_matrix.dim.append(item)
160        tensor_map = layout.tensor_map.add()
161        for item in layout_item.get("tensor_map"):
162            tensor_map.dim.append(item)
163        param_split_shape = layout.param_split_shape.add()
164        if "param_split_shape" in layout_item:
165            for item in layout_item.get("param_split_shape"):
166                param_split_shape.dim.append(item)
167        indices_offset = layout.indices_offset.add()
168        if "indices_offset" in layout_item:
169            for item in layout_item.get("indices_offset"):
170                indices_offset.dim.append(item)
171        strategy[parameter_name] = layout
172    return strategy
173
174
175def _build_searched_strategy(strategy_filename):
176    """build searched strategy"""
177    _check_strategy_file(strategy_filename)
178    if strategy_filename[-5:] != ".json":
179        return _build_protobuf_strategy(strategy_filename)
180    return _build_json_strategy(strategy_filename)
181
182
183def _merge_protobuf_strategy(src_strategy_files, dst_strategy_file):
184    """merge protobuf strategy"""
185    dst_parallel_strategy_map = ms.train.node_strategy_pb2.ParallelStrategyMap()
186    merged_stage = []
187    for src_strategy_file in src_strategy_files:
188        src_parallel_strategy_map = _load_protobuf_strategy(src_strategy_file)
189        strategy_items = src_parallel_strategy_map.parallel_strategy_item
190        layout_items = src_parallel_strategy_map.parallel_layout_item
191        if not strategy_items or not layout_items:
192            raise ValueError("The strategy file {} is empty".format(src_strategy_file))
193        pipeline_stage = strategy_items[0].parallel_strategys.stage
194        if pipeline_stage in merged_stage:
195            continue
196        for layout_item in layout_items:
197            layout_item.param_name = "-".join([str(pipeline_stage), layout_item.param_name])
198        dst_parallel_strategy_map.parallel_strategy_item.extend(strategy_items)
199        dst_parallel_strategy_map.parallel_layout_item.extend(layout_items)
200        merged_stage.append(pipeline_stage)
201    dst_parallel_strategy_map.current_stage = 1
202    with open(dst_strategy_file, "wb") as f:
203        f.write(dst_parallel_strategy_map.SerializeToString())
204
205
206def _merge_json_strategy(src_strategy_files, dst_strategy_file):
207    """merge protobuf strategy"""
208    dst_parallel_strategy_map = {"current_stage": 1, "parallel_strategy_item": {}, "parallel_layout_item": {}}
209    merged_stage = []
210    for src_strategy_file in src_strategy_files:
211        with open(src_strategy_file, 'r') as f:
212            json_content = json.load(f)
213        layout_items = json_content.get("parallel_layout_item")
214        strategy_items = json_content.get("parallel_strategy_item")
215        if not strategy_items or not layout_items:
216            raise ValueError("The strategy file {} is empty".format(src_strategy_file))
217        pipeline_stage = strategy_items.get(list(strategy_items.keys())[0]).get('stage')
218        if pipeline_stage in merged_stage:
219            continue
220        for param_name, layout_item in layout_items.items():
221            new_layout_item = {}
222            new_param_name = "-".join([str(pipeline_stage), param_name])
223            new_layout_item[new_param_name] = layout_item
224            dst_parallel_strategy_map.get("parallel_layout_item").update(new_layout_item)
225        dst_parallel_strategy_map.get("parallel_strategy_item").update(strategy_items)
226        merged_stage.append(pipeline_stage)
227    with open(dst_strategy_file, "w") as f:
228        json.dump(dst_parallel_strategy_map, f)
229
230
231def _parameter_not_in_local_stage(param_name, origin_strategy_list, strategy_list):
232    """parameter whether in the local stage"""
233    if origin_strategy_list is None or strategy_list is None:
234        return True
235    return param_name in origin_strategy_list and param_name not in strategy_list
236
237
238def _extract_layout_map(strategy_file, rank_id=None):
239    """Extract layout map"""
240    layout_map = None
241    if strategy_file is not None:
242        src_strategy = _build_searched_strategy(strategy_file)
243        layout_map = _convert_to_list(src_strategy, rank_id)
244    return layout_map
245
246
247def _extract_pipeline_stage_num(strategy_file):
248    """extract pipeline stage num"""
249    pipeline_stage_num = 1
250    if strategy_file is not None:
251        src_strategy = _build_searched_strategy(strategy_file)
252        layout_map = _convert_to_list(src_strategy)
253        pipeline_stage_set = set()
254        for _, layout in layout_map.items():
255            pipeline_stage_set.update(layout[6])
256        pipeline_stage_num = len(pipeline_stage_set)
257        if list(pipeline_stage_set) != list(range(pipeline_stage_num)):
258            raise ValueError("The strategy file for pipeline parallel dose not contains all stages.")
259    return pipeline_stage_num
260
261
262def _extract_src_dst_layout_map_by_src(src_strategy_file=None, dst_strategy_file=None):
263    """Extract strategy list by src strategy"""
264    src_layout_map = _extract_layout_map(src_strategy_file)
265    dst_layout_map = _extract_layout_map(dst_strategy_file)
266    if dst_layout_map is None:
267        return src_layout_map, dst_layout_map
268    for param_name in list(dst_layout_map.keys()):
269        if param_name in src_layout_map.keys():
270            continue
271        dst_layout_map.pop(param_name)
272    stage_id = 0
273    if src_strategy_file[-5:] == ".json":
274        with open(src_strategy_file, 'r') as f:
275            json_content = json.load(f)
276        strategy_items = json_content.get("parallel_strategy_item")
277        if not strategy_items:
278            raise ValueError("The strategy file {} if empty.".format(src_strategy_file))
279        stage_id = strategy_items.get(list(strategy_items.keys())[0]).get('stage')
280    else:
281        src_parallel_strategy_map = _load_protobuf_strategy(src_strategy_file)
282        strategy_items = src_parallel_strategy_map.parallel_strategy_item
283        if not strategy_items:
284            raise ValueError("The strategy file {} if empty.".format(src_strategy_file))
285        stage_id = strategy_items[0].parallel_strategys.stage
286    return src_layout_map, dst_layout_map, stage_id
287
288
289def _extract_src_dst_layout_map(rank_id, src_strategy_file=None, dst_strategy_file=None):
290    """Extract strategy list"""
291    src_layout_map = _extract_layout_map(src_strategy_file, None)
292    dst_layout_map = _extract_layout_map(dst_strategy_file, rank_id)
293    if dst_layout_map is None:
294        return src_layout_map, dst_layout_map
295    dst_stage_device_num = np.prod(dst_layout_map.get(list(dst_layout_map.keys())[0])[0])
296    dst_stage_id = rank_id // dst_stage_device_num
297    # cut the source and destination layout, remain the parameter in the dst_stage
298    for param_name in list(dst_layout_map.keys()):
299        if dst_stage_id in dst_layout_map.get(param_name)[6]:
300            continue
301        dst_layout_map.pop(param_name)
302        if src_layout_map is not None and param_name in src_layout_map:
303            src_layout_map.pop(param_name)
304    return src_layout_map, dst_layout_map
305
306
307def _restore_group_info_list(group_info_file_name):
308    """restore group info"""
309    parallel_group_map = ms.train.node_strategy_pb2.ParallelGroupMap()
310
311    with open(group_info_file_name, 'rb') as f:
312        pb_content = f.read()
313    parallel_group_map.ParseFromString(pb_content)
314
315    restore_list = parallel_group_map.ckpt_restore_rank_list
316    if not restore_list:
317        raise ValueError("For 'restore_group_info_list', the group information file has no restore rank list.")
318
319    return [rank for rank in restore_list.dim]
320
321
322def _get_device_num_from_strategy(strategy_file=None):
323    """Get device num from strategy file"""
324    if strategy_file is None:
325        return 1
326    src_strategy = _build_searched_strategy(strategy_file)
327    strategy_list = _convert_to_list(src_strategy)
328    device_mat = list(strategy_list.values())[0][0]
329    return np.prod(device_mat)
330
331
332def _rank_list_for_transform_parallel_checkpoint(rank_id, src_strategy_list, dst_strategy_list):
333    """
334    Get the needed rank list for transform model parallel dim of checkpoint.
335    """
336    result_list = set()
337    handled_layout = []
338    for param_name, _ in src_strategy_list.items():
339        if dst_strategy_list is not None and param_name not in dst_strategy_list:
340            continue
341        from_dev_matrix, from_tensor_map, from_opt_shard_step, from_opt_shard_size = _extract_layout_item(
342            src_strategy_list.get(param_name))
343        from_device_num = np.prod(from_dev_matrix)
344        fake_tensor_shape = [8] * len(from_tensor_map)
345        to_dev_matrix = [1]
346        to_tensor_map = [-1] * len(fake_tensor_shape)
347        to_opt_shard_step = 0
348        to_opt_shard_size = 0
349        if dst_strategy_list is not None:
350            to_dev_matrix, to_tensor_map, to_opt_shard_step, to_opt_shard_size = _extract_layout_item(
351                dst_strategy_list.get(param_name))
352        handled_key = (from_dev_matrix, from_tensor_map, from_opt_shard_step, from_opt_shard_size,
353                       to_dev_matrix, to_tensor_map, to_opt_shard_step, to_opt_shard_size)
354        if handled_key in handled_layout:
355            continue
356        handled_layout.append(handled_key)
357        param_strategy = _get_tensor_strategy(from_dev_matrix, from_tensor_map)
358        origin_tensor_shape = ()
359        for i, item in enumerate(fake_tensor_shape):
360            if i == 0 and from_opt_shard_size > 0:
361                origin_tensor_shape += (item * param_strategy[i] * from_opt_shard_size,)
362                continue
363            origin_tensor_shape += (item * param_strategy[i],)
364
365        from_dev_matrix, from_tensor_map, from_full_tensor_shape = _construct_tensor_layout_for_opt_shard(
366            from_dev_matrix, from_tensor_map, from_opt_shard_step, from_opt_shard_size, origin_tensor_shape)
367        to_dev_matrix, to_tensor_map, to_full_tensor_shape = _construct_tensor_layout_for_opt_shard(
368            to_dev_matrix, to_tensor_map, to_opt_shard_step, to_opt_shard_size, origin_tensor_shape)
369        # Convert tensor layout to same device num
370        from_tensor_layout, to_tensor_layout = _construct_from_to_tensor_layout(from_full_tensor_shape, from_dev_matrix,
371                                                                                from_tensor_map, to_full_tensor_shape,
372                                                                                to_dev_matrix, to_tensor_map)
373        device_list = list(range(0, np.prod(from_tensor_layout[0])))
374        param_rank_list = _get_needed_rank_list_by_layouts(from_tensor_layout, to_tensor_layout, device_list, rank_id)
375        param_rank_list_new = [rank % from_device_num for rank in param_rank_list]
376        param_rank_set_new = set(param_rank_list_new)
377        result_list.update(param_rank_set_new)
378    return list(result_list)
379
380
381def _transform_parallel_checkpoint(rank_id, param_total_dict, param_attr_dict, src_strategy_list,
382                                   dst_strategy_list, param_type_dict):
383    """
384    Transform model parallel dimension for distributed checkpoint files.
385    """
386    transform_param_dict = {}
387    device_num = -1
388    for param_name, _ in param_total_dict.items():
389        tensor_shape = list(param_total_dict[param_name].values())[0].shape
390        from_dev_matrix = [1]
391        from_tensor_map = [-1] * len(tensor_shape)
392        from_opt_shard_step = 0
393        from_opt_shard_size = 0
394        if src_strategy_list is not None:
395            if param_name not in src_strategy_list:
396                ms.log.warning("The parameter {} is not in src_strategy.".format(param_name))
397                continue
398            from_dev_matrix, from_tensor_map, from_opt_shard_step, from_opt_shard_size = _extract_layout_item(
399                src_strategy_list.get(param_name))
400        to_dev_matrix_origin = [1]
401        to_tensor_map_origin = [-1] * len(tensor_shape)
402        to_opt_shard_step = 0
403        to_opt_shard_size = 0
404        if dst_strategy_list is not None:
405            if param_name not in dst_strategy_list:
406                ms.log.warning("The parameter {} is not in dst_strategy.".format(param_name))
407                continue
408            to_dev_matrix_origin, to_tensor_map_origin, to_opt_shard_step, to_opt_shard_size = _extract_layout_item(
409                dst_strategy_list.get(param_name))
410        # Add optimizer sharding dim for tensor layout
411        device_num = np.prod(from_dev_matrix)
412        param_strategy = _get_tensor_strategy(from_dev_matrix, from_tensor_map)
413        origin_tensor_shape = ()
414        for i, item in enumerate(tensor_shape):
415            if i == 0 and from_opt_shard_size > 0:
416                origin_tensor_shape += (item * param_strategy[i] * from_opt_shard_size,)
417                continue
418            origin_tensor_shape += (item * param_strategy[i],)
419
420        from_dev_matrix, from_tensor_map, from_full_tensor_shape = _construct_tensor_layout_for_opt_shard(
421            from_dev_matrix, from_tensor_map, from_opt_shard_step, from_opt_shard_size, origin_tensor_shape)
422        to_dev_matrix, to_tensor_map, to_full_tensor_shape = _construct_tensor_layout_for_opt_shard(
423            to_dev_matrix_origin, to_tensor_map_origin, to_opt_shard_step, to_opt_shard_size, origin_tensor_shape)
424        # Convert tensor layout to same device num
425        from_tensor_layout, to_tensor_layout = _construct_from_to_tensor_layout(from_full_tensor_shape, from_dev_matrix,
426                                                                                from_tensor_map, to_full_tensor_shape,
427                                                                                to_dev_matrix, to_tensor_map)
428
429        # when the from_layout is less devices, the checkpoint_map for map[device_num] should using map[0]
430        device_list = list(range(0, np.prod(from_tensor_layout[0])))
431        if rank_id % device_num not in param_attr_dict[param_name]:
432            raise ValueError("The checkpoint of rank {} is missing.".format(rank_id % device_num))
433        param_rank_map = _get_needed_rank_transform_operator_map_by_layouts(from_tensor_layout, to_tensor_layout,
434                                                                            device_list, rank_id)
435
436
437        from_info_tuple = (from_opt_shard_size, from_dev_matrix, from_tensor_map, from_full_tensor_shape)
438        to_info_tuple = (to_opt_shard_size, to_dev_matrix_origin, to_tensor_map_origin, origin_tensor_shape)
439        _insert_opt_shard_reshape(param_rank_map, from_info_tuple, to_info_tuple)
440        transform_operator_stack = _generate_transform_operator_stack(param_rank_map, rank_id)
441        param_total_dict_copy = param_total_dict[param_name].copy()
442        _apply_tensor_transform_operators(transform_operator_stack, param_total_dict_copy, device_num)
443        transform_tensor = ms.Tensor(param_total_dict_copy[rank_id % device_num])
444        requires_grad = param_attr_dict[param_name][rank_id % device_num][0]
445        layerwise_parallel = param_attr_dict[param_name][rank_id % device_num][1]
446        transform_para = ms.Parameter(transform_tensor, param_name, requires_grad, layerwise_parallel)
447        if param_type_dict[param_name][rank_id % device_num] == "BFloat16":
448            transform_para.set_dtype(ms.bfloat16)
449        transform_param_dict[param_name] = transform_para
450    if device_num < 1:
451        raise ValueError("None of the parameters in checkpoint file are in either src strategy or "
452                         "dst strategy. Please check correctness of strategy files.")
453
454    # Handle those parameter like learning_rate, global_step which not in strategy_file.
455    for param_name, _ in param_total_dict.items():
456        if param_name not in transform_param_dict:
457            transform_para = ms.Parameter(
458                ms.Tensor(param_total_dict[param_name][rank_id % device_num]), param_name,
459                param_attr_dict[param_name][rank_id % device_num][0],
460                param_attr_dict[param_name][rank_id % device_num][1])
461            if param_type_dict[param_name][rank_id % device_num] == "BFloat16":
462                transform_para.set_dtype(ms.bfloat16)
463            transform_param_dict[param_name] = transform_para
464
465    transform_param_list = [{"name": param_name, "data": param_data}
466                            for param_name, param_data in transform_param_dict.items()]
467    return transform_param_list
468
469
470def _make_dir(path, arg_name):
471    """Make directory."""
472    if not isinstance(path, str):
473        ms.log.critical("The %s is invalid, the type should be string.", arg_name)
474        raise TypeError("The {} is invalid, the type should be string.".format(arg_name))
475    if path.strip() == "":
476        ms.log.critical("The %s is invalid, it should be non-blank.", arg_name)
477        raise ValueError("The {} is invalid, it should be non-blank.".format(arg_name))
478
479    path = os.path.realpath(path)
480
481    if len(path) > MAX_PATH_LENGTH:
482        ms.log.critical("The %s length is too long, it should be limited in %s.", arg_name, MAX_PATH_LENGTH)
483        raise ValueError("The {} length is too long, it should be limited in {}.".format(arg_name, MAX_PATH_LENGTH))
484
485    ms.log.debug("The abs path is %r", path)
486
487    if os.path.exists(path):
488        if not os.path.isdir(path):
489            ms.log.critical("The path(%r) is a file path, it should be a directory path.", path)
490            raise NotADirectoryError("The path({}) is a file path, it should be a directory path.".format(path))
491        real_path = path
492    else:
493        ms.log.debug("The directory(%s) doesn't exist, will create it", path)
494        try:
495            permissions = os.R_OK | os.W_OK | os.X_OK
496            os.umask(permissions << 3 | permissions)
497            mode = permissions << 6
498            os.makedirs(path, mode=mode, exist_ok=True)
499            real_path = path
500        except PermissionError as e:
501            ms.log.critical("No write permission on the directory(%r), error = %r", path, e)
502            raise TypeError("No write permission on the directory.") from e
503        finally:
504            pass
505    return real_path
506
507
508def _insert_opt_shard_reshape(param_rank_map, from_info_tuple, to_info_tuple):
509    """insert opt_shard op reshape"""
510    from_opt_shard_size = from_info_tuple[0]
511    from_dev_matrix = from_info_tuple[1]
512    from_tensor_map = from_info_tuple[2]
513    from_full_tensor_shape = from_info_tuple[3]
514    to_opt_shard_size = to_info_tuple[0]
515    to_dev_matrix_origin = to_info_tuple[1]
516    to_tensor_map_origin = to_info_tuple[2]
517    origin_tensor_shape = to_info_tuple[3]
518    for param_rank, _ in param_rank_map.items():
519        if from_opt_shard_size > 0:
520            from_tensor_strategy = _get_tensor_strategy(from_dev_matrix, from_tensor_map)
521            from_slice_tensor_shape = ()
522            for i, item in enumerate(from_full_tensor_shape):
523                from_slice_tensor_shape += (item // from_tensor_strategy[i],)
524            param_rank_map.get(param_rank).insert(0, ('Reshape', list(from_slice_tensor_shape)))
525        if to_opt_shard_size > 0:
526            to_tensor_strategy = _get_tensor_strategy(to_dev_matrix_origin, to_tensor_map_origin)
527            to_slice_tensor_shape = ()
528            for i, item in enumerate(origin_tensor_shape):
529                if i == 0 and to_opt_shard_size > 0:
530                    to_slice_tensor_shape += (item // (to_tensor_strategy[i] * to_opt_shard_size),)
531                    continue
532                to_slice_tensor_shape += (item // to_tensor_strategy[i],)
533            param_rank_map.get(param_rank).append(('Reshape', list(to_slice_tensor_shape)))
534