• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright © 2022 Advanced Micro Devices, Inc.
3  *
4  * SPDX-License-Identifier: MIT
5  */
6 
7 /* Convert 8-bit and 16-bit loads to 32 bits. This is for drivers that don't
8  * support non-32-bit loads.
9  *
10  * This pass only transforms load intrinsics lowered by nir_lower_explicit_io,
11  * so this pass should run after it.
12  *
13  * nir_opt_load_store_vectorize should be run before this because it analyzes
14  * offset calculations and recomputes align_mul and align_offset.
15  *
16  * nir_opt_algebraic and (optionally) ALU scalarization are recommended to be
17  * run after this.
18  *
19  * Running nir_opt_load_store_vectorize after this pass may lead to further
20  * vectorization, e.g. adjacent 2x16-bit and 1x32-bit loads will become
21  * 2x32-bit loads.
22  */
23 
24 #include "util/u_math.h"
25 #include "ac_nir.h"
26 
27 static bool
lower_subdword_loads(nir_builder * b,nir_instr * instr,void * data)28 lower_subdword_loads(nir_builder *b, nir_instr *instr, void *data)
29 {
30    ac_nir_lower_subdword_options *options = data;
31 
32    if (instr->type != nir_instr_type_intrinsic)
33       return false;
34 
35    nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
36    unsigned num_components = intr->num_components;
37    nir_variable_mode modes =
38       num_components == 1 ? options->modes_1_comp
39                           : options->modes_N_comps;
40 
41    switch (intr->intrinsic) {
42    case nir_intrinsic_load_ubo:
43       if (!(modes & nir_var_mem_ubo))
44          return false;
45       break;
46    case nir_intrinsic_load_ssbo:
47       if (!(modes & nir_var_mem_ssbo))
48          return false;
49       break;
50    case nir_intrinsic_load_global:
51       if (!(modes & nir_var_mem_global))
52          return false;
53       break;
54    default:
55       return false;
56    }
57 
58    unsigned bit_size = intr->def.bit_size;
59    if (bit_size >= 32)
60       return false;
61 
62    assert(bit_size == 8 || bit_size == 16);
63 
64    unsigned component_size = bit_size / 8;
65    unsigned comp_per_dword = 4 / component_size;
66 
67    /* Get the offset alignment relative to the closest dword. */
68    unsigned align_mul = MIN2(nir_intrinsic_align_mul(intr), 4);
69    unsigned align_offset = nir_intrinsic_align_offset(intr) % align_mul;
70 
71    nir_src *src_offset = nir_get_io_offset_src(intr);
72    nir_def *offset = src_offset->ssa;
73    nir_def *result = &intr->def;
74 
75    /* Change the load to 32 bits per channel, update the channel count,
76     * and increase the declared load alignment.
77     */
78    intr->def.bit_size = 32;
79 
80    if (align_mul == 4 && align_offset == 0) {
81       intr->num_components = intr->def.num_components =
82          DIV_ROUND_UP(num_components, comp_per_dword);
83 
84       /* Aligned loads. Just bitcast the vector and trim it if there are
85        * trailing unused elements.
86        */
87       b->cursor = nir_after_instr(instr);
88       result = nir_extract_bits(b, &result, 1, 0, num_components, bit_size);
89 
90       nir_def_rewrite_uses_after(&intr->def, result,
91                                      result->parent_instr);
92       return true;
93    }
94 
95    /* Multi-component unaligned loads may straddle the dword boundary.
96     * E.g. for 2 components, we need to load an extra dword, and so on.
97     */
98    intr->num_components = intr->def.num_components =
99       DIV_ROUND_UP(4 - align_mul + align_offset + num_components * component_size, 4);
100 
101    nir_intrinsic_set_align(intr,
102                            MAX2(nir_intrinsic_align_mul(intr), 4),
103                            nir_intrinsic_align_offset(intr) & ~0x3);
104 
105    if (align_mul == 4) {
106       /* Unaligned loads with an aligned non-constant base offset (which is
107        * X * align_mul) and a constant added offset (align_offset).
108        */
109       assert(align_offset <= 3);
110       assert(align_offset % component_size == 0);
111       unsigned comp_offset = align_offset / component_size;
112 
113       /* There is a good probability that the offset is "iadd" adding
114        * align_offset. Subtracting align_offset should eliminate it.
115        */
116       b->cursor = nir_before_instr(instr);
117       nir_src_rewrite(src_offset, nir_iadd_imm(b, offset, -align_offset));
118 
119       b->cursor = nir_after_instr(instr);
120       result = nir_extract_bits(b, &result, 1, comp_offset * bit_size,
121                                 num_components, bit_size);
122 
123       nir_def_rewrite_uses_after(&intr->def, result,
124                                      result->parent_instr);
125       return true;
126    }
127 
128    /* Fully unaligned loads. We overfetch by up to 1 dword and then bitshift
129     * the whole vector.
130     */
131    assert(align_mul <= 2 && align_offset <= 3);
132 
133    /* Round down by masking out the bits. */
134    b->cursor = nir_before_instr(instr);
135    nir_src_rewrite(src_offset, nir_iand_imm(b, offset, ~0x3));
136 
137    /* We need to shift bits in the loaded vector by this number. */
138    b->cursor = nir_after_instr(instr);
139    nir_def *shift = nir_ishl_imm(b, nir_iand_imm(b, offset, 0x3), 3);
140    nir_def *rev_shift32 = nir_isub_imm(b, 32, shift);
141 
142    nir_def *elems[NIR_MAX_VEC_COMPONENTS];
143 
144    /* "shift" can be only be one of: 0, 8, 16, 24
145     *
146     * When we shift by (32 - shift) and shift is 0, resulting in a shift by 32,
147     * which is the same as a shift by 0, we need to convert the shifted number
148     * to u64 to get the shift by 32 that we want.
149     *
150     * The following algorithms are used to shift the vector.
151     *
152     * 64-bit variant (shr64 + shl64 + or32 per 2 elements):
153     *    for (i = 0; i < num_components / 2 - 1; i++) {
154     *       qword1 = pack(src[i * 2 + 0], src[i * 2 + 1]) >> shift;
155     *       dword2 = u2u32(u2u64(src[i * 2 + 2]) << (32 - shift));
156     *       dst[i * 2 + 0] = unpack_64_2x32_x(qword1);
157     *       dst[i * 2 + 1] = unpack_64_2x32_y(qword1) | dword2;
158     *    }
159     *    i *= 2;
160     *
161     * 32-bit variant (shr32 + shl64 + or32 per element):
162     *    for (; i < num_components - 1; i++)
163     *       dst[i] = (src[i] >> shift) |
164     *                u2u32(u2u64(src[i + 1]) << (32 - shift));
165     */
166    unsigned i = 0;
167 
168    if (intr->num_components >= 2) {
169       /* Use the 64-bit algorithm as described above. */
170       for (i = 0; i < intr->num_components / 2 - 1; i++) {
171          nir_def *qword1, *dword2;
172 
173          qword1 = nir_pack_64_2x32_split(b,
174                                          nir_channel(b, result, i * 2 + 0),
175                                          nir_channel(b, result, i * 2 + 1));
176          qword1 = nir_ushr(b, qword1, shift);
177          dword2 = nir_ishl(b, nir_u2u64(b, nir_channel(b, result, i * 2 + 2)),
178                            rev_shift32);
179          dword2 = nir_u2u32(b, dword2);
180 
181          elems[i * 2 + 0] = nir_unpack_64_2x32_split_x(b, qword1);
182          elems[i * 2 + 1] =
183             nir_ior(b, nir_unpack_64_2x32_split_y(b, qword1), dword2);
184       }
185       i *= 2;
186 
187       /* Use the 32-bit algorithm for the remainder of the vector. */
188       for (; i < intr->num_components - 1; i++) {
189          elems[i] =
190             nir_ior(b,
191                     nir_ushr(b, nir_channel(b, result, i), shift),
192                     nir_u2u32(b,
193                         nir_ishl(b, nir_u2u64(b, nir_channel(b, result, i + 1)),
194                                  rev_shift32)));
195       }
196    }
197 
198    /* Shift the last element. */
199    elems[i] = nir_ushr(b, nir_channel(b, result, i), shift);
200 
201    result = nir_vec(b, elems, intr->num_components);
202    result = nir_extract_bits(b, &result, 1, 0, num_components, bit_size);
203 
204    nir_def_rewrite_uses_after(&intr->def, result,
205                                   result->parent_instr);
206    return true;
207 }
208 
209 bool
ac_nir_lower_subdword_loads(nir_shader * nir,ac_nir_lower_subdword_options options)210 ac_nir_lower_subdword_loads(nir_shader *nir, ac_nir_lower_subdword_options options)
211 {
212    return nir_shader_instructions_pass(nir, lower_subdword_loads,
213                                        nir_metadata_dominance |
214                                        nir_metadata_block_index, &options);
215 }
216