# Owner(s): ["module: inductor"] import contextlib from unittest import skipIf import torch import torch.distributed as dist from torch._inductor import config, metrics from torch._inductor.comm_analysis import estimate_nccl_collective_runtime from torch._inductor.compile_fx import compile_fx, compile_fx_inner from torch._inductor.test_case import TestCase as InductorTestCase from torch._inductor.utils import is_collective from torch.testing._internal.inductor_utils import HAS_CUDA aten = torch.ops.aten c10d = torch.ops.c10d_functional _c10d = torch.ops._c10d_functional def compile_but_use_eager(gm, example_inputs): def inner_compile(gm, *args, **kwargs): compile_fx_inner(gm, *args, **kwargs) return gm return compile_fx(gm, example_inputs, inner_compile=inner_compile) def calculate_runtime(f, *args) -> float: """ Assumes all inputs are fp32 """ metrics.reset() torch.compile(f, backend=compile_but_use_eager)(*args) print(metrics.node_runtimes) ret = 0.0 for pair in metrics.node_runtimes: ret += pair[1] return ret DEVICE = "cuda" def T(*size, dtype=torch.float32, device=DEVICE, grad=False) -> torch.Tensor: return torch.randn(size, dtype=dtype, device=device, requires_grad=grad) class TestCase(InductorTestCase): device = DEVICE """ Helper methods to compare runtime estimate against 0. Since this estimate is hardware dependent, stronger comparisons may fail dependending on the host's specs. atol/rtol must be provided explicitly with each call, since precision/rel_tol overrides are not always utilized """ def setUp(self): super().setUp() # These tests check metrics.node_runtimes and we don't save / restore # those in the FX graph cache. self._test_snode_stack = contextlib.ExitStack() self._test_snode_stack.enter_context( config.patch({"fx_graph_remote_cache": False}) ) def tearDown(self): self._test_snode_stack.close() super().tearDown() def assertZero(self, x: float): assert isinstance(x, float) super().assertEqual(x, 0.0, atol=0, rtol=0) def assertNotZero(self, x): assert isinstance(x, float) super().assertNotEqual(x, 0.0, atol=0, rtol=0) class UnsupportedTests(TestCase): def test_no_op(self): def f(a): return a inp = (T(10, 10),) self.assertZero(calculate_runtime(f, *inp)) def test_no_cuda(self): def f(a): return a inp = (torch.randn((10, 10), device="cpu"),) self.assertZero(calculate_runtime(f, *inp)) class ComputeBoundedTests(TestCase): def test_conv1d(self): def f(x, y): return torch.nn.functional.conv1d(x, y) inp = (T(33, 16, 30), T(20, 16, 5)) self.assertNotZero(calculate_runtime(f, *inp)) def test_conv2d(self): def f(x, y): return torch.nn.functional.conv2d(x, y, padding=1) inp = (T(8, 4, 3, 3), T(1, 4, 5, 5)) self.assertNotZero(calculate_runtime(f, *inp)) def test_conv2d_transpose(self): def f(x, y): return torch.nn.functional.conv_transpose2d(x, y, padding=1) inp = (T(8, 1, 1, 1), T(1, 4, 5, 5)) self.assertNotZero(calculate_runtime(f, *inp)) def test_conv3d(self): def f(x, y): return torch.nn.functional.conv3d(x, y) inp = (T(20, 16, 50, 10, 20), T(33, 16, 3, 3, 3)) self.assertNotZero(calculate_runtime(f, *inp)) def test_mm(self): def f(a, b): return torch.mm(a, b) inp = ( T(10, 10), T(10, 10), ) self.assertNotZero(calculate_runtime(f, *inp)) def test_addmm(self): def f(a, b, c): return torch.addmm(a, b, c) inp = ( T(10, 10), T(10, 10), T(10, 10), ) self.assertNotZero(calculate_runtime(f, *inp)) def test_bmm(self): def f(a, b): return torch.bmm(a, b) inp = ( T(10, 10, 10), T(10, 10, 10), ) self.assertNotZero(calculate_runtime(f, *inp)) class MemoryBoundedTests(TestCase): def test_relu(self): def f(a): return torch.nn.functional.relu(a) inp = (T(10, 10),) self.assertNotZero(calculate_runtime(f, *inp)) def test_horizontal_reduction_pointwise(self): def f(a): b = a.sum(dim=1) c = a.cos() return b, c inp = (T(10, 10),) self.assertNotZero(calculate_runtime(f, *inp)) def test_pointwise(self): def f(x): return x.cos() inp = (T(10),) self.assertNotZero(calculate_runtime(f, *inp)) @torch._dynamo.config.patch(assume_static_by_default=False) def test_dynamic(self): def f(x): return x.cos() inp = (T(10),) self.assertNotZero(calculate_runtime(f, *inp)) @skipIf(not dist.is_available(), "requires distributed") class TestCommAnalysis(TestCase): WORLD_SIZE: int = 8 RANKS = list(range(8)) def _verify_runtime_estimation(self, fn, inps): from torch.testing._internal.distributed.fake_pg import FakeStore store = FakeStore() dist.init_process_group( backend="fake", rank=0, world_size=self.WORLD_SIZE, store=store ) try: metrics.reset() torch.compile(fn)(*inps) found_collective = False for snode, runtime in metrics.node_runtimes: if not is_collective(snode.node): continue found_collective = True # Inductor swallows errors from snode runtime estimations. # We call estimate_nccl_collective_runtime in a white-box # fashion here so potential issues can be surfaced in tests. est = estimate_nccl_collective_runtime(snode.node) self.assertNotZero(est) # Also make sure estimate_nccl_collective_runtime works # correctly in inductor. self.assertNotZero(runtime) # Make sure a collective kernel is found in graph self.assertTrue(found_collective) finally: dist.destroy_process_group() def test_legacy_all_reduce(self): def fn(x): r = c10d.all_reduce(x, "sum", "", self.RANKS, self.WORLD_SIZE) return c10d.wait_tensor(r) inp = T(10, 10) self._verify_runtime_estimation(fn, (inp,)) def test_legacy_all_reduce_coalesced(self): def fn(x): rs = c10d.all_reduce_coalesced(x, "sum", "", self.RANKS, self.WORLD_SIZE) return [c10d.wait_tensor(r) for r in rs] inp = [T(10, 10), T(15, 15)] self._verify_runtime_estimation(fn, (inp,)) def test_legacy_all_gather_into_tensor_coalesced(self): def fn(x): rs = c10d.all_gather_into_tensor_coalesced( x, "", self.RANKS, self.WORLD_SIZE, ) return [c10d.wait_tensor(r) for r in rs] inp = [T(10, 10), T(15, 15)] self._verify_runtime_estimation(fn, (inp,)) def test_all_reduce(self): def fn(x): r = _c10d.all_reduce(x, "sum", "0") return _c10d.wait_tensor(r) inp = T(10, 10) self._verify_runtime_estimation(fn, (inp,)) def test_all_reduce_coalesced(self): def fn(x): rs = _c10d.all_reduce_coalesced(x, "sum", "0") return [_c10d.wait_tensor(r) for r in rs] inp = [T(10, 10), T(15, 15)] self._verify_runtime_estimation(fn, (inp,)) def test_all_gather_into_tensor(self): def fn(x): rs = _c10d.all_gather_into_tensor( x, self.WORLD_SIZE, "0", ) return [_c10d.wait_tensor(r) for r in rs] inp = T(10, 10) self._verify_runtime_estimation(fn, (inp,)) def test_all_gather_into_tensor_coalesced(self): def fn(x): rs = _c10d.all_gather_into_tensor_coalesced( x, self.WORLD_SIZE, "0", ) return [_c10d.wait_tensor(r) for r in rs] inp = [T(10, 10), T(15, 15)] self._verify_runtime_estimation(fn, (inp,)) def test_reduce_scatter_tensor(self): def fn(x): rs = _c10d.reduce_scatter_tensor( x, "sum", self.WORLD_SIZE, "0", ) return [_c10d.wait_tensor(r) for r in rs] inp = T(self.WORLD_SIZE, 10) self._verify_runtime_estimation(fn, (inp,)) def test_reduce_scatter_tensor_coalesced(self): def fn(x): rs = _c10d.reduce_scatter_tensor_coalesced( x, "sum", self.WORLD_SIZE, "0", ) return [_c10d.wait_tensor(r) for r in rs] inp = [T(self.WORLD_SIZE, 10), T(self.WORLD_SIZE, 15)] self._verify_runtime_estimation(fn, (inp,)) if __name__ == "__main__": from torch._inductor.test_case import run_tests if HAS_CUDA: run_tests(needs="filelock")