# mypy: allow-untyped-defs import copy import itertools import math from typing import Optional import torch import torch.distributed as dist from torch._utils import _get_device_module from torch.distributed import distributed_c10d from torch.distributed._shard.sharded_tensor import ( Shard, ShardedTensor, ShardedTensorMetadata, TensorProperties, ) from torch.distributed._shard.sharding_spec import ShardMetadata from torch.distributed.tensor import DeviceMesh, DTensor, Replicate, Shard as DShard def _get_remote_device_str(rank, device_type, num_devices_per_node): if device_type.lower() == "cpu": return f"rank:{rank}/{device_type}" elif device_type.lower() == "hpu": return f"rank:{rank}/{device_type}:{_get_device_module(device_type).current_device()}" else: return f"rank:{rank}/{device_type}:{rank % num_devices_per_node}" def _create_chunk_sharded_tensor( tensor: torch.Tensor, rank: int, world_size: int, num_devices_per_node: int, pg: dist.ProcessGroup, device: Optional[torch.device] = None, ) -> ShardedTensor: """ Shard a tensor to chunks along the first dimension. The local rank will gets its corresponding chunk as the local shard to create a ShardedTensor. """ chunks = tensor.chunk(world_size, dim=0) if len(chunks) > rank: local_shard = chunks[rank].clone() offsets = [0 for _ in tensor.size()] offsets[0] = math.ceil(tensor.size()[0] / world_size) * rank local_shards = [Shard.from_tensor_and_offsets(local_shard, offsets, rank)] else: local_shards = [] # Create a ShardedTensor without invoking communication. chunk_sizes = [list(chunk.size()) for chunk in chunks] dim0_offsets = [0] + list( itertools.accumulate([chunk_size[0] for chunk_size in chunk_sizes]) )[:-1] offsets = [0] * (len(chunk_sizes[0]) - 1) chunk_offsets = [[d0] + offsets for d0 in dim0_offsets] device_type = ( distributed_c10d._get_pg_default_device(pg).type if device is None else device.type ) placements = [ _get_remote_device_str( dist.get_global_rank(pg, r), device_type, num_devices_per_node, ) for r in range(len(chunk_sizes)) ] assert len(chunk_sizes) == len(chunk_offsets) == len(placements) shard_metadata = [ ShardMetadata(offset, size, placement) for offset, size, placement in zip(chunk_offsets, chunk_sizes, placements) ] sharded_tensor_metadata = ShardedTensorMetadata( shards_metadata=shard_metadata, size=tensor.size(), tensor_properties=TensorProperties( dtype=tensor.dtype, layout=tensor.layout, requires_grad=False, memory_format=torch.contiguous_format, pin_memory=tensor.is_pinned(), ), ) return ShardedTensor._init_from_local_shards_and_global_metadata( local_shards, sharded_tensor_metadata=sharded_tensor_metadata, process_group=pg ) def _create_chunk_dtensor( tensor: torch.Tensor, rank: int, device_mesh: DeviceMesh, ) -> DTensor: """ Shard a tensor to chunks along the first dimension. The local rank will gets its corresponding chunk as the local tensor to create a DTensor. """ # We need to explicitly call .detach() to return a new tensor detached from the current graph. tensor = tensor.clone().detach() # FSDP placements: [Shard(0)] # HSDP placements: [Replicate(), Shard(0)] replicate_placements = [Replicate() for _ in range(device_mesh.ndim)] shard_placements = [Replicate() for _ in range(device_mesh.ndim)] shard_placements[-1] = DShard(0) # type: ignore[call-overload] return DTensor.from_local( tensor, device_mesh, replicate_placements, run_check=False ).redistribute( placements=shard_placements, ) def _all_gather_dtensor( tensor: DTensor, root_mesh: Optional[DeviceMesh], ) -> torch.Tensor: """ All gather a DTensor in its sharded dimension and return the local tensor. """ assert ( root_mesh == tensor.device_mesh ), "The device mesh of a tensor should be a root mesh." placements = list(copy.deepcopy(tensor.placements)) # FSDP placements: [Shard(0)] -> [Replicate()] # HSDP placements: [Replicate(), Shard(0)] -> [Replicate(), Replicate()] placements[-1] = Replicate() tensor = tensor.redistribute( device_mesh=tensor.device_mesh, placements=placements, ) return tensor.to_local()