1# Copyright 2022 Alyssa Rosenzweig 2# Copyright 2021 Collabora, Ltd. 3# Copyright 2016 Intel Corporation 4# SPDX-License-Identifier: MIT 5 6import argparse 7import sys 8import math 9 10a = 'a' 11b = 'b' 12c = 'c' 13d = 'd' 14e = 'e' 15 16lower_sm5_shift = [] 17 18# Our shifts differ from SM5 for the upper bits. Mask to match the NIR 19# behaviour. Because this happens as a late lowering, NIR won't optimize the 20# masking back out (that happens in the main nir_opt_algebraic). 21for s in [8, 16, 32, 64]: 22 for shift in ["ishl", "ishr", "ushr"]: 23 lower_sm5_shift += [((shift, f'a@{s}', b), 24 (shift, a, ('iand', b, s - 1)))] 25 26lower_pack = [ 27 (('pack_half_2x16_split', a, b), 28 ('pack_32_2x16_split', ('f2f16', a), ('f2f16', b))), 29 30 # We don't have 8-bit ALU, so we need to lower this. But if we lower it like 31 # this, we can at least coalesce the pack_32_2x16_split and only pay the 32 # cost of the iors and ishl. (u2u16 of 8-bit is assumed free.) 33 (('pack_32_4x8_split', a, b, c, d), 34 ('pack_32_2x16_split', ('ior', ('u2u16', a), ('ishl', ('u2u16', b), 8)), 35 ('ior', ('u2u16', c), ('ishl', ('u2u16', d), 8)))), 36 37 (('unpack_half_2x16_split_x', a), ('f2f32', ('unpack_32_2x16_split_x', a))), 38 (('unpack_half_2x16_split_y', a), ('f2f32', ('unpack_32_2x16_split_y', a))), 39 40 (('extract_u16', 'a@32', 0), ('u2u32', ('unpack_32_2x16_split_x', a))), 41 (('extract_u16', 'a@32', 1), ('u2u32', ('unpack_32_2x16_split_y', a))), 42 (('extract_i16', 'a@32', 0), ('i2i32', ('unpack_32_2x16_split_x', a))), 43 (('extract_i16', 'a@32', 1), ('i2i32', ('unpack_32_2x16_split_y', a))), 44 45 # For optimizing extract->convert sequences for unpack/pack norm 46 (('u2f32', ('u2u32', a)), ('u2f32', a)), 47 (('i2f32', ('i2i32', a)), ('i2f32', a)), 48 49 # Chew through some 8-bit before the backend has to deal with it 50 (('f2u8', a), ('u2u8', ('f2u16', a))), 51 (('f2i8', a), ('i2i8', ('f2i16', a))), 52 53 # Duplicated from nir_opt_algebraic since this pattern is generated by our 54 # bounds checking optimization which needs to run relatively late. 55 (('unpack_64_2x32_split_x', ('pack_64_2x32_split', a, b)), a), 56 (('unpack_64_2x32_split_y', ('pack_64_2x32_split', a, b)), b), 57 58 # Based on the VIR lowering 59 (('f2f16_rtz', 'a@32'), 60 ('bcsel', ('flt', ('fabs', a), ('fabs', ('f2f32', ('f2f16_rtne', a)))), 61 ('isub', ('f2f16_rtne', a), 1), ('f2f16_rtne', a))), 62 63 # These are based on the lowerings from nir_opt_algebraic, but conditioned 64 # on the number of bits not being constant. If the bit count is constant 65 # (the happy path) we can use our native instruction instead. 66 (('ibitfield_extract', 'value', 'offset', 'bits(is_not_const)'), 67 ('bcsel', ('ieq', 0, 'bits'), 68 0, 69 ('ishr', 70 ('ishl', 'value', ('isub', ('isub', 32, 'bits'), 'offset')), 71 ('isub', 32, 'bits')))), 72 73 (('ubitfield_extract', 'value', 'offset', 'bits(is_not_const)'), 74 ('iand', 75 ('ushr', 'value', 'offset'), 76 ('bcsel', ('ieq', 'bits', 32), 77 0xffffffff, 78 ('isub', ('ishl', 1, 'bits'), 1)))), 79 80 # Codegen depends on this trivial case being optimized out. 81 (('ubitfield_extract', 'value', 'offset', 0), 0), 82 (('ibitfield_extract', 'value', 'offset', 0), 0), 83 84 # At this point, bitfield extracts are constant. We can only do constant 85 # unsigned bitfield extract, so lower signed to unsigned + sign extend. 86 (('ibitfield_extract', a, b, '#bits'), 87 ('ishr', ('ishl', ('ubitfield_extract', a, b, 'bits'), ('isub', 32, 'bits')), 88 ('isub', 32, 'bits'))), 89] 90 91lower_selects = [] 92for T, sizes, one in [('f', [16, 32], 1.0), 93 ('i', [8, 16, 32], 1), 94 ('b', [16, 32], -1)]: 95 for size in sizes: 96 lower_selects.extend([ 97 ((f'b2{T}{size}', ('inot', 'a@1')), ('bcsel', a, 0, one)), 98 ((f'b2{T}{size}', 'a@1'), ('bcsel', a, one, 0)), 99 ]) 100 101# Rewriting bcsel(a || b, ...) in terms of bcsel(a, ...) and bcsel(b, ...) lets 102# our rules to fuse compare-and-select do a better job, assuming that a and b 103# are comparisons themselves. 104# 105# This needs to be a separate pass that runs after lower_selects, in order to 106# pick up patterns like b2f32(iand(...)) 107opt_selects = [ 108 (('bcsel', ('ior(is_used_once)', a, b), c, d), 109 ('bcsel', a, c, ('bcsel', b, c, d))), 110 111 (('bcsel', ('iand(is_used_once)', a, b), c, d), 112 ('bcsel', a, ('bcsel', b, c, d), d)), 113] 114 115# When the ior/iand is used multiple times, we can instead fuse the other way. 116opt_selects.extend([ 117 (('iand', ('inot', 'a@1'), b), ('bcsel', a, False, b)), 118 (('iand', 'a@1', b), ('bcsel', a, b, False)), 119 120 (('ior', ('inot', 'a@1'), b), ('bcsel', a, b, True)), 121 (('ior', 'a@1', b), ('bcsel', a, True, b)), 122]) 123 124fuse_extr = [] 125for start in range(32): 126 fuse_extr.extend([ 127 (('ior', ('ushr', 'a@32', start), ('ishl', 'b@32', 32 - start)), 128 ('extr_agx', a, b, start, 0)), 129 ]) 130 131fuse_ubfe = [] 132for bits in range(1, 32): 133 fuse_ubfe.extend([ 134 (('iand', ('ushr', 'a@32', b), (1 << bits) - 1), 135 ('ubitfield_extract', a, b, bits)) 136 ]) 137 138# (x * y) + s = (x * y) + (s << 0) 139def imad(x, y, z): 140 return ('imadshl_agx', x, y, z, 0) 141 142# (x * y) - s = (x * y) - (s << 0) 143def imsub(x, y, z): 144 return ('imsubshl_agx', x, y, z, 0) 145 146# x + (y << s) = (x * 1) + (y << s) 147def iaddshl(x, y, s): 148 return ('imadshl_agx', x, 1, y, s) 149 150# x - (y << s) = (x * 1) - (y << s) 151def isubshl(x, y, s): 152 return ('imsubshl_agx', x, 1, y, s) 153 154fuse_imad = [ 155 # Reassociate imul+iadd chain in order to fuse imads. This pattern comes up 156 # in compute shader lowering. 157 (('iadd', ('iadd(is_used_once)', ('imul(is_used_once)', a, b), 158 ('imul(is_used_once)', c, d)), e), 159 imad(a, b, imad(c, d, e))), 160 161 # Fuse regular imad 162 (('iadd', ('imul(is_used_once)', a, b), c), imad(a, b, c)), 163 (('isub', ('imul(is_used_once)', a, b), c), imsub(a, b, c)), 164] 165 166for s in range(1, 5): 167 fuse_imad += [ 168 # Definitions 169 (('iadd', a, ('ishl(is_used_once)', b, s)), iaddshl(a, b, s)), 170 (('isub', a, ('ishl(is_used_once)', b, s)), isubshl(a, b, s)), 171 172 # ineg(x) is 0 - x 173 (('ineg', ('ishl(is_used_once)', b, s)), isubshl(0, b, s)), 174 175 # Definitions 176 (imad(a, b, ('ishl(is_used_once)', c, s)), ('imadshl_agx', a, b, c, s)), 177 (imsub(a, b, ('ishl(is_used_once)', c, s)), ('imsubshl_agx', a, b, c, s)), 178 179 # The above but after the below shift lowering 180 (imad(a, b, ('imadshl_agx(is_used_once)', 0, 1, c, s)), ('imadshl_agx', a, b, c, s)), 181 (imsub(a, b, ('imadshl_agx(is_used_once)', 0, 1, c, s)), ('imsubshl_agx', a, b, c, s)), 182 183 # a + (a << s) = a + a * (1 << s) = a * (1 + (1 << s)) 184 (('imul', a, 1 + (1 << s)), iaddshl(a, a, s)), 185 186 # a - (a << s) = a - a * (1 << s) = a * (1 - (1 << s)) 187 (('imul', a, 1 - (1 << s)), isubshl(a, a, s)), 188 189 # a - (a << s) = a * (1 - (1 << s)) = -(a * (1 << s) - 1) 190 (('ineg', ('imul(is_used_once)', a, (1 << s) - 1)), isubshl(a, a, s)), 191 192 # iadd is SCIB, general shfit is IC (slower) 193 (('ishl', a, s), iaddshl(0, a, s)), 194 ] 195 196# If the above rules failed, we have a large constant shift on the IC unit. 197# Might as well fuse an add to form an imad, if we're on the IC anyway. 198fuse_imad += [ 199 (('iadd', a, ('ishl(is_used_once)', b, '#c')), imad(b, ('ishl', 1, c), a)), 200] 201 202# Discard lowering generates this pattern, clean it up 203ixor_bcsel = [ 204 (('ixor', ('bcsel', a, '#b', '#c'), '#d'), 205 ('bcsel', a, ('ixor', b, d), ('ixor', c, d))), 206] 207 208# The main NIR optimizer works on imul, not iadd. We need just enough patterns 209# for amul to let us fuse lea. 210cleanup_amul = [ 211 # Neither operation overflows so we can keep the amul. 212 (('amul', ('amul', a, '#b'), '#c'), ('amul', a, ('imul', b, c))), 213 214 # Result of u2u64 has zero in upper half, so the shift doesn't overflow, so 215 # neither multiplication overflows. 216 (('amul', ('ishl', ('u2u64', 'a@32'), '#b(is_ult_32)'), '#c'), 217 ('amul', ('u2u64', a), ('ishl', c, b))), 218] 219 220fuse_lea = [] 221 222# Handle 64-bit address arithmetic (OpenCL) 223for s in range(1, 5): 224 pot = 1 << s 225 226 fuse_lea += [ 227 # A + (#b + c) 2^s = (A + c 2^s) + #b 2^s 228 (('iadd', 'a@64', ('amul', pot, ('iadd', '#b(is_upper_half_zero)', ('u2u64', 'c@32')))), 229 ('ulea_agx', ('ulea_agx', a, c, s), ('u2u32', b), s)), 230 231 # A + (B + c) 2^s = (A + B 2^s) + c 2^s 232 (('iadd', 'a@64', ('amul', ('iadd', 'b@64', ('i2i64', 'c@32')), pot)), 233 ('ilea_agx', ('iadd', a, ('ishl', b, s)), c, s)), 234 235 # A + 2^s (B + (C + d)) = (A + (B + C)2^s) + d 2^s 236 (('iadd', 'a@64', ('amul', ('iadd', 'b@64', 237 ('iadd', 'c@64', ('u2u64', 'd@32'))), pot)), 238 ('ulea_agx', ('iadd', a, ('ishl', ('iadd', b, c), s)), d, s)), 239 ] 240 241 for sgn in ["u", "i"]: 242 upconv = f'{sgn}2{sgn}64' 243 lea = f'{sgn}lea_agx' 244 245 fuse_lea += [ 246 # Basic pattern match 247 (('iadd', 'a@64', ('amul', (upconv, 'b@32'), pot)), (lea, a, b, s)), 248 (('iadd', 'a@64', ('ishl', (upconv, 'b@32'), s)), (lea, a, b, s)), 249 ] 250 251# Handle relaxed 32-bit address arithmetic (OpenGL, Vulkan) 252for s_ in range(1, 5): 253 # Iterate backwards 254 s = 5 - s_ 255 256 v = 1 << s 257 is_mult = f'(is_unsigned_multiple_of_{v})' 258 259 fuse_lea += [ 260 # A + b * s = A + B * s with relaxed multiply 261 (('iadd', 'a@64', ('u2u64', ('amul', 'b@32', v))), 262 ('ulea_agx', a, b, s)), 263 264 # A + (b * c 2^s) = A + (b * c) 2^s with relaxed multiply 265 (('iadd', 'a@64', ('u2u64', ('amul', 'b@32', f'#c{is_mult}'))), 266 ('ulea_agx', a, ('imul', b, ('ushr', c, s)), s)), 267 268 # A + (b 2^s + c d 2^s) = A + (b + cd) 2^s with relaxation. 269 # 270 # amul is bounded by the buffer size by definition, and both the GL & VK 271 # limit UBOs and SSBOs to INT32_MAX bytes. Therefore, amul has no signed 272 # wrap. 273 # 274 # Further, because we are zero-extending the 32-bit result, the 32-bit 275 # sum must be nonnegative -- if it were negative, it would represent an 276 # offset above INT32_MAX which would be invalid given the amul and 277 # max buffer size. Thus with signed math 278 # 279 # 0 <= b 2^s + cd 2^s < INT32_MAX 280 # 281 # ..and hence 282 # 283 # 0 <= b + cd < INT32_MAX 284 # 285 # Those bounds together with distributivity mean that 286 # 287 # (b 2^s + cd 2^s) mod 2^32 = 2^s ((b + cd) mod 2^32) 288 # 289 # ...which is exactly what we need to factor out the shift. 290 (('iadd', 'a@64', ('u2u64', ('iadd', f'#b{is_mult}', 291 ('amul', 'c@32', f'#d{is_mult}')))), 292 ('ulea_agx', a, ('iadd', ('ishr', b, s), 293 ('amul', 'c@32', ('ishr', d, s))), s)), 294 ] 295 296# 8-bit rules are lowest precedence since we really like to fuse shifts 297fuse_lea += [ 298 (('iadd', a, ('u2u64', 'b@32')), ('ulea_agx', a, b, 0)), 299 (('iadd', a, ('i2i64', 'b@32')), ('ilea_agx', a, b, 0)), 300 301 (('iadd', a, ('iadd', ('u2u64', 'b@32'), c)), 302 ('ulea_agx', ('iadd', a, c), b, 0)), 303 (('iadd', a, ('iadd', ('i2i64', 'b@32'), c)), 304 ('ilea_agx', ('iadd', a, c), b, 0)), 305] 306 307# After lowering address arithmetic, the various address arithmetic opcodes are 308# no longer useful. Lower them to regular arithmetic to let nir_opt_algebraic 309# take over. 310lower_lea = [ 311 (('amul', a, b), ('imul', a, b)), 312 (('ulea_agx', a, b, c), ('iadd', a, ('ishl', ('u2u64', b), c))), 313 (('ilea_agx', a, b, c), ('iadd', a, ('ishl', ('i2i64', b), c))), 314] 315 316def main(): 317 parser = argparse.ArgumentParser() 318 parser.add_argument('-p', '--import-path', required=True) 319 args = parser.parse_args() 320 sys.path.insert(0, args.import_path) 321 run() 322 323def run(): 324 import nir_algebraic # pylint: disable=import-error 325 326 print('#include "agx_nir.h"') 327 328 print(nir_algebraic.AlgebraicPass("agx_nir_cleanup_amul", cleanup_amul).render()) 329 print(nir_algebraic.AlgebraicPass("agx_nir_fuse_lea", fuse_lea).render()) 330 print(nir_algebraic.AlgebraicPass("agx_nir_lower_lea", lower_lea).render()) 331 332 print(nir_algebraic.AlgebraicPass("agx_nir_lower_algebraic_late", 333 lower_sm5_shift + lower_pack + 334 lower_selects).render()) 335 print(nir_algebraic.AlgebraicPass("agx_nir_fuse_selects", 336 opt_selects).render()) 337 print(nir_algebraic.AlgebraicPass("agx_nir_fuse_algebraic_late", 338 fuse_extr + fuse_ubfe + 339 fuse_imad + ixor_bcsel).render()) 340 341 342if __name__ == '__main__': 343 main() 344