1 /*
2 * Copyright © 2020 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24 #include "util/u_dynarray.h"
25 #include "util/u_math.h"
26 #include "nir.h"
27 #include "nir_builder.h"
28 #include "nir_phi_builder.h"
29
30 static bool
move_system_values_to_top(nir_shader * shader)31 move_system_values_to_top(nir_shader *shader)
32 {
33 nir_function_impl *impl = nir_shader_get_entrypoint(shader);
34
35 bool progress = false;
36 nir_foreach_block(block, impl) {
37 nir_foreach_instr_safe(instr, block) {
38 if (instr->type != nir_instr_type_intrinsic)
39 continue;
40
41 /* These intrinsics not only can't be re-materialized but aren't
42 * preserved when moving to the continuation shader. We have to move
43 * them to the top to ensure they get spilled as needed.
44 */
45 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
46 switch (intrin->intrinsic) {
47 case nir_intrinsic_load_shader_record_ptr:
48 case nir_intrinsic_load_btd_local_arg_addr_intel:
49 nir_instr_remove(instr);
50 nir_instr_insert(nir_before_impl(impl), instr);
51 progress = true;
52 break;
53
54 default:
55 break;
56 }
57 }
58 }
59
60 if (progress) {
61 nir_metadata_preserve(impl, nir_metadata_block_index |
62 nir_metadata_dominance);
63 } else {
64 nir_metadata_preserve(impl, nir_metadata_all);
65 }
66
67 return progress;
68 }
69
70 static bool
instr_is_shader_call(nir_instr * instr)71 instr_is_shader_call(nir_instr *instr)
72 {
73 if (instr->type != nir_instr_type_intrinsic)
74 return false;
75
76 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
77 return intrin->intrinsic == nir_intrinsic_trace_ray ||
78 intrin->intrinsic == nir_intrinsic_report_ray_intersection ||
79 intrin->intrinsic == nir_intrinsic_execute_callable;
80 }
81
82 /* Previously named bitset, it had to be renamed as FreeBSD defines a struct
83 * named bitset in sys/_bitset.h required by pthread_np.h which is included
84 * from src/util/u_thread.h that is indirectly included by this file.
85 */
86 struct sized_bitset {
87 BITSET_WORD *set;
88 unsigned size;
89 };
90
91 static struct sized_bitset
bitset_create(void * mem_ctx,unsigned size)92 bitset_create(void *mem_ctx, unsigned size)
93 {
94 return (struct sized_bitset){
95 .set = rzalloc_array(mem_ctx, BITSET_WORD, BITSET_WORDS(size)),
96 .size = size,
97 };
98 }
99
100 static bool
src_is_in_bitset(nir_src * src,void * _set)101 src_is_in_bitset(nir_src *src, void *_set)
102 {
103 struct sized_bitset *set = _set;
104
105 /* Any SSA values which were added after we generated liveness information
106 * are things generated by this pass and, while most of it is arithmetic
107 * which we could re-materialize, we don't need to because it's only used
108 * for a single load/store and so shouldn't cross any shader calls.
109 */
110 if (src->ssa->index >= set->size)
111 return false;
112
113 return BITSET_TEST(set->set, src->ssa->index);
114 }
115
116 static void
add_ssa_def_to_bitset(nir_def * def,struct sized_bitset * set)117 add_ssa_def_to_bitset(nir_def *def, struct sized_bitset *set)
118 {
119 if (def->index >= set->size)
120 return;
121
122 BITSET_SET(set->set, def->index);
123 }
124
125 static bool
can_remat_instr(nir_instr * instr,struct sized_bitset * remat)126 can_remat_instr(nir_instr *instr, struct sized_bitset *remat)
127 {
128 /* Set of all values which are trivially re-materializable and we shouldn't
129 * ever spill them. This includes:
130 *
131 * - Undef values
132 * - Constants
133 * - Uniforms (UBO or push constant)
134 * - ALU combinations of any of the above
135 * - Derefs which are either complete or casts of any of the above
136 *
137 * Because this pass rewrites things in-order and phis are always turned
138 * into register writes, we can use "is it SSA?" to answer the question
139 * "can my source be re-materialized?". Register writes happen via
140 * non-rematerializable intrinsics.
141 */
142 switch (instr->type) {
143 case nir_instr_type_alu:
144 return nir_foreach_src(instr, src_is_in_bitset, remat);
145
146 case nir_instr_type_deref:
147 return nir_foreach_src(instr, src_is_in_bitset, remat);
148
149 case nir_instr_type_intrinsic: {
150 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
151 switch (intrin->intrinsic) {
152 case nir_intrinsic_load_uniform:
153 case nir_intrinsic_load_ubo:
154 case nir_intrinsic_vulkan_resource_index:
155 case nir_intrinsic_vulkan_resource_reindex:
156 case nir_intrinsic_load_vulkan_descriptor:
157 case nir_intrinsic_load_push_constant:
158 case nir_intrinsic_load_global_constant:
159 case nir_intrinsic_load_global_const_block_intel:
160 case nir_intrinsic_load_desc_set_address_intel:
161 /* These intrinsics don't need to be spilled as long as they don't
162 * depend on any spilled values.
163 */
164 return nir_foreach_src(instr, src_is_in_bitset, remat);
165
166 case nir_intrinsic_load_scratch_base_ptr:
167 case nir_intrinsic_load_ray_launch_id:
168 case nir_intrinsic_load_topology_id_intel:
169 case nir_intrinsic_load_btd_global_arg_addr_intel:
170 case nir_intrinsic_load_btd_resume_sbt_addr_intel:
171 case nir_intrinsic_load_ray_base_mem_addr_intel:
172 case nir_intrinsic_load_ray_hw_stack_size_intel:
173 case nir_intrinsic_load_ray_sw_stack_size_intel:
174 case nir_intrinsic_load_ray_num_dss_rt_stacks_intel:
175 case nir_intrinsic_load_ray_hit_sbt_addr_intel:
176 case nir_intrinsic_load_ray_hit_sbt_stride_intel:
177 case nir_intrinsic_load_ray_miss_sbt_addr_intel:
178 case nir_intrinsic_load_ray_miss_sbt_stride_intel:
179 case nir_intrinsic_load_callable_sbt_addr_intel:
180 case nir_intrinsic_load_callable_sbt_stride_intel:
181 case nir_intrinsic_load_reloc_const_intel:
182 case nir_intrinsic_load_ray_query_global_intel:
183 case nir_intrinsic_load_ray_launch_size:
184 /* Notably missing from the above list is btd_local_arg_addr_intel.
185 * This is because the resume shader will have a different local
186 * argument pointer because it has a different BSR. Any access of
187 * the original shader's local arguments needs to be preserved so
188 * that pointer has to be saved on the stack.
189 *
190 * TODO: There may be some system values we want to avoid
191 * re-materializing as well but we have to be very careful
192 * to ensure that it's a system value which cannot change
193 * across a shader call.
194 */
195 return true;
196
197 case nir_intrinsic_resource_intel:
198 return nir_foreach_src(instr, src_is_in_bitset, remat);
199
200 default:
201 return false;
202 }
203 }
204
205 case nir_instr_type_undef:
206 case nir_instr_type_load_const:
207 return true;
208
209 default:
210 return false;
211 }
212 }
213
214 static bool
can_remat_ssa_def(nir_def * def,struct sized_bitset * remat)215 can_remat_ssa_def(nir_def *def, struct sized_bitset *remat)
216 {
217 return can_remat_instr(def->parent_instr, remat);
218 }
219
220 struct add_instr_data {
221 struct util_dynarray *buf;
222 struct sized_bitset *remat;
223 };
224
225 static bool
add_src_instr(nir_src * src,void * state)226 add_src_instr(nir_src *src, void *state)
227 {
228 struct add_instr_data *data = state;
229 if (BITSET_TEST(data->remat->set, src->ssa->index))
230 return true;
231
232 util_dynarray_foreach(data->buf, nir_instr *, instr_ptr) {
233 if (*instr_ptr == src->ssa->parent_instr)
234 return true;
235 }
236
237 /* Abort rematerializing an instruction chain if it is too long. */
238 if (data->buf->size >= data->buf->capacity)
239 return false;
240
241 util_dynarray_append(data->buf, nir_instr *, src->ssa->parent_instr);
242 return true;
243 }
244
245 static int
compare_instr_indexes(const void * _inst1,const void * _inst2)246 compare_instr_indexes(const void *_inst1, const void *_inst2)
247 {
248 const nir_instr *const *inst1 = _inst1;
249 const nir_instr *const *inst2 = _inst2;
250
251 return (*inst1)->index - (*inst2)->index;
252 }
253
254 static bool
can_remat_chain_ssa_def(nir_def * def,struct sized_bitset * remat,struct util_dynarray * buf)255 can_remat_chain_ssa_def(nir_def *def, struct sized_bitset *remat, struct util_dynarray *buf)
256 {
257 assert(util_dynarray_num_elements(buf, nir_instr *) == 0);
258
259 void *mem_ctx = ralloc_context(NULL);
260
261 /* Add all the instructions involved in build this ssa_def */
262 util_dynarray_append(buf, nir_instr *, def->parent_instr);
263
264 unsigned idx = 0;
265 struct add_instr_data data = {
266 .buf = buf,
267 .remat = remat,
268 };
269 while (idx < util_dynarray_num_elements(buf, nir_instr *)) {
270 nir_instr *instr = *util_dynarray_element(buf, nir_instr *, idx++);
271 if (!nir_foreach_src(instr, add_src_instr, &data))
272 goto fail;
273 }
274
275 /* Sort instructions by index */
276 qsort(util_dynarray_begin(buf),
277 util_dynarray_num_elements(buf, nir_instr *),
278 sizeof(nir_instr *),
279 compare_instr_indexes);
280
281 /* Create a temporary bitset with all values already
282 * rematerialized/rematerializable. We'll add to this bit set as we go
283 * through values that might not be in that set but that we can
284 * rematerialize.
285 */
286 struct sized_bitset potential_remat = bitset_create(mem_ctx, remat->size);
287 memcpy(potential_remat.set, remat->set, BITSET_WORDS(remat->size) * sizeof(BITSET_WORD));
288
289 util_dynarray_foreach(buf, nir_instr *, instr_ptr) {
290 nir_def *instr_ssa_def = nir_instr_def(*instr_ptr);
291
292 /* If already in the potential rematerializable, nothing to do. */
293 if (BITSET_TEST(potential_remat.set, instr_ssa_def->index))
294 continue;
295
296 if (!can_remat_instr(*instr_ptr, &potential_remat))
297 goto fail;
298
299 /* All the sources are rematerializable and the instruction is also
300 * rematerializable, mark it as rematerializable too.
301 */
302 BITSET_SET(potential_remat.set, instr_ssa_def->index);
303 }
304
305 ralloc_free(mem_ctx);
306
307 return true;
308
309 fail:
310 util_dynarray_clear(buf);
311 ralloc_free(mem_ctx);
312 return false;
313 }
314
315 static nir_def *
remat_ssa_def(nir_builder * b,nir_def * def,struct hash_table * remap_table)316 remat_ssa_def(nir_builder *b, nir_def *def, struct hash_table *remap_table)
317 {
318 nir_instr *clone = nir_instr_clone_deep(b->shader, def->parent_instr, remap_table);
319 nir_builder_instr_insert(b, clone);
320 return nir_instr_def(clone);
321 }
322
323 static nir_def *
remat_chain_ssa_def(nir_builder * b,struct util_dynarray * buf,struct sized_bitset * remat,nir_def *** fill_defs,unsigned call_idx,struct hash_table * remap_table)324 remat_chain_ssa_def(nir_builder *b, struct util_dynarray *buf,
325 struct sized_bitset *remat, nir_def ***fill_defs,
326 unsigned call_idx, struct hash_table *remap_table)
327 {
328 nir_def *last_def = NULL;
329
330 util_dynarray_foreach(buf, nir_instr *, instr_ptr) {
331 nir_def *instr_ssa_def = nir_instr_def(*instr_ptr);
332 unsigned ssa_index = instr_ssa_def->index;
333
334 if (fill_defs[ssa_index] != NULL &&
335 fill_defs[ssa_index][call_idx] != NULL)
336 continue;
337
338 /* Clone the instruction we want to rematerialize */
339 nir_def *clone_ssa_def = remat_ssa_def(b, instr_ssa_def, remap_table);
340
341 if (fill_defs[ssa_index] == NULL) {
342 fill_defs[ssa_index] =
343 rzalloc_array(fill_defs, nir_def *, remat->size);
344 }
345
346 /* Add the new ssa_def to the list fill_defs and flag it as
347 * rematerialized
348 */
349 fill_defs[ssa_index][call_idx] = last_def = clone_ssa_def;
350 BITSET_SET(remat->set, ssa_index);
351
352 _mesa_hash_table_insert(remap_table, instr_ssa_def, last_def);
353 }
354
355 return last_def;
356 }
357
358 struct pbv_array {
359 struct nir_phi_builder_value **arr;
360 unsigned len;
361 };
362
363 static struct nir_phi_builder_value *
get_phi_builder_value_for_def(nir_def * def,struct pbv_array * pbv_arr)364 get_phi_builder_value_for_def(nir_def *def,
365 struct pbv_array *pbv_arr)
366 {
367 if (def->index >= pbv_arr->len)
368 return NULL;
369
370 return pbv_arr->arr[def->index];
371 }
372
373 static nir_def *
get_phi_builder_def_for_src(nir_src * src,struct pbv_array * pbv_arr,nir_block * block)374 get_phi_builder_def_for_src(nir_src *src, struct pbv_array *pbv_arr,
375 nir_block *block)
376 {
377
378 struct nir_phi_builder_value *pbv =
379 get_phi_builder_value_for_def(src->ssa, pbv_arr);
380 if (pbv == NULL)
381 return NULL;
382
383 return nir_phi_builder_value_get_block_def(pbv, block);
384 }
385
386 static bool
rewrite_instr_src_from_phi_builder(nir_src * src,void * _pbv_arr)387 rewrite_instr_src_from_phi_builder(nir_src *src, void *_pbv_arr)
388 {
389 nir_block *block;
390 if (nir_src_parent_instr(src)->type == nir_instr_type_phi) {
391 nir_phi_src *phi_src = exec_node_data(nir_phi_src, src, src);
392 block = phi_src->pred;
393 } else {
394 block = nir_src_parent_instr(src)->block;
395 }
396
397 nir_def *new_def = get_phi_builder_def_for_src(src, _pbv_arr, block);
398 if (new_def != NULL)
399 nir_src_rewrite(src, new_def);
400 return true;
401 }
402
403 static nir_def *
spill_fill(nir_builder * before,nir_builder * after,nir_def * def,unsigned value_id,unsigned call_idx,unsigned offset,unsigned stack_alignment)404 spill_fill(nir_builder *before, nir_builder *after, nir_def *def,
405 unsigned value_id, unsigned call_idx,
406 unsigned offset, unsigned stack_alignment)
407 {
408 const unsigned comp_size = def->bit_size / 8;
409
410 nir_store_stack(before, def,
411 .base = offset,
412 .call_idx = call_idx,
413 .align_mul = MIN2(comp_size, stack_alignment),
414 .value_id = value_id,
415 .write_mask = BITFIELD_MASK(def->num_components));
416 return nir_load_stack(after, def->num_components, def->bit_size,
417 .base = offset,
418 .call_idx = call_idx,
419 .value_id = value_id,
420 .align_mul = MIN2(comp_size, stack_alignment));
421 }
422
423 static bool
add_src_to_call_live_bitset(nir_src * src,void * state)424 add_src_to_call_live_bitset(nir_src *src, void *state)
425 {
426 BITSET_WORD *call_live = state;
427
428 BITSET_SET(call_live, src->ssa->index);
429 return true;
430 }
431
432 static void
spill_ssa_defs_and_lower_shader_calls(nir_shader * shader,uint32_t num_calls,const nir_lower_shader_calls_options * options)433 spill_ssa_defs_and_lower_shader_calls(nir_shader *shader, uint32_t num_calls,
434 const nir_lower_shader_calls_options *options)
435 {
436 /* TODO: If a SSA def is filled more than once, we probably want to just
437 * spill it at the LCM of the fill sites so we avoid unnecessary
438 * extra spills
439 *
440 * TODO: If a SSA def is defined outside a loop but live through some call
441 * inside the loop, we probably want to spill outside the loop. We
442 * may also want to fill outside the loop if it's not used in the
443 * loop.
444 *
445 * TODO: Right now, we only re-materialize things if their immediate
446 * sources are things which we filled. We probably want to expand
447 * that to re-materialize things whose sources are things we can
448 * re-materialize from things we filled. We may want some DAG depth
449 * heuristic on this.
450 */
451
452 /* This happens per-shader rather than per-impl because we mess with
453 * nir_shader::scratch_size.
454 */
455 nir_function_impl *impl = nir_shader_get_entrypoint(shader);
456
457 nir_metadata_require(impl, nir_metadata_live_defs |
458 nir_metadata_dominance |
459 nir_metadata_block_index |
460 nir_metadata_instr_index);
461
462 void *mem_ctx = ralloc_context(shader);
463
464 const unsigned num_ssa_defs = impl->ssa_alloc;
465 const unsigned live_words = BITSET_WORDS(num_ssa_defs);
466 struct sized_bitset trivial_remat = bitset_create(mem_ctx, num_ssa_defs);
467
468 /* Array of all live SSA defs which are spill candidates */
469 nir_def **spill_defs =
470 rzalloc_array(mem_ctx, nir_def *, num_ssa_defs);
471
472 /* For each spill candidate, an array of every time it's defined by a fill,
473 * indexed by call instruction index.
474 */
475 nir_def ***fill_defs =
476 rzalloc_array(mem_ctx, nir_def **, num_ssa_defs);
477
478 /* For each call instruction, the liveness set at the call */
479 const BITSET_WORD **call_live =
480 rzalloc_array(mem_ctx, const BITSET_WORD *, num_calls);
481
482 /* For each call instruction, the block index of the block it lives in */
483 uint32_t *call_block_indices = rzalloc_array(mem_ctx, uint32_t, num_calls);
484
485 /* Remap table when rebuilding instructions out of fill operations */
486 struct hash_table *trivial_remap_table =
487 _mesa_pointer_hash_table_create(mem_ctx);
488
489 /* Walk the call instructions and fetch the liveness set and block index
490 * for each one. We need to do this before we start modifying the shader
491 * so that liveness doesn't complain that it's been invalidated. Don't
492 * worry, we'll be very careful with our live sets. :-)
493 */
494 unsigned call_idx = 0;
495 nir_foreach_block(block, impl) {
496 nir_foreach_instr(instr, block) {
497 if (!instr_is_shader_call(instr))
498 continue;
499
500 call_block_indices[call_idx] = block->index;
501
502 /* The objective here is to preserve values around shader call
503 * instructions. Therefore, we use the live set after the
504 * instruction as the set of things we want to preserve. Because
505 * none of our shader call intrinsics return anything, we don't have
506 * to worry about spilling over a return value.
507 *
508 * TODO: This isn't quite true for report_intersection.
509 */
510 call_live[call_idx] =
511 nir_get_live_defs(nir_after_instr(instr), mem_ctx);
512
513 call_idx++;
514 }
515 }
516
517 /* If a should_remat_callback is given, call it on each of the live values
518 * for each call site. If it returns true we need to rematerialize that
519 * instruction (instead of spill/fill). Therefore we need to add the
520 * sources as live values so that we can rematerialize on top of those
521 * spilled/filled sources.
522 */
523 if (options->should_remat_callback) {
524 BITSET_WORD **updated_call_live =
525 rzalloc_array(mem_ctx, BITSET_WORD *, num_calls);
526
527 nir_foreach_block(block, impl) {
528 nir_foreach_instr(instr, block) {
529 nir_def *def = nir_instr_def(instr);
530 if (def == NULL)
531 continue;
532
533 for (unsigned c = 0; c < num_calls; c++) {
534 if (!BITSET_TEST(call_live[c], def->index))
535 continue;
536
537 if (!options->should_remat_callback(def->parent_instr,
538 options->should_remat_data))
539 continue;
540
541 if (updated_call_live[c] == NULL) {
542 const unsigned bitset_words = BITSET_WORDS(impl->ssa_alloc);
543 updated_call_live[c] = ralloc_array(mem_ctx, BITSET_WORD, bitset_words);
544 memcpy(updated_call_live[c], call_live[c], bitset_words * sizeof(BITSET_WORD));
545 }
546
547 nir_foreach_src(instr, add_src_to_call_live_bitset, updated_call_live[c]);
548 }
549 }
550 }
551
552 for (unsigned c = 0; c < num_calls; c++) {
553 if (updated_call_live[c] != NULL)
554 call_live[c] = updated_call_live[c];
555 }
556 }
557
558 nir_builder before, after;
559 before = nir_builder_create(impl);
560 after = nir_builder_create(impl);
561
562 call_idx = 0;
563 unsigned max_scratch_size = shader->scratch_size;
564 nir_foreach_block(block, impl) {
565 nir_foreach_instr_safe(instr, block) {
566 nir_def *def = nir_instr_def(instr);
567 if (def != NULL) {
568 if (can_remat_ssa_def(def, &trivial_remat)) {
569 add_ssa_def_to_bitset(def, &trivial_remat);
570 _mesa_hash_table_insert(trivial_remap_table, def, def);
571 } else {
572 spill_defs[def->index] = def;
573 }
574 }
575
576 if (!instr_is_shader_call(instr))
577 continue;
578
579 const BITSET_WORD *live = call_live[call_idx];
580
581 struct hash_table *remap_table =
582 _mesa_hash_table_clone(trivial_remap_table, mem_ctx);
583
584 /* Make a copy of trivial_remat that we'll update as we crawl through
585 * the live SSA defs and unspill them.
586 */
587 struct sized_bitset remat = bitset_create(mem_ctx, num_ssa_defs);
588 memcpy(remat.set, trivial_remat.set, live_words * sizeof(BITSET_WORD));
589
590 /* Before the two builders are always separated by the call
591 * instruction, it won't break anything to have two of them.
592 */
593 before.cursor = nir_before_instr(instr);
594 after.cursor = nir_after_instr(instr);
595
596 /* Array used to hold all the values needed to rematerialize a live
597 * value. The capacity is used to determine when we should abort testing
598 * a remat chain. In practice, shaders can have chains with more than
599 * 10k elements while only chains with less than 16 have realistic
600 * chances. There also isn't any performance benefit in rematerializing
601 * extremely long chains.
602 */
603 nir_instr *remat_chain_instrs[16];
604 struct util_dynarray remat_chain;
605 util_dynarray_init_from_stack(&remat_chain, remat_chain_instrs, sizeof(remat_chain_instrs));
606
607 unsigned offset = shader->scratch_size;
608 for (unsigned w = 0; w < live_words; w++) {
609 BITSET_WORD spill_mask = live[w] & ~trivial_remat.set[w];
610 while (spill_mask) {
611 int i = u_bit_scan(&spill_mask);
612 assert(i >= 0);
613 unsigned index = w * BITSET_WORDBITS + i;
614 assert(index < num_ssa_defs);
615
616 def = spill_defs[index];
617 nir_def *original_def = def, *new_def;
618 if (can_remat_ssa_def(def, &remat)) {
619 /* If this SSA def is re-materializable or based on other
620 * things we've already spilled, re-materialize it rather
621 * than spilling and filling. Anything which is trivially
622 * re-materializable won't even get here because we take
623 * those into account in spill_mask above.
624 */
625 new_def = remat_ssa_def(&after, def, remap_table);
626 } else if (can_remat_chain_ssa_def(def, &remat, &remat_chain)) {
627 new_def = remat_chain_ssa_def(&after, &remat_chain, &remat,
628 fill_defs, call_idx,
629 remap_table);
630 util_dynarray_clear(&remat_chain);
631 } else {
632 bool is_bool = def->bit_size == 1;
633 if (is_bool)
634 def = nir_b2b32(&before, def);
635
636 const unsigned comp_size = def->bit_size / 8;
637 offset = ALIGN(offset, comp_size);
638
639 new_def = spill_fill(&before, &after, def,
640 index, call_idx,
641 offset, options->stack_alignment);
642
643 if (is_bool)
644 new_def = nir_b2b1(&after, new_def);
645
646 offset += def->num_components * comp_size;
647 }
648
649 /* Mark this SSA def as available in the remat set so that, if
650 * some other SSA def we need is computed based on it, we can
651 * just re-compute instead of fetching from memory.
652 */
653 BITSET_SET(remat.set, index);
654
655 /* For now, we just make a note of this new SSA def. We'll
656 * fix things up with the phi builder as a second pass.
657 */
658 if (fill_defs[index] == NULL) {
659 fill_defs[index] =
660 rzalloc_array(fill_defs, nir_def *, num_calls);
661 }
662 fill_defs[index][call_idx] = new_def;
663 _mesa_hash_table_insert(remap_table, original_def, new_def);
664 }
665 }
666
667 nir_builder *b = &before;
668
669 offset = ALIGN(offset, options->stack_alignment);
670 max_scratch_size = MAX2(max_scratch_size, offset);
671
672 /* First thing on the called shader's stack is the resume address
673 * followed by a pointer to the payload.
674 */
675 nir_intrinsic_instr *call = nir_instr_as_intrinsic(instr);
676
677 /* Lower to generic intrinsics with information about the stack & resume shader. */
678 switch (call->intrinsic) {
679 case nir_intrinsic_trace_ray: {
680 nir_rt_trace_ray(b, call->src[0].ssa, call->src[1].ssa,
681 call->src[2].ssa, call->src[3].ssa,
682 call->src[4].ssa, call->src[5].ssa,
683 call->src[6].ssa, call->src[7].ssa,
684 call->src[8].ssa, call->src[9].ssa,
685 call->src[10].ssa,
686 .call_idx = call_idx, .stack_size = offset);
687 break;
688 }
689
690 case nir_intrinsic_report_ray_intersection:
691 unreachable("Any-hit shaders must be inlined");
692
693 case nir_intrinsic_execute_callable: {
694 nir_rt_execute_callable(b, call->src[0].ssa, call->src[1].ssa, .call_idx = call_idx, .stack_size = offset);
695 break;
696 }
697
698 default:
699 unreachable("Invalid shader call instruction");
700 }
701
702 nir_rt_resume(b, .call_idx = call_idx, .stack_size = offset);
703
704 nir_instr_remove(&call->instr);
705
706 call_idx++;
707 }
708 }
709 assert(call_idx == num_calls);
710 shader->scratch_size = max_scratch_size;
711
712 struct nir_phi_builder *pb = nir_phi_builder_create(impl);
713 struct pbv_array pbv_arr = {
714 .arr = rzalloc_array(mem_ctx, struct nir_phi_builder_value *,
715 num_ssa_defs),
716 .len = num_ssa_defs,
717 };
718
719 const unsigned block_words = BITSET_WORDS(impl->num_blocks);
720 BITSET_WORD *def_blocks = ralloc_array(mem_ctx, BITSET_WORD, block_words);
721
722 /* Go through and set up phi builder values for each spillable value which
723 * we ever needed to spill at any point.
724 */
725 for (unsigned index = 0; index < num_ssa_defs; index++) {
726 if (fill_defs[index] == NULL)
727 continue;
728
729 nir_def *def = spill_defs[index];
730
731 memset(def_blocks, 0, block_words * sizeof(BITSET_WORD));
732 BITSET_SET(def_blocks, def->parent_instr->block->index);
733 for (unsigned call_idx = 0; call_idx < num_calls; call_idx++) {
734 if (fill_defs[index][call_idx] != NULL)
735 BITSET_SET(def_blocks, call_block_indices[call_idx]);
736 }
737
738 pbv_arr.arr[index] = nir_phi_builder_add_value(pb, def->num_components,
739 def->bit_size, def_blocks);
740 }
741
742 /* Walk the shader one more time and rewrite SSA defs as needed using the
743 * phi builder.
744 */
745 nir_foreach_block(block, impl) {
746 nir_foreach_instr_safe(instr, block) {
747 nir_def *def = nir_instr_def(instr);
748 if (def != NULL) {
749 struct nir_phi_builder_value *pbv =
750 get_phi_builder_value_for_def(def, &pbv_arr);
751 if (pbv != NULL)
752 nir_phi_builder_value_set_block_def(pbv, block, def);
753 }
754
755 if (instr->type == nir_instr_type_phi)
756 continue;
757
758 nir_foreach_src(instr, rewrite_instr_src_from_phi_builder, &pbv_arr);
759
760 if (instr->type != nir_instr_type_intrinsic)
761 continue;
762
763 nir_intrinsic_instr *resume = nir_instr_as_intrinsic(instr);
764 if (resume->intrinsic != nir_intrinsic_rt_resume)
765 continue;
766
767 call_idx = nir_intrinsic_call_idx(resume);
768
769 /* Technically, this is the wrong place to add the fill defs to the
770 * phi builder values because we haven't seen any of the load_scratch
771 * instructions for this call yet. However, we know based on how we
772 * emitted them that no value ever gets used until after the load
773 * instruction has been emitted so this should be safe. If we ever
774 * fail validation due this it likely means a bug in our spilling
775 * code and not the phi re-construction code here.
776 */
777 for (unsigned index = 0; index < num_ssa_defs; index++) {
778 if (fill_defs[index] && fill_defs[index][call_idx]) {
779 nir_phi_builder_value_set_block_def(pbv_arr.arr[index], block,
780 fill_defs[index][call_idx]);
781 }
782 }
783 }
784
785 nir_if *following_if = nir_block_get_following_if(block);
786 if (following_if) {
787 nir_def *new_def =
788 get_phi_builder_def_for_src(&following_if->condition,
789 &pbv_arr, block);
790 if (new_def != NULL)
791 nir_src_rewrite(&following_if->condition, new_def);
792 }
793
794 /* Handle phi sources that source from this block. We have to do this
795 * as a separate pass because the phi builder assumes that uses and
796 * defs are processed in an order that respects dominance. When we have
797 * loops, a phi source may be a back-edge so we have to handle it as if
798 * it were one of the last instructions in the predecessor block.
799 */
800 nir_foreach_phi_src_leaving_block(block,
801 rewrite_instr_src_from_phi_builder,
802 &pbv_arr);
803 }
804
805 nir_phi_builder_finish(pb);
806
807 ralloc_free(mem_ctx);
808
809 nir_metadata_preserve(impl, nir_metadata_block_index |
810 nir_metadata_dominance);
811 }
812
813 static nir_instr *
find_resume_instr(nir_function_impl * impl,unsigned call_idx)814 find_resume_instr(nir_function_impl *impl, unsigned call_idx)
815 {
816 nir_foreach_block(block, impl) {
817 nir_foreach_instr(instr, block) {
818 if (instr->type != nir_instr_type_intrinsic)
819 continue;
820
821 nir_intrinsic_instr *resume = nir_instr_as_intrinsic(instr);
822 if (resume->intrinsic != nir_intrinsic_rt_resume)
823 continue;
824
825 if (nir_intrinsic_call_idx(resume) == call_idx)
826 return &resume->instr;
827 }
828 }
829 unreachable("Couldn't find resume instruction");
830 }
831
832 /* Walk the CF tree and duplicate the contents of every loop, one half runs on
833 * resume and the other half is for any post-resume loop iterations. We are
834 * careful in our duplication to ensure that resume_instr is in the resume
835 * half of the loop though a copy of resume_instr will remain in the other
836 * half as well in case the same shader call happens twice.
837 */
838 static bool
duplicate_loop_bodies(nir_function_impl * impl,nir_instr * resume_instr)839 duplicate_loop_bodies(nir_function_impl *impl, nir_instr *resume_instr)
840 {
841 nir_def *resume_reg = NULL;
842 for (nir_cf_node *node = resume_instr->block->cf_node.parent;
843 node->type != nir_cf_node_function; node = node->parent) {
844 if (node->type != nir_cf_node_loop)
845 continue;
846
847 nir_loop *loop = nir_cf_node_as_loop(node);
848 assert(!nir_loop_has_continue_construct(loop));
849
850 nir_builder b = nir_builder_create(impl);
851
852 if (resume_reg == NULL) {
853 /* We only create resume_reg if we encounter a loop. This way we can
854 * avoid re-validating the shader and calling ssa_to_reg_intrinsics in
855 * the case where it's just if-ladders.
856 */
857 resume_reg = nir_decl_reg(&b, 1, 1, 0);
858
859 /* Initialize resume to true at the start of the shader, right after
860 * the register is declared at the start.
861 */
862 b.cursor = nir_after_instr(resume_reg->parent_instr);
863 nir_store_reg(&b, nir_imm_true(&b), resume_reg);
864
865 /* Set resume to false right after the resume instruction */
866 b.cursor = nir_after_instr(resume_instr);
867 nir_store_reg(&b, nir_imm_false(&b), resume_reg);
868 }
869
870 /* Before we go any further, make sure that everything which exits the
871 * loop or continues around to the top of the loop does so through
872 * registers. We're about to duplicate the loop body and we'll have
873 * serious trouble if we don't do this.
874 */
875 nir_convert_loop_to_lcssa(loop);
876 nir_lower_phis_to_regs_block(nir_loop_first_block(loop));
877 nir_lower_phis_to_regs_block(
878 nir_cf_node_as_block(nir_cf_node_next(&loop->cf_node)));
879
880 nir_cf_list cf_list;
881 nir_cf_list_extract(&cf_list, &loop->body);
882
883 nir_if *_if = nir_if_create(impl->function->shader);
884 b.cursor = nir_after_cf_list(&loop->body);
885 _if->condition = nir_src_for_ssa(nir_load_reg(&b, resume_reg));
886 nir_cf_node_insert(nir_after_cf_list(&loop->body), &_if->cf_node);
887
888 nir_cf_list clone;
889 nir_cf_list_clone(&clone, &cf_list, &loop->cf_node, NULL);
890
891 /* Insert the clone in the else and the original in the then so that
892 * the resume_instr remains valid even after the duplication.
893 */
894 nir_cf_reinsert(&cf_list, nir_before_cf_list(&_if->then_list));
895 nir_cf_reinsert(&clone, nir_before_cf_list(&_if->else_list));
896 }
897
898 if (resume_reg != NULL)
899 nir_metadata_preserve(impl, nir_metadata_none);
900
901 return resume_reg != NULL;
902 }
903
904 static bool
cf_node_contains_block(nir_cf_node * node,nir_block * block)905 cf_node_contains_block(nir_cf_node *node, nir_block *block)
906 {
907 for (nir_cf_node *n = &block->cf_node; n != NULL; n = n->parent) {
908 if (n == node)
909 return true;
910 }
911
912 return false;
913 }
914
915 static void
rewrite_phis_to_pred(nir_block * block,nir_block * pred)916 rewrite_phis_to_pred(nir_block *block, nir_block *pred)
917 {
918 nir_foreach_phi(phi, block) {
919 ASSERTED bool found = false;
920 nir_foreach_phi_src(phi_src, phi) {
921 if (phi_src->pred == pred) {
922 found = true;
923 nir_def_rewrite_uses(&phi->def, phi_src->src.ssa);
924 break;
925 }
926 }
927 assert(found);
928 }
929 }
930
931 static bool
cursor_is_after_jump(nir_cursor cursor)932 cursor_is_after_jump(nir_cursor cursor)
933 {
934 switch (cursor.option) {
935 case nir_cursor_before_instr:
936 case nir_cursor_before_block:
937 return false;
938 case nir_cursor_after_instr:
939 return cursor.instr->type == nir_instr_type_jump;
940 case nir_cursor_after_block:
941 return nir_block_ends_in_jump(cursor.block);
942 ;
943 }
944 unreachable("Invalid cursor option");
945 }
946
947 /** Flattens if ladders leading up to a resume
948 *
949 * Given a resume_instr, this function flattens any if ladders leading to the
950 * resume instruction and deletes any code that cannot be encountered on a
951 * direct path to the resume instruction. This way we get, for the most part,
952 * straight-line control-flow up to the resume instruction.
953 *
954 * While we do this flattening, we also move any code which is in the remat
955 * set up to the top of the function or to the top of the resume portion of
956 * the current loop. We don't worry about control-flow as we do this because
957 * phis will never be in the remat set (see can_remat_instr) and so nothing
958 * control-dependent will ever need to be re-materialized. It is possible
959 * that this algorithm will preserve too many instructions by moving them to
960 * the top but we leave that for DCE to clean up. Any code not in the remat
961 * set is deleted because it's either unused in the continuation or else
962 * unspilled from a previous continuation and the unspill code is after the
963 * resume instruction.
964 *
965 * If, for instance, we have something like this:
966 *
967 * // block 0
968 * if (cond1) {
969 * // block 1
970 * } else {
971 * // block 2
972 * if (cond2) {
973 * // block 3
974 * resume;
975 * if (cond3) {
976 * // block 4
977 * }
978 * } else {
979 * // block 5
980 * }
981 * }
982 *
983 * then we know, because we know the resume instruction had to be encoutered,
984 * that cond1 = false and cond2 = true and we lower as follows:
985 *
986 * // block 0
987 * // block 2
988 * // block 3
989 * resume;
990 * if (cond3) {
991 * // block 4
992 * }
993 *
994 * As you can see, the code in blocks 1 and 5 was removed because there is no
995 * path from the start of the shader to the resume instruction which execute
996 * blocks 1 or 5. Any remat code from blocks 0, 2, and 3 is preserved and
997 * moved to the top. If the resume instruction is inside a loop then we know
998 * a priori that it is of the form
999 *
1000 * loop {
1001 * if (resume) {
1002 * // Contents containing resume_instr
1003 * } else {
1004 * // Second copy of contents
1005 * }
1006 * }
1007 *
1008 * In this case, we only descend into the first half of the loop. The second
1009 * half is left alone as that portion is only ever executed after the resume
1010 * instruction.
1011 */
1012 static bool
flatten_resume_if_ladder(nir_builder * b,nir_cf_node * parent_node,struct exec_list * child_list,bool child_list_contains_cursor,nir_instr * resume_instr,struct sized_bitset * remat)1013 flatten_resume_if_ladder(nir_builder *b,
1014 nir_cf_node *parent_node,
1015 struct exec_list *child_list,
1016 bool child_list_contains_cursor,
1017 nir_instr *resume_instr,
1018 struct sized_bitset *remat)
1019 {
1020 nir_cf_list cf_list;
1021
1022 /* If our child list contains the cursor instruction then we start out
1023 * before the cursor instruction. We need to know this so that we can skip
1024 * moving instructions which are already before the cursor.
1025 */
1026 bool before_cursor = child_list_contains_cursor;
1027
1028 nir_cf_node *resume_node = NULL;
1029 foreach_list_typed_safe(nir_cf_node, child, node, child_list) {
1030 switch (child->type) {
1031 case nir_cf_node_block: {
1032 nir_block *block = nir_cf_node_as_block(child);
1033 if (b->cursor.option == nir_cursor_before_block &&
1034 b->cursor.block == block) {
1035 assert(before_cursor);
1036 before_cursor = false;
1037 }
1038 nir_foreach_instr_safe(instr, block) {
1039 if ((b->cursor.option == nir_cursor_before_instr ||
1040 b->cursor.option == nir_cursor_after_instr) &&
1041 b->cursor.instr == instr) {
1042 assert(nir_cf_node_is_first(&block->cf_node));
1043 assert(before_cursor);
1044 before_cursor = false;
1045 continue;
1046 }
1047
1048 if (instr == resume_instr)
1049 goto found_resume;
1050
1051 if (!before_cursor && can_remat_instr(instr, remat)) {
1052 nir_instr_remove(instr);
1053 nir_instr_insert(b->cursor, instr);
1054 b->cursor = nir_after_instr(instr);
1055
1056 nir_def *def = nir_instr_def(instr);
1057 BITSET_SET(remat->set, def->index);
1058 }
1059 }
1060 if (b->cursor.option == nir_cursor_after_block &&
1061 b->cursor.block == block) {
1062 assert(before_cursor);
1063 before_cursor = false;
1064 }
1065 break;
1066 }
1067
1068 case nir_cf_node_if: {
1069 assert(!before_cursor);
1070 nir_if *_if = nir_cf_node_as_if(child);
1071 if (flatten_resume_if_ladder(b, &_if->cf_node, &_if->then_list,
1072 false, resume_instr, remat)) {
1073 resume_node = child;
1074 rewrite_phis_to_pred(nir_cf_node_as_block(nir_cf_node_next(child)),
1075 nir_if_last_then_block(_if));
1076 goto found_resume;
1077 }
1078
1079 if (flatten_resume_if_ladder(b, &_if->cf_node, &_if->else_list,
1080 false, resume_instr, remat)) {
1081 resume_node = child;
1082 rewrite_phis_to_pred(nir_cf_node_as_block(nir_cf_node_next(child)),
1083 nir_if_last_else_block(_if));
1084 goto found_resume;
1085 }
1086 break;
1087 }
1088
1089 case nir_cf_node_loop: {
1090 assert(!before_cursor);
1091 nir_loop *loop = nir_cf_node_as_loop(child);
1092 assert(!nir_loop_has_continue_construct(loop));
1093
1094 if (cf_node_contains_block(&loop->cf_node, resume_instr->block)) {
1095 /* Thanks to our loop body duplication pass, every level of loop
1096 * containing the resume instruction contains exactly three nodes:
1097 * two blocks and an if. We don't want to lower away this if
1098 * because it's the resume selection if. The resume half is
1099 * always the then_list so that's what we want to flatten.
1100 */
1101 nir_block *header = nir_loop_first_block(loop);
1102 nir_if *_if = nir_cf_node_as_if(nir_cf_node_next(&header->cf_node));
1103
1104 /* We want to place anything re-materialized from inside the loop
1105 * at the top of the resume half of the loop.
1106 */
1107 nir_builder bl = nir_builder_at(nir_before_cf_list(&_if->then_list));
1108
1109 ASSERTED bool found =
1110 flatten_resume_if_ladder(&bl, &_if->cf_node, &_if->then_list,
1111 true, resume_instr, remat);
1112 assert(found);
1113 resume_node = child;
1114 goto found_resume;
1115 } else {
1116 ASSERTED bool found =
1117 flatten_resume_if_ladder(b, &loop->cf_node, &loop->body,
1118 false, resume_instr, remat);
1119 assert(!found);
1120 }
1121 break;
1122 }
1123
1124 case nir_cf_node_function:
1125 unreachable("Unsupported CF node type");
1126 }
1127 }
1128 assert(!before_cursor);
1129
1130 /* If we got here, we didn't find the resume node or instruction. */
1131 return false;
1132
1133 found_resume:
1134 /* If we got here then we found either the resume node or the resume
1135 * instruction in this CF list.
1136 */
1137 if (resume_node) {
1138 /* If the resume instruction is buried in side one of our children CF
1139 * nodes, resume_node now points to that child.
1140 */
1141 if (resume_node->type == nir_cf_node_if) {
1142 /* Thanks to the recursive call, all of the interesting contents of
1143 * resume_node have been copied before the cursor. We just need to
1144 * copy the stuff after resume_node.
1145 */
1146 nir_cf_extract(&cf_list, nir_after_cf_node(resume_node),
1147 nir_after_cf_list(child_list));
1148 } else {
1149 /* The loop contains its own cursor and still has useful stuff in it.
1150 * We want to move everything after and including the loop to before
1151 * the cursor.
1152 */
1153 assert(resume_node->type == nir_cf_node_loop);
1154 nir_cf_extract(&cf_list, nir_before_cf_node(resume_node),
1155 nir_after_cf_list(child_list));
1156 }
1157 } else {
1158 /* If we found the resume instruction in one of our blocks, grab
1159 * everything after it in the entire list (not just the one block), and
1160 * place it before the cursor instr.
1161 */
1162 nir_cf_extract(&cf_list, nir_after_instr(resume_instr),
1163 nir_after_cf_list(child_list));
1164 }
1165
1166 /* If the resume instruction is in the first block of the child_list,
1167 * and the cursor is still before that block, the nir_cf_extract() may
1168 * extract the block object pointed by the cursor, and instead create
1169 * a new one for the code before the resume. In such case the cursor
1170 * will be broken, as it will point to a block which is no longer
1171 * in a function.
1172 *
1173 * Luckily, in both cases when this is possible, the intended cursor
1174 * position is right before the child_list, so we can fix the cursor here.
1175 */
1176 if (child_list_contains_cursor &&
1177 b->cursor.option == nir_cursor_before_block &&
1178 b->cursor.block->cf_node.parent == NULL)
1179 b->cursor = nir_before_cf_list(child_list);
1180
1181 if (cursor_is_after_jump(b->cursor)) {
1182 /* If the resume instruction is in a loop, it's possible cf_list ends
1183 * in a break or continue instruction, in which case we don't want to
1184 * insert anything. It's also possible we have an early return if
1185 * someone hasn't lowered those yet. In either case, nothing after that
1186 * point executes in this context so we can delete it.
1187 */
1188 nir_cf_delete(&cf_list);
1189 } else {
1190 b->cursor = nir_cf_reinsert(&cf_list, b->cursor);
1191 }
1192
1193 if (!resume_node) {
1194 /* We want the resume to be the first "interesting" instruction */
1195 nir_instr_remove(resume_instr);
1196 nir_instr_insert(nir_before_impl(b->impl), resume_instr);
1197 }
1198
1199 /* We've copied everything interesting out of this CF list to before the
1200 * cursor. Delete everything else.
1201 */
1202 if (child_list_contains_cursor) {
1203 nir_cf_extract(&cf_list, b->cursor, nir_after_cf_list(child_list));
1204 } else {
1205 nir_cf_list_extract(&cf_list, child_list);
1206 }
1207 nir_cf_delete(&cf_list);
1208
1209 return true;
1210 }
1211
1212 typedef bool (*wrap_instr_callback)(nir_instr *instr);
1213
1214 static bool
wrap_instr(nir_builder * b,nir_instr * instr,void * data)1215 wrap_instr(nir_builder *b, nir_instr *instr, void *data)
1216 {
1217 wrap_instr_callback callback = data;
1218 if (!callback(instr))
1219 return false;
1220
1221 b->cursor = nir_before_instr(instr);
1222
1223 nir_if *_if = nir_push_if(b, nir_imm_true(b));
1224 nir_pop_if(b, NULL);
1225
1226 nir_cf_list cf_list;
1227 nir_cf_extract(&cf_list, nir_before_instr(instr), nir_after_instr(instr));
1228 nir_cf_reinsert(&cf_list, nir_before_block(nir_if_first_then_block(_if)));
1229
1230 return true;
1231 }
1232
1233 /* This pass wraps jump instructions in a dummy if block so that when
1234 * flatten_resume_if_ladder() does its job, it doesn't move a jump instruction
1235 * directly in front of another instruction which the NIR control flow helpers
1236 * do not allow.
1237 */
1238 static bool
wrap_instrs(nir_shader * shader,wrap_instr_callback callback)1239 wrap_instrs(nir_shader *shader, wrap_instr_callback callback)
1240 {
1241 return nir_shader_instructions_pass(shader, wrap_instr,
1242 nir_metadata_none, callback);
1243 }
1244
1245 static bool
instr_is_jump(nir_instr * instr)1246 instr_is_jump(nir_instr *instr)
1247 {
1248 return instr->type == nir_instr_type_jump;
1249 }
1250
1251 static nir_instr *
lower_resume(nir_shader * shader,int call_idx)1252 lower_resume(nir_shader *shader, int call_idx)
1253 {
1254 wrap_instrs(shader, instr_is_jump);
1255
1256 nir_function_impl *impl = nir_shader_get_entrypoint(shader);
1257 nir_instr *resume_instr = find_resume_instr(impl, call_idx);
1258
1259 if (duplicate_loop_bodies(impl, resume_instr)) {
1260 nir_validate_shader(shader, "after duplicate_loop_bodies in "
1261 "nir_lower_shader_calls");
1262 /* If we duplicated the bodies of any loops, run reg_intrinsics_to_ssa to
1263 * get rid of all those pesky registers we just added.
1264 */
1265 NIR_PASS_V(shader, nir_lower_reg_intrinsics_to_ssa);
1266 }
1267
1268 /* Re-index nir_def::index. We don't care about actual liveness in
1269 * this pass but, so we can use the same helpers as the spilling pass, we
1270 * need to make sure that live_index is something sane. It's used
1271 * constantly for determining if an SSA value has been added since the
1272 * start of the pass.
1273 */
1274 nir_index_ssa_defs(impl);
1275
1276 void *mem_ctx = ralloc_context(shader);
1277
1278 /* Used to track which things may have been assumed to be re-materialized
1279 * by the spilling pass and which we shouldn't delete.
1280 */
1281 struct sized_bitset remat = bitset_create(mem_ctx, impl->ssa_alloc);
1282
1283 /* Create a nop instruction to use as a cursor as we extract and re-insert
1284 * stuff into the CFG.
1285 */
1286 nir_builder b = nir_builder_at(nir_before_impl(impl));
1287 ASSERTED bool found =
1288 flatten_resume_if_ladder(&b, &impl->cf_node, &impl->body,
1289 true, resume_instr, &remat);
1290 assert(found);
1291
1292 ralloc_free(mem_ctx);
1293
1294 nir_metadata_preserve(impl, nir_metadata_none);
1295
1296 nir_validate_shader(shader, "after flatten_resume_if_ladder in "
1297 "nir_lower_shader_calls");
1298
1299 return resume_instr;
1300 }
1301
1302 static void
replace_resume_with_halt(nir_shader * shader,nir_instr * keep)1303 replace_resume_with_halt(nir_shader *shader, nir_instr *keep)
1304 {
1305 nir_function_impl *impl = nir_shader_get_entrypoint(shader);
1306
1307 nir_builder b = nir_builder_create(impl);
1308
1309 nir_foreach_block_safe(block, impl) {
1310 nir_foreach_instr_safe(instr, block) {
1311 if (instr == keep)
1312 continue;
1313
1314 if (instr->type != nir_instr_type_intrinsic)
1315 continue;
1316
1317 nir_intrinsic_instr *resume = nir_instr_as_intrinsic(instr);
1318 if (resume->intrinsic != nir_intrinsic_rt_resume)
1319 continue;
1320
1321 /* If this is some other resume, then we've kicked off a ray or
1322 * bindless thread and we don't want to go any further in this
1323 * shader. Insert a halt so that NIR will delete any instructions
1324 * dominated by this call instruction including the scratch_load
1325 * instructions we inserted.
1326 */
1327 nir_cf_list cf_list;
1328 nir_cf_extract(&cf_list, nir_after_instr(&resume->instr),
1329 nir_after_block(block));
1330 nir_cf_delete(&cf_list);
1331 b.cursor = nir_instr_remove(&resume->instr);
1332 nir_jump(&b, nir_jump_halt);
1333 break;
1334 }
1335 }
1336 }
1337
1338 struct lower_scratch_state {
1339 nir_address_format address_format;
1340 };
1341
1342 static bool
lower_stack_instr_to_scratch(struct nir_builder * b,nir_instr * instr,void * data)1343 lower_stack_instr_to_scratch(struct nir_builder *b, nir_instr *instr, void *data)
1344 {
1345 struct lower_scratch_state *state = data;
1346
1347 if (instr->type != nir_instr_type_intrinsic)
1348 return false;
1349
1350 nir_intrinsic_instr *stack = nir_instr_as_intrinsic(instr);
1351 switch (stack->intrinsic) {
1352 case nir_intrinsic_load_stack: {
1353 b->cursor = nir_instr_remove(instr);
1354 nir_def *data, *old_data = nir_instr_def(instr);
1355
1356 if (state->address_format == nir_address_format_64bit_global) {
1357 nir_def *addr = nir_iadd_imm(b,
1358 nir_load_scratch_base_ptr(b, 1, 64, 1),
1359 nir_intrinsic_base(stack));
1360 data = nir_build_load_global(b,
1361 stack->def.num_components,
1362 stack->def.bit_size,
1363 addr,
1364 .align_mul = nir_intrinsic_align_mul(stack),
1365 .align_offset = nir_intrinsic_align_offset(stack));
1366 } else {
1367 assert(state->address_format == nir_address_format_32bit_offset);
1368 data = nir_load_scratch(b,
1369 old_data->num_components,
1370 old_data->bit_size,
1371 nir_imm_int(b, nir_intrinsic_base(stack)),
1372 .align_mul = nir_intrinsic_align_mul(stack),
1373 .align_offset = nir_intrinsic_align_offset(stack));
1374 }
1375 nir_def_rewrite_uses(old_data, data);
1376 break;
1377 }
1378
1379 case nir_intrinsic_store_stack: {
1380 b->cursor = nir_instr_remove(instr);
1381 nir_def *data = stack->src[0].ssa;
1382
1383 if (state->address_format == nir_address_format_64bit_global) {
1384 nir_def *addr = nir_iadd_imm(b,
1385 nir_load_scratch_base_ptr(b, 1, 64, 1),
1386 nir_intrinsic_base(stack));
1387 nir_store_global(b, addr,
1388 nir_intrinsic_align_mul(stack),
1389 data,
1390 nir_component_mask(data->num_components));
1391 } else {
1392 assert(state->address_format == nir_address_format_32bit_offset);
1393 nir_store_scratch(b, data,
1394 nir_imm_int(b, nir_intrinsic_base(stack)),
1395 .align_mul = nir_intrinsic_align_mul(stack),
1396 .write_mask = BITFIELD_MASK(data->num_components));
1397 }
1398 break;
1399 }
1400
1401 default:
1402 return false;
1403 }
1404
1405 return true;
1406 }
1407
1408 static bool
nir_lower_stack_to_scratch(nir_shader * shader,nir_address_format address_format)1409 nir_lower_stack_to_scratch(nir_shader *shader,
1410 nir_address_format address_format)
1411 {
1412 struct lower_scratch_state state = {
1413 .address_format = address_format,
1414 };
1415
1416 return nir_shader_instructions_pass(shader,
1417 lower_stack_instr_to_scratch,
1418 nir_metadata_block_index |
1419 nir_metadata_dominance,
1420 &state);
1421 }
1422
1423 static bool
opt_remove_respills_instr(struct nir_builder * b,nir_intrinsic_instr * store_intrin,void * data)1424 opt_remove_respills_instr(struct nir_builder *b,
1425 nir_intrinsic_instr *store_intrin, void *data)
1426 {
1427 if (store_intrin->intrinsic != nir_intrinsic_store_stack)
1428 return false;
1429
1430 nir_instr *value_instr = store_intrin->src[0].ssa->parent_instr;
1431 if (value_instr->type != nir_instr_type_intrinsic)
1432 return false;
1433
1434 nir_intrinsic_instr *load_intrin = nir_instr_as_intrinsic(value_instr);
1435 if (load_intrin->intrinsic != nir_intrinsic_load_stack)
1436 return false;
1437
1438 if (nir_intrinsic_base(load_intrin) != nir_intrinsic_base(store_intrin))
1439 return false;
1440
1441 nir_instr_remove(&store_intrin->instr);
1442 return true;
1443 }
1444
1445 /* After shader split, look at stack load/store operations. If we're loading
1446 * and storing the same value at the same location, we can drop the store
1447 * instruction.
1448 */
1449 static bool
nir_opt_remove_respills(nir_shader * shader)1450 nir_opt_remove_respills(nir_shader *shader)
1451 {
1452 return nir_shader_intrinsics_pass(shader, opt_remove_respills_instr,
1453 nir_metadata_block_index |
1454 nir_metadata_dominance,
1455 NULL);
1456 }
1457
1458 static void
add_use_mask(struct hash_table_u64 * offset_to_mask,unsigned offset,unsigned mask)1459 add_use_mask(struct hash_table_u64 *offset_to_mask,
1460 unsigned offset, unsigned mask)
1461 {
1462 uintptr_t old_mask = (uintptr_t)
1463 _mesa_hash_table_u64_search(offset_to_mask, offset);
1464
1465 _mesa_hash_table_u64_insert(offset_to_mask, offset,
1466 (void *)(uintptr_t)(old_mask | mask));
1467 }
1468
1469 /* When splitting the shaders, we might have inserted store & loads of vec4s,
1470 * because a live value is a 4 components. But sometimes, only some components
1471 * of that vec4 will be used by after the scratch load. This pass removes the
1472 * unused components of scratch load/stores.
1473 */
1474 static bool
nir_opt_trim_stack_values(nir_shader * shader)1475 nir_opt_trim_stack_values(nir_shader *shader)
1476 {
1477 nir_function_impl *impl = nir_shader_get_entrypoint(shader);
1478
1479 struct hash_table_u64 *value_id_to_mask = _mesa_hash_table_u64_create(NULL);
1480 bool progress = false;
1481
1482 /* Find all the loads and how their value is being used */
1483 nir_foreach_block_safe(block, impl) {
1484 nir_foreach_instr_safe(instr, block) {
1485 if (instr->type != nir_instr_type_intrinsic)
1486 continue;
1487
1488 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1489 if (intrin->intrinsic != nir_intrinsic_load_stack)
1490 continue;
1491
1492 const unsigned value_id = nir_intrinsic_value_id(intrin);
1493
1494 const unsigned mask =
1495 nir_def_components_read(nir_instr_def(instr));
1496 add_use_mask(value_id_to_mask, value_id, mask);
1497 }
1498 }
1499
1500 /* For each store, if it stores more than is being used, trim it.
1501 * Otherwise, remove it from the hash table.
1502 */
1503 nir_foreach_block_safe(block, impl) {
1504 nir_foreach_instr_safe(instr, block) {
1505 if (instr->type != nir_instr_type_intrinsic)
1506 continue;
1507
1508 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1509 if (intrin->intrinsic != nir_intrinsic_store_stack)
1510 continue;
1511
1512 const unsigned value_id = nir_intrinsic_value_id(intrin);
1513
1514 const unsigned write_mask = nir_intrinsic_write_mask(intrin);
1515 const unsigned read_mask = (uintptr_t)
1516 _mesa_hash_table_u64_search(value_id_to_mask, value_id);
1517
1518 /* Already removed from the table, nothing to do */
1519 if (read_mask == 0)
1520 continue;
1521
1522 /* Matching read/write mask, nothing to do, remove from the table. */
1523 if (write_mask == read_mask) {
1524 _mesa_hash_table_u64_remove(value_id_to_mask, value_id);
1525 continue;
1526 }
1527
1528 nir_builder b = nir_builder_at(nir_before_instr(instr));
1529
1530 nir_def *value = nir_channels(&b, intrin->src[0].ssa, read_mask);
1531 nir_src_rewrite(&intrin->src[0], value);
1532
1533 intrin->num_components = util_bitcount(read_mask);
1534 nir_intrinsic_set_write_mask(intrin, (1u << intrin->num_components) - 1);
1535
1536 progress = true;
1537 }
1538 }
1539
1540 /* For each load remaining in the hash table (only the ones we changed the
1541 * number of components of), apply triming/reswizzle.
1542 */
1543 nir_foreach_block_safe(block, impl) {
1544 nir_foreach_instr_safe(instr, block) {
1545 if (instr->type != nir_instr_type_intrinsic)
1546 continue;
1547
1548 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1549 if (intrin->intrinsic != nir_intrinsic_load_stack)
1550 continue;
1551
1552 const unsigned value_id = nir_intrinsic_value_id(intrin);
1553
1554 unsigned read_mask = (uintptr_t)
1555 _mesa_hash_table_u64_search(value_id_to_mask, value_id);
1556 if (read_mask == 0)
1557 continue;
1558
1559 unsigned swiz_map[NIR_MAX_VEC_COMPONENTS] = {
1560 0,
1561 };
1562 unsigned swiz_count = 0;
1563 u_foreach_bit(idx, read_mask)
1564 swiz_map[idx] = swiz_count++;
1565
1566 nir_def *def = nir_instr_def(instr);
1567
1568 nir_foreach_use_safe(use_src, def) {
1569 if (nir_src_parent_instr(use_src)->type == nir_instr_type_alu) {
1570 nir_alu_instr *alu = nir_instr_as_alu(nir_src_parent_instr(use_src));
1571 nir_alu_src *alu_src = exec_node_data(nir_alu_src, use_src, src);
1572
1573 unsigned count = alu->def.num_components;
1574 for (unsigned idx = 0; idx < count; ++idx)
1575 alu_src->swizzle[idx] = swiz_map[alu_src->swizzle[idx]];
1576 } else if (nir_src_parent_instr(use_src)->type == nir_instr_type_intrinsic) {
1577 nir_intrinsic_instr *use_intrin =
1578 nir_instr_as_intrinsic(nir_src_parent_instr(use_src));
1579 assert(nir_intrinsic_has_write_mask(use_intrin));
1580 unsigned write_mask = nir_intrinsic_write_mask(use_intrin);
1581 unsigned new_write_mask = 0;
1582 u_foreach_bit(idx, write_mask)
1583 new_write_mask |= 1 << swiz_map[idx];
1584 nir_intrinsic_set_write_mask(use_intrin, new_write_mask);
1585 } else {
1586 unreachable("invalid instruction type");
1587 }
1588 }
1589
1590 intrin->def.num_components = intrin->num_components = swiz_count;
1591
1592 progress = true;
1593 }
1594 }
1595
1596 nir_metadata_preserve(impl,
1597 progress ? (nir_metadata_dominance |
1598 nir_metadata_block_index |
1599 nir_metadata_loop_analysis)
1600 : nir_metadata_all);
1601
1602 _mesa_hash_table_u64_destroy(value_id_to_mask);
1603
1604 return progress;
1605 }
1606
1607 struct scratch_item {
1608 unsigned old_offset;
1609 unsigned new_offset;
1610 unsigned bit_size;
1611 unsigned num_components;
1612 unsigned value;
1613 unsigned call_idx;
1614 };
1615
1616 static int
sort_scratch_item_by_size_and_value_id(const void * _item1,const void * _item2)1617 sort_scratch_item_by_size_and_value_id(const void *_item1, const void *_item2)
1618 {
1619 const struct scratch_item *item1 = _item1;
1620 const struct scratch_item *item2 = _item2;
1621
1622 /* By ascending value_id */
1623 if (item1->bit_size == item2->bit_size)
1624 return (int)item1->value - (int)item2->value;
1625
1626 /* By descending size */
1627 return (int)item2->bit_size - (int)item1->bit_size;
1628 }
1629
1630 static bool
nir_opt_sort_and_pack_stack(nir_shader * shader,unsigned start_call_scratch,unsigned stack_alignment,unsigned num_calls)1631 nir_opt_sort_and_pack_stack(nir_shader *shader,
1632 unsigned start_call_scratch,
1633 unsigned stack_alignment,
1634 unsigned num_calls)
1635 {
1636 nir_function_impl *impl = nir_shader_get_entrypoint(shader);
1637
1638 void *mem_ctx = ralloc_context(NULL);
1639
1640 struct hash_table_u64 *value_id_to_item =
1641 _mesa_hash_table_u64_create(mem_ctx);
1642 struct util_dynarray ops;
1643 util_dynarray_init(&ops, mem_ctx);
1644
1645 for (unsigned call_idx = 0; call_idx < num_calls; call_idx++) {
1646 _mesa_hash_table_u64_clear(value_id_to_item);
1647 util_dynarray_clear(&ops);
1648
1649 /* Find all the stack load and their offset. */
1650 nir_foreach_block_safe(block, impl) {
1651 nir_foreach_instr_safe(instr, block) {
1652 if (instr->type != nir_instr_type_intrinsic)
1653 continue;
1654
1655 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1656 if (intrin->intrinsic != nir_intrinsic_load_stack)
1657 continue;
1658
1659 if (nir_intrinsic_call_idx(intrin) != call_idx)
1660 continue;
1661
1662 const unsigned value_id = nir_intrinsic_value_id(intrin);
1663 nir_def *def = nir_instr_def(instr);
1664
1665 assert(_mesa_hash_table_u64_search(value_id_to_item,
1666 value_id) == NULL);
1667
1668 struct scratch_item item = {
1669 .old_offset = nir_intrinsic_base(intrin),
1670 .bit_size = def->bit_size,
1671 .num_components = def->num_components,
1672 .value = value_id,
1673 };
1674
1675 util_dynarray_append(&ops, struct scratch_item, item);
1676 _mesa_hash_table_u64_insert(value_id_to_item, value_id, (void *)(uintptr_t) true);
1677 }
1678 }
1679
1680 /* Sort scratch item by component size. */
1681 if (util_dynarray_num_elements(&ops, struct scratch_item)) {
1682 qsort(util_dynarray_begin(&ops),
1683 util_dynarray_num_elements(&ops, struct scratch_item),
1684 sizeof(struct scratch_item),
1685 sort_scratch_item_by_size_and_value_id);
1686 }
1687
1688 /* Reorder things on the stack */
1689 _mesa_hash_table_u64_clear(value_id_to_item);
1690
1691 unsigned scratch_size = start_call_scratch;
1692 util_dynarray_foreach(&ops, struct scratch_item, item) {
1693 item->new_offset = ALIGN(scratch_size, item->bit_size / 8);
1694 scratch_size = item->new_offset + (item->bit_size * item->num_components) / 8;
1695 _mesa_hash_table_u64_insert(value_id_to_item, item->value, item);
1696 }
1697 shader->scratch_size = ALIGN(scratch_size, stack_alignment);
1698
1699 /* Update offsets in the instructions */
1700 nir_foreach_block_safe(block, impl) {
1701 nir_foreach_instr_safe(instr, block) {
1702 if (instr->type != nir_instr_type_intrinsic)
1703 continue;
1704
1705 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1706 switch (intrin->intrinsic) {
1707 case nir_intrinsic_load_stack:
1708 case nir_intrinsic_store_stack: {
1709 if (nir_intrinsic_call_idx(intrin) != call_idx)
1710 continue;
1711
1712 struct scratch_item *item =
1713 _mesa_hash_table_u64_search(value_id_to_item,
1714 nir_intrinsic_value_id(intrin));
1715 assert(item);
1716
1717 nir_intrinsic_set_base(intrin, item->new_offset);
1718 break;
1719 }
1720
1721 case nir_intrinsic_rt_trace_ray:
1722 case nir_intrinsic_rt_execute_callable:
1723 case nir_intrinsic_rt_resume:
1724 if (nir_intrinsic_call_idx(intrin) != call_idx)
1725 continue;
1726 nir_intrinsic_set_stack_size(intrin, shader->scratch_size);
1727 break;
1728
1729 default:
1730 break;
1731 }
1732 }
1733 }
1734 }
1735
1736 ralloc_free(mem_ctx);
1737
1738 nir_shader_preserve_all_metadata(shader);
1739
1740 return true;
1741 }
1742
1743 static unsigned
nir_block_loop_depth(nir_block * block)1744 nir_block_loop_depth(nir_block *block)
1745 {
1746 nir_cf_node *node = &block->cf_node;
1747 unsigned loop_depth = 0;
1748
1749 while (node != NULL) {
1750 if (node->type == nir_cf_node_loop)
1751 loop_depth++;
1752 node = node->parent;
1753 }
1754
1755 return loop_depth;
1756 }
1757
1758 /* Find the last block dominating all the uses of a SSA value. */
1759 static nir_block *
find_last_dominant_use_block(nir_function_impl * impl,nir_def * value)1760 find_last_dominant_use_block(nir_function_impl *impl, nir_def *value)
1761 {
1762 nir_block *old_block = value->parent_instr->block;
1763 unsigned old_block_loop_depth = nir_block_loop_depth(old_block);
1764
1765 nir_foreach_block_reverse_safe(block, impl) {
1766 bool fits = true;
1767
1768 /* Store on the current block of the value */
1769 if (block == old_block)
1770 return block;
1771
1772 /* Don't move instructions deeper into loops, this would generate more
1773 * memory traffic.
1774 */
1775 unsigned block_loop_depth = nir_block_loop_depth(block);
1776 if (block_loop_depth > old_block_loop_depth)
1777 continue;
1778
1779 nir_foreach_if_use(src, value) {
1780 nir_block *block_before_if =
1781 nir_cf_node_as_block(nir_cf_node_prev(&nir_src_parent_if(src)->cf_node));
1782 if (!nir_block_dominates(block, block_before_if)) {
1783 fits = false;
1784 break;
1785 }
1786 }
1787 if (!fits)
1788 continue;
1789
1790 nir_foreach_use(src, value) {
1791 if (nir_src_parent_instr(src)->type == nir_instr_type_phi &&
1792 block == nir_src_parent_instr(src)->block) {
1793 fits = false;
1794 break;
1795 }
1796
1797 if (!nir_block_dominates(block, nir_src_parent_instr(src)->block)) {
1798 fits = false;
1799 break;
1800 }
1801 }
1802 if (!fits)
1803 continue;
1804
1805 return block;
1806 }
1807 unreachable("Cannot find block");
1808 }
1809
1810 /* Put the scratch loads in the branches where they're needed. */
1811 static bool
nir_opt_stack_loads(nir_shader * shader)1812 nir_opt_stack_loads(nir_shader *shader)
1813 {
1814 bool progress = false;
1815
1816 nir_foreach_function_impl(impl, shader) {
1817 nir_metadata_require(impl, nir_metadata_dominance |
1818 nir_metadata_block_index);
1819
1820 bool func_progress = false;
1821 nir_foreach_block_safe(block, impl) {
1822 nir_foreach_instr_safe(instr, block) {
1823 if (instr->type != nir_instr_type_intrinsic)
1824 continue;
1825
1826 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1827 if (intrin->intrinsic != nir_intrinsic_load_stack)
1828 continue;
1829
1830 nir_def *value = &intrin->def;
1831 nir_block *new_block = find_last_dominant_use_block(impl, value);
1832 if (new_block == block)
1833 continue;
1834
1835 /* Move the scratch load in the new block, after the phis. */
1836 nir_instr_remove(instr);
1837 nir_instr_insert(nir_before_block_after_phis(new_block), instr);
1838
1839 func_progress = true;
1840 }
1841 }
1842
1843 nir_metadata_preserve(impl,
1844 func_progress ? (nir_metadata_block_index |
1845 nir_metadata_dominance |
1846 nir_metadata_loop_analysis)
1847 : nir_metadata_all);
1848
1849 progress |= func_progress;
1850 }
1851
1852 return progress;
1853 }
1854
1855 static bool
split_stack_components_instr(struct nir_builder * b,nir_intrinsic_instr * intrin,void * data)1856 split_stack_components_instr(struct nir_builder *b,
1857 nir_intrinsic_instr *intrin, void *data)
1858 {
1859 if (intrin->intrinsic != nir_intrinsic_load_stack &&
1860 intrin->intrinsic != nir_intrinsic_store_stack)
1861 return false;
1862
1863 if (intrin->intrinsic == nir_intrinsic_load_stack &&
1864 intrin->def.num_components == 1)
1865 return false;
1866
1867 if (intrin->intrinsic == nir_intrinsic_store_stack &&
1868 intrin->src[0].ssa->num_components == 1)
1869 return false;
1870
1871 b->cursor = nir_before_instr(&intrin->instr);
1872
1873 unsigned align_mul = nir_intrinsic_align_mul(intrin);
1874 unsigned align_offset = nir_intrinsic_align_offset(intrin);
1875 if (intrin->intrinsic == nir_intrinsic_load_stack) {
1876 nir_def *components[NIR_MAX_VEC_COMPONENTS] = {
1877 0,
1878 };
1879 for (unsigned c = 0; c < intrin->def.num_components; c++) {
1880 unsigned offset = c * intrin->def.bit_size / 8;
1881 components[c] = nir_load_stack(b, 1, intrin->def.bit_size,
1882 .base = nir_intrinsic_base(intrin) + offset,
1883 .call_idx = nir_intrinsic_call_idx(intrin),
1884 .value_id = nir_intrinsic_value_id(intrin),
1885 .align_mul = align_mul,
1886 .align_offset = (align_offset + offset) % align_mul);
1887 }
1888
1889 nir_def_rewrite_uses(&intrin->def,
1890 nir_vec(b, components,
1891 intrin->def.num_components));
1892 } else {
1893 assert(intrin->intrinsic == nir_intrinsic_store_stack);
1894 for (unsigned c = 0; c < intrin->src[0].ssa->num_components; c++) {
1895 unsigned offset = c * intrin->src[0].ssa->bit_size / 8;
1896 nir_store_stack(b, nir_channel(b, intrin->src[0].ssa, c),
1897 .base = nir_intrinsic_base(intrin) + offset,
1898 .call_idx = nir_intrinsic_call_idx(intrin),
1899 .align_mul = align_mul,
1900 .align_offset = (align_offset + offset) % align_mul,
1901 .value_id = nir_intrinsic_value_id(intrin),
1902 .write_mask = 0x1);
1903 }
1904 }
1905
1906 nir_instr_remove(&intrin->instr);
1907
1908 return true;
1909 }
1910
1911 /* Break the load_stack/store_stack intrinsics into single compoments. This
1912 * helps the vectorizer to pack components.
1913 */
1914 static bool
nir_split_stack_components(nir_shader * shader)1915 nir_split_stack_components(nir_shader *shader)
1916 {
1917 return nir_shader_intrinsics_pass(shader, split_stack_components_instr,
1918 nir_metadata_block_index |
1919 nir_metadata_dominance,
1920 NULL);
1921 }
1922
1923 struct stack_op_vectorizer_state {
1924 nir_should_vectorize_mem_func driver_callback;
1925 void *driver_data;
1926 };
1927
1928 static bool
should_vectorize(unsigned align_mul,unsigned align_offset,unsigned bit_size,unsigned num_components,nir_intrinsic_instr * low,nir_intrinsic_instr * high,void * data)1929 should_vectorize(unsigned align_mul,
1930 unsigned align_offset,
1931 unsigned bit_size,
1932 unsigned num_components,
1933 nir_intrinsic_instr *low, nir_intrinsic_instr *high,
1934 void *data)
1935 {
1936 /* We only care about those intrinsics */
1937 if ((low->intrinsic != nir_intrinsic_load_stack &&
1938 low->intrinsic != nir_intrinsic_store_stack) ||
1939 (high->intrinsic != nir_intrinsic_load_stack &&
1940 high->intrinsic != nir_intrinsic_store_stack))
1941 return false;
1942
1943 struct stack_op_vectorizer_state *state = data;
1944
1945 return state->driver_callback(align_mul, align_offset,
1946 bit_size, num_components,
1947 low, high, state->driver_data);
1948 }
1949
1950 /** Lower shader call instructions to split shaders.
1951 *
1952 * Shader calls can be split into an initial shader and a series of "resume"
1953 * shaders. When the shader is first invoked, it is the initial shader which
1954 * is executed. At any point in the initial shader or any one of the resume
1955 * shaders, a shader call operation may be performed. The possible shader call
1956 * operations are:
1957 *
1958 * - trace_ray
1959 * - report_ray_intersection
1960 * - execute_callable
1961 *
1962 * When a shader call operation is performed, we push all live values to the
1963 * stack,call rt_trace_ray/rt_execute_callable and then kill the shader. Once
1964 * the operation we invoked is complete, a callee shader will return execution
1965 * to the respective resume shader. The resume shader pops the contents off
1966 * the stack and picks up where the calling shader left off.
1967 *
1968 * Stack management is assumed to be done after this pass. Call
1969 * instructions and their resumes get annotated with stack information that
1970 * should be enough for the backend to implement proper stack management.
1971 */
1972 bool
nir_lower_shader_calls(nir_shader * shader,const nir_lower_shader_calls_options * options,nir_shader *** resume_shaders_out,uint32_t * num_resume_shaders_out,void * mem_ctx)1973 nir_lower_shader_calls(nir_shader *shader,
1974 const nir_lower_shader_calls_options *options,
1975 nir_shader ***resume_shaders_out,
1976 uint32_t *num_resume_shaders_out,
1977 void *mem_ctx)
1978 {
1979 nir_function_impl *impl = nir_shader_get_entrypoint(shader);
1980
1981 int num_calls = 0;
1982 nir_foreach_block(block, impl) {
1983 nir_foreach_instr_safe(instr, block) {
1984 if (instr_is_shader_call(instr))
1985 num_calls++;
1986 }
1987 }
1988
1989 if (num_calls == 0) {
1990 nir_shader_preserve_all_metadata(shader);
1991 *num_resume_shaders_out = 0;
1992 return false;
1993 }
1994
1995 /* Some intrinsics not only can't be re-materialized but aren't preserved
1996 * when moving to the continuation shader. We have to move them to the top
1997 * to ensure they get spilled as needed.
1998 */
1999 {
2000 bool progress = false;
2001 NIR_PASS(progress, shader, move_system_values_to_top);
2002 if (progress)
2003 NIR_PASS(progress, shader, nir_opt_cse);
2004 }
2005
2006 /* Deref chains contain metadata information that is needed by other passes
2007 * after this one. If we don't rematerialize the derefs in the blocks where
2008 * they're used here, the following lowerings will insert phis which can
2009 * prevent other passes from chasing deref chains. Additionally, derefs need
2010 * to be rematerialized after shader call instructions to avoid spilling.
2011 */
2012 {
2013 bool progress = false;
2014 NIR_PASS(progress, shader, wrap_instrs, instr_is_shader_call);
2015
2016 nir_rematerialize_derefs_in_use_blocks_impl(impl);
2017
2018 if (progress)
2019 NIR_PASS(_, shader, nir_opt_dead_cf);
2020 }
2021
2022 /* Save the start point of the call stack in scratch */
2023 unsigned start_call_scratch = shader->scratch_size;
2024
2025 NIR_PASS_V(shader, spill_ssa_defs_and_lower_shader_calls,
2026 num_calls, options);
2027
2028 NIR_PASS_V(shader, nir_opt_remove_phis);
2029
2030 NIR_PASS_V(shader, nir_opt_trim_stack_values);
2031 NIR_PASS_V(shader, nir_opt_sort_and_pack_stack,
2032 start_call_scratch, options->stack_alignment, num_calls);
2033
2034 /* Make N copies of our shader */
2035 nir_shader **resume_shaders = ralloc_array(mem_ctx, nir_shader *, num_calls);
2036 for (unsigned i = 0; i < num_calls; i++) {
2037 resume_shaders[i] = nir_shader_clone(mem_ctx, shader);
2038
2039 /* Give them a recognizable name */
2040 resume_shaders[i]->info.name =
2041 ralloc_asprintf(mem_ctx, "%s%sresume_%u",
2042 shader->info.name ? shader->info.name : "",
2043 shader->info.name ? "-" : "",
2044 i);
2045 }
2046
2047 replace_resume_with_halt(shader, NULL);
2048 nir_opt_dce(shader);
2049 nir_opt_dead_cf(shader);
2050 for (unsigned i = 0; i < num_calls; i++) {
2051 nir_instr *resume_instr = lower_resume(resume_shaders[i], i);
2052 replace_resume_with_halt(resume_shaders[i], resume_instr);
2053 /* Remove CF after halt before nir_opt_if(). */
2054 nir_opt_dead_cf(resume_shaders[i]);
2055 /* Remove the dummy blocks added by flatten_resume_if_ladder() */
2056 nir_opt_if(resume_shaders[i], nir_opt_if_optimize_phi_true_false);
2057 nir_opt_dce(resume_shaders[i]);
2058 nir_opt_dead_cf(resume_shaders[i]);
2059 nir_opt_remove_phis(resume_shaders[i]);
2060 }
2061
2062 for (unsigned i = 0; i < num_calls; i++)
2063 NIR_PASS_V(resume_shaders[i], nir_opt_remove_respills);
2064
2065 if (options->localized_loads) {
2066 /* Once loads have been combined we can try to put them closer to where
2067 * they're needed.
2068 */
2069 for (unsigned i = 0; i < num_calls; i++)
2070 NIR_PASS_V(resume_shaders[i], nir_opt_stack_loads);
2071 }
2072
2073 struct stack_op_vectorizer_state vectorizer_state = {
2074 .driver_callback = options->vectorizer_callback,
2075 .driver_data = options->vectorizer_data,
2076 };
2077 nir_load_store_vectorize_options vect_opts = {
2078 .modes = nir_var_shader_temp,
2079 .callback = should_vectorize,
2080 .cb_data = &vectorizer_state,
2081 };
2082
2083 if (options->vectorizer_callback != NULL) {
2084 NIR_PASS_V(shader, nir_split_stack_components);
2085 NIR_PASS_V(shader, nir_opt_load_store_vectorize, &vect_opts);
2086 }
2087 NIR_PASS_V(shader, nir_lower_stack_to_scratch, options->address_format);
2088 nir_opt_cse(shader);
2089 for (unsigned i = 0; i < num_calls; i++) {
2090 if (options->vectorizer_callback != NULL) {
2091 NIR_PASS_V(resume_shaders[i], nir_split_stack_components);
2092 NIR_PASS_V(resume_shaders[i], nir_opt_load_store_vectorize, &vect_opts);
2093 }
2094 NIR_PASS_V(resume_shaders[i], nir_lower_stack_to_scratch,
2095 options->address_format);
2096 nir_opt_cse(resume_shaders[i]);
2097 }
2098
2099 *resume_shaders_out = resume_shaders;
2100 *num_resume_shaders_out = num_calls;
2101
2102 return true;
2103 }
2104