• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Owner(s): ["module: inductor"]
2
3import contextlib
4from unittest import skipIf
5
6import torch
7import torch.distributed as dist
8from torch._inductor import config, metrics
9from torch._inductor.comm_analysis import estimate_nccl_collective_runtime
10from torch._inductor.compile_fx import compile_fx, compile_fx_inner
11from torch._inductor.test_case import TestCase as InductorTestCase
12from torch._inductor.utils import is_collective
13from torch.testing._internal.inductor_utils import HAS_CUDA
14
15
16aten = torch.ops.aten
17c10d = torch.ops.c10d_functional
18_c10d = torch.ops._c10d_functional
19
20
21def compile_but_use_eager(gm, example_inputs):
22    def inner_compile(gm, *args, **kwargs):
23        compile_fx_inner(gm, *args, **kwargs)
24        return gm
25
26    return compile_fx(gm, example_inputs, inner_compile=inner_compile)
27
28
29def calculate_runtime(f, *args) -> float:
30    """
31    Assumes all inputs are fp32
32    """
33    metrics.reset()
34    torch.compile(f, backend=compile_but_use_eager)(*args)
35    print(metrics.node_runtimes)
36
37    ret = 0.0
38    for pair in metrics.node_runtimes:
39        ret += pair[1]
40
41    return ret
42
43
44DEVICE = "cuda"
45
46
47def T(*size, dtype=torch.float32, device=DEVICE, grad=False) -> torch.Tensor:
48    return torch.randn(size, dtype=dtype, device=device, requires_grad=grad)
49
50
51class TestCase(InductorTestCase):
52    device = DEVICE
53
54    """
55    Helper methods to compare runtime estimate against 0. Since this estimate is hardware dependent,
56    stronger comparisons may fail dependending on the host's specs.
57
58    atol/rtol must be provided explicitly with each call, since precision/rel_tol overrides are not always utilized
59    """
60
61    def setUp(self):
62        super().setUp()
63        # These tests check metrics.node_runtimes and we don't save / restore
64        # those in the FX graph cache.
65        self._test_snode_stack = contextlib.ExitStack()
66        self._test_snode_stack.enter_context(
67            config.patch({"fx_graph_remote_cache": False})
68        )
69
70    def tearDown(self):
71        self._test_snode_stack.close()
72        super().tearDown()
73
74    def assertZero(self, x: float):
75        assert isinstance(x, float)
76        super().assertEqual(x, 0.0, atol=0, rtol=0)
77
78    def assertNotZero(self, x):
79        assert isinstance(x, float)
80        super().assertNotEqual(x, 0.0, atol=0, rtol=0)
81
82
83class UnsupportedTests(TestCase):
84    def test_no_op(self):
85        def f(a):
86            return a
87
88        inp = (T(10, 10),)
89        self.assertZero(calculate_runtime(f, *inp))
90
91    def test_no_cuda(self):
92        def f(a):
93            return a
94
95        inp = (torch.randn((10, 10), device="cpu"),)
96        self.assertZero(calculate_runtime(f, *inp))
97
98
99class ComputeBoundedTests(TestCase):
100    def test_conv1d(self):
101        def f(x, y):
102            return torch.nn.functional.conv1d(x, y)
103
104        inp = (T(33, 16, 30), T(20, 16, 5))
105        self.assertNotZero(calculate_runtime(f, *inp))
106
107    def test_conv2d(self):
108        def f(x, y):
109            return torch.nn.functional.conv2d(x, y, padding=1)
110
111        inp = (T(8, 4, 3, 3), T(1, 4, 5, 5))
112        self.assertNotZero(calculate_runtime(f, *inp))
113
114    def test_conv2d_transpose(self):
115        def f(x, y):
116            return torch.nn.functional.conv_transpose2d(x, y, padding=1)
117
118        inp = (T(8, 1, 1, 1), T(1, 4, 5, 5))
119        self.assertNotZero(calculate_runtime(f, *inp))
120
121    def test_conv3d(self):
122        def f(x, y):
123            return torch.nn.functional.conv3d(x, y)
124
125        inp = (T(20, 16, 50, 10, 20), T(33, 16, 3, 3, 3))
126        self.assertNotZero(calculate_runtime(f, *inp))
127
128    def test_mm(self):
129        def f(a, b):
130            return torch.mm(a, b)
131
132        inp = (
133            T(10, 10),
134            T(10, 10),
135        )
136        self.assertNotZero(calculate_runtime(f, *inp))
137
138    def test_addmm(self):
139        def f(a, b, c):
140            return torch.addmm(a, b, c)
141
142        inp = (
143            T(10, 10),
144            T(10, 10),
145            T(10, 10),
146        )
147        self.assertNotZero(calculate_runtime(f, *inp))
148
149    def test_bmm(self):
150        def f(a, b):
151            return torch.bmm(a, b)
152
153        inp = (
154            T(10, 10, 10),
155            T(10, 10, 10),
156        )
157        self.assertNotZero(calculate_runtime(f, *inp))
158
159
160class MemoryBoundedTests(TestCase):
161    def test_relu(self):
162        def f(a):
163            return torch.nn.functional.relu(a)
164
165        inp = (T(10, 10),)
166        self.assertNotZero(calculate_runtime(f, *inp))
167
168    def test_horizontal_reduction_pointwise(self):
169        def f(a):
170            b = a.sum(dim=1)
171            c = a.cos()
172            return b, c
173
174        inp = (T(10, 10),)
175        self.assertNotZero(calculate_runtime(f, *inp))
176
177    def test_pointwise(self):
178        def f(x):
179            return x.cos()
180
181        inp = (T(10),)
182        self.assertNotZero(calculate_runtime(f, *inp))
183
184    @torch._dynamo.config.patch(assume_static_by_default=False)
185    def test_dynamic(self):
186        def f(x):
187            return x.cos()
188
189        inp = (T(10),)
190        self.assertNotZero(calculate_runtime(f, *inp))
191
192
193@skipIf(not dist.is_available(), "requires distributed")
194class TestCommAnalysis(TestCase):
195    WORLD_SIZE: int = 8
196    RANKS = list(range(8))
197
198    def _verify_runtime_estimation(self, fn, inps):
199        from torch.testing._internal.distributed.fake_pg import FakeStore
200
201        store = FakeStore()
202        dist.init_process_group(
203            backend="fake", rank=0, world_size=self.WORLD_SIZE, store=store
204        )
205        try:
206            metrics.reset()
207            torch.compile(fn)(*inps)
208            found_collective = False
209            for snode, runtime in metrics.node_runtimes:
210                if not is_collective(snode.node):
211                    continue
212                found_collective = True
213                # Inductor swallows errors from snode runtime estimations.
214                # We call estimate_nccl_collective_runtime in a white-box
215                # fashion here so potential issues can be surfaced in tests.
216                est = estimate_nccl_collective_runtime(snode.node)
217                self.assertNotZero(est)
218                # Also make sure estimate_nccl_collective_runtime works
219                # correctly in inductor.
220                self.assertNotZero(runtime)
221            # Make sure a collective kernel is found in graph
222            self.assertTrue(found_collective)
223        finally:
224            dist.destroy_process_group()
225
226    def test_legacy_all_reduce(self):
227        def fn(x):
228            r = c10d.all_reduce(x, "sum", "", self.RANKS, self.WORLD_SIZE)
229            return c10d.wait_tensor(r)
230
231        inp = T(10, 10)
232        self._verify_runtime_estimation(fn, (inp,))
233
234    def test_legacy_all_reduce_coalesced(self):
235        def fn(x):
236            rs = c10d.all_reduce_coalesced(x, "sum", "", self.RANKS, self.WORLD_SIZE)
237            return [c10d.wait_tensor(r) for r in rs]
238
239        inp = [T(10, 10), T(15, 15)]
240        self._verify_runtime_estimation(fn, (inp,))
241
242    def test_legacy_all_gather_into_tensor_coalesced(self):
243        def fn(x):
244            rs = c10d.all_gather_into_tensor_coalesced(
245                x,
246                "",
247                self.RANKS,
248                self.WORLD_SIZE,
249            )
250            return [c10d.wait_tensor(r) for r in rs]
251
252        inp = [T(10, 10), T(15, 15)]
253        self._verify_runtime_estimation(fn, (inp,))
254
255    def test_all_reduce(self):
256        def fn(x):
257            r = _c10d.all_reduce(x, "sum", "0")
258            return _c10d.wait_tensor(r)
259
260        inp = T(10, 10)
261        self._verify_runtime_estimation(fn, (inp,))
262
263    def test_all_reduce_coalesced(self):
264        def fn(x):
265            rs = _c10d.all_reduce_coalesced(x, "sum", "0")
266            return [_c10d.wait_tensor(r) for r in rs]
267
268        inp = [T(10, 10), T(15, 15)]
269        self._verify_runtime_estimation(fn, (inp,))
270
271    def test_all_gather_into_tensor(self):
272        def fn(x):
273            rs = _c10d.all_gather_into_tensor(
274                x,
275                self.WORLD_SIZE,
276                "0",
277            )
278            return [_c10d.wait_tensor(r) for r in rs]
279
280        inp = T(10, 10)
281        self._verify_runtime_estimation(fn, (inp,))
282
283    def test_all_gather_into_tensor_coalesced(self):
284        def fn(x):
285            rs = _c10d.all_gather_into_tensor_coalesced(
286                x,
287                self.WORLD_SIZE,
288                "0",
289            )
290            return [_c10d.wait_tensor(r) for r in rs]
291
292        inp = [T(10, 10), T(15, 15)]
293        self._verify_runtime_estimation(fn, (inp,))
294
295    def test_reduce_scatter_tensor(self):
296        def fn(x):
297            rs = _c10d.reduce_scatter_tensor(
298                x,
299                "sum",
300                self.WORLD_SIZE,
301                "0",
302            )
303            return [_c10d.wait_tensor(r) for r in rs]
304
305        inp = T(self.WORLD_SIZE, 10)
306        self._verify_runtime_estimation(fn, (inp,))
307
308    def test_reduce_scatter_tensor_coalesced(self):
309        def fn(x):
310            rs = _c10d.reduce_scatter_tensor_coalesced(
311                x,
312                "sum",
313                self.WORLD_SIZE,
314                "0",
315            )
316            return [_c10d.wait_tensor(r) for r in rs]
317
318        inp = [T(self.WORLD_SIZE, 10), T(self.WORLD_SIZE, 15)]
319        self._verify_runtime_estimation(fn, (inp,))
320
321
322if __name__ == "__main__":
323    from torch._inductor.test_case import run_tests
324
325    if HAS_CUDA:
326        run_tests(needs="filelock")
327