• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) Meta Platforms, Inc. and affiliates
2from functools import partial
3from typing import Any, Optional, Tuple
4
5import torch
6from torch.distributed.tensor import DeviceMesh, DTensor, Replicate, Shard
7
8
9__all__ = [
10    "input_reshard",
11]
12
13
14def input_reshard(
15    module: torch.nn.Module,
16    tp_device_mesh: DeviceMesh,
17    input_reshard_dim: Optional[int] = None,
18) -> torch.nn.Module:
19    """
20    Register hooks to an nn.Module for input resharding, enabling sharding and restoration during backward computation.
21
22    Register hooks to an nn.Module with input resharding so that we can shard
23    per the given `tp_device_mesh` and `input_reshard_dim` and restore the
24    input back when recomputing the activations in the backward. The reason
25    why we can do this is that for Tensor Parallel(TP), the input are same
26    across all TP ranks.
27
28    Args:
29        module (:class:`nn.Module`):
30            Module to be registered with input resharding.
31        tp_device_mesh (:class:`DeviceMesh`):
32            Object which describes the mesh topology
33            of devices for Tensor Parallel.
34        input_reshard_dim (Optional[int]):
35            The dimension of where we perform the sharding
36            of input. If set None, there is no sharding of input.
37            Default: None
38
39    Return:
40        A :class:`nn.Module` object registered with TP input resharding.
41    """
42    cx: Optional[torch.autograd.graph.saved_tensors_hooks] = None
43
44    def input_reshard_forward_pre_hook(_: torch.nn.Module, _i: Tuple[Any, ...]) -> None:
45        saved_tensor_hooks = torch.autograd.graph.saved_tensors_hooks(
46            partial(_pack_hook_tp, tp_device_mesh, input_reshard_dim),
47            partial(_unpack_hook_tp, tp_device_mesh, input_reshard_dim),
48        )
49        saved_tensor_hooks.__enter__()
50        nonlocal cx
51        cx = saved_tensor_hooks  # type: ignore[name-defined]
52
53    def input_reshard_backward_hook(
54        _: torch.nn.Module, _i: Tuple[Any, ...], _o: Any
55    ) -> Any:
56        nonlocal cx
57        cx.__exit__()  # type: ignore[name-defined, union-attr]
58
59    if input_reshard_dim is None:
60        return module
61    module.register_forward_pre_hook(input_reshard_forward_pre_hook)
62    module.register_forward_hook(input_reshard_backward_hook)
63    return module
64
65
66def _pack_hook_tp(
67    mesh: DeviceMesh, input_reshard_dim: int, x: torch.Tensor
68) -> Any:  # noqa: D401
69    """Hook function called after FWD to shard input."""
70    if isinstance(x, DTensor) and all(p.is_replicate() for p in x._spec.placements):
71        return x.redistribute(device_mesh=mesh, placements=[Shard(input_reshard_dim)])
72    elif (
73        not isinstance(x, DTensor)
74        and isinstance(x, torch.Tensor)
75        and x.numel() >= mesh.size()
76    ):
77        return (
78            DTensor.from_local(x, device_mesh=mesh)
79            .redistribute(device_mesh=mesh, placements=[Shard(input_reshard_dim)])
80            .to_local()
81        )
82    else:
83        return x
84
85
86def _unpack_hook_tp(
87    mesh: DeviceMesh, input_reshard_dim: int, x: Any
88) -> torch.Tensor:  # noqa: D401
89    """Hook function called before activation recomputing in BWD to restore input."""
90    if (
91        isinstance(x, DTensor)
92        and len(x._spec.placements) == 1
93        and x._spec.placements[0].is_shard()
94    ):
95        return x.redistribute(device_mesh=mesh, placements=[Replicate()])
96    elif (
97        not isinstance(x, DTensor)
98        and isinstance(x, torch.Tensor)
99        and x.numel() >= mesh.size()
100    ):
101        return (
102            DTensor.from_local(
103                x, device_mesh=mesh, placements=[Shard(input_reshard_dim)]
104            )
105            .redistribute(device_mesh=mesh, placements=[Replicate()])
106            .to_local()
107        )
108    else:
109        return x
110