1# Owner(s): ["module: inductor"] 2 3import contextlib 4from unittest import skipIf 5 6import torch 7import torch.distributed as dist 8from torch._inductor import config, metrics 9from torch._inductor.comm_analysis import estimate_nccl_collective_runtime 10from torch._inductor.compile_fx import compile_fx, compile_fx_inner 11from torch._inductor.test_case import TestCase as InductorTestCase 12from torch._inductor.utils import is_collective 13from torch.testing._internal.inductor_utils import HAS_CUDA 14 15 16aten = torch.ops.aten 17c10d = torch.ops.c10d_functional 18_c10d = torch.ops._c10d_functional 19 20 21def compile_but_use_eager(gm, example_inputs): 22 def inner_compile(gm, *args, **kwargs): 23 compile_fx_inner(gm, *args, **kwargs) 24 return gm 25 26 return compile_fx(gm, example_inputs, inner_compile=inner_compile) 27 28 29def calculate_runtime(f, *args) -> float: 30 """ 31 Assumes all inputs are fp32 32 """ 33 metrics.reset() 34 torch.compile(f, backend=compile_but_use_eager)(*args) 35 print(metrics.node_runtimes) 36 37 ret = 0.0 38 for pair in metrics.node_runtimes: 39 ret += pair[1] 40 41 return ret 42 43 44DEVICE = "cuda" 45 46 47def T(*size, dtype=torch.float32, device=DEVICE, grad=False) -> torch.Tensor: 48 return torch.randn(size, dtype=dtype, device=device, requires_grad=grad) 49 50 51class TestCase(InductorTestCase): 52 device = DEVICE 53 54 """ 55 Helper methods to compare runtime estimate against 0. Since this estimate is hardware dependent, 56 stronger comparisons may fail dependending on the host's specs. 57 58 atol/rtol must be provided explicitly with each call, since precision/rel_tol overrides are not always utilized 59 """ 60 61 def setUp(self): 62 super().setUp() 63 # These tests check metrics.node_runtimes and we don't save / restore 64 # those in the FX graph cache. 65 self._test_snode_stack = contextlib.ExitStack() 66 self._test_snode_stack.enter_context( 67 config.patch({"fx_graph_remote_cache": False}) 68 ) 69 70 def tearDown(self): 71 self._test_snode_stack.close() 72 super().tearDown() 73 74 def assertZero(self, x: float): 75 assert isinstance(x, float) 76 super().assertEqual(x, 0.0, atol=0, rtol=0) 77 78 def assertNotZero(self, x): 79 assert isinstance(x, float) 80 super().assertNotEqual(x, 0.0, atol=0, rtol=0) 81 82 83class UnsupportedTests(TestCase): 84 def test_no_op(self): 85 def f(a): 86 return a 87 88 inp = (T(10, 10),) 89 self.assertZero(calculate_runtime(f, *inp)) 90 91 def test_no_cuda(self): 92 def f(a): 93 return a 94 95 inp = (torch.randn((10, 10), device="cpu"),) 96 self.assertZero(calculate_runtime(f, *inp)) 97 98 99class ComputeBoundedTests(TestCase): 100 def test_conv1d(self): 101 def f(x, y): 102 return torch.nn.functional.conv1d(x, y) 103 104 inp = (T(33, 16, 30), T(20, 16, 5)) 105 self.assertNotZero(calculate_runtime(f, *inp)) 106 107 def test_conv2d(self): 108 def f(x, y): 109 return torch.nn.functional.conv2d(x, y, padding=1) 110 111 inp = (T(8, 4, 3, 3), T(1, 4, 5, 5)) 112 self.assertNotZero(calculate_runtime(f, *inp)) 113 114 def test_conv2d_transpose(self): 115 def f(x, y): 116 return torch.nn.functional.conv_transpose2d(x, y, padding=1) 117 118 inp = (T(8, 1, 1, 1), T(1, 4, 5, 5)) 119 self.assertNotZero(calculate_runtime(f, *inp)) 120 121 def test_conv3d(self): 122 def f(x, y): 123 return torch.nn.functional.conv3d(x, y) 124 125 inp = (T(20, 16, 50, 10, 20), T(33, 16, 3, 3, 3)) 126 self.assertNotZero(calculate_runtime(f, *inp)) 127 128 def test_mm(self): 129 def f(a, b): 130 return torch.mm(a, b) 131 132 inp = ( 133 T(10, 10), 134 T(10, 10), 135 ) 136 self.assertNotZero(calculate_runtime(f, *inp)) 137 138 def test_addmm(self): 139 def f(a, b, c): 140 return torch.addmm(a, b, c) 141 142 inp = ( 143 T(10, 10), 144 T(10, 10), 145 T(10, 10), 146 ) 147 self.assertNotZero(calculate_runtime(f, *inp)) 148 149 def test_bmm(self): 150 def f(a, b): 151 return torch.bmm(a, b) 152 153 inp = ( 154 T(10, 10, 10), 155 T(10, 10, 10), 156 ) 157 self.assertNotZero(calculate_runtime(f, *inp)) 158 159 160class MemoryBoundedTests(TestCase): 161 def test_relu(self): 162 def f(a): 163 return torch.nn.functional.relu(a) 164 165 inp = (T(10, 10),) 166 self.assertNotZero(calculate_runtime(f, *inp)) 167 168 def test_horizontal_reduction_pointwise(self): 169 def f(a): 170 b = a.sum(dim=1) 171 c = a.cos() 172 return b, c 173 174 inp = (T(10, 10),) 175 self.assertNotZero(calculate_runtime(f, *inp)) 176 177 def test_pointwise(self): 178 def f(x): 179 return x.cos() 180 181 inp = (T(10),) 182 self.assertNotZero(calculate_runtime(f, *inp)) 183 184 @torch._dynamo.config.patch(assume_static_by_default=False) 185 def test_dynamic(self): 186 def f(x): 187 return x.cos() 188 189 inp = (T(10),) 190 self.assertNotZero(calculate_runtime(f, *inp)) 191 192 193@skipIf(not dist.is_available(), "requires distributed") 194class TestCommAnalysis(TestCase): 195 WORLD_SIZE: int = 8 196 RANKS = list(range(8)) 197 198 def _verify_runtime_estimation(self, fn, inps): 199 from torch.testing._internal.distributed.fake_pg import FakeStore 200 201 store = FakeStore() 202 dist.init_process_group( 203 backend="fake", rank=0, world_size=self.WORLD_SIZE, store=store 204 ) 205 try: 206 metrics.reset() 207 torch.compile(fn)(*inps) 208 found_collective = False 209 for snode, runtime in metrics.node_runtimes: 210 if not is_collective(snode.node): 211 continue 212 found_collective = True 213 # Inductor swallows errors from snode runtime estimations. 214 # We call estimate_nccl_collective_runtime in a white-box 215 # fashion here so potential issues can be surfaced in tests. 216 est = estimate_nccl_collective_runtime(snode.node) 217 self.assertNotZero(est) 218 # Also make sure estimate_nccl_collective_runtime works 219 # correctly in inductor. 220 self.assertNotZero(runtime) 221 # Make sure a collective kernel is found in graph 222 self.assertTrue(found_collective) 223 finally: 224 dist.destroy_process_group() 225 226 def test_legacy_all_reduce(self): 227 def fn(x): 228 r = c10d.all_reduce(x, "sum", "", self.RANKS, self.WORLD_SIZE) 229 return c10d.wait_tensor(r) 230 231 inp = T(10, 10) 232 self._verify_runtime_estimation(fn, (inp,)) 233 234 def test_legacy_all_reduce_coalesced(self): 235 def fn(x): 236 rs = c10d.all_reduce_coalesced(x, "sum", "", self.RANKS, self.WORLD_SIZE) 237 return [c10d.wait_tensor(r) for r in rs] 238 239 inp = [T(10, 10), T(15, 15)] 240 self._verify_runtime_estimation(fn, (inp,)) 241 242 def test_legacy_all_gather_into_tensor_coalesced(self): 243 def fn(x): 244 rs = c10d.all_gather_into_tensor_coalesced( 245 x, 246 "", 247 self.RANKS, 248 self.WORLD_SIZE, 249 ) 250 return [c10d.wait_tensor(r) for r in rs] 251 252 inp = [T(10, 10), T(15, 15)] 253 self._verify_runtime_estimation(fn, (inp,)) 254 255 def test_all_reduce(self): 256 def fn(x): 257 r = _c10d.all_reduce(x, "sum", "0") 258 return _c10d.wait_tensor(r) 259 260 inp = T(10, 10) 261 self._verify_runtime_estimation(fn, (inp,)) 262 263 def test_all_reduce_coalesced(self): 264 def fn(x): 265 rs = _c10d.all_reduce_coalesced(x, "sum", "0") 266 return [_c10d.wait_tensor(r) for r in rs] 267 268 inp = [T(10, 10), T(15, 15)] 269 self._verify_runtime_estimation(fn, (inp,)) 270 271 def test_all_gather_into_tensor(self): 272 def fn(x): 273 rs = _c10d.all_gather_into_tensor( 274 x, 275 self.WORLD_SIZE, 276 "0", 277 ) 278 return [_c10d.wait_tensor(r) for r in rs] 279 280 inp = T(10, 10) 281 self._verify_runtime_estimation(fn, (inp,)) 282 283 def test_all_gather_into_tensor_coalesced(self): 284 def fn(x): 285 rs = _c10d.all_gather_into_tensor_coalesced( 286 x, 287 self.WORLD_SIZE, 288 "0", 289 ) 290 return [_c10d.wait_tensor(r) for r in rs] 291 292 inp = [T(10, 10), T(15, 15)] 293 self._verify_runtime_estimation(fn, (inp,)) 294 295 def test_reduce_scatter_tensor(self): 296 def fn(x): 297 rs = _c10d.reduce_scatter_tensor( 298 x, 299 "sum", 300 self.WORLD_SIZE, 301 "0", 302 ) 303 return [_c10d.wait_tensor(r) for r in rs] 304 305 inp = T(self.WORLD_SIZE, 10) 306 self._verify_runtime_estimation(fn, (inp,)) 307 308 def test_reduce_scatter_tensor_coalesced(self): 309 def fn(x): 310 rs = _c10d.reduce_scatter_tensor_coalesced( 311 x, 312 "sum", 313 self.WORLD_SIZE, 314 "0", 315 ) 316 return [_c10d.wait_tensor(r) for r in rs] 317 318 inp = [T(self.WORLD_SIZE, 10), T(self.WORLD_SIZE, 15)] 319 self._verify_runtime_estimation(fn, (inp,)) 320 321 322if __name__ == "__main__": 323 from torch._inductor.test_case import run_tests 324 325 if HAS_CUDA: 326 run_tests(needs="filelock") 327