1# Copyright (c) Meta Platforms, Inc. and affiliates 2 3import copy 4from typing import TYPE_CHECKING 5 6import torch.distributed as dist 7from torch.distributed._shard.sharded_tensor import Shard, ShardedTensor, ShardMetadata 8from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE 9from torch.distributed.remote_device import _remote_device 10 11from ._traverse import OBJ_PATH, set_element, STATE_DICT_ITEM, traverse_state_dict 12from .utils import _element_wise_add, _normalize_device_info 13 14 15if TYPE_CHECKING: 16 from torch.distributed._shard.sharded_tensor.metadata import ShardedTensorMetadata 17 18 19# TODO: We need to refactor this code. 20def _flatten_sharded_tensors(state_dict: STATE_DICT_TYPE) -> STATE_DICT_TYPE: 21 r""" 22 Transform ``state_dict`` by flattening all nested ShardedTensor instances found. 23 24 The resulting ShardedTensor instances are only correct regarding the local shard and 25 MUST not be used for any other purpose but checkpointing, as no operator will work with them. 26 27 This function should be used in conjunction with a state_dict produced by FSDP's 28 StateDictType.SHARDED_STATE_DICT methods. 29 """ 30 new_state_dict: STATE_DICT_TYPE = {} 31 32 def rewrite_dict(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None: 33 if not isinstance(value, ShardedTensor): 34 set_element(new_state_dict, path, value) 35 return 36 shards = value.local_shards() 37 38 if len(shards) == 0: 39 return 40 if len(shards) != 1: 41 set_element(new_state_dict, path, value) 42 return 43 44 outer_shard = shards[0] 45 46 inner_st = outer_shard.tensor 47 if not isinstance(inner_st, ShardedTensor): 48 set_element(new_state_dict, path, value) 49 return 50 51 if len(inner_st.local_shards()) != 1: 52 raise ValueError("Cannot handle inner tensor with more than 1 shard") 53 inner_shard = inner_st.local_shards()[0] 54 55 local_shards = [ 56 Shard( 57 tensor=inner_shard.tensor, 58 metadata=ShardMetadata( 59 shard_offsets=_element_wise_add( 60 outer_shard.metadata.shard_offsets, 61 inner_shard.metadata.shard_offsets, 62 ), 63 shard_sizes=inner_shard.metadata.shard_sizes, 64 placement=f"rank:{dist.get_rank()}/{inner_shard.tensor.device}", 65 ), 66 ) 67 ] 68 69 st_meta: ShardedTensorMetadata = copy.deepcopy(value.metadata()) 70 other_rank = 0 if dist.get_rank() > 0 else 1 71 device_info = _normalize_device_info(inner_shard.tensor.device.type, 0) 72 73 # Remove the outer ST shard the inner ST covers 74 for i, shard_md in enumerate(st_meta.shards_metadata): 75 if shard_md.shard_offsets == outer_shard.metadata.shard_offsets: 76 st_meta.shards_metadata.pop(i) 77 break 78 79 # Attribute other rank for the other shards 80 for shard_md in st_meta.shards_metadata: 81 shard_md.placement = _remote_device(f"rank:{other_rank}/{device_info}") 82 83 # Add other inner shards from the inner tensor 84 for inner_md in inner_st.metadata().shards_metadata: 85 if inner_md.shard_offsets != inner_shard.metadata.shard_offsets: 86 st_meta.shards_metadata.append( 87 ShardMetadata( 88 shard_offsets=_element_wise_add( 89 outer_shard.metadata.shard_offsets, 90 inner_md.shard_offsets, 91 ), 92 shard_sizes=inner_md.shard_sizes, 93 placement=f"rank:{other_rank}/{device_info}", 94 ) 95 ) 96 97 # Finally add this shard 98 st_meta.shards_metadata.append(local_shards[0].metadata) 99 100 st = ShardedTensor._init_from_local_shards_and_global_metadata( 101 local_shards=local_shards, 102 sharded_tensor_metadata=st_meta, 103 ) 104 set_element(new_state_dict, path, st) 105 106 traverse_state_dict(state_dict, rewrite_dict) 107 return new_state_dict 108