• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright © 2020 Intel Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 #include "util/u_dynarray.h"
25 #include "util/u_math.h"
26 #include "nir.h"
27 #include "nir_builder.h"
28 #include "nir_phi_builder.h"
29 
30 static bool
move_system_values_to_top(nir_shader * shader)31 move_system_values_to_top(nir_shader *shader)
32 {
33    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
34 
35    bool progress = false;
36    nir_foreach_block(block, impl) {
37       nir_foreach_instr_safe(instr, block) {
38          if (instr->type != nir_instr_type_intrinsic)
39             continue;
40 
41          /* These intrinsics not only can't be re-materialized but aren't
42           * preserved when moving to the continuation shader.  We have to move
43           * them to the top to ensure they get spilled as needed.
44           */
45          nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
46          switch (intrin->intrinsic) {
47          case nir_intrinsic_load_shader_record_ptr:
48          case nir_intrinsic_load_btd_local_arg_addr_intel:
49             nir_instr_remove(instr);
50             nir_instr_insert(nir_before_impl(impl), instr);
51             progress = true;
52             break;
53 
54          default:
55             break;
56          }
57       }
58    }
59 
60    if (progress) {
61       nir_metadata_preserve(impl, nir_metadata_block_index |
62                                      nir_metadata_dominance);
63    } else {
64       nir_metadata_preserve(impl, nir_metadata_all);
65    }
66 
67    return progress;
68 }
69 
70 static bool
instr_is_shader_call(nir_instr * instr)71 instr_is_shader_call(nir_instr *instr)
72 {
73    if (instr->type != nir_instr_type_intrinsic)
74       return false;
75 
76    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
77    return intrin->intrinsic == nir_intrinsic_trace_ray ||
78           intrin->intrinsic == nir_intrinsic_report_ray_intersection ||
79           intrin->intrinsic == nir_intrinsic_execute_callable;
80 }
81 
82 /* Previously named bitset, it had to be renamed as FreeBSD defines a struct
83  * named bitset in sys/_bitset.h required by pthread_np.h which is included
84  * from src/util/u_thread.h that is indirectly included by this file.
85  */
86 struct sized_bitset {
87    BITSET_WORD *set;
88    unsigned size;
89 };
90 
91 static struct sized_bitset
bitset_create(void * mem_ctx,unsigned size)92 bitset_create(void *mem_ctx, unsigned size)
93 {
94    return (struct sized_bitset){
95       .set = rzalloc_array(mem_ctx, BITSET_WORD, BITSET_WORDS(size)),
96       .size = size,
97    };
98 }
99 
100 static bool
src_is_in_bitset(nir_src * src,void * _set)101 src_is_in_bitset(nir_src *src, void *_set)
102 {
103    struct sized_bitset *set = _set;
104 
105    /* Any SSA values which were added after we generated liveness information
106     * are things generated by this pass and, while most of it is arithmetic
107     * which we could re-materialize, we don't need to because it's only used
108     * for a single load/store and so shouldn't cross any shader calls.
109     */
110    if (src->ssa->index >= set->size)
111       return false;
112 
113    return BITSET_TEST(set->set, src->ssa->index);
114 }
115 
116 static void
add_ssa_def_to_bitset(nir_def * def,struct sized_bitset * set)117 add_ssa_def_to_bitset(nir_def *def, struct sized_bitset *set)
118 {
119    if (def->index >= set->size)
120       return;
121 
122    BITSET_SET(set->set, def->index);
123 }
124 
125 static bool
can_remat_instr(nir_instr * instr,struct sized_bitset * remat)126 can_remat_instr(nir_instr *instr, struct sized_bitset *remat)
127 {
128    /* Set of all values which are trivially re-materializable and we shouldn't
129     * ever spill them.  This includes:
130     *
131     *   - Undef values
132     *   - Constants
133     *   - Uniforms (UBO or push constant)
134     *   - ALU combinations of any of the above
135     *   - Derefs which are either complete or casts of any of the above
136     *
137     * Because this pass rewrites things in-order and phis are always turned
138     * into register writes, we can use "is it SSA?" to answer the question
139     * "can my source be re-materialized?". Register writes happen via
140     * non-rematerializable intrinsics.
141     */
142    switch (instr->type) {
143    case nir_instr_type_alu:
144       return nir_foreach_src(instr, src_is_in_bitset, remat);
145 
146    case nir_instr_type_deref:
147       return nir_foreach_src(instr, src_is_in_bitset, remat);
148 
149    case nir_instr_type_intrinsic: {
150       nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
151       switch (intrin->intrinsic) {
152       case nir_intrinsic_load_uniform:
153       case nir_intrinsic_load_ubo:
154       case nir_intrinsic_vulkan_resource_index:
155       case nir_intrinsic_vulkan_resource_reindex:
156       case nir_intrinsic_load_vulkan_descriptor:
157       case nir_intrinsic_load_push_constant:
158       case nir_intrinsic_load_global_constant:
159       case nir_intrinsic_load_global_const_block_intel:
160       case nir_intrinsic_load_desc_set_address_intel:
161          /* These intrinsics don't need to be spilled as long as they don't
162           * depend on any spilled values.
163           */
164          return nir_foreach_src(instr, src_is_in_bitset, remat);
165 
166       case nir_intrinsic_load_scratch_base_ptr:
167       case nir_intrinsic_load_ray_launch_id:
168       case nir_intrinsic_load_topology_id_intel:
169       case nir_intrinsic_load_btd_global_arg_addr_intel:
170       case nir_intrinsic_load_btd_resume_sbt_addr_intel:
171       case nir_intrinsic_load_ray_base_mem_addr_intel:
172       case nir_intrinsic_load_ray_hw_stack_size_intel:
173       case nir_intrinsic_load_ray_sw_stack_size_intel:
174       case nir_intrinsic_load_ray_num_dss_rt_stacks_intel:
175       case nir_intrinsic_load_ray_hit_sbt_addr_intel:
176       case nir_intrinsic_load_ray_hit_sbt_stride_intel:
177       case nir_intrinsic_load_ray_miss_sbt_addr_intel:
178       case nir_intrinsic_load_ray_miss_sbt_stride_intel:
179       case nir_intrinsic_load_callable_sbt_addr_intel:
180       case nir_intrinsic_load_callable_sbt_stride_intel:
181       case nir_intrinsic_load_reloc_const_intel:
182       case nir_intrinsic_load_ray_query_global_intel:
183       case nir_intrinsic_load_ray_launch_size:
184          /* Notably missing from the above list is btd_local_arg_addr_intel.
185           * This is because the resume shader will have a different local
186           * argument pointer because it has a different BSR.  Any access of
187           * the original shader's local arguments needs to be preserved so
188           * that pointer has to be saved on the stack.
189           *
190           * TODO: There may be some system values we want to avoid
191           *       re-materializing as well but we have to be very careful
192           *       to ensure that it's a system value which cannot change
193           *       across a shader call.
194           */
195          return true;
196 
197       case nir_intrinsic_resource_intel:
198          return nir_foreach_src(instr, src_is_in_bitset, remat);
199 
200       default:
201          return false;
202       }
203    }
204 
205    case nir_instr_type_undef:
206    case nir_instr_type_load_const:
207       return true;
208 
209    default:
210       return false;
211    }
212 }
213 
214 static bool
can_remat_ssa_def(nir_def * def,struct sized_bitset * remat)215 can_remat_ssa_def(nir_def *def, struct sized_bitset *remat)
216 {
217    return can_remat_instr(def->parent_instr, remat);
218 }
219 
220 struct add_instr_data {
221    struct util_dynarray *buf;
222    struct sized_bitset *remat;
223 };
224 
225 static bool
add_src_instr(nir_src * src,void * state)226 add_src_instr(nir_src *src, void *state)
227 {
228    struct add_instr_data *data = state;
229    if (BITSET_TEST(data->remat->set, src->ssa->index))
230       return true;
231 
232    util_dynarray_foreach(data->buf, nir_instr *, instr_ptr) {
233       if (*instr_ptr == src->ssa->parent_instr)
234          return true;
235    }
236 
237    /* Abort rematerializing an instruction chain if it is too long. */
238    if (data->buf->size >= data->buf->capacity)
239       return false;
240 
241    util_dynarray_append(data->buf, nir_instr *, src->ssa->parent_instr);
242    return true;
243 }
244 
245 static int
compare_instr_indexes(const void * _inst1,const void * _inst2)246 compare_instr_indexes(const void *_inst1, const void *_inst2)
247 {
248    const nir_instr *const *inst1 = _inst1;
249    const nir_instr *const *inst2 = _inst2;
250 
251    return (*inst1)->index - (*inst2)->index;
252 }
253 
254 static bool
can_remat_chain_ssa_def(nir_def * def,struct sized_bitset * remat,struct util_dynarray * buf)255 can_remat_chain_ssa_def(nir_def *def, struct sized_bitset *remat, struct util_dynarray *buf)
256 {
257    assert(util_dynarray_num_elements(buf, nir_instr *) == 0);
258 
259    void *mem_ctx = ralloc_context(NULL);
260 
261    /* Add all the instructions involved in build this ssa_def */
262    util_dynarray_append(buf, nir_instr *, def->parent_instr);
263 
264    unsigned idx = 0;
265    struct add_instr_data data = {
266       .buf = buf,
267       .remat = remat,
268    };
269    while (idx < util_dynarray_num_elements(buf, nir_instr *)) {
270       nir_instr *instr = *util_dynarray_element(buf, nir_instr *, idx++);
271       if (!nir_foreach_src(instr, add_src_instr, &data))
272          goto fail;
273    }
274 
275    /* Sort instructions by index */
276    qsort(util_dynarray_begin(buf),
277          util_dynarray_num_elements(buf, nir_instr *),
278          sizeof(nir_instr *),
279          compare_instr_indexes);
280 
281    /* Create a temporary bitset with all values already
282     * rematerialized/rematerializable. We'll add to this bit set as we go
283     * through values that might not be in that set but that we can
284     * rematerialize.
285     */
286    struct sized_bitset potential_remat = bitset_create(mem_ctx, remat->size);
287    memcpy(potential_remat.set, remat->set, BITSET_WORDS(remat->size) * sizeof(BITSET_WORD));
288 
289    util_dynarray_foreach(buf, nir_instr *, instr_ptr) {
290       nir_def *instr_ssa_def = nir_instr_def(*instr_ptr);
291 
292       /* If already in the potential rematerializable, nothing to do. */
293       if (BITSET_TEST(potential_remat.set, instr_ssa_def->index))
294          continue;
295 
296       if (!can_remat_instr(*instr_ptr, &potential_remat))
297          goto fail;
298 
299       /* All the sources are rematerializable and the instruction is also
300        * rematerializable, mark it as rematerializable too.
301        */
302       BITSET_SET(potential_remat.set, instr_ssa_def->index);
303    }
304 
305    ralloc_free(mem_ctx);
306 
307    return true;
308 
309 fail:
310    util_dynarray_clear(buf);
311    ralloc_free(mem_ctx);
312    return false;
313 }
314 
315 static nir_def *
remat_ssa_def(nir_builder * b,nir_def * def,struct hash_table * remap_table)316 remat_ssa_def(nir_builder *b, nir_def *def, struct hash_table *remap_table)
317 {
318    nir_instr *clone = nir_instr_clone_deep(b->shader, def->parent_instr, remap_table);
319    nir_builder_instr_insert(b, clone);
320    return nir_instr_def(clone);
321 }
322 
323 static nir_def *
remat_chain_ssa_def(nir_builder * b,struct util_dynarray * buf,struct sized_bitset * remat,nir_def *** fill_defs,unsigned call_idx,struct hash_table * remap_table)324 remat_chain_ssa_def(nir_builder *b, struct util_dynarray *buf,
325                     struct sized_bitset *remat, nir_def ***fill_defs,
326                     unsigned call_idx, struct hash_table *remap_table)
327 {
328    nir_def *last_def = NULL;
329 
330    util_dynarray_foreach(buf, nir_instr *, instr_ptr) {
331       nir_def *instr_ssa_def = nir_instr_def(*instr_ptr);
332       unsigned ssa_index = instr_ssa_def->index;
333 
334       if (fill_defs[ssa_index] != NULL &&
335           fill_defs[ssa_index][call_idx] != NULL)
336          continue;
337 
338       /* Clone the instruction we want to rematerialize */
339       nir_def *clone_ssa_def = remat_ssa_def(b, instr_ssa_def, remap_table);
340 
341       if (fill_defs[ssa_index] == NULL) {
342          fill_defs[ssa_index] =
343             rzalloc_array(fill_defs, nir_def *, remat->size);
344       }
345 
346       /* Add the new ssa_def to the list fill_defs and flag it as
347        * rematerialized
348        */
349       fill_defs[ssa_index][call_idx] = last_def = clone_ssa_def;
350       BITSET_SET(remat->set, ssa_index);
351 
352       _mesa_hash_table_insert(remap_table, instr_ssa_def, last_def);
353    }
354 
355    return last_def;
356 }
357 
358 struct pbv_array {
359    struct nir_phi_builder_value **arr;
360    unsigned len;
361 };
362 
363 static struct nir_phi_builder_value *
get_phi_builder_value_for_def(nir_def * def,struct pbv_array * pbv_arr)364 get_phi_builder_value_for_def(nir_def *def,
365                               struct pbv_array *pbv_arr)
366 {
367    if (def->index >= pbv_arr->len)
368       return NULL;
369 
370    return pbv_arr->arr[def->index];
371 }
372 
373 static nir_def *
get_phi_builder_def_for_src(nir_src * src,struct pbv_array * pbv_arr,nir_block * block)374 get_phi_builder_def_for_src(nir_src *src, struct pbv_array *pbv_arr,
375                             nir_block *block)
376 {
377 
378    struct nir_phi_builder_value *pbv =
379       get_phi_builder_value_for_def(src->ssa, pbv_arr);
380    if (pbv == NULL)
381       return NULL;
382 
383    return nir_phi_builder_value_get_block_def(pbv, block);
384 }
385 
386 static bool
rewrite_instr_src_from_phi_builder(nir_src * src,void * _pbv_arr)387 rewrite_instr_src_from_phi_builder(nir_src *src, void *_pbv_arr)
388 {
389    nir_block *block;
390    if (nir_src_parent_instr(src)->type == nir_instr_type_phi) {
391       nir_phi_src *phi_src = exec_node_data(nir_phi_src, src, src);
392       block = phi_src->pred;
393    } else {
394       block = nir_src_parent_instr(src)->block;
395    }
396 
397    nir_def *new_def = get_phi_builder_def_for_src(src, _pbv_arr, block);
398    if (new_def != NULL)
399       nir_src_rewrite(src, new_def);
400    return true;
401 }
402 
403 static nir_def *
spill_fill(nir_builder * before,nir_builder * after,nir_def * def,unsigned value_id,unsigned call_idx,unsigned offset,unsigned stack_alignment)404 spill_fill(nir_builder *before, nir_builder *after, nir_def *def,
405            unsigned value_id, unsigned call_idx,
406            unsigned offset, unsigned stack_alignment)
407 {
408    const unsigned comp_size = def->bit_size / 8;
409 
410    nir_store_stack(before, def,
411                    .base = offset,
412                    .call_idx = call_idx,
413                    .align_mul = MIN2(comp_size, stack_alignment),
414                    .value_id = value_id,
415                    .write_mask = BITFIELD_MASK(def->num_components));
416    return nir_load_stack(after, def->num_components, def->bit_size,
417                          .base = offset,
418                          .call_idx = call_idx,
419                          .value_id = value_id,
420                          .align_mul = MIN2(comp_size, stack_alignment));
421 }
422 
423 static bool
add_src_to_call_live_bitset(nir_src * src,void * state)424 add_src_to_call_live_bitset(nir_src *src, void *state)
425 {
426    BITSET_WORD *call_live = state;
427 
428    BITSET_SET(call_live, src->ssa->index);
429    return true;
430 }
431 
432 static void
spill_ssa_defs_and_lower_shader_calls(nir_shader * shader,uint32_t num_calls,const nir_lower_shader_calls_options * options)433 spill_ssa_defs_and_lower_shader_calls(nir_shader *shader, uint32_t num_calls,
434                                       const nir_lower_shader_calls_options *options)
435 {
436    /* TODO: If a SSA def is filled more than once, we probably want to just
437     *       spill it at the LCM of the fill sites so we avoid unnecessary
438     *       extra spills
439     *
440     * TODO: If a SSA def is defined outside a loop but live through some call
441     *       inside the loop, we probably want to spill outside the loop.  We
442     *       may also want to fill outside the loop if it's not used in the
443     *       loop.
444     *
445     * TODO: Right now, we only re-materialize things if their immediate
446     *       sources are things which we filled.  We probably want to expand
447     *       that to re-materialize things whose sources are things we can
448     *       re-materialize from things we filled.  We may want some DAG depth
449     *       heuristic on this.
450     */
451 
452    /* This happens per-shader rather than per-impl because we mess with
453     * nir_shader::scratch_size.
454     */
455    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
456 
457    nir_metadata_require(impl, nir_metadata_live_defs |
458                                  nir_metadata_dominance |
459                                  nir_metadata_block_index |
460                                  nir_metadata_instr_index);
461 
462    void *mem_ctx = ralloc_context(shader);
463 
464    const unsigned num_ssa_defs = impl->ssa_alloc;
465    const unsigned live_words = BITSET_WORDS(num_ssa_defs);
466    struct sized_bitset trivial_remat = bitset_create(mem_ctx, num_ssa_defs);
467 
468    /* Array of all live SSA defs which are spill candidates */
469    nir_def **spill_defs =
470       rzalloc_array(mem_ctx, nir_def *, num_ssa_defs);
471 
472    /* For each spill candidate, an array of every time it's defined by a fill,
473     * indexed by call instruction index.
474     */
475    nir_def ***fill_defs =
476       rzalloc_array(mem_ctx, nir_def **, num_ssa_defs);
477 
478    /* For each call instruction, the liveness set at the call */
479    const BITSET_WORD **call_live =
480       rzalloc_array(mem_ctx, const BITSET_WORD *, num_calls);
481 
482    /* For each call instruction, the block index of the block it lives in */
483    uint32_t *call_block_indices = rzalloc_array(mem_ctx, uint32_t, num_calls);
484 
485    /* Remap table when rebuilding instructions out of fill operations */
486    struct hash_table *trivial_remap_table =
487       _mesa_pointer_hash_table_create(mem_ctx);
488 
489    /* Walk the call instructions and fetch the liveness set and block index
490     * for each one.  We need to do this before we start modifying the shader
491     * so that liveness doesn't complain that it's been invalidated.  Don't
492     * worry, we'll be very careful with our live sets. :-)
493     */
494    unsigned call_idx = 0;
495    nir_foreach_block(block, impl) {
496       nir_foreach_instr(instr, block) {
497          if (!instr_is_shader_call(instr))
498             continue;
499 
500          call_block_indices[call_idx] = block->index;
501 
502          /* The objective here is to preserve values around shader call
503           * instructions.  Therefore, we use the live set after the
504           * instruction as the set of things we want to preserve.  Because
505           * none of our shader call intrinsics return anything, we don't have
506           * to worry about spilling over a return value.
507           *
508           * TODO: This isn't quite true for report_intersection.
509           */
510          call_live[call_idx] =
511             nir_get_live_defs(nir_after_instr(instr), mem_ctx);
512 
513          call_idx++;
514       }
515    }
516 
517    /* If a should_remat_callback is given, call it on each of the live values
518     * for each call site. If it returns true we need to rematerialize that
519     * instruction (instead of spill/fill). Therefore we need to add the
520     * sources as live values so that we can rematerialize on top of those
521     * spilled/filled sources.
522     */
523    if (options->should_remat_callback) {
524       BITSET_WORD **updated_call_live =
525          rzalloc_array(mem_ctx, BITSET_WORD *, num_calls);
526 
527       nir_foreach_block(block, impl) {
528          nir_foreach_instr(instr, block) {
529             nir_def *def = nir_instr_def(instr);
530             if (def == NULL)
531                continue;
532 
533             for (unsigned c = 0; c < num_calls; c++) {
534                if (!BITSET_TEST(call_live[c], def->index))
535                   continue;
536 
537                if (!options->should_remat_callback(def->parent_instr,
538                                                    options->should_remat_data))
539                   continue;
540 
541                if (updated_call_live[c] == NULL) {
542                   const unsigned bitset_words = BITSET_WORDS(impl->ssa_alloc);
543                   updated_call_live[c] = ralloc_array(mem_ctx, BITSET_WORD, bitset_words);
544                   memcpy(updated_call_live[c], call_live[c], bitset_words * sizeof(BITSET_WORD));
545                }
546 
547                nir_foreach_src(instr, add_src_to_call_live_bitset, updated_call_live[c]);
548             }
549          }
550       }
551 
552       for (unsigned c = 0; c < num_calls; c++) {
553          if (updated_call_live[c] != NULL)
554             call_live[c] = updated_call_live[c];
555       }
556    }
557 
558    nir_builder before, after;
559    before = nir_builder_create(impl);
560    after = nir_builder_create(impl);
561 
562    call_idx = 0;
563    unsigned max_scratch_size = shader->scratch_size;
564    nir_foreach_block(block, impl) {
565       nir_foreach_instr_safe(instr, block) {
566          nir_def *def = nir_instr_def(instr);
567          if (def != NULL) {
568             if (can_remat_ssa_def(def, &trivial_remat)) {
569                add_ssa_def_to_bitset(def, &trivial_remat);
570                _mesa_hash_table_insert(trivial_remap_table, def, def);
571             } else {
572                spill_defs[def->index] = def;
573             }
574          }
575 
576          if (!instr_is_shader_call(instr))
577             continue;
578 
579          const BITSET_WORD *live = call_live[call_idx];
580 
581          struct hash_table *remap_table =
582             _mesa_hash_table_clone(trivial_remap_table, mem_ctx);
583 
584          /* Make a copy of trivial_remat that we'll update as we crawl through
585           * the live SSA defs and unspill them.
586           */
587          struct sized_bitset remat = bitset_create(mem_ctx, num_ssa_defs);
588          memcpy(remat.set, trivial_remat.set, live_words * sizeof(BITSET_WORD));
589 
590          /* Before the two builders are always separated by the call
591           * instruction, it won't break anything to have two of them.
592           */
593          before.cursor = nir_before_instr(instr);
594          after.cursor = nir_after_instr(instr);
595 
596          /* Array used to hold all the values needed to rematerialize a live
597           * value. The capacity is used to determine when we should abort testing
598           * a remat chain. In practice, shaders can have chains with more than
599           * 10k elements while only chains with less than 16 have realistic
600           * chances. There also isn't any performance benefit in rematerializing
601           * extremely long chains.
602           */
603          nir_instr *remat_chain_instrs[16];
604          struct util_dynarray remat_chain;
605          util_dynarray_init_from_stack(&remat_chain, remat_chain_instrs, sizeof(remat_chain_instrs));
606 
607          unsigned offset = shader->scratch_size;
608          for (unsigned w = 0; w < live_words; w++) {
609             BITSET_WORD spill_mask = live[w] & ~trivial_remat.set[w];
610             while (spill_mask) {
611                int i = u_bit_scan(&spill_mask);
612                assert(i >= 0);
613                unsigned index = w * BITSET_WORDBITS + i;
614                assert(index < num_ssa_defs);
615 
616                def = spill_defs[index];
617                nir_def *original_def = def, *new_def;
618                if (can_remat_ssa_def(def, &remat)) {
619                   /* If this SSA def is re-materializable or based on other
620                    * things we've already spilled, re-materialize it rather
621                    * than spilling and filling.  Anything which is trivially
622                    * re-materializable won't even get here because we take
623                    * those into account in spill_mask above.
624                    */
625                   new_def = remat_ssa_def(&after, def, remap_table);
626                } else if (can_remat_chain_ssa_def(def, &remat, &remat_chain)) {
627                   new_def = remat_chain_ssa_def(&after, &remat_chain, &remat,
628                                                 fill_defs, call_idx,
629                                                 remap_table);
630                   util_dynarray_clear(&remat_chain);
631                } else {
632                   bool is_bool = def->bit_size == 1;
633                   if (is_bool)
634                      def = nir_b2b32(&before, def);
635 
636                   const unsigned comp_size = def->bit_size / 8;
637                   offset = ALIGN(offset, comp_size);
638 
639                   new_def = spill_fill(&before, &after, def,
640                                        index, call_idx,
641                                        offset, options->stack_alignment);
642 
643                   if (is_bool)
644                      new_def = nir_b2b1(&after, new_def);
645 
646                   offset += def->num_components * comp_size;
647                }
648 
649                /* Mark this SSA def as available in the remat set so that, if
650                 * some other SSA def we need is computed based on it, we can
651                 * just re-compute instead of fetching from memory.
652                 */
653                BITSET_SET(remat.set, index);
654 
655                /* For now, we just make a note of this new SSA def.  We'll
656                 * fix things up with the phi builder as a second pass.
657                 */
658                if (fill_defs[index] == NULL) {
659                   fill_defs[index] =
660                      rzalloc_array(fill_defs, nir_def *, num_calls);
661                }
662                fill_defs[index][call_idx] = new_def;
663                _mesa_hash_table_insert(remap_table, original_def, new_def);
664             }
665          }
666 
667          nir_builder *b = &before;
668 
669          offset = ALIGN(offset, options->stack_alignment);
670          max_scratch_size = MAX2(max_scratch_size, offset);
671 
672          /* First thing on the called shader's stack is the resume address
673           * followed by a pointer to the payload.
674           */
675          nir_intrinsic_instr *call = nir_instr_as_intrinsic(instr);
676 
677          /* Lower to generic intrinsics with information about the stack & resume shader. */
678          switch (call->intrinsic) {
679          case nir_intrinsic_trace_ray: {
680             nir_rt_trace_ray(b, call->src[0].ssa, call->src[1].ssa,
681                              call->src[2].ssa, call->src[3].ssa,
682                              call->src[4].ssa, call->src[5].ssa,
683                              call->src[6].ssa, call->src[7].ssa,
684                              call->src[8].ssa, call->src[9].ssa,
685                              call->src[10].ssa,
686                              .call_idx = call_idx, .stack_size = offset);
687             break;
688          }
689 
690          case nir_intrinsic_report_ray_intersection:
691             unreachable("Any-hit shaders must be inlined");
692 
693          case nir_intrinsic_execute_callable: {
694             nir_rt_execute_callable(b, call->src[0].ssa, call->src[1].ssa, .call_idx = call_idx, .stack_size = offset);
695             break;
696          }
697 
698          default:
699             unreachable("Invalid shader call instruction");
700          }
701 
702          nir_rt_resume(b, .call_idx = call_idx, .stack_size = offset);
703 
704          nir_instr_remove(&call->instr);
705 
706          call_idx++;
707       }
708    }
709    assert(call_idx == num_calls);
710    shader->scratch_size = max_scratch_size;
711 
712    struct nir_phi_builder *pb = nir_phi_builder_create(impl);
713    struct pbv_array pbv_arr = {
714       .arr = rzalloc_array(mem_ctx, struct nir_phi_builder_value *,
715                            num_ssa_defs),
716       .len = num_ssa_defs,
717    };
718 
719    const unsigned block_words = BITSET_WORDS(impl->num_blocks);
720    BITSET_WORD *def_blocks = ralloc_array(mem_ctx, BITSET_WORD, block_words);
721 
722    /* Go through and set up phi builder values for each spillable value which
723     * we ever needed to spill at any point.
724     */
725    for (unsigned index = 0; index < num_ssa_defs; index++) {
726       if (fill_defs[index] == NULL)
727          continue;
728 
729       nir_def *def = spill_defs[index];
730 
731       memset(def_blocks, 0, block_words * sizeof(BITSET_WORD));
732       BITSET_SET(def_blocks, def->parent_instr->block->index);
733       for (unsigned call_idx = 0; call_idx < num_calls; call_idx++) {
734          if (fill_defs[index][call_idx] != NULL)
735             BITSET_SET(def_blocks, call_block_indices[call_idx]);
736       }
737 
738       pbv_arr.arr[index] = nir_phi_builder_add_value(pb, def->num_components,
739                                                      def->bit_size, def_blocks);
740    }
741 
742    /* Walk the shader one more time and rewrite SSA defs as needed using the
743     * phi builder.
744     */
745    nir_foreach_block(block, impl) {
746       nir_foreach_instr_safe(instr, block) {
747          nir_def *def = nir_instr_def(instr);
748          if (def != NULL) {
749             struct nir_phi_builder_value *pbv =
750                get_phi_builder_value_for_def(def, &pbv_arr);
751             if (pbv != NULL)
752                nir_phi_builder_value_set_block_def(pbv, block, def);
753          }
754 
755          if (instr->type == nir_instr_type_phi)
756             continue;
757 
758          nir_foreach_src(instr, rewrite_instr_src_from_phi_builder, &pbv_arr);
759 
760          if (instr->type != nir_instr_type_intrinsic)
761             continue;
762 
763          nir_intrinsic_instr *resume = nir_instr_as_intrinsic(instr);
764          if (resume->intrinsic != nir_intrinsic_rt_resume)
765             continue;
766 
767          call_idx = nir_intrinsic_call_idx(resume);
768 
769          /* Technically, this is the wrong place to add the fill defs to the
770           * phi builder values because we haven't seen any of the load_scratch
771           * instructions for this call yet.  However, we know based on how we
772           * emitted them that no value ever gets used until after the load
773           * instruction has been emitted so this should be safe.  If we ever
774           * fail validation due this it likely means a bug in our spilling
775           * code and not the phi re-construction code here.
776           */
777          for (unsigned index = 0; index < num_ssa_defs; index++) {
778             if (fill_defs[index] && fill_defs[index][call_idx]) {
779                nir_phi_builder_value_set_block_def(pbv_arr.arr[index], block,
780                                                    fill_defs[index][call_idx]);
781             }
782          }
783       }
784 
785       nir_if *following_if = nir_block_get_following_if(block);
786       if (following_if) {
787          nir_def *new_def =
788             get_phi_builder_def_for_src(&following_if->condition,
789                                         &pbv_arr, block);
790          if (new_def != NULL)
791             nir_src_rewrite(&following_if->condition, new_def);
792       }
793 
794       /* Handle phi sources that source from this block.  We have to do this
795        * as a separate pass because the phi builder assumes that uses and
796        * defs are processed in an order that respects dominance.  When we have
797        * loops, a phi source may be a back-edge so we have to handle it as if
798        * it were one of the last instructions in the predecessor block.
799        */
800       nir_foreach_phi_src_leaving_block(block,
801                                         rewrite_instr_src_from_phi_builder,
802                                         &pbv_arr);
803    }
804 
805    nir_phi_builder_finish(pb);
806 
807    ralloc_free(mem_ctx);
808 
809    nir_metadata_preserve(impl, nir_metadata_block_index |
810                                   nir_metadata_dominance);
811 }
812 
813 static nir_instr *
find_resume_instr(nir_function_impl * impl,unsigned call_idx)814 find_resume_instr(nir_function_impl *impl, unsigned call_idx)
815 {
816    nir_foreach_block(block, impl) {
817       nir_foreach_instr(instr, block) {
818          if (instr->type != nir_instr_type_intrinsic)
819             continue;
820 
821          nir_intrinsic_instr *resume = nir_instr_as_intrinsic(instr);
822          if (resume->intrinsic != nir_intrinsic_rt_resume)
823             continue;
824 
825          if (nir_intrinsic_call_idx(resume) == call_idx)
826             return &resume->instr;
827       }
828    }
829    unreachable("Couldn't find resume instruction");
830 }
831 
832 /* Walk the CF tree and duplicate the contents of every loop, one half runs on
833  * resume and the other half is for any post-resume loop iterations.  We are
834  * careful in our duplication to ensure that resume_instr is in the resume
835  * half of the loop though a copy of resume_instr will remain in the other
836  * half as well in case the same shader call happens twice.
837  */
838 static bool
duplicate_loop_bodies(nir_function_impl * impl,nir_instr * resume_instr)839 duplicate_loop_bodies(nir_function_impl *impl, nir_instr *resume_instr)
840 {
841    nir_def *resume_reg = NULL;
842    for (nir_cf_node *node = resume_instr->block->cf_node.parent;
843         node->type != nir_cf_node_function; node = node->parent) {
844       if (node->type != nir_cf_node_loop)
845          continue;
846 
847       nir_loop *loop = nir_cf_node_as_loop(node);
848       assert(!nir_loop_has_continue_construct(loop));
849 
850       nir_builder b = nir_builder_create(impl);
851 
852       if (resume_reg == NULL) {
853          /* We only create resume_reg if we encounter a loop.  This way we can
854           * avoid re-validating the shader and calling ssa_to_reg_intrinsics in
855           * the case where it's just if-ladders.
856           */
857          resume_reg = nir_decl_reg(&b, 1, 1, 0);
858 
859          /* Initialize resume to true at the start of the shader, right after
860           * the register is declared at the start.
861           */
862          b.cursor = nir_after_instr(resume_reg->parent_instr);
863          nir_store_reg(&b, nir_imm_true(&b), resume_reg);
864 
865          /* Set resume to false right after the resume instruction */
866          b.cursor = nir_after_instr(resume_instr);
867          nir_store_reg(&b, nir_imm_false(&b), resume_reg);
868       }
869 
870       /* Before we go any further, make sure that everything which exits the
871        * loop or continues around to the top of the loop does so through
872        * registers.  We're about to duplicate the loop body and we'll have
873        * serious trouble if we don't do this.
874        */
875       nir_convert_loop_to_lcssa(loop);
876       nir_lower_phis_to_regs_block(nir_loop_first_block(loop));
877       nir_lower_phis_to_regs_block(
878          nir_cf_node_as_block(nir_cf_node_next(&loop->cf_node)));
879 
880       nir_cf_list cf_list;
881       nir_cf_list_extract(&cf_list, &loop->body);
882 
883       nir_if *_if = nir_if_create(impl->function->shader);
884       b.cursor = nir_after_cf_list(&loop->body);
885       _if->condition = nir_src_for_ssa(nir_load_reg(&b, resume_reg));
886       nir_cf_node_insert(nir_after_cf_list(&loop->body), &_if->cf_node);
887 
888       nir_cf_list clone;
889       nir_cf_list_clone(&clone, &cf_list, &loop->cf_node, NULL);
890 
891       /* Insert the clone in the else and the original in the then so that
892        * the resume_instr remains valid even after the duplication.
893        */
894       nir_cf_reinsert(&cf_list, nir_before_cf_list(&_if->then_list));
895       nir_cf_reinsert(&clone, nir_before_cf_list(&_if->else_list));
896    }
897 
898    if (resume_reg != NULL)
899       nir_metadata_preserve(impl, nir_metadata_none);
900 
901    return resume_reg != NULL;
902 }
903 
904 static bool
cf_node_contains_block(nir_cf_node * node,nir_block * block)905 cf_node_contains_block(nir_cf_node *node, nir_block *block)
906 {
907    for (nir_cf_node *n = &block->cf_node; n != NULL; n = n->parent) {
908       if (n == node)
909          return true;
910    }
911 
912    return false;
913 }
914 
915 static void
rewrite_phis_to_pred(nir_block * block,nir_block * pred)916 rewrite_phis_to_pred(nir_block *block, nir_block *pred)
917 {
918    nir_foreach_phi(phi, block) {
919       ASSERTED bool found = false;
920       nir_foreach_phi_src(phi_src, phi) {
921          if (phi_src->pred == pred) {
922             found = true;
923             nir_def_rewrite_uses(&phi->def, phi_src->src.ssa);
924             break;
925          }
926       }
927       assert(found);
928    }
929 }
930 
931 static bool
cursor_is_after_jump(nir_cursor cursor)932 cursor_is_after_jump(nir_cursor cursor)
933 {
934    switch (cursor.option) {
935    case nir_cursor_before_instr:
936    case nir_cursor_before_block:
937       return false;
938    case nir_cursor_after_instr:
939       return cursor.instr->type == nir_instr_type_jump;
940    case nir_cursor_after_block:
941       return nir_block_ends_in_jump(cursor.block);
942       ;
943    }
944    unreachable("Invalid cursor option");
945 }
946 
947 /** Flattens if ladders leading up to a resume
948  *
949  * Given a resume_instr, this function flattens any if ladders leading to the
950  * resume instruction and deletes any code that cannot be encountered on a
951  * direct path to the resume instruction.  This way we get, for the most part,
952  * straight-line control-flow up to the resume instruction.
953  *
954  * While we do this flattening, we also move any code which is in the remat
955  * set up to the top of the function or to the top of the resume portion of
956  * the current loop.  We don't worry about control-flow as we do this because
957  * phis will never be in the remat set (see can_remat_instr) and so nothing
958  * control-dependent will ever need to be re-materialized.  It is possible
959  * that this algorithm will preserve too many instructions by moving them to
960  * the top but we leave that for DCE to clean up.  Any code not in the remat
961  * set is deleted because it's either unused in the continuation or else
962  * unspilled from a previous continuation and the unspill code is after the
963  * resume instruction.
964  *
965  * If, for instance, we have something like this:
966  *
967  *    // block 0
968  *    if (cond1) {
969  *       // block 1
970  *    } else {
971  *       // block 2
972  *       if (cond2) {
973  *          // block 3
974  *          resume;
975  *          if (cond3) {
976  *             // block 4
977  *          }
978  *       } else {
979  *          // block 5
980  *       }
981  *    }
982  *
983  * then we know, because we know the resume instruction had to be encoutered,
984  * that cond1 = false and cond2 = true and we lower as follows:
985  *
986  *    // block 0
987  *    // block 2
988  *    // block 3
989  *    resume;
990  *    if (cond3) {
991  *       // block 4
992  *    }
993  *
994  * As you can see, the code in blocks 1 and 5 was removed because there is no
995  * path from the start of the shader to the resume instruction which execute
996  * blocks 1 or 5.  Any remat code from blocks 0, 2, and 3 is preserved and
997  * moved to the top.  If the resume instruction is inside a loop then we know
998  * a priori that it is of the form
999  *
1000  *    loop {
1001  *       if (resume) {
1002  *          // Contents containing resume_instr
1003  *       } else {
1004  *          // Second copy of contents
1005  *       }
1006  *    }
1007  *
1008  * In this case, we only descend into the first half of the loop.  The second
1009  * half is left alone as that portion is only ever executed after the resume
1010  * instruction.
1011  */
1012 static bool
flatten_resume_if_ladder(nir_builder * b,nir_cf_node * parent_node,struct exec_list * child_list,bool child_list_contains_cursor,nir_instr * resume_instr,struct sized_bitset * remat)1013 flatten_resume_if_ladder(nir_builder *b,
1014                          nir_cf_node *parent_node,
1015                          struct exec_list *child_list,
1016                          bool child_list_contains_cursor,
1017                          nir_instr *resume_instr,
1018                          struct sized_bitset *remat)
1019 {
1020    nir_cf_list cf_list;
1021 
1022    /* If our child list contains the cursor instruction then we start out
1023     * before the cursor instruction.  We need to know this so that we can skip
1024     * moving instructions which are already before the cursor.
1025     */
1026    bool before_cursor = child_list_contains_cursor;
1027 
1028    nir_cf_node *resume_node = NULL;
1029    foreach_list_typed_safe(nir_cf_node, child, node, child_list) {
1030       switch (child->type) {
1031       case nir_cf_node_block: {
1032          nir_block *block = nir_cf_node_as_block(child);
1033          if (b->cursor.option == nir_cursor_before_block &&
1034              b->cursor.block == block) {
1035             assert(before_cursor);
1036             before_cursor = false;
1037          }
1038          nir_foreach_instr_safe(instr, block) {
1039             if ((b->cursor.option == nir_cursor_before_instr ||
1040                  b->cursor.option == nir_cursor_after_instr) &&
1041                 b->cursor.instr == instr) {
1042                assert(nir_cf_node_is_first(&block->cf_node));
1043                assert(before_cursor);
1044                before_cursor = false;
1045                continue;
1046             }
1047 
1048             if (instr == resume_instr)
1049                goto found_resume;
1050 
1051             if (!before_cursor && can_remat_instr(instr, remat)) {
1052                nir_instr_remove(instr);
1053                nir_instr_insert(b->cursor, instr);
1054                b->cursor = nir_after_instr(instr);
1055 
1056                nir_def *def = nir_instr_def(instr);
1057                BITSET_SET(remat->set, def->index);
1058             }
1059          }
1060          if (b->cursor.option == nir_cursor_after_block &&
1061              b->cursor.block == block) {
1062             assert(before_cursor);
1063             before_cursor = false;
1064          }
1065          break;
1066       }
1067 
1068       case nir_cf_node_if: {
1069          assert(!before_cursor);
1070          nir_if *_if = nir_cf_node_as_if(child);
1071          if (flatten_resume_if_ladder(b, &_if->cf_node, &_if->then_list,
1072                                       false, resume_instr, remat)) {
1073             resume_node = child;
1074             rewrite_phis_to_pred(nir_cf_node_as_block(nir_cf_node_next(child)),
1075                                  nir_if_last_then_block(_if));
1076             goto found_resume;
1077          }
1078 
1079          if (flatten_resume_if_ladder(b, &_if->cf_node, &_if->else_list,
1080                                       false, resume_instr, remat)) {
1081             resume_node = child;
1082             rewrite_phis_to_pred(nir_cf_node_as_block(nir_cf_node_next(child)),
1083                                  nir_if_last_else_block(_if));
1084             goto found_resume;
1085          }
1086          break;
1087       }
1088 
1089       case nir_cf_node_loop: {
1090          assert(!before_cursor);
1091          nir_loop *loop = nir_cf_node_as_loop(child);
1092          assert(!nir_loop_has_continue_construct(loop));
1093 
1094          if (cf_node_contains_block(&loop->cf_node, resume_instr->block)) {
1095             /* Thanks to our loop body duplication pass, every level of loop
1096              * containing the resume instruction contains exactly three nodes:
1097              * two blocks and an if.  We don't want to lower away this if
1098              * because it's the resume selection if.  The resume half is
1099              * always the then_list so that's what we want to flatten.
1100              */
1101             nir_block *header = nir_loop_first_block(loop);
1102             nir_if *_if = nir_cf_node_as_if(nir_cf_node_next(&header->cf_node));
1103 
1104             /* We want to place anything re-materialized from inside the loop
1105              * at the top of the resume half of the loop.
1106              */
1107             nir_builder bl = nir_builder_at(nir_before_cf_list(&_if->then_list));
1108 
1109             ASSERTED bool found =
1110                flatten_resume_if_ladder(&bl, &_if->cf_node, &_if->then_list,
1111                                         true, resume_instr, remat);
1112             assert(found);
1113             resume_node = child;
1114             goto found_resume;
1115          } else {
1116             ASSERTED bool found =
1117                flatten_resume_if_ladder(b, &loop->cf_node, &loop->body,
1118                                         false, resume_instr, remat);
1119             assert(!found);
1120          }
1121          break;
1122       }
1123 
1124       case nir_cf_node_function:
1125          unreachable("Unsupported CF node type");
1126       }
1127    }
1128    assert(!before_cursor);
1129 
1130    /* If we got here, we didn't find the resume node or instruction. */
1131    return false;
1132 
1133 found_resume:
1134    /* If we got here then we found either the resume node or the resume
1135     * instruction in this CF list.
1136     */
1137    if (resume_node) {
1138       /* If the resume instruction is buried in side one of our children CF
1139        * nodes, resume_node now points to that child.
1140        */
1141       if (resume_node->type == nir_cf_node_if) {
1142          /* Thanks to the recursive call, all of the interesting contents of
1143           * resume_node have been copied before the cursor.  We just need to
1144           * copy the stuff after resume_node.
1145           */
1146          nir_cf_extract(&cf_list, nir_after_cf_node(resume_node),
1147                         nir_after_cf_list(child_list));
1148       } else {
1149          /* The loop contains its own cursor and still has useful stuff in it.
1150           * We want to move everything after and including the loop to before
1151           * the cursor.
1152           */
1153          assert(resume_node->type == nir_cf_node_loop);
1154          nir_cf_extract(&cf_list, nir_before_cf_node(resume_node),
1155                         nir_after_cf_list(child_list));
1156       }
1157    } else {
1158       /* If we found the resume instruction in one of our blocks, grab
1159        * everything after it in the entire list (not just the one block), and
1160        * place it before the cursor instr.
1161        */
1162       nir_cf_extract(&cf_list, nir_after_instr(resume_instr),
1163                      nir_after_cf_list(child_list));
1164    }
1165 
1166    /* If the resume instruction is in the first block of the child_list,
1167     * and the cursor is still before that block, the nir_cf_extract() may
1168     * extract the block object pointed by the cursor, and instead create
1169     * a new one for the code before the resume. In such case the cursor
1170     * will be broken, as it will point to a block which is no longer
1171     * in a function.
1172     *
1173     * Luckily, in both cases when this is possible, the intended cursor
1174     * position is right before the child_list, so we can fix the cursor here.
1175     */
1176    if (child_list_contains_cursor &&
1177        b->cursor.option == nir_cursor_before_block &&
1178        b->cursor.block->cf_node.parent == NULL)
1179       b->cursor = nir_before_cf_list(child_list);
1180 
1181    if (cursor_is_after_jump(b->cursor)) {
1182       /* If the resume instruction is in a loop, it's possible cf_list ends
1183        * in a break or continue instruction, in which case we don't want to
1184        * insert anything.  It's also possible we have an early return if
1185        * someone hasn't lowered those yet.  In either case, nothing after that
1186        * point executes in this context so we can delete it.
1187        */
1188       nir_cf_delete(&cf_list);
1189    } else {
1190       b->cursor = nir_cf_reinsert(&cf_list, b->cursor);
1191    }
1192 
1193    if (!resume_node) {
1194       /* We want the resume to be the first "interesting" instruction */
1195       nir_instr_remove(resume_instr);
1196       nir_instr_insert(nir_before_impl(b->impl), resume_instr);
1197    }
1198 
1199    /* We've copied everything interesting out of this CF list to before the
1200     * cursor.  Delete everything else.
1201     */
1202    if (child_list_contains_cursor) {
1203       nir_cf_extract(&cf_list, b->cursor, nir_after_cf_list(child_list));
1204    } else {
1205       nir_cf_list_extract(&cf_list, child_list);
1206    }
1207    nir_cf_delete(&cf_list);
1208 
1209    return true;
1210 }
1211 
1212 typedef bool (*wrap_instr_callback)(nir_instr *instr);
1213 
1214 static bool
wrap_instr(nir_builder * b,nir_instr * instr,void * data)1215 wrap_instr(nir_builder *b, nir_instr *instr, void *data)
1216 {
1217    wrap_instr_callback callback = data;
1218    if (!callback(instr))
1219       return false;
1220 
1221    b->cursor = nir_before_instr(instr);
1222 
1223    nir_if *_if = nir_push_if(b, nir_imm_true(b));
1224    nir_pop_if(b, NULL);
1225 
1226    nir_cf_list cf_list;
1227    nir_cf_extract(&cf_list, nir_before_instr(instr), nir_after_instr(instr));
1228    nir_cf_reinsert(&cf_list, nir_before_block(nir_if_first_then_block(_if)));
1229 
1230    return true;
1231 }
1232 
1233 /* This pass wraps jump instructions in a dummy if block so that when
1234  * flatten_resume_if_ladder() does its job, it doesn't move a jump instruction
1235  * directly in front of another instruction which the NIR control flow helpers
1236  * do not allow.
1237  */
1238 static bool
wrap_instrs(nir_shader * shader,wrap_instr_callback callback)1239 wrap_instrs(nir_shader *shader, wrap_instr_callback callback)
1240 {
1241    return nir_shader_instructions_pass(shader, wrap_instr,
1242                                        nir_metadata_none, callback);
1243 }
1244 
1245 static bool
instr_is_jump(nir_instr * instr)1246 instr_is_jump(nir_instr *instr)
1247 {
1248    return instr->type == nir_instr_type_jump;
1249 }
1250 
1251 static nir_instr *
lower_resume(nir_shader * shader,int call_idx)1252 lower_resume(nir_shader *shader, int call_idx)
1253 {
1254    wrap_instrs(shader, instr_is_jump);
1255 
1256    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
1257    nir_instr *resume_instr = find_resume_instr(impl, call_idx);
1258 
1259    if (duplicate_loop_bodies(impl, resume_instr)) {
1260       nir_validate_shader(shader, "after duplicate_loop_bodies in "
1261                                   "nir_lower_shader_calls");
1262       /* If we duplicated the bodies of any loops, run reg_intrinsics_to_ssa to
1263        * get rid of all those pesky registers we just added.
1264        */
1265       NIR_PASS_V(shader, nir_lower_reg_intrinsics_to_ssa);
1266    }
1267 
1268    /* Re-index nir_def::index.  We don't care about actual liveness in
1269     * this pass but, so we can use the same helpers as the spilling pass, we
1270     * need to make sure that live_index is something sane.  It's used
1271     * constantly for determining if an SSA value has been added since the
1272     * start of the pass.
1273     */
1274    nir_index_ssa_defs(impl);
1275 
1276    void *mem_ctx = ralloc_context(shader);
1277 
1278    /* Used to track which things may have been assumed to be re-materialized
1279     * by the spilling pass and which we shouldn't delete.
1280     */
1281    struct sized_bitset remat = bitset_create(mem_ctx, impl->ssa_alloc);
1282 
1283    /* Create a nop instruction to use as a cursor as we extract and re-insert
1284     * stuff into the CFG.
1285     */
1286    nir_builder b = nir_builder_at(nir_before_impl(impl));
1287    ASSERTED bool found =
1288       flatten_resume_if_ladder(&b, &impl->cf_node, &impl->body,
1289                                true, resume_instr, &remat);
1290    assert(found);
1291 
1292    ralloc_free(mem_ctx);
1293 
1294    nir_metadata_preserve(impl, nir_metadata_none);
1295 
1296    nir_validate_shader(shader, "after flatten_resume_if_ladder in "
1297                                "nir_lower_shader_calls");
1298 
1299    return resume_instr;
1300 }
1301 
1302 static void
replace_resume_with_halt(nir_shader * shader,nir_instr * keep)1303 replace_resume_with_halt(nir_shader *shader, nir_instr *keep)
1304 {
1305    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
1306 
1307    nir_builder b = nir_builder_create(impl);
1308 
1309    nir_foreach_block_safe(block, impl) {
1310       nir_foreach_instr_safe(instr, block) {
1311          if (instr == keep)
1312             continue;
1313 
1314          if (instr->type != nir_instr_type_intrinsic)
1315             continue;
1316 
1317          nir_intrinsic_instr *resume = nir_instr_as_intrinsic(instr);
1318          if (resume->intrinsic != nir_intrinsic_rt_resume)
1319             continue;
1320 
1321          /* If this is some other resume, then we've kicked off a ray or
1322           * bindless thread and we don't want to go any further in this
1323           * shader.  Insert a halt so that NIR will delete any instructions
1324           * dominated by this call instruction including the scratch_load
1325           * instructions we inserted.
1326           */
1327          nir_cf_list cf_list;
1328          nir_cf_extract(&cf_list, nir_after_instr(&resume->instr),
1329                         nir_after_block(block));
1330          nir_cf_delete(&cf_list);
1331          b.cursor = nir_instr_remove(&resume->instr);
1332          nir_jump(&b, nir_jump_halt);
1333          break;
1334       }
1335    }
1336 }
1337 
1338 struct lower_scratch_state {
1339    nir_address_format address_format;
1340 };
1341 
1342 static bool
lower_stack_instr_to_scratch(struct nir_builder * b,nir_instr * instr,void * data)1343 lower_stack_instr_to_scratch(struct nir_builder *b, nir_instr *instr, void *data)
1344 {
1345    struct lower_scratch_state *state = data;
1346 
1347    if (instr->type != nir_instr_type_intrinsic)
1348       return false;
1349 
1350    nir_intrinsic_instr *stack = nir_instr_as_intrinsic(instr);
1351    switch (stack->intrinsic) {
1352    case nir_intrinsic_load_stack: {
1353       b->cursor = nir_instr_remove(instr);
1354       nir_def *data, *old_data = nir_instr_def(instr);
1355 
1356       if (state->address_format == nir_address_format_64bit_global) {
1357          nir_def *addr = nir_iadd_imm(b,
1358                                       nir_load_scratch_base_ptr(b, 1, 64, 1),
1359                                       nir_intrinsic_base(stack));
1360          data = nir_build_load_global(b,
1361                                       stack->def.num_components,
1362                                       stack->def.bit_size,
1363                                       addr,
1364                                       .align_mul = nir_intrinsic_align_mul(stack),
1365                                       .align_offset = nir_intrinsic_align_offset(stack));
1366       } else {
1367          assert(state->address_format == nir_address_format_32bit_offset);
1368          data = nir_load_scratch(b,
1369                                  old_data->num_components,
1370                                  old_data->bit_size,
1371                                  nir_imm_int(b, nir_intrinsic_base(stack)),
1372                                  .align_mul = nir_intrinsic_align_mul(stack),
1373                                  .align_offset = nir_intrinsic_align_offset(stack));
1374       }
1375       nir_def_rewrite_uses(old_data, data);
1376       break;
1377    }
1378 
1379    case nir_intrinsic_store_stack: {
1380       b->cursor = nir_instr_remove(instr);
1381       nir_def *data = stack->src[0].ssa;
1382 
1383       if (state->address_format == nir_address_format_64bit_global) {
1384          nir_def *addr = nir_iadd_imm(b,
1385                                       nir_load_scratch_base_ptr(b, 1, 64, 1),
1386                                       nir_intrinsic_base(stack));
1387          nir_store_global(b, addr,
1388                           nir_intrinsic_align_mul(stack),
1389                           data,
1390                           nir_component_mask(data->num_components));
1391       } else {
1392          assert(state->address_format == nir_address_format_32bit_offset);
1393          nir_store_scratch(b, data,
1394                            nir_imm_int(b, nir_intrinsic_base(stack)),
1395                            .align_mul = nir_intrinsic_align_mul(stack),
1396                            .write_mask = BITFIELD_MASK(data->num_components));
1397       }
1398       break;
1399    }
1400 
1401    default:
1402       return false;
1403    }
1404 
1405    return true;
1406 }
1407 
1408 static bool
nir_lower_stack_to_scratch(nir_shader * shader,nir_address_format address_format)1409 nir_lower_stack_to_scratch(nir_shader *shader,
1410                            nir_address_format address_format)
1411 {
1412    struct lower_scratch_state state = {
1413       .address_format = address_format,
1414    };
1415 
1416    return nir_shader_instructions_pass(shader,
1417                                        lower_stack_instr_to_scratch,
1418                                        nir_metadata_block_index |
1419                                           nir_metadata_dominance,
1420                                        &state);
1421 }
1422 
1423 static bool
opt_remove_respills_instr(struct nir_builder * b,nir_intrinsic_instr * store_intrin,void * data)1424 opt_remove_respills_instr(struct nir_builder *b,
1425                           nir_intrinsic_instr *store_intrin, void *data)
1426 {
1427    if (store_intrin->intrinsic != nir_intrinsic_store_stack)
1428       return false;
1429 
1430    nir_instr *value_instr = store_intrin->src[0].ssa->parent_instr;
1431    if (value_instr->type != nir_instr_type_intrinsic)
1432       return false;
1433 
1434    nir_intrinsic_instr *load_intrin = nir_instr_as_intrinsic(value_instr);
1435    if (load_intrin->intrinsic != nir_intrinsic_load_stack)
1436       return false;
1437 
1438    if (nir_intrinsic_base(load_intrin) != nir_intrinsic_base(store_intrin))
1439       return false;
1440 
1441    nir_instr_remove(&store_intrin->instr);
1442    return true;
1443 }
1444 
1445 /* After shader split, look at stack load/store operations. If we're loading
1446  * and storing the same value at the same location, we can drop the store
1447  * instruction.
1448  */
1449 static bool
nir_opt_remove_respills(nir_shader * shader)1450 nir_opt_remove_respills(nir_shader *shader)
1451 {
1452    return nir_shader_intrinsics_pass(shader, opt_remove_respills_instr,
1453                                        nir_metadata_block_index |
1454                                           nir_metadata_dominance,
1455                                        NULL);
1456 }
1457 
1458 static void
add_use_mask(struct hash_table_u64 * offset_to_mask,unsigned offset,unsigned mask)1459 add_use_mask(struct hash_table_u64 *offset_to_mask,
1460              unsigned offset, unsigned mask)
1461 {
1462    uintptr_t old_mask = (uintptr_t)
1463       _mesa_hash_table_u64_search(offset_to_mask, offset);
1464 
1465    _mesa_hash_table_u64_insert(offset_to_mask, offset,
1466                                (void *)(uintptr_t)(old_mask | mask));
1467 }
1468 
1469 /* When splitting the shaders, we might have inserted store & loads of vec4s,
1470  * because a live value is a 4 components. But sometimes, only some components
1471  * of that vec4 will be used by after the scratch load. This pass removes the
1472  * unused components of scratch load/stores.
1473  */
1474 static bool
nir_opt_trim_stack_values(nir_shader * shader)1475 nir_opt_trim_stack_values(nir_shader *shader)
1476 {
1477    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
1478 
1479    struct hash_table_u64 *value_id_to_mask = _mesa_hash_table_u64_create(NULL);
1480    bool progress = false;
1481 
1482    /* Find all the loads and how their value is being used */
1483    nir_foreach_block_safe(block, impl) {
1484       nir_foreach_instr_safe(instr, block) {
1485          if (instr->type != nir_instr_type_intrinsic)
1486             continue;
1487 
1488          nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1489          if (intrin->intrinsic != nir_intrinsic_load_stack)
1490             continue;
1491 
1492          const unsigned value_id = nir_intrinsic_value_id(intrin);
1493 
1494          const unsigned mask =
1495             nir_def_components_read(nir_instr_def(instr));
1496          add_use_mask(value_id_to_mask, value_id, mask);
1497       }
1498    }
1499 
1500    /* For each store, if it stores more than is being used, trim it.
1501     * Otherwise, remove it from the hash table.
1502     */
1503    nir_foreach_block_safe(block, impl) {
1504       nir_foreach_instr_safe(instr, block) {
1505          if (instr->type != nir_instr_type_intrinsic)
1506             continue;
1507 
1508          nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1509          if (intrin->intrinsic != nir_intrinsic_store_stack)
1510             continue;
1511 
1512          const unsigned value_id = nir_intrinsic_value_id(intrin);
1513 
1514          const unsigned write_mask = nir_intrinsic_write_mask(intrin);
1515          const unsigned read_mask = (uintptr_t)
1516             _mesa_hash_table_u64_search(value_id_to_mask, value_id);
1517 
1518          /* Already removed from the table, nothing to do */
1519          if (read_mask == 0)
1520             continue;
1521 
1522          /* Matching read/write mask, nothing to do, remove from the table. */
1523          if (write_mask == read_mask) {
1524             _mesa_hash_table_u64_remove(value_id_to_mask, value_id);
1525             continue;
1526          }
1527 
1528          nir_builder b = nir_builder_at(nir_before_instr(instr));
1529 
1530          nir_def *value = nir_channels(&b, intrin->src[0].ssa, read_mask);
1531          nir_src_rewrite(&intrin->src[0], value);
1532 
1533          intrin->num_components = util_bitcount(read_mask);
1534          nir_intrinsic_set_write_mask(intrin, (1u << intrin->num_components) - 1);
1535 
1536          progress = true;
1537       }
1538    }
1539 
1540    /* For each load remaining in the hash table (only the ones we changed the
1541     * number of components of), apply triming/reswizzle.
1542     */
1543    nir_foreach_block_safe(block, impl) {
1544       nir_foreach_instr_safe(instr, block) {
1545          if (instr->type != nir_instr_type_intrinsic)
1546             continue;
1547 
1548          nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1549          if (intrin->intrinsic != nir_intrinsic_load_stack)
1550             continue;
1551 
1552          const unsigned value_id = nir_intrinsic_value_id(intrin);
1553 
1554          unsigned read_mask = (uintptr_t)
1555             _mesa_hash_table_u64_search(value_id_to_mask, value_id);
1556          if (read_mask == 0)
1557             continue;
1558 
1559          unsigned swiz_map[NIR_MAX_VEC_COMPONENTS] = {
1560             0,
1561          };
1562          unsigned swiz_count = 0;
1563          u_foreach_bit(idx, read_mask)
1564             swiz_map[idx] = swiz_count++;
1565 
1566          nir_def *def = nir_instr_def(instr);
1567 
1568          nir_foreach_use_safe(use_src, def) {
1569             if (nir_src_parent_instr(use_src)->type == nir_instr_type_alu) {
1570                nir_alu_instr *alu = nir_instr_as_alu(nir_src_parent_instr(use_src));
1571                nir_alu_src *alu_src = exec_node_data(nir_alu_src, use_src, src);
1572 
1573                unsigned count = alu->def.num_components;
1574                for (unsigned idx = 0; idx < count; ++idx)
1575                   alu_src->swizzle[idx] = swiz_map[alu_src->swizzle[idx]];
1576             } else if (nir_src_parent_instr(use_src)->type == nir_instr_type_intrinsic) {
1577                nir_intrinsic_instr *use_intrin =
1578                   nir_instr_as_intrinsic(nir_src_parent_instr(use_src));
1579                assert(nir_intrinsic_has_write_mask(use_intrin));
1580                unsigned write_mask = nir_intrinsic_write_mask(use_intrin);
1581                unsigned new_write_mask = 0;
1582                u_foreach_bit(idx, write_mask)
1583                   new_write_mask |= 1 << swiz_map[idx];
1584                nir_intrinsic_set_write_mask(use_intrin, new_write_mask);
1585             } else {
1586                unreachable("invalid instruction type");
1587             }
1588          }
1589 
1590          intrin->def.num_components = intrin->num_components = swiz_count;
1591 
1592          progress = true;
1593       }
1594    }
1595 
1596    nir_metadata_preserve(impl,
1597                          progress ? (nir_metadata_dominance |
1598                                      nir_metadata_block_index |
1599                                      nir_metadata_loop_analysis)
1600                                   : nir_metadata_all);
1601 
1602    _mesa_hash_table_u64_destroy(value_id_to_mask);
1603 
1604    return progress;
1605 }
1606 
1607 struct scratch_item {
1608    unsigned old_offset;
1609    unsigned new_offset;
1610    unsigned bit_size;
1611    unsigned num_components;
1612    unsigned value;
1613    unsigned call_idx;
1614 };
1615 
1616 static int
sort_scratch_item_by_size_and_value_id(const void * _item1,const void * _item2)1617 sort_scratch_item_by_size_and_value_id(const void *_item1, const void *_item2)
1618 {
1619    const struct scratch_item *item1 = _item1;
1620    const struct scratch_item *item2 = _item2;
1621 
1622    /* By ascending value_id */
1623    if (item1->bit_size == item2->bit_size)
1624       return (int)item1->value - (int)item2->value;
1625 
1626    /* By descending size */
1627    return (int)item2->bit_size - (int)item1->bit_size;
1628 }
1629 
1630 static bool
nir_opt_sort_and_pack_stack(nir_shader * shader,unsigned start_call_scratch,unsigned stack_alignment,unsigned num_calls)1631 nir_opt_sort_and_pack_stack(nir_shader *shader,
1632                             unsigned start_call_scratch,
1633                             unsigned stack_alignment,
1634                             unsigned num_calls)
1635 {
1636    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
1637 
1638    void *mem_ctx = ralloc_context(NULL);
1639 
1640    struct hash_table_u64 *value_id_to_item =
1641       _mesa_hash_table_u64_create(mem_ctx);
1642    struct util_dynarray ops;
1643    util_dynarray_init(&ops, mem_ctx);
1644 
1645    for (unsigned call_idx = 0; call_idx < num_calls; call_idx++) {
1646       _mesa_hash_table_u64_clear(value_id_to_item);
1647       util_dynarray_clear(&ops);
1648 
1649       /* Find all the stack load and their offset. */
1650       nir_foreach_block_safe(block, impl) {
1651          nir_foreach_instr_safe(instr, block) {
1652             if (instr->type != nir_instr_type_intrinsic)
1653                continue;
1654 
1655             nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1656             if (intrin->intrinsic != nir_intrinsic_load_stack)
1657                continue;
1658 
1659             if (nir_intrinsic_call_idx(intrin) != call_idx)
1660                continue;
1661 
1662             const unsigned value_id = nir_intrinsic_value_id(intrin);
1663             nir_def *def = nir_instr_def(instr);
1664 
1665             assert(_mesa_hash_table_u64_search(value_id_to_item,
1666                                                value_id) == NULL);
1667 
1668             struct scratch_item item = {
1669                .old_offset = nir_intrinsic_base(intrin),
1670                .bit_size = def->bit_size,
1671                .num_components = def->num_components,
1672                .value = value_id,
1673             };
1674 
1675             util_dynarray_append(&ops, struct scratch_item, item);
1676             _mesa_hash_table_u64_insert(value_id_to_item, value_id, (void *)(uintptr_t) true);
1677          }
1678       }
1679 
1680       /* Sort scratch item by component size. */
1681       if (util_dynarray_num_elements(&ops, struct scratch_item)) {
1682          qsort(util_dynarray_begin(&ops),
1683                util_dynarray_num_elements(&ops, struct scratch_item),
1684                sizeof(struct scratch_item),
1685                sort_scratch_item_by_size_and_value_id);
1686       }
1687 
1688       /* Reorder things on the stack */
1689       _mesa_hash_table_u64_clear(value_id_to_item);
1690 
1691       unsigned scratch_size = start_call_scratch;
1692       util_dynarray_foreach(&ops, struct scratch_item, item) {
1693          item->new_offset = ALIGN(scratch_size, item->bit_size / 8);
1694          scratch_size = item->new_offset + (item->bit_size * item->num_components) / 8;
1695          _mesa_hash_table_u64_insert(value_id_to_item, item->value, item);
1696       }
1697       shader->scratch_size = ALIGN(scratch_size, stack_alignment);
1698 
1699       /* Update offsets in the instructions */
1700       nir_foreach_block_safe(block, impl) {
1701          nir_foreach_instr_safe(instr, block) {
1702             if (instr->type != nir_instr_type_intrinsic)
1703                continue;
1704 
1705             nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1706             switch (intrin->intrinsic) {
1707             case nir_intrinsic_load_stack:
1708             case nir_intrinsic_store_stack: {
1709                if (nir_intrinsic_call_idx(intrin) != call_idx)
1710                   continue;
1711 
1712                struct scratch_item *item =
1713                   _mesa_hash_table_u64_search(value_id_to_item,
1714                                               nir_intrinsic_value_id(intrin));
1715                assert(item);
1716 
1717                nir_intrinsic_set_base(intrin, item->new_offset);
1718                break;
1719             }
1720 
1721             case nir_intrinsic_rt_trace_ray:
1722             case nir_intrinsic_rt_execute_callable:
1723             case nir_intrinsic_rt_resume:
1724                if (nir_intrinsic_call_idx(intrin) != call_idx)
1725                   continue;
1726                nir_intrinsic_set_stack_size(intrin, shader->scratch_size);
1727                break;
1728 
1729             default:
1730                break;
1731             }
1732          }
1733       }
1734    }
1735 
1736    ralloc_free(mem_ctx);
1737 
1738    nir_shader_preserve_all_metadata(shader);
1739 
1740    return true;
1741 }
1742 
1743 static unsigned
nir_block_loop_depth(nir_block * block)1744 nir_block_loop_depth(nir_block *block)
1745 {
1746    nir_cf_node *node = &block->cf_node;
1747    unsigned loop_depth = 0;
1748 
1749    while (node != NULL) {
1750       if (node->type == nir_cf_node_loop)
1751          loop_depth++;
1752       node = node->parent;
1753    }
1754 
1755    return loop_depth;
1756 }
1757 
1758 /* Find the last block dominating all the uses of a SSA value. */
1759 static nir_block *
find_last_dominant_use_block(nir_function_impl * impl,nir_def * value)1760 find_last_dominant_use_block(nir_function_impl *impl, nir_def *value)
1761 {
1762    nir_block *old_block = value->parent_instr->block;
1763    unsigned old_block_loop_depth = nir_block_loop_depth(old_block);
1764 
1765    nir_foreach_block_reverse_safe(block, impl) {
1766       bool fits = true;
1767 
1768       /* Store on the current block of the value */
1769       if (block == old_block)
1770          return block;
1771 
1772       /* Don't move instructions deeper into loops, this would generate more
1773        * memory traffic.
1774        */
1775       unsigned block_loop_depth = nir_block_loop_depth(block);
1776       if (block_loop_depth > old_block_loop_depth)
1777          continue;
1778 
1779       nir_foreach_if_use(src, value) {
1780          nir_block *block_before_if =
1781             nir_cf_node_as_block(nir_cf_node_prev(&nir_src_parent_if(src)->cf_node));
1782          if (!nir_block_dominates(block, block_before_if)) {
1783             fits = false;
1784             break;
1785          }
1786       }
1787       if (!fits)
1788          continue;
1789 
1790       nir_foreach_use(src, value) {
1791          if (nir_src_parent_instr(src)->type == nir_instr_type_phi &&
1792              block == nir_src_parent_instr(src)->block) {
1793             fits = false;
1794             break;
1795          }
1796 
1797          if (!nir_block_dominates(block, nir_src_parent_instr(src)->block)) {
1798             fits = false;
1799             break;
1800          }
1801       }
1802       if (!fits)
1803          continue;
1804 
1805       return block;
1806    }
1807    unreachable("Cannot find block");
1808 }
1809 
1810 /* Put the scratch loads in the branches where they're needed. */
1811 static bool
nir_opt_stack_loads(nir_shader * shader)1812 nir_opt_stack_loads(nir_shader *shader)
1813 {
1814    bool progress = false;
1815 
1816    nir_foreach_function_impl(impl, shader) {
1817       nir_metadata_require(impl, nir_metadata_dominance |
1818                                     nir_metadata_block_index);
1819 
1820       bool func_progress = false;
1821       nir_foreach_block_safe(block, impl) {
1822          nir_foreach_instr_safe(instr, block) {
1823             if (instr->type != nir_instr_type_intrinsic)
1824                continue;
1825 
1826             nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1827             if (intrin->intrinsic != nir_intrinsic_load_stack)
1828                continue;
1829 
1830             nir_def *value = &intrin->def;
1831             nir_block *new_block = find_last_dominant_use_block(impl, value);
1832             if (new_block == block)
1833                continue;
1834 
1835             /* Move the scratch load in the new block, after the phis. */
1836             nir_instr_remove(instr);
1837             nir_instr_insert(nir_before_block_after_phis(new_block), instr);
1838 
1839             func_progress = true;
1840          }
1841       }
1842 
1843       nir_metadata_preserve(impl,
1844                             func_progress ? (nir_metadata_block_index |
1845                                              nir_metadata_dominance |
1846                                              nir_metadata_loop_analysis)
1847                                           : nir_metadata_all);
1848 
1849       progress |= func_progress;
1850    }
1851 
1852    return progress;
1853 }
1854 
1855 static bool
split_stack_components_instr(struct nir_builder * b,nir_intrinsic_instr * intrin,void * data)1856 split_stack_components_instr(struct nir_builder *b,
1857                              nir_intrinsic_instr *intrin, void *data)
1858 {
1859    if (intrin->intrinsic != nir_intrinsic_load_stack &&
1860        intrin->intrinsic != nir_intrinsic_store_stack)
1861       return false;
1862 
1863    if (intrin->intrinsic == nir_intrinsic_load_stack &&
1864        intrin->def.num_components == 1)
1865       return false;
1866 
1867    if (intrin->intrinsic == nir_intrinsic_store_stack &&
1868        intrin->src[0].ssa->num_components == 1)
1869       return false;
1870 
1871    b->cursor = nir_before_instr(&intrin->instr);
1872 
1873    unsigned align_mul = nir_intrinsic_align_mul(intrin);
1874    unsigned align_offset = nir_intrinsic_align_offset(intrin);
1875    if (intrin->intrinsic == nir_intrinsic_load_stack) {
1876       nir_def *components[NIR_MAX_VEC_COMPONENTS] = {
1877          0,
1878       };
1879       for (unsigned c = 0; c < intrin->def.num_components; c++) {
1880          unsigned offset = c * intrin->def.bit_size / 8;
1881          components[c] = nir_load_stack(b, 1, intrin->def.bit_size,
1882                                         .base = nir_intrinsic_base(intrin) + offset,
1883                                         .call_idx = nir_intrinsic_call_idx(intrin),
1884                                         .value_id = nir_intrinsic_value_id(intrin),
1885                                         .align_mul = align_mul,
1886                                         .align_offset = (align_offset + offset) % align_mul);
1887       }
1888 
1889       nir_def_rewrite_uses(&intrin->def,
1890                            nir_vec(b, components,
1891                                    intrin->def.num_components));
1892    } else {
1893       assert(intrin->intrinsic == nir_intrinsic_store_stack);
1894       for (unsigned c = 0; c < intrin->src[0].ssa->num_components; c++) {
1895          unsigned offset = c * intrin->src[0].ssa->bit_size / 8;
1896          nir_store_stack(b, nir_channel(b, intrin->src[0].ssa, c),
1897                          .base = nir_intrinsic_base(intrin) + offset,
1898                          .call_idx = nir_intrinsic_call_idx(intrin),
1899                          .align_mul = align_mul,
1900                          .align_offset = (align_offset + offset) % align_mul,
1901                          .value_id = nir_intrinsic_value_id(intrin),
1902                          .write_mask = 0x1);
1903       }
1904    }
1905 
1906    nir_instr_remove(&intrin->instr);
1907 
1908    return true;
1909 }
1910 
1911 /* Break the load_stack/store_stack intrinsics into single compoments. This
1912  * helps the vectorizer to pack components.
1913  */
1914 static bool
nir_split_stack_components(nir_shader * shader)1915 nir_split_stack_components(nir_shader *shader)
1916 {
1917    return nir_shader_intrinsics_pass(shader, split_stack_components_instr,
1918                                        nir_metadata_block_index |
1919                                           nir_metadata_dominance,
1920                                        NULL);
1921 }
1922 
1923 struct stack_op_vectorizer_state {
1924    nir_should_vectorize_mem_func driver_callback;
1925    void *driver_data;
1926 };
1927 
1928 static bool
should_vectorize(unsigned align_mul,unsigned align_offset,unsigned bit_size,unsigned num_components,nir_intrinsic_instr * low,nir_intrinsic_instr * high,void * data)1929 should_vectorize(unsigned align_mul,
1930                  unsigned align_offset,
1931                  unsigned bit_size,
1932                  unsigned num_components,
1933                  nir_intrinsic_instr *low, nir_intrinsic_instr *high,
1934                  void *data)
1935 {
1936    /* We only care about those intrinsics */
1937    if ((low->intrinsic != nir_intrinsic_load_stack &&
1938         low->intrinsic != nir_intrinsic_store_stack) ||
1939        (high->intrinsic != nir_intrinsic_load_stack &&
1940         high->intrinsic != nir_intrinsic_store_stack))
1941       return false;
1942 
1943    struct stack_op_vectorizer_state *state = data;
1944 
1945    return state->driver_callback(align_mul, align_offset,
1946                                  bit_size, num_components,
1947                                  low, high, state->driver_data);
1948 }
1949 
1950 /** Lower shader call instructions to split shaders.
1951  *
1952  * Shader calls can be split into an initial shader and a series of "resume"
1953  * shaders.   When the shader is first invoked, it is the initial shader which
1954  * is executed.  At any point in the initial shader or any one of the resume
1955  * shaders, a shader call operation may be performed.  The possible shader call
1956  * operations are:
1957  *
1958  *  - trace_ray
1959  *  - report_ray_intersection
1960  *  - execute_callable
1961  *
1962  * When a shader call operation is performed, we push all live values to the
1963  * stack,call rt_trace_ray/rt_execute_callable and then kill the shader. Once
1964  * the operation we invoked is complete, a callee shader will return execution
1965  * to the respective resume shader. The resume shader pops the contents off
1966  * the stack and picks up where the calling shader left off.
1967  *
1968  * Stack management is assumed to be done after this pass. Call
1969  * instructions and their resumes get annotated with stack information that
1970  * should be enough for the backend to implement proper stack management.
1971  */
1972 bool
nir_lower_shader_calls(nir_shader * shader,const nir_lower_shader_calls_options * options,nir_shader *** resume_shaders_out,uint32_t * num_resume_shaders_out,void * mem_ctx)1973 nir_lower_shader_calls(nir_shader *shader,
1974                        const nir_lower_shader_calls_options *options,
1975                        nir_shader ***resume_shaders_out,
1976                        uint32_t *num_resume_shaders_out,
1977                        void *mem_ctx)
1978 {
1979    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
1980 
1981    int num_calls = 0;
1982    nir_foreach_block(block, impl) {
1983       nir_foreach_instr_safe(instr, block) {
1984          if (instr_is_shader_call(instr))
1985             num_calls++;
1986       }
1987    }
1988 
1989    if (num_calls == 0) {
1990       nir_shader_preserve_all_metadata(shader);
1991       *num_resume_shaders_out = 0;
1992       return false;
1993    }
1994 
1995    /* Some intrinsics not only can't be re-materialized but aren't preserved
1996     * when moving to the continuation shader.  We have to move them to the top
1997     * to ensure they get spilled as needed.
1998     */
1999    {
2000       bool progress = false;
2001       NIR_PASS(progress, shader, move_system_values_to_top);
2002       if (progress)
2003          NIR_PASS(progress, shader, nir_opt_cse);
2004    }
2005 
2006    /* Deref chains contain metadata information that is needed by other passes
2007     * after this one. If we don't rematerialize the derefs in the blocks where
2008     * they're used here, the following lowerings will insert phis which can
2009     * prevent other passes from chasing deref chains. Additionally, derefs need
2010     * to be rematerialized after shader call instructions to avoid spilling.
2011     */
2012    {
2013       bool progress = false;
2014       NIR_PASS(progress, shader, wrap_instrs, instr_is_shader_call);
2015 
2016       nir_rematerialize_derefs_in_use_blocks_impl(impl);
2017 
2018       if (progress)
2019          NIR_PASS(_, shader, nir_opt_dead_cf);
2020    }
2021 
2022    /* Save the start point of the call stack in scratch */
2023    unsigned start_call_scratch = shader->scratch_size;
2024 
2025    NIR_PASS_V(shader, spill_ssa_defs_and_lower_shader_calls,
2026               num_calls, options);
2027 
2028    NIR_PASS_V(shader, nir_opt_remove_phis);
2029 
2030    NIR_PASS_V(shader, nir_opt_trim_stack_values);
2031    NIR_PASS_V(shader, nir_opt_sort_and_pack_stack,
2032               start_call_scratch, options->stack_alignment, num_calls);
2033 
2034    /* Make N copies of our shader */
2035    nir_shader **resume_shaders = ralloc_array(mem_ctx, nir_shader *, num_calls);
2036    for (unsigned i = 0; i < num_calls; i++) {
2037       resume_shaders[i] = nir_shader_clone(mem_ctx, shader);
2038 
2039       /* Give them a recognizable name */
2040       resume_shaders[i]->info.name =
2041          ralloc_asprintf(mem_ctx, "%s%sresume_%u",
2042                          shader->info.name ? shader->info.name : "",
2043                          shader->info.name ? "-" : "",
2044                          i);
2045    }
2046 
2047    replace_resume_with_halt(shader, NULL);
2048    nir_opt_dce(shader);
2049    nir_opt_dead_cf(shader);
2050    for (unsigned i = 0; i < num_calls; i++) {
2051       nir_instr *resume_instr = lower_resume(resume_shaders[i], i);
2052       replace_resume_with_halt(resume_shaders[i], resume_instr);
2053       /* Remove CF after halt before nir_opt_if(). */
2054       nir_opt_dead_cf(resume_shaders[i]);
2055       /* Remove the dummy blocks added by flatten_resume_if_ladder() */
2056       nir_opt_if(resume_shaders[i], nir_opt_if_optimize_phi_true_false);
2057       nir_opt_dce(resume_shaders[i]);
2058       nir_opt_dead_cf(resume_shaders[i]);
2059       nir_opt_remove_phis(resume_shaders[i]);
2060    }
2061 
2062    for (unsigned i = 0; i < num_calls; i++)
2063       NIR_PASS_V(resume_shaders[i], nir_opt_remove_respills);
2064 
2065    if (options->localized_loads) {
2066       /* Once loads have been combined we can try to put them closer to where
2067        * they're needed.
2068        */
2069       for (unsigned i = 0; i < num_calls; i++)
2070          NIR_PASS_V(resume_shaders[i], nir_opt_stack_loads);
2071    }
2072 
2073    struct stack_op_vectorizer_state vectorizer_state = {
2074       .driver_callback = options->vectorizer_callback,
2075       .driver_data = options->vectorizer_data,
2076    };
2077    nir_load_store_vectorize_options vect_opts = {
2078       .modes = nir_var_shader_temp,
2079       .callback = should_vectorize,
2080       .cb_data = &vectorizer_state,
2081    };
2082 
2083    if (options->vectorizer_callback != NULL) {
2084       NIR_PASS_V(shader, nir_split_stack_components);
2085       NIR_PASS_V(shader, nir_opt_load_store_vectorize, &vect_opts);
2086    }
2087    NIR_PASS_V(shader, nir_lower_stack_to_scratch, options->address_format);
2088    nir_opt_cse(shader);
2089    for (unsigned i = 0; i < num_calls; i++) {
2090       if (options->vectorizer_callback != NULL) {
2091          NIR_PASS_V(resume_shaders[i], nir_split_stack_components);
2092          NIR_PASS_V(resume_shaders[i], nir_opt_load_store_vectorize, &vect_opts);
2093       }
2094       NIR_PASS_V(resume_shaders[i], nir_lower_stack_to_scratch,
2095                  options->address_format);
2096       nir_opt_cse(resume_shaders[i]);
2097    }
2098 
2099    *resume_shaders_out = resume_shaders;
2100    *num_resume_shaders_out = num_calls;
2101 
2102    return true;
2103 }
2104