• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Owner(s): ["oncall: distributed"]
2
3from itertools import chain
4
5import torch
6from torch.distributed._tensor import DeviceMesh, DTensor
7from torch.distributed._tensor.placement_types import (
8    DTensorSpec,
9    Partial,
10    Replicate,
11    Shard,
12    TensorMeta,
13)
14from torch.distributed.tensor._collective_utils import redistribute_cost
15from torch.distributed.tensor._op_schema import OpSchema, OpStrategy, PlacementStrategy
16from torch.distributed.tensor._ops._einsum_strategy import (
17    EinsumDims,
18    gen_einsum_strategies,
19)
20from torch.testing._internal.common_utils import run_tests, TestCase
21from torch.testing._internal.distributed._tensor.common_dtensor import DTensorOpTestBase
22
23
24class TestEinsumDims(TestCase):
25    def test_batch_dims(self):
26        equation = "abc,abc->abc"
27        input_dims, output_dim = EinsumDims.parse_equation(equation)
28        edims = EinsumDims.parse_dims(input_dims, output_dim)
29
30        self.assertEqual(edims.batch_dims, ["a", "b", "c"])
31        self.assertEqual(edims.contracting_dims, [])
32        self.assertEqual(edims.lhs_out_only_dims, [])
33        self.assertEqual(edims.rhs_out_only_dims, [])
34
35    def test_mm_dims(self):
36        equation = "mk,kn->mn"
37        input_dims, output_dim = EinsumDims.parse_equation(equation)
38        edims = EinsumDims.parse_dims(input_dims, output_dim)
39
40        self.assertEqual(edims.batch_dims, [])
41        self.assertEqual(edims.contracting_dims, ["k"])
42        self.assertEqual(edims.lhs_out_only_dims, ["m"])
43        self.assertEqual(edims.rhs_out_only_dims, ["n"])
44
45    def test_bmm_dims(self):
46        equation = "bmk,bkn->bmn"
47        input_dims, output_dim = EinsumDims.parse_equation(equation)
48        edims = EinsumDims.parse_dims(input_dims, output_dim)
49
50        self.assertEqual(edims.batch_dims, ["b"])
51        self.assertEqual(edims.contracting_dims, ["k"])
52        self.assertEqual(edims.lhs_out_only_dims, ["m"])
53        self.assertEqual(edims.rhs_out_only_dims, ["n"])
54
55        equation = "bcmk,bckn->bcmn"
56        input_dims, output_dim = EinsumDims.parse_equation(equation)
57        edims = EinsumDims.parse_dims(input_dims, output_dim)
58
59        self.assertEqual(edims.batch_dims, ["b", "c"])
60        self.assertEqual(edims.contracting_dims, ["k"])
61        self.assertEqual(edims.lhs_out_only_dims, ["m"])
62        self.assertEqual(edims.rhs_out_only_dims, ["n"])
63
64    def test_free_dims(self):
65        equation = "abc,ab->abc"
66        input_dims, output_dim = EinsumDims.parse_equation(equation)
67        edims = EinsumDims.parse_dims(input_dims, output_dim)
68
69        self.assertEqual(edims.batch_dims, ["a", "b"])
70        self.assertEqual(edims.contracting_dims, [])
71        self.assertEqual(edims.lhs_out_only_dims, ["c"])
72        self.assertEqual(edims.rhs_out_only_dims, [])
73
74        equation = "abd,bf->abfd"
75        input_dims, output_dim = EinsumDims.parse_equation(equation)
76        edims = EinsumDims.parse_dims(input_dims, output_dim)
77
78        self.assertEqual(edims.batch_dims, ["b"])
79        self.assertEqual(edims.contracting_dims, [])
80        self.assertEqual(edims.lhs_out_only_dims, ["a", "d"])
81        self.assertEqual(edims.rhs_out_only_dims, ["f"])
82
83
84class TestEinsumStrategies(DTensorOpTestBase):
85    @property
86    def world_size(self) -> int:
87        return 4
88
89    def test_mm_1d_mesh(self):
90        mesh = self.build_device_mesh()
91
92        all_strats = gen_einsum_strategies("mk,kn->mn", mesh)
93        self.assertEqual(len(all_strats.strategies), 4)
94
95    def test_mm_2d_mesh(self):
96        mesh = DeviceMesh(self.device_type, torch.arange(self.world_size).reshape(2, 2))
97
98        all_strats = gen_einsum_strategies("mk,kn->mn", mesh)
99        self.assertEqual(len(all_strats.strategies), 16)
100
101    def test_bmm_1d_mesh(self):
102        mesh = self.build_device_mesh()
103
104        all_strats = gen_einsum_strategies("bmk,bkn->bmn", mesh)
105        self.assertEqual(len(all_strats.strategies), 5)
106
107    def test_bmm_2d_mesh(self):
108        mesh = DeviceMesh(self.device_type, torch.arange(self.world_size).reshape(2, 2))
109
110        all_strats = gen_einsum_strategies("bmk,bkn->bmn", mesh)
111        self.assertEqual(len(all_strats.strategies), 25)
112
113    def test_pointwise_1d_mesh(self):
114        mesh = self.build_device_mesh()
115
116        simple_strats = gen_einsum_strategies("abcd,abcd->abcd", mesh)
117        self.assertEqual(len(simple_strats.strategies), 5)
118
119        broadcast_strats = gen_einsum_strategies("bcd,abcd->abcd", mesh)
120        self.assertEqual(len(broadcast_strats.strategies), 5)
121
122    def test_linearity_1d_mesh(self):
123        mesh = self.build_device_mesh()
124
125        all_strats = gen_einsum_strategies("abcd,abcd->abcd", mesh, linearity=True)
126        self.assertEqual(len(all_strats.strategies), 6)
127
128
129class TestCostModel(DTensorOpTestBase):
130    def _extract_tensor_meta(self, t) -> TensorMeta:
131        return TensorMeta(t.shape, t.stride(), t.dtype)
132
133    @property
134    def world_size(self) -> int:
135        return 4
136
137    def test_redistribute_cost_mesh_1d(self):
138        mesh_1d = self.build_device_mesh()
139        shard_placement = (Shard(0),)
140        replica_placement = (Replicate(),)
141        partial_placement = (Partial(),)
142
143        global_tensor = torch.randn(10, 10)
144        global_tensor_meta = self._extract_tensor_meta(global_tensor)
145
146        # shard spec
147        shard_spec = DTensorSpec(mesh_1d, shard_placement, global_tensor_meta)
148        # replica spec
149        replica_spec = DTensorSpec(mesh_1d, replica_placement, global_tensor_meta)
150        # partial spec
151        partial_spec = DTensorSpec(mesh_1d, partial_placement, global_tensor_meta)
152
153        # make sure reshard cost is 0 for the same spec redistribute
154        for spec in [shard_spec, replica_spec, partial_spec]:
155            cost = redistribute_cost(spec, spec)
156            self.assertEqual(cost, 0)
157
158        # shard -> replicate
159        allgather_cost = redistribute_cost(shard_spec, replica_spec)
160        # partial -> shard
161        reduce_scatter_cost = redistribute_cost(partial_spec, shard_spec)
162        # partial -> replicate
163        allreduce_cost = redistribute_cost(partial_spec, replica_spec)
164        self.assertEqual(allgather_cost, reduce_scatter_cost)
165        self.assertTrue(allreduce_cost + 1 < allgather_cost + reduce_scatter_cost)
166        # shard to partial
167        cost = redistribute_cost(shard_spec, partial_spec)
168        self.assertEqual(cost, float("inf"))
169
170    def test_redistribute_cost_latency(self):
171        # test cost model on addmm op
172        from torch.distributed.tensor._ops._matrix_ops import addmm_strategy
173
174        mesh = self.build_device_mesh()
175        shard0_placement = (Shard(0),)
176        partial_placement = (Partial(),)
177        shard1_placement = (Shard(1),)
178
179        shard0_tensor_meta = self._extract_tensor_meta(torch.randn(8))
180        partial_tensor_meta = self._extract_tensor_meta(torch.randn(50, 6))
181        shard1_tensor_meta = self._extract_tensor_meta(torch.randn(6, 8))
182
183        # shard spec
184        shard0_spec = DTensorSpec(mesh, shard0_placement, shard0_tensor_meta)
185        # replica spec
186        partial_spec = DTensorSpec(mesh, partial_placement, partial_tensor_meta)
187        # partial spec
188        shard1_spec = DTensorSpec(mesh, shard1_placement, shard1_tensor_meta)
189
190        op_schema = OpSchema(
191            torch.ops.aten.addmm.default,
192            (
193                OpStrategy([PlacementStrategy(shard0_spec)]),
194                OpStrategy([PlacementStrategy(partial_spec)]),
195                OpStrategy([PlacementStrategy(shard1_spec)]),
196            ),
197            {},
198        )
199
200        output_strategy = addmm_strategy(mesh, op_schema)
201        strategy_costs = {}
202        for strategy in output_strategy.strategies:
203            redistribute_cost = sum(chain.from_iterable(strategy.redistribute_cost))
204            strategy_costs[str(strategy)] = redistribute_cost
205
206        # assert that cost model counts for collective latency (i.e. multiple comm is penalized)
207        self.assertTrue(
208            strategy_costs["(S(0), R, S(1)) -> S(1)"]
209            < strategy_costs["(R, S(0), R) -> S(0)"]
210        )
211        # assert a single allreduce is the best one
212        self.assertEqual(
213            strategy_costs["(S(0), R, S(1)) -> S(1)"], min(strategy_costs.values())
214        )
215
216    def test_redistribute_cost_mesh_2d(self):
217        mesh_2d = DeviceMesh(
218            self.device_type, torch.arange(self.world_size).reshape(2, 2)
219        )
220        shard_placement = (Shard(0), Shard(0))
221        replica_placement = (Replicate(), Replicate())
222        partial_placement = (Partial(), Partial())
223
224        global_tensor = torch.randn(8, 8)
225        global_tensor_meta = self._extract_tensor_meta(global_tensor)
226
227        # shard spec
228        shard_spec = DTensorSpec(mesh_2d, shard_placement, global_tensor_meta)
229        # replica spec
230        replica_spec = DTensorSpec(mesh_2d, replica_placement, global_tensor_meta)
231        # partial spec
232        partial_spec = DTensorSpec(mesh_2d, partial_placement, global_tensor_meta)
233
234        # make sure reshard cost is 0 for the same spec redistribute
235        for spec in [shard_spec, replica_spec, partial_spec]:
236            cost = redistribute_cost(spec, spec)
237            self.assertEqual(cost, 0)
238
239        # shard -> replicate
240        allgather_cost = redistribute_cost(shard_spec, replica_spec)
241        # partial -> replicate
242        allreduce_cost = redistribute_cost(partial_spec, replica_spec)
243        # partial -> shard
244        reduce_scatter_cost = redistribute_cost(partial_spec, shard_spec)
245        self.assertTrue(allreduce_cost > allgather_cost)
246        self.assertTrue(allreduce_cost > reduce_scatter_cost)
247
248    def test_mm_strategies(self):
249        from torch.distributed.tensor._ops._matrix_ops import mm_strategy
250
251        mesh = self.build_device_mesh()
252        lhs_tensor = torch.randn(6, 8)
253        rhs_tensor = torch.randn(8, 12)
254        lhs_tensor_meta = self._extract_tensor_meta(lhs_tensor)
255        rhs_tensor_meta = self._extract_tensor_meta(rhs_tensor)
256
257        mm_combs = (
258            (Shard(0), Replicate()),
259            (Replicate(), Shard(1)),
260            (Shard(1), Shard(0)),
261            (Replicate(), Replicate()),
262        )
263        for lhs, rhs in mm_combs:
264            lhs_spec = DTensorSpec(mesh, (lhs,), lhs_tensor_meta)
265            rhs_spec = DTensorSpec(mesh, (rhs,), rhs_tensor_meta)
266
267            op_schema = OpSchema(
268                torch.ops.aten.mm.default,
269                (
270                    OpStrategy([PlacementStrategy(lhs_spec)]),
271                    OpStrategy([PlacementStrategy(rhs_spec)]),
272                ),
273                {},
274            )
275            # test the strategy
276            res_strategies = mm_strategy(mesh, op_schema)
277
278            for strtgy in res_strategies.strategies:
279                if strtgy.input_specs == (lhs_spec, rhs_spec):
280                    self.assertEqual(strtgy.redistribute_cost, [[0.0], [0.0]])
281                    break
282
283            op_schema = OpSchema(
284                torch.ops.aten.mm.default,
285                (lhs_spec, rhs_spec),
286                {},
287            )
288            # test sharding prop
289            output_sharding = DTensor._op_dispatcher.sharding_propagator.propagate_op_sharding_non_cached(
290                op_schema
291            )
292            self.assertFalse(output_sharding.needs_redistribute)
293
294    def test_bmm_strategies(self):
295        from torch.distributed.tensor._ops._matrix_ops import bmm_strategy
296
297        mesh = self.build_device_mesh()
298        lhs_tensor = torch.randn(8, 6, 8)
299        rhs_tensor = torch.randn(8, 8, 12)
300        lhs_tensor_meta = self._extract_tensor_meta(lhs_tensor)
301        rhs_tensor_meta = self._extract_tensor_meta(rhs_tensor)
302
303        bmm_combs = (
304            (Shard(0), Shard(0)),
305            (Shard(1), Replicate()),
306            (Replicate(), Shard(2)),
307            (Shard(2), Shard(1)),
308            (Replicate(), Replicate()),
309        )
310        for lhs, rhs in bmm_combs:
311            lhs_spec = DTensorSpec(mesh, (lhs,), lhs_tensor_meta)
312            rhs_spec = DTensorSpec(mesh, (rhs,), rhs_tensor_meta)
313
314            op_schema = OpSchema(
315                torch.ops.aten.bmm.default,
316                (
317                    OpStrategy([PlacementStrategy(lhs_spec)]),
318                    OpStrategy([PlacementStrategy(rhs_spec)]),
319                ),
320                {},
321            )
322            # test the strategy
323            res_strategies = bmm_strategy(mesh, op_schema)
324
325            for strtgy in res_strategies.strategies:
326                if strtgy.input_specs == (lhs_spec, rhs_spec):
327                    self.assertEqual(strtgy.redistribute_cost, [[0.0], [0.0]])
328                    break
329
330            op_schema = OpSchema(
331                torch.ops.aten.bmm.default,
332                (lhs_spec, rhs_spec),
333                {},
334            )
335            # test sharding prop
336            output_sharding = DTensor._op_dispatcher.sharding_propagator.propagate_op_sharding_non_cached(
337                op_schema
338            )
339            self.assertFalse(output_sharding.needs_redistribute)
340
341
342if __name__ == "__main__":
343    run_tests()
344