• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# mypy: allow-untyped-defs
2# Owner(s): ["oncall: distributed"]
3
4import os
5import shutil
6import traceback
7
8import torch
9import torch.distributed as dist
10import torch.distributed.checkpoint as dcp
11import torch.multiprocessing as mp
12import torch.nn as nn
13import torch.nn.functional as F
14from torch.distributed.checkpoint.state_dict import (
15    _patch_model_state_dict,
16    _patch_optimizer_state_dict,
17)
18from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
19from torch.distributed.tensor.device_mesh import init_device_mesh
20
21
22DEVICE = "cuda"
23NUM_EPOCHS = 1000
24SAVE_PERIOD = 10
25FAULT_PERIOD = 25
26CHECKPOINT_DIR = f"~/{os.environ.get('LOGNAME', '')}/checkpoint"
27
28
29class InjectedException(Exception):
30    pass
31
32
33class Model(torch.nn.Module):
34    def __init__(self) -> None:
35        super().__init__()
36        self.net1 = nn.Linear(8, 32)
37        self.net2 = nn.Linear(32, 128)
38        self.net3 = nn.Linear(128, 64)
39        self.net4 = nn.Linear(64, 8)
40        self.net5 = nn.Linear(8, 1)
41
42    def forward(self, x):
43        x = F.relu(self.net1(x))
44        x = F.relu(self.net2(x))
45        x = F.relu(self.net3(x))
46        x = F.relu(self.net4(x))
47        x = F.sigmoid(self.net5(x))
48        return x
49
50
51def _init_model(rank, world_size):
52    device_mesh = init_device_mesh(DEVICE, (world_size,))
53
54    # Create a dummy model and wrap it in FSDP
55    model = Model().cuda()
56    device_mesh = init_device_mesh(DEVICE, (world_size,))
57    model = FSDP(model, device_mesh=device_mesh, use_orig_params=True)
58
59    optim = torch.optim.Adam(model.parameters(), lr=0.0001)
60
61    _patch_model_state_dict(model)
62    _patch_optimizer_state_dict(model, optimizers=optim)
63
64    return model, optim
65
66
67def _print(msg):
68    if dist.get_rank() == 0:
69        print(msg)
70
71
72def _input():
73    x = torch.rand(128, 8, device="cuda")
74    y = torch.zeros(128, 1, device="cuda")
75
76    y[torch.sum(x, dim=1) >= 4] = 1.0
77
78    return x, y
79
80
81def run(rank, world_size):
82    # Set up world pg
83    os.environ["MASTER_ADDR"] = "localhost"
84    os.environ["MASTER_PORT"] = "12355"
85
86    dist.init_process_group("cpu:gloo,cuda:nccl", rank=rank, world_size=world_size)
87    torch.cuda.set_device(rank)
88
89    model, optim = _init_model(rank, world_size)
90    state_dict = {"model": model, "optim": optim}
91    loss_calc = torch.nn.BCELoss()
92
93    f = None
94    for epoch in range(NUM_EPOCHS):
95        try:
96            torch.manual_seed(epoch)
97            x, y = _input()
98
99            loss = loss_calc(model(x), y)
100
101            _print(f"{epoch=} {loss=}")
102
103            loss.backward()
104            optim.step()
105            optim.zero_grad()
106
107            if epoch % SAVE_PERIOD == 0:
108                if f is not None:
109                    f.result()
110                f = dcp.state_dict_saver.async_save(
111                    state_dict, checkpoint_id=CHECKPOINT_DIR
112                )
113
114            if FAULT_PERIOD > 0 and epoch % FAULT_PERIOD == 0:
115                raise InjectedException("Fault injection!")
116
117        except InjectedException as e:
118            dist.barrier()
119
120            _print("Trainer encountered exception:")
121            traceback.print_tb(e.__traceback__)
122
123            _print("Reloading model from last checkpoint!")
124            if f is not None:
125                f.result()
126            dcp.load(state_dict)
127
128
129if __name__ == "__main__":
130    world_size = torch.cuda.device_count()
131    print(f"Running an example of Async Checkpointing on {world_size} devices.")
132    shutil.rmtree(CHECKPOINT_DIR, ignore_errors=True)
133
134    mp.spawn(
135        run,
136        args=(world_size,),
137        nprocs=world_size,
138        join=True,
139    )
140