• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) Meta Platforms, Inc. and affiliates
2
3import copy
4from typing import TYPE_CHECKING
5
6import torch.distributed as dist
7from torch.distributed._shard.sharded_tensor import Shard, ShardedTensor, ShardMetadata
8from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
9from torch.distributed.remote_device import _remote_device
10
11from ._traverse import OBJ_PATH, set_element, STATE_DICT_ITEM, traverse_state_dict
12from .utils import _element_wise_add, _normalize_device_info
13
14
15if TYPE_CHECKING:
16    from torch.distributed._shard.sharded_tensor.metadata import ShardedTensorMetadata
17
18
19# TODO: We need to refactor this code.
20def _flatten_sharded_tensors(state_dict: STATE_DICT_TYPE) -> STATE_DICT_TYPE:
21    r"""
22    Transform ``state_dict`` by flattening all nested ShardedTensor instances found.
23
24    The resulting ShardedTensor instances are only correct regarding the local shard and
25    MUST not be used for any other purpose but checkpointing, as no operator will work with them.
26
27    This function should be used in conjunction with a state_dict produced by FSDP's
28    StateDictType.SHARDED_STATE_DICT methods.
29    """
30    new_state_dict: STATE_DICT_TYPE = {}
31
32    def rewrite_dict(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None:
33        if not isinstance(value, ShardedTensor):
34            set_element(new_state_dict, path, value)
35            return
36        shards = value.local_shards()
37
38        if len(shards) == 0:
39            return
40        if len(shards) != 1:
41            set_element(new_state_dict, path, value)
42            return
43
44        outer_shard = shards[0]
45
46        inner_st = outer_shard.tensor
47        if not isinstance(inner_st, ShardedTensor):
48            set_element(new_state_dict, path, value)
49            return
50
51        if len(inner_st.local_shards()) != 1:
52            raise ValueError("Cannot handle inner tensor with more than 1 shard")
53        inner_shard = inner_st.local_shards()[0]
54
55        local_shards = [
56            Shard(
57                tensor=inner_shard.tensor,
58                metadata=ShardMetadata(
59                    shard_offsets=_element_wise_add(
60                        outer_shard.metadata.shard_offsets,
61                        inner_shard.metadata.shard_offsets,
62                    ),
63                    shard_sizes=inner_shard.metadata.shard_sizes,
64                    placement=f"rank:{dist.get_rank()}/{inner_shard.tensor.device}",
65                ),
66            )
67        ]
68
69        st_meta: ShardedTensorMetadata = copy.deepcopy(value.metadata())
70        other_rank = 0 if dist.get_rank() > 0 else 1
71        device_info = _normalize_device_info(inner_shard.tensor.device.type, 0)
72
73        # Remove the outer ST shard the inner ST covers
74        for i, shard_md in enumerate(st_meta.shards_metadata):
75            if shard_md.shard_offsets == outer_shard.metadata.shard_offsets:
76                st_meta.shards_metadata.pop(i)
77                break
78
79        # Attribute other rank for the other shards
80        for shard_md in st_meta.shards_metadata:
81            shard_md.placement = _remote_device(f"rank:{other_rank}/{device_info}")
82
83        # Add other inner shards from the inner tensor
84        for inner_md in inner_st.metadata().shards_metadata:
85            if inner_md.shard_offsets != inner_shard.metadata.shard_offsets:
86                st_meta.shards_metadata.append(
87                    ShardMetadata(
88                        shard_offsets=_element_wise_add(
89                            outer_shard.metadata.shard_offsets,
90                            inner_md.shard_offsets,
91                        ),
92                        shard_sizes=inner_md.shard_sizes,
93                        placement=f"rank:{other_rank}/{device_info}",
94                    )
95                )
96
97        # Finally add this shard
98        st_meta.shards_metadata.append(local_shards[0].metadata)
99
100        st = ShardedTensor._init_from_local_shards_and_global_metadata(
101            local_shards=local_shards,
102            sharded_tensor_metadata=st_meta,
103        )
104        set_element(new_state_dict, path, st)
105
106    traverse_state_dict(state_dict, rewrite_dict)
107    return new_state_dict
108