• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Owner(s): ["oncall: distributed"]
2
3import sys
4
5import torch
6import torch.nn as nn
7import torch.optim as optim
8from torch import distributed as dist
9from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
10from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
11from torch.testing._internal.common_fsdp import FSDPTest
12from torch.testing._internal.common_utils import (
13    instantiate_parametrized_tests,
14    parametrize,
15    run_tests,
16    TEST_WITH_DEV_DBG_ASAN,
17)
18from torch.utils.checkpoint import checkpoint
19
20
21if not dist.is_available():
22    print("Distributed not available, skipping tests", file=sys.stderr)
23    sys.exit(0)
24
25if TEST_WITH_DEV_DBG_ASAN:
26    print(
27        "Skip dev-asan as torch + multiprocessing spawn have known issues",
28        file=sys.stderr,
29    )
30    sys.exit(0)
31
32
33def get_cur_mem(rank, result, prefix):
34    """Collect memory allocated values in a result dict in MB"""
35    torch._C._cuda_clearCublasWorkspaces()
36    result[prefix] = round(torch.cuda.memory_allocated() / 1024 / 1024)
37
38
39class Model(nn.Module):
40    def __init__(self, hidden_dim, with_fsdp=False, with_checkpoint=False):
41        super().__init__()
42        if with_fsdp:
43            self.stem = nn.Sequential(
44                nn.Conv2d(3, 64, kernel_size=3),
45                FSDP(nn.BatchNorm2d(64)),
46                nn.ReLU(inplace=True),
47            )
48        else:
49            self.stem = nn.Sequential(
50                nn.Conv2d(3, 64, kernel_size=3),
51                nn.BatchNorm2d(64),
52                nn.ReLU(inplace=True),
53            )
54        if with_fsdp:
55            self.blocks = nn.Sequential(
56                nn.Conv2d(64, hidden_dim, kernel_size=5, padding=2),
57                FSDP(nn.BatchNorm2d(hidden_dim)),
58                nn.ReLU(inplace=True),
59                nn.Conv2d(hidden_dim, hidden_dim, kernel_size=5, padding=2),
60                FSDP(nn.BatchNorm2d(hidden_dim)),
61                nn.ReLU(inplace=True),
62                nn.Conv2d(hidden_dim, hidden_dim, kernel_size=5, padding=2),
63                FSDP(nn.BatchNorm2d(hidden_dim)),
64                nn.ReLU(inplace=True),
65                nn.AdaptiveAvgPool2d(output_size=(1, 1)),
66                nn.Flatten(),
67            )
68        else:
69            self.blocks = nn.Sequential(
70                nn.Conv2d(64, hidden_dim, kernel_size=5, padding=2),
71                nn.BatchNorm2d(hidden_dim),
72                nn.ReLU(inplace=True),
73                nn.Conv2d(hidden_dim, hidden_dim, kernel_size=5, padding=2),
74                nn.BatchNorm2d(hidden_dim),
75                nn.ReLU(inplace=True),
76                nn.Conv2d(hidden_dim, hidden_dim, kernel_size=5, padding=2),
77                nn.BatchNorm2d(hidden_dim),
78                nn.ReLU(inplace=True),
79                nn.AdaptiveAvgPool2d(output_size=(1, 1)),
80                nn.Flatten(),
81            )
82
83        self.head = nn.Linear(hidden_dim, 10)
84        self.with_checkpoint = with_checkpoint
85
86    def forward(self, x):
87        if self.with_checkpoint:
88            return self.head(checkpoint(self.blocks, self.stem(x), use_reentrant=True))
89        else:
90            return self.head(self.blocks(self.stem(x)))
91
92
93def create_model(with_fsdp, with_checkpoint, model_hidden_dim):
94    torch.manual_seed(0)
95    model = Model(model_hidden_dim, with_fsdp, with_checkpoint)
96    if with_fsdp:
97        model.stem = FSDP(model.stem)
98        model.blocks = FSDP(model.blocks)
99        model.head = FSDP(model.head)
100
101    return model
102
103
104class TestFSDPMemory(FSDPTest):
105    @property
106    def world_size(self):
107        return 2
108
109    def _dist_train(self, with_checkpoint, expected, model_hidden_dim, iterations):
110        gpu_id = self.rank
111        world_size = self.world_size
112
113        batch = torch.randn(size=(2, 3, 224, 224)).cuda()
114
115        model = create_model(
116            with_fsdp=True,
117            with_checkpoint=with_checkpoint,
118            model_hidden_dim=model_hidden_dim,
119        )
120        model = model.cuda()
121        model = FSDP(model)
122
123        # We enable momentum so that after the first iteration, the optimizer state is added
124        # to the total memory used.
125        criterion = nn.MSELoss()
126        optimizer = optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
127
128        results = {}  # results of memory stats
129        for iteration in range(iterations):
130            get_cur_mem(gpu_id, results, f"iter {iteration}: start")
131
132            out = model(batch)
133            get_cur_mem(gpu_id, results, f"iter {iteration}: after fwd")
134
135            out = sum(o.sum() for o in out[0])
136            fake_loss = criterion(out, torch.tensor(0.0).cuda())
137            get_cur_mem(gpu_id, results, f"iter {iteration}: after loss")
138
139            fake_loss.backward()
140            get_cur_mem(gpu_id, results, f"iter {iteration}: after bwd")
141
142            optimizer.step()
143            get_cur_mem(gpu_id, results, f"iter {iteration}: after step")
144
145            # It is important to use `set_to_none` below, not optimizer.zero_grad() to reclaim memory.
146            model.zero_grad(set_to_none=True)
147            get_cur_mem(gpu_id, results, f"iter {iteration}: done")
148
149        def cmp(results, expected):
150            ret = ""
151            self.assertEqual(results.keys(), expected.keys())
152            for k, v in results.items():
153                exp = expected[k]
154                if abs(exp - v) > 1:  # allow 1MB rounding differences
155                    ret += f"{k}: got {v}, expected {exp}\n"
156            return ret
157
158        output = cmp(results, expected)
159        self.assertEqual(output, "")
160
161    @skip_if_lt_x_gpu(2)
162    @parametrize("ckpt", ["no_ckpt", "ckpt"])
163    def test_fsdp_memory(self, ckpt):
164        # hidden_dim 128: model size ~4MB
165        model_hidden_dim = 128
166
167        model = create_model(
168            with_fsdp=False, with_checkpoint=False, model_hidden_dim=model_hidden_dim
169        ).cuda()
170        model_size_mb = round(torch.cuda.memory_allocated() / 1024 / 1024)
171        del model
172
173        sharded_model_size_mb = int(model_size_mb / self.world_size)
174
175        # We have observed that sometimes after 3rd iteration, 4th one can fail (not on this
176        # test but on much bigger scale tests). We run 4 iterations here just in case it happens.
177        iterations = 4
178
179        expected = {}
180
181        for iteration in range(iterations):
182            if iteration == 0:
183                # sharded model size + 1MB temp memory
184                expected[f"iter {iteration}: start"] = sharded_model_size_mb + 1
185                # it is hard to calculate this memory size, get it from printed memory usage
186                if ckpt == "ckpt":
187                    expected[f"iter {iteration}: after fwd"] = 51
188                    expected[f"iter {iteration}: after loss"] = 51
189                else:
190                    expected[f"iter {iteration}: after fwd"] = 340
191                    expected[f"iter {iteration}: after loss"] = 340
192                # sharded model size + sharded grad size + 1M temp memory
193                expected[f"iter {iteration}: after bwd"] = 2 * sharded_model_size_mb + 1
194            else:
195                # after optimizer step in the first iteration, memory usage increased by
196                # sharded_model_size_mb because of increased optimizer states memory usage
197                expected[f"iter {iteration}: start"] = 2 * sharded_model_size_mb + 1
198                if ckpt == "ckpt":
199                    expected[f"iter {iteration}: after fwd"] = (
200                        51 + sharded_model_size_mb
201                    )
202                    expected[f"iter {iteration}: after loss"] = (
203                        51 + sharded_model_size_mb
204                    )
205                else:
206                    expected[f"iter {iteration}: after fwd"] = (
207                        340 + sharded_model_size_mb
208                    )
209                    expected[f"iter {iteration}: after loss"] = (
210                        340 + sharded_model_size_mb
211                    )
212                expected[f"iter {iteration}: after bwd"] = 3 * sharded_model_size_mb + 1
213
214            # sharded model size + sharded grad size + optimizer states + 1M temp memory
215            expected[f"iter {iteration}: after step"] = 3 * sharded_model_size_mb + 1
216            # grad memory is claimed after setting grad = None
217            # sharded model size + optimizer states + 1M temp memory
218            expected[f"iter {iteration}: done"] = 2 * sharded_model_size_mb + 1
219
220        # Get the fsdp and checkpoint flags.
221        with_ckpt = ckpt == "ckpt"
222
223        self._dist_train(
224            with_ckpt,
225            expected,
226            model_hidden_dim,
227            iterations,
228        )
229
230
231instantiate_parametrized_tests(TestFSDPMemory)
232
233
234if __name__ == "__main__":
235    run_tests()
236