• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Owner(s): ["oncall: distributed"]
2import torch
3import torch.distributed.checkpoint as dist_cp
4from torch.distributed._tensor import (
5    distribute_tensor,
6    init_device_mesh,
7    Replicate,
8    Shard,
9    zeros,
10)
11from torch.testing._internal.common_utils import run_tests
12from torch.testing._internal.distributed._tensor.common_dtensor import (
13    DTensorTestBase,
14    skip_if_lt_x_gpu,
15    with_comms,
16)
17from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
18
19
20CHECKPOINT_DIR = "checkpoint"
21
22ONE_D_PLACEMENTS = [
23    [Shard(0)],
24    [Replicate()],
25]
26ONE_D_TO_ONE_D_PLACEMENTS = [
27    ([Replicate()], [Shard(0)]),
28    ([Shard(0)], [Replicate()]),
29]
30
31TWO_D_PLACEMENTS = [
32    [Replicate(), Replicate()],
33    [Replicate(), Shard(0)],
34    [Shard(0), Replicate()],
35    [Shard(0), Shard(0)],
36]
37TWO_D_TO_TWO_D_PLACEMENTS = []
38for p1 in TWO_D_PLACEMENTS:
39    for p2 in TWO_D_PLACEMENTS:
40        if p1 != p2:
41            TWO_D_TO_TWO_D_PLACEMENTS.append((p1, p2))
42
43
44class TestDTensorReshardPlacementChange(DTensorTestBase):
45    """
46    Test DCP reshard for DTensor with placements changes and without world_size change and mesh_tensor change.
47    """
48
49    @with_comms
50    @skip_if_lt_x_gpu(2)
51    @with_temp_dir
52    def test_1d_to_1d_reshard_placement_change(self) -> None:
53        CHECKPOINT_DIR = self.temp_dir
54
55        for one_d_to_one_d_placements in ONE_D_TO_ONE_D_PLACEMENTS:
56            original_placement, new_placement = one_d_to_one_d_placements
57
58            global_tensor = torch.arange(16, dtype=torch.float).view(4, 4)
59            mesh_shape = (self.world_size,)
60            device_mesh = init_device_mesh(self.device_type, mesh_shape)
61            dtensor = distribute_tensor(
62                global_tensor, device_mesh, placements=original_placement
63            )
64            state_dict_to_save = {"dtensor": dtensor}
65
66            dist_cp.save_state_dict(
67                state_dict=state_dict_to_save,
68                storage_writer=dist_cp.FileSystemWriter(path=CHECKPOINT_DIR),
69                planner=dist_cp.DefaultSavePlanner(),
70            )
71
72            zero_dtensor = zeros(
73                [4, 4], device_mesh=device_mesh, placements=new_placement
74            )
75            state_dict_to_load = {"dtensor": zero_dtensor}
76
77            dist_cp.load_state_dict(
78                state_dict=state_dict_to_load,
79                storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR),
80                planner=dist_cp.DefaultLoadPlanner(),
81            )
82
83            # materialzie the whole tensor to compare with the original global_tensor
84            state_dict_to_load["dtensor"] = state_dict_to_load["dtensor"].redistribute(
85                device_mesh,
86                placements=[Replicate()],
87            )
88            self.assertEqual(global_tensor, state_dict_to_load["dtensor"].to_local())
89
90            # redistribute the tensor back to its original placement for comparison.
91            state_dict_to_load["dtensor"] = state_dict_to_load["dtensor"].redistribute(
92                device_mesh,
93                placements=original_placement,
94            )
95            self.assertEqual(
96                state_dict_to_save["dtensor"].to_local(),
97                state_dict_to_load["dtensor"].to_local(),
98            )
99
100    @with_comms
101    @skip_if_lt_x_gpu(4)
102    @with_temp_dir
103    def test_2d_to_2d_reshard_placement_change(self) -> None:
104        CHECKPOINT_DIR = self.temp_dir
105        for two_d_to_two_d_placements in TWO_D_TO_TWO_D_PLACEMENTS:
106            original_placement, new_placement = two_d_to_two_d_placements
107
108            global_tensor = torch.arange(16, dtype=torch.float).view(4, 4)
109            mesh_shape = (2, self.world_size // 2)
110            mesh_2d = init_device_mesh(self.device_type, mesh_shape)
111            dtensor = distribute_tensor(
112                global_tensor,
113                mesh_2d,
114                placements=original_placement,
115            )
116            state_dict_to_save = {"dtensor": dtensor}
117
118            dist_cp.save_state_dict(
119                state_dict=state_dict_to_save,
120                storage_writer=dist_cp.FileSystemWriter(path=CHECKPOINT_DIR),
121                planner=dist_cp.DefaultSavePlanner(),
122            )
123
124            zero_dtensor = zeros([4, 4], device_mesh=mesh_2d, placements=new_placement)
125            state_dict_to_load = {"dtensor": zero_dtensor}
126
127            dist_cp.load_state_dict(
128                state_dict=state_dict_to_load,
129                storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR),
130                planner=dist_cp.DefaultLoadPlanner(),
131            )
132
133            state_dict_to_load["dtensor"] = state_dict_to_load["dtensor"].redistribute(
134                mesh_2d,
135                placements=[Replicate(), Replicate()],
136            )
137            self.assertEqual(global_tensor, state_dict_to_load["dtensor"].to_local())
138
139            state_dict_to_load["dtensor"] = state_dict_to_load["dtensor"].redistribute(
140                mesh_2d,
141                placements=original_placement,
142            )
143            self.assertEqual(
144                state_dict_to_save["dtensor"].to_local(),
145                state_dict_to_load["dtensor"].to_local(),
146            )
147
148
149class TestDTensorReshardMeshChange(DTensorTestBase):
150    """
151    Test DCP reshard for DTensor with placements changes and mesh_tensor change.
152    """
153
154    @with_comms
155    @with_temp_dir
156    @skip_if_lt_x_gpu(2)
157    def test_1d_to_2d_reshard_mesh_change(self) -> None:
158        CHECKPOINT_DIR = self.temp_dir
159        for placements_1d in ONE_D_PLACEMENTS:
160            global_tensor = torch.arange(16, dtype=torch.float).view(4, 4)
161            mesh_shape = (self.world_size,)
162            mesh_1d = init_device_mesh(self.device_type, mesh_shape)
163            dtensor = distribute_tensor(
164                global_tensor, mesh_1d, placements=placements_1d
165            )
166            state_dict_to_save = {"dtensor": dtensor}
167
168            dist_cp.save_state_dict(
169                state_dict=state_dict_to_save,
170                storage_writer=dist_cp.FileSystemWriter(path=CHECKPOINT_DIR),
171                planner=dist_cp.DefaultSavePlanner(),
172            )
173
174            for placements_2d in TWO_D_PLACEMENTS:
175                mesh_shape = (2, self.world_size // 2)
176                mesh_2d = init_device_mesh(self.device_type, mesh_shape)
177
178                zero_dtensor = zeros(
179                    [4, 4], device_mesh=mesh_2d, placements=placements_2d
180                )
181                state_dict_to_load = {"dtensor": zero_dtensor}
182
183                dist_cp.load_state_dict(
184                    state_dict=state_dict_to_load,
185                    storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR),
186                    planner=dist_cp.DefaultLoadPlanner(),
187                )
188
189                # materialzie the whole tensor to compare with the original global_tensor
190                state_dict_to_load["dtensor"] = state_dict_to_load[
191                    "dtensor"
192                ].redistribute(
193                    mesh_2d,
194                    placements=[Replicate(), Replicate()],
195                )
196                self.assertEqual(
197                    global_tensor, state_dict_to_load["dtensor"].to_local()
198                )
199
200    @with_comms
201    @with_temp_dir
202    @skip_if_lt_x_gpu(4)
203    def test_2d_to_1d_reshard_mesh_change(self) -> None:
204        CHECKPOINT_DIR = self.temp_dir
205        for placements_2d in TWO_D_PLACEMENTS:
206            global_tensor = torch.arange(16, dtype=torch.float).view(4, 4)
207            mesh_shape = (2, self.world_size // 2)
208            mesh_2d = init_device_mesh(self.device_type, mesh_shape)
209            dtensor = distribute_tensor(
210                global_tensor, mesh_2d, placements=placements_2d
211            )
212            state_dict_to_save = {"dtensor": dtensor}
213
214            dist_cp.save_state_dict(
215                state_dict=state_dict_to_save,
216                storage_writer=dist_cp.FileSystemWriter(path=CHECKPOINT_DIR),
217                planner=dist_cp.DefaultSavePlanner(),
218            )
219
220            for placements_1d in ONE_D_PLACEMENTS:
221                mesh_shape = (self.world_size,)
222                mesh_1d = init_device_mesh(self.device_type, mesh_shape)
223
224                zero_dtensor = zeros(
225                    [4, 4], device_mesh=mesh_1d, placements=placements_1d
226                )
227                state_dict_to_load = {"dtensor": zero_dtensor}
228
229                dist_cp.load_state_dict(
230                    state_dict=state_dict_to_load,
231                    storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR),
232                    planner=dist_cp.DefaultLoadPlanner(),
233                )
234
235                # materialzie the whole tensor to compare with the original global_tensor
236                state_dict_to_load["dtensor"] = state_dict_to_load[
237                    "dtensor"
238                ].redistribute(
239                    mesh_1d,
240                    placements=[Replicate()],
241                )
242                self.assertEqual(
243                    global_tensor, state_dict_to_load["dtensor"].to_local()
244                )
245
246    @with_comms
247    @with_temp_dir
248    @skip_if_lt_x_gpu(2)
249    def test_dtensor_checkpoint_resharding_with_empty_shard(self):
250        """
251        Test dtensor checkpoint resharding with dtensor containing empty shards.
252        """
253        tensor = torch.rand(1).cuda()
254        mesh = init_device_mesh(self.device_type, (self.world_size,))
255        dtensor = distribute_tensor(tensor, mesh, [Shard(0)])
256        ref_state_dict = {"dtensor": dtensor}
257
258        dist_cp.save_state_dict(
259            state_dict=ref_state_dict,
260            storage_writer=dist_cp.FileSystemWriter(path=self.temp_dir),
261        )
262
263        tensor = torch.rand(1).cuda()
264        mesh_2 = init_device_mesh(self.device_type, (2, self.world_size // 2))
265        dtensor = distribute_tensor(tensor, mesh_2, [Shard(0), Shard(0)])
266        state_dict = {"dtensor": dtensor}
267        dist_cp.load_state_dict(
268            state_dict=state_dict,
269            storage_reader=dist_cp.FileSystemReader(self.temp_dir),
270        )
271
272    # TODO: Add a assertEqual for ref_state_dict["dtensor"].full_tensor()
273    # and state_dict["dtensor"].full_tensor() after we fix the size mismatch
274    # issue for un-even sharding dtensor.
275
276
277# TODO: Add dtensor resharding test when world size changes.
278if __name__ == "__main__":
279    run_tests()
280