• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# mypy: allow-untyped-defs
2import logging
3from typing import List, TYPE_CHECKING
4
5from ..select_algorithm import autotune_select_algorithm, TritonTemplate
6from .mm_common import mm_args, mm_configs, mm_grid, mm_options
7
8
9if TYPE_CHECKING:
10    from ..ir import ChoiceCaller
11
12log = logging.getLogger(__name__)
13
14uint4x2_mixed_mm_template = TritonTemplate(
15    name="uint4x2_mixed_mm",
16    grid=mm_grid,
17    source=r"""
18{{def_kernel("A", "B")}}
19    M = {{size("A", 0)}}
20    N = {{size("B", 1)}}
21    K = {{size("A", 1)}}
22    stride_am = {{stride("A", 0)}}
23    stride_ak = {{stride("A", 1)}}
24    stride_bk = {{stride("B", 0)}}
25    stride_bn = {{stride("B", 1)}}
26
27    # based on triton.ops.matmul
28    pid = tl.program_id(0)
29    grid_m = (M + BLOCK_M - 1) // BLOCK_M
30    grid_n = (N + BLOCK_N - 1) // BLOCK_N
31
32    # re-order program ID for better L2 performance
33    width = GROUP_M * grid_n
34    group_id = pid // width
35    group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
36    pid_m = group_id * GROUP_M + (pid % group_size)
37    pid_n = (pid % width) // (group_size)
38
39    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
40    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
41    ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
42    rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
43    rk = tl.arange(0, BLOCK_K)
44    A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
45    B = B + (rk[:, None]//2 * stride_bk + rbn[None, :] * stride_bn)
46    b_shifts = 4*(rk%2)
47    b_subs = 8*(1-(rk%2))
48
49    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
50    for k in range(K, 0, -BLOCK_K):
51        if EVEN_K:
52            a = tl.load(A)
53            b = tl.load(B)
54        else:
55            a = tl.load(A, mask=rk[None, :] < k, other=0.)
56            b = tl.load(B, mask=rk[:, None] < k, other=0.)
57        b = ((b >> b_shifts[:, None]) & 0xF) - 8
58        b = b.to(B_PROLOGUE_CAST_TYPE)
59        acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
60        A += BLOCK_K * stride_ak
61        B += BLOCK_K//2 * stride_bk
62
63    # rematerialize rm and rn to save registers
64    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
65    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
66    idx_m = rm[:, None]
67    idx_n = rn[None, :]
68    mask = (idx_m < M) & (idx_n < N)
69
70    # inductor generates a suffix
71    {{store_output(("idx_m", "idx_n"), "acc", "mask")}}
72""",
73)
74
75
76def tuned_uint4x2_mixed_mm(mat1, mat2, mat2_mm_shape, mat2_dtype):
77    m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=None, use_4x2_dim=True)
78    choices: List[ChoiceCaller] = []
79    b_prologue_cast_type = f"tl.{mat2_dtype}".replace("torch.", "")
80    for config in mm_configs(m, n, k):
81        uint4x2_mixed_mm_template.maybe_append_choice(
82            choices,
83            input_nodes=(mat1, mat2),
84            layout=layout,
85            **mm_options(config, m, n, k, layout, b_prologue_cast_type),
86        )
87    return autotune_select_algorithm("uint4x2_mixed_mm", choices, [mat1, mat2], layout)
88