1# Owner(s): ["oncall: distributed"] 2 3from itertools import chain 4 5import torch 6from torch.distributed._tensor import DeviceMesh, DTensor 7from torch.distributed._tensor.placement_types import ( 8 DTensorSpec, 9 Partial, 10 Replicate, 11 Shard, 12 TensorMeta, 13) 14from torch.distributed.tensor._collective_utils import redistribute_cost 15from torch.distributed.tensor._op_schema import OpSchema, OpStrategy, PlacementStrategy 16from torch.distributed.tensor._ops._einsum_strategy import ( 17 EinsumDims, 18 gen_einsum_strategies, 19) 20from torch.testing._internal.common_utils import run_tests, TestCase 21from torch.testing._internal.distributed._tensor.common_dtensor import DTensorOpTestBase 22 23 24class TestEinsumDims(TestCase): 25 def test_batch_dims(self): 26 equation = "abc,abc->abc" 27 input_dims, output_dim = EinsumDims.parse_equation(equation) 28 edims = EinsumDims.parse_dims(input_dims, output_dim) 29 30 self.assertEqual(edims.batch_dims, ["a", "b", "c"]) 31 self.assertEqual(edims.contracting_dims, []) 32 self.assertEqual(edims.lhs_out_only_dims, []) 33 self.assertEqual(edims.rhs_out_only_dims, []) 34 35 def test_mm_dims(self): 36 equation = "mk,kn->mn" 37 input_dims, output_dim = EinsumDims.parse_equation(equation) 38 edims = EinsumDims.parse_dims(input_dims, output_dim) 39 40 self.assertEqual(edims.batch_dims, []) 41 self.assertEqual(edims.contracting_dims, ["k"]) 42 self.assertEqual(edims.lhs_out_only_dims, ["m"]) 43 self.assertEqual(edims.rhs_out_only_dims, ["n"]) 44 45 def test_bmm_dims(self): 46 equation = "bmk,bkn->bmn" 47 input_dims, output_dim = EinsumDims.parse_equation(equation) 48 edims = EinsumDims.parse_dims(input_dims, output_dim) 49 50 self.assertEqual(edims.batch_dims, ["b"]) 51 self.assertEqual(edims.contracting_dims, ["k"]) 52 self.assertEqual(edims.lhs_out_only_dims, ["m"]) 53 self.assertEqual(edims.rhs_out_only_dims, ["n"]) 54 55 equation = "bcmk,bckn->bcmn" 56 input_dims, output_dim = EinsumDims.parse_equation(equation) 57 edims = EinsumDims.parse_dims(input_dims, output_dim) 58 59 self.assertEqual(edims.batch_dims, ["b", "c"]) 60 self.assertEqual(edims.contracting_dims, ["k"]) 61 self.assertEqual(edims.lhs_out_only_dims, ["m"]) 62 self.assertEqual(edims.rhs_out_only_dims, ["n"]) 63 64 def test_free_dims(self): 65 equation = "abc,ab->abc" 66 input_dims, output_dim = EinsumDims.parse_equation(equation) 67 edims = EinsumDims.parse_dims(input_dims, output_dim) 68 69 self.assertEqual(edims.batch_dims, ["a", "b"]) 70 self.assertEqual(edims.contracting_dims, []) 71 self.assertEqual(edims.lhs_out_only_dims, ["c"]) 72 self.assertEqual(edims.rhs_out_only_dims, []) 73 74 equation = "abd,bf->abfd" 75 input_dims, output_dim = EinsumDims.parse_equation(equation) 76 edims = EinsumDims.parse_dims(input_dims, output_dim) 77 78 self.assertEqual(edims.batch_dims, ["b"]) 79 self.assertEqual(edims.contracting_dims, []) 80 self.assertEqual(edims.lhs_out_only_dims, ["a", "d"]) 81 self.assertEqual(edims.rhs_out_only_dims, ["f"]) 82 83 84class TestEinsumStrategies(DTensorOpTestBase): 85 @property 86 def world_size(self) -> int: 87 return 4 88 89 def test_mm_1d_mesh(self): 90 mesh = self.build_device_mesh() 91 92 all_strats = gen_einsum_strategies("mk,kn->mn", mesh) 93 self.assertEqual(len(all_strats.strategies), 4) 94 95 def test_mm_2d_mesh(self): 96 mesh = DeviceMesh(self.device_type, torch.arange(self.world_size).reshape(2, 2)) 97 98 all_strats = gen_einsum_strategies("mk,kn->mn", mesh) 99 self.assertEqual(len(all_strats.strategies), 16) 100 101 def test_bmm_1d_mesh(self): 102 mesh = self.build_device_mesh() 103 104 all_strats = gen_einsum_strategies("bmk,bkn->bmn", mesh) 105 self.assertEqual(len(all_strats.strategies), 5) 106 107 def test_bmm_2d_mesh(self): 108 mesh = DeviceMesh(self.device_type, torch.arange(self.world_size).reshape(2, 2)) 109 110 all_strats = gen_einsum_strategies("bmk,bkn->bmn", mesh) 111 self.assertEqual(len(all_strats.strategies), 25) 112 113 def test_pointwise_1d_mesh(self): 114 mesh = self.build_device_mesh() 115 116 simple_strats = gen_einsum_strategies("abcd,abcd->abcd", mesh) 117 self.assertEqual(len(simple_strats.strategies), 5) 118 119 broadcast_strats = gen_einsum_strategies("bcd,abcd->abcd", mesh) 120 self.assertEqual(len(broadcast_strats.strategies), 5) 121 122 def test_linearity_1d_mesh(self): 123 mesh = self.build_device_mesh() 124 125 all_strats = gen_einsum_strategies("abcd,abcd->abcd", mesh, linearity=True) 126 self.assertEqual(len(all_strats.strategies), 6) 127 128 129class TestCostModel(DTensorOpTestBase): 130 def _extract_tensor_meta(self, t) -> TensorMeta: 131 return TensorMeta(t.shape, t.stride(), t.dtype) 132 133 @property 134 def world_size(self) -> int: 135 return 4 136 137 def test_redistribute_cost_mesh_1d(self): 138 mesh_1d = self.build_device_mesh() 139 shard_placement = (Shard(0),) 140 replica_placement = (Replicate(),) 141 partial_placement = (Partial(),) 142 143 global_tensor = torch.randn(10, 10) 144 global_tensor_meta = self._extract_tensor_meta(global_tensor) 145 146 # shard spec 147 shard_spec = DTensorSpec(mesh_1d, shard_placement, global_tensor_meta) 148 # replica spec 149 replica_spec = DTensorSpec(mesh_1d, replica_placement, global_tensor_meta) 150 # partial spec 151 partial_spec = DTensorSpec(mesh_1d, partial_placement, global_tensor_meta) 152 153 # make sure reshard cost is 0 for the same spec redistribute 154 for spec in [shard_spec, replica_spec, partial_spec]: 155 cost = redistribute_cost(spec, spec) 156 self.assertEqual(cost, 0) 157 158 # shard -> replicate 159 allgather_cost = redistribute_cost(shard_spec, replica_spec) 160 # partial -> shard 161 reduce_scatter_cost = redistribute_cost(partial_spec, shard_spec) 162 # partial -> replicate 163 allreduce_cost = redistribute_cost(partial_spec, replica_spec) 164 self.assertEqual(allgather_cost, reduce_scatter_cost) 165 self.assertTrue(allreduce_cost + 1 < allgather_cost + reduce_scatter_cost) 166 # shard to partial 167 cost = redistribute_cost(shard_spec, partial_spec) 168 self.assertEqual(cost, float("inf")) 169 170 def test_redistribute_cost_latency(self): 171 # test cost model on addmm op 172 from torch.distributed.tensor._ops._matrix_ops import addmm_strategy 173 174 mesh = self.build_device_mesh() 175 shard0_placement = (Shard(0),) 176 partial_placement = (Partial(),) 177 shard1_placement = (Shard(1),) 178 179 shard0_tensor_meta = self._extract_tensor_meta(torch.randn(8)) 180 partial_tensor_meta = self._extract_tensor_meta(torch.randn(50, 6)) 181 shard1_tensor_meta = self._extract_tensor_meta(torch.randn(6, 8)) 182 183 # shard spec 184 shard0_spec = DTensorSpec(mesh, shard0_placement, shard0_tensor_meta) 185 # replica spec 186 partial_spec = DTensorSpec(mesh, partial_placement, partial_tensor_meta) 187 # partial spec 188 shard1_spec = DTensorSpec(mesh, shard1_placement, shard1_tensor_meta) 189 190 op_schema = OpSchema( 191 torch.ops.aten.addmm.default, 192 ( 193 OpStrategy([PlacementStrategy(shard0_spec)]), 194 OpStrategy([PlacementStrategy(partial_spec)]), 195 OpStrategy([PlacementStrategy(shard1_spec)]), 196 ), 197 {}, 198 ) 199 200 output_strategy = addmm_strategy(mesh, op_schema) 201 strategy_costs = {} 202 for strategy in output_strategy.strategies: 203 redistribute_cost = sum(chain.from_iterable(strategy.redistribute_cost)) 204 strategy_costs[str(strategy)] = redistribute_cost 205 206 # assert that cost model counts for collective latency (i.e. multiple comm is penalized) 207 self.assertTrue( 208 strategy_costs["(S(0), R, S(1)) -> S(1)"] 209 < strategy_costs["(R, S(0), R) -> S(0)"] 210 ) 211 # assert a single allreduce is the best one 212 self.assertEqual( 213 strategy_costs["(S(0), R, S(1)) -> S(1)"], min(strategy_costs.values()) 214 ) 215 216 def test_redistribute_cost_mesh_2d(self): 217 mesh_2d = DeviceMesh( 218 self.device_type, torch.arange(self.world_size).reshape(2, 2) 219 ) 220 shard_placement = (Shard(0), Shard(0)) 221 replica_placement = (Replicate(), Replicate()) 222 partial_placement = (Partial(), Partial()) 223 224 global_tensor = torch.randn(8, 8) 225 global_tensor_meta = self._extract_tensor_meta(global_tensor) 226 227 # shard spec 228 shard_spec = DTensorSpec(mesh_2d, shard_placement, global_tensor_meta) 229 # replica spec 230 replica_spec = DTensorSpec(mesh_2d, replica_placement, global_tensor_meta) 231 # partial spec 232 partial_spec = DTensorSpec(mesh_2d, partial_placement, global_tensor_meta) 233 234 # make sure reshard cost is 0 for the same spec redistribute 235 for spec in [shard_spec, replica_spec, partial_spec]: 236 cost = redistribute_cost(spec, spec) 237 self.assertEqual(cost, 0) 238 239 # shard -> replicate 240 allgather_cost = redistribute_cost(shard_spec, replica_spec) 241 # partial -> replicate 242 allreduce_cost = redistribute_cost(partial_spec, replica_spec) 243 # partial -> shard 244 reduce_scatter_cost = redistribute_cost(partial_spec, shard_spec) 245 self.assertTrue(allreduce_cost > allgather_cost) 246 self.assertTrue(allreduce_cost > reduce_scatter_cost) 247 248 def test_mm_strategies(self): 249 from torch.distributed.tensor._ops._matrix_ops import mm_strategy 250 251 mesh = self.build_device_mesh() 252 lhs_tensor = torch.randn(6, 8) 253 rhs_tensor = torch.randn(8, 12) 254 lhs_tensor_meta = self._extract_tensor_meta(lhs_tensor) 255 rhs_tensor_meta = self._extract_tensor_meta(rhs_tensor) 256 257 mm_combs = ( 258 (Shard(0), Replicate()), 259 (Replicate(), Shard(1)), 260 (Shard(1), Shard(0)), 261 (Replicate(), Replicate()), 262 ) 263 for lhs, rhs in mm_combs: 264 lhs_spec = DTensorSpec(mesh, (lhs,), lhs_tensor_meta) 265 rhs_spec = DTensorSpec(mesh, (rhs,), rhs_tensor_meta) 266 267 op_schema = OpSchema( 268 torch.ops.aten.mm.default, 269 ( 270 OpStrategy([PlacementStrategy(lhs_spec)]), 271 OpStrategy([PlacementStrategy(rhs_spec)]), 272 ), 273 {}, 274 ) 275 # test the strategy 276 res_strategies = mm_strategy(mesh, op_schema) 277 278 for strtgy in res_strategies.strategies: 279 if strtgy.input_specs == (lhs_spec, rhs_spec): 280 self.assertEqual(strtgy.redistribute_cost, [[0.0], [0.0]]) 281 break 282 283 op_schema = OpSchema( 284 torch.ops.aten.mm.default, 285 (lhs_spec, rhs_spec), 286 {}, 287 ) 288 # test sharding prop 289 output_sharding = DTensor._op_dispatcher.sharding_propagator.propagate_op_sharding_non_cached( 290 op_schema 291 ) 292 self.assertFalse(output_sharding.needs_redistribute) 293 294 def test_bmm_strategies(self): 295 from torch.distributed.tensor._ops._matrix_ops import bmm_strategy 296 297 mesh = self.build_device_mesh() 298 lhs_tensor = torch.randn(8, 6, 8) 299 rhs_tensor = torch.randn(8, 8, 12) 300 lhs_tensor_meta = self._extract_tensor_meta(lhs_tensor) 301 rhs_tensor_meta = self._extract_tensor_meta(rhs_tensor) 302 303 bmm_combs = ( 304 (Shard(0), Shard(0)), 305 (Shard(1), Replicate()), 306 (Replicate(), Shard(2)), 307 (Shard(2), Shard(1)), 308 (Replicate(), Replicate()), 309 ) 310 for lhs, rhs in bmm_combs: 311 lhs_spec = DTensorSpec(mesh, (lhs,), lhs_tensor_meta) 312 rhs_spec = DTensorSpec(mesh, (rhs,), rhs_tensor_meta) 313 314 op_schema = OpSchema( 315 torch.ops.aten.bmm.default, 316 ( 317 OpStrategy([PlacementStrategy(lhs_spec)]), 318 OpStrategy([PlacementStrategy(rhs_spec)]), 319 ), 320 {}, 321 ) 322 # test the strategy 323 res_strategies = bmm_strategy(mesh, op_schema) 324 325 for strtgy in res_strategies.strategies: 326 if strtgy.input_specs == (lhs_spec, rhs_spec): 327 self.assertEqual(strtgy.redistribute_cost, [[0.0], [0.0]]) 328 break 329 330 op_schema = OpSchema( 331 torch.ops.aten.bmm.default, 332 (lhs_spec, rhs_spec), 333 {}, 334 ) 335 # test sharding prop 336 output_sharding = DTensor._op_dispatcher.sharding_propagator.propagate_op_sharding_non_cached( 337 op_schema 338 ) 339 self.assertFalse(output_sharding.needs_redistribute) 340 341 342if __name__ == "__main__": 343 run_tests() 344