1import logging 2from typing import Any, Dict, List, Optional, Tuple 3 4import sympy 5 6import torch 7 8from .. import config as inductor_config 9from ..ir import ChoiceCaller, Layout, StorageBox, TensorBox 10from ..lowering import add_layout_constraint, constrain_to_fx_strides, register_lowering 11from ..select_algorithm import ( 12 autotune_select_algorithm, 13 ExternKernelChoice, 14 NoValidChoicesError, 15 realize_inputs, 16 TritonTemplate, 17) 18from ..utils import use_aten_gemm_kernels, use_triton_template 19from .mm import _is_static_problem # TODO(yangsiyu) move to mm_common 20from .mm_common import mm_args, mm_grid, scaled_mm_configs 21 22 23log = logging.getLogger(__name__) 24aten = torch.ops.aten 25 26 27scaled_mm_template = TritonTemplate( 28 name="scaled_mm", 29 grid=mm_grid, 30 source=r""" 31{{def_kernel("A", "B", "A_inverse_scale", "B_inverse_scale")}} 32 M = {{size("A", 0)}} 33 N = {{size("B", 1)}} 34 K = {{size("A", 1)}} 35 if M * N == 0: 36 # early exit due to zero-size input(s) 37 return 38 stride_am = {{stride("A", 0)}} 39 stride_ak = {{stride("A", 1)}} 40 stride_bk = {{stride("B", 0)}} 41 stride_bn = {{stride("B", 1)}} 42 43 # based on triton.ops.matmul 44 pid = tl.program_id(0) 45 grid_m = (M + BLOCK_M - 1) // BLOCK_M 46 grid_n = (N + BLOCK_N - 1) // BLOCK_N 47 48 # re-order program ID for better L2 performance 49 width = GROUP_M * grid_n 50 group_id = pid // width 51 group_size = min(grid_m - group_id * GROUP_M, GROUP_M) 52 pid_m = group_id * GROUP_M + (pid % group_size) 53 pid_n = (pid % width) // (group_size) 54 55 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 56 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 57 ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) 58 rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) 59 rk = tl.arange(0, BLOCK_K) 60 A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) 61 B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) 62 63 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) 64 for k in range(K, 0, -BLOCK_K): 65 if EVEN_K: 66 a = tl.load(A) 67 b = tl.load(B) 68 else: 69 a = tl.load(A, mask=rk[None, :] < k, other=0.) 70 b = tl.load(B, mask=rk[:, None] < k, other=0.) 71 if B_PROLOGUE_CAST_TYPE is not None: 72 b = b.to(B_PROLOGUE_CAST_TYPE) 73 if USE_FAST_ACCUM: 74 acc = tl.dot(a, b, acc, out_dtype=ACC_TYPE) 75 else: 76 acc += tl.dot(a, b, out_dtype=ACC_TYPE) 77 A += BLOCK_K * stride_ak 78 B += BLOCK_K * stride_bk 79 80 if SCALING_ROWWISE: 81 inv_a_scale_row = tl.load(A_inverse_scale + rm, mask=rm < M) 82 inv_b_scale_row = tl.load(B_inverse_scale + rn, mask=rn < N) 83 inv_scale_row = inv_a_scale_row[:, None] * inv_b_scale_row[None, :] 84 acc *= inv_scale_row 85 else: 86 # for tensor-wise scaling, the scales are scalars 87 inv_a_scale = tl.load(A_inverse_scale) 88 inv_b_scale = tl.load(B_inverse_scale) 89 inv_scale = inv_a_scale * inv_b_scale 90 acc *= inv_scale 91 92 # rematerialize rm and rn to save registers 93 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 94 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 95 96 idx_m = rm[:, None] 97 idx_n = rn[None, :] 98 mask = (idx_m < M) & (idx_n < N) 99 100 # inductor generates a suffix 101 {{store_output(("idx_m", "idx_n"), "acc", "mask")}} 102""", 103) 104 105 106# Inductor does not allow optional tensor input arguments currently (pass None as an 107# input node to template choices), but since for _scaled_mm there is only one such arg 108# (bias), work around by having a second template when bias is provided. 109scaled_mm_bias_template = TritonTemplate( 110 name="scaled_mm_bias", 111 grid=mm_grid, 112 source=r""" 113{{def_kernel("A", "B", "A_inverse_scale", "B_inverse_scale", "bias_ptr")}} 114 M = {{size("A", 0)}} 115 N = {{size("B", 1)}} 116 K = {{size("A", 1)}} 117 if M * N == 0: 118 # early exit due to zero-size input(s) 119 return 120 stride_am = {{stride("A", 0)}} 121 stride_ak = {{stride("A", 1)}} 122 stride_bk = {{stride("B", 0)}} 123 stride_bn = {{stride("B", 1)}} 124 125 # based on triton.ops.matmul 126 pid = tl.program_id(0) 127 grid_m = (M + BLOCK_M - 1) // BLOCK_M 128 grid_n = (N + BLOCK_N - 1) // BLOCK_N 129 130 # re-order program ID for better L2 performance 131 width = GROUP_M * grid_n 132 group_id = pid // width 133 group_size = min(grid_m - group_id * GROUP_M, GROUP_M) 134 pid_m = group_id * GROUP_M + (pid % group_size) 135 pid_n = (pid % width) // (group_size) 136 137 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 138 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 139 ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) 140 rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) 141 rk = tl.arange(0, BLOCK_K) 142 A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) 143 B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) 144 145 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) 146 for k in range(K, 0, -BLOCK_K): 147 if EVEN_K: 148 a = tl.load(A) 149 b = tl.load(B) 150 else: 151 a = tl.load(A, mask=rk[None, :] < k, other=0.) 152 b = tl.load(B, mask=rk[:, None] < k, other=0.) 153 if B_PROLOGUE_CAST_TYPE is not None: 154 b = b.to(B_PROLOGUE_CAST_TYPE) 155 if USE_FAST_ACCUM: 156 acc = tl.dot(a, b, acc, out_dtype=ACC_TYPE) 157 else: 158 acc += tl.dot(a, b, out_dtype=ACC_TYPE) 159 A += BLOCK_K * stride_ak 160 B += BLOCK_K * stride_bk 161 162 if SCALING_ROWWISE: 163 inv_a_scale_row = tl.load(A_inverse_scale + rm, mask=rm < M) 164 inv_b_scale_row = tl.load(B_inverse_scale + rn, mask=rn < N) 165 inv_scale_row = inv_a_scale_row[:, None] * inv_b_scale_row[None, :] 166 acc *= inv_scale_row 167 else: 168 # for tensor-wise scaling, the scales are scalars 169 inv_a_scale = tl.load(A_inverse_scale) 170 inv_b_scale = tl.load(B_inverse_scale) 171 inv_scale = inv_a_scale * inv_b_scale 172 acc *= inv_scale 173 174 # rematerialize rm and rn to save registers 175 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 176 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 177 178 # bias 179 bias = tl.load(bias_ptr + rn, mask=rn < N) 180 acc += bias 181 182 idx_m = rm[:, None] 183 idx_n = rn[None, :] 184 mask = (idx_m < M) & (idx_n < N) 185 186 # inductor generates a suffix 187 {{store_output(("idx_m", "idx_n"), "acc", "mask")}} 188""", 189) 190 191 192aten__fp8_mm = ExternKernelChoice(torch._scaled_mm, "at::_scaled_mm") 193 194 195def are_compatible_scales(size_a: List[int], size_b: List[int]) -> bool: 196 # Same sized scales are compatable 197 if len(size_a) == len(size_b): 198 return True 199 200 # Both need to be scalars or len(1) tensors 201 if len(size_a) <= 1 and len(size_b) <= 1: 202 return True 203 204 return False 205 206 207def scaled_mm_options( # type: ignore[no-untyped-def] 208 config, # triton.Config 209 sym_m: sympy.core.numbers.Integer, 210 sym_n: sympy.core.numbers.Integer, 211 sym_k: sympy.core.numbers.Integer, 212 layout: Layout, 213 scale_a: StorageBox, 214 scale_b: StorageBox, 215 use_fast_accum: bool, 216 b_prologue_cast_type: Optional[str] = None, 217) -> Dict[str, Any]: 218 even_k_symbolic = ( 219 sympy.gcd(sym_k, config.kwargs["BLOCK_K"]) == config.kwargs["BLOCK_K"] 220 ) 221 222 size_a, size_b = scale_a.get_size(), scale_b.get_size() 223 assert are_compatible_scales(size_a, size_b), ( 224 "Expect scale_a and scale_b to be either both scalars (including single-element tensors) " 225 f"or 1-dimensional tensors with the same size. Got scale_a: {len(size_a)} and scale_b: {len(size_b)}." 226 ) 227 return dict( 228 GROUP_M=8, 229 EVEN_K=even_k_symbolic, 230 ACC_TYPE="tl.float32", 231 B_PROLOGUE_CAST_TYPE=b_prologue_cast_type, 232 USE_FAST_ACCUM=use_fast_accum, 233 num_stages=config.num_stages, 234 num_warps=config.num_warps, 235 # tensor-wise scaling if scalar scales 236 SCALING_ROWWISE=len(scale_a.get_size()) == 2, 237 **config.kwargs, 238 ) 239 240 241add_layout_constraint(aten._scaled_mm.default, constrain_to_fx_strides) 242 243 244@register_lowering(aten._scaled_mm.default, type_promotion_kind=None) # type: ignore[misc] 245def tuned_scaled_mm( 246 mat_a: TensorBox, 247 mat_b: TensorBox, 248 scale_a: TensorBox, 249 scale_b: TensorBox, 250 bias: Optional[TensorBox] = None, 251 scale_result: Optional[TensorBox] = None, 252 out_dtype: Optional[torch.dtype] = None, 253 use_fast_accum: bool = False, 254 layout: Optional[Layout] = None, 255) -> TensorBox: 256 m, n, k, layout, mat_a, mat_b = mm_args( 257 mat_a, mat_b, layout=layout, out_dtype=out_dtype 258 ) 259 scale_a, scale_b = realize_inputs(scale_a, scale_b) 260 261 input_nodes: Tuple[Any, ...] 262 # workaround for Inductor not supporting optional tensor input arguments 263 if bias is None: 264 input_nodes = (mat_a, mat_b, scale_a, scale_b) 265 triton_template = scaled_mm_template 266 else: 267 bias = realize_inputs(bias) 268 input_nodes = (mat_a, mat_b, scale_a, scale_b, bias) 269 triton_template = scaled_mm_bias_template 270 271 aten_choice = aten__fp8_mm.bind( 272 input_nodes, layout, out_dtype=out_dtype, use_fast_accum=use_fast_accum 273 ) 274 275 choices: List[ChoiceCaller] = [] 276 if use_aten_gemm_kernels(): 277 choices.append(aten_choice) 278 279 static_shape, is_nonzero = _is_static_problem([mat_a, mat_b], layout) 280 if is_nonzero and use_triton_template(layout, enable_float8=True): 281 for config in scaled_mm_configs(m, n, k): 282 if k == 16 and config.kwargs["BLOCK_M"] >= 64: 283 continue # Triton crashes in this case 284 kwargs = scaled_mm_options( 285 config, m, n, k, layout, scale_a, scale_b, use_fast_accum 286 ) 287 # possibly appends a TritonTemplateCaller to choices 288 triton_template.maybe_append_choice( 289 choices, 290 input_nodes=input_nodes, 291 layout=layout, 292 **kwargs, 293 ) 294 295 if ( 296 len(choices) == 0 297 and not use_aten_gemm_kernels() 298 and inductor_config.autotune_fallback_to_aten 299 ): 300 log.warning("No choices for scaled_mm, using ATen backend as fallback") 301 return aten_choice.output_node() 302 303 try: 304 return autotune_select_algorithm("scaled_mm", choices, input_nodes, layout) 305 except NoValidChoicesError: 306 if not inductor_config.autotune_fallback_to_aten: 307 raise 308 log.warning( 309 "All choices for scaled_mm were invalid, using ATen backend as fallback" 310 ) 311 return aten_choice.output_node() 312