• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Owner(s): ["module: inductor"]
2import unittest
3from unittest.mock import patch
4
5import torch
6import torch._dynamo
7import torch._dynamo.logging
8import torch._dynamo.test_case
9
10# for some reason importing functional collectives after dynamo breaks collectives handling!
11import torch.distributed._functional_collectives as _functional_collectives
12from torch._C import FileCheck
13from torch._dynamo.utils import same
14from torch._inductor import ir, scheduler
15from torch._inductor.comm_analysis import (
16    baseLat,
17    hwLat,
18    llMaxBws,
19    NCCL_ALGO,
20    NCCL_HW,
21    NCCL_PROTO,
22    NVIDIA_GPU_TYPE,
23)
24from torch._inductor.utils import run_and_get_triton_code
25from torch.testing._internal.common_distributed import (
26    _dynamo_dist_per_rank_init,
27    at_least_x_gpu,
28    DynamoDistributedMultiProcTestCase,
29    requires_nccl,
30)
31from torch.utils._triton import has_triton
32
33
34def get_snode_runtime_for_reorder_compute_test(snode):
35    # NOTE: custom cost model to show that the compute reordering algorithm is working
36    # Collective kernels
37    if isinstance(snode.node, ir._CollectiveKernel):
38        return 100
39    elif isinstance(snode.node, ir._WaitKernel):
40        return 0
41    # High-arithmetic-intensity compute kernels
42    elif isinstance(snode.node, ir.ExternKernel):
43        return 5
44    # All other kernels
45    return 1
46
47
48def create_grouped_node_for_allreduce_and_its_deps(snodes):
49    name_to_snode = {snode.node.name: snode for snode in snodes}
50    all_reduce_snodes = [
51        snode
52        for snode in snodes
53        if isinstance(snode.node, ir._CollectiveKernel)
54        and snode.node.op_overload == torch.ops._c10d_functional.all_reduce_.default
55    ]
56    assert len(all_reduce_snodes) == 1
57    all_reduce_snode = all_reduce_snodes[0]
58    all_reduce_dep_snodes = [
59        name_to_snode[node.name] for node in all_reduce_snode.node.inputs
60    ]
61    assert len(all_reduce_dep_snodes) == 1
62    all_reduce_dep_snode = all_reduce_dep_snodes[0]
63
64    grouped_snode = scheduler.GroupedSchedulerNode.create(
65        [all_reduce_dep_snode, all_reduce_snode]
66    )
67    new_snode_order = []
68    new_snode_order.append(grouped_snode)
69    for snode in snodes:
70        if snode in grouped_snode.snodes:
71            continue
72        new_snode_order.append(snode)
73    return new_snode_order
74
75
76@requires_nccl()
77class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
78    """
79    Run correctness checks in multi-proc runner, mark with minimum # GPUs to run under
80    """
81
82    def get_world_trs(self):
83        return {
84            "tag": "",
85            "ranks": list(range(self.world_size)),
86            "group_size": self.world_size,
87        }
88
89    @property
90    def world_size(self) -> int:
91        # hack: no matter whether we have 2 or 3 or 4 gpus, just run on 2
92        # works around issue with skipif<2 and workers with unpredictable #s gpu
93        return 2
94
95    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
96    @patch.object(torch._inductor.config, "allow_buffer_reuse", True)
97    # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
98    @patch.object(torch._inductor.config, "compile_threads", 1)
99    @patch.object(torch._inductor.config, "reorder_for_locality", False)
100    @patch.object(torch._inductor.config, "reorder_for_compute_comm_overlap", True)
101    @patch.object(
102        torch._inductor.config,
103        "reorder_for_compute_comm_overlap_passes",
104        [
105            "sink_waits",
106        ],
107    )
108    def test_sink_waits(self):
109        def func(a):
110            ar = _functional_collectives.all_reduce(a, "sum", "0")
111            b = torch.matmul(a, a)
112            return torch.matmul(ar, b)
113
114        with _dynamo_dist_per_rank_init(
115            self.rank, self.world_size, fake_pg=not at_least_x_gpu(2)
116        ):
117            inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank
118            compiled = torch.compile(func)
119            code = run_and_get_triton_code(compiled, inputs)
120            # Verify that the wait_tensor is sinked below the 1st matmul but
121            # above the 2nd matmul.
122            (
123                FileCheck()
124                .check("torch.ops._c10d_functional.all_reduce_.default")
125                .check("extern_kernels.mm")
126                .check("torch.ops._c10d_functional.wait_tensor.default")
127                .check("extern_kernels.mm")
128                .run(code)
129            )
130            out = compiled(inputs)
131            correct = func(inputs)
132            self.assertTrue(same(out, correct))
133
134    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
135    @patch.object(torch._inductor.config, "allow_buffer_reuse", True)
136    # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
137    @patch.object(torch._inductor.config, "compile_threads", 1)
138    @patch.object(torch._inductor.config, "reorder_for_locality", False)
139    @patch.object(torch._inductor.config, "reorder_for_compute_comm_overlap", True)
140    @patch.object(
141        torch._inductor.config,
142        "reorder_for_compute_comm_overlap_passes",
143        [
144            "raise_comms",
145        ],
146    )
147    def test_raise_comms(self):
148        def func(a):
149            b = torch.matmul(a, a)
150            c = torch.relu(b)
151            d = torch.matmul(c, c)
152            e = _functional_collectives.all_reduce(b, "sum", "0")
153            return torch.matmul(d, e)
154
155        with _dynamo_dist_per_rank_init(
156            self.rank, self.world_size, fake_pg=not at_least_x_gpu(2)
157        ):
158            inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank
159            compiled = torch.compile(func)
160            code = run_and_get_triton_code(compiled, inputs)
161            print(code)
162            # Verify that the all_reduce_ has been raised above the 2nd matmul
163            # but below the 1st matmul. Note that the all_reduce_ directly
164            # writes to the output buffer of the 1st matmul, which is an input
165            # to the first relu. Therefore, the all_reduce_ should be scheduled
166            # after the first relu.
167            (
168                FileCheck()
169                .check("extern_kernels.mm")
170                .check("triton_poi_fused_relu")
171                .check("torch.ops._c10d_functional.all_reduce_.default")
172                .check("extern_kernels.mm")
173                .check("torch.ops._c10d_functional.wait_tensor.default")
174                .check("extern_kernels.mm")
175                .run(code)
176            )
177            out = compiled(inputs)
178            correct = func(inputs)
179            self.assertTrue(same(out, correct))
180
181    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
182    @patch.object(torch._inductor.config, "allow_buffer_reuse", True)
183    # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
184    @patch.object(torch._inductor.config, "compile_threads", 1)
185    @patch.object(torch._inductor.config, "reorder_for_compute_comm_overlap", True)
186    @patch.object(
187        torch._inductor.config,
188        "reorder_for_compute_comm_overlap_passes",
189        [
190            "sink_waits",
191            "raise_comms",
192        ],
193    )
194    def test_sink_waits_raise_comms(self):
195        def func(a, *, tag, ranks, group_size):
196            b = torch.matmul(a, a)
197            c = torch.relu(b)
198            d = torch.matmul(c, c)
199            e = _functional_collectives.all_reduce(b, "sum", "0")
200            f = torch.relu(d)
201            g = torch.matmul(f, f)
202            return torch.mm(e, g)
203
204        with _dynamo_dist_per_rank_init(
205            self.rank, self.world_size, fake_pg=not at_least_x_gpu(2)
206        ):
207            inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank
208            compiled = torch.compile(func)
209            code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
210            # Things to verify:
211            # - The clone prologue of the all_reduce_ should not be fused with
212            # any relus.
213            # - The all_reduce_ and its prologue should be raised above the 2nd
214            # matmul but below the 1st matmul.
215            # - The wait_tensor should be sinked below the 3rd matmul but above
216            # the 4th matmul.
217            (
218                FileCheck()
219                .check("extern_kernels.mm")
220                .check("triton_poi_fused_all_reduce_0")
221                .check("torch.ops._c10d_functional.all_reduce_.default")
222                .check("triton_poi_fused_relu")
223                .check("extern_kernels.mm")
224                .check("triton_poi_fused_relu")
225                .check("extern_kernels.mm")
226                .check("torch.ops._c10d_functional.wait_tensor.default")
227                .check("extern_kernels.mm")
228                .run(code)
229            )
230            out = compiled(inputs, **self.get_world_trs())
231            correct = func(inputs, **self.get_world_trs())
232            self.assertTrue(same(out, correct))
233
234    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
235    @patch.object(torch._inductor.config, "allow_buffer_reuse", True)
236    # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
237    @patch.object(torch._inductor.config, "compile_threads", 1)
238    @patch.object(torch._inductor.config, "reorder_for_compute_comm_overlap", True)
239    @patch.object(
240        torch._inductor.config,
241        "reorder_for_compute_comm_overlap_passes",
242        [
243            "reorder_compute_for_overlap",
244        ],
245    )
246    def test_reorder_compute_for_overlap(self):
247        def func(a, *, tag, ranks, group_size):
248            ar = _functional_collectives.all_reduce(a, "sum", ranks, tag)
249            g = torch.matmul(a, a)
250            c = torch.relu(a)
251            d = torch.matmul(c, c)
252            f = d * c * ar
253            fr = _functional_collectives.all_reduce(f, "sum", ranks, tag)
254            e = torch.matmul(d + ar + fr, g)
255            return (e,)
256
257        with _dynamo_dist_per_rank_init(
258            self.rank, self.world_size, fake_pg=not at_least_x_gpu(2)
259        ):
260            inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank
261            compiled = torch.compile(func)
262            code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
263            # NOTE: after scheduling the first all_reduce:
264            # 1. we first schedule the ops (c and d) that ARE required for second all_reduce but DO NOT depend on first all_reduce.
265            # 2. then, we schedule the ops (g) that ARE NOT required for second all_reduce and DO NOT depend on first all_reduce.
266            # 3. then, we schedule the ops (f) that ARE required for second all_reduce and DO depend on first all_reduce.
267            # and then, we schedule the second all_reduce. And then schedule all ops that depend on second all_reduce.
268            (
269                FileCheck()
270                .check("torch.ops._c10d_functional.all_reduce_.default")
271                .check("triton_poi_fused_relu")
272                .check("extern_kernels.mm")
273                .check("extern_kernels.mm")
274                .check("torch.ops._c10d_functional.wait_tensor.default")
275                .check("triton_poi_fused_mul")
276                .check("torch.ops._c10d_functional.all_reduce_.default")
277                .check("torch.ops._c10d_functional.wait_tensor.default")
278                .check("triton_poi_fused_add")
279                .check("extern_kernels.mm")
280                .run(code)
281            )
282            out = compiled(inputs, **self.get_world_trs())
283            correct = func(inputs, **self.get_world_trs())
284            self.assertTrue(same(out, correct))
285
286    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
287    @patch.object(torch._inductor.config, "allow_buffer_reuse", True)
288    # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
289    @patch.object(torch._inductor.config, "compile_threads", 1)
290    @patch.object(torch._inductor.config, "reorder_for_compute_comm_overlap", True)
291    @patch.object(
292        torch._inductor.config,
293        "reorder_for_compute_comm_overlap_passes",
294        [
295            "reorder_compute_for_overlap",
296        ],
297    )
298    @patch.object(
299        torch._inductor.config,
300        "estimate_op_runtime",
301        get_snode_runtime_for_reorder_compute_test,
302    )
303    def test_reorder_compute_for_overlap_custom_runtime_estimation(self):
304        def func(a, *, tag, ranks, group_size):
305            ar = _functional_collectives.all_reduce(a, "sum", ranks, tag)
306            g = torch.matmul(a, a)
307            c = torch.relu(a)
308            d = torch.matmul(c, c)
309            f = d * c * ar
310            fr = _functional_collectives.all_reduce(f, "sum", ranks, tag)
311            e = torch.matmul(d + ar + fr, g)
312            return (e,)
313
314        with _dynamo_dist_per_rank_init(
315            self.rank, self.world_size, fake_pg=not at_least_x_gpu(2)
316        ):
317            inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank
318            compiled = torch.compile(func)
319            code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
320            # NOTE: after scheduling the first all_reduce:
321            # 1. we first schedule the ops (c and d) that ARE required for second all_reduce but DO NOT depend on first all_reduce.
322            # 2. then, we schedule the ops (g) that ARE NOT required for second all_reduce and DO NOT depend on first all_reduce.
323            # 3. then, we schedule the ops (f) that ARE required for second all_reduce and DO depend on first all_reduce.
324            # and then, we schedule the second all_reduce. And then schedule all ops that depend on second all_reduce.
325            (
326                FileCheck()
327                .check("torch.ops._c10d_functional.all_reduce_.default")
328                .check("triton_poi_fused_relu")
329                .check("extern_kernels.mm")
330                .check("extern_kernels.mm")
331                .check("torch.ops._c10d_functional.wait_tensor.default")
332                .check("triton_poi_fused_mul")
333                .check("torch.ops._c10d_functional.all_reduce_.default")
334                .check("torch.ops._c10d_functional.wait_tensor.default")
335                .check("triton_poi_fused_add")
336                .check("extern_kernels.mm")
337                .run(code)
338            )
339            out = compiled(inputs, **self.get_world_trs())
340            correct = func(inputs, **self.get_world_trs())
341            self.assertTrue(same(out, correct))
342
343    @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
344    # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
345    @patch.object(torch._inductor.config, "compile_threads", 1)
346    @patch.object(
347        torch._inductor.config,
348        "_pre_fusion_custom_pass",
349        create_grouped_node_for_allreduce_and_its_deps,
350    )
351    def test_grouped_scheduler_node(self):
352        def func(a, *, tag, ranks, group_size):
353            add = a + a
354            div = add / a
355            ar = _functional_collectives.all_reduce(div, "sum", ranks, tag)
356            # Normally, we would fuse `add = a + a`, `div = add / a` and `mul = a * a` together into a single fused op,
357            # but here in this unit test, we intentionally put `add`, `div` and `ar` computation
358            # into a GroupedSchedulerNode, which prevents them from being fused with any other ops.
359            mul = a * a
360            mm = torch.matmul(mul, ar)
361            return (mm,)
362
363        with _dynamo_dist_per_rank_init(
364            self.rank, self.world_size, fake_pg=not at_least_x_gpu(2)
365        ):
366            inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank
367            compiled = torch.compile(func)
368            code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
369            # Expectations:
370            # 1. `add = a + a` and `div = add / a` are still fused, which means fusion
371            #    still happens among nodes within a GroupedSchedulerNode.
372            # 2. `mul = a * a` is not fused with `add` or `div`, because the latter two are within
373            #    GroupedSchedulerNode and thus are prevented from being fused with any outside ops.
374            FileCheck().check("triton_poi_fused_add_div_0.").check(
375                "_c10d_functional.all_reduce_."
376            ).check("triton_poi_fused_mul_1.").run(code)
377            out = compiled(inputs, **self.get_world_trs())
378            correct = func(inputs, **self.get_world_trs())
379            self.assertTrue(same(out, correct))
380
381    def test_nccl_heuristics(self):
382        assert len(baseLat) == len(NCCL_ALGO)
383        assert all(len(x) == len(NCCL_PROTO) for x in baseLat)
384
385        assert len(hwLat) == len(NCCL_HW)
386        assert all(len(x) == len(NCCL_ALGO) for x in hwLat)
387        assert all(len(y) == len(NCCL_PROTO) for x in hwLat for y in x)
388
389        assert len(llMaxBws) == len(NVIDIA_GPU_TYPE)
390
391
392if __name__ == "__main__":
393    from torch._dynamo.test_case import run_tests
394
395    run_tests()
396