1# Owner(s): ["oncall: distributed"] 2import copy 3import io 4 5import torch 6import torch.distributed as dist 7import torch.distributed._functional_collectives as funcol 8from torch.distributed._state_dict_utils import ( 9 _check_state_dict_similarity, 10 _copy_state_dict, 11 _create_cpu_state_dict, 12 _distribute_tensors, 13 _gather_state_dict, 14 _offload_state_dict_to_cpu, 15) 16from torch.distributed._tensor import ( 17 distribute_tensor, 18 DTensor, 19 init_device_mesh, 20 Shard, 21) 22from torch.testing._internal.common_utils import run_tests 23from torch.testing._internal.distributed._tensor.common_dtensor import ( 24 DTensorTestBase, 25 skip_if_lt_x_gpu, 26 with_comms, 27) 28 29 30class TestStateDictUtils(DTensorTestBase): 31 @property 32 def world_size(self): 33 return min(4, torch.cuda.device_count()) 34 35 @with_comms 36 @skip_if_lt_x_gpu(2) 37 def test_gather_state_dict_dtensor(self): 38 device_mesh = self.build_device_mesh() 39 shard_spec = [Shard(0)] 40 torch.random.manual_seed(dist.get_rank()) 41 local_tensor = torch.randn(3, 3, 3) 42 dist_tensor = DTensor.from_local(local_tensor, device_mesh, shard_spec) 43 state_dict = {"dtensor": dist_tensor} 44 45 gathered_state_dict = _gather_state_dict(state_dict) 46 expected_gathered_dtensor = funcol.all_gather_tensor( 47 dist_tensor.to_local(), gather_dim=0, group=(device_mesh, 0) 48 ) 49 self.assertEqual(expected_gathered_dtensor, gathered_state_dict["dtensor"]) 50 self.assertTrue(gathered_state_dict["dtensor"].is_cuda) 51 52 @with_comms 53 @skip_if_lt_x_gpu(4) 54 def test_gather_with_cpu_and_ranks_only(self): 55 device_mesh = self.build_device_mesh() 56 shard_spec = [Shard(0)] 57 torch.random.manual_seed(dist.get_rank()) 58 local_tensor = torch.randn(3, 3, 3) 59 dist_tensor = DTensor.from_local(local_tensor, device_mesh, shard_spec) 60 state_dict = {"dtensor": dist_tensor} 61 62 gathered_state_dict = _gather_state_dict( 63 state_dict, cpu_offload=True, ranks_only=(0, 2) 64 ) 65 expected_gathered_dtensor = funcol.all_gather_tensor( 66 dist_tensor.to_local(), gather_dim=0, group=(device_mesh, 0) 67 ) 68 if dist.get_rank() in (0, 2): 69 self.assertEqual(expected_gathered_dtensor, gathered_state_dict["dtensor"]) 70 self.assertFalse(gathered_state_dict["dtensor"].is_cuda) 71 else: 72 self.assertEqual(gathered_state_dict, {}) 73 74 @with_comms 75 @skip_if_lt_x_gpu(4) 76 def test_cpu_and_ranks_only(self): 77 device = torch.device("cuda") 78 state_dict = { 79 "tensor1": torch.arange(10, device=device), 80 "tensor2": torch.ones(10, device=device), 81 } 82 83 cpu_state_dict = _offload_state_dict_to_cpu(state_dict, ranks_only=(0, 2)) 84 if dist.get_rank() in (0, 2): 85 for v in cpu_state_dict.values(): 86 self.assertFalse(v.is_cuda) 87 self.assertEqual(cpu_state_dict["tensor1"], torch.arange(10)) 88 self.assertEqual(cpu_state_dict["tensor2"], torch.ones(10)) 89 else: 90 self.assertEqual(cpu_state_dict, {}) 91 92 @with_comms 93 @skip_if_lt_x_gpu(4) 94 def test_complicated_dict(self): 95 def create_dtensor(): 96 device_mesh = self.build_device_mesh() 97 shard_spec = [Shard(0)] 98 torch.random.manual_seed(dist.get_rank()) 99 local_tensor = torch.randn(3, 3, 3) 100 dist_tensor = DTensor.from_local(local_tensor, device_mesh, shard_spec) 101 tensor = funcol.all_gather_tensor( 102 dist_tensor.to_local(), gather_dim=0, group=(device_mesh, 0) 103 ) 104 return tensor, dist_tensor 105 106 ltensor, ldtensor = [], [] 107 for i in range(10): 108 tensor, dtensor = create_dtensor() 109 ltensor.append(tensor) 110 ltensor.append(torch.ones(10, device=torch.device("cuda"))) 111 ldtensor.append(dtensor) 112 ldtensor.append(torch.ones(10, device=torch.device("cuda"))) 113 114 tensor, dtensor = create_dtensor() 115 dist_state_dict = { 116 "local": dtensor, 117 "list": ldtensor, 118 "arange": torch.arange(10, device=torch.device("cuda")), 119 } 120 state_dict = { 121 "local": tensor, 122 "list": ltensor, 123 "arange": torch.arange(10, device=torch.device("cuda")), 124 } 125 self.assertEqual(state_dict, _gather_state_dict(dist_state_dict)) 126 127 @skip_if_lt_x_gpu(2) 128 def test_create_cpu_state_dict(self): 129 device = torch.device("cuda") 130 buffer = io.BytesIO() 131 torch.save(torch.ones(10), buffer) 132 buffer.seek(0) 133 state_dict = { 134 "tensor1": torch.arange(10, device=device), 135 "tensor2": torch.ones(10, device=device), 136 "non_tensor_bytes_io": copy.deepcopy(buffer), 137 "non_tensor_bytes": buffer.read(), 138 "step": torch.tensor(7, dtype=torch.float), 139 "lr": 1.5, 140 "nested": {"list": [1, 2, 3, 4]}, 141 } 142 143 def _verify(cpu_state_dict): 144 # Verify the correctness of _check_state_dict_similarity() 145 self.assertTrue(_check_state_dict_similarity(state_dict, cpu_state_dict)) 146 tensor1 = cpu_state_dict["tensor1"] 147 cpu_state_dict["tensor1"] = torch.arange(11) 148 self.assertFalse(_check_state_dict_similarity(state_dict, cpu_state_dict)) 149 cpu_state_dict["tensor1"] = tensor1 150 151 _copy_state_dict(state_dict, cpu_state_dict) 152 153 # Verify if _copy_state_dict works 154 for v in cpu_state_dict.values(): 155 if isinstance(v, torch.Tensor): 156 self.assertFalse(v.is_cuda) 157 self.assertEqual(cpu_state_dict["tensor1"], torch.arange(10)) 158 self.assertEqual(cpu_state_dict["tensor2"], torch.ones(10)) 159 buffer.seek(0) 160 cpu_state_dict["non_tensor_bytes_io"].seek(0) 161 self.assertEqual( 162 cpu_state_dict["non_tensor_bytes_io"].read(), buffer.read() 163 ) 164 buffer.seek(0) 165 self.assertEqual(cpu_state_dict["non_tensor_bytes"], buffer.read()) 166 self.assertEqual(cpu_state_dict["lr"], 1.5) 167 self.assertEqual(cpu_state_dict["step"], 7) 168 self.assertEqual(cpu_state_dict["nested"], {"list": [1, 2, 3, 4]}) 169 170 cpu_state_dict = _create_cpu_state_dict(state_dict, pin_memory=True) 171 _verify(cpu_state_dict) 172 cpu_state_dict = _create_cpu_state_dict(state_dict, share_memory=True) 173 _verify(cpu_state_dict) 174 cpu_state_dict = _create_cpu_state_dict( 175 state_dict, share_memory=True, pin_memory=True 176 ) 177 _verify(cpu_state_dict) 178 179 @with_comms 180 @skip_if_lt_x_gpu(2) 181 def test_state_dict_util_distribute_tensors(self): 182 even_tensor = torch.randn(self.world_size, 2) 183 uneven_tensor = torch.randn(1, 2) 184 185 mesh = init_device_mesh("cuda", mesh_shape=(self.world_size,)) 186 even_dtensor = distribute_tensor( 187 torch.randn(self.world_size, 2), mesh, [Shard(0)] 188 ) 189 uneven_dtensor = distribute_tensor(torch.randn(1, 2), mesh, [Shard(0)]) 190 191 # the dtensor and tensor are different before _distribute_tensors is called. 192 local_state_dict = { 193 "even": [even_dtensor, even_tensor], 194 "uneven": [uneven_dtensor, uneven_tensor], 195 } 196 ref_local_state_dict = copy.deepcopy(local_state_dict) 197 keys = ["even", "uneven"] 198 199 _distribute_tensors(local_state_dict, keys, self.device_type) 200 for local_v, ref_v in zip( 201 local_state_dict.values(), ref_local_state_dict.values() 202 ): 203 self.assertEqual(local_v.size(), ref_v[0].size()) 204 self.assertEqual(local_v.stride(), ref_v[0].stride()) 205 self.assertNotEqual( 206 local_v_full_tensor := local_v.full_tensor(), ref_v[0].full_tensor() 207 ) 208 self.assertEqual(local_v_full_tensor, ref_v[1]) 209 210 211if __name__ == "__main__": 212 run_tests() 213