1# mypy: allow-untyped-defs 2# Owner(s): ["oncall: distributed"] 3 4import os 5import shutil 6import traceback 7 8import torch 9import torch.distributed as dist 10import torch.distributed.checkpoint as dcp 11import torch.multiprocessing as mp 12import torch.nn as nn 13import torch.nn.functional as F 14from torch.distributed.checkpoint.state_dict import ( 15 _patch_model_state_dict, 16 _patch_optimizer_state_dict, 17) 18from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 19from torch.distributed.tensor.device_mesh import init_device_mesh 20 21 22DEVICE = "cuda" 23NUM_EPOCHS = 1000 24SAVE_PERIOD = 10 25FAULT_PERIOD = 25 26CHECKPOINT_DIR = f"~/{os.environ.get('LOGNAME', '')}/checkpoint" 27 28 29class InjectedException(Exception): 30 pass 31 32 33class Model(torch.nn.Module): 34 def __init__(self) -> None: 35 super().__init__() 36 self.net1 = nn.Linear(8, 32) 37 self.net2 = nn.Linear(32, 128) 38 self.net3 = nn.Linear(128, 64) 39 self.net4 = nn.Linear(64, 8) 40 self.net5 = nn.Linear(8, 1) 41 42 def forward(self, x): 43 x = F.relu(self.net1(x)) 44 x = F.relu(self.net2(x)) 45 x = F.relu(self.net3(x)) 46 x = F.relu(self.net4(x)) 47 x = F.sigmoid(self.net5(x)) 48 return x 49 50 51def _init_model(rank, world_size): 52 device_mesh = init_device_mesh(DEVICE, (world_size,)) 53 54 # Create a dummy model and wrap it in FSDP 55 model = Model().cuda() 56 device_mesh = init_device_mesh(DEVICE, (world_size,)) 57 model = FSDP(model, device_mesh=device_mesh, use_orig_params=True) 58 59 optim = torch.optim.Adam(model.parameters(), lr=0.0001) 60 61 _patch_model_state_dict(model) 62 _patch_optimizer_state_dict(model, optimizers=optim) 63 64 return model, optim 65 66 67def _print(msg): 68 if dist.get_rank() == 0: 69 print(msg) 70 71 72def _input(): 73 x = torch.rand(128, 8, device="cuda") 74 y = torch.zeros(128, 1, device="cuda") 75 76 y[torch.sum(x, dim=1) >= 4] = 1.0 77 78 return x, y 79 80 81def run(rank, world_size): 82 # Set up world pg 83 os.environ["MASTER_ADDR"] = "localhost" 84 os.environ["MASTER_PORT"] = "12355" 85 86 dist.init_process_group("cpu:gloo,cuda:nccl", rank=rank, world_size=world_size) 87 torch.cuda.set_device(rank) 88 89 model, optim = _init_model(rank, world_size) 90 state_dict = {"model": model, "optim": optim} 91 loss_calc = torch.nn.BCELoss() 92 93 f = None 94 for epoch in range(NUM_EPOCHS): 95 try: 96 torch.manual_seed(epoch) 97 x, y = _input() 98 99 loss = loss_calc(model(x), y) 100 101 _print(f"{epoch=} {loss=}") 102 103 loss.backward() 104 optim.step() 105 optim.zero_grad() 106 107 if epoch % SAVE_PERIOD == 0: 108 if f is not None: 109 f.result() 110 f = dcp.state_dict_saver.async_save( 111 state_dict, checkpoint_id=CHECKPOINT_DIR 112 ) 113 114 if FAULT_PERIOD > 0 and epoch % FAULT_PERIOD == 0: 115 raise InjectedException("Fault injection!") 116 117 except InjectedException as e: 118 dist.barrier() 119 120 _print("Trainer encountered exception:") 121 traceback.print_tb(e.__traceback__) 122 123 _print("Reloading model from last checkpoint!") 124 if f is not None: 125 f.result() 126 dcp.load(state_dict) 127 128 129if __name__ == "__main__": 130 world_size = torch.cuda.device_count() 131 print(f"Running an example of Async Checkpointing on {world_size} devices.") 132 shutil.rmtree(CHECKPOINT_DIR, ignore_errors=True) 133 134 mp.spawn( 135 run, 136 args=(world_size,), 137 nprocs=world_size, 138 join=True, 139 ) 140