• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright © 2016 Bas Nieuwenhuizen
3  *
4  * SPDX-License-Identifier: MIT
5  */
6 
7 #include "ac_gpu_info.h"
8 #include "ac_nir.h"
9 #include "ac_nir_helpers.h"
10 #include "nir_builder.h"
11 
12 /* Set NIR options shared by ACO, LLVM, RADV, and radeonsi. */
ac_nir_set_options(struct radeon_info * info,bool use_llvm,nir_shader_compiler_options * options)13 void ac_nir_set_options(struct radeon_info *info, bool use_llvm,
14                         nir_shader_compiler_options *options)
15 {
16    /*        |---------------------------------- Performance & Availability --------------------------------|
17     *        |MAD/MAC/MADAK/MADMK|MAD_LEGACY|MAC_LEGACY|    FMA     |FMAC/FMAAK/FMAMK|FMA_LEGACY|PK_FMA_F16,|Best choice
18     * Arch   |    F32,F16,F64    | F32,F16  | F32,F16  |F32,F16,F64 |    F32,F16     |   F32    |PK_FMAC_F16|F16,F32,F64
19     * ------------------------------------------------------------------------------------------------------------------
20     * gfx6,7 |     1 , - , -     |  1 , -   |  1 , -   |1/4, - ,1/16|     - , -      |    -     |   - , -   | - ,MAD,FMA
21     * gfx8   |     1 , 1 , -     |  1 , -   |  - , -   |1/4, 1 ,1/16|     - , -      |    -     |   - , -   |MAD,MAD,FMA
22     * gfx9   |     1 ,1|0, -     |  1 , -   |  - , -   | 1 , 1 ,1/16|    0|1, -      |    -     |   2 , -   |FMA,MAD,FMA
23     * gfx10  |     1 , - , -     |  1 , -   |  1 , -   | 1 , 1 ,1/16|     1 , 1      |    -     |   2 , 2   |FMA,MAD,FMA
24     * gfx10.3|     - , - , -     |  - , -   |  - , -   | 1 , 1 ,1/16|     1 , 1      |    1     |   2 , 2   |  all FMA
25     * gfx11  |     - , - , -     |  - , -   |  - , -   | 2 , 2 ,1/16|     2 , 2      |    2     |   2 , 2   |  all FMA
26     *
27     * Tahiti, Hawaii, Carrizo, Vega20: FMA_F32 is full rate, FMA_F64 is 1/4
28     * gfx9 supports MAD_F16 only on Vega10, Raven, Raven2, Renoir.
29     * gfx9 supports FMAC_F32 only on Vega20, but doesn't support FMAAK and FMAMK.
30     *
31     * gfx8 prefers MAD for F16 because of MAC/MADAK/MADMK.
32     * gfx9 and newer prefer FMA for F16 because of the packed instruction.
33     * gfx10 and older prefer MAD for F32 because of the legacy instruction.
34     */
35 
36    memset(options, 0, sizeof(*options));
37    options->vertex_id_zero_based = true;
38    options->lower_scmp = true;
39    options->lower_flrp16 = true;
40    options->lower_flrp32 = true;
41    options->lower_flrp64 = true;
42    options->lower_device_index_to_zero = true;
43    options->lower_fdiv = true;
44    options->lower_fmod = true;
45    options->lower_ineg = true;
46    options->lower_bitfield_insert = true;
47    options->lower_bitfield_extract = true;
48    options->lower_pack_snorm_4x8 = true;
49    options->lower_pack_unorm_4x8 = true;
50    options->lower_pack_half_2x16 = true;
51    options->lower_pack_64_2x32 = true;
52    options->lower_pack_64_4x16 = true;
53    options->lower_pack_32_2x16 = true;
54    options->lower_unpack_snorm_2x16 = true;
55    options->lower_unpack_snorm_4x8 = true;
56    options->lower_unpack_unorm_2x16 = true;
57    options->lower_unpack_unorm_4x8 = true;
58    options->lower_unpack_half_2x16 = true;
59    options->lower_fpow = true;
60    options->lower_mul_2x32_64 = true;
61    options->lower_iadd_sat = info->gfx_level <= GFX8;
62    options->lower_hadd = true;
63    options->lower_mul_32x16 = true;
64    options->has_bfe = true;
65    options->has_bfm = true;
66    options->has_bitfield_select = true;
67    options->has_fneo_fcmpu = true;
68    options->has_ford_funord = true;
69    options->has_fsub = true;
70    options->has_isub = true;
71    options->has_sdot_4x8 = info->has_accelerated_dot_product;
72    options->has_sudot_4x8 = info->has_accelerated_dot_product && info->gfx_level >= GFX11;
73    options->has_udot_4x8 = info->has_accelerated_dot_product;
74    options->has_sdot_4x8_sat = info->has_accelerated_dot_product;
75    options->has_sudot_4x8_sat = info->has_accelerated_dot_product && info->gfx_level >= GFX11;
76    options->has_udot_4x8_sat = info->has_accelerated_dot_product;
77    options->has_dot_2x16 = info->has_accelerated_dot_product && info->gfx_level < GFX11;
78    options->has_find_msb_rev = true;
79    options->has_pack_32_4x8 = true;
80    options->has_pack_half_2x16_rtz = true;
81    options->has_bit_test = !use_llvm;
82    options->has_fmulz = true;
83    options->has_msad = true;
84    options->has_shfr32 = true;
85    options->lower_int64_options = nir_lower_imul64 | nir_lower_imul_high64 | nir_lower_imul_2x32_64 | nir_lower_divmod64 |
86                                   nir_lower_minmax64 | nir_lower_iabs64 | nir_lower_iadd_sat64 | nir_lower_conv64;
87    options->divergence_analysis_options = nir_divergence_view_index_uniform;
88    options->optimize_quad_vote_to_reduce = !use_llvm;
89    options->lower_fisnormal = true;
90    options->support_16bit_alu = info->gfx_level >= GFX8;
91    options->vectorize_vec2_16bit = info->has_packed_math_16bit;
92    options->discard_is_demote = true;
93    options->optimize_sample_mask_in = true;
94    options->optimize_load_front_face_fsign = true;
95    options->io_options = nir_io_has_flexible_input_interpolation_except_flat |
96                          (info->gfx_level >= GFX8 ? nir_io_16bit_input_output_support : 0) |
97                          nir_io_prefer_scalar_fs_inputs |
98                          nir_io_mix_convergent_flat_with_interpolated |
99                          nir_io_vectorizer_ignores_types |
100                          nir_io_compaction_rotates_color_channels;
101    options->lower_layer_fs_input_to_sysval = true;
102    options->scalarize_ddx = true;
103    options->skip_lower_packing_ops =
104       BITFIELD_BIT(nir_lower_packing_op_unpack_64_2x32) |
105       BITFIELD_BIT(nir_lower_packing_op_unpack_64_4x16) |
106       BITFIELD_BIT(nir_lower_packing_op_unpack_32_2x16) |
107       BITFIELD_BIT(nir_lower_packing_op_pack_32_4x8) |
108       BITFIELD_BIT(nir_lower_packing_op_unpack_32_4x8);
109 }
110 
111 /* Sleep for the given number of clock cycles. */
112 void
ac_nir_sleep(nir_builder * b,unsigned num_cycles)113 ac_nir_sleep(nir_builder *b, unsigned num_cycles)
114 {
115    /* s_sleep can only sleep for N*64 cycles. */
116    if (num_cycles >= 64) {
117       nir_sleep_amd(b, num_cycles / 64);
118       num_cycles &= 63;
119    }
120 
121    /* Use s_nop to sleep for the remaining cycles. */
122    while (num_cycles) {
123       unsigned nop_cycles = MIN2(num_cycles, 16);
124 
125       nir_nop_amd(b, nop_cycles - 1);
126       num_cycles -= nop_cycles;
127    }
128 }
129 
130 /* Load argument with index start from arg plus relative_index. */
131 nir_def *
ac_nir_load_arg_at_offset(nir_builder * b,const struct ac_shader_args * ac_args,struct ac_arg arg,unsigned relative_index)132 ac_nir_load_arg_at_offset(nir_builder *b, const struct ac_shader_args *ac_args,
133                           struct ac_arg arg, unsigned relative_index)
134 {
135    unsigned arg_index = arg.arg_index + relative_index;
136    unsigned num_components = ac_args->args[arg_index].size;
137 
138    if (ac_args->args[arg_index].skip)
139       return nir_undef(b, num_components, 32);
140 
141    if (ac_args->args[arg_index].file == AC_ARG_SGPR)
142       return nir_load_scalar_arg_amd(b, num_components, .base = arg_index);
143    else
144       return nir_load_vector_arg_amd(b, num_components, .base = arg_index);
145 }
146 
147 nir_def *
ac_nir_load_arg(nir_builder * b,const struct ac_shader_args * ac_args,struct ac_arg arg)148 ac_nir_load_arg(nir_builder *b, const struct ac_shader_args *ac_args, struct ac_arg arg)
149 {
150    return ac_nir_load_arg_at_offset(b, ac_args, arg, 0);
151 }
152 
153 nir_def *
ac_nir_load_arg_upper_bound(nir_builder * b,const struct ac_shader_args * ac_args,struct ac_arg arg,unsigned upper_bound)154 ac_nir_load_arg_upper_bound(nir_builder *b, const struct ac_shader_args *ac_args, struct ac_arg arg,
155                             unsigned upper_bound)
156 {
157    nir_def *value = ac_nir_load_arg_at_offset(b, ac_args, arg, 0);
158    nir_intrinsic_set_arg_upper_bound_u32_amd(nir_instr_as_intrinsic(value->parent_instr),
159                                              upper_bound);
160    return value;
161 }
162 
163 void
ac_nir_store_arg(nir_builder * b,const struct ac_shader_args * ac_args,struct ac_arg arg,nir_def * val)164 ac_nir_store_arg(nir_builder *b, const struct ac_shader_args *ac_args, struct ac_arg arg,
165                  nir_def *val)
166 {
167    assert(nir_cursor_current_block(b->cursor)->cf_node.parent->type == nir_cf_node_function);
168 
169    if (ac_args->args[arg.arg_index].file == AC_ARG_SGPR)
170       nir_store_scalar_arg_amd(b, val, .base = arg.arg_index);
171    else
172       nir_store_vector_arg_amd(b, val, .base = arg.arg_index);
173 }
174 
175 nir_def *
ac_nir_unpack_value(nir_builder * b,nir_def * value,unsigned rshift,unsigned bitwidth)176 ac_nir_unpack_value(nir_builder *b, nir_def *value, unsigned rshift, unsigned bitwidth)
177 {
178    if (rshift == 0 && bitwidth == 32)
179       return value;
180    else if (rshift == 0)
181       return nir_iand_imm(b, value, BITFIELD_MASK(bitwidth));
182    else if ((32 - rshift) <= bitwidth)
183       return nir_ushr_imm(b, value, rshift);
184    else
185       return nir_ubfe_imm(b, value, rshift, bitwidth);
186 }
187 
188 nir_def *
ac_nir_unpack_arg(nir_builder * b,const struct ac_shader_args * ac_args,struct ac_arg arg,unsigned rshift,unsigned bitwidth)189 ac_nir_unpack_arg(nir_builder *b, const struct ac_shader_args *ac_args, struct ac_arg arg,
190                   unsigned rshift, unsigned bitwidth)
191 {
192    nir_def *value = ac_nir_load_arg(b, ac_args, arg);
193    return ac_nir_unpack_value(b, value, rshift, bitwidth);
194 }
195 
196 bool
ac_nir_lower_indirect_derefs(nir_shader * shader,enum amd_gfx_level gfx_level)197 ac_nir_lower_indirect_derefs(nir_shader *shader,
198                              enum amd_gfx_level gfx_level)
199 {
200    bool progress = false;
201 
202    /* TODO: Don't lower convergent VGPR indexing because the hw can do it. */
203 
204    /* Lower large variables to scratch first so that we won't bloat the
205     * shader by generating large if ladders for them.
206     */
207    NIR_PASS(progress, shader, nir_lower_vars_to_scratch, nir_var_function_temp, 256,
208             glsl_get_natural_size_align_bytes, glsl_get_natural_size_align_bytes);
209 
210    /* This lowers indirect indexing to if-else ladders. */
211    NIR_PASS(progress, shader, nir_lower_indirect_derefs, nir_var_function_temp, UINT32_MAX);
212    return progress;
213 }
214 
215 /* Shader logging function for printing nir_def values. The driver prints this after
216  * command submission.
217  *
218  * Ring buffer layout: {uint32_t num_dwords; vec4; vec4; vec4; ... }
219  * - The buffer size must be 2^N * 16 + 4
220  * - num_dwords is incremented atomically and the ring wraps around, removing
221  *   the oldest entries.
222  */
223 void
ac_nir_store_debug_log_amd(nir_builder * b,nir_def * uvec4)224 ac_nir_store_debug_log_amd(nir_builder *b, nir_def *uvec4)
225 {
226    nir_def *buf = nir_load_debug_log_desc_amd(b);
227    nir_def *zero = nir_imm_int(b, 0);
228 
229    nir_def *max_index =
230       nir_iadd_imm(b, nir_ushr_imm(b, nir_iadd_imm(b, nir_channel(b, buf, 2), -4), 4), -1);
231    nir_def *index = nir_ssbo_atomic(b, 32, buf, zero, nir_imm_int(b, 1),
232                                     .atomic_op = nir_atomic_op_iadd);
233    index = nir_iand(b, index, max_index);
234    nir_def *offset = nir_iadd_imm(b, nir_imul_imm(b, index, 16), 4);
235    nir_store_buffer_amd(b, uvec4, buf, offset, zero, zero);
236 }
237 
238 nir_def *
ac_average_samples(nir_builder * b,nir_def ** samples,unsigned num_samples)239 ac_average_samples(nir_builder *b, nir_def **samples, unsigned num_samples)
240 {
241    /* This works like add-reduce by computing the sum of each pair independently, and then
242     * computing the sum of each pair of sums, and so on, to get better instruction-level
243     * parallelism.
244     */
245    if (num_samples == 16) {
246       for (unsigned i = 0; i < 8; i++)
247          samples[i] = nir_fadd(b, samples[i * 2], samples[i * 2 + 1]);
248    }
249    if (num_samples >= 8) {
250       for (unsigned i = 0; i < 4; i++)
251          samples[i] = nir_fadd(b, samples[i * 2], samples[i * 2 + 1]);
252    }
253    if (num_samples >= 4) {
254       for (unsigned i = 0; i < 2; i++)
255          samples[i] = nir_fadd(b, samples[i * 2], samples[i * 2 + 1]);
256    }
257    if (num_samples >= 2)
258       samples[0] = nir_fadd(b, samples[0], samples[1]);
259 
260    return nir_fmul_imm(b, samples[0], 1.0 / num_samples); /* average the sum */
261 }
262 
263 void
ac_optimization_barrier_vgpr_array(const struct radeon_info * info,nir_builder * b,nir_def ** array,unsigned num_elements,unsigned num_components)264 ac_optimization_barrier_vgpr_array(const struct radeon_info *info, nir_builder *b,
265                                    nir_def **array, unsigned num_elements,
266                                    unsigned num_components)
267 {
268    /* We use the optimization barrier to force LLVM to form VMEM clauses by constraining its
269     * instruction scheduling options.
270     *
271     * VMEM clauses are supported since GFX10. It's not recommended to use the optimization
272     * barrier in the compute blit for GFX6-8 because the lack of A16 combined with optimization
273     * barriers would unnecessarily increase VGPR usage for MSAA resources.
274     */
275    if (!b->shader->info.use_aco_amd && info->gfx_level >= GFX10) {
276       for (unsigned i = 0; i < num_elements; i++) {
277          unsigned prev_num = array[i]->num_components;
278          array[i] = nir_trim_vector(b, array[i], num_components);
279          array[i] = nir_optimization_barrier_vgpr_amd(b, array[i]->bit_size, array[i]);
280          array[i] = nir_pad_vector(b, array[i], prev_num);
281       }
282    }
283 }
284 
285 nir_def *
ac_get_global_ids(nir_builder * b,unsigned num_components,unsigned bit_size)286 ac_get_global_ids(nir_builder *b, unsigned num_components, unsigned bit_size)
287 {
288    unsigned mask = BITFIELD_MASK(num_components);
289 
290    nir_def *local_ids = nir_channels(b, nir_load_local_invocation_id(b), mask);
291    nir_def *block_ids = nir_channels(b, nir_load_workgroup_id(b), mask);
292    nir_def *block_size = nir_channels(b, nir_load_workgroup_size(b), mask);
293 
294    assert(bit_size == 32 || bit_size == 16);
295    if (bit_size == 16) {
296       local_ids = nir_i2iN(b, local_ids, bit_size);
297       block_ids = nir_i2iN(b, block_ids, bit_size);
298       block_size = nir_i2iN(b, block_size, bit_size);
299    }
300 
301    return nir_iadd(b, nir_imul(b, block_ids, block_size), local_ids);
302 }
303 
304 unsigned
ac_nir_varying_expression_max_cost(nir_shader * producer,nir_shader * consumer)305 ac_nir_varying_expression_max_cost(nir_shader *producer, nir_shader *consumer)
306 {
307    switch (consumer->info.stage) {
308    case MESA_SHADER_TESS_CTRL:
309       /* VS->TCS
310        * Non-amplifying shaders can always have their varying expressions
311        * moved into later shaders.
312        */
313       return UINT_MAX;
314 
315    case MESA_SHADER_GEOMETRY:
316       /* VS->GS, TES->GS */
317       return consumer->info.gs.vertices_in == 1 ? UINT_MAX :
318              consumer->info.gs.vertices_in == 2 ? 20 : 14;
319 
320    case MESA_SHADER_TESS_EVAL:
321       /* TCS->TES and VS->TES (OpenGL only) */
322    case MESA_SHADER_FRAGMENT:
323       /* Up to 3 uniforms and 5 ALUs. */
324       return 12;
325 
326    default:
327       unreachable("unexpected shader stage");
328    }
329 }
330 
331 bool
ac_nir_optimize_uniform_atomics(nir_shader * nir)332 ac_nir_optimize_uniform_atomics(nir_shader *nir)
333 {
334    bool progress = false;
335    NIR_PASS(progress, nir, ac_nir_opt_shared_append);
336 
337    nir_divergence_analysis(nir);
338    NIR_PASS(progress, nir, nir_opt_uniform_atomics, false);
339 
340    return progress;
341 }
342 
343 unsigned
ac_nir_lower_bit_size_callback(const nir_instr * instr,void * data)344 ac_nir_lower_bit_size_callback(const nir_instr *instr, void *data)
345 {
346    enum amd_gfx_level chip = *(enum amd_gfx_level *)data;
347 
348    if (instr->type != nir_instr_type_alu)
349       return 0;
350    nir_alu_instr *alu = nir_instr_as_alu(instr);
351 
352    /* If an instruction is not scalarized by this point,
353     * it can be emitted as packed instruction */
354    if (alu->def.num_components > 1)
355       return 0;
356 
357    if (alu->def.bit_size & (8 | 16)) {
358       unsigned bit_size = alu->def.bit_size;
359       switch (alu->op) {
360       case nir_op_bitfield_select:
361       case nir_op_imul_high:
362       case nir_op_umul_high:
363       case nir_op_uadd_carry:
364       case nir_op_usub_borrow:
365          return 32;
366       case nir_op_iabs:
367       case nir_op_imax:
368       case nir_op_umax:
369       case nir_op_imin:
370       case nir_op_umin:
371       case nir_op_ishr:
372       case nir_op_ushr:
373       case nir_op_ishl:
374       case nir_op_isign:
375       case nir_op_uadd_sat:
376       case nir_op_usub_sat:
377          return (bit_size == 8 || !(chip >= GFX8 && alu->def.divergent)) ? 32 : 0;
378       case nir_op_iadd_sat:
379       case nir_op_isub_sat:
380          return bit_size == 8 || !alu->def.divergent ? 32 : 0;
381 
382       default:
383          return 0;
384       }
385    }
386 
387    if (nir_src_bit_size(alu->src[0].src) & (8 | 16)) {
388       unsigned bit_size = nir_src_bit_size(alu->src[0].src);
389       switch (alu->op) {
390       case nir_op_bit_count:
391       case nir_op_find_lsb:
392       case nir_op_ufind_msb:
393          return 32;
394       case nir_op_ilt:
395       case nir_op_ige:
396       case nir_op_ieq:
397       case nir_op_ine:
398       case nir_op_ult:
399       case nir_op_uge:
400       case nir_op_bitz:
401       case nir_op_bitnz:
402          return (bit_size == 8 || !(chip >= GFX8 && alu->def.divergent)) ? 32 : 0;
403       default:
404          return 0;
405       }
406    }
407 
408    return 0;
409 }
410 
411 static unsigned
align_load_store_size(enum amd_gfx_level gfx_level,unsigned size,bool uses_smem,bool is_shared)412 align_load_store_size(enum amd_gfx_level gfx_level, unsigned size, bool uses_smem, bool is_shared)
413 {
414    /* LDS can't overfetch because accesses that are partially out of range would be dropped
415     * entirely, so all unaligned LDS accesses are always split.
416     */
417    if (is_shared)
418       return size;
419 
420    /* Align the size to what the hw supports. Out of range access due to alignment is OK because
421     * range checking is per dword for untyped instructions. This assumes that the compiler backend
422     * overfetches due to load size alignment instead of splitting the load.
423     *
424     * GFX6-11 don't have 96-bit SMEM loads.
425     * GFX6 doesn't have 96-bit untyped VMEM loads.
426     */
427    if (gfx_level >= (uses_smem ? GFX12 : GFX7) && size == 96)
428       return size;
429    else
430       return util_next_power_of_two(size);
431 }
432 
433 bool
ac_nir_mem_vectorize_callback(unsigned align_mul,unsigned align_offset,unsigned bit_size,unsigned num_components,int64_t hole_size,nir_intrinsic_instr * low,nir_intrinsic_instr * high,void * data)434 ac_nir_mem_vectorize_callback(unsigned align_mul, unsigned align_offset, unsigned bit_size,
435                               unsigned num_components, int64_t hole_size, nir_intrinsic_instr *low,
436                               nir_intrinsic_instr *high, void *data)
437 {
438    struct ac_nir_config *config = (struct ac_nir_config *)data;
439    bool uses_smem = (nir_intrinsic_has_access(low) &&
440                      nir_intrinsic_access(low) & ACCESS_SMEM_AMD) ||
441                     /* These don't have the "access" field. */
442                     low->intrinsic == nir_intrinsic_load_smem_amd ||
443                     low->intrinsic == nir_intrinsic_load_push_constant;
444    bool is_store = !nir_intrinsic_infos[low->intrinsic].has_dest;
445    bool is_scratch = low->intrinsic == nir_intrinsic_load_stack ||
446                      low->intrinsic == nir_intrinsic_store_stack ||
447                      low->intrinsic == nir_intrinsic_load_scratch ||
448                      low->intrinsic == nir_intrinsic_store_scratch;
449    bool is_shared = low->intrinsic == nir_intrinsic_load_shared ||
450                     low->intrinsic == nir_intrinsic_store_shared ||
451                     low->intrinsic == nir_intrinsic_load_deref ||
452                     low->intrinsic == nir_intrinsic_store_deref;
453 
454    assert(!is_store || hole_size <= 0);
455 
456    /* If we get derefs here, only shared memory derefs are expected. */
457    assert((low->intrinsic != nir_intrinsic_load_deref &&
458            low->intrinsic != nir_intrinsic_store_deref) ||
459           nir_deref_mode_is(nir_src_as_deref(low->src[0]), nir_var_mem_shared));
460 
461    /* Don't vectorize descriptor loads for LLVM due to excessive SGPR and VGPR spilling. */
462    if (!config->uses_aco && low->intrinsic == nir_intrinsic_load_smem_amd)
463       return false;
464 
465    /* Reject opcodes we don't vectorize. */
466    switch (low->intrinsic) {
467    case nir_intrinsic_load_smem_amd:
468    case nir_intrinsic_load_push_constant:
469    case nir_intrinsic_load_ubo:
470    case nir_intrinsic_load_stack:
471    case nir_intrinsic_store_stack:
472    case nir_intrinsic_load_scratch:
473    case nir_intrinsic_store_scratch:
474    case nir_intrinsic_load_global_constant:
475    case nir_intrinsic_load_global:
476    case nir_intrinsic_store_global:
477    case nir_intrinsic_load_ssbo:
478    case nir_intrinsic_store_ssbo:
479    case nir_intrinsic_load_deref:
480    case nir_intrinsic_store_deref:
481    case nir_intrinsic_load_shared:
482    case nir_intrinsic_store_shared:
483       break;
484    default:
485       return false;
486    }
487 
488    /* Align the size to what the hw supports. */
489    unsigned unaligned_new_size = num_components * bit_size;
490    unsigned aligned_new_size = align_load_store_size(config->gfx_level, unaligned_new_size,
491                                                      uses_smem, is_shared);
492 
493    if (uses_smem) {
494       /* Maximize SMEM vectorization except for LLVM, which suffers from SGPR and VGPR spilling.
495        * GFX6-7 have fewer hw SGPRs, so merge only up to 128 bits to limit SGPR usage.
496        */
497       if (aligned_new_size > (config->gfx_level >= GFX8 ? (config->uses_aco ? 512 : 256) : 128))
498          return false;
499    } else {
500       if (aligned_new_size > 128)
501          return false;
502 
503       /* GFX6-8 only support 32-bit scratch loads/stores. */
504       if (config->gfx_level <= GFX8 && is_scratch && aligned_new_size > 32)
505          return false;
506    }
507 
508    if (!is_store) {
509       /* Non-descriptor loads. */
510       if (low->intrinsic != nir_intrinsic_load_ubo &&
511           low->intrinsic != nir_intrinsic_load_ssbo) {
512          /* Only increase the size of loads if doing so doesn't extend into a new page.
513           * Here we set alignment to MAX because we don't know the alignment of global
514           * pointers before adding the offset.
515           */
516          uint32_t resource_align = low->intrinsic == nir_intrinsic_load_global_constant ||
517                                    low->intrinsic == nir_intrinsic_load_global ? NIR_ALIGN_MUL_MAX : 4;
518          uint32_t page_size = 4096;
519          uint32_t mul = MIN3(align_mul, page_size, resource_align);
520          unsigned end = (align_offset + unaligned_new_size / 8u) & (mul - 1);
521          if ((aligned_new_size - unaligned_new_size) / 8u > (mul - end))
522             return false;
523       }
524 
525       /* Only allow SMEM loads to overfetch by 32 bits:
526        *
527        * Examples (the hole is indicated by parentheses, the numbers are  in bytes, the maximum
528        * overfetch size is 4):
529        *    4  | (4) | 4   ->  hw loads 12  : ALLOWED    (4 over)
530        *    4  | (4) | 4   ->  hw loads 16  : DISALLOWED (8 over)
531        *    4  |  4  | 4   ->  hw loads 16  : ALLOWED    (4 over)
532        *    4  | (4) | 8   ->  hw loads 16  : ALLOWED    (4 over)
533        *    16 |  4        ->  hw loads 32  : DISALLOWED (12 over)
534        *    16 |  8        ->  hw loads 32  : DISALLOWED (8 over)
535        *    16 | 12        ->  hw loads 32  : ALLOWED    (4 over)
536        *    16 | (4) | 12  ->  hw loads 32  : ALLOWED    (4 over)
537        *    32 | 16        ->  hw loads 64  : DISALLOWED (16 over)
538        *    32 | 28        ->  hw loads 64  : ALLOWED    (4 over)
539        *    32 | (4) | 28  ->  hw loads 64  : ALLOWED    (4 over)
540        *
541        * Note that we can overfetch by more than 4 bytes if we merge more than 2 loads, e.g.:
542        *    4  | (4) | 8 | (4) | 12  ->  hw loads 32  : ALLOWED (4 + 4 over)
543        *
544        * That's because this callback is called twice in that case, each time allowing only 4 over.
545        *
546        * This is only enabled for ACO. LLVM spills SGPRs and VGPRs too much.
547        */
548       unsigned overfetch_size = 0;
549 
550       if (config->uses_aco && uses_smem && aligned_new_size >= 128)
551          overfetch_size = 32;
552 
553       int64_t aligned_unvectorized_size =
554          align_load_store_size(config->gfx_level, low->num_components * low->def.bit_size,
555                                uses_smem, is_shared) +
556          align_load_store_size(config->gfx_level, high->num_components * high->def.bit_size,
557                                uses_smem, is_shared);
558 
559       if (aligned_new_size > aligned_unvectorized_size + overfetch_size)
560          return false;
561    }
562 
563    uint32_t align;
564    if (align_offset)
565       align = 1 << (ffs(align_offset) - 1);
566    else
567       align = align_mul;
568 
569    /* Validate the alignment and number of components. */
570    if (!is_shared) {
571       unsigned max_components;
572       if (align % 4 == 0)
573          max_components = NIR_MAX_VEC_COMPONENTS;
574       else if (align % 2 == 0)
575          max_components = 16u / bit_size;
576       else
577          max_components = 8u / bit_size;
578       return (align % (bit_size / 8u)) == 0 && num_components <= max_components;
579    } else {
580       if (bit_size * num_components == 96) { /* 96 bit loads require 128 bit alignment and are split otherwise */
581          return align % 16 == 0;
582       } else if (bit_size == 16 && (align % 4)) {
583          /* AMD hardware can't do 2-byte aligned f16vec2 loads, but they are useful for ALU
584           * vectorization, because our vectorizer requires the scalar IR to already contain vectors.
585           */
586          return (align % 2 == 0) && num_components <= 2;
587       } else {
588          if (num_components == 3) {
589             /* AMD hardware can't do 3-component loads except for 96-bit loads, handled above. */
590             return false;
591          }
592          unsigned req = bit_size * num_components;
593          if (req == 64 || req == 128) /* 64-bit and 128-bit loads can use ds_read2_b{32,64} */
594             req /= 2u;
595          return align % (req / 8u) == 0;
596       }
597    }
598    return false;
599 }
600 
ac_nir_scalarize_overfetching_loads_callback(const nir_instr * instr,const void * data)601 bool ac_nir_scalarize_overfetching_loads_callback(const nir_instr *instr, const void *data)
602 {
603    nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
604 
605    /* Reject opcodes we don't scalarize. */
606    switch (intr->intrinsic) {
607    case nir_intrinsic_load_ubo:
608    case nir_intrinsic_load_ssbo:
609    case nir_intrinsic_load_global:
610    case nir_intrinsic_load_global_constant:
611    case nir_intrinsic_load_shared:
612       break;
613    default:
614       return false;
615    }
616 
617    bool uses_smem = nir_intrinsic_has_access(intr) &&
618                     nir_intrinsic_access(intr) & ACCESS_SMEM_AMD;
619    bool is_shared = intr->intrinsic == nir_intrinsic_load_shared;
620 
621    enum amd_gfx_level gfx_level = *(enum amd_gfx_level *)data;
622    unsigned comp_size = intr->def.bit_size / 8;
623    unsigned load_size = intr->def.num_components * comp_size;
624    unsigned used_load_size = util_bitcount(nir_def_components_read(&intr->def)) * comp_size;
625 
626    /* Scalarize if the load overfetches. That includes loads that overfetch due to load size
627     * alignment, e.g. when only a power-of-two load is available. The scalarized loads are expected
628     * to be later vectorized to optimal sizes.
629     */
630    return used_load_size < align_load_store_size(gfx_level, load_size, uses_smem, is_shared);
631 }
632 
633 /* Get chip-agnostic memory instruction access flags (as opposed to chip-specific GLC/DLC/SLC)
634  * from a NIR memory intrinsic.
635  */
ac_nir_get_mem_access_flags(const nir_intrinsic_instr * instr)636 enum gl_access_qualifier ac_nir_get_mem_access_flags(const nir_intrinsic_instr *instr)
637 {
638    enum gl_access_qualifier access =
639       nir_intrinsic_has_access(instr) ? nir_intrinsic_access(instr) : 0;
640 
641    /* Determine ACCESS_MAY_STORE_SUBDWORD. (for the GFX6 TC L1 bug workaround) */
642    if (!nir_intrinsic_infos[instr->intrinsic].has_dest) {
643       switch (instr->intrinsic) {
644       case nir_intrinsic_bindless_image_store:
645          access |= ACCESS_MAY_STORE_SUBDWORD;
646          break;
647 
648       case nir_intrinsic_store_ssbo:
649       case nir_intrinsic_store_buffer_amd:
650       case nir_intrinsic_store_global:
651       case nir_intrinsic_store_global_amd:
652          if (access & ACCESS_USES_FORMAT_AMD ||
653              (nir_intrinsic_has_align_offset(instr) && nir_intrinsic_align(instr) % 4 != 0) ||
654              ((instr->src[0].ssa->bit_size / 8) * instr->src[0].ssa->num_components) % 4 != 0)
655             access |= ACCESS_MAY_STORE_SUBDWORD;
656          break;
657 
658       default:
659          unreachable("unexpected store instruction");
660       }
661    }
662 
663    return access;
664 }
665 
666 /**
667  * Computes a horizontal sum of 8-bit packed values loaded from LDS.
668  *
669  * Each lane N will sum packed bytes 0 to N.
670  * We only care about the results from up to wave_id lanes.
671  * (Other lanes are not deactivated but their calculation is not used.)
672  */
673 static nir_def *
summarize_repack(nir_builder * b,nir_def * packed_counts,bool mask_lane_id,unsigned num_lds_dwords)674 summarize_repack(nir_builder *b, nir_def *packed_counts, bool mask_lane_id, unsigned num_lds_dwords)
675 {
676    /* We'll use shift to filter out the bytes not needed by the current lane.
677     *
678     * For each row:
679     * Need to shift by: `num_lds_dwords * 4 - 1 - lane_id_in_row` (in bytes)
680     * in order to implement an inclusive scan.
681     *
682     * When v_dot4_u32_u8 is available, we right-shift a series of 0x01 bytes.
683     * This will yield 0x01 at wanted byte positions and 0x00 at unwanted positions,
684     * therefore v_dot can get rid of the unneeded values.
685     *
686     * If the v_dot instruction can't be used, we left-shift the packed bytes
687     * in order to shift out the unneeded bytes and shift in zeroes instead,
688     * then we sum them using v_msad_u8.
689     */
690 
691    nir_def *lane_id = nir_load_subgroup_invocation(b);
692 
693    /* Mask lane ID so that lanes 16...31 also have the ID 0...15,
694     * in order to perform a second horizontal sum in parallel when needed.
695     */
696    if (mask_lane_id)
697       lane_id = nir_iand_imm(b, lane_id, 0xf);
698 
699    nir_def *shift = nir_iadd_imm(b, nir_imul_imm(b, lane_id, -8u), num_lds_dwords * 32 - 8);
700    assert(b->shader->options->has_msad || b->shader->options->has_udot_4x8);
701    bool use_dot = b->shader->options->has_udot_4x8;
702 
703    if (num_lds_dwords == 1) {
704       /* Broadcast the packed data we read from LDS
705        * (to the first 16 lanes of the row, but we only care up to num_waves).
706        */
707       nir_def *packed = nir_lane_permute_16_amd(b, packed_counts, nir_imm_int(b, 0), nir_imm_int(b, 0));
708 
709       /* Horizontally add the packed bytes. */
710       if (use_dot) {
711          nir_def *dot_op = nir_ushr(b, nir_imm_int(b, 0x01010101), shift);
712          return nir_udot_4x8_uadd(b, packed, dot_op, nir_imm_int(b, 0));
713       } else {
714          nir_def *sad_op = nir_ishl(b, packed, shift);
715          return nir_msad_4x8(b, sad_op, nir_imm_int(b, 0), nir_imm_int(b, 0));
716       }
717    } else if (num_lds_dwords == 2) {
718       /* Broadcast the packed data we read from LDS
719        * (to the first 16 lanes of the row, but we only care up to num_waves).
720        */
721       nir_def *packed_dw0 = nir_lane_permute_16_amd(b, nir_unpack_64_2x32_split_x(b, packed_counts), nir_imm_int(b, 0), nir_imm_int(b, 0));
722       nir_def *packed_dw1 = nir_lane_permute_16_amd(b, nir_unpack_64_2x32_split_y(b, packed_counts), nir_imm_int(b, 0), nir_imm_int(b, 0));
723 
724       /* Horizontally add the packed bytes. */
725       if (use_dot) {
726          nir_def *dot_op = nir_ushr(b, nir_imm_int64(b, 0x0101010101010101), shift);
727          nir_def *sum = nir_udot_4x8_uadd(b, packed_dw0, nir_unpack_64_2x32_split_x(b, dot_op), nir_imm_int(b, 0));
728          return nir_udot_4x8_uadd(b, packed_dw1, nir_unpack_64_2x32_split_y(b, dot_op), sum);
729       } else {
730          nir_def *sad_op = nir_ishl(b, nir_pack_64_2x32_split(b, packed_dw0, packed_dw1), shift);
731          nir_def *sum = nir_msad_4x8(b, nir_unpack_64_2x32_split_x(b, sad_op), nir_imm_int(b, 0), nir_imm_int(b, 0));
732          return nir_msad_4x8(b, nir_unpack_64_2x32_split_y(b, sad_op), nir_imm_int(b, 0), sum);
733       }
734    } else {
735       unreachable("Unimplemented NGG wave count");
736    }
737 }
738 
739 /**
740  * Repacks invocations in the current workgroup to eliminate gaps between them.
741  *
742  * Uses 1 dword of LDS per 4 waves (1 byte of LDS per wave) for each repack.
743  * Assumes that all invocations in the workgroup are active (exec = -1).
744  */
745 void
ac_nir_repack_invocations_in_workgroup(nir_builder * b,nir_def ** input_bool,ac_nir_wg_repack_result * results,const unsigned num_repacks,nir_def * lds_addr_base,unsigned max_num_waves,unsigned wave_size)746 ac_nir_repack_invocations_in_workgroup(nir_builder *b, nir_def **input_bool,
747                                        ac_nir_wg_repack_result *results, const unsigned num_repacks,
748                                        nir_def *lds_addr_base, unsigned max_num_waves,
749                                        unsigned wave_size)
750 {
751    /* We can currently only do up to 2 repacks at a time. */
752    assert(num_repacks <= 2);
753 
754    /* STEP 1. Count surviving invocations in the current wave.
755     *
756     * Implemented by a scalar instruction that simply counts the number of bits set in a 32/64-bit mask.
757     */
758 
759    nir_def *input_mask[2];
760    nir_def *surviving_invocations_in_current_wave[2];
761 
762    for (unsigned i = 0; i < num_repacks; ++i) {
763       /* Input should be boolean: 1 if the current invocation should survive the repack. */
764       assert(input_bool[i]->bit_size == 1);
765 
766       input_mask[i] = nir_ballot(b, 1, wave_size, input_bool[i]);
767       surviving_invocations_in_current_wave[i] = nir_bit_count(b, input_mask[i]);
768    }
769 
770    /* If we know at compile time that the workgroup has only 1 wave, no further steps are necessary. */
771    if (max_num_waves == 1) {
772       for (unsigned i = 0; i < num_repacks; ++i) {
773          results[i].num_repacked_invocations = surviving_invocations_in_current_wave[i];
774          results[i].repacked_invocation_index = nir_mbcnt_amd(b, input_mask[i], nir_imm_int(b, 0));
775       }
776       return;
777    }
778 
779    /* STEP 2. Waves tell each other their number of surviving invocations.
780     *
781     * Row 0 (lanes 0-15) performs the first repack, and Row 1 (lanes 16-31) the second in parallel.
782     * Each wave activates only its first lane per row, which stores the number of surviving
783     * invocations in that wave into the LDS for that repack, then reads the numbers from every wave.
784     *
785     * The workgroup size of NGG shaders is at most 256, which means
786     * the maximum number of waves is 4 in Wave64 mode and 8 in Wave32 mode.
787     * For each repack:
788     * Each wave writes 1 byte, so it's up to 8 bytes, so at most 2 dwords are necessary.
789     * (The maximum is 4 dwords for 2 repacks in Wave32 mode.)
790     */
791 
792    const unsigned num_lds_dwords = DIV_ROUND_UP(max_num_waves, 4);
793    assert(num_lds_dwords <= 2);
794 
795    /* The first lane of each row (per repack) needs to access the LDS. */
796    const unsigned ballot = num_repacks == 1 ? 1 : 0x10001;
797 
798    nir_def *wave_id = nir_load_subgroup_id(b);
799    nir_def *dont_care = nir_undef(b, 1, num_lds_dwords * 32);
800    nir_def *packed_counts = NULL;
801 
802    nir_if *if_use_lds = nir_push_if(b, nir_inverse_ballot(b, 1, nir_imm_intN_t(b, ballot, wave_size)));
803    {
804       nir_def *store_val = surviving_invocations_in_current_wave[0];
805 
806       if (num_repacks == 2) {
807          nir_def *lane_id_0 = nir_inverse_ballot(b, 1, nir_imm_intN_t(b, 1, wave_size));
808          nir_def *off = nir_bcsel(b, lane_id_0, nir_imm_int(b, 0), nir_imm_int(b, num_lds_dwords * 4));
809          lds_addr_base = nir_iadd_nuw(b, lds_addr_base, off);
810          store_val = nir_bcsel(b, lane_id_0, store_val, surviving_invocations_in_current_wave[1]);
811       }
812 
813       nir_def *store_byte = nir_u2u8(b, store_val);
814       nir_def *lds_offset = nir_iadd(b, lds_addr_base, wave_id);
815       nir_store_shared(b, store_byte, lds_offset);
816 
817       nir_barrier(b, .execution_scope = SCOPE_WORKGROUP, .memory_scope = SCOPE_WORKGROUP,
818                      .memory_semantics = NIR_MEMORY_ACQ_REL, .memory_modes = nir_var_mem_shared);
819 
820       packed_counts = nir_load_shared(b, 1, num_lds_dwords * 32, lds_addr_base, .align_mul = 8u);
821    }
822    nir_pop_if(b, if_use_lds);
823 
824    packed_counts = nir_if_phi(b, packed_counts, dont_care);
825 
826    /* STEP 3. Compute the repacked invocation index and the total number of surviving invocations.
827     *
828     * By now, every wave knows the number of surviving invocations in all waves.
829     * Each number is 1 byte, and they are packed into up to 2 dwords.
830     *
831     * For each row (of 16 lanes):
832     * Each lane N (in the row) will sum the number of surviving invocations inclusively from waves 0 to N.
833     * If the workgroup has M waves, then each row will use only its first M lanes for this.
834     * (Other lanes are not deactivated but their calculation is not used.)
835     *
836     * - We read the sum from the lane whose id  (in the row) is the current wave's id,
837     *   and subtract the number of its own surviving invocations.
838     *   Add the masked bitcount to this, and we get the repacked invocation index.
839     * - We read the sum from the lane whose id (in the row) is the number of waves in the workgroup minus 1.
840     *   This is the total number of surviving invocations in the workgroup.
841     */
842 
843    nir_def *num_waves = nir_load_num_subgroups(b);
844    nir_def *sum = summarize_repack(b, packed_counts, num_repacks == 2, num_lds_dwords);
845 
846    for (unsigned i = 0; i < num_repacks; ++i) {
847       nir_def *index_base_lane = nir_iadd_imm_nuw(b, wave_id, i * 16);
848       nir_def *num_invocartions_lane = nir_iadd_imm(b, num_waves, i * 16 - 1);
849       nir_def *wg_repacked_index_base =
850          nir_isub(b, nir_read_invocation(b, sum, index_base_lane), surviving_invocations_in_current_wave[i]);
851       results[i].num_repacked_invocations =
852          nir_read_invocation(b, sum, num_invocartions_lane);
853       results[i].repacked_invocation_index =
854          nir_mbcnt_amd(b, input_mask[i], wg_repacked_index_base);
855    }
856 }
857