1# Owner(s): ["module: inductor"] 2import unittest 3from unittest.mock import patch 4 5import torch 6import torch._dynamo 7import torch._dynamo.logging 8import torch._dynamo.test_case 9 10# for some reason importing functional collectives after dynamo breaks collectives handling! 11import torch.distributed._functional_collectives as _functional_collectives 12from torch._C import FileCheck 13from torch._dynamo.utils import same 14from torch._inductor import ir, scheduler 15from torch._inductor.comm_analysis import ( 16 baseLat, 17 hwLat, 18 llMaxBws, 19 NCCL_ALGO, 20 NCCL_HW, 21 NCCL_PROTO, 22 NVIDIA_GPU_TYPE, 23) 24from torch._inductor.utils import run_and_get_triton_code 25from torch.testing._internal.common_distributed import ( 26 _dynamo_dist_per_rank_init, 27 at_least_x_gpu, 28 DynamoDistributedMultiProcTestCase, 29 requires_nccl, 30) 31from torch.utils._triton import has_triton 32 33 34def get_snode_runtime_for_reorder_compute_test(snode): 35 # NOTE: custom cost model to show that the compute reordering algorithm is working 36 # Collective kernels 37 if isinstance(snode.node, ir._CollectiveKernel): 38 return 100 39 elif isinstance(snode.node, ir._WaitKernel): 40 return 0 41 # High-arithmetic-intensity compute kernels 42 elif isinstance(snode.node, ir.ExternKernel): 43 return 5 44 # All other kernels 45 return 1 46 47 48def create_grouped_node_for_allreduce_and_its_deps(snodes): 49 name_to_snode = {snode.node.name: snode for snode in snodes} 50 all_reduce_snodes = [ 51 snode 52 for snode in snodes 53 if isinstance(snode.node, ir._CollectiveKernel) 54 and snode.node.op_overload == torch.ops._c10d_functional.all_reduce_.default 55 ] 56 assert len(all_reduce_snodes) == 1 57 all_reduce_snode = all_reduce_snodes[0] 58 all_reduce_dep_snodes = [ 59 name_to_snode[node.name] for node in all_reduce_snode.node.inputs 60 ] 61 assert len(all_reduce_dep_snodes) == 1 62 all_reduce_dep_snode = all_reduce_dep_snodes[0] 63 64 grouped_snode = scheduler.GroupedSchedulerNode.create( 65 [all_reduce_dep_snode, all_reduce_snode] 66 ) 67 new_snode_order = [] 68 new_snode_order.append(grouped_snode) 69 for snode in snodes: 70 if snode in grouped_snode.snodes: 71 continue 72 new_snode_order.append(snode) 73 return new_snode_order 74 75 76@requires_nccl() 77class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase): 78 """ 79 Run correctness checks in multi-proc runner, mark with minimum # GPUs to run under 80 """ 81 82 def get_world_trs(self): 83 return { 84 "tag": "", 85 "ranks": list(range(self.world_size)), 86 "group_size": self.world_size, 87 } 88 89 @property 90 def world_size(self) -> int: 91 # hack: no matter whether we have 2 or 3 or 4 gpus, just run on 2 92 # works around issue with skipif<2 and workers with unpredictable #s gpu 93 return 2 94 95 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 96 @patch.object(torch._inductor.config, "allow_buffer_reuse", True) 97 # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor 98 @patch.object(torch._inductor.config, "compile_threads", 1) 99 @patch.object(torch._inductor.config, "reorder_for_locality", False) 100 @patch.object(torch._inductor.config, "reorder_for_compute_comm_overlap", True) 101 @patch.object( 102 torch._inductor.config, 103 "reorder_for_compute_comm_overlap_passes", 104 [ 105 "sink_waits", 106 ], 107 ) 108 def test_sink_waits(self): 109 def func(a): 110 ar = _functional_collectives.all_reduce(a, "sum", "0") 111 b = torch.matmul(a, a) 112 return torch.matmul(ar, b) 113 114 with _dynamo_dist_per_rank_init( 115 self.rank, self.world_size, fake_pg=not at_least_x_gpu(2) 116 ): 117 inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank 118 compiled = torch.compile(func) 119 code = run_and_get_triton_code(compiled, inputs) 120 # Verify that the wait_tensor is sinked below the 1st matmul but 121 # above the 2nd matmul. 122 ( 123 FileCheck() 124 .check("torch.ops._c10d_functional.all_reduce_.default") 125 .check("extern_kernels.mm") 126 .check("torch.ops._c10d_functional.wait_tensor.default") 127 .check("extern_kernels.mm") 128 .run(code) 129 ) 130 out = compiled(inputs) 131 correct = func(inputs) 132 self.assertTrue(same(out, correct)) 133 134 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 135 @patch.object(torch._inductor.config, "allow_buffer_reuse", True) 136 # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor 137 @patch.object(torch._inductor.config, "compile_threads", 1) 138 @patch.object(torch._inductor.config, "reorder_for_locality", False) 139 @patch.object(torch._inductor.config, "reorder_for_compute_comm_overlap", True) 140 @patch.object( 141 torch._inductor.config, 142 "reorder_for_compute_comm_overlap_passes", 143 [ 144 "raise_comms", 145 ], 146 ) 147 def test_raise_comms(self): 148 def func(a): 149 b = torch.matmul(a, a) 150 c = torch.relu(b) 151 d = torch.matmul(c, c) 152 e = _functional_collectives.all_reduce(b, "sum", "0") 153 return torch.matmul(d, e) 154 155 with _dynamo_dist_per_rank_init( 156 self.rank, self.world_size, fake_pg=not at_least_x_gpu(2) 157 ): 158 inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank 159 compiled = torch.compile(func) 160 code = run_and_get_triton_code(compiled, inputs) 161 print(code) 162 # Verify that the all_reduce_ has been raised above the 2nd matmul 163 # but below the 1st matmul. Note that the all_reduce_ directly 164 # writes to the output buffer of the 1st matmul, which is an input 165 # to the first relu. Therefore, the all_reduce_ should be scheduled 166 # after the first relu. 167 ( 168 FileCheck() 169 .check("extern_kernels.mm") 170 .check("triton_poi_fused_relu") 171 .check("torch.ops._c10d_functional.all_reduce_.default") 172 .check("extern_kernels.mm") 173 .check("torch.ops._c10d_functional.wait_tensor.default") 174 .check("extern_kernels.mm") 175 .run(code) 176 ) 177 out = compiled(inputs) 178 correct = func(inputs) 179 self.assertTrue(same(out, correct)) 180 181 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 182 @patch.object(torch._inductor.config, "allow_buffer_reuse", True) 183 # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor 184 @patch.object(torch._inductor.config, "compile_threads", 1) 185 @patch.object(torch._inductor.config, "reorder_for_compute_comm_overlap", True) 186 @patch.object( 187 torch._inductor.config, 188 "reorder_for_compute_comm_overlap_passes", 189 [ 190 "sink_waits", 191 "raise_comms", 192 ], 193 ) 194 def test_sink_waits_raise_comms(self): 195 def func(a, *, tag, ranks, group_size): 196 b = torch.matmul(a, a) 197 c = torch.relu(b) 198 d = torch.matmul(c, c) 199 e = _functional_collectives.all_reduce(b, "sum", "0") 200 f = torch.relu(d) 201 g = torch.matmul(f, f) 202 return torch.mm(e, g) 203 204 with _dynamo_dist_per_rank_init( 205 self.rank, self.world_size, fake_pg=not at_least_x_gpu(2) 206 ): 207 inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank 208 compiled = torch.compile(func) 209 code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs()) 210 # Things to verify: 211 # - The clone prologue of the all_reduce_ should not be fused with 212 # any relus. 213 # - The all_reduce_ and its prologue should be raised above the 2nd 214 # matmul but below the 1st matmul. 215 # - The wait_tensor should be sinked below the 3rd matmul but above 216 # the 4th matmul. 217 ( 218 FileCheck() 219 .check("extern_kernels.mm") 220 .check("triton_poi_fused_all_reduce_0") 221 .check("torch.ops._c10d_functional.all_reduce_.default") 222 .check("triton_poi_fused_relu") 223 .check("extern_kernels.mm") 224 .check("triton_poi_fused_relu") 225 .check("extern_kernels.mm") 226 .check("torch.ops._c10d_functional.wait_tensor.default") 227 .check("extern_kernels.mm") 228 .run(code) 229 ) 230 out = compiled(inputs, **self.get_world_trs()) 231 correct = func(inputs, **self.get_world_trs()) 232 self.assertTrue(same(out, correct)) 233 234 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 235 @patch.object(torch._inductor.config, "allow_buffer_reuse", True) 236 # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor 237 @patch.object(torch._inductor.config, "compile_threads", 1) 238 @patch.object(torch._inductor.config, "reorder_for_compute_comm_overlap", True) 239 @patch.object( 240 torch._inductor.config, 241 "reorder_for_compute_comm_overlap_passes", 242 [ 243 "reorder_compute_for_overlap", 244 ], 245 ) 246 def test_reorder_compute_for_overlap(self): 247 def func(a, *, tag, ranks, group_size): 248 ar = _functional_collectives.all_reduce(a, "sum", ranks, tag) 249 g = torch.matmul(a, a) 250 c = torch.relu(a) 251 d = torch.matmul(c, c) 252 f = d * c * ar 253 fr = _functional_collectives.all_reduce(f, "sum", ranks, tag) 254 e = torch.matmul(d + ar + fr, g) 255 return (e,) 256 257 with _dynamo_dist_per_rank_init( 258 self.rank, self.world_size, fake_pg=not at_least_x_gpu(2) 259 ): 260 inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank 261 compiled = torch.compile(func) 262 code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs()) 263 # NOTE: after scheduling the first all_reduce: 264 # 1. we first schedule the ops (c and d) that ARE required for second all_reduce but DO NOT depend on first all_reduce. 265 # 2. then, we schedule the ops (g) that ARE NOT required for second all_reduce and DO NOT depend on first all_reduce. 266 # 3. then, we schedule the ops (f) that ARE required for second all_reduce and DO depend on first all_reduce. 267 # and then, we schedule the second all_reduce. And then schedule all ops that depend on second all_reduce. 268 ( 269 FileCheck() 270 .check("torch.ops._c10d_functional.all_reduce_.default") 271 .check("triton_poi_fused_relu") 272 .check("extern_kernels.mm") 273 .check("extern_kernels.mm") 274 .check("torch.ops._c10d_functional.wait_tensor.default") 275 .check("triton_poi_fused_mul") 276 .check("torch.ops._c10d_functional.all_reduce_.default") 277 .check("torch.ops._c10d_functional.wait_tensor.default") 278 .check("triton_poi_fused_add") 279 .check("extern_kernels.mm") 280 .run(code) 281 ) 282 out = compiled(inputs, **self.get_world_trs()) 283 correct = func(inputs, **self.get_world_trs()) 284 self.assertTrue(same(out, correct)) 285 286 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 287 @patch.object(torch._inductor.config, "allow_buffer_reuse", True) 288 # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor 289 @patch.object(torch._inductor.config, "compile_threads", 1) 290 @patch.object(torch._inductor.config, "reorder_for_compute_comm_overlap", True) 291 @patch.object( 292 torch._inductor.config, 293 "reorder_for_compute_comm_overlap_passes", 294 [ 295 "reorder_compute_for_overlap", 296 ], 297 ) 298 @patch.object( 299 torch._inductor.config, 300 "estimate_op_runtime", 301 get_snode_runtime_for_reorder_compute_test, 302 ) 303 def test_reorder_compute_for_overlap_custom_runtime_estimation(self): 304 def func(a, *, tag, ranks, group_size): 305 ar = _functional_collectives.all_reduce(a, "sum", ranks, tag) 306 g = torch.matmul(a, a) 307 c = torch.relu(a) 308 d = torch.matmul(c, c) 309 f = d * c * ar 310 fr = _functional_collectives.all_reduce(f, "sum", ranks, tag) 311 e = torch.matmul(d + ar + fr, g) 312 return (e,) 313 314 with _dynamo_dist_per_rank_init( 315 self.rank, self.world_size, fake_pg=not at_least_x_gpu(2) 316 ): 317 inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank 318 compiled = torch.compile(func) 319 code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs()) 320 # NOTE: after scheduling the first all_reduce: 321 # 1. we first schedule the ops (c and d) that ARE required for second all_reduce but DO NOT depend on first all_reduce. 322 # 2. then, we schedule the ops (g) that ARE NOT required for second all_reduce and DO NOT depend on first all_reduce. 323 # 3. then, we schedule the ops (f) that ARE required for second all_reduce and DO depend on first all_reduce. 324 # and then, we schedule the second all_reduce. And then schedule all ops that depend on second all_reduce. 325 ( 326 FileCheck() 327 .check("torch.ops._c10d_functional.all_reduce_.default") 328 .check("triton_poi_fused_relu") 329 .check("extern_kernels.mm") 330 .check("extern_kernels.mm") 331 .check("torch.ops._c10d_functional.wait_tensor.default") 332 .check("triton_poi_fused_mul") 333 .check("torch.ops._c10d_functional.all_reduce_.default") 334 .check("torch.ops._c10d_functional.wait_tensor.default") 335 .check("triton_poi_fused_add") 336 .check("extern_kernels.mm") 337 .run(code) 338 ) 339 out = compiled(inputs, **self.get_world_trs()) 340 correct = func(inputs, **self.get_world_trs()) 341 self.assertTrue(same(out, correct)) 342 343 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") 344 # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor 345 @patch.object(torch._inductor.config, "compile_threads", 1) 346 @patch.object( 347 torch._inductor.config, 348 "_pre_fusion_custom_pass", 349 create_grouped_node_for_allreduce_and_its_deps, 350 ) 351 def test_grouped_scheduler_node(self): 352 def func(a, *, tag, ranks, group_size): 353 add = a + a 354 div = add / a 355 ar = _functional_collectives.all_reduce(div, "sum", ranks, tag) 356 # Normally, we would fuse `add = a + a`, `div = add / a` and `mul = a * a` together into a single fused op, 357 # but here in this unit test, we intentionally put `add`, `div` and `ar` computation 358 # into a GroupedSchedulerNode, which prevents them from being fused with any other ops. 359 mul = a * a 360 mm = torch.matmul(mul, ar) 361 return (mm,) 362 363 with _dynamo_dist_per_rank_init( 364 self.rank, self.world_size, fake_pg=not at_least_x_gpu(2) 365 ): 366 inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank 367 compiled = torch.compile(func) 368 code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs()) 369 # Expectations: 370 # 1. `add = a + a` and `div = add / a` are still fused, which means fusion 371 # still happens among nodes within a GroupedSchedulerNode. 372 # 2. `mul = a * a` is not fused with `add` or `div`, because the latter two are within 373 # GroupedSchedulerNode and thus are prevented from being fused with any outside ops. 374 FileCheck().check("triton_poi_fused_add_div_0.").check( 375 "_c10d_functional.all_reduce_." 376 ).check("triton_poi_fused_mul_1.").run(code) 377 out = compiled(inputs, **self.get_world_trs()) 378 correct = func(inputs, **self.get_world_trs()) 379 self.assertTrue(same(out, correct)) 380 381 def test_nccl_heuristics(self): 382 assert len(baseLat) == len(NCCL_ALGO) 383 assert all(len(x) == len(NCCL_PROTO) for x in baseLat) 384 385 assert len(hwLat) == len(NCCL_HW) 386 assert all(len(x) == len(NCCL_ALGO) for x in hwLat) 387 assert all(len(y) == len(NCCL_PROTO) for x in hwLat for y in x) 388 389 assert len(llMaxBws) == len(NVIDIA_GPU_TYPE) 390 391 392if __name__ == "__main__": 393 from torch._dynamo.test_case import run_tests 394 395 run_tests() 396