# Owner(s): ["oncall: distributed"] from typing import Dict, Union import torch import torch.distributed as dist import torch.distributed.checkpoint as dist_cp from torch.distributed._tensor import ( DeviceMesh, distribute_tensor, DTensor, Replicate, Shard, zeros, ) from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, skip_if_lt_x_gpu, with_comms, ) from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir SUBMESH_TENSOR_SIZE = 6 class MyTestModule(torch.nn.Module): def __init__( self, sdt: DTensor, rdt: DTensor, submesh_sdt: DTensor, submesh_rdt: DTensor, extra_state: int = 1, extra_state_tensor: torch.Tensor = torch.zeros(1), ) -> None: super().__init__() self.sdt = torch.nn.Parameter(sdt) self.rdt = torch.nn.Parameter(rdt) self.submesh_sdt = torch.nn.Parameter(submesh_sdt) self.submesh_rdt = torch.nn.Parameter(submesh_rdt) self._extra_state = extra_state self._extra_state_tensor = extra_state_tensor @property def extra_state(self) -> int: return self._extra_state @extra_state.setter def extra_state(self, new_extra_state: int) -> None: self._extra_state = new_extra_state @property def extra_state_tensor(self) -> torch.Tensor: return self._extra_state_tensor @extra_state_tensor.setter def extra_state_tensor(self, new_extra_state_tensor: torch.Tensor) -> None: self._extra_state_tensor = new_extra_state_tensor def get_extra_state(self) -> Dict[str, Union[int, torch._tensor.Tensor]]: return { "extra_state": self._extra_state, "extra_state_tensor": self._extra_state_tensor, } def set_extra_state( self, state: Dict[str, Union[int, torch._tensor.Tensor]] ) -> None: self._extra_state = state["extra_state"] # pyre-ignore[8] self._extra_state_tensor = state["extra_state_tensor"] # pyre-ignore[8] class DTensorPlanner(DTensorTestBase): def create_dtensor_model( self, tensor_to_shard: torch.tensor, tensor_to_replicate: torch.tensor, ) -> torch.nn.Module: mesh = DeviceMesh( device_type=self.device_type, mesh=range(dist.get_world_size()), ) sharded_dt = distribute_tensor(tensor_to_shard, mesh, placements=[Shard(0)]) replicated_dt = distribute_tensor( tensor_to_replicate, mesh, placements=[Replicate()] ) # Only even rank will be part of the mesh. submesh = DeviceMesh( device_type=self.device_type, mesh=[i for i in range(dist.get_world_size()) if i % 2 == 0], ) submesh_tensor_size = [SUBMESH_TENSOR_SIZE] submesh_sharded_dt = zeros( submesh_tensor_size, device_mesh=submesh, placements=[Shard(0)], ) submesh_replicated_dt = zeros( submesh_tensor_size, device_mesh=submesh, placements=[Replicate()] ) model = MyTestModule( sharded_dt, replicated_dt, submesh_sharded_dt, submesh_replicated_dt, ).cuda() return ( model, sharded_dt, replicated_dt, ) @with_comms @with_temp_dir @skip_if_lt_x_gpu(2) def test_distributed_tensor_planner(self) -> None: CHECKPOINT_DIR = self.temp_dir local_tensor = torch.arange(0, 4, dtype=torch.float32) local_tensor_2 = torch.arange(4, 8, dtype=torch.float32) (model, sharded_dt, replicated_dt) = self.create_dtensor_model( local_tensor, local_tensor_2 ) state_dict = model.state_dict() """ When the model is initialized, the state_dict on rank 0 are as follows when there are 4 GPUs. rank 0: OrderedDict( [ ( 'rdt', DTensor( local_tensor=tensor([4., 5., 6., 7.], device='cuda:0'), device_mesh=DeviceMesh:([0, 1, 2, 3]), placements=[Replicate()] ) ), ( 'sdt', DTensor( local_tensor=tensor([0.], device='cuda:0'), device_mesh=DeviceMesh:([0, 1, 2, 3]), placements=[Shard(dim=0)]) ), ), ( 'submesh_sdt', DTensor( local_tensor=tensor([8., 9.], device='cuda:0'), device_mesh=DeviceMesh:([0, 2]), placements=[Shard(dim=0)] ), ), ( 'submesh_rdt', DTensor( local_tensor=tensor([12., 13., 14., 15.], device='cuda:0'), device_mesh=DeviceMesh:([0, 2]), placements=[Replicate()] ) ), ( '_extra_state', {'extra_state': 1, 'extra_state_tensor': tensor([0.])} ) ] ) """ dist_cp.save_state_dict( state_dict=state_dict, storage_writer=dist_cp.FileSystemWriter(path=CHECKPOINT_DIR), planner=dist_cp.DefaultSavePlanner(), ) model, _, _ = self.create_dtensor_model(local_tensor * 10, local_tensor_2 * 10) state_dict = model.state_dict() """ When the model is re-initialized, we have changed the params in state_dict. The updated values are as follows, when there are 4 GPUs: rank 0: OrderedDict( [ ( 'rdt', DTensor( local_tensor=tensor([40., 50., 60., 70.], device='cuda:0'), device_mesh=DeviceMesh:([0, 1, 2, 3]), placements=[Replicate()], ) ), ( 'sdt', DTensor( local_tensor=tensor([0.], device='cuda:0'), device_mesh=DeviceMesh:([0, 1, 2, 3]), placements=[Shard(dim=0)], ) ), ( 'submesh_sdt', DTensor( local_tensor=tensor([80., 90.], device='cuda:0'), device_mesh=DeviceMesh:([0, 2]), placements=[Shard(dim=0)] ) ), ('submesh_rdt', DTensor( local_tensor=tensor([120., 130., 140., 150.], device='cuda:0'), device_mesh=DeviceMesh:([0, 2]), placements=[Replicate()] ) ), ( '_extra_state', {'extra_state': 10, 'extra_state_tensor': tensor([10.])} ) ] ) """ dist_cp.load_state_dict( state_dict=state_dict, storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR), planner=dist_cp.DefaultLoadPlanner(), ) """ After loading the model from the checkpoint, we want to make sure that the values in state_dict match the values that are originally saved to the checkpoint. """ for k, v in state_dict.items(): if k == "sdt": self.assertEqual(sharded_dt.to_local(), v.to_local()) if k == "rdt": self.assertEqual(replicated_dt.to_local(), v.to_local()) if k == "submesh_sdt": if self.rank % 2 == 0: shard_size = int(SUBMESH_TENSOR_SIZE / v.device_mesh.size()) self.assertEqual(v.to_local().size(), torch.Size([shard_size])) self.assertEqual(v.to_local(), torch.zeros([shard_size])) else: self.assertEqual(v.to_local().size(), torch.Size([0])) self.assertEqual(v.to_local(), torch.tensor([])) if k == "submesh_rdt": if self.rank % 2 == 0: shard_size = SUBMESH_TENSOR_SIZE self.assertEqual(v.to_local().size(), torch.Size([shard_size])) self.assertEqual(v.to_local(), torch.zeros([shard_size])) else: self.assertEqual(v.to_local().size(), torch.Size([0])) self.assertEqual(v.to_local(), torch.tensor([])) if k == "_extra_state": self.assertEqual(1, v["extra_state"]) self.assertEqual(torch.tensor([0.0]), v["extra_state_tensor"]) if __name__ == "__main__": run_tests()