• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Owner(s): ["oncall: distributed"]
2
3from copy import deepcopy
4
5import torch
6import torch.distributed.checkpoint as dcp
7from torch.distributed._tensor import init_device_mesh
8from torch.distributed.checkpoint.default_planner import (
9    DefaultLoadPlanner,
10    DefaultSavePlanner,
11)
12from torch.distributed.tensor.parallel import (
13    ColwiseParallel,
14    parallelize_module,
15    RowwiseParallel,
16)
17from torch.testing._internal.common_utils import run_tests
18from torch.testing._internal.distributed._tensor.common_dtensor import (
19    DTensorTestBase,
20    MLPModule,
21    skip_if_lt_x_gpu,
22    with_comms,
23)
24from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
25
26
27class UnevenShardedModel(torch.nn.Module):
28    def __init__(self, device):
29        super().__init__()
30        torch.manual_seed(5)
31        self.net1 = torch.nn.Linear(5, 10, device=device)
32        self.relu = torch.nn.ReLU()
33        self.net2 = torch.nn.Linear(10, 15, device=device)
34        self.net3 = torch.nn.Linear(15, 1, device=device)
35
36    def forward(self, x):
37        return self.net3(self.net2(self.relu(self.net1(x))))
38
39
40class TestTpCheckpoint(DTensorTestBase):
41    @with_comms
42    @skip_if_lt_x_gpu(2)
43    @with_temp_dir
44    def test_tp_checkpoint(self):
45        CHECKPOINT_DIR = self.temp_dir
46        mesh_shpe = (self.world_size,)
47        tp_mesh = init_device_mesh(self.device_type, mesh_shpe)
48
49        # create model and move it to GPU with id rank
50        model = MLPModule(self.device_type).cuda(self.rank)
51        # Parallelize the module based on the given Parallel Style.
52        parallelize_plan = {
53            "net1": ColwiseParallel(),
54            "net2": RowwiseParallel(),
55        }
56        model = parallelize_module(model, tp_mesh, parallelize_plan)
57        optimizer = torch.optim.SGD(model.parameters(), lr=0.25)
58        original_state_dict = deepcopy(model.state_dict())
59
60        dcp.save(
61            state_dict=original_state_dict,
62            storage_writer=dcp.FileSystemWriter(CHECKPOINT_DIR),
63            planner=DefaultSavePlanner(),
64        )
65
66        # Update the parameters so model.state_dict() will be different from original_state_dict.
67        torch.manual_seed(0)
68        inp = torch.rand(20, 10).cuda(self.rank)
69        output = model(inp)
70        output.sum().backward()
71        optimizer.step()
72        state_dict = model.state_dict()
73
74        # ensure the current model parameters are different from original_state_dict before loading from checkpoint
75        for param1, param2 in zip(original_state_dict.values(), state_dict.values()):
76            self.assertNotEqual(param1.to_local(), param2.to_local())
77
78        dcp.load(
79            state_dict=state_dict,
80            storage_reader=dcp.FileSystemReader(CHECKPOINT_DIR),
81            planner=DefaultLoadPlanner(),
82        )
83
84        # now load from checkpoint to check current model parameters are the same as original_state_dict
85        for param1, param2 in zip(original_state_dict.values(), state_dict.values()):
86            self.assertEqual(param1.to_local(), param2.to_local())
87
88    @with_comms
89    @skip_if_lt_x_gpu(2)
90    @with_temp_dir
91    def test_tp_checkpoint_load_on_meta_device(self):
92        CHECKPOINT_DIR = self.temp_dir
93        mesh_shpe = (self.world_size,)
94        tp_mesh = init_device_mesh(self.device_type, mesh_shpe)
95
96        # create model and move it to GPU with id rank
97        model = UnevenShardedModel(self.device_type).cuda(self.rank)
98        # Parallelize the module based on the given Parallel Style.
99        parallelize_plan = {
100            "net1": ColwiseParallel(),
101            "net2": RowwiseParallel(),
102            "net3": ColwiseParallel(),
103        }
104        model = parallelize_module(model, tp_mesh, parallelize_plan=parallelize_plan)
105        original_state_dict = {
106            "model": model.state_dict(),
107        }
108
109        dcp.save(
110            state_dict=original_state_dict,
111            storage_writer=dcp.FileSystemWriter(CHECKPOINT_DIR),
112        )
113
114        model2 = parallelize_module(
115            UnevenShardedModel("meta"), tp_mesh, parallelize_plan=parallelize_plan
116        )
117        model2_sd_before_load = model2.state_dict()
118        state_dict_to_load = {"model": model2_sd_before_load}
119
120        dcp.load(
121            state_dict=state_dict_to_load,
122            storage_reader=dcp.FileSystemReader(CHECKPOINT_DIR),
123        )
124        # We need to make sure state_dict_to_load["model"] is the same as state_dict_after_load["model"],
125        # since we are doing in-place loading.
126        self.assertTrue(state_dict_to_load["model"] is model2_sd_before_load)
127
128        model2.load_state_dict(state_dict_to_load["model"], assign=True)
129        state_dict_after_load = {"model": model2.state_dict()}
130
131        self.assertEqual(
132            len(original_state_dict["model"]), len(state_dict_to_load["model"])
133        )
134        self.assertEqual(
135            len(original_state_dict["model"]), len(state_dict_after_load["model"])
136        )
137
138        for name, param in original_state_dict["model"].items():
139            param_to_load = state_dict_to_load["model"][name]
140            param_after_load = state_dict_after_load["model"][name]
141
142            # we need to explicitly check the device is not meta as the assertEqual check
143            # currently doesn't handle DTensor with meta device.
144            self.assertTrue(not param_to_load.is_meta)
145            self.assertTrue(not param_after_load.is_meta)
146            self.assertEqual(param.to_local(), param_to_load.to_local())
147            self.assertEqual(param.to_local(), param_after_load.to_local())
148
149
150if __name__ == "__main__":
151    run_tests()
152