1# Owner(s): ["module: inductor"] 2 3import logging 4 5import torch 6import torch._inductor 7from torch._dynamo.utils import counters 8from torch._inductor.test_case import run_tests, TestCase 9from torch._inductor.utils import run_and_get_code 10from torch.testing import FileCheck 11from torch.testing._internal.common_utils import ( 12 instantiate_parametrized_tests, 13 parametrize, 14) 15from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA 16from torch.testing._internal.triton_utils import requires_gpu 17 18 19class MyModule(torch.nn.Module): 20 def __init__( 21 self, n_input: int, n_output: int, has_bias: bool, device=GPU_TYPE 22 ) -> None: 23 super().__init__() 24 self.linear = torch.nn.Linear(n_input, n_output, bias=has_bias) 25 26 def forward(self, x: torch.Tensor) -> torch.Tensor: 27 return self.linear(x) 28 29 30class MyModule2(torch.nn.Module): 31 def __init__(self) -> None: 32 super().__init__() 33 34 def forward(self, input1, input2): 35 output = torch.bmm(input1, input2) 36 return output 37 38 39class MyModule3(torch.nn.Module): 40 def __init__(self) -> None: 41 super().__init__() 42 43 def forward(self, input1, input2): 44 output = torch.mm(input1, input2) 45 return output 46 47 48@requires_gpu 49@torch._inductor.config.patch( 50 post_grad_fusion_options={ 51 "decompose_mm_pass": {}, 52 } 53) 54@instantiate_parametrized_tests 55class TestDecomposeMemMM(TestCase): 56 def compare_dict_tensors(self, ref_dict, res_dict, rtol=1e-3, atol=1e-3): 57 if len(set(ref_dict.keys())) != len(set(res_dict.keys())): 58 return False 59 for key1 in ref_dict.keys(): 60 key2 = "_orig_mod." + key1 61 assert key2 in res_dict, f"{key1} does not exist in traced module" 62 if not torch.allclose(ref_dict[key1], res_dict[key2], rtol=rtol, atol=atol): 63 return False 64 return True 65 66 def compare_pred(self, module, traced, input, rtol=1e-3, atol=1e-3): 67 ref = module(*input) 68 res = traced(*input) 69 self.assertEqual(ref, res, rtol=rtol, atol=atol) 70 71 def compare_parameters(self, module, traced, rtol=1e-3, atol=1e-3): 72 ref_params = dict(module.named_parameters()) 73 res_params = dict(traced.named_parameters()) 74 self.assertTrue(self.compare_dict_tensors(ref_params, res_params, rtol, atol)) 75 76 def compare_gradients(self, module, traced, rtol=1e-3, atol=1e-3): 77 ref_grad = {key: param.grad for key, param in module.named_parameters()} 78 res_grad = {key: param.grad for key, param in traced.named_parameters()} 79 self.assertTrue( 80 self.compare_dict_tensors(ref_grad, res_grad, rtol=rtol, atol=atol) 81 ) 82 83 @parametrize( 84 "b,m,k,n,should_decompose", 85 [(10240, 2, 2, 2, True), (10240, 2, 32, 32, False), (2000, 2, 2, 2, False)], 86 ) 87 def test_decompose_bmm(self, b, m, n, k, should_decompose): 88 torch._logging.set_logs(inductor=logging.DEBUG) 89 mat1 = torch.randn(b, m, k, device=GPU_TYPE).requires_grad_(True) 90 mat2 = torch.randn(b, k, n, device=GPU_TYPE).requires_grad_(True) 91 92 counters.clear() 93 94 module = MyModule2().to(GPU_TYPE) 95 traced = torch.compile(module) 96 input = [mat1, mat2] 97 ref = module(*input) 98 res = traced(*input) 99 100 self.compare_pred(module, traced, input) 101 102 expected_val = 1 if should_decompose and HAS_CUDA else 0 103 self.assertEqual( 104 counters["inductor"]["decompose_bmm"], 105 expected_val, 106 ) 107 108 ref.sum().backward() 109 res.sum().backward() 110 self.compare_parameters(module, traced) 111 self.compare_gradients(module, traced) 112 113 expected_val = 3 if should_decompose and HAS_CUDA else 0 114 self.assertEqual( 115 counters["inductor"]["decompose_bmm"], 116 expected_val, 117 ) 118 counters.clear() 119 120 @parametrize( 121 "m,k,n, should_decompose", 122 [(20480, 5, 2, True), (20480, 32, 2, False), (2048, 2, 2, False)], 123 ) 124 @parametrize("has_bias", [True, False]) 125 def test_decompose_linear(self, m, n, k, has_bias, should_decompose): 126 torch._logging.set_logs(inductor=logging.DEBUG) 127 input = torch.randn(m, k, device=GPU_TYPE).requires_grad_(True) 128 129 counters.clear() 130 131 module = MyModule(k, n, has_bias).to(GPU_TYPE) 132 traced = torch.compile(module) 133 input = [input] 134 ref = module(*input) 135 res = traced(*input) 136 137 self.compare_pred(module, traced, input) 138 139 expected_val = 1 if should_decompose and HAS_CUDA else 0 140 if has_bias: 141 self.assertEqual( 142 counters["inductor"]["decompose_addmm"], 143 expected_val, 144 ) 145 else: 146 self.assertEqual( 147 counters["inductor"]["decompose_mm"], 148 expected_val, 149 ) 150 decompose_mm_fwd = counters["inductor"]["decompose_mm"] 151 152 ref.sum().backward() 153 res.sum().backward() 154 155 self.compare_parameters(module, traced) 156 self.compare_gradients(module, traced) 157 158 self.assertEqual( 159 counters["inductor"]["decompose_mm"] - decompose_mm_fwd, 160 expected_val, 161 ) 162 counters.clear() 163 164 @parametrize( 165 "m,k,n, should_decompose", 166 [(20480, 5, 2, True), (20480, 32, 2, False), (2048, 2, 2, False)], 167 ) 168 @parametrize("has_bias", [True, False]) 169 def test_decompose_linear_mixed_precision( 170 self, m, n, k, has_bias, should_decompose 171 ): 172 with torch.amp.autocast(device_type=GPU_TYPE, dtype=torch.bfloat16): 173 torch._logging.set_logs(inductor=logging.DEBUG) 174 input = torch.randn(m, k, device=GPU_TYPE).requires_grad_(True) 175 176 counters.clear() 177 178 module = MyModule(k, n, has_bias).to(GPU_TYPE) 179 traced = torch.compile(module) 180 input = [input] 181 ref = module(*input) 182 res = traced(*input) 183 184 self.compare_pred(module, traced, input) 185 186 expected_val = 1 if should_decompose and HAS_CUDA else 0 187 if has_bias: 188 self.assertEqual( 189 counters["inductor"]["decompose_addmm"], 190 expected_val, 191 ) 192 else: 193 self.assertEqual( 194 counters["inductor"]["decompose_mm"], 195 expected_val, 196 ) 197 decompose_mm_fwd = counters["inductor"]["decompose_mm"] 198 199 ref.sum().backward() 200 res.sum().backward() 201 202 self.compare_parameters(module, traced) 203 self.compare_gradients(module, traced) 204 205 self.assertEqual( 206 counters["inductor"]["decompose_mm"] - decompose_mm_fwd, 207 expected_val, 208 ) 209 counters.clear() 210 211 @parametrize( 212 "m,k,n, should_decompose", 213 [(20480, 5, 2, True), (20480, 32, 2, False), (2048, 2, 2, False)], 214 ) 215 @parametrize("has_bias", [True, False]) 216 def test_decompose_mm(self, m, n, k, has_bias, should_decompose): 217 torch._logging.set_logs(inductor=logging.DEBUG) 218 mat1 = torch.randn(m, k, device=GPU_TYPE).requires_grad_(True) 219 mat2 = torch.randn(k, n, device=GPU_TYPE).requires_grad_(True) 220 221 counters.clear() 222 223 module = MyModule3().to(GPU_TYPE) 224 traced = torch.compile(module) 225 input = [mat1, mat2] 226 ref = module(*input) 227 res = traced(*input) 228 229 self.compare_pred(module, traced, input) 230 231 expected_val = 1 if should_decompose and HAS_CUDA else 0 232 self.assertEqual( 233 counters["inductor"]["decompose_mm"], 234 expected_val, 235 ) 236 decompose_mm_fwd = counters["inductor"]["decompose_mm"] 237 238 ref.sum().backward() 239 res.sum().backward() 240 self.compare_parameters(module, traced) 241 self.compare_gradients(module, traced) 242 243 expected_val = 1 if should_decompose and HAS_CUDA else 0 244 self.assertEqual( 245 counters["inductor"]["decompose_mm"] - decompose_mm_fwd, 246 expected_val, 247 ) 248 counters.clear() 249 250 @parametrize( 251 "m,k,n, should_decompose", 252 [(20480, 5, 2, True), (20480, 32, 2, False), (2048, 2, 2, False)], 253 ) 254 @parametrize("has_bias", [True, False]) 255 def test_decompose_mm_mixed_precision(self, m, n, k, has_bias, should_decompose): 256 with torch.amp.autocast(device_type=GPU_TYPE, dtype=torch.bfloat16): 257 torch._logging.set_logs(inductor=logging.DEBUG) 258 mat1 = torch.randn(m, k, device=GPU_TYPE).requires_grad_(True) 259 mat2 = torch.randn(k, n, device=GPU_TYPE).requires_grad_(True) 260 261 counters.clear() 262 263 module = MyModule3().to(GPU_TYPE) 264 traced = torch.compile(module) 265 input = [mat1, mat2] 266 ref = module(*input) 267 res = traced(*input) 268 269 self.compare_pred(module, traced, input) 270 271 expected_val = 1 if should_decompose and HAS_CUDA else 0 272 self.assertEqual( 273 counters["inductor"]["decompose_mm"], 274 expected_val, 275 ) 276 decompose_mm_fwd = counters["inductor"]["decompose_mm"] 277 278 ref.sum().backward() 279 res.sum().backward() 280 self.compare_parameters(module, traced) 281 self.compare_gradients(module, traced) 282 283 expected_val = 1 if should_decompose and HAS_CUDA else 0 284 self.assertEqual( 285 counters["inductor"]["decompose_mm"] - decompose_mm_fwd, 286 expected_val, 287 ) 288 counters.clear() 289 290 @parametrize("m,k,n, should_decompose", [(20480, 5, 2, True)]) 291 @parametrize("has_bias", [True, False]) 292 def test_dynamic_shape(self, m, n, k, has_bias, should_decompose): 293 torch._logging.set_logs(inductor=logging.DEBUG) 294 input = torch.randn(m, k, device=GPU_TYPE).requires_grad_(True) 295 296 counters.clear() 297 298 module = MyModule(k, n, has_bias).to(GPU_TYPE) 299 traced = torch.compile(module, dynamic=True) 300 input = [input] 301 ref = module(*input) 302 res = traced(*input) 303 304 self.compare_pred(module, traced, input) 305 306 expected_val = 1 if should_decompose and HAS_CUDA else 0 307 if has_bias: 308 self.assertEqual( 309 counters["inductor"]["decompose_addmm"], 310 expected_val, 311 ) 312 313 ref.sum().backward() 314 res.sum().backward() 315 316 self.compare_parameters(module, traced) 317 self.compare_gradients(module, traced) 318 319 expected_val = 0 320 if HAS_CUDA: 321 expected_val = 1 if has_bias else 2 322 323 self.assertEqual( 324 counters["inductor"]["decompose_mm"], 325 expected_val, 326 ) 327 counters.clear() 328 329 def test_realize_input(self): 330 m = 20480 331 k = 5 332 n = 2 333 torch._logging.set_logs(inductor=logging.DEBUG) 334 input1 = torch.randn(m, k, device=GPU_TYPE).T.contiguous() 335 input2 = torch.randn(k, n, device=GPU_TYPE) 336 337 @torch.compile() 338 def foo(x, y): 339 return x.T.contiguous() @ y 340 341 out, code = run_and_get_code(foo, input1, input2) 342 343 if GPU_TYPE == "xpu": 344 # only 1 kernel generated on the XPU stack 345 FileCheck().check_count(".run(", 1, exactly=True).run(code[0]) 346 else: 347 # two kernels generated 348 FileCheck().check_count(".run(", 2, exactly=True).run(code[0]) 349 350 351if __name__ == "__main__": 352 run_tests() 353