• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2022 Alyssa Rosenzweig
2# Copyright 2021 Collabora, Ltd.
3# Copyright 2016 Intel Corporation
4# SPDX-License-Identifier: MIT
5
6import argparse
7import sys
8import math
9
10a = 'a'
11b = 'b'
12c = 'c'
13d = 'd'
14e = 'e'
15
16lower_sm5_shift = []
17
18# Our shifts differ from SM5 for the upper bits. Mask to match the NIR
19# behaviour. Because this happens as a late lowering, NIR won't optimize the
20# masking back out (that happens in the main nir_opt_algebraic).
21for s in [8, 16, 32, 64]:
22    for shift in ["ishl", "ishr", "ushr"]:
23        lower_sm5_shift += [((shift, f'a@{s}', b),
24                             (shift, a, ('iand', b, s - 1)))]
25
26lower_pack = [
27    (('pack_half_2x16_split', a, b),
28     ('pack_32_2x16_split', ('f2f16', a), ('f2f16', b))),
29
30    # We don't have 8-bit ALU, so we need to lower this. But if we lower it like
31    # this, we can at least coalesce the pack_32_2x16_split and only pay the
32    # cost of the iors and ishl. (u2u16 of 8-bit is assumed free.)
33    (('pack_32_4x8_split', a, b, c, d),
34     ('pack_32_2x16_split', ('ior', ('u2u16', a), ('ishl', ('u2u16', b), 8)),
35                            ('ior', ('u2u16', c), ('ishl', ('u2u16', d), 8)))),
36
37    (('unpack_half_2x16_split_x', a), ('f2f32', ('unpack_32_2x16_split_x', a))),
38    (('unpack_half_2x16_split_y', a), ('f2f32', ('unpack_32_2x16_split_y', a))),
39
40    (('extract_u16', 'a@32', 0), ('u2u32', ('unpack_32_2x16_split_x', a))),
41    (('extract_u16', 'a@32', 1), ('u2u32', ('unpack_32_2x16_split_y', a))),
42    (('extract_i16', 'a@32', 0), ('i2i32', ('unpack_32_2x16_split_x', a))),
43    (('extract_i16', 'a@32', 1), ('i2i32', ('unpack_32_2x16_split_y', a))),
44
45    # For optimizing extract->convert sequences for unpack/pack norm
46    (('u2f32', ('u2u32', a)), ('u2f32', a)),
47    (('i2f32', ('i2i32', a)), ('i2f32', a)),
48
49    # Chew through some 8-bit before the backend has to deal with it
50    (('f2u8', a), ('u2u8', ('f2u16', a))),
51    (('f2i8', a), ('i2i8', ('f2i16', a))),
52
53    # Duplicated from nir_opt_algebraic since this pattern is generated by our
54    # bounds checking optimization which needs to run relatively late.
55    (('unpack_64_2x32_split_x', ('pack_64_2x32_split', a, b)), a),
56    (('unpack_64_2x32_split_y', ('pack_64_2x32_split', a, b)), b),
57
58    # Based on the VIR lowering
59    (('f2f16_rtz', 'a@32'),
60     ('bcsel', ('flt', ('fabs', a), ('fabs', ('f2f32', ('f2f16_rtne', a)))),
61      ('isub', ('f2f16_rtne', a), 1), ('f2f16_rtne', a))),
62
63    # These are based on the lowerings from nir_opt_algebraic, but conditioned
64    # on the number of bits not being constant. If the bit count is constant
65    # (the happy path) we can use our native instruction instead.
66    (('ibitfield_extract', 'value', 'offset', 'bits(is_not_const)'),
67     ('bcsel', ('ieq', 0, 'bits'),
68      0,
69      ('ishr',
70       ('ishl', 'value', ('isub', ('isub', 32, 'bits'), 'offset')),
71       ('isub', 32, 'bits')))),
72
73    (('ubitfield_extract', 'value', 'offset', 'bits(is_not_const)'),
74     ('iand',
75      ('ushr', 'value', 'offset'),
76      ('bcsel', ('ieq', 'bits', 32),
77       0xffffffff,
78       ('isub', ('ishl', 1, 'bits'), 1)))),
79
80    # Codegen depends on this trivial case being optimized out.
81    (('ubitfield_extract', 'value', 'offset', 0), 0),
82    (('ibitfield_extract', 'value', 'offset', 0), 0),
83
84    # At this point, bitfield extracts are constant. We can only do constant
85    # unsigned bitfield extract, so lower signed to unsigned + sign extend.
86    (('ibitfield_extract', a, b, '#bits'),
87     ('ishr', ('ishl', ('ubitfield_extract', a, b, 'bits'), ('isub', 32, 'bits')),
88      ('isub', 32, 'bits'))),
89]
90
91lower_selects = []
92for T, sizes, one in [('f', [16, 32], 1.0),
93                      ('i', [8, 16, 32], 1),
94                      ('b', [16, 32], -1)]:
95    for size in sizes:
96        lower_selects.extend([
97            ((f'b2{T}{size}', ('inot', 'a@1')), ('bcsel', a, 0, one)),
98            ((f'b2{T}{size}', 'a@1'), ('bcsel', a, one, 0)),
99        ])
100
101# Rewriting bcsel(a || b, ...) in terms of bcsel(a, ...) and bcsel(b, ...) lets
102# our rules to fuse compare-and-select do a better job, assuming that a and b
103# are comparisons themselves.
104#
105# This needs to be a separate pass that runs after lower_selects, in order to
106# pick up patterns like b2f32(iand(...))
107opt_selects = [
108        (('bcsel', ('ior(is_used_once)', a, b), c, d),
109         ('bcsel', a, c, ('bcsel', b, c, d))),
110
111        (('bcsel', ('iand(is_used_once)', a, b), c, d),
112         ('bcsel', a, ('bcsel', b, c, d), d)),
113]
114
115# When the ior/iand is used multiple times, we can instead fuse the other way.
116opt_selects.extend([
117        (('iand', ('inot', 'a@1'), b), ('bcsel', a, False, b)),
118        (('iand', 'a@1', b), ('bcsel', a, b, False)),
119
120        (('ior', ('inot', 'a@1'), b), ('bcsel', a, b, True)),
121        (('ior', 'a@1', b), ('bcsel', a, True, b)),
122])
123
124fuse_extr = []
125for start in range(32):
126    fuse_extr.extend([
127        (('ior', ('ushr', 'a@32', start), ('ishl', 'b@32', 32 - start)),
128         ('extr_agx', a, b, start, 0)),
129    ])
130
131fuse_ubfe = []
132for bits in range(1, 32):
133    fuse_ubfe.extend([
134        (('iand', ('ushr', 'a@32', b), (1 << bits) - 1),
135         ('ubitfield_extract', a, b, bits))
136    ])
137
138# (x * y) + s = (x * y) + (s << 0)
139def imad(x, y, z):
140    return ('imadshl_agx', x, y, z, 0)
141
142# (x * y) - s = (x * y) - (s << 0)
143def imsub(x, y, z):
144    return ('imsubshl_agx', x, y, z, 0)
145
146# x + (y << s) = (x * 1) + (y << s)
147def iaddshl(x, y, s):
148    return ('imadshl_agx', x, 1, y, s)
149
150# x - (y << s) = (x * 1) - (y << s)
151def isubshl(x, y, s):
152    return ('imsubshl_agx', x, 1, y, s)
153
154fuse_imad = [
155    # Reassociate imul+iadd chain in order to fuse imads. This pattern comes up
156    # in compute shader lowering.
157    (('iadd', ('iadd(is_used_once)', ('imul(is_used_once)', a, b),
158              ('imul(is_used_once)', c, d)), e),
159     imad(a, b, imad(c, d, e))),
160
161    # Fuse regular imad
162    (('iadd', ('imul(is_used_once)', a, b), c), imad(a, b, c)),
163    (('isub', ('imul(is_used_once)', a, b), c), imsub(a, b, c)),
164]
165
166for s in range(1, 5):
167    fuse_imad += [
168        # Definitions
169        (('iadd', a, ('ishl(is_used_once)', b, s)), iaddshl(a, b, s)),
170        (('isub', a, ('ishl(is_used_once)', b, s)), isubshl(a, b, s)),
171
172        # ineg(x) is 0 - x
173        (('ineg', ('ishl(is_used_once)', b, s)), isubshl(0, b, s)),
174
175        # Definitions
176        (imad(a, b, ('ishl(is_used_once)', c, s)), ('imadshl_agx', a, b, c, s)),
177        (imsub(a, b, ('ishl(is_used_once)', c, s)), ('imsubshl_agx', a, b, c, s)),
178
179        # The above but after the below shift lowering
180        (imad(a, b, ('imadshl_agx(is_used_once)', 0, 1, c, s)), ('imadshl_agx', a, b, c, s)),
181        (imsub(a, b, ('imadshl_agx(is_used_once)', 0, 1, c, s)), ('imsubshl_agx', a, b, c, s)),
182
183        # a + (a << s) = a + a * (1 << s) = a * (1 + (1 << s))
184        (('imul', a, 1 + (1 << s)), iaddshl(a, a, s)),
185
186        # a - (a << s) = a - a * (1 << s) = a * (1 - (1 << s))
187        (('imul', a, 1 - (1 << s)), isubshl(a, a, s)),
188
189        # a - (a << s) = a * (1 - (1 << s)) = -(a * (1 << s) - 1)
190        (('ineg', ('imul(is_used_once)', a, (1 << s) - 1)), isubshl(a, a, s)),
191
192        # iadd is SCIB, general shfit is IC (slower)
193        (('ishl', a, s), iaddshl(0, a, s)),
194    ]
195
196# If the above rules failed, we have a large constant shift on the IC unit.
197# Might as well fuse an add to form an imad, if we're on the IC anyway.
198fuse_imad += [
199    (('iadd', a, ('ishl(is_used_once)', b, '#c')), imad(b, ('ishl', 1, c), a)),
200]
201
202# Discard lowering generates this pattern, clean it up
203ixor_bcsel = [
204   (('ixor', ('bcsel', a, '#b', '#c'), '#d'),
205    ('bcsel', a, ('ixor', b, d), ('ixor', c, d))),
206]
207
208# The main NIR optimizer works on imul, not iadd. We need just enough patterns
209# for amul to let us fuse lea.
210cleanup_amul = [
211   # Neither operation overflows so we can keep the amul.
212   (('amul', ('amul', a, '#b'), '#c'), ('amul', a, ('imul', b, c))),
213
214   # Result of u2u64 has zero in upper half, so the shift doesn't overflow, so
215   # neither multiplication overflows.
216   (('amul', ('ishl', ('u2u64', 'a@32'), '#b(is_ult_32)'), '#c'),
217    ('amul', ('u2u64', a), ('ishl', c, b))),
218]
219
220fuse_lea = []
221
222# Handle 64-bit address arithmetic (OpenCL)
223for s in range(1, 5):
224    pot = 1 << s
225
226    fuse_lea += [
227        # A + (#b + c) 2^s = (A + c 2^s) + #b 2^s
228        (('iadd', 'a@64', ('amul', pot, ('iadd', '#b(is_upper_half_zero)', ('u2u64', 'c@32')))),
229         ('ulea_agx', ('ulea_agx', a, c, s), ('u2u32', b), s)),
230
231        # A + (B + c) 2^s = (A + B 2^s) + c 2^s
232        (('iadd', 'a@64', ('amul', ('iadd', 'b@64', ('i2i64', 'c@32')), pot)),
233         ('ilea_agx', ('iadd', a, ('ishl', b, s)), c, s)),
234
235        # A + 2^s (B + (C + d)) = (A + (B + C)2^s) + d 2^s
236        (('iadd', 'a@64', ('amul', ('iadd', 'b@64',
237                                   ('iadd', 'c@64', ('u2u64', 'd@32'))), pot)),
238         ('ulea_agx', ('iadd', a, ('ishl', ('iadd', b, c), s)), d, s)),
239    ]
240
241    for sgn in ["u", "i"]:
242        upconv = f'{sgn}2{sgn}64'
243        lea = f'{sgn}lea_agx'
244
245        fuse_lea += [
246            # Basic pattern match
247            (('iadd', 'a@64', ('amul', (upconv, 'b@32'), pot)), (lea, a, b, s)),
248            (('iadd', 'a@64', ('ishl', (upconv, 'b@32'), s)), (lea, a, b, s)),
249        ]
250
251# Handle relaxed 32-bit address arithmetic (OpenGL, Vulkan)
252for s_ in range(1, 5):
253    # Iterate backwards
254    s = 5 - s_
255
256    v = 1 << s
257    is_mult = f'(is_unsigned_multiple_of_{v})'
258
259    fuse_lea += [
260        # A + b * s = A + B * s with relaxed multiply
261        (('iadd', 'a@64', ('u2u64', ('amul', 'b@32', v))),
262         ('ulea_agx', a, b, s)),
263
264        # A + (b * c 2^s) = A + (b * c) 2^s with relaxed multiply
265        (('iadd', 'a@64', ('u2u64', ('amul', 'b@32', f'#c{is_mult}'))),
266         ('ulea_agx', a, ('imul', b, ('ushr', c, s)), s)),
267
268        # A + (b 2^s + c d 2^s) = A + (b + cd) 2^s with relaxation.
269        #
270        # amul is bounded by the buffer size by definition, and both the GL & VK
271        # limit UBOs and SSBOs to INT32_MAX bytes. Therefore, amul has no signed
272        # wrap.
273        #
274        # Further, because we are zero-extending the 32-bit result, the 32-bit
275        # sum must be nonnegative -- if it were negative, it would represent an
276        # offset above INT32_MAX which would be invalid given the amul and
277        # max buffer size. Thus with signed math
278        #
279        #   0 <= b 2^s + cd 2^s < INT32_MAX
280        #
281        # ..and hence
282        #
283        #   0 <= b + cd < INT32_MAX
284        #
285        # Those bounds together with distributivity mean that
286        #
287        #   (b 2^s + cd 2^s) mod 2^32 = 2^s ((b + cd) mod 2^32)
288        #
289        # ...which is exactly what we need to factor out the shift.
290        (('iadd', 'a@64', ('u2u64', ('iadd', f'#b{is_mult}',
291                                     ('amul', 'c@32', f'#d{is_mult}')))),
292         ('ulea_agx', a, ('iadd', ('ishr', b, s),
293                                  ('amul', 'c@32', ('ishr', d, s))), s)),
294    ]
295
296# 8-bit rules are lowest precedence since we really like to fuse shifts
297fuse_lea += [
298    (('iadd', a, ('u2u64', 'b@32')), ('ulea_agx', a, b, 0)),
299    (('iadd', a, ('i2i64', 'b@32')), ('ilea_agx', a, b, 0)),
300
301    (('iadd', a, ('iadd', ('u2u64', 'b@32'), c)),
302     ('ulea_agx', ('iadd', a, c), b, 0)),
303    (('iadd', a, ('iadd', ('i2i64', 'b@32'), c)),
304     ('ilea_agx', ('iadd', a, c), b, 0)),
305]
306
307# After lowering address arithmetic, the various address arithmetic opcodes are
308# no longer useful. Lower them to regular arithmetic to let nir_opt_algebraic
309# take over.
310lower_lea = [
311    (('amul', a, b), ('imul', a, b)),
312    (('ulea_agx', a, b, c), ('iadd', a, ('ishl', ('u2u64', b), c))),
313    (('ilea_agx', a, b, c), ('iadd', a, ('ishl', ('i2i64', b), c))),
314]
315
316def main():
317    parser = argparse.ArgumentParser()
318    parser.add_argument('-p', '--import-path', required=True)
319    args = parser.parse_args()
320    sys.path.insert(0, args.import_path)
321    run()
322
323def run():
324    import nir_algebraic  # pylint: disable=import-error
325
326    print('#include "agx_nir.h"')
327
328    print(nir_algebraic.AlgebraicPass("agx_nir_cleanup_amul", cleanup_amul).render())
329    print(nir_algebraic.AlgebraicPass("agx_nir_fuse_lea", fuse_lea).render())
330    print(nir_algebraic.AlgebraicPass("agx_nir_lower_lea", lower_lea).render())
331
332    print(nir_algebraic.AlgebraicPass("agx_nir_lower_algebraic_late",
333                                      lower_sm5_shift + lower_pack +
334                                      lower_selects).render())
335    print(nir_algebraic.AlgebraicPass("agx_nir_fuse_selects",
336                                      opt_selects).render())
337    print(nir_algebraic.AlgebraicPass("agx_nir_fuse_algebraic_late",
338                                      fuse_extr + fuse_ubfe +
339                                      fuse_imad + ixor_bcsel).render())
340
341
342if __name__ == '__main__':
343    main()
344