• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Owner(s): ["oncall: distributed"]
2import torch
3import torch.distributed.checkpoint as dist_cp
4from torch.distributed._shard.sharded_tensor import ShardedTensor
5from torch.distributed._state_dict_utils import _all_gather_sharded_tensor
6from torch.distributed._tensor import DTensor, init_device_mesh, Replicate
7from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
8from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
9from torch.distributed.tensor.parallel import (
10    ColwiseParallel,
11    parallelize_module,
12    RowwiseParallel,
13)
14from torch.testing._internal.common_utils import run_tests
15from torch.testing._internal.distributed._tensor.common_dtensor import (
16    DTensorTestBase,
17    MLPModule,
18    skip_if_lt_x_gpu,
19    with_comms,
20)
21from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
22
23
24# TODO: modularize this test and add test for checkpoint conversion in both direction.
25class TestFsdpTpCheckpointConversion(DTensorTestBase):
26    @with_comms
27    @skip_if_lt_x_gpu(2)
28    @with_temp_dir
29    def test_fsdp_to_tp(self):
30        CHECKPOINT_DIR = self.temp_dir
31
32        model = MLPModule(self.device_type).cuda(self.rank)
33        # create a FSDP wrapped model
34        fsdp_model = FSDP(model, use_orig_params=True)
35
36        FSDP.set_state_dict_type(
37            fsdp_model,
38            StateDictType.SHARDED_STATE_DICT,
39        )
40        fsdp_state_dict = fsdp_model.state_dict()
41
42        # save fsdp_state_dict to storage
43        dist_cp.save(
44            state_dict=fsdp_state_dict,
45            storage_writer=dist_cp.FileSystemWriter(CHECKPOINT_DIR),
46        )
47
48        # create a TP wrapped model
49        mesh_shape = (self.world_size,)
50        device_mesh = init_device_mesh(self.device_type, mesh_shape)
51        model = MLPModule(self.device_type).cuda(self.rank)
52        # Parallelize the module based on the given Parallel Style.
53        parallelize_plan = {
54            "net1": ColwiseParallel(),
55            "net2": RowwiseParallel(),
56        }
57        tp_model = parallelize_module(model, device_mesh, parallelize_plan)
58        optimizer = torch.optim.SGD(tp_model.parameters(), lr=0.25)
59
60        # Update the parameters so tp_model.state_dict() will be different from fsdp_model.state_dict().
61        torch.manual_seed(0)
62        inp = torch.rand(20, 10).cuda(self.rank)
63        output = tp_model(inp)
64        output.sum().backward()
65        optimizer.step()
66        tp_state_dict = tp_model.state_dict()
67
68        # Check parameters are indeed different prior to loading.
69        for fsdp_item, tp_item in zip(fsdp_state_dict.items(), tp_state_dict.items()):
70            fsdp_k, fsdp_v = fsdp_item
71            tp_k, tp_v = tp_item
72
73            self.assertEqual(fsdp_k, tp_k)
74
75            if isinstance(fsdp_v, ShardedTensor) and isinstance(tp_v, DTensor):
76                fsdp_redistributed = _all_gather_sharded_tensor(fsdp_v)
77                tp_redistributed = tp_v.redistribute(
78                    device_mesh, placements=[Replicate()]
79                ).to_local()
80                self.assertNotEqual(fsdp_redistributed, tp_redistributed)
81
82        dist_cp.load_state_dict(
83            state_dict=tp_state_dict,
84            storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR),
85        )
86        tp_model.load_state_dict(tp_state_dict)
87
88        # Check parameters are equal after loading.
89        tp_state_dict_after_load = tp_model.state_dict()
90        for fsdp_item, tp_item in zip(fsdp_state_dict.items(), tp_state_dict.items()):
91            fsdp_k, fsdp_v = fsdp_item
92            tp_k, tp_v = tp_item
93
94            self.assertEqual(fsdp_k, tp_k)
95
96            if isinstance(fsdp_v, ShardedTensor) and isinstance(tp_v, DTensor):
97                fsdp_redistributed = _all_gather_sharded_tensor(fsdp_v)
98                tp_redistributed = tp_v.redistribute(
99                    device_mesh, placements=[Replicate()]
100                ).to_local()
101                self.assertEqual(fsdp_redistributed, tp_redistributed)
102
103
104if __name__ == "__main__":
105    run_tests()
106