1 /*
2 * Copyright 2022 Alyssa Rosenzweig
3 * SPDX-License-Identifier: MIT
4 */
5
6 #include <stdint.h>
7 #include "compiler/nir/nir_builder.h"
8 #include "agx_nir.h"
9 #include "nir.h"
10 #include "nir_intrinsics.h"
11 #include "nir_opcodes.h"
12
13 struct match {
14 nir_scalar base, offset;
15 bool sign_extend;
16 uint8_t shift;
17 };
18
19 static enum pipe_format
format_for_bitsize(unsigned bitsize)20 format_for_bitsize(unsigned bitsize)
21 {
22 switch (bitsize) {
23 case 8:
24 return PIPE_FORMAT_R8_UINT;
25 case 16:
26 return PIPE_FORMAT_R16_UINT;
27 case 32:
28 return PIPE_FORMAT_R32_UINT;
29 default:
30 unreachable("should have been lowered");
31 }
32 }
33
34 static bool
pass(struct nir_builder * b,nir_intrinsic_instr * intr,void * data)35 pass(struct nir_builder *b, nir_intrinsic_instr *intr, void *data)
36 {
37 if (intr->intrinsic != nir_intrinsic_load_global &&
38 intr->intrinsic != nir_intrinsic_load_global_constant &&
39 intr->intrinsic != nir_intrinsic_global_atomic &&
40 intr->intrinsic != nir_intrinsic_global_atomic_swap &&
41 intr->intrinsic != nir_intrinsic_store_global)
42 return false;
43
44 b->cursor = nir_before_instr(&intr->instr);
45
46 unsigned bitsize = intr->intrinsic == nir_intrinsic_store_global
47 ? nir_src_bit_size(intr->src[0])
48 : intr->def.bit_size;
49 enum pipe_format format = format_for_bitsize(bitsize);
50 unsigned format_shift = util_logbase2(util_format_get_blocksize(format));
51
52 nir_src *orig_offset = nir_get_io_offset_src(intr);
53 nir_scalar base = nir_scalar_resolved(orig_offset->ssa, 0);
54 struct match match = {.base = base};
55 bool shift_must_match =
56 (intr->intrinsic == nir_intrinsic_global_atomic) ||
57 (intr->intrinsic == nir_intrinsic_global_atomic_swap);
58 unsigned max_shift = format_shift + (shift_must_match ? 0 : 2);
59
60 if (nir_scalar_is_alu(base)) {
61 nir_op op = nir_scalar_alu_op(base);
62 if (op == nir_op_ulea_agx || op == nir_op_ilea_agx) {
63 unsigned shift = nir_scalar_as_uint(nir_scalar_chase_alu_src(base, 2));
64 if (shift >= format_shift && shift <= max_shift) {
65 match = (struct match){
66 .base = nir_scalar_chase_alu_src(base, 0),
67 .offset = nir_scalar_chase_alu_src(base, 1),
68 .shift = shift - format_shift,
69 .sign_extend = (op == nir_op_ilea_agx),
70 };
71 }
72 } else if (op == nir_op_iadd) {
73 for (unsigned i = 0; i < 2; ++i) {
74 nir_scalar const_scalar = nir_scalar_chase_alu_src(base, i);
75 if (!nir_scalar_is_const(const_scalar))
76 continue;
77
78 /* Put scalar into form (k*2^n), clamping n at the maximum hardware
79 * shift.
80 */
81 int64_t raw_scalar = nir_scalar_as_uint(const_scalar);
82 uint32_t shift = MIN2(__builtin_ctz(raw_scalar), max_shift);
83 int64_t k = raw_scalar >> shift;
84
85 /* See if the reduced scalar is from a sign extension. */
86 if (k > INT32_MAX || k < INT32_MIN)
87 break;
88
89 /* Match the constant */
90 match = (struct match){
91 .base = nir_scalar_chase_alu_src(base, 1 - i),
92 .offset = nir_get_scalar(nir_imm_int(b, k), 0),
93 .shift = shift - format_shift,
94 .sign_extend = true,
95 };
96
97 break;
98 }
99 }
100 }
101
102 nir_def *offset = match.offset.def != NULL
103 ? nir_channel(b, match.offset.def, match.offset.comp)
104 : nir_imm_int(b, 0);
105
106 nir_def *new_base = nir_channel(b, match.base.def, match.base.comp);
107
108 nir_def *repl = NULL;
109 bool has_dest = (intr->intrinsic != nir_intrinsic_store_global);
110 unsigned num_components = has_dest ? intr->def.num_components : 0;
111 unsigned bit_size = has_dest ? intr->def.bit_size : 0;
112
113 if (intr->intrinsic == nir_intrinsic_load_global) {
114 repl =
115 nir_load_agx(b, num_components, bit_size, new_base, offset,
116 .access = nir_intrinsic_access(intr), .base = match.shift,
117 .format = format, .sign_extend = match.sign_extend);
118
119 } else if (intr->intrinsic == nir_intrinsic_load_global_constant) {
120 repl = nir_load_constant_agx(b, num_components, bit_size, new_base,
121 offset, .access = nir_intrinsic_access(intr),
122 .base = match.shift, .format = format,
123 .sign_extend = match.sign_extend);
124 } else if (intr->intrinsic == nir_intrinsic_global_atomic) {
125 repl =
126 nir_global_atomic_agx(b, bit_size, new_base, offset, intr->src[1].ssa,
127 .atomic_op = nir_intrinsic_atomic_op(intr),
128 .sign_extend = match.sign_extend);
129 } else if (intr->intrinsic == nir_intrinsic_global_atomic_swap) {
130 repl = nir_global_atomic_swap_agx(
131 b, bit_size, new_base, offset, intr->src[1].ssa, intr->src[2].ssa,
132 .atomic_op = nir_intrinsic_atomic_op(intr),
133 .sign_extend = match.sign_extend);
134 } else {
135 nir_store_agx(b, intr->src[0].ssa, new_base, offset,
136 .access = nir_intrinsic_access(intr), .base = match.shift,
137 .format = format, .sign_extend = match.sign_extend);
138 }
139
140 if (repl)
141 nir_def_rewrite_uses(&intr->def, repl);
142
143 nir_instr_remove(&intr->instr);
144 return true;
145 }
146
147 bool
agx_nir_lower_address(nir_shader * nir)148 agx_nir_lower_address(nir_shader *nir)
149 {
150 bool progress = false;
151
152 /* First, clean up as much as possible. This will make fusing more effective.
153 */
154 do {
155 progress = false;
156 NIR_PASS(progress, nir, agx_nir_cleanup_amul);
157 NIR_PASS(progress, nir, nir_opt_constant_folding);
158 NIR_PASS(progress, nir, nir_opt_dce);
159 } while (progress);
160
161 /* Then, fuse as many lea as possible */
162 NIR_PASS(progress, nir, agx_nir_fuse_lea);
163
164 /* Next, lower load/store using the lea's */
165 NIR_PASS(progress, nir, nir_shader_intrinsics_pass, pass,
166 nir_metadata_control_flow, NULL);
167
168 /* Finally, lower any leftover lea instructions back to ALU to let
169 * nir_opt_algebraic simplify them from here.
170 */
171 NIR_PASS(progress, nir, agx_nir_lower_lea);
172 NIR_PASS(progress, nir, nir_opt_dce);
173
174 return progress;
175 }
176