1 /*
2 * Copyright © 2022 Advanced Micro Devices, Inc.
3 *
4 * SPDX-License-Identifier: MIT
5 */
6
7 /* Convert 8-bit and 16-bit loads to 32 bits. This is for drivers that don't
8 * support non-32-bit loads.
9 *
10 * This pass only transforms load intrinsics lowered by nir_lower_explicit_io,
11 * so this pass should run after it.
12 *
13 * nir_opt_load_store_vectorize should be run before this because it analyzes
14 * offset calculations and recomputes align_mul and align_offset.
15 *
16 * nir_opt_algebraic and (optionally) ALU scalarization are recommended to be
17 * run after this.
18 *
19 * Running nir_opt_load_store_vectorize after this pass may lead to further
20 * vectorization, e.g. adjacent 2x16-bit and 1x32-bit loads will become
21 * 2x32-bit loads.
22 */
23
24 #include "util/u_math.h"
25 #include "ac_nir.h"
26
27 static bool
lower_subdword_loads(nir_builder * b,nir_instr * instr,void * data)28 lower_subdword_loads(nir_builder *b, nir_instr *instr, void *data)
29 {
30 ac_nir_lower_subdword_options *options = data;
31
32 if (instr->type != nir_instr_type_intrinsic)
33 return false;
34
35 nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
36 unsigned num_components = intr->num_components;
37 nir_variable_mode modes =
38 num_components == 1 ? options->modes_1_comp
39 : options->modes_N_comps;
40
41 switch (intr->intrinsic) {
42 case nir_intrinsic_load_ubo:
43 if (!(modes & nir_var_mem_ubo))
44 return false;
45 break;
46 case nir_intrinsic_load_ssbo:
47 if (!(modes & nir_var_mem_ssbo))
48 return false;
49 break;
50 case nir_intrinsic_load_global:
51 if (!(modes & nir_var_mem_global))
52 return false;
53 break;
54 default:
55 return false;
56 }
57
58 unsigned bit_size = intr->def.bit_size;
59 if (bit_size >= 32)
60 return false;
61
62 assert(bit_size == 8 || bit_size == 16);
63
64 unsigned component_size = bit_size / 8;
65 unsigned comp_per_dword = 4 / component_size;
66
67 /* Get the offset alignment relative to the closest dword. */
68 unsigned align_mul = MIN2(nir_intrinsic_align_mul(intr), 4);
69 unsigned align_offset = nir_intrinsic_align_offset(intr) % align_mul;
70
71 nir_src *src_offset = nir_get_io_offset_src(intr);
72 nir_def *offset = src_offset->ssa;
73 nir_def *result = &intr->def;
74
75 /* Change the load to 32 bits per channel, update the channel count,
76 * and increase the declared load alignment.
77 */
78 intr->def.bit_size = 32;
79
80 if (align_mul == 4 && align_offset == 0) {
81 intr->num_components = intr->def.num_components =
82 DIV_ROUND_UP(num_components, comp_per_dword);
83
84 /* Aligned loads. Just bitcast the vector and trim it if there are
85 * trailing unused elements.
86 */
87 b->cursor = nir_after_instr(instr);
88 result = nir_extract_bits(b, &result, 1, 0, num_components, bit_size);
89
90 nir_def_rewrite_uses_after(&intr->def, result,
91 result->parent_instr);
92 return true;
93 }
94
95 /* Multi-component unaligned loads may straddle the dword boundary.
96 * E.g. for 2 components, we need to load an extra dword, and so on.
97 */
98 intr->num_components = intr->def.num_components =
99 DIV_ROUND_UP(4 - align_mul + align_offset + num_components * component_size, 4);
100
101 nir_intrinsic_set_align(intr,
102 MAX2(nir_intrinsic_align_mul(intr), 4),
103 nir_intrinsic_align_offset(intr) & ~0x3);
104
105 if (align_mul == 4) {
106 /* Unaligned loads with an aligned non-constant base offset (which is
107 * X * align_mul) and a constant added offset (align_offset).
108 */
109 assert(align_offset <= 3);
110 assert(align_offset % component_size == 0);
111 unsigned comp_offset = align_offset / component_size;
112
113 /* There is a good probability that the offset is "iadd" adding
114 * align_offset. Subtracting align_offset should eliminate it.
115 */
116 b->cursor = nir_before_instr(instr);
117 nir_src_rewrite(src_offset, nir_iadd_imm(b, offset, -align_offset));
118
119 b->cursor = nir_after_instr(instr);
120 result = nir_extract_bits(b, &result, 1, comp_offset * bit_size,
121 num_components, bit_size);
122
123 nir_def_rewrite_uses_after(&intr->def, result,
124 result->parent_instr);
125 return true;
126 }
127
128 /* Fully unaligned loads. We overfetch by up to 1 dword and then bitshift
129 * the whole vector.
130 */
131 assert(align_mul <= 2 && align_offset <= 3);
132
133 /* Round down by masking out the bits. */
134 b->cursor = nir_before_instr(instr);
135 nir_src_rewrite(src_offset, nir_iand_imm(b, offset, ~0x3));
136
137 /* We need to shift bits in the loaded vector by this number. */
138 b->cursor = nir_after_instr(instr);
139 nir_def *shift = nir_ishl_imm(b, nir_iand_imm(b, offset, 0x3), 3);
140 nir_def *rev_shift32 = nir_isub_imm(b, 32, shift);
141
142 nir_def *elems[NIR_MAX_VEC_COMPONENTS];
143
144 /* "shift" can be only be one of: 0, 8, 16, 24
145 *
146 * When we shift by (32 - shift) and shift is 0, resulting in a shift by 32,
147 * which is the same as a shift by 0, we need to convert the shifted number
148 * to u64 to get the shift by 32 that we want.
149 *
150 * The following algorithms are used to shift the vector.
151 *
152 * 64-bit variant (shr64 + shl64 + or32 per 2 elements):
153 * for (i = 0; i < num_components / 2 - 1; i++) {
154 * qword1 = pack(src[i * 2 + 0], src[i * 2 + 1]) >> shift;
155 * dword2 = u2u32(u2u64(src[i * 2 + 2]) << (32 - shift));
156 * dst[i * 2 + 0] = unpack_64_2x32_x(qword1);
157 * dst[i * 2 + 1] = unpack_64_2x32_y(qword1) | dword2;
158 * }
159 * i *= 2;
160 *
161 * 32-bit variant (shr32 + shl64 + or32 per element):
162 * for (; i < num_components - 1; i++)
163 * dst[i] = (src[i] >> shift) |
164 * u2u32(u2u64(src[i + 1]) << (32 - shift));
165 */
166 unsigned i = 0;
167
168 if (intr->num_components >= 2) {
169 /* Use the 64-bit algorithm as described above. */
170 for (i = 0; i < intr->num_components / 2 - 1; i++) {
171 nir_def *qword1, *dword2;
172
173 qword1 = nir_pack_64_2x32_split(b,
174 nir_channel(b, result, i * 2 + 0),
175 nir_channel(b, result, i * 2 + 1));
176 qword1 = nir_ushr(b, qword1, shift);
177 dword2 = nir_ishl(b, nir_u2u64(b, nir_channel(b, result, i * 2 + 2)),
178 rev_shift32);
179 dword2 = nir_u2u32(b, dword2);
180
181 elems[i * 2 + 0] = nir_unpack_64_2x32_split_x(b, qword1);
182 elems[i * 2 + 1] =
183 nir_ior(b, nir_unpack_64_2x32_split_y(b, qword1), dword2);
184 }
185 i *= 2;
186
187 /* Use the 32-bit algorithm for the remainder of the vector. */
188 for (; i < intr->num_components - 1; i++) {
189 elems[i] =
190 nir_ior(b,
191 nir_ushr(b, nir_channel(b, result, i), shift),
192 nir_u2u32(b,
193 nir_ishl(b, nir_u2u64(b, nir_channel(b, result, i + 1)),
194 rev_shift32)));
195 }
196 }
197
198 /* Shift the last element. */
199 elems[i] = nir_ushr(b, nir_channel(b, result, i), shift);
200
201 result = nir_vec(b, elems, intr->num_components);
202 result = nir_extract_bits(b, &result, 1, 0, num_components, bit_size);
203
204 nir_def_rewrite_uses_after(&intr->def, result,
205 result->parent_instr);
206 return true;
207 }
208
209 bool
ac_nir_lower_subdword_loads(nir_shader * nir,ac_nir_lower_subdword_options options)210 ac_nir_lower_subdword_loads(nir_shader *nir, ac_nir_lower_subdword_options options)
211 {
212 return nir_shader_instructions_pass(nir, lower_subdword_loads,
213 nir_metadata_dominance |
214 nir_metadata_block_index, &options);
215 }
216