• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Owner(s): ["oncall: distributed"]
2import copy
3import io
4
5import torch
6import torch.distributed as dist
7import torch.distributed._functional_collectives as funcol
8from torch.distributed._state_dict_utils import (
9    _check_state_dict_similarity,
10    _copy_state_dict,
11    _create_cpu_state_dict,
12    _distribute_tensors,
13    _gather_state_dict,
14    _offload_state_dict_to_cpu,
15)
16from torch.distributed._tensor import (
17    distribute_tensor,
18    DTensor,
19    init_device_mesh,
20    Shard,
21)
22from torch.testing._internal.common_utils import run_tests
23from torch.testing._internal.distributed._tensor.common_dtensor import (
24    DTensorTestBase,
25    skip_if_lt_x_gpu,
26    with_comms,
27)
28
29
30class TestStateDictUtils(DTensorTestBase):
31    @property
32    def world_size(self):
33        return min(4, torch.cuda.device_count())
34
35    @with_comms
36    @skip_if_lt_x_gpu(2)
37    def test_gather_state_dict_dtensor(self):
38        device_mesh = self.build_device_mesh()
39        shard_spec = [Shard(0)]
40        torch.random.manual_seed(dist.get_rank())
41        local_tensor = torch.randn(3, 3, 3)
42        dist_tensor = DTensor.from_local(local_tensor, device_mesh, shard_spec)
43        state_dict = {"dtensor": dist_tensor}
44
45        gathered_state_dict = _gather_state_dict(state_dict)
46        expected_gathered_dtensor = funcol.all_gather_tensor(
47            dist_tensor.to_local(), gather_dim=0, group=(device_mesh, 0)
48        )
49        self.assertEqual(expected_gathered_dtensor, gathered_state_dict["dtensor"])
50        self.assertTrue(gathered_state_dict["dtensor"].is_cuda)
51
52    @with_comms
53    @skip_if_lt_x_gpu(4)
54    def test_gather_with_cpu_and_ranks_only(self):
55        device_mesh = self.build_device_mesh()
56        shard_spec = [Shard(0)]
57        torch.random.manual_seed(dist.get_rank())
58        local_tensor = torch.randn(3, 3, 3)
59        dist_tensor = DTensor.from_local(local_tensor, device_mesh, shard_spec)
60        state_dict = {"dtensor": dist_tensor}
61
62        gathered_state_dict = _gather_state_dict(
63            state_dict, cpu_offload=True, ranks_only=(0, 2)
64        )
65        expected_gathered_dtensor = funcol.all_gather_tensor(
66            dist_tensor.to_local(), gather_dim=0, group=(device_mesh, 0)
67        )
68        if dist.get_rank() in (0, 2):
69            self.assertEqual(expected_gathered_dtensor, gathered_state_dict["dtensor"])
70            self.assertFalse(gathered_state_dict["dtensor"].is_cuda)
71        else:
72            self.assertEqual(gathered_state_dict, {})
73
74    @with_comms
75    @skip_if_lt_x_gpu(4)
76    def test_cpu_and_ranks_only(self):
77        device = torch.device("cuda")
78        state_dict = {
79            "tensor1": torch.arange(10, device=device),
80            "tensor2": torch.ones(10, device=device),
81        }
82
83        cpu_state_dict = _offload_state_dict_to_cpu(state_dict, ranks_only=(0, 2))
84        if dist.get_rank() in (0, 2):
85            for v in cpu_state_dict.values():
86                self.assertFalse(v.is_cuda)
87            self.assertEqual(cpu_state_dict["tensor1"], torch.arange(10))
88            self.assertEqual(cpu_state_dict["tensor2"], torch.ones(10))
89        else:
90            self.assertEqual(cpu_state_dict, {})
91
92    @with_comms
93    @skip_if_lt_x_gpu(4)
94    def test_complicated_dict(self):
95        def create_dtensor():
96            device_mesh = self.build_device_mesh()
97            shard_spec = [Shard(0)]
98            torch.random.manual_seed(dist.get_rank())
99            local_tensor = torch.randn(3, 3, 3)
100            dist_tensor = DTensor.from_local(local_tensor, device_mesh, shard_spec)
101            tensor = funcol.all_gather_tensor(
102                dist_tensor.to_local(), gather_dim=0, group=(device_mesh, 0)
103            )
104            return tensor, dist_tensor
105
106        ltensor, ldtensor = [], []
107        for i in range(10):
108            tensor, dtensor = create_dtensor()
109            ltensor.append(tensor)
110            ltensor.append(torch.ones(10, device=torch.device("cuda")))
111            ldtensor.append(dtensor)
112            ldtensor.append(torch.ones(10, device=torch.device("cuda")))
113
114        tensor, dtensor = create_dtensor()
115        dist_state_dict = {
116            "local": dtensor,
117            "list": ldtensor,
118            "arange": torch.arange(10, device=torch.device("cuda")),
119        }
120        state_dict = {
121            "local": tensor,
122            "list": ltensor,
123            "arange": torch.arange(10, device=torch.device("cuda")),
124        }
125        self.assertEqual(state_dict, _gather_state_dict(dist_state_dict))
126
127    @skip_if_lt_x_gpu(2)
128    def test_create_cpu_state_dict(self):
129        device = torch.device("cuda")
130        buffer = io.BytesIO()
131        torch.save(torch.ones(10), buffer)
132        buffer.seek(0)
133        state_dict = {
134            "tensor1": torch.arange(10, device=device),
135            "tensor2": torch.ones(10, device=device),
136            "non_tensor_bytes_io": copy.deepcopy(buffer),
137            "non_tensor_bytes": buffer.read(),
138            "step": torch.tensor(7, dtype=torch.float),
139            "lr": 1.5,
140            "nested": {"list": [1, 2, 3, 4]},
141        }
142
143        def _verify(cpu_state_dict):
144            # Verify the correctness of _check_state_dict_similarity()
145            self.assertTrue(_check_state_dict_similarity(state_dict, cpu_state_dict))
146            tensor1 = cpu_state_dict["tensor1"]
147            cpu_state_dict["tensor1"] = torch.arange(11)
148            self.assertFalse(_check_state_dict_similarity(state_dict, cpu_state_dict))
149            cpu_state_dict["tensor1"] = tensor1
150
151            _copy_state_dict(state_dict, cpu_state_dict)
152
153            # Verify if _copy_state_dict works
154            for v in cpu_state_dict.values():
155                if isinstance(v, torch.Tensor):
156                    self.assertFalse(v.is_cuda)
157            self.assertEqual(cpu_state_dict["tensor1"], torch.arange(10))
158            self.assertEqual(cpu_state_dict["tensor2"], torch.ones(10))
159            buffer.seek(0)
160            cpu_state_dict["non_tensor_bytes_io"].seek(0)
161            self.assertEqual(
162                cpu_state_dict["non_tensor_bytes_io"].read(), buffer.read()
163            )
164            buffer.seek(0)
165            self.assertEqual(cpu_state_dict["non_tensor_bytes"], buffer.read())
166            self.assertEqual(cpu_state_dict["lr"], 1.5)
167            self.assertEqual(cpu_state_dict["step"], 7)
168            self.assertEqual(cpu_state_dict["nested"], {"list": [1, 2, 3, 4]})
169
170        cpu_state_dict = _create_cpu_state_dict(state_dict, pin_memory=True)
171        _verify(cpu_state_dict)
172        cpu_state_dict = _create_cpu_state_dict(state_dict, share_memory=True)
173        _verify(cpu_state_dict)
174        cpu_state_dict = _create_cpu_state_dict(
175            state_dict, share_memory=True, pin_memory=True
176        )
177        _verify(cpu_state_dict)
178
179    @with_comms
180    @skip_if_lt_x_gpu(2)
181    def test_state_dict_util_distribute_tensors(self):
182        even_tensor = torch.randn(self.world_size, 2)
183        uneven_tensor = torch.randn(1, 2)
184
185        mesh = init_device_mesh("cuda", mesh_shape=(self.world_size,))
186        even_dtensor = distribute_tensor(
187            torch.randn(self.world_size, 2), mesh, [Shard(0)]
188        )
189        uneven_dtensor = distribute_tensor(torch.randn(1, 2), mesh, [Shard(0)])
190
191        # the dtensor and tensor are different before _distribute_tensors is called.
192        local_state_dict = {
193            "even": [even_dtensor, even_tensor],
194            "uneven": [uneven_dtensor, uneven_tensor],
195        }
196        ref_local_state_dict = copy.deepcopy(local_state_dict)
197        keys = ["even", "uneven"]
198
199        _distribute_tensors(local_state_dict, keys, self.device_type)
200        for local_v, ref_v in zip(
201            local_state_dict.values(), ref_local_state_dict.values()
202        ):
203            self.assertEqual(local_v.size(), ref_v[0].size())
204            self.assertEqual(local_v.stride(), ref_v[0].stride())
205            self.assertNotEqual(
206                local_v_full_tensor := local_v.full_tensor(), ref_v[0].full_tensor()
207            )
208            self.assertEqual(local_v_full_tensor, ref_v[1])
209
210
211if __name__ == "__main__":
212    run_tests()
213