# Copyright (c) Meta Platforms, Inc. and affiliates # Owner(s): ["oncall: distributed"] import copy import logging import os import sys import tempfile from model_registry import ModelWithKwargs, MultiMLP, MultiMLPWithDw from schedule_registry import ScheduleUnbalanced, ScheduleVShaped, ScheduleWithW import torch import torch.distributed as dist from torch.distributed.pipelining import ( _ScheduleForwardOnly, pipeline, PipelineStage, Schedule1F1B, ScheduleFlexibleInterleaved1F1B, ScheduleGPipe, ScheduleInterleaved1F1B, ScheduleInterleavedZeroBubble, ScheduleLoopedBFS, ) from torch.distributed.pipelining.schedules import _PipelineScheduleRuntime from torch.testing._internal.common_cuda import TEST_MULTIGPU from torch.testing._internal.common_distributed import ( MultiProcContinousTest, requires_nccl, ) from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, skip_but_pass_in_sandcastle_if, ) logger = logging.getLogger(__name__) d_hid = 512 batch_size = 256 torch.manual_seed(0) class ScheduleTest(MultiProcContinousTest): @classmethod def backend_str(cls) -> str: # Testing with NCCL backend return "nccl" @classmethod def setUpClass(cls): """ Class-scope test fixture. Run once for entire test class, before any test starts. Set up the device. """ super().setUpClass() dev_id = cls.rank % torch.cuda.device_count() cls.device = torch.device(f"cuda:{dev_id}") @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") @parametrize("ScheduleClass", [_ScheduleForwardOnly]) def test_forward_only(self, ScheduleClass): mod = MultiMLP(d_hid, n_layers=self.world_size) mod.to(self.device) mod_ref = copy.deepcopy(mod) x = torch.randn(batch_size, d_hid, device=self.device) x_clone = x.clone() num_microbatches = 4 x_mb = x.chunk(num_microbatches)[0] # Create a pipeline split_spec = mod.split_spec if hasattr(mod, "split_spec") else None pipe = pipeline( mod, mb_args=(x_mb,), split_spec=split_spec, ) stage = pipe.build_stage( self.rank, self.device, ) # Attach to a schedule schedule = ScheduleClass(stage, num_microbatches) # Run num_iters = 20 for _ in range(num_iters): if self.rank == 0: schedule.step(x) dist.recv(x, src=self.world_size - 1) elif self.rank == self.world_size - 1: out = schedule.step() dist.send(out, dst=0) else: schedule.step() # Validate pipelined output is the same as reference model if self.rank == self.world_size - 1: for _ in range(num_iters): x_clone = mod_ref(x_clone) torch.testing.assert_close(x_clone, out) @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") @parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B]) def test_multi_iter(self, ScheduleClass): mod = MultiMLP(d_hid, n_layers=self.world_size) mod.to(self.device) x = torch.randn(batch_size, d_hid, device=self.device) target = torch.randn(batch_size, d_hid, device=self.device) loss_fn = torch.nn.MSELoss(reduction="sum") chunks = 4 x_mb = x.chunk(chunks)[0] # Create a pipeline split_spec = mod.split_spec if hasattr(mod, "split_spec") else None pipe = pipeline( mod, mb_args=(x_mb,), split_spec=split_spec, ) stage = pipe.build_stage( self.rank, self.device, ) # Attach to a schedule schedule = ScheduleClass(stage, chunks, loss_fn=loss_fn) # Run for _ in range(20): if self.rank == 0: schedule.step(x) elif self.rank == self.world_size - 1: losses = [] out = schedule.step(target=target, losses=losses) else: schedule.step() @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") @parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B]) def test_kwargs_with_tracer(self, ScheduleClass): mod = ModelWithKwargs(d_hid) mod.to(self.device) x = torch.randn(batch_size, d_hid, device=self.device) y = torch.randn(batch_size, d_hid, device=self.device) target = torch.randn(batch_size, d_hid, device=self.device) loss_fn = torch.nn.MSELoss(reduction="sum") chunks = 4 x_mb = x.chunk(chunks)[0] y_mb = y.chunk(chunks)[0] pipe = pipeline( mod, mb_args=(x_mb,), mb_kwargs={"y": y_mb}, ) stage = pipe.build_stage( self.rank, self.device, ) # Attach to a schedule schedule = ScheduleClass(stage, chunks, loss_fn=loss_fn) # Run if self.rank == 0: schedule.step(x, y=y) elif self.rank == self.world_size - 1: losses = [] out = schedule.step(target=target, losses=losses) else: schedule.step() dist.barrier() # Last rank checks result if self.rank == self.world_size - 1: ref_out = mod(x, y=y) ref_loss = loss_fn(ref_out, target) pipe_loss = sum(losses) torch.testing.assert_close(out, ref_out, rtol=1e-2, atol=5e-3) torch.testing.assert_close(pipe_loss, ref_loss) @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") @parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B]) @parametrize("ModelClass", [MultiMLP]) def test_grad_with_tracer(self, ScheduleClass, ModelClass): mod = ModelClass(d_hid) mod.to(self.device) ref_mod = copy.deepcopy(mod) x = torch.randn(batch_size, d_hid, device=self.device) with torch.no_grad(): y = ref_mod(x) # Add a small perturbation target = y + torch.randn(batch_size, d_hid, device=self.device) loss_fn = torch.nn.MSELoss(reduction="sum") # Run reference for _ in range(2): ref_mod.zero_grad() ref_out = ref_mod(x) ref_loss = loss_fn(ref_out, target) ref_loss.backward() # Create a pipeline chunks = 4 x_mb = x.chunk(chunks)[0] split_spec = mod.split_spec if hasattr(mod, "split_spec") else None pipe = pipeline( mod, mb_args=(x_mb,), split_spec=split_spec, ) stage = pipe.build_stage( self.rank, self.device, ) # Attach to a schedule schedule = ScheduleClass(stage, chunks, loss_fn=loss_fn) # Run stage_module = pipe.get_stage_module(self.rank) for _ in range(2): # Zero gradients stage_module.zero_grad() if self.rank == 0: schedule.step(x) elif self.rank == self.world_size - 1: losses = [] out = schedule.step(target=target, losses=losses) else: schedule.step() dist.barrier() # Last rank checks result if self.rank == self.world_size - 1: # Check output torch.testing.assert_close(out, ref_out) # Check loss # Since the reduction used in the loss function above is "sum", we use # "sum" here to reduce microbatch losses into a single value too. pipe_loss = sum(losses) torch.testing.assert_close(pipe_loss, ref_loss) # Every rank checks gradients for name, p in stage_module.named_parameters(): ref_p = ref_mod.get_parameter(name) try: torch.testing.assert_close(p.grad, ref_p.grad, rtol=1e-5, atol=4e-5) except AssertionError: print(f"Gradient test failed for {name}: {p.grad} vs {ref_p.grad}") raise @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") @parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B]) def test_grad_with_manual(self, ScheduleClass): full_mod = MultiMLP(d_hid, n_layers=self.world_size) full_mod.to(self.device) ref_mod = copy.deepcopy(full_mod) x = torch.randn(batch_size, d_hid, device=self.device) with torch.no_grad(): y = ref_mod(x) # Add a small perturbation target = y + torch.randn(batch_size, d_hid, device=self.device) loss_fn = torch.nn.MSELoss(reduction="sum") # Run reference for _ in range(2): ref_mod.zero_grad() ref_out = ref_mod(x) ref_loss = loss_fn(ref_out, target) ref_loss.backward() # Get a submodule, e.g. `layers.0` or `layers.1` submod_name = f"layers.{self.rank}" stage_module = full_mod.get_submodule(submod_name) chunks = 4 # Create a pipeline stage to wrap that submodule stage = PipelineStage( stage_module, self.rank, self.world_size, self.device, input_args=x.chunk(chunks)[0], ) # Attach to a schedule schedule = ScheduleClass(stage, chunks, loss_fn=loss_fn) # Run for _ in range(2): # Zero gradients stage_module.zero_grad() if self.rank == 0: schedule.step(x) elif self.rank == self.world_size - 1: losses = [] out = schedule.step(target=target, losses=losses) else: schedule.step() dist.barrier() # Last rank checks result if self.rank == self.world_size - 1: # Check output torch.testing.assert_close(out, ref_out) # Check loss # Since the reduction used in the loss function above is "sum", we use # "sum" here to reduce microbatch losses into a single value too. pipe_loss = sum(losses) torch.testing.assert_close(pipe_loss, ref_loss) # Every rank checks gradients ref_submod = ref_mod.get_submodule(submod_name) for name, p in stage_module.named_parameters(): ref_p = ref_submod.get_parameter(name) try: torch.testing.assert_close(p.grad, ref_p.grad, rtol=1e-5, atol=4e-5) except AssertionError: print(f"Gradient test failed for {name}: {p.grad} vs {ref_p.grad}") raise @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") @parametrize( "ScheduleClass", [ScheduleInterleaved1F1B, ScheduleLoopedBFS, ScheduleInterleavedZeroBubble], ) @parametrize("use_new_runtime", [False, True]) def test_grad_with_manual_interleaved(self, ScheduleClass, use_new_runtime): stages_per_rank = 2 n_stages = stages_per_rank * self.world_size full_mod = MultiMLP(d_hid, n_layers=n_stages) full_mod.to(self.device) ref_mod = copy.deepcopy(full_mod) x = torch.randn(batch_size, d_hid, device=self.device) with torch.no_grad(): y = ref_mod(x) # Add a small perturbation target = y + torch.randn(batch_size, d_hid, device=self.device) loss_fn = torch.nn.MSELoss(reduction="sum") # Run reference for _ in range(2): ref_mod.zero_grad() ref_out = ref_mod(x) ref_loss = loss_fn(ref_out, target) ref_loss.backward() # Get a submodule, e.g. `layers.0` or `layers.1` stage_indices = [ self.rank + i * self.world_size for i in range(stages_per_rank) ] print(f"Rank {self.rank} stages: {stage_indices}") submod_names = [f"layers.{i}" for i in stage_indices] stage_modules = [ full_mod.get_submodule(submod_name) for submod_name in submod_names ] # Create a pipeline stage to wrap that submodule num_microbatches = ( ScheduleClass.num_microbatches if hasattr(ScheduleClass, "num_microbatches") else 8 ) input_args = x.chunk(num_microbatches)[0] stages = [ PipelineStage( stage_module, stage_idx, n_stages, self.device, input_args=input_args, ) for stage_module, stage_idx in zip(stage_modules, stage_indices) ] # Attach to a schedule schedule = ScheduleClass(stages, num_microbatches, loss_fn=loss_fn) if use_new_runtime: old_schedule = schedule tmp_schedule = _PipelineScheduleRuntime( stages, num_microbatches, loss_fn=loss_fn, stage_index_to_group_rank=old_schedule.stage_index_to_group_rank, use_full_backward=old_schedule.use_full_backward, ) tmp_schedule._load_actions(old_schedule.pipeline_order) # test that csv round-trip works for compute_comms schedule schedule = _PipelineScheduleRuntime( stages, num_microbatches, loss_fn=loss_fn, stage_index_to_group_rank=old_schedule.stage_index_to_group_rank, use_full_backward=old_schedule.use_full_backward, ) with tempfile.NamedTemporaryFile() as f: tmp_schedule._dump_csv(f.name) f.seek(0) schedule._load_csv(f.name, format="compute_comms") one_more_schedule = _PipelineScheduleRuntime( stages, num_microbatches, loss_fn=loss_fn, stage_index_to_group_rank=old_schedule.stage_index_to_group_rank, use_full_backward=old_schedule.use_full_backward, ) one_more_schedule._load_actions( schedule.pipeline_order_with_comms, format="compute_comms" ) self.assertEqual( len(schedule.pipeline_order_with_comms), len( one_more_schedule.pipeline_order_with_comms, ), ) for rank in schedule.pipeline_order_with_comms: self.assertEqual( len(schedule.pipeline_order_with_comms[rank]), len( one_more_schedule.pipeline_order_with_comms[rank], ), ) for a, b in zip( schedule.pipeline_order_with_comms[rank], one_more_schedule.pipeline_order_with_comms[rank], ): self.assertEqual(a, b) # Run for _ in range(2): # Zero gradients for stage_module in stage_modules: stage_module.zero_grad() if self.rank == 0: schedule.step(x) elif self.rank == self.world_size - 1: losses = [] out = schedule.step(target=target, losses=losses) else: schedule.step() dist.barrier() # Last rank checks result if self.rank == self.world_size - 1: # Check output torch.testing.assert_close(out, ref_out) # Check loss # Since the reduction used in the loss function above is "sum", we use # "sum" here to reduce microbatch losses into a single value too. pipe_loss = sum(losses) torch.testing.assert_close(pipe_loss, ref_loss) # Every rank checks gradients for stage_module, submod_name in zip(stage_modules, submod_names): # Get corresponding submodule from reference model ref_submod = ref_mod.get_submodule(submod_name) # Check gradients per parameter for name, p in stage_module.named_parameters(): ref_p = ref_submod.get_parameter(name) try: torch.testing.assert_close(p.grad, ref_p.grad, rtol=1e-5, atol=4e-5) except AssertionError: print(f"Gradient test failed for {name}: {p.grad} vs {ref_p.grad}") raise @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") @parametrize("ScheduleClass", [ScheduleWithW, ScheduleFlexibleInterleaved1F1B]) def test_schedule_with_native_zero_bubble(self, ScheduleClass): print(ScheduleClass) if ScheduleClass is ScheduleFlexibleInterleaved1F1B: n_stages = 4 num_microbatches = 8 rank_stages = { 0: [0, 2], 1: [1, 3], } else: n_stages = ScheduleClass.n_stages num_microbatches = ScheduleClass.num_microbatches rank_stages = ScheduleClass.rank_stages num_steps = 4 full_mod = MultiMLP(d_hid, n_layers=n_stages) full_mod.to(self.device) ref_mod = copy.deepcopy(full_mod) x = torch.randn(batch_size, d_hid, device=self.device) # x = torch.randn(batch_size, d_hid, device=self.device, requires_grad=True) with torch.no_grad(): y = ref_mod(x) # Add a small perturbation target = y + torch.randn(batch_size, d_hid, device=self.device) loss_fn = torch.nn.MSELoss(reduction="sum") # Create a pipeline stage to wrap that submodule input_args = x.chunk(num_microbatches)[0] stage_indices = rank_stages[self.rank] print(f"Rank {self.rank} stages: {stage_indices}") submod_names = [f"layers.{i}" for i in stage_indices] stage_modules = [ full_mod.get_submodule(submod_name) for submod_name in submod_names ] stages = [ PipelineStage( stage_module, stage_idx, n_stages, self.device, input_args=input_args, ) for stage_module, stage_idx in zip(stage_modules, rank_stages[self.rank]) ] schedule = ScheduleClass( stages, num_microbatches, loss_fn=loss_fn, enable_zero_bubble=True ) # Run reference ref_x = x.clone().detach().requires_grad_(x.requires_grad) torch.testing.assert_close(x, ref_x) for _ in range(num_steps): ref_out = ref_mod(ref_x) ref_loss = loss_fn(ref_out, target) ref_loss.backward() # Run pipelined stages for _ in range(num_steps): if self.rank == 0: schedule.step(x) elif self.rank == self.world_size - 1: losses = [] out = schedule.step(target=target, losses=losses) else: schedule.step() # Every rank checks parameters compared with the reference model for stage_module, submod_name in zip(stage_modules, submod_names): # Get corresponding submodule from reference model ref_submod = ref_mod.get_submodule(submod_name) # Check gradients per parameter for name, p in stage_module.named_parameters(): ref_p = ref_submod.get_parameter(name) try: torch.testing.assert_close(p.grad, ref_p.grad, rtol=1e-5, atol=4e-5) except AssertionError: print( f"Parameter test failed for {submod_name}.{name}: {p.grad} vs {ref_p.grad}" ) raise @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") @parametrize("ScheduleClass", [ScheduleVShaped, ScheduleUnbalanced]) def test_non_symmetric_stage_ids(self, ScheduleClass): n_stages = ScheduleClass.n_stages full_mod = MultiMLP(d_hid, n_layers=n_stages) full_mod.to(self.device) ref_mod = copy.deepcopy(full_mod) x = torch.randn(batch_size, d_hid, device=self.device) with torch.no_grad(): y = ref_mod(x) # Add a small perturbation target = y + torch.randn(batch_size, d_hid, device=self.device) loss_fn = torch.nn.MSELoss(reduction="sum") # Run reference for _ in range(2): ref_mod.zero_grad() ref_out = ref_mod(x) ref_loss = loss_fn(ref_out, target) ref_loss.backward() # Create a pipeline stage to wrap that submodule chunks = 1 input_args = x.chunk(chunks)[0] rank_stages = ScheduleClass.rank_stages stage_indices = rank_stages[self.rank] print(f"Rank {self.rank} stages: {stage_indices}") submod_names = [f"layers.{i}" for i in stage_indices] stage_modules = [ full_mod.get_submodule(submod_name) for submod_name in submod_names ] stages = [ PipelineStage( stage_module, stage_idx, n_stages, self.device, input_args=input_args, ) for stage_module, stage_idx in zip(stage_modules, rank_stages[self.rank]) ] # Attach to a schedule stage_index_to_group_rank = { value: key for key, values in rank_stages.items() for value in values } schedule = ScheduleClass( stages, chunks, stage_index_to_group_rank, loss_fn=loss_fn ) # Run # TODO how to better specify .step() when first and last stage are on rank 0... for _ in range(2): # Zero gradients for stage_module in stage_modules: stage_module.zero_grad() if self.rank == 0: losses = [] out = schedule.step(x, target=target, losses=losses) else: schedule.step() dist.barrier() # Last rank checks result if self.rank == 0: # Check output torch.testing.assert_close(out, ref_out) # Check loss # Since the reduction used in the loss function above is "sum", we use # "sum" here to reduce microbatch losses into a single value too. pipe_loss = sum(losses) torch.testing.assert_close(pipe_loss, ref_loss) # Every rank checks gradients for stage_module, submod_name in zip(stage_modules, submod_names): # Get corresponding submodule from reference model ref_submod = ref_mod.get_submodule(submod_name) # Check gradients per parameter for name, p in stage_module.named_parameters(): ref_p = ref_submod.get_parameter(name) try: torch.testing.assert_close(p.grad, ref_p.grad, rtol=1e-5, atol=4e-5) except AssertionError: print(f"Gradient test failed for {name}: {p.grad} vs {ref_p.grad}") raise @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") @parametrize("ScheduleClass", [ScheduleFlexibleInterleaved1F1B]) def test_schedule_with_weight_update_mlp_e2e(self, ScheduleClass): stages_per_rank = 2 n_stages = stages_per_rank * self.world_size full_mod = MultiMLPWithDw(d_hid, n_layers=n_stages) full_mod.to(self.device) ref_mod = copy.deepcopy(full_mod) x = torch.randn(batch_size, d_hid, device=self.device) with torch.no_grad(): y = ref_mod(x) # Add a small perturbation target = y + torch.randn(batch_size, d_hid, device=self.device) ref_loss_fn = torch.nn.MSELoss(reduction="sum") full_loss_fn = torch.nn.MSELoss(reduction="sum") full_mod.toggle() # Get a submodule, e.g. `layers.0` or `layers.1` stage_indices = [ self.rank + i * self.world_size for i in range(stages_per_rank) ] submod_names = [f"layers.{i}" for i in stage_indices] stage_modules = [ full_mod.get_submodule(submod_name) for submod_name in submod_names ] # Run reference for _ in range(2): ref_stage_modules = [ ref_mod.get_submodule(submod_name) for submod_name in submod_names ] for stage_module in ref_stage_modules: stage_module.zero_grad() ref_mod.zero_grad() ref_out = ref_mod(x) ref_loss = ref_loss_fn(ref_out, target) ref_loss.backward() class CustomState: def __init__(self, stage_module, stage_idx, rank): self.i = 0 self.stage_module = stage_module self.stage_idx = stage_idx self.rank = rank def dw_builder(self): def dw_runner(): # This inner function would be called by PipelineStage during `backward_weight_one_chunk` self.i += 1 print( f"[Rank {self.rank}] dw_count={self.i} stage={self.stage_idx}" ) self.stage_module.compute_dW() return dw_runner cs = {} for stage_module, stage_idx in zip(stage_modules, stage_indices): cs[stage_idx] = CustomState(stage_module, stage_idx, self.rank) # Create a pipeline stage to wrap that submodule chunks = 2 input_args = x.chunk(chunks)[0] stages = [ PipelineStage( stage_module, stage_idx, n_stages, self.device, input_args=input_args, dw_builder=cs[stage_idx].dw_builder, ) for stage_module, stage_idx in zip(stage_modules, stage_indices) ] # Attach to a schedule schedule = ScheduleClass( stages, chunks, loss_fn=full_loss_fn, enable_zero_bubble=True ) for _ in range(2): # Zero gradients for stage_module in stage_modules: stage_module.zero_grad() if self.rank == 0: schedule.step(x) elif self.rank == self.world_size - 1: losses = [] out = schedule.step(target=target, losses=losses) else: schedule.step() dist.barrier() # Last rank checks result if self.rank == self.world_size - 1: # Check output torch.testing.assert_close(out, ref_out) # Check loss # Since the reduction used in the loss function above is "sum", we use # "sum" here to reduce microbatch losses into a single value too. pipe_loss = sum(losses) torch.testing.assert_close(pipe_loss, ref_loss) # Every rank checks gradients for stage_module, submod_name in zip(stage_modules, submod_names): # Get corresponding submodule from reference model ref_submod = ref_mod.get_submodule(submod_name) # Check gradients per parameter for name, p in stage_module.named_parameters(): ref_p = ref_submod.get_parameter(name) torch.testing.assert_close(p.grad, ref_p.grad, rtol=1e-5, atol=4e-5) instantiate_parametrized_tests(ScheduleTest) if __name__ == "__main__": # Check if GPU and NCCL are available if not ( dist.is_available() and dist.is_nccl_available() and torch.cuda.device_count() > 1 ): print( "c10d NCCL not available or not enough GPUs, skipping tests", file=sys.stderr, ) sys.exit(0) rank = int(os.getenv("RANK", -1)) world_size = int(os.getenv("WORLD_SIZE", 2)) if rank != -1: # Launched with torchrun or other multi-proc launchers. Directly run the test. ScheduleTest.run_rank(rank, world_size) else: # Launched as a single process. Spawn subprocess to run the tests. # Also need a rendezvous file for `init_process_group` purpose. rdvz_file = tempfile.NamedTemporaryFile(delete=False).name torch.multiprocessing.spawn( ScheduleTest.run_rank, nprocs=world_size, args=(world_size, rdvz_file), )