# Owner(s): ["oncall: distributed"] import os import sys import torch import torch.distributed as dist import torch.distributed.checkpoint as dcp import torch.nn as nn from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_fsdp import FSDPTest from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir if not dist.is_available(): print("Distributed not available, skipping tests", file=sys.stderr) sys.exit(0) if TEST_WITH_DEV_DBG_ASAN: print( "Skip dev-asan as torch + multiprocessing spawn have known issues", file=sys.stderr, ) sys.exit(0) DIM = 500 class PipelineModel(nn.Module): def __init__(self) -> None: super().__init__() self.layer1 = nn.Linear(DIM, DIM) self.layer2 = nn.Linear(DIM, DIM) self.layer3 = nn.Linear(DIM, DIM) self.layer4 = nn.Linear(DIM, DIM) self.relu = nn.ReLU() def forward(self, batch): x = self.relu(self.layer1(batch)) x = self.relu(self.layer2(x)) x = self.relu(self.layer3(x)) x = self.relu(self.layer4(x)) return x class TestPipeline(FSDPTest): @property def world_size(self) -> int: return min(4, torch.cuda.device_count()) def save_with_pipeline(self, pipeline_dir: str) -> None: with torch.device("meta"): model = PipelineModel() pipeline_modules = [model.layer1, model.layer2, model.layer3, model.layer4] # Materialize the model submodule = pipeline_modules[self.rank] submodule.to_empty(device=torch.device("cuda")) # submodule.reset_parameters() optim = torch.optim.Adam(submodule.parameters(), lr=1e-3) # Ignore the training as we don't have a real pipeline parallelism. # Save state_dict model_state_dict, optim_state_dict = get_state_dict(model, optimizers=optim) saved_state_dict = {"model": model_state_dict, "optim": optim_state_dict} dcp.save( state_dict=saved_state_dict, storage_writer=dcp.FileSystemWriter(pipeline_dir), ) def load_with_fsdp(self, pipeline_dir: str) -> None: model = FSDP(PipelineModel().cuda()) optim = torch.optim.Adam(model.parameters(), lr=1e-3) # Load the checkpoint model_state_dict, optim_state_dict = get_state_dict(model, optimizers=optim) dcp.load( {"model": model_state_dict, "optim": optim_state_dict}, storage_reader=dcp.FileSystemReader(pipeline_dir), ) set_state_dict( model, optimizers=optim, model_state_dict=model_state_dict, optim_state_dict=optim_state_dict, ) @skip_if_lt_x_gpu(4) @with_temp_dir def test_pipeline(self) -> None: self.assertTrue(os.path.exists(self.temp_dir)) pipeline_dir = os.path.join(self.temp_dir, "pipeline") if self.rank == 0: os.mkdir(pipeline_dir) os.sync() dist.barrier() self.assertTrue(os.path.exists(pipeline_dir)) self.save_with_pipeline(pipeline_dir) self.load_with_fsdp(pipeline_dir) if __name__ == "__main__": run_tests()