• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Owner(s): ["module: unknown"]
2import functools
3import gc
4from typing import Union
5
6import torch
7import torch.nn as nn
8from torch.distributed._composable import checkpoint
9from torch.distributed._composable.fsdp import (
10    CPUOffloadPolicy,
11    fully_shard,
12    MixedPrecisionPolicy,
13    OffloadPolicy,
14)
15from torch.distributed._tensor import init_device_mesh
16from torch.distributed._tools.fsdp2_mem_tracker import FSDPMemTracker
17from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
18    apply_activation_checkpointing,
19    CheckpointWrapper,
20)
21from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
22from torch.testing._internal.common_fsdp import FSDPTest, MLP
23from torch.testing._internal.common_utils import run_tests
24from torch.testing._internal.distributed._tensor.common_dtensor import (
25    ModelArgs,
26    Transformer,
27    TransformerBlock,
28)
29
30
31def _init_cublas_workspace(dev: torch.device):
32    lin = torch.nn.Linear(768, 768, device=dev)
33    inp = torch.randn(1, 768, device=dev)
34    lin(inp).sum().backward()
35    del lin
36    del inp
37
38
39def _reset_mem_stats(dev: torch.device):
40    torch.cuda.empty_cache()
41    torch.cuda.reset_accumulated_memory_stats(dev)
42    torch.cuda.reset_peak_memory_stats(dev)
43
44
45class TestTrackerFullyShard1DTrainingCore(FSDPTest):
46    @property
47    def world_size(self) -> int:
48        return min(4, torch.cuda.device_count())
49
50    @skip_if_lt_x_gpu(2)
51    def test_tracker_multi_group_eager(self):
52        """
53        Tests tracker accuracy when using multiple parameter groups for
54        communication (for communication and computation overlap plus memory
55        reduction) and different mixed precision policies.
56        """
57        self.run_subtests(
58            {
59                "reshard_after_forward": [True, False],
60                "offload_policy": [
61                    CPUOffloadPolicy(pin_memory=False),
62                    OffloadPolicy(),
63                ],
64                "mp_policy": [
65                    MixedPrecisionPolicy(
66                        param_dtype=torch.float16, reduce_dtype=torch.float32
67                    ),
68                ],
69            },
70            self._test_tracker_multi_group,
71        )
72
73    def _test_tracker_multi_group(
74        self,
75        reshard_after_forward: Union[bool, int],
76        offload_policy: OffloadPolicy,
77        mp_policy: MixedPrecisionPolicy,
78    ):
79        debug = False
80        dev = torch.device(torch.cuda.current_device())
81        _init_cublas_workspace(dev)
82        gc.collect()
83        _reset_mem_stats(dev)
84        mem_stats = torch.cuda.memory_stats(dev)
85        pre_cuda_active = mem_stats["active_bytes.all.current"]
86        torch.manual_seed(42)
87        lin_dim, bsz = 2048, 8192
88        with torch.device(dev):
89            model = nn.Sequential(*[MLP(dim=lin_dim, device=dev) for _ in range(4)])
90        mesh = init_device_mesh("cuda", (self.world_size,))
91        fully_shard_fn = functools.partial(
92            fully_shard,
93            mesh=mesh,
94            reshard_after_forward=reshard_after_forward,
95            offload_policy=offload_policy,
96            mp_policy=mp_policy,
97        )
98        for mlp in model:
99            fully_shard_fn(mlp)
100        fully_shard_fn(model)
101        optim = torch.optim.Adam(model.parameters(), lr=1e-2)
102        inp = torch.randn((bsz, lin_dim), device=dev)
103        fmt = FSDPMemTracker(model, optim)
104        fmt.track_inputs((inp,))
105        with fmt:
106            for iter_idx in range(2):
107                loss = model(inp).sum()
108                loss.backward()
109                optim.step()
110                optim.zero_grad()
111                if iter_idx == 0:
112                    fmt.reset_mod_stats()
113        mem_stats = torch.cuda.memory_stats()
114        tracker_max = fmt.get_tracker_snapshot("peak")[dev]["Total"]
115        cuda_max = mem_stats["active_bytes.all.peak"] - pre_cuda_active
116        accuracy = tracker_max / cuda_max
117        if self.rank == 0 and debug:
118            print(f"Accuracy: {accuracy} Tracker Max:{tracker_max} CUDA Max:{cuda_max}")
119        self.assertAlmostEqual(
120            accuracy,
121            1.0,
122            delta=0.1,
123            msg=f"Tracker Max:{tracker_max} CUDA Max:{cuda_max}",
124        )
125        del model
126        del inp
127        del optim
128
129    @skip_if_lt_x_gpu(2)
130    def test_tracker_non_root_forward_backward(self):
131        """
132        Tests tracker accracy when running forward/backward through a non-root.
133        """
134        debug = False
135        dev = torch.device(torch.cuda.current_device())
136        _init_cublas_workspace(dev)
137        gc.collect()
138        _reset_mem_stats(dev)
139        mem_stats = torch.cuda.memory_stats(dev)
140        pre_cuda_active = mem_stats["active_bytes.all.current"]
141        torch.manual_seed(42)
142        lin_dim, bsz = 2048, 8
143        model = nn.Sequential(*[MLP(lin_dim, dev) for _ in range(3)])
144        for mlp in model:
145            fully_shard(mlp)
146        fully_shard(model)
147        optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=True)
148        torch.manual_seed(42 + self.rank)
149        inp = torch.randn((bsz, lin_dim), device=dev)
150        fmt = FSDPMemTracker(model, optim)
151        fmt.track_inputs((inp,))
152        with fmt:
153            for iter_idx in range(2):
154                nonroot_loss = model[0](inp).sum()
155                nonroot_loss.backward()
156                optim.step()
157                optim.zero_grad()
158                if iter_idx == 0:
159                    fmt.reset_mod_stats()
160        mem_stats = torch.cuda.memory_stats()
161        tracker_max = fmt.get_tracker_snapshot("peak")[dev]["Total"]
162        cuda_max = mem_stats["active_bytes.all.peak"] - pre_cuda_active
163        accuracy = tracker_max / cuda_max
164        if self.rank == 0 and debug:
165            print(f"Accuracy: {accuracy} Tracker Max:{tracker_max} CUDA Max:{cuda_max}")
166        self.assertAlmostEqual(
167            accuracy,
168            1.0,
169            delta=0.1,
170            msg=f"Tracker Max:{tracker_max} CUDA Max:{cuda_max}",
171        )
172        del inp
173        del model
174        del optim
175
176
177class TestTrackerFullyShard1DTrainingCompose(FSDPTest):
178    @property
179    def world_size(self) -> int:
180        return min(torch.cuda.device_count(), 4)
181
182    @skip_if_lt_x_gpu(2)
183    def test_tracker_with_activation_checkpointing(self):
184        """
185        Tests tracker accuracy when composing with activation checkpointing.
186        """
187        self.run_subtests(
188            {
189                "reshard_after_forward": [True, False],
190                "checkpoint_impl": ["composable", "wrapper"],
191            },
192            self._test_tracker_with_activation_checkpointing,
193        )
194
195    def _test_tracker_with_activation_checkpointing(
196        self, reshard_after_forward: Union[bool, int], checkpoint_impl: str
197    ):
198        assert checkpoint_impl in ("composable", "wrapper")
199        debug = False
200        dev = torch.device(torch.cuda.current_device())
201        _init_cublas_workspace(dev)
202        gc.collect()
203        _reset_mem_stats(dev)
204        mem_stats = torch.cuda.memory_stats(dev)
205        pre_cuda_active = mem_stats["active_bytes.all.current"]
206        torch.manual_seed(42)
207        vocab_size = 8192
208        bsz, seq_len = 16, 512
209        with torch.device(dev):
210            model_args = ModelArgs(
211                n_layers=4,
212                n_heads=4,
213                vocab_size=vocab_size,
214                max_seq_len=seq_len,
215                dropout_p=0.1,
216            )
217            model = Transformer(model_args)
218        foreach = False
219        fully_shard_fn = functools.partial(
220            fully_shard,
221            reshard_after_forward=reshard_after_forward,
222        )
223        if checkpoint_impl == "wrapper":
224            apply_activation_checkpointing(
225                model, check_fn=lambda m: isinstance(m, TransformerBlock)
226            )
227            for module in model.modules():
228                # Apply to `CheckpointWrapper`, which wraps `TransformerBlock`
229                if isinstance(module, CheckpointWrapper):
230                    fully_shard_fn(module)
231        else:
232            for module in model.modules():
233                if isinstance(module, TransformerBlock):
234                    if checkpoint_impl == "composable":
235                        checkpoint(module)
236                    fully_shard_fn(module)
237        fully_shard_fn(model)
238        optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=foreach)
239
240        torch.manual_seed(42 + self.rank)
241        inp = torch.randint(0, vocab_size, (bsz, seq_len), device=dev)
242        fmt = FSDPMemTracker(model, optim)
243        fmt.track_inputs((inp,))
244        with fmt:
245            for iter_idx in range(2):
246                loss = model(inp).sum()
247                loss.backward()
248                optim.step()
249                optim.zero_grad()
250                if iter_idx == 0:
251                    fmt.reset_mod_stats()
252        mem_stats = torch.cuda.memory_stats()
253        tracker_max = fmt.get_tracker_snapshot("peak")[dev]["Total"]
254        cuda_max = mem_stats["active_bytes.all.peak"] - pre_cuda_active
255        accuracy = tracker_max / cuda_max
256        if self.rank == 0 and debug:
257            print(f"Accuracy: {accuracy} Tracker Max:{tracker_max} CUDA Max:{cuda_max}")
258        self.assertAlmostEqual(
259            accuracy,
260            1.0,
261            delta=0.1,
262            msg=f"Tracker Max:{tracker_max} CUDA Max:{cuda_max}",
263        )
264        del inp
265        del model
266        del optim
267
268
269if __name__ == "__main__":
270    run_tests()
271