1# Copyright (c) Meta Platforms, Inc. and affiliates 2from functools import partial 3from typing import Any, Optional, Tuple 4 5import torch 6from torch.distributed.tensor import DeviceMesh, DTensor, Replicate, Shard 7 8 9__all__ = [ 10 "input_reshard", 11] 12 13 14def input_reshard( 15 module: torch.nn.Module, 16 tp_device_mesh: DeviceMesh, 17 input_reshard_dim: Optional[int] = None, 18) -> torch.nn.Module: 19 """ 20 Register hooks to an nn.Module for input resharding, enabling sharding and restoration during backward computation. 21 22 Register hooks to an nn.Module with input resharding so that we can shard 23 per the given `tp_device_mesh` and `input_reshard_dim` and restore the 24 input back when recomputing the activations in the backward. The reason 25 why we can do this is that for Tensor Parallel(TP), the input are same 26 across all TP ranks. 27 28 Args: 29 module (:class:`nn.Module`): 30 Module to be registered with input resharding. 31 tp_device_mesh (:class:`DeviceMesh`): 32 Object which describes the mesh topology 33 of devices for Tensor Parallel. 34 input_reshard_dim (Optional[int]): 35 The dimension of where we perform the sharding 36 of input. If set None, there is no sharding of input. 37 Default: None 38 39 Return: 40 A :class:`nn.Module` object registered with TP input resharding. 41 """ 42 cx: Optional[torch.autograd.graph.saved_tensors_hooks] = None 43 44 def input_reshard_forward_pre_hook(_: torch.nn.Module, _i: Tuple[Any, ...]) -> None: 45 saved_tensor_hooks = torch.autograd.graph.saved_tensors_hooks( 46 partial(_pack_hook_tp, tp_device_mesh, input_reshard_dim), 47 partial(_unpack_hook_tp, tp_device_mesh, input_reshard_dim), 48 ) 49 saved_tensor_hooks.__enter__() 50 nonlocal cx 51 cx = saved_tensor_hooks # type: ignore[name-defined] 52 53 def input_reshard_backward_hook( 54 _: torch.nn.Module, _i: Tuple[Any, ...], _o: Any 55 ) -> Any: 56 nonlocal cx 57 cx.__exit__() # type: ignore[name-defined, union-attr] 58 59 if input_reshard_dim is None: 60 return module 61 module.register_forward_pre_hook(input_reshard_forward_pre_hook) 62 module.register_forward_hook(input_reshard_backward_hook) 63 return module 64 65 66def _pack_hook_tp( 67 mesh: DeviceMesh, input_reshard_dim: int, x: torch.Tensor 68) -> Any: # noqa: D401 69 """Hook function called after FWD to shard input.""" 70 if isinstance(x, DTensor) and all(p.is_replicate() for p in x._spec.placements): 71 return x.redistribute(device_mesh=mesh, placements=[Shard(input_reshard_dim)]) 72 elif ( 73 not isinstance(x, DTensor) 74 and isinstance(x, torch.Tensor) 75 and x.numel() >= mesh.size() 76 ): 77 return ( 78 DTensor.from_local(x, device_mesh=mesh) 79 .redistribute(device_mesh=mesh, placements=[Shard(input_reshard_dim)]) 80 .to_local() 81 ) 82 else: 83 return x 84 85 86def _unpack_hook_tp( 87 mesh: DeviceMesh, input_reshard_dim: int, x: Any 88) -> torch.Tensor: # noqa: D401 89 """Hook function called before activation recomputing in BWD to restore input.""" 90 if ( 91 isinstance(x, DTensor) 92 and len(x._spec.placements) == 1 93 and x._spec.placements[0].is_shard() 94 ): 95 return x.redistribute(device_mesh=mesh, placements=[Replicate()]) 96 elif ( 97 not isinstance(x, DTensor) 98 and isinstance(x, torch.Tensor) 99 and x.numel() >= mesh.size() 100 ): 101 return ( 102 DTensor.from_local( 103 x, device_mesh=mesh, placements=[Shard(input_reshard_dim)] 104 ) 105 .redistribute(device_mesh=mesh, placements=[Replicate()]) 106 .to_local() 107 ) 108 else: 109 return x 110