• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Owner(s): ["oncall: distributed"]
2
3import os
4import sys
5
6import torch
7import torch.distributed as dist
8import torch.distributed.checkpoint as dist_cp
9import torch.nn as nn
10from torch.distributed._tensor import init_device_mesh
11from torch.distributed.checkpoint.state_dict import (
12    get_model_state_dict,
13    get_state_dict,
14    set_model_state_dict,
15    set_state_dict,
16    StateDictOptions,
17)
18from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
19from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
20from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
21from torch.testing._internal.distributed._tensor.common_dtensor import (
22    DTensorTestBase,
23    with_comms,
24)
25from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
26
27
28if not dist.is_available():
29    print("Distributed not available, skipping tests", file=sys.stderr)
30    sys.exit(0)
31
32if TEST_WITH_DEV_DBG_ASAN:
33    print(
34        "Skip dev-asan as torch + multiprocessing spawn have known issues",
35        file=sys.stderr,
36    )
37    sys.exit(0)
38
39
40DIM = 500
41
42
43class PreTrainedModel(nn.Module):
44    def __init__(self) -> None:
45        super().__init__()
46        self.layer1 = nn.Linear(DIM, DIM)
47        self.layer2 = nn.Linear(DIM, DIM)
48        self.layer3 = nn.Linear(DIM, DIM)
49        self.sequential = nn.Sequential(nn.Linear(DIM, DIM), nn.ReLU())
50        self.module_list = nn.ModuleList([nn.Linear(DIM, DIM), nn.ReLU()])
51        self.relu = nn.ReLU()
52
53    def forward(self, batch):
54        x = self.relu(self.layer1(batch))
55        x = self.relu(self.layer2(x))
56        x = self.relu(self.layer3(x))
57        x = self.sequential(x)
58        x = self.module_list[1](self.module_list[0](x))
59        return x
60
61
62class FineTuningModel(nn.Module):
63    def __init__(self) -> None:
64        super().__init__()
65        self.pretrain = PreTrainedModel()
66        for p in self.pretrain.parameters():
67            p.requires_grad = False
68
69        self.layer1 = nn.Linear(DIM, DIM)
70        self.layer2 = nn.Linear(DIM, DIM)
71        self.layer3 = nn.Linear(DIM, DIM)
72        self.relu = nn.ReLU()
73
74    def forward(self, batch):
75        x = self.relu(self.pretrain(batch))
76        x = self.relu(self.layer1(x))
77        x = self.relu(self.layer2(x))
78        x = self.relu(self.layer3(x))
79        return x
80
81
82class TestFineTuning(DTensorTestBase):
83    @property
84    def world_size(self) -> int:
85        return min(4, torch.cuda.device_count())
86
87    @property
88    def backend(self):
89        return "cpu:gloo,cuda:nccl"
90
91    def pretrain(self, pretrain_dir: str) -> None:
92        device_mesh = init_device_mesh(self.device_type, (self.world_size,))
93
94        model = PreTrainedModel().cuda()
95        model = FSDP(model, device_mesh=device_mesh)
96        optim = torch.optim.Adam(model.parameters(), lr=1e-3)
97
98        # Training
99        for i in range(3):
100            batch = torch.rand(32, DIM, device="cuda")
101            loss = model(batch).sum()
102            loss.backward()
103            optim.step()
104            optim.zero_grad()
105
106        # Save state_dict
107        model_state_dict, optim_state_dict = get_state_dict(model, optimizers=optim)
108        saved_state_dict = {"model": model_state_dict, "optim": optim_state_dict}
109        dist_cp.save_state_dict(
110            state_dict=saved_state_dict,
111            storage_writer=dist_cp.FileSystemWriter(pretrain_dir),
112        )
113
114    def finetune(self, pretrain_dir: str, finetune_dir: str) -> None:
115        device_mesh = init_device_mesh(self.device_type, (self.world_size,))
116
117        model = FineTuningModel().cuda()
118        # TODO: make the parallelism more complicated, e.g., using 2D + DDP.
119        model = FSDP(model, use_orig_params=True, device_mesh=device_mesh)
120        optim = torch.optim.Adam(model.parameters(), lr=1e-3)
121
122        # Simulate that the fine tuning restart after 3 iterations
123        for i in range(2):
124            # Load pretrain submodules checkpoint
125            pretrain_state_dict = get_model_state_dict(
126                model,
127                submodules={model.pretrain},
128                options=StateDictOptions(keep_submodule_prefixes=False),
129            )
130            dist_cp.load_state_dict(
131                {"model": pretrain_state_dict},
132                storage_reader=dist_cp.FileSystemReader(pretrain_dir),
133            )
134            set_model_state_dict(
135                model,
136                model_state_dict={model.pretrain: pretrain_state_dict},
137                options=StateDictOptions(strict=False),
138            )
139
140            try:
141                # Load training submodules checkpoint
142                model_state_dict, optim_state_dict = get_state_dict(
143                    model,
144                    optimizers=optim,
145                    options=StateDictOptions(ignore_frozen_params=True),
146                )
147                dist_cp.load_state_dict(
148                    {"model": model_state_dict, "optim": optim_state_dict},
149                    storage_reader=dist_cp.FileSystemReader(pretrain_dir),
150                )
151                set_state_dict(
152                    model,
153                    optimizers=optim,
154                    model_state_dict=model_state_dict,
155                    optim_state_dict=optim_state_dict,
156                    options=StateDictOptions(strict=False),
157                )
158            except KeyError:
159                # If this is the first round of the fine tuning, then nothing is saved.
160                # If this is the restart of the fine tuning, then checkpoint should exit.
161                self.assertEqual(i, 0)
162
163            # Training
164            for j in range(3):
165                batch = torch.rand(32, DIM, device="cuda")
166                loss = model(batch).sum()
167                loss.backward()
168                optim.step()
169                optim.zero_grad()
170
171            # Save state_dict
172            model_state_dict, optim_state_dict = get_state_dict(
173                model,
174                optimizers=optim,
175                options=StateDictOptions(ignore_frozen_params=True),
176            )
177            saved_state_dict = {"model": model_state_dict, "optim": optim_state_dict}
178            dist_cp.save_state_dict(
179                state_dict=saved_state_dict,
180                storage_writer=dist_cp.FileSystemWriter(finetune_dir),
181            )
182
183    @skip_if_lt_x_gpu(4)
184    @with_comms
185    @with_temp_dir
186    def test_fine_tuning(self) -> None:
187        self.assertTrue(os.path.exists(self.temp_dir))
188        pretrain_dir = os.path.join(self.temp_dir, "pretrain")
189        finetune_dir = os.path.join(self.temp_dir, "finetune")
190        print(pretrain_dir, finetune_dir)
191        if self.rank == 0:
192            os.mkdir(pretrain_dir)
193            os.mkdir(finetune_dir)
194        dist.barrier()
195        os.sync()
196        self.assertTrue(os.path.exists(pretrain_dir))
197        self.assertTrue(os.path.exists(finetune_dir))
198
199        self.pretrain(pretrain_dir)
200        self.finetune(pretrain_dir, finetune_dir)
201
202
203if __name__ == "__main__":
204    run_tests()
205