1# Copyright (c) Meta Platforms, Inc. and affiliates 2# Owner(s): ["oncall: distributed"] 3 4import os 5import unittest 6from functools import wraps 7from typing import Any, Callable, Dict, Tuple 8 9import numpy as np 10 11import torch 12from torch import nn 13from torch.distributed._tensor import ( 14 DeviceMesh, 15 distribute_module, 16 distribute_tensor, 17 Replicate, 18 Shard, 19) 20from torch.testing._internal.common_utils import run_tests, TestCase 21 22 23# wrapper to check xla test requirements 24def with_xla(func: Callable) -> Callable: 25 assert func is not None 26 27 @wraps(func) # pyre-ignore[6] 28 def wrapper( 29 self, *args: Tuple[object], **kwargs: Dict[str, Any] # type: ignore[misc] 30 ) -> None: 31 # TODO(yeounoh) replace this with xr.use_spmd() when we deprecate the flag. 32 os.environ["XLA_USE_SPMD"] = "1" 33 try: 34 import torch_xla # type:ignore[import] # noqa: F401 35 except ImportError as exc: 36 raise unittest.SkipTest("torch_xla is not installed.") from exc 37 self.device_type = "xla" 38 func(self, *args, **kwargs) # type: ignore[misc] 39 os.environ["XLA_USE_SPMD"] = "0" 40 41 return wrapper 42 43 44class DTensorXLAIntegrationTest(TestCase): 45 class SimpleLinear(nn.Module): 46 def __init__(self) -> None: 47 super(DTensorXLAIntegrationTest.SimpleLinear, self).__init__() 48 self.fc1 = nn.Linear(128, 64) 49 self.relu = nn.ReLU() 50 self.fc2 = nn.Linear(64, 1) 51 52 def forward(self, x): 53 y = self.relu(self.fc1(x)) 54 z = self.fc2(y) 55 return z 56 57 @with_xla 58 def test_xla_distribute_tensor_1d_shard(self): 59 import torch_xla.runtime as xr # type:ignore[import] 60 61 device_count = xr.global_runtime_device_count() 62 if device_count > 1: 63 device_mesh = DeviceMesh("xla", list(range(device_count))) 64 shard_spec = [Shard(0)] 65 66 for requires_grad in [True, False]: 67 tensor_to_shard = torch.randn( 68 3 * device_count, 3, requires_grad=requires_grad 69 ) 70 dist_tensor = distribute_tensor( 71 tensor_to_shard, device_mesh, shard_spec 72 ) 73 # TODO(yeounoh) switch to DTensor API when XLAShardedTensor inherits DTensor 74 assert type(dist_tensor).__name__ == "XLAShardedTensor" 75 global_tensor = dist_tensor.global_tensor # type:ignore[attr-defined] 76 self.assertEqual( 77 global_tensor.size(), torch.Size([3 * device_count, 3]) 78 ) 79 local_tensor = dist_tensor.local_shards[0].data 80 self.assertEqual(local_tensor.size(), torch.Size([3, 3])) 81 if requires_grad: 82 self.assertTrue(dist_tensor.global_tensor.requires_grad) 83 self.assertTrue(dist_tensor.is_leaf) 84 85 @with_xla 86 def test_xla_distribute_tensor_1d_replicate(self): 87 import torch_xla.runtime as xr # type:ignore[import] 88 89 device_count = xr.global_runtime_device_count() 90 device_mesh = DeviceMesh("xla", list(range(device_count))) 91 shard_spec = [Replicate()] 92 93 for requires_grad in [True, False]: 94 tensor_to_shard = torch.randn( 95 3 * device_count, 3, requires_grad=requires_grad 96 ) 97 dist_tensor = distribute_tensor(tensor_to_shard, device_mesh, shard_spec) 98 # TODO(yeounoh) switch to DTensor API when XLAShardedTensor inherits DTensor 99 assert type(dist_tensor).__name__ == "XLAShardedTensor" 100 global_tensor = dist_tensor.global_tensor # type:ignore[attr-defined] 101 self.assertEqual(global_tensor.size(), torch.Size([3 * device_count, 3])) 102 local_tensor = dist_tensor.local_shards[0].data 103 self.assertEqual(local_tensor.size(), torch.Size([3 * device_count, 3])) 104 if requires_grad: 105 self.assertTrue(dist_tensor.global_tensor.requires_grad) 106 self.assertTrue(dist_tensor.is_leaf) 107 108 @with_xla 109 def test_xla_distribute_tensor_2d(self): 110 import torch_xla.runtime as xr # type:ignore[import] 111 112 device_count = xr.global_runtime_device_count() 113 if device_count > 1: 114 device_mesh = DeviceMesh( 115 "xla", np.array(range(device_count)).reshape(2, device_count // 2) 116 ) 117 shard_spec = [Replicate(), Shard(0)] 118 119 for requires_grad in [True, False]: 120 tensor_to_shard = torch.randn( 121 3 * device_count // 2, 3, requires_grad=requires_grad 122 ) 123 dist_tensor = distribute_tensor( 124 tensor_to_shard, device_mesh, shard_spec 125 ) 126 # TODO(yeounoh) switch to DTensor API when XLAShardedTensor inherits DTensor 127 assert type(dist_tensor).__name__ == "XLAShardedTensor" 128 global_tensor = dist_tensor.global_tensor # type:ignore[attr-defined] 129 self.assertEqual( 130 global_tensor.size(), torch.Size([3 * device_count // 2, 3]) 131 ) 132 local_tensor = dist_tensor.local_shards[0].data 133 self.assertEqual(local_tensor.size(), torch.Size([3, 3])) 134 if requires_grad: 135 self.assertTrue(dist_tensor.global_tensor.requires_grad) 136 self.assertTrue(dist_tensor.is_leaf) 137 138 @with_xla 139 def text_xla_distribute_module(self): 140 import torch_xla # type:ignore[import] 141 import torch_xla.core.xla_model as xm # type:ignore[import] 142 import torch_xla.runtime as xr # type:ignore[import] 143 144 model = self.SimpleLinear().to(xm.xla_device()) 145 146 device_count = xr.global_runtime_device_count() 147 device_mesh = DeviceMesh("xla", list(range(device_count))) 148 149 def shard_params(mod_name, mod, mesh): 150 shard_spec = [Shard(0)] 151 # annoate fc1 and fc2 152 if isinstance(mod, nn.Linear): 153 for name, param in mod.named_parameters(): 154 # annotate the parameter tensors directly 155 distribute_tensor(param, mesh, shard_spec) 156 157 sharded_model = distribute_module(model, device_mesh, shard_params) 158 self.assertTrue( 159 torch_xla._XLAC._get_xla_sharding_spec(sharded_model.fc1.weight) != "" 160 ) 161 self.assertTrue( 162 torch_xla._XLAC._get_xla_sharding_spec(sharded_model.fc2.weight) != "" 163 ) 164 165 166if __name__ == "__main__": 167 run_tests() 168