• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2022 Alyssa Rosenzweig
3  * SPDX-License-Identifier: MIT
4  */
5 
6 #include <stdint.h>
7 #include "compiler/nir/nir_builder.h"
8 #include "agx_nir.h"
9 #include "nir.h"
10 #include "nir_intrinsics.h"
11 #include "nir_opcodes.h"
12 
13 struct match {
14    nir_scalar base, offset;
15    bool sign_extend;
16    uint8_t shift;
17 };
18 
19 static enum pipe_format
format_for_bitsize(unsigned bitsize)20 format_for_bitsize(unsigned bitsize)
21 {
22    switch (bitsize) {
23    case 8:
24       return PIPE_FORMAT_R8_UINT;
25    case 16:
26       return PIPE_FORMAT_R16_UINT;
27    case 32:
28       return PIPE_FORMAT_R32_UINT;
29    default:
30       unreachable("should have been lowered");
31    }
32 }
33 
34 static bool
pass(struct nir_builder * b,nir_intrinsic_instr * intr,void * data)35 pass(struct nir_builder *b, nir_intrinsic_instr *intr, void *data)
36 {
37    if (intr->intrinsic != nir_intrinsic_load_global &&
38        intr->intrinsic != nir_intrinsic_load_global_constant &&
39        intr->intrinsic != nir_intrinsic_global_atomic &&
40        intr->intrinsic != nir_intrinsic_global_atomic_swap &&
41        intr->intrinsic != nir_intrinsic_store_global)
42       return false;
43 
44    b->cursor = nir_before_instr(&intr->instr);
45 
46    unsigned bitsize = intr->intrinsic == nir_intrinsic_store_global
47                          ? nir_src_bit_size(intr->src[0])
48                          : intr->def.bit_size;
49    enum pipe_format format = format_for_bitsize(bitsize);
50    unsigned format_shift = util_logbase2(util_format_get_blocksize(format));
51 
52    nir_src *orig_offset = nir_get_io_offset_src(intr);
53    nir_scalar base = nir_scalar_resolved(orig_offset->ssa, 0);
54    struct match match = {.base = base};
55    bool shift_must_match =
56       (intr->intrinsic == nir_intrinsic_global_atomic) ||
57       (intr->intrinsic == nir_intrinsic_global_atomic_swap);
58    unsigned max_shift = format_shift + (shift_must_match ? 0 : 2);
59 
60    if (nir_scalar_is_alu(base)) {
61       nir_op op = nir_scalar_alu_op(base);
62       if (op == nir_op_ulea_agx || op == nir_op_ilea_agx) {
63          unsigned shift = nir_scalar_as_uint(nir_scalar_chase_alu_src(base, 2));
64          if (shift >= format_shift && shift <= max_shift) {
65             match = (struct match){
66                .base = nir_scalar_chase_alu_src(base, 0),
67                .offset = nir_scalar_chase_alu_src(base, 1),
68                .shift = shift - format_shift,
69                .sign_extend = (op == nir_op_ilea_agx),
70             };
71          }
72       } else if (op == nir_op_iadd) {
73          for (unsigned i = 0; i < 2; ++i) {
74             nir_scalar const_scalar = nir_scalar_chase_alu_src(base, i);
75             if (!nir_scalar_is_const(const_scalar))
76                continue;
77 
78             /* Put scalar into form (k*2^n), clamping n at the maximum hardware
79              * shift.
80              */
81             int64_t raw_scalar = nir_scalar_as_uint(const_scalar);
82             uint32_t shift = MIN2(__builtin_ctz(raw_scalar), max_shift);
83             int64_t k = raw_scalar >> shift;
84 
85             /* See if the reduced scalar is from a sign extension. */
86             if (k > INT32_MAX || k < INT32_MIN)
87                break;
88 
89             /* Match the constant */
90             match = (struct match){
91                .base = nir_scalar_chase_alu_src(base, 1 - i),
92                .offset = nir_get_scalar(nir_imm_int(b, k), 0),
93                .shift = shift - format_shift,
94                .sign_extend = true,
95             };
96 
97             break;
98          }
99       }
100    }
101 
102    nir_def *offset = match.offset.def != NULL
103                         ? nir_channel(b, match.offset.def, match.offset.comp)
104                         : nir_imm_int(b, 0);
105 
106    nir_def *new_base = nir_channel(b, match.base.def, match.base.comp);
107 
108    nir_def *repl = NULL;
109    bool has_dest = (intr->intrinsic != nir_intrinsic_store_global);
110    unsigned num_components = has_dest ? intr->def.num_components : 0;
111    unsigned bit_size = has_dest ? intr->def.bit_size : 0;
112 
113    if (intr->intrinsic == nir_intrinsic_load_global) {
114       repl =
115          nir_load_agx(b, num_components, bit_size, new_base, offset,
116                       .access = nir_intrinsic_access(intr), .base = match.shift,
117                       .format = format, .sign_extend = match.sign_extend);
118 
119    } else if (intr->intrinsic == nir_intrinsic_load_global_constant) {
120       repl = nir_load_constant_agx(b, num_components, bit_size, new_base,
121                                    offset, .access = nir_intrinsic_access(intr),
122                                    .base = match.shift, .format = format,
123                                    .sign_extend = match.sign_extend);
124    } else if (intr->intrinsic == nir_intrinsic_global_atomic) {
125       repl =
126          nir_global_atomic_agx(b, bit_size, new_base, offset, intr->src[1].ssa,
127                                .atomic_op = nir_intrinsic_atomic_op(intr),
128                                .sign_extend = match.sign_extend);
129    } else if (intr->intrinsic == nir_intrinsic_global_atomic_swap) {
130       repl = nir_global_atomic_swap_agx(
131          b, bit_size, new_base, offset, intr->src[1].ssa, intr->src[2].ssa,
132          .atomic_op = nir_intrinsic_atomic_op(intr),
133          .sign_extend = match.sign_extend);
134    } else {
135       nir_store_agx(b, intr->src[0].ssa, new_base, offset,
136                     .access = nir_intrinsic_access(intr), .base = match.shift,
137                     .format = format, .sign_extend = match.sign_extend);
138    }
139 
140    if (repl)
141       nir_def_rewrite_uses(&intr->def, repl);
142 
143    nir_instr_remove(&intr->instr);
144    return true;
145 }
146 
147 bool
agx_nir_lower_address(nir_shader * nir)148 agx_nir_lower_address(nir_shader *nir)
149 {
150    bool progress = false;
151 
152    /* First, clean up as much as possible. This will make fusing more effective.
153     */
154    do {
155       progress = false;
156       NIR_PASS(progress, nir, agx_nir_cleanup_amul);
157       NIR_PASS(progress, nir, nir_opt_constant_folding);
158       NIR_PASS(progress, nir, nir_opt_dce);
159    } while (progress);
160 
161    /* Then, fuse as many lea as possible */
162    NIR_PASS(progress, nir, agx_nir_fuse_lea);
163 
164    /* Next, lower load/store using the lea's */
165    NIR_PASS(progress, nir, nir_shader_intrinsics_pass, pass,
166             nir_metadata_control_flow, NULL);
167 
168    /* Finally, lower any leftover lea instructions back to ALU to let
169     * nir_opt_algebraic simplify them from here.
170     */
171    NIR_PASS(progress, nir, agx_nir_lower_lea);
172    NIR_PASS(progress, nir, nir_opt_dce);
173 
174    return progress;
175 }
176