# Owner(s): ["module: c10d"] import threading import unittest from typing import List import torch import torch.distributed as dist import torch.distributed._functional_collectives as funcol from torch._C import FileCheck from torch._inductor.utils import fresh_inductor_cache, run_and_get_triton_code from torch.distributed._functional_collectives import ( all_gather_into_tensor_coalesced, all_gather_tensor, all_reduce, all_reduce_coalesced, all_to_all_single, AsyncCollectiveTensor, reduce_scatter_tensor, reduce_scatter_tensor_coalesced, ) from torch.testing._internal.common_distributed import ( MultiProcessTestCase, requires_nccl, skip_if_lt_x_gpu, ) from torch.testing._internal.common_utils import ( # type: ignore[attr-defined] run_tests, TestCase, ) from torch.testing._internal.distributed.fake_pg import FakeStore from torch.utils._triton import has_triton def load_test_module(name): import sys from importlib.machinery import SourceFileLoader from pathlib import Path from unittest import mock testdir = Path(__file__).absolute().parent.parent with mock.patch("sys.path", [*sys.path, str(testdir)]): return SourceFileLoader( name, str(testdir / f"{name.replace('.', '/')}.py") ).load_module() AOTIRunnerUtil = load_test_module("inductor.test_aot_inductor_utils").AOTIRunnerUtil import sys if not dist.is_available(): print("distributed package not available, skipping tests", file=sys.stderr) sys.exit(0) @requires_nccl() class TestWithNCCL(MultiProcessTestCase): def setUp(self) -> None: super().setUp() self._spawn_processes() @property def world_size(self) -> int: return 2 @property def ranks(self) -> List[int]: return list(range(self.world_size)) @property def device(self) -> torch.device: return torch.device(f"cuda:{self.rank}") def _init_process_group(self) -> None: # Allow testing aoti after torch.compile torch._inductor.config.triton.store_cubin = True torch._inductor.config.debug = True torch.cuda.set_device(self.device) store = dist.FileStore(self.file_name, self.world_size) dist.init_process_group( backend="nccl", world_size=self.world_size, rank=self.rank, store=store, ) torch._C._distributed_c10d._register_process_group("default", dist.group.WORLD) @skip_if_lt_x_gpu(2) def test_all_reduce_single(self) -> None: self._init_process_group() input = torch.full((10, 10), float(self.rank), device=self.device) output = torch.ops._c10d_functional.all_reduce( input, "avg", "default", ) output = torch.ops._c10d_functional.wait_tensor(output) assert id(output) != id(input) expect = sum(self.ranks) / self.world_size assert output.eq(expect).all() # Test Python API and AsyncCollectiveTensor output = all_reduce( input, "avg", "default", ) assert isinstance(output, AsyncCollectiveTensor) assert not output.completed assert output.eq(expect).all() assert output.completed @skip_if_lt_x_gpu(2) def test_all_reduce_single_(self) -> None: self._init_process_group() input = torch.full((10, 10), float(self.rank), device=self.device) output = torch.ops._c10d_functional.all_reduce_( input, "avg", "default", ) output = torch.ops._c10d_functional.wait_tensor(output) assert id(output) == id(input) expect = sum(self.ranks) / self.world_size assert output.eq(expect).all() @skip_if_lt_x_gpu(2) def test_all_reduce_coalesced(self) -> None: self._init_process_group() inputs = [ torch.full((i, i), float(self.rank * i), device=self.device) for i in range(10) ] outputs = torch.ops._c10d_functional.all_reduce_coalesced( inputs, "avg", "default", ) for i, (output, input) in enumerate(zip(outputs, inputs)): output = torch.ops._c10d_functional.wait_tensor(output) assert id(output) != id(input) assert output.eq(sum(self.ranks) / self.world_size * i).all() # Test Python API and AsyncCollectiveTensor outputs = all_reduce_coalesced( inputs, "avg", "default", ) for i, (output, input) in enumerate(zip(outputs, inputs)): assert not output.completed assert output.eq(sum(self.ranks) / self.world_size * i).all() assert output.completed @skip_if_lt_x_gpu(2) def test_all_reduce_coalesced_(self) -> None: self._init_process_group() inputs = [ torch.full((i, i), float(self.rank * i), device=self.device) for i in range(10) ] outputs = torch.ops._c10d_functional.all_reduce_coalesced_( inputs, "avg", "default", ) for i, (output, input) in enumerate(zip(outputs, inputs)): output = torch.ops._c10d_functional.wait_tensor(output) assert id(output) == id(input) assert output.eq(sum(self.ranks) / self.world_size * i).all() @skip_if_lt_x_gpu(2) def test_all_gather_into_tensor_single(self) -> None: self._init_process_group() input = torch.full((10, 10), float(self.rank), device=self.device) output = torch.ops._c10d_functional.all_gather_into_tensor( input, self.world_size, "default", ) output = torch.ops._c10d_functional.wait_tensor(output) expect = torch.cat( [ torch.full((10, 10), float(rank), device=self.device) for rank in self.ranks ] ) assert torch.allclose(output, expect) assert output.eq(expect).all() # Test out-variant of all_gather_into_tensor output = torch.empty(expect.shape, device=self.device) output = torch.ops._c10d_functional.all_gather_into_tensor_out( input, self.world_size, "default", out=output, ) output = torch.ops._c10d_functional.wait_tensor(output) assert torch.allclose(output, expect) assert output.eq(expect).all() # Test Python API and AsyncCollectiveTensor output = all_gather_tensor( input, 0, "default", ) assert isinstance(output, AsyncCollectiveTensor) assert not output.completed assert output.eq(expect).all() assert output.completed @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) # https://github.com/pytorch/pytorch/issues/126338 def test_inductor_dtypeview_memory_leak(self): self._init_process_group() def func(arg: torch.Tensor) -> torch.Tensor: ag0 = torch.ops._c10d_functional.all_gather_into_tensor.default( arg, self.world_size, "default", ) ag0_view = torch.ops.aten.view.dtype(ag0, torch.int32) return funcol.wait_tensor(ag0_view) arg = torch.full( (10, 10), float(self.rank), device=self.device, dtype=torch.float32, ) compiled = torch.compile(func) mem_usage = {} # check if the aten.view.dtype is compiled to aten.view.dtype code = run_and_get_triton_code(compiled, arg) ( FileCheck() .check("torch.ops._c10d_functional.wait_tensor.default(aten.view.dtype") .run(code) ) # check memory leak for i in range(1, 10): mem_usage[i] = torch.cuda.max_memory_allocated() compiled(arg) assert mem_usage[9] == mem_usage[8] @skip_if_lt_x_gpu(2) def test_all_gather_into_tensor_coalesced(self) -> None: self._init_process_group() inputs = [ torch.full((10, 10), float(self.rank * i), device=self.device) for i in range(10) ] outputs = torch.ops._c10d_functional.all_gather_into_tensor_coalesced( inputs, self.world_size, "default", ) expect = [ torch.cat( [ torch.full((10, 10), float(rank) * i, device=self.device) for rank in self.ranks ] ) for i in range(10) ] for i, output in enumerate(outputs): output = torch.ops._c10d_functional.wait_tensor(output) assert output.eq(expect[i]).all() # Test Python API and AsyncCollectiveTensor outputs = all_gather_into_tensor_coalesced( inputs, "default", ) for i, output in enumerate(outputs): assert not output.completed assert output.eq(expect[i]).all() assert output.completed @skip_if_lt_x_gpu(2) def test_reduce_scatter_tensor_single(self) -> None: self._init_process_group() input = torch.tensor(self.ranks, device=self.device) output = torch.ops._c10d_functional.reduce_scatter_tensor( input, "avg", self.world_size, "default", ) output = torch.ops._c10d_functional.wait_tensor(output) assert output.eq(self.rank).all() # Test Python API and AsyncCollectiveTensor output = reduce_scatter_tensor( input, "avg", 0, "default", ) assert isinstance(output, AsyncCollectiveTensor) assert not output.completed assert output.eq(self.rank).all() assert output.completed @skip_if_lt_x_gpu(2) def test_reduce_scatter_tensor_coalesced(self) -> None: self._init_process_group() inputs = [torch.tensor(self.ranks, device=self.device) * i for i in range(10)] outputs = torch.ops._c10d_functional.reduce_scatter_tensor_coalesced( inputs, "avg", self.world_size, "default", ) for i, output in enumerate(outputs): output = torch.ops._c10d_functional.wait_tensor(output) assert output.eq(self.rank * i).all() # Test Python API and AsyncCollectiveTensor outputs = reduce_scatter_tensor_coalesced( inputs, "avg", [0] * 10, "default", ) for i, output in enumerate(outputs): assert not output.completed assert output.eq(self.rank * i).all() assert output.completed @skip_if_lt_x_gpu(2) def test_all_to_all_single(self) -> None: self._init_process_group() torch.cuda.set_device(self.device) torch.manual_seed(42) send_sz_matrix = torch.randint(0, 20, (self.world_size, self.world_size)) input_split_sizes = send_sz_matrix[self.rank].tolist() output_split_sizes = send_sz_matrix[:, self.rank].tolist() input = torch.full((sum(input_split_sizes),), float(self.rank)).cuda() output = torch.ops._c10d_functional.all_to_all_single( input, output_split_sizes, input_split_sizes, "default", ) output = torch.ops._c10d_functional.wait_tensor(output) expect = torch.cat( [ torch.full((sz,), float(rank)).cuda() for rank, sz in enumerate(output_split_sizes) ] ) assert output.eq(expect).all() # Test Python API and AsyncCollectiveTensor output = all_to_all_single( input, output_split_sizes, input_split_sizes, "default" ) assert not output.completed assert output.eq(expect).all() assert output.completed @skip_if_lt_x_gpu(2) def test_broadcast(self) -> None: self._init_process_group() input = torch.full((10, 10), float(self.rank), device=self.device) output = torch.ops._c10d_functional.broadcast( input, 1, "default", ) output = torch.ops._c10d_functional.wait_tensor(output) assert id(output) != id(input) expect = 1 assert output.eq(expect).all() # Test Python API and AsyncCollectiveTensor output = funcol.broadcast( input, 1, "default", ) assert isinstance(output, AsyncCollectiveTensor) assert not output.completed assert output.eq(expect).all() assert output.completed @skip_if_lt_x_gpu(2) def test_unwaited(self) -> None: # Verify that the process can terminate gracefully # even with unwaited tensors self._init_process_group() input = torch.full((10, 10), float(self.rank), device=self.device) output = torch.ops._c10d_functional.all_reduce( input, "avg", "default", ) @skip_if_lt_x_gpu(2) def test_py_work(self) -> None: self._init_process_group() wait_called = False class MyWork(dist.Work): def wait(self, _): nonlocal wait_called wait_called = True tensor = torch.rand(2, 2) torch._C._distributed_c10d._register_work(tensor, MyWork()) torch.ops._c10d_functional.wait_tensor(tensor) self.assertTrue(wait_called) @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) @fresh_inductor_cache() def test_threading(self): self._init_process_group() device = torch.device(f"cuda:{self.rank}") def func(arg: torch.Tensor) -> torch.Tensor: buf0 = arg + 42 ar0 = funcol.all_reduce(buf0, "avg", "0") ar0 = funcol.wait_tensor(ar0) return ar0 + 1 arg = torch.rand(4, 4, device=device) func(arg) compiled = torch.compile(func, fullgraph=True) code = run_and_get_triton_code(compiled, arg) FileCheck().check("all_reduce_.default(buf0, 'avg', '0')").run(code) # Unless explicitly specified (e.g. in a custom runtime), the process # group registry is shared among all threads in a process. Here we # verify that a process group registered in main thread can be resolved # in a different thread. class TestThread(threading.Thread): def run(self): self.exc = None try: func(arg) compiled(arg) except BaseException as exc: self.exc = exc def join(self): threading.Thread.join(self) if self.exc: raise self.exc t = TestThread() t.start() t.join() class CompileTest(TestCase): def setUp(self): # Allow testing aoti after torch.compile torch._inductor.config.triton.store_cubin = True torch._inductor.config.debug = True self.rank = 0 self.world_size = 2 torch.cuda.set_device("cuda:0") store = FakeStore() dist.init_process_group( backend="fake", world_size=self.world_size, rank=self.rank, store=store, ) def tearDown(self): dist.destroy_process_group() @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @fresh_inductor_cache() def test_inductor_all_reduce_single(self): def func(arg: torch.Tensor) -> torch.Tensor: buf0 = arg + 42 # Expect in-place with inductor allocated buf ar0 = funcol.all_reduce(buf0, "avg", "0") ar0 = funcol.wait_tensor(ar0) # Expect no in-place with graph input ar1 = funcol.all_reduce(arg, "avg", "0") ar1 = funcol.wait_tensor(ar1) return ar0, ar1 arg = torch.rand(4, 4, device="cuda") compiled = torch.compile(func) code = run_and_get_triton_code(compiled, arg) ( FileCheck() .check("buf0 = empty") .check("buf7 = empty") # Expect in-place with inductor allocated buf .check("torch.ops._c10d_functional.all_reduce_.default(buf0") .check("torch.ops._c10d_functional.wait_tensor.default(buf0") # Expect no in-place with graph input (buf5 is a clone) .check("torch.ops._c10d_functional.all_reduce_.default(buf7") .check("torch.ops._c10d_functional.wait_tensor.default(buf7") # Expect no extra copy on return .check("return (buf0, buf7, )") .run(code) ) assert "= torch.ops._c10d_functional.wait_tensor.default" not in code # Test aoti out = AOTIRunnerUtil.run("cuda", func, (arg,)) torch.cuda.synchronize() @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @fresh_inductor_cache() def test_inductor_all_reduce_coalesced(self): def func(args: List[torch.Tensor]) -> torch.Tensor: bufs = [arg + 42 for arg in args] # Expect in-place with inductor allocated buf ar0 = funcol.all_reduce_coalesced(bufs, "avg", "0") ar0 = [funcol.wait_tensor(out) for out in ar0] # Expect no in-place with graph input ar1 = funcol.all_reduce_coalesced(args, "avg", "0") ar1 = [funcol.wait_tensor(out) for out in ar1] return ar0, ar1 args = [torch.rand(4, 4, device="cuda") for _ in range(2)] compiled = torch.compile(func) code = run_and_get_triton_code(compiled, args) ( FileCheck() .check("buf0 = empty") .check("buf5 = empty") .check("buf1 = empty") .check("buf6 = empty") # Expect in-place with inductor allocated buf .check( "torch.ops._c10d_functional.all_reduce_coalesced_" ".default([buf0, buf1]" ) # Expect no in-place with graph input (buf5, buf6 are clones) .check( "torch.ops._c10d_functional.all_reduce_coalesced_" ".default([buf5, buf6]" ) .check("torch.ops._c10d_functional.wait_tensor.default(buf0") .check("torch.ops._c10d_functional.wait_tensor.default(buf1") .check("torch.ops._c10d_functional.wait_tensor.default(buf5") .check("torch.ops._c10d_functional.wait_tensor.default(buf6") # Expect no extra copy on return .check("return (buf0, buf1, buf5, buf6, )") .run(code) ) assert "= torch.ops._c10d_functional.wait_tensor.default" not in code # Test aoti out = AOTIRunnerUtil.run("cuda", func, (args,)) torch.cuda.synchronize() @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @fresh_inductor_cache() def test_inductor_inplace_op_on_view(self): def func(arg: torch.Tensor) -> torch.Tensor: buf0 = (arg + 10)[:2] ar0 = funcol.all_reduce(buf0, "avg", "0") ar0 = funcol.wait_tensor(ar0) return ar0 arg = torch.rand(4, 4, device="cuda") compiled = torch.compile(func) code = run_and_get_triton_code(compiled, arg) ( FileCheck() .check("buf0 = empty") # Ensure the all_reduce_ input is a view .check( "torch.ops._c10d_functional.all_reduce_.default(reinterpret_tensor(buf0" ) .check( "torch.ops._c10d_functional.wait_tensor.default(reinterpret_tensor(buf0" ) .check("return (reinterpret_tensor(buf0") .run(code) ) @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @fresh_inductor_cache() def test_inductor_reuse_buffer_after_inplace_collective(self): def func(arg: torch.Tensor) -> torch.Tensor: # Expect allocation buf0 = arg + 42 ar0 = funcol.all_reduce(buf0, "avg", "0") ar0 = funcol.wait_tensor(ar0) # Expect allocation buf1 = torch.mm(arg, ar0) # Expect buf0 to be reused buf2 = torch.mm(arg, buf1) return buf1, buf2 arg = torch.rand(4, 4, device="cuda") compiled = torch.compile(func) code = run_and_get_triton_code(compiled, arg) ( FileCheck() # Expect allocation .check("buf0 = empty") .check("torch.ops._c10d_functional.all_reduce_.default(buf0") .check("torch.ops._c10d_functional.wait_tensor.default(buf0") # Expect allocation .check("buf7 = empty") .check("extern_kernels.mm(arg0_1, buf0, out=buf7") # Expect buf0 to be reused .check("buf8 = buf0; del buf0 # reuse") .check("extern_kernels.mm(arg0_1, buf7, out=buf8") # Expect no extra copy on return .check("return (buf7, buf8, )") .run(code) ) assert "= torch.ops._c10d_functional.wait_tensor.default" not in code @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @fresh_inductor_cache() def test_inductor_all_gather_into_tensor_single(self): def func(arg: torch.Tensor) -> torch.Tensor: ag0 = funcol.all_gather_tensor(arg, 0, "0") ag0 = funcol.wait_tensor(ag0) return ag0 arg = torch.rand(4, 4, device="cuda") compiled = torch.compile(func) code = run_and_get_triton_code(compiled, arg) ( FileCheck() .check( "buf0 = torch.ops._c10d_functional.all_gather_into_tensor.default(arg0_1" ) .check("torch.ops._c10d_functional.wait_tensor.default(buf0") # Expect no extra copy on return .check("return (buf0, )") .run(code) ) assert "= torch.ops._c10d_functional.wait_tensor.default" not in code # Test aoti out = AOTIRunnerUtil.run("cuda", func, (arg,)) torch.cuda.synchronize() @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @fresh_inductor_cache() def test_inductor_all_gather_into_tensor_coalesced(self): def func(args: List[torch.Tensor]) -> torch.Tensor: ag0 = funcol.all_gather_into_tensor_coalesced(args, "0") ag0 = [funcol.wait_tensor(out) for out in ag0] return ag0 args = [torch.rand(4, 4, device="cuda") for _ in range(4)] compiled = torch.compile(func) code = run_and_get_triton_code(compiled, args) ( FileCheck() .check( "buf0 = torch.ops._c10d_functional.all_gather_into_tensor_coalesced" ".default([arg0_1, arg1_1, arg2_1, arg3_1]" ) .check("buf1 = buf0[0]") .check("buf2 = buf0[1]") .check("buf3 = buf0[2]") .check("buf4 = buf0[3]") .check("torch.ops._c10d_functional.wait_tensor.default(buf1") .check("torch.ops._c10d_functional.wait_tensor.default(buf2") .check("torch.ops._c10d_functional.wait_tensor.default(buf3") .check("torch.ops._c10d_functional.wait_tensor.default(buf4") # Expect no extra copy on return .check("return (buf1, buf2, buf3, buf4, )") .run(code) ) # Test aoti out = AOTIRunnerUtil.run("cuda", func, (args,)) torch.cuda.synchronize() @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @fresh_inductor_cache() def test_inductor_reduce_scatter_tensor_single(self): def func(arg: torch.Tensor) -> torch.Tensor: rs0 = funcol.reduce_scatter_tensor(arg, "avg", 0, "0") rs0 = funcol.wait_tensor(rs0) return rs0 arg = torch.rand(4, 4, device="cuda") compiled = torch.compile(func) code = run_and_get_triton_code(compiled, arg) ( FileCheck() .check( "buf0 = torch.ops._c10d_functional.reduce_scatter_tensor.default(arg0_1" ) .check("torch.ops._c10d_functional.wait_tensor.default(buf0") # Expect no extra copy on return .check("return (buf0, )") .run(code) ) # Test aoti out = AOTIRunnerUtil.run("cuda", func, (arg,)) torch.cuda.synchronize() @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @fresh_inductor_cache() def test_inductor_reduce_scatter_tensor_coalesced(self): def func(args: List[torch.Tensor]) -> torch.Tensor: rs0 = funcol.reduce_scatter_tensor_coalesced( args, "avg", [0] * len(args), "0" ) rs0 = [funcol.wait_tensor(out) for out in rs0] return rs0 args = [torch.rand(4, 4, device="cuda") for _ in range(4)] compiled = torch.compile(func) code = run_and_get_triton_code(compiled, args) ( FileCheck() .check( "buf0 = torch.ops._c10d_functional.reduce_scatter_tensor_coalesced" ".default([arg0_1, arg1_1, arg2_1, arg3_1]" ) .check("buf1 = buf0[0]") .check("buf2 = buf0[1]") .check("buf3 = buf0[2]") .check("buf4 = buf0[3]") .check("torch.ops._c10d_functional.wait_tensor.default(buf1") .check("torch.ops._c10d_functional.wait_tensor.default(buf2") .check("torch.ops._c10d_functional.wait_tensor.default(buf3") .check("torch.ops._c10d_functional.wait_tensor.default(buf4") # Expect no extra copy on return .check("return (buf1, buf2, buf3, buf4, )") .run(code) ) # Test aoti AOTIRunnerUtil.run("cuda", func, (args,)) torch.cuda.synchronize() @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @fresh_inductor_cache() def test_inductor_all_to_all_single(self): def _tolist_with_constrain_as_size(tensor): lst = tensor.tolist() for elem in lst: torch._check_is_size(elem) return lst def func( input: torch.Tensor, output_split_sizes: torch.Tensor, input_split_sizes: torch.Tensor, ) -> torch.Tensor: output = funcol.all_to_all_single( input, _tolist_with_constrain_as_size(output_split_sizes), _tolist_with_constrain_as_size(input_split_sizes), "0", ) return funcol.wait_tensor(output) torch.manual_seed(42) send_sz_matrix = torch.randint(0, 20, (self.world_size, self.world_size)) input_split_sizes = send_sz_matrix[self.rank] output_split_sizes = send_sz_matrix[:, self.rank].contiguous() input = torch.full((input_split_sizes.sum().item(),), float(self.rank)).cuda() with torch._dynamo.config.patch( dynamic_shapes=True, capture_dynamic_output_shape_ops=True, capture_scalar_outputs=True, ): compiled = torch.compile(func, dynamic=True) code = run_and_get_triton_code( compiled, input, output_split_sizes, input_split_sizes ) ( FileCheck() .check_regex( "torch.ops._c10d_functional.all_to_all_single.default\\(" "arg\\d+_\\d+, \\[u\\d+, u\\d+\\], \\[u\\d+, u\\d+\\]" ) .check("torch.ops._c10d_functional.wait_tensor.default(") .run(code) ) @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @fresh_inductor_cache() def test_inductor_broadcast(self): def func(arg: torch.Tensor) -> torch.Tensor: buf0 = arg + 42 # Expect in-place with inductor allocated buf br0 = funcol.broadcast(buf0, 1, "0") br0 = funcol.wait_tensor(br0) # Expect no in-place with graph input br1 = funcol.broadcast(arg, 0, "0") br1 = funcol.wait_tensor(br1) return br0, br1 arg = torch.rand(4, 4, device="cuda") compiled = torch.compile(func) code = run_and_get_triton_code(compiled, arg) ( FileCheck() .check("buf0 = empty") .check("buf7 = empty") # Expect in-place with inductor allocated buf .check("torch.ops._c10d_functional.broadcast_.default(buf0") .check("torch.ops._c10d_functional.wait_tensor.default(buf0") # Expect no in-place with graph input (buf5 is a clone) .check("torch.ops._c10d_functional.broadcast_.default(buf7") .check("torch.ops._c10d_functional.wait_tensor.default(buf7") # Expect no extra copy on return .check("return (buf0, buf7, )") .run(code) ) # Test aoti out = AOTIRunnerUtil.run("cuda", func, (arg,)) torch.cuda.synchronize() @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @fresh_inductor_cache() def test_ranks_and_tag(self): def func(arg: torch.Tensor) -> torch.Tensor: buf0 = arg + 42 # Expect in-place with inductor allocated buf ar0 = funcol.all_reduce(buf0, "avg", [0, 1], "") ar0 = funcol.wait_tensor(ar0) # Expect no in-place with graph input ar1 = funcol.all_reduce(arg, "avg", [0, 1], "") ar1 = funcol.wait_tensor(ar1) return ar0, ar1 arg = torch.rand(4, 4, device="cuda") compiled = torch.compile(func, fullgraph=True) code = run_and_get_triton_code(compiled, arg) (FileCheck().check("all_reduce_.default(buf0, 'avg', '0')").run(code)) if __name__ == "__main__": run_tests()