1# Owner(s): ["module: inductor"] 2 3import copy 4import os 5import unittest 6 7import torch 8from torch import nn 9from torch._dynamo.utils import counters, same 10from torch._inductor import metrics 11from torch._inductor.runtime.benchmarking import benchmarker 12from torch._inductor.test_case import TestCase 13from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU 14 15 16DO_PERF_TEST = os.environ.get("DO_PERF_TEST") == "1" 17 18 19class TestScatterOpt(TestCase): 20 def setUp(self): 21 super().setUp() 22 metrics.reset() 23 counters.clear() 24 25 def check_metric(self, val=1): 26 self.assertEqual(val, metrics.num_matches_for_scatter_upon_const_tensor) 27 28 def do_acc_test(self, f, *args): 29 expect = f(*args) 30 actual = torch.compile(f)(*args) 31 self.assertTrue(same(expect, actual, tol=1e-3), f"{expect=}\n{actual=}\n") 32 33 def test_3d_tensor(self): 34 L, M, N = 2, 1024, 2048 35 36 def f(x): 37 y = torch.full([L, M, N], 3.14, dtype=torch.float) 38 y.scatter_(2, x.unsqueeze(2), 2.718) 39 return y 40 41 x = torch.randint(0, N, (L, M), dtype=torch.int64) 42 self.do_acc_test(f, x) 43 expected_num_bytes = ( 44 L * M * N * torch.float.itemsize + L * M * torch.int64.itemsize 45 ) 46 self.assertEqual(metrics.num_bytes_accessed, expected_num_bytes) 47 48 def test_non_last_dim(self): 49 """ 50 Test the case that the scatter dimension is not the last one. 51 """ 52 M, N = 1024, 2048 53 54 def f(x): 55 y = torch.full([M, N], 3.14, dtype=torch.float) 56 y.scatter_(0, x.unsqueeze(0), 2.718) 57 return y 58 59 x = torch.randint(0, M, (N,), dtype=torch.int64) 60 self.do_acc_test(f, x) 61 expected_num_bytes = M * N * torch.float.itemsize + N * torch.int64.itemsize 62 self.assertEqual(metrics.num_bytes_accessed, expected_num_bytes) 63 64 def test_neg_scatter_dim(self): 65 M, N = 1024, 2048 66 67 def f(x): 68 y = torch.full([M, N], 3.14, dtype=torch.float) 69 y.scatter_(-1, x.unsqueeze(1), 2.718) 70 return y 71 72 x = torch.randint(0, N, (M,), dtype=torch.int64) 73 self.do_acc_test(f, x) 74 expected_num_bytes = M * N * torch.float.itemsize + M * torch.int64.itemsize 75 self.assertEqual(metrics.num_bytes_accessed, expected_num_bytes) 76 77 def test_shorter_index_tensor(self): 78 M, N = 1024, 2048 79 80 def f(x): 81 y = torch.full([M, N], 3.14, dtype=torch.float) 82 y.scatter_(1, x.unsqueeze(1), 2.718) 83 return y 84 85 x = torch.randint(0, N, (M // 2,), dtype=torch.int64) 86 self.do_acc_test(f, x) 87 88 # no match since the index tensor is shorter. May support it in future. 89 self.assertEqual(0, counters["inductor"]["pattern_matcher_count"]) 90 91 def test_nonzero_const_tensor(self): 92 M, N = 1024, 2048 93 94 def f(x): 95 y = torch.full([M, N], 3.14, dtype=torch.float) 96 y.scatter_(1, x.unsqueeze(1), 2.718) 97 return y 98 99 x = torch.randint(0, N, (M,), dtype=torch.int64) 100 self.do_acc_test(f, x) 101 expected_num_bytes = M * N * torch.float.itemsize + M * torch.int64.itemsize 102 self.assertEqual(metrics.num_bytes_accessed, expected_num_bytes) 103 104 def test_can_not_optimize_due_to_dense(self): 105 M, N = 1024, 2048 106 107 def f(x): 108 y = torch.full([M, N], 0, dtype=torch.float) 109 y.scatter_(1, x, 0.618) 110 return y 111 112 x = torch.randint(0, N, (M, N // 2), dtype=torch.int64) 113 self.do_acc_test(f, x) 114 expected_num_bytes = M * N * torch.float.itemsize + M * (N // 2) * ( 115 torch.int64.itemsize + torch.float.itemsize 116 ) 117 # Use assertGreaterEqual rather than assertEqual due to the issue related 118 # to StarDep mentioned here: https://github.com/pytorch/pytorch/pull/129043#discussion_r1651699706 119 self.assertGreaterEqual(metrics.num_bytes_accessed, expected_num_bytes) 120 121 def test_can_not_optimize_due_to_non_const(self): 122 M, N = 1024, 2048 123 124 def f(x, y): 125 y.scatter_(1, x, 0.618) 126 return y 127 128 x = torch.randint(0, N, (M, 1), dtype=torch.int64) 129 y = torch.randn([M, N]) 130 self.do_acc_test(f, x, y) 131 132 # The generated code is quite in-efficient. 133 # There are 3 kernels 134 # 1. copy from arg to buf 135 # 2. scatter upon buf 136 # 3. copy buf back to arg 137 # Link to the wrapper: https://gist.github.com/shunting314/d43b74e680b3e5b514f7c28160c39f40 138 expected_num_bytes = 4 * M * N * torch.float.itemsize + M * ( 139 torch.int64.itemsize + torch.float.itemsize 140 ) 141 self.assertGreaterEqual(metrics.num_bytes_accessed, expected_num_bytes) 142 143 # the second kernel and third kernel are both mutation kernel. So we 144 # overestimated the memory accessed 145 # Update the test once the overestimiation is fixed. 146 over_estimate = M * torch.float.itemsize + M * N * torch.float.itemsize 147 self.assertEqual(metrics.num_bytes_accessed, expected_num_bytes + over_estimate) 148 149 def test_cross_entropy_loss(self): 150 """ 151 Match full+scatter in CEL and replaces it with a pointwise. 152 153 Perf data on an A100 GPU: 154 Without the scatter optimization: 155 ms=47.340, peak_mem=10.524 GB 156 With the scatter optimization: 157 ms=42.768, peak_mem=7.227 GB 158 """ 159 B, T, D, V = 32, 1024, 768, 50257 160 if not DO_PERF_TEST: 161 # use a smaller V if not doing perf test to avoid OOM 162 # in CI 163 V = V // 100 164 ref_model = nn.Linear(D, V).to(torch.bfloat16) 165 opt_model = copy.deepcopy(ref_model) 166 ce = nn.CrossEntropyLoss() 167 168 def f(m, x, label): 169 ce(m(x).view(-1, V), label.view(-1)).backward() 170 171 opt_f = torch.compile(f) 172 173 x = torch.randn(B, T, D).to(torch.bfloat16) 174 label = torch.randint(0, V, (B, T)).to(torch.int64) 175 176 f(ref_model, x, label) 177 ref_grad = ref_model.weight.grad 178 opt_f(opt_model, x, label) 179 act_grad = opt_model.weight.grad 180 assert torch.allclose( 181 ref_grad, act_grad, atol=1e-3, rtol=1e-3 182 ), f"{ref_grad=}\n{act_grad=}" 183 184 self.check_metric() 185 186 if DO_PERF_TEST: 187 if GPU_TYPE == "xpu": 188 raise unittest.SkipTest( 189 "torch.xpu.reset_peak_memory_stats not implemented." 190 ) 191 torch.cuda.reset_peak_memory_stats() 192 for _ in range(3): 193 opt_f(opt_model, x, label) 194 ms = benchmarker.benchmark_gpu(lambda: opt_f(opt_model, x, label)) 195 peak_mem = torch.cuda.max_memory_allocated() / 10**9 196 print(f"{ms=:.3f}, {peak_mem=:.3f} GB") 197 198 199if HAS_GPU: 200 torch.set_default_device(GPU_TYPE) 201 202if __name__ == "__main__": 203 from torch._inductor.test_case import run_tests 204 205 if HAS_GPU: 206 run_tests() 207