1# mypy: allow-untyped-defs 2# Owner(s): ["oncall: distributed"] 3 4# pyre-unsafe 5 6 7import os 8import shutil 9 10import torch 11import torch.distributed as dist 12import torch.distributed.checkpoint as dcp 13import torch.multiprocessing as mp 14import torch.nn as nn 15from torch.distributed.checkpoint.state_dict import ( 16 _patch_model_state_dict, 17 _patch_optimizer_state_dict, 18) 19from torch.distributed.device_mesh import init_device_mesh 20from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 21 22 23CHECKPOINT_DIR = f"~/{os.environ['LOGNAME']}/checkpoint" 24 25 26class Model(torch.nn.Module): 27 def __init__(self) -> None: 28 super().__init__() 29 torch.manual_seed(0) 30 self.net1 = nn.Sequential(nn.Linear(8, 16), nn.ReLU()) 31 self.net2 = nn.Sequential(nn.Linear(16, 32), nn.ReLU()) 32 self.net3 = nn.Linear(32, 64) 33 self.net4 = nn.Sequential(nn.ReLU(), nn.Linear(64, 8)) 34 35 def forward(self, x): 36 return self.net4(self.net3(self.net2(self.net1(x)))) 37 38 def get_input(self): 39 return torch.rand(8, 8, device="cuda") 40 41 42def _make_stateful(model, optim): 43 _patch_model_state_dict(model) 44 _patch_optimizer_state_dict(model, optimizers=optim) 45 46 47def _train(model, optim, train_steps=1): 48 torch.manual_seed(0) 49 loss = None 50 for _ in range(train_steps): 51 loss = model(model.get_input()).sum() 52 loss.backward() 53 optim.step() 54 optim.zero_grad() 55 56 return loss 57 58 59def _init_model(device, world_size): 60 device_mesh = init_device_mesh(device, (world_size,)) 61 model = Model().cuda() 62 model = FSDP( 63 model, 64 device_mesh=device_mesh, 65 use_orig_params=True, 66 ) 67 optim = torch.optim.Adam(model.parameters(), lr=0.1) 68 _make_stateful(model, optim) 69 70 return model, optim 71 72 73def run(rank, world_size, device="cuda"): 74 # Set up world pg 75 os.environ["MASTER_ADDR"] = "localhost" 76 os.environ["MASTER_PORT"] = "12355" 77 78 dist.init_process_group("cpu:gloo,cuda:nccl", rank=rank, world_size=world_size) 79 torch.cuda.set_device(rank) 80 81 model, optim = _init_model(device, world_size) 82 _train(model, optim, train_steps=2) 83 84 dcp.save( 85 state_dict={"model": model, "optimizer": optim}, 86 checkpoint_id=CHECKPOINT_DIR, 87 ) 88 89 # presumably do something else 90 model, optim = _init_model(device, world_size) 91 dcp.load( 92 state_dict={"model": model, "optimizer": optim}, 93 checkpoint_id=CHECKPOINT_DIR, 94 ) 95 _train(model, optim, train_steps=2) 96 97 98if __name__ == "__main__": 99 world_size = torch.cuda.device_count() 100 print(f"Running stateful checkpoint example on {world_size} devices.") 101 shutil.rmtree(CHECKPOINT_DIR, ignore_errors=True) 102 mp.spawn( 103 run, 104 args=(world_size,), 105 nprocs=world_size, 106 join=True, 107 ) 108