1# Owner(s): ["oncall: distributed"] 2 3import sys 4 5import torch 6import torch.distributed.fsdp._traversal_utils as traversal_utils 7from torch import distributed as dist 8from torch.distributed.fsdp import ( 9 CPUOffload, 10 FullyShardedDataParallel as FSDP, 11 MixedPrecision, 12) 13from torch.testing._internal.common_distributed import skip_if_lt_x_gpu 14from torch.testing._internal.common_fsdp import ( 15 CUDAInitMode, 16 FSDPInitMode, 17 FSDPTest, 18 NestedWrappedModule, 19) 20from torch.testing._internal.common_utils import ( 21 instantiate_parametrized_tests, 22 run_tests, 23 TEST_WITH_DEV_DBG_ASAN, 24) 25 26 27if not dist.is_available(): 28 print("Distributed not available, skipping tests", file=sys.stderr) 29 sys.exit(0) 30 31if TEST_WITH_DEV_DBG_ASAN: 32 print( 33 "Skip dev-asan as torch + multiprocessing spawn have known issues", 34 file=sys.stderr, 35 ) 36 sys.exit(0) 37 38 39class TestPureFP16(FSDPTest): 40 @property 41 def world_size(self): 42 # Test fails due to inaccuracies when using more than 4 GPUs 43 return min(4, super().world_size) 44 45 @skip_if_lt_x_gpu(2) 46 def test_pure_fp16_training(self): 47 """Tests pure FP16 training, including when the parameter's dtype is 48 changed after FSDP initialization and before training.""" 49 self.run_subtests( 50 { 51 "cpu_offload": [ 52 CPUOffload(offload_params=True), 53 CPUOffload(offload_params=False), 54 ] 55 }, 56 self._test_pure_fp16_training, 57 ) 58 59 def _test_pure_fp16_training(self, cpu_offload: CPUOffload): 60 self._test_fsdp_parity( 61 NestedWrappedModule, 62 FSDPInitMode.RECURSIVE, 63 cuda_init_mode=CUDAInitMode.CUDA_BEFORE, 64 # Run one iteration to avoid NaN without a gradient scaler 65 num_iters=1, 66 cpu_offload=cpu_offload, 67 use_pure_fp16=True, 68 ) 69 70 @skip_if_lt_x_gpu(2) 71 def test_fp16_dtypes(self): 72 """ 73 Tests that both user-facing parameter/gradient dtypes and internal 74 saved dtype attributes are as expected when using an FP16 model 75 possibly with explicit mixed precision enabled. 76 """ 77 self.run_subtests( 78 { 79 "to_half_before_fsdp_init": [False, True], 80 "use_orig_params": [False, True], 81 "mixed_precision": [ 82 MixedPrecision(), 83 MixedPrecision( 84 param_dtype=torch.float16, 85 reduce_dtype=torch.float32, 86 ), 87 MixedPrecision( 88 param_dtype=torch.float32, 89 ), 90 ], 91 }, 92 self._test_fp16_dtypes, 93 ) 94 95 def _test_fp16_dtypes( 96 self, 97 to_half_before_fsdp_init: bool, 98 use_orig_params: bool, 99 mixed_precision: MixedPrecision, 100 ): 101 model = NestedWrappedModule.init( 102 self.process_group, 103 FSDPInitMode.NO_FSDP, 104 CUDAInitMode.CUDA_NEVER, 105 {}, 106 ) 107 fsdp_kwargs = { 108 "use_orig_params": use_orig_params, 109 "device_id": torch.cuda.current_device(), 110 "mixed_precision": mixed_precision, 111 } 112 if to_half_before_fsdp_init: 113 model = model.half() 114 fsdp_model = FSDP(model, **fsdp_kwargs) 115 if not to_half_before_fsdp_init: 116 fsdp_model = fsdp_model.half() 117 for param in fsdp_model.parameters(): 118 self.assertEqual(param.dtype, torch.float16) 119 inp = tuple( 120 t.half() if torch.is_tensor(t) else t 121 for t in fsdp_model.module.get_input(torch.device("cuda")) 122 ) 123 out = fsdp_model(*inp) 124 out.sum().backward() 125 126 # Check handle dtype attributes 127 for handle in traversal_utils._get_fsdp_handles(fsdp_model): 128 self.assertEqual(handle.flat_param.dtype, torch.float16) 129 self.assertEqual(handle.flat_param.grad.dtype, torch.float16) 130 self.assertEqual(handle._orig_param_dtype, torch.float16) 131 # Specifying `mixed_precision` takes precedence over the model 132 # dtype for both `param_dtype` and `reduce_dtype` 133 if mixed_precision.param_dtype is not None: 134 self.assertEqual( 135 handle._fwd_bwd_param_dtype, mixed_precision.param_dtype 136 ) 137 else: 138 self.assertEqual(handle._fwd_bwd_param_dtype, torch.float16) 139 if mixed_precision.reduce_dtype is not None: 140 self.assertEqual(handle._reduce_dtype, mixed_precision.reduce_dtype) 141 elif ( 142 mixed_precision.reduce_dtype is None 143 and mixed_precision.param_dtype is not None 144 ): 145 # Special case: infer reduce dtype from parameter dtype 146 self.assertEqual(handle._reduce_dtype, mixed_precision.param_dtype) 147 else: 148 self.assertEqual(handle._reduce_dtype, torch.float16) 149 150 # Check parameter/gradient dtypes 151 for param in fsdp_model.parameters(): 152 self.assertEqual(param.dtype, torch.float16) 153 if param.grad is not None: 154 self.assertEqual(param.grad.dtype, torch.float16) 155 156 157instantiate_parametrized_tests(TestPureFP16) 158 159if __name__ == "__main__": 160 run_tests() 161