• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Owner(s): ["module: inductor"]
2
3import copy
4import os
5import unittest
6
7import torch
8from torch import nn
9from torch._dynamo.utils import counters, same
10from torch._inductor import metrics
11from torch._inductor.runtime.benchmarking import benchmarker
12from torch._inductor.test_case import TestCase
13from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
14
15
16DO_PERF_TEST = os.environ.get("DO_PERF_TEST") == "1"
17
18
19class TestScatterOpt(TestCase):
20    def setUp(self):
21        super().setUp()
22        metrics.reset()
23        counters.clear()
24
25    def check_metric(self, val=1):
26        self.assertEqual(val, metrics.num_matches_for_scatter_upon_const_tensor)
27
28    def do_acc_test(self, f, *args):
29        expect = f(*args)
30        actual = torch.compile(f)(*args)
31        self.assertTrue(same(expect, actual, tol=1e-3), f"{expect=}\n{actual=}\n")
32
33    def test_3d_tensor(self):
34        L, M, N = 2, 1024, 2048
35
36        def f(x):
37            y = torch.full([L, M, N], 3.14, dtype=torch.float)
38            y.scatter_(2, x.unsqueeze(2), 2.718)
39            return y
40
41        x = torch.randint(0, N, (L, M), dtype=torch.int64)
42        self.do_acc_test(f, x)
43        expected_num_bytes = (
44            L * M * N * torch.float.itemsize + L * M * torch.int64.itemsize
45        )
46        self.assertEqual(metrics.num_bytes_accessed, expected_num_bytes)
47
48    def test_non_last_dim(self):
49        """
50        Test the case that the scatter dimension is not the last one.
51        """
52        M, N = 1024, 2048
53
54        def f(x):
55            y = torch.full([M, N], 3.14, dtype=torch.float)
56            y.scatter_(0, x.unsqueeze(0), 2.718)
57            return y
58
59        x = torch.randint(0, M, (N,), dtype=torch.int64)
60        self.do_acc_test(f, x)
61        expected_num_bytes = M * N * torch.float.itemsize + N * torch.int64.itemsize
62        self.assertEqual(metrics.num_bytes_accessed, expected_num_bytes)
63
64    def test_neg_scatter_dim(self):
65        M, N = 1024, 2048
66
67        def f(x):
68            y = torch.full([M, N], 3.14, dtype=torch.float)
69            y.scatter_(-1, x.unsqueeze(1), 2.718)
70            return y
71
72        x = torch.randint(0, N, (M,), dtype=torch.int64)
73        self.do_acc_test(f, x)
74        expected_num_bytes = M * N * torch.float.itemsize + M * torch.int64.itemsize
75        self.assertEqual(metrics.num_bytes_accessed, expected_num_bytes)
76
77    def test_shorter_index_tensor(self):
78        M, N = 1024, 2048
79
80        def f(x):
81            y = torch.full([M, N], 3.14, dtype=torch.float)
82            y.scatter_(1, x.unsqueeze(1), 2.718)
83            return y
84
85        x = torch.randint(0, N, (M // 2,), dtype=torch.int64)
86        self.do_acc_test(f, x)
87
88        # no match since the index tensor is shorter. May support it in future.
89        self.assertEqual(0, counters["inductor"]["pattern_matcher_count"])
90
91    def test_nonzero_const_tensor(self):
92        M, N = 1024, 2048
93
94        def f(x):
95            y = torch.full([M, N], 3.14, dtype=torch.float)
96            y.scatter_(1, x.unsqueeze(1), 2.718)
97            return y
98
99        x = torch.randint(0, N, (M,), dtype=torch.int64)
100        self.do_acc_test(f, x)
101        expected_num_bytes = M * N * torch.float.itemsize + M * torch.int64.itemsize
102        self.assertEqual(metrics.num_bytes_accessed, expected_num_bytes)
103
104    def test_can_not_optimize_due_to_dense(self):
105        M, N = 1024, 2048
106
107        def f(x):
108            y = torch.full([M, N], 0, dtype=torch.float)
109            y.scatter_(1, x, 0.618)
110            return y
111
112        x = torch.randint(0, N, (M, N // 2), dtype=torch.int64)
113        self.do_acc_test(f, x)
114        expected_num_bytes = M * N * torch.float.itemsize + M * (N // 2) * (
115            torch.int64.itemsize + torch.float.itemsize
116        )
117        # Use assertGreaterEqual rather than assertEqual due to the issue related
118        # to StarDep mentioned here: https://github.com/pytorch/pytorch/pull/129043#discussion_r1651699706
119        self.assertGreaterEqual(metrics.num_bytes_accessed, expected_num_bytes)
120
121    def test_can_not_optimize_due_to_non_const(self):
122        M, N = 1024, 2048
123
124        def f(x, y):
125            y.scatter_(1, x, 0.618)
126            return y
127
128        x = torch.randint(0, N, (M, 1), dtype=torch.int64)
129        y = torch.randn([M, N])
130        self.do_acc_test(f, x, y)
131
132        # The generated code is quite in-efficient.
133        # There are 3 kernels
134        # 1. copy from arg to buf
135        # 2. scatter upon buf
136        # 3. copy buf back to arg
137        # Link to the wrapper: https://gist.github.com/shunting314/d43b74e680b3e5b514f7c28160c39f40
138        expected_num_bytes = 4 * M * N * torch.float.itemsize + M * (
139            torch.int64.itemsize + torch.float.itemsize
140        )
141        self.assertGreaterEqual(metrics.num_bytes_accessed, expected_num_bytes)
142
143        # the second kernel and third kernel are both mutation kernel. So we
144        # overestimated the memory accessed
145        # Update the test once the overestimiation is fixed.
146        over_estimate = M * torch.float.itemsize + M * N * torch.float.itemsize
147        self.assertEqual(metrics.num_bytes_accessed, expected_num_bytes + over_estimate)
148
149    def test_cross_entropy_loss(self):
150        """
151        Match full+scatter in CEL and replaces it with a pointwise.
152
153        Perf data on an A100 GPU:
154        Without the scatter optimization:
155          ms=47.340, peak_mem=10.524 GB
156        With the scatter optimization:
157          ms=42.768, peak_mem=7.227 GB
158        """
159        B, T, D, V = 32, 1024, 768, 50257
160        if not DO_PERF_TEST:
161            # use a smaller V if not doing perf test to avoid OOM
162            # in CI
163            V = V // 100
164        ref_model = nn.Linear(D, V).to(torch.bfloat16)
165        opt_model = copy.deepcopy(ref_model)
166        ce = nn.CrossEntropyLoss()
167
168        def f(m, x, label):
169            ce(m(x).view(-1, V), label.view(-1)).backward()
170
171        opt_f = torch.compile(f)
172
173        x = torch.randn(B, T, D).to(torch.bfloat16)
174        label = torch.randint(0, V, (B, T)).to(torch.int64)
175
176        f(ref_model, x, label)
177        ref_grad = ref_model.weight.grad
178        opt_f(opt_model, x, label)
179        act_grad = opt_model.weight.grad
180        assert torch.allclose(
181            ref_grad, act_grad, atol=1e-3, rtol=1e-3
182        ), f"{ref_grad=}\n{act_grad=}"
183
184        self.check_metric()
185
186        if DO_PERF_TEST:
187            if GPU_TYPE == "xpu":
188                raise unittest.SkipTest(
189                    "torch.xpu.reset_peak_memory_stats not implemented."
190                )
191            torch.cuda.reset_peak_memory_stats()
192            for _ in range(3):
193                opt_f(opt_model, x, label)
194            ms = benchmarker.benchmark_gpu(lambda: opt_f(opt_model, x, label))
195            peak_mem = torch.cuda.max_memory_allocated() / 10**9
196            print(f"{ms=:.3f}, {peak_mem=:.3f} GB")
197
198
199if HAS_GPU:
200    torch.set_default_device(GPU_TYPE)
201
202if __name__ == "__main__":
203    from torch._inductor.test_case import run_tests
204
205    if HAS_GPU:
206        run_tests()
207