• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) Meta Platforms, Inc. and affiliates
2# Owner(s): ["oncall: distributed"]
3
4import os
5import unittest
6from functools import wraps
7from typing import Any, Callable, Dict, Tuple
8
9import numpy as np
10
11import torch
12from torch import nn
13from torch.distributed._tensor import (
14    DeviceMesh,
15    distribute_module,
16    distribute_tensor,
17    Replicate,
18    Shard,
19)
20from torch.testing._internal.common_utils import run_tests, TestCase
21
22
23# wrapper to check xla test requirements
24def with_xla(func: Callable) -> Callable:
25    assert func is not None
26
27    @wraps(func)  # pyre-ignore[6]
28    def wrapper(
29        self, *args: Tuple[object], **kwargs: Dict[str, Any]  # type: ignore[misc]
30    ) -> None:
31        # TODO(yeounoh) replace this with xr.use_spmd() when we deprecate the flag.
32        os.environ["XLA_USE_SPMD"] = "1"
33        try:
34            import torch_xla  # type:ignore[import]  # noqa: F401
35        except ImportError as exc:
36            raise unittest.SkipTest("torch_xla is not installed.") from exc
37        self.device_type = "xla"
38        func(self, *args, **kwargs)  # type: ignore[misc]
39        os.environ["XLA_USE_SPMD"] = "0"
40
41    return wrapper
42
43
44class DTensorXLAIntegrationTest(TestCase):
45    class SimpleLinear(nn.Module):
46        def __init__(self) -> None:
47            super(DTensorXLAIntegrationTest.SimpleLinear, self).__init__()
48            self.fc1 = nn.Linear(128, 64)
49            self.relu = nn.ReLU()
50            self.fc2 = nn.Linear(64, 1)
51
52        def forward(self, x):
53            y = self.relu(self.fc1(x))
54            z = self.fc2(y)
55            return z
56
57    @with_xla
58    def test_xla_distribute_tensor_1d_shard(self):
59        import torch_xla.runtime as xr  # type:ignore[import]
60
61        device_count = xr.global_runtime_device_count()
62        if device_count > 1:
63            device_mesh = DeviceMesh("xla", list(range(device_count)))
64            shard_spec = [Shard(0)]
65
66            for requires_grad in [True, False]:
67                tensor_to_shard = torch.randn(
68                    3 * device_count, 3, requires_grad=requires_grad
69                )
70                dist_tensor = distribute_tensor(
71                    tensor_to_shard, device_mesh, shard_spec
72                )
73                # TODO(yeounoh) switch to DTensor API when XLAShardedTensor inherits DTensor
74                assert type(dist_tensor).__name__ == "XLAShardedTensor"
75                global_tensor = dist_tensor.global_tensor  # type:ignore[attr-defined]
76                self.assertEqual(
77                    global_tensor.size(), torch.Size([3 * device_count, 3])
78                )
79                local_tensor = dist_tensor.local_shards[0].data
80                self.assertEqual(local_tensor.size(), torch.Size([3, 3]))
81                if requires_grad:
82                    self.assertTrue(dist_tensor.global_tensor.requires_grad)
83                    self.assertTrue(dist_tensor.is_leaf)
84
85    @with_xla
86    def test_xla_distribute_tensor_1d_replicate(self):
87        import torch_xla.runtime as xr  # type:ignore[import]
88
89        device_count = xr.global_runtime_device_count()
90        device_mesh = DeviceMesh("xla", list(range(device_count)))
91        shard_spec = [Replicate()]
92
93        for requires_grad in [True, False]:
94            tensor_to_shard = torch.randn(
95                3 * device_count, 3, requires_grad=requires_grad
96            )
97            dist_tensor = distribute_tensor(tensor_to_shard, device_mesh, shard_spec)
98            # TODO(yeounoh) switch to DTensor API when XLAShardedTensor inherits DTensor
99            assert type(dist_tensor).__name__ == "XLAShardedTensor"
100            global_tensor = dist_tensor.global_tensor  # type:ignore[attr-defined]
101            self.assertEqual(global_tensor.size(), torch.Size([3 * device_count, 3]))
102            local_tensor = dist_tensor.local_shards[0].data
103            self.assertEqual(local_tensor.size(), torch.Size([3 * device_count, 3]))
104            if requires_grad:
105                self.assertTrue(dist_tensor.global_tensor.requires_grad)
106                self.assertTrue(dist_tensor.is_leaf)
107
108    @with_xla
109    def test_xla_distribute_tensor_2d(self):
110        import torch_xla.runtime as xr  # type:ignore[import]
111
112        device_count = xr.global_runtime_device_count()
113        if device_count > 1:
114            device_mesh = DeviceMesh(
115                "xla", np.array(range(device_count)).reshape(2, device_count // 2)
116            )
117            shard_spec = [Replicate(), Shard(0)]
118
119            for requires_grad in [True, False]:
120                tensor_to_shard = torch.randn(
121                    3 * device_count // 2, 3, requires_grad=requires_grad
122                )
123                dist_tensor = distribute_tensor(
124                    tensor_to_shard, device_mesh, shard_spec
125                )
126                # TODO(yeounoh) switch to DTensor API when XLAShardedTensor inherits DTensor
127                assert type(dist_tensor).__name__ == "XLAShardedTensor"
128                global_tensor = dist_tensor.global_tensor  # type:ignore[attr-defined]
129                self.assertEqual(
130                    global_tensor.size(), torch.Size([3 * device_count // 2, 3])
131                )
132                local_tensor = dist_tensor.local_shards[0].data
133                self.assertEqual(local_tensor.size(), torch.Size([3, 3]))
134                if requires_grad:
135                    self.assertTrue(dist_tensor.global_tensor.requires_grad)
136                    self.assertTrue(dist_tensor.is_leaf)
137
138    @with_xla
139    def text_xla_distribute_module(self):
140        import torch_xla  # type:ignore[import]
141        import torch_xla.core.xla_model as xm  # type:ignore[import]
142        import torch_xla.runtime as xr  # type:ignore[import]
143
144        model = self.SimpleLinear().to(xm.xla_device())
145
146        device_count = xr.global_runtime_device_count()
147        device_mesh = DeviceMesh("xla", list(range(device_count)))
148
149        def shard_params(mod_name, mod, mesh):
150            shard_spec = [Shard(0)]
151            # annoate fc1 and fc2
152            if isinstance(mod, nn.Linear):
153                for name, param in mod.named_parameters():
154                    # annotate the parameter tensors directly
155                    distribute_tensor(param, mesh, shard_spec)
156
157        sharded_model = distribute_module(model, device_mesh, shard_params)
158        self.assertTrue(
159            torch_xla._XLAC._get_xla_sharding_spec(sharded_model.fc1.weight) != ""
160        )
161        self.assertTrue(
162            torch_xla._XLAC._get_xla_sharding_spec(sharded_model.fc2.weight) != ""
163        )
164
165
166if __name__ == "__main__":
167    run_tests()
168