# Owner(s): ["oncall: distributed"] import os import sys import torch import torch.cuda import torch.distributed as dist import torch.distributed.algorithms._quantization.quantization as quant from torch.distributed.algorithms._quantization.quantization import DQuantType from torch.testing._internal.common_distributed import ( init_multigpu_helper, MultiProcessTestCase, requires_gloo, requires_nccl, skip_if_lt_x_gpu, skip_if_rocm_multiprocess, ) from torch.testing._internal.common_utils import ( NO_MULTIPROCESSING_SPAWN, run_tests, skip_but_pass_in_sandcastle_if, TEST_WITH_DEV_DBG_ASAN, ) torch.backends.cuda.matmul.allow_tf32 = False if not dist.is_available(): print("Distributed not available, skipping tests", file=sys.stderr) sys.exit(0) def _build_tensor(size, value=None, dtype=torch.float, device_id=None): if value is None: value = size if device_id is None: return torch.empty(size, dtype=dtype).fill_(value) else: return torch.empty(size, dtype=dtype).fill_(value).cuda(device_id) if TEST_WITH_DEV_DBG_ASAN: print( "Skip dev-asan as torch + multiprocessing spawn have known issues", file=sys.stderr, ) sys.exit(0) if NO_MULTIPROCESSING_SPAWN: print("Spawn not available, skipping tests.", file=sys.stderr) sys.exit(0) BACKEND = os.environ["BACKEND"] if BACKEND == "gloo" or BACKEND == "nccl": class DistQuantizationTests(MultiProcessTestCase): def setUp(self): super().setUp() self._spawn_processes() torch.backends.cudnn.flags(enabled=True, allow_tf32=False).__enter__() def tearDown(self): super().tearDown() try: os.remove(self.file_name) except OSError: pass @property def op_timeout_sec(self): return 1 @property def world_size(self): return int(os.environ["WORLD_SIZE"]) @requires_gloo() @skip_but_pass_in_sandcastle_if( BACKEND != "gloo", "Only gloo backend supports all_gather_fp16" ) def test_all_gather_fp16(self): store = dist.FileStore(self.file_name, self.world_size) dist.init_process_group( store=store, rank=self.rank, world_size=self.world_size, backend="gloo" ) device = torch.device(f"cuda:{self.rank}") group = list(range(0, self.world_size)) group_id = dist.group.WORLD self._test_all_gather( group, group_id, self.rank, dtype=torch.float32, qtype=DQuantType.FP16 ) @requires_gloo() @skip_but_pass_in_sandcastle_if( BACKEND != "gloo", "Only gloo backend supports all_gather_fp16" ) def test_all_gather_bfp16(self): store = dist.FileStore(self.file_name, self.world_size) dist.init_process_group( store=store, rank=self.rank, world_size=self.world_size, backend="gloo" ) device = torch.device(f"cuda:{self.rank}") group = list(range(0, self.world_size)) group_id = dist.group.WORLD self._test_all_gather( group, group_id, self.rank, dtype=torch.float32, qtype=DQuantType.BFP16 ) @requires_nccl() @skip_but_pass_in_sandcastle_if( BACKEND != "nccl", "Only nccl backend supports all_to_all_fp16" ) @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"])) @skip_if_rocm_multiprocess def test_all_to_all_fp16(self): store = dist.FileStore(self.file_name, self.world_size) dist.init_process_group( store=store, rank=self.rank, world_size=self.world_size, backend="nccl" ) device = torch.device(f"cuda:{self.rank}") group = list(range(0, self.world_size)) group_id = dist.new_group(range(self.world_size)) rank_to_GPU = init_multigpu_helper(self.world_size, BACKEND) self._test_all_to_all( group, group_id, self.rank, cuda=True, rank_to_GPU=rank_to_GPU, dtype=torch.float32, qtype=DQuantType.FP16, ) @requires_nccl() @skip_but_pass_in_sandcastle_if( BACKEND != "nccl", "Only nccl backend supports all_to_all_fp16" ) @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"])) @skip_if_rocm_multiprocess def test_all_to_all_bfp16(self): store = dist.FileStore(self.file_name, self.world_size) dist.init_process_group( store=store, rank=self.rank, world_size=self.world_size, backend="nccl" ) device = torch.device(f"cuda:{self.rank}") group = list(range(0, self.world_size)) group_id = dist.new_group(range(self.world_size)) rank_to_GPU = init_multigpu_helper(self.world_size, BACKEND) self._test_all_to_all( group, group_id, self.rank, cuda=True, rank_to_GPU=rank_to_GPU, dtype=torch.float32, qtype=DQuantType.BFP16, ) @requires_nccl() @skip_but_pass_in_sandcastle_if( BACKEND != "nccl", "Only nccl backend supports all_to_all_single_fp16" ) @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"])) def test_all_to_all_single_fp16(self): store = dist.FileStore(self.file_name, self.world_size) dist.init_process_group( store=store, rank=self.rank, world_size=self.world_size, backend="nccl" ) device = torch.device(f"cuda:{self.rank}") group = list(range(0, self.world_size)) group_id = dist.new_group(range(self.world_size)) rank_to_GPU = init_multigpu_helper(self.world_size, BACKEND) self._test_all_to_all_single( group, group_id, self.rank, cuda=True, rank_to_GPU=rank_to_GPU, dtype=torch.float32, qtype=DQuantType.FP16, ) @requires_nccl() @skip_but_pass_in_sandcastle_if( BACKEND != "nccl", "Only nccl backend supports all_to_all_single_bfp16" ) @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"])) def test_all_to_all_single_bfp16(self): store = dist.FileStore(self.file_name, self.world_size) dist.init_process_group( store=store, rank=self.rank, world_size=self.world_size, backend="nccl" ) device = torch.device(f"cuda:{self.rank}") group = list(range(0, self.world_size)) group_id = dist.new_group(range(self.world_size)) rank_to_GPU = init_multigpu_helper(self.world_size, BACKEND) self._test_all_to_all_single( group, group_id, self.rank, cuda=True, rank_to_GPU=rank_to_GPU, dtype=torch.float32, qtype=DQuantType.BFP16, ) def _test_all_gather( self, group, group_id, rank, cuda=False, rank_to_GPU=None, dtype=torch.float, qtype=None, ): for dest in group: tensor = _build_tensor([dest + 1, dest + 1], rank, dtype=dtype) tensors = [ _build_tensor([dest + 1, dest + 1], -1, dtype=dtype) for i in group ] expected_tensors = [ _build_tensor([dest + 1, dest + 1], i, dtype=dtype) for i in group ] if cuda: tensor = tensor.cuda(rank_to_GPU[rank][0]) tensors = [t.cuda(rank_to_GPU[rank][0]) for t in tensors] if tensors[0].dtype == torch.complex64: tensor_shapes = [torch.view_as_real(tensors[0]).shape] else: tensor_shapes = [tensors[0].shape] allgather = quant.auto_quantize(dist.all_gather, qtype, quant_loss=None) allgather(tensors, tensor, group=group_id, async_op=False) for t1, t2 in zip(tensors, expected_tensors): self.assertEqual(t1, t2) def _test_all_to_all( self, group, group_id, rank, cuda=False, rank_to_GPU=None, dtype=torch.float, qtype=None, ): if group_id is not None: size = len(group) in_splits = [i + 1 for i in group] in_tensors = [ torch.ones([in_splits[i], size], dtype=dtype) * rank for i, _ in enumerate(group) ] out_tensors = [ torch.ones([(rank + 1), size], dtype=dtype) for _ in group ] expected_tensors = [ torch.ones([rank + 1, size], dtype=dtype) * i for i in group ] if cuda: in_tensors = [t.cuda(rank_to_GPU[rank][0]) for t in in_tensors] expected_tensors = [ t.cuda(rank_to_GPU[rank][0]) for t in expected_tensors ] out_tensors = [t.cuda(rank_to_GPU[rank][0]) for t in out_tensors] quantize_alltoall = quant.auto_quantize( dist.all_to_all, qtype, quant_loss=None ) quantize_alltoall(out_tensors, in_tensors, group=group_id) for t1, t2 in zip(out_tensors, expected_tensors): self.assertEqual(t1, t2) def _test_all_to_all_single( self, group, group_id, rank, cuda=False, rank_to_GPU=None, dtype=torch.float, qtype=DQuantType.FP16, ): if group_id is not None: size = len(group) in_splits = [i + 1 for i in group] out_splits = [rank + 1 for _ in group] in_tensor = torch.ones([sum(in_splits), size], dtype=dtype) * rank out_tensor = torch.ones([(rank + 1) * size, size], dtype=dtype) expected_tensor = torch.cat( [torch.ones([rank + 1, size], dtype=dtype) * i for i in group] ) if cuda: rank_to_GPU = rank_to_GPU[rank][0] in_tensor = in_tensor.cuda(rank_to_GPU) expected_tensor = expected_tensor.cuda(rank_to_GPU) out_tensor = out_tensor.cuda(rank_to_GPU) quantize_alltoall_single = quant.auto_quantize( dist.all_to_all_single, qtype, quant_loss=None ) quantize_alltoall_single( out_tensor, in_tensor, out_splits=out_splits, in_splits=in_splits, group=group_id, ) self.assertEqual(out_tensor, expected_tensor) if __name__ == "__main__": run_tests()