import triton from benchmark_helper import time_with_torch_timer import torch import torch._dynamo import torch._dynamo.config import torch._inductor.config as config from torch._inductor.runtime.benchmarking import benchmarker # The flag below controls whether to allow TF32 on matmul. This flag defaults to True. torch.backends.cuda.matmul.allow_tf32 = True # The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True. torch.backends.cudnn.allow_tf32 = True @torch._dynamo.optimize("inductor", nopython=True) def inductor_aten_mm(a, b): return torch.mm(a, b) @torch._dynamo.optimize("inductor", nopython=True) def inductor_triton_mm(a, b): return torch.mm(a, b) def torch_mm(a, b): return torch.mm(a, b) def triton_mm(a, b): return triton.ops.matmul(a, b) def test_total_time(shapes): print("shape; torch mm; triton mm; inductor aten mm; inductor triton mm") for i in range(len(shapes)): a_shape, b_shape = shapes[i] print(a_shape, "x", b_shape, end="; ") a = torch.randn(a_shape, device="cuda", dtype=torch.float16) b = torch.randn(b_shape, device="cuda", dtype=a.dtype) config.triton.mm = "aten" inductor_aten_mm(a, b) config.triton.mm = "triton" inductor_triton_mm(a, b) torch_ms = time_with_torch_timer(torch_mm, (a, b)).mean * 1000 triton_ms = time_with_torch_timer(triton_mm, (a, b)).mean * 1000 config.triton.mm = "aten" ind_aten_ms = time_with_torch_timer(inductor_aten_mm, (a, b)).mean * 1000 config.triton.mm = "triton" ind_triton_ms = time_with_torch_timer(inductor_triton_mm, (a, b)).mean * 1000 print(torch_ms, triton_ms, ind_aten_ms, ind_triton_ms, sep="; ") torch._dynamo.reset() def test_GPU_time(shapes): print("shape; torch mm; triton mm; inductor aten mm; inductor triton mm") for i in range(len(shapes)): a_shape, b_shape = shapes[i] print(a_shape, "x", b_shape, end="; ") a = torch.randn(a_shape, device="cuda", dtype=torch.float16) b = torch.randn(b_shape, device="cuda", dtype=a.dtype) config.triton.mm = "aten" inductor_aten_mm(a, b) config.triton.mm = "triton" inductor_triton_mm(a, b) torch_ms, _, _ = benchmarker.benchmark_gpu(lambda: torch_mm(a, b)) triton_ms, _, _ = benchmarker.benchmark_gpu(lambda: triton_mm(a, b)) ind_aten_ms, _, _ = benchmarker.benchmark_gpu(lambda: inductor_aten_mm(a, b)) ind_triton_ms, _, _ = benchmarker.benchmark_gpu( lambda: inductor_triton_mm(a, b) ) print(torch_ms, triton_ms, ind_aten_ms, ind_triton_ms, sep="; ") torch._dynamo.reset() if __name__ == "__main__": shapes = [ # alexnet ([128, 9216], [9216, 4096]), ([128, 4096], [4096, 4096]), ([128, 4096], [4096, 1000]), # BERT ([2048, 768], [768, 768]), ([2048, 768], [768, 3072]), ([2048, 3072], [3072, 768]), # hf_GPT2 ([1024, 768], [768, 768]), ([1024, 768], [768, 3072]), ([1024, 3072], [3072, 768]), ([1024, 768], [768, 2304]), ] print("test total time") test_total_time(shapes) print("test GPU time") test_GPU_time(shapes) # Results Preview on AWS AI cluster """ test total time shape; torch mm; triton mm; inductor aten mm; inductor triton mm [128, 9216] x [9216, 4096]; 0.07240759208798409; 0.10885953903198242; 0.20063146017491817; 0.20054904278367758 [128, 4096] x [4096, 4096]; 0.03640300128608942; 0.10960095096379519; 0.09948539081960917; 0.0996188772842288 [128, 4096] x [4096, 1000]; 0.02215010579675436; 0.12592008337378502; 0.031120930798351765; 0.0370654184371233 [2048, 768] x [768, 768]; 0.023501068353652954; 0.10804693214595318; 0.03004650119692087; 0.0276932492852211 [2048, 768] x [768, 3072]; 0.045639658346772194; 0.10883208829909563; 0.062736920081079; 0.06480381824076176 [2048, 3072] x [3072, 768]; 0.054093082435429096; 0.10804777964949608; 0.08744294755160809; 0.07766005117446184 [1024, 768] x [768, 768]; 0.021525858901441097; 0.10909941978752613; 0.02656651195138693; 0.02683836966753006 [1024, 768] x [768, 3072]; 0.027319076471030712; 0.10825308971107006; 0.040118801407516; 0.039282338693737984 [1024, 3072] x [3072, 768]; 0.034132059663534164; 0.10594133753329515; 0.05069758277386427; 0.04572632722556591 [1024, 768] x [768, 2304]; 0.02529360819607973; 0.10486091021448374; 0.03724239766597748; 0.036449190229177475 test GPU time shape; torch mm; triton mm; inductor aten mm; inductor triton mm [128, 9216] x [9216, 4096]; 0.09113600105047226; 0.09011200070381165; 0.21606400609016418; 0.21606400609016418 [128, 4096] x [4096, 4096]; 0.053247999399900436; 0.05222399905323982; 0.1157120019197464; 0.1157120019197464 [128, 4096] x [4096, 1000]; 0.026623999699950218; 0.02969600073993206; 0.04710400104522705; 0.05222399905323982 [2048, 768] x [768, 768]; 0.02457600086927414; 0.020479999482631683; 0.04095999896526337; 0.03993599861860275 [2048, 768] x [768, 3072]; 0.05119999870657921; 0.05222399905323982; 0.07475200295448303; 0.07577600330114365 [2048, 3072] x [3072, 768]; 0.05939200147986412; 0.05222399905323982; 0.09830400347709656; 0.0870399996638298 [1024, 768] x [768, 768]; 0.01945599913597107; 0.016383999958634377; 0.03276799991726875; 0.03276799991726875 [1024, 768] x [768, 3072]; 0.03174399957060814; 0.03276799991726875; 0.053247999399900436; 0.053247999399900436 [1024, 3072] x [3072, 768]; 0.04403200000524521; 0.03379200026392937; 0.06860800087451935; 0.062463998794555664 [1024, 768] x [768, 2304]; 0.02969600073993206; 0.02969600073993206; 0.04915200173854828; 0.048128001391887665 """