• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright © 2020 Intel Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 #include "nir.h"
25 #include "nir_builder.h"
26 #include "nir_phi_builder.h"
27 #include "util/u_math.h"
28 
29 static bool
move_system_values_to_top(nir_shader * shader)30 move_system_values_to_top(nir_shader *shader)
31 {
32    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
33 
34    bool progress = false;
35    nir_foreach_block(block, impl) {
36       nir_foreach_instr_safe(instr, block) {
37          if (instr->type != nir_instr_type_intrinsic)
38             continue;
39 
40          /* These intrinsics not only can't be re-materialized but aren't
41           * preserved when moving to the continuation shader.  We have to move
42           * them to the top to ensure they get spilled as needed.
43           */
44          nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
45          switch (intrin->intrinsic) {
46          case nir_intrinsic_load_shader_record_ptr:
47          case nir_intrinsic_load_btd_local_arg_addr_intel:
48             nir_instr_remove(instr);
49             nir_instr_insert(nir_before_cf_list(&impl->body), instr);
50             progress = true;
51             break;
52 
53          default:
54             break;
55          }
56       }
57    }
58 
59    if (progress) {
60       nir_metadata_preserve(impl, nir_metadata_block_index |
61                                   nir_metadata_dominance);
62    } else {
63       nir_metadata_preserve(impl, nir_metadata_all);
64    }
65 
66    return progress;
67 }
68 
69 static bool
instr_is_shader_call(nir_instr * instr)70 instr_is_shader_call(nir_instr *instr)
71 {
72    if (instr->type != nir_instr_type_intrinsic)
73       return false;
74 
75    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
76    return intrin->intrinsic == nir_intrinsic_trace_ray ||
77           intrin->intrinsic == nir_intrinsic_report_ray_intersection ||
78           intrin->intrinsic == nir_intrinsic_execute_callable;
79 }
80 
81 /* Previously named bitset, it had to be renamed as FreeBSD defines a struct
82  * named bitset in sys/_bitset.h required by pthread_np.h which is included
83  * from src/util/u_thread.h that is indirectly included by this file.
84  */
85 struct brw_bitset {
86    BITSET_WORD *set;
87    unsigned size;
88 };
89 
90 static struct brw_bitset
bitset_create(void * mem_ctx,unsigned size)91 bitset_create(void *mem_ctx, unsigned size)
92 {
93    return (struct brw_bitset) {
94       .set = rzalloc_array(mem_ctx, BITSET_WORD, BITSET_WORDS(size)),
95       .size = size,
96    };
97 }
98 
99 static bool
src_is_in_bitset(nir_src * src,void * _set)100 src_is_in_bitset(nir_src *src, void *_set)
101 {
102    struct brw_bitset *set = _set;
103    assert(src->is_ssa);
104 
105    /* Any SSA values which were added after we generated liveness information
106     * are things generated by this pass and, while most of it is arithmetic
107     * which we could re-materialize, we don't need to because it's only used
108     * for a single load/store and so shouldn't cross any shader calls.
109     */
110    if (src->ssa->index >= set->size)
111       return false;
112 
113    return BITSET_TEST(set->set, src->ssa->index);
114 }
115 
116 static void
add_ssa_def_to_bitset(nir_ssa_def * def,struct brw_bitset * set)117 add_ssa_def_to_bitset(nir_ssa_def *def, struct brw_bitset *set)
118 {
119    if (def->index >= set->size)
120       return;
121 
122    BITSET_SET(set->set, def->index);
123 }
124 
125 static bool
can_remat_instr(nir_instr * instr,struct brw_bitset * remat)126 can_remat_instr(nir_instr *instr, struct brw_bitset *remat)
127 {
128    /* Set of all values which are trivially re-materializable and we shouldn't
129     * ever spill them.  This includes:
130     *
131     *   - Undef values
132     *   - Constants
133     *   - Uniforms (UBO or push constant)
134     *   - ALU combinations of any of the above
135     *   - Derefs which are either complete or casts of any of the above
136     *
137     * Because this pass rewrites things in-order and phis are always turned
138     * into register writes, We can use "is it SSA?" to answer the question
139     * "can my source be re-materialized?".
140     */
141    switch (instr->type) {
142    case nir_instr_type_alu:
143       if (!nir_instr_as_alu(instr)->dest.dest.is_ssa)
144          return false;
145 
146       return nir_foreach_src(instr, src_is_in_bitset, remat);
147 
148    case nir_instr_type_deref:
149       return nir_foreach_src(instr, src_is_in_bitset, remat);
150 
151    case nir_instr_type_intrinsic: {
152       nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
153       switch (intrin->intrinsic) {
154       case nir_intrinsic_load_ubo:
155       case nir_intrinsic_vulkan_resource_index:
156       case nir_intrinsic_vulkan_resource_reindex:
157       case nir_intrinsic_load_vulkan_descriptor:
158       case nir_intrinsic_load_push_constant:
159          /* These intrinsics don't need to be spilled as long as they don't
160           * depend on any spilled values.
161           */
162          return nir_foreach_src(instr, src_is_in_bitset, remat);
163 
164       case nir_intrinsic_load_scratch_base_ptr:
165       case nir_intrinsic_load_ray_launch_id:
166       case nir_intrinsic_load_btd_dss_id_intel:
167       case nir_intrinsic_load_btd_global_arg_addr_intel:
168       case nir_intrinsic_load_btd_resume_sbt_addr_intel:
169       case nir_intrinsic_load_ray_base_mem_addr_intel:
170       case nir_intrinsic_load_ray_hw_stack_size_intel:
171       case nir_intrinsic_load_ray_sw_stack_size_intel:
172       case nir_intrinsic_load_ray_num_dss_rt_stacks_intel:
173       case nir_intrinsic_load_ray_hit_sbt_addr_intel:
174       case nir_intrinsic_load_ray_hit_sbt_stride_intel:
175       case nir_intrinsic_load_ray_miss_sbt_addr_intel:
176       case nir_intrinsic_load_ray_miss_sbt_stride_intel:
177       case nir_intrinsic_load_callable_sbt_addr_intel:
178       case nir_intrinsic_load_callable_sbt_stride_intel:
179          /* Notably missing from the above list is btd_local_arg_addr_intel.
180           * This is because the resume shader will have a different local
181           * argument pointer because it has a different BSR.  Any access of
182           * the original shader's local arguments needs to be preserved so
183           * that pointer has to be saved on the stack.
184           *
185           * TODO: There may be some system values we want to avoid
186           *       re-materializing as well but we have to be very careful
187           *       to ensure that it's a system value which cannot change
188           *       across a shader call.
189           */
190          return true;
191 
192       default:
193          return false;
194       }
195    }
196 
197    case nir_instr_type_ssa_undef:
198    case nir_instr_type_load_const:
199       return true;
200 
201    default:
202       return false;
203    }
204 }
205 
206 static bool
can_remat_ssa_def(nir_ssa_def * def,struct brw_bitset * remat)207 can_remat_ssa_def(nir_ssa_def *def, struct brw_bitset *remat)
208 {
209    return can_remat_instr(def->parent_instr, remat);
210 }
211 
212 static nir_ssa_def *
remat_ssa_def(nir_builder * b,nir_ssa_def * def)213 remat_ssa_def(nir_builder *b, nir_ssa_def *def)
214 {
215    nir_instr *clone = nir_instr_clone(b->shader, def->parent_instr);
216    nir_builder_instr_insert(b, clone);
217    return nir_instr_ssa_def(clone);
218 }
219 
220 struct pbv_array {
221    struct nir_phi_builder_value **arr;
222    unsigned len;
223 };
224 
225 static struct nir_phi_builder_value *
get_phi_builder_value_for_def(nir_ssa_def * def,struct pbv_array * pbv_arr)226 get_phi_builder_value_for_def(nir_ssa_def *def,
227                               struct pbv_array *pbv_arr)
228 {
229    if (def->index >= pbv_arr->len)
230       return NULL;
231 
232    return pbv_arr->arr[def->index];
233 }
234 
235 static nir_ssa_def *
get_phi_builder_def_for_src(nir_src * src,struct pbv_array * pbv_arr,nir_block * block)236 get_phi_builder_def_for_src(nir_src *src, struct pbv_array *pbv_arr,
237                             nir_block *block)
238 {
239    assert(src->is_ssa);
240 
241    struct nir_phi_builder_value *pbv =
242       get_phi_builder_value_for_def(src->ssa, pbv_arr);
243    if (pbv == NULL)
244       return NULL;
245 
246    return nir_phi_builder_value_get_block_def(pbv, block);
247 }
248 
249 static bool
rewrite_instr_src_from_phi_builder(nir_src * src,void * _pbv_arr)250 rewrite_instr_src_from_phi_builder(nir_src *src, void *_pbv_arr)
251 {
252    nir_block *block;
253    if (src->parent_instr->type == nir_instr_type_phi) {
254       nir_phi_src *phi_src = exec_node_data(nir_phi_src, src, src);
255       block = phi_src->pred;
256    } else {
257       block = src->parent_instr->block;
258    }
259 
260    nir_ssa_def *new_def = get_phi_builder_def_for_src(src, _pbv_arr, block);
261    if (new_def != NULL)
262       nir_instr_rewrite_src(src->parent_instr, src, nir_src_for_ssa(new_def));
263    return true;
264 }
265 
266 static nir_ssa_def *
spill_fill(nir_builder * before,nir_builder * after,nir_ssa_def * def,unsigned offset,nir_address_format address_format,unsigned stack_alignment)267 spill_fill(nir_builder *before, nir_builder *after, nir_ssa_def *def, unsigned offset,
268            nir_address_format address_format, unsigned stack_alignment)
269 {
270    const unsigned comp_size = def->bit_size / 8;
271 
272    switch(address_format) {
273    case nir_address_format_32bit_offset:
274       nir_store_scratch(before, def, nir_imm_int(before, offset),
275                         .align_mul = MIN2(comp_size, stack_alignment), .write_mask = ~0);
276       def = nir_load_scratch(after, def->num_components, def->bit_size,
277                              nir_imm_int(after, offset), .align_mul = MIN2(comp_size, stack_alignment));
278       break;
279    case nir_address_format_64bit_global: {
280       nir_ssa_def *addr = nir_iadd_imm(before, nir_load_scratch_base_ptr(before, 1, 64, 1), offset);
281       nir_store_global(before, addr, MIN2(comp_size, stack_alignment), def, ~0);
282       addr = nir_iadd_imm(after, nir_load_scratch_base_ptr(after, 1, 64, 1), offset);
283       def = nir_load_global(after, addr, MIN2(comp_size, stack_alignment),
284                             def->num_components, def->bit_size);
285       break;
286    }
287    default:
288       unreachable("Unimplemented address format");
289    }
290    return def;
291 }
292 
293 static void
spill_ssa_defs_and_lower_shader_calls(nir_shader * shader,uint32_t num_calls,nir_address_format address_format,unsigned stack_alignment)294 spill_ssa_defs_and_lower_shader_calls(nir_shader *shader, uint32_t num_calls,
295                                       nir_address_format address_format,
296                                       unsigned stack_alignment)
297 {
298    /* TODO: If a SSA def is filled more than once, we probably want to just
299     *       spill it at the LCM of the fill sites so we avoid unnecessary
300     *       extra spills
301     *
302     * TODO: If a SSA def is defined outside a loop but live through some call
303     *       inside the loop, we probably want to spill outside the loop.  We
304     *       may also want to fill outside the loop if it's not used in the
305     *       loop.
306     *
307     * TODO: Right now, we only re-materialize things if their immediate
308     *       sources are things which we filled.  We probably want to expand
309     *       that to re-materialize things whose sources are things we can
310     *       re-materialize from things we filled.  We may want some DAG depth
311     *       heuristic on this.
312     */
313 
314    /* This happens per-shader rather than per-impl because we mess with
315     * nir_shader::scratch_size.
316     */
317    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
318 
319    nir_metadata_require(impl, nir_metadata_live_ssa_defs |
320                               nir_metadata_dominance |
321                               nir_metadata_block_index);
322 
323    void *mem_ctx = ralloc_context(shader);
324 
325    const unsigned num_ssa_defs = impl->ssa_alloc;
326    const unsigned live_words = BITSET_WORDS(num_ssa_defs);
327    struct brw_bitset trivial_remat = bitset_create(mem_ctx, num_ssa_defs);
328 
329    /* Array of all live SSA defs which are spill candidates */
330    nir_ssa_def **spill_defs =
331       rzalloc_array(mem_ctx, nir_ssa_def *, num_ssa_defs);
332 
333    /* For each spill candidate, an array of every time it's defined by a fill,
334     * indexed by call instruction index.
335     */
336    nir_ssa_def ***fill_defs =
337       rzalloc_array(mem_ctx, nir_ssa_def **, num_ssa_defs);
338 
339    /* For each call instruction, the liveness set at the call */
340    const BITSET_WORD **call_live =
341       rzalloc_array(mem_ctx, const BITSET_WORD *, num_calls);
342 
343    /* For each call instruction, the block index of the block it lives in */
344    uint32_t *call_block_indices = rzalloc_array(mem_ctx, uint32_t, num_calls);
345 
346    /* Walk the call instructions and fetch the liveness set and block index
347     * for each one.  We need to do this before we start modifying the shader
348     * so that liveness doesn't complain that it's been invalidated.  Don't
349     * worry, we'll be very careful with our live sets. :-)
350     */
351    unsigned call_idx = 0;
352    nir_foreach_block(block, impl) {
353       nir_foreach_instr(instr, block) {
354          if (!instr_is_shader_call(instr))
355             continue;
356 
357          call_block_indices[call_idx] = block->index;
358 
359          /* The objective here is to preserve values around shader call
360           * instructions.  Therefore, we use the live set after the
361           * instruction as the set of things we want to preserve.  Because
362           * none of our shader call intrinsics return anything, we don't have
363           * to worry about spilling over a return value.
364           *
365           * TODO: This isn't quite true for report_intersection.
366           */
367          call_live[call_idx] =
368             nir_get_live_ssa_defs(nir_after_instr(instr), mem_ctx);
369 
370          call_idx++;
371       }
372    }
373 
374    nir_builder before, after;
375    nir_builder_init(&before, impl);
376    nir_builder_init(&after, impl);
377 
378    call_idx = 0;
379    unsigned max_scratch_size = shader->scratch_size;
380    nir_foreach_block(block, impl) {
381       nir_foreach_instr_safe(instr, block) {
382          nir_ssa_def *def = nir_instr_ssa_def(instr);
383          if (def != NULL) {
384             if (can_remat_ssa_def(def, &trivial_remat)) {
385                add_ssa_def_to_bitset(def, &trivial_remat);
386             } else {
387                spill_defs[def->index] = def;
388             }
389          }
390 
391          if (!instr_is_shader_call(instr))
392             continue;
393 
394          const BITSET_WORD *live = call_live[call_idx];
395 
396          /* Make a copy of trivial_remat that we'll update as we crawl through
397           * the live SSA defs and unspill them.
398           */
399          struct brw_bitset remat = bitset_create(mem_ctx, num_ssa_defs);
400          memcpy(remat.set, trivial_remat.set, live_words * sizeof(BITSET_WORD));
401 
402          /* Before the two builders are always separated by the call
403           * instruction, it won't break anything to have two of them.
404           */
405          before.cursor = nir_before_instr(instr);
406          after.cursor = nir_after_instr(instr);
407 
408          unsigned offset = shader->scratch_size;
409          for (unsigned w = 0; w < live_words; w++) {
410             BITSET_WORD spill_mask = live[w] & ~trivial_remat.set[w];
411             while (spill_mask) {
412                int i = u_bit_scan(&spill_mask);
413                assert(i >= 0);
414                unsigned index = w * BITSET_WORDBITS + i;
415                assert(index < num_ssa_defs);
416 
417                nir_ssa_def *def = spill_defs[index];
418                if (can_remat_ssa_def(def, &remat)) {
419                   /* If this SSA def is re-materializable or based on other
420                    * things we've already spilled, re-materialize it rather
421                    * than spilling and filling.  Anything which is trivially
422                    * re-materializable won't even get here because we take
423                    * those into account in spill_mask above.
424                    */
425                   def = remat_ssa_def(&after, def);
426                } else {
427                   bool is_bool = def->bit_size == 1;
428                   if (is_bool)
429                      def = nir_b2b32(&before, def);
430 
431                   const unsigned comp_size = def->bit_size / 8;
432                   offset = ALIGN(offset, comp_size);
433 
434                   def = spill_fill(&before, &after, def, offset,
435                                    address_format,stack_alignment);
436 
437                   if (is_bool)
438                      def = nir_b2b1(&after, def);
439 
440                   offset += def->num_components * comp_size;
441                }
442 
443                /* Mark this SSA def as available in the remat set so that, if
444                 * some other SSA def we need is computed based on it, we can
445                 * just re-compute instead of fetching from memory.
446                 */
447                BITSET_SET(remat.set, index);
448 
449                /* For now, we just make a note of this new SSA def.  We'll
450                 * fix things up with the phi builder as a second pass.
451                 */
452                if (fill_defs[index] == NULL) {
453                   fill_defs[index] =
454                      rzalloc_array(mem_ctx, nir_ssa_def *, num_calls);
455                }
456                fill_defs[index][call_idx] = def;
457             }
458          }
459 
460          nir_builder *b = &before;
461 
462          offset = ALIGN(offset, stack_alignment);
463          max_scratch_size = MAX2(max_scratch_size, offset);
464 
465          /* First thing on the called shader's stack is the resume address
466           * followed by a pointer to the payload.
467           */
468          nir_intrinsic_instr *call = nir_instr_as_intrinsic(instr);
469 
470          /* Lower to generic intrinsics with information about the stack & resume shader. */
471          switch (call->intrinsic) {
472          case nir_intrinsic_trace_ray: {
473             nir_rt_trace_ray(b, call->src[0].ssa, call->src[1].ssa,
474                               call->src[2].ssa, call->src[3].ssa,
475                               call->src[4].ssa, call->src[5].ssa,
476                               call->src[6].ssa, call->src[7].ssa,
477                               call->src[8].ssa, call->src[9].ssa,
478                               call->src[10].ssa,
479                               .call_idx = call_idx, .stack_size = offset);
480             break;
481          }
482 
483          case nir_intrinsic_report_ray_intersection:
484             unreachable("Any-hit shaders must be inlined");
485 
486          case nir_intrinsic_execute_callable: {
487             nir_rt_execute_callable(b, call->src[0].ssa, call->src[1].ssa, .call_idx = call_idx, .stack_size = offset);
488             break;
489          }
490 
491          default:
492             unreachable("Invalid shader call instruction");
493          }
494 
495          nir_rt_resume(b, .call_idx = call_idx, .stack_size = offset);
496 
497          nir_instr_remove(&call->instr);
498 
499          call_idx++;
500       }
501    }
502    assert(call_idx == num_calls);
503    shader->scratch_size = max_scratch_size;
504 
505    struct nir_phi_builder *pb = nir_phi_builder_create(impl);
506    struct pbv_array pbv_arr = {
507       .arr = rzalloc_array(mem_ctx, struct nir_phi_builder_value *,
508                            num_ssa_defs),
509       .len = num_ssa_defs,
510    };
511 
512    const unsigned block_words = BITSET_WORDS(impl->num_blocks);
513    BITSET_WORD *def_blocks = ralloc_array(mem_ctx, BITSET_WORD, block_words);
514 
515    /* Go through and set up phi builder values for each spillable value which
516     * we ever needed to spill at any point.
517     */
518    for (unsigned index = 0; index < num_ssa_defs; index++) {
519       if (fill_defs[index] == NULL)
520          continue;
521 
522       nir_ssa_def *def = spill_defs[index];
523 
524       memset(def_blocks, 0, block_words * sizeof(BITSET_WORD));
525       BITSET_SET(def_blocks, def->parent_instr->block->index);
526       for (unsigned call_idx = 0; call_idx < num_calls; call_idx++) {
527          if (fill_defs[index][call_idx] != NULL)
528             BITSET_SET(def_blocks, call_block_indices[call_idx]);
529       }
530 
531       pbv_arr.arr[index] = nir_phi_builder_add_value(pb, def->num_components,
532                                                      def->bit_size, def_blocks);
533    }
534 
535    /* Walk the shader one more time and rewrite SSA defs as needed using the
536     * phi builder.
537     */
538    nir_foreach_block(block, impl) {
539       nir_foreach_instr_safe(instr, block) {
540          nir_ssa_def *def = nir_instr_ssa_def(instr);
541          if (def != NULL) {
542             struct nir_phi_builder_value *pbv =
543                get_phi_builder_value_for_def(def, &pbv_arr);
544             if (pbv != NULL)
545                nir_phi_builder_value_set_block_def(pbv, block, def);
546          }
547 
548          if (instr->type == nir_instr_type_phi)
549             continue;
550 
551          nir_foreach_src(instr, rewrite_instr_src_from_phi_builder, &pbv_arr);
552 
553          if (instr->type != nir_instr_type_intrinsic)
554             continue;
555 
556          nir_intrinsic_instr *resume = nir_instr_as_intrinsic(instr);
557          if (resume->intrinsic != nir_intrinsic_rt_resume)
558             continue;
559 
560          call_idx = nir_intrinsic_call_idx(resume);
561 
562          /* Technically, this is the wrong place to add the fill defs to the
563           * phi builder values because we haven't seen any of the load_scratch
564           * instructions for this call yet.  However, we know based on how we
565           * emitted them that no value ever gets used until after the load
566           * instruction has been emitted so this should be safe.  If we ever
567           * fail validation due this it likely means a bug in our spilling
568           * code and not the phi re-construction code here.
569           */
570          for (unsigned index = 0; index < num_ssa_defs; index++) {
571             if (fill_defs[index] && fill_defs[index][call_idx]) {
572                nir_phi_builder_value_set_block_def(pbv_arr.arr[index], block,
573                                                    fill_defs[index][call_idx]);
574             }
575          }
576       }
577 
578       nir_if *following_if = nir_block_get_following_if(block);
579       if (following_if) {
580          nir_ssa_def *new_def =
581             get_phi_builder_def_for_src(&following_if->condition,
582                                         &pbv_arr, block);
583          if (new_def != NULL)
584             nir_if_rewrite_condition(following_if, nir_src_for_ssa(new_def));
585       }
586 
587       /* Handle phi sources that source from this block.  We have to do this
588        * as a separate pass because the phi builder assumes that uses and
589        * defs are processed in an order that respects dominance.  When we have
590        * loops, a phi source may be a back-edge so we have to handle it as if
591        * it were one of the last instructions in the predecessor block.
592        */
593       nir_foreach_phi_src_leaving_block(block,
594                                         rewrite_instr_src_from_phi_builder,
595                                         &pbv_arr);
596    }
597 
598    nir_phi_builder_finish(pb);
599 
600    ralloc_free(mem_ctx);
601 
602    nir_metadata_preserve(impl, nir_metadata_block_index |
603                                nir_metadata_dominance);
604 }
605 
606 static nir_instr *
find_resume_instr(nir_function_impl * impl,unsigned call_idx)607 find_resume_instr(nir_function_impl *impl, unsigned call_idx)
608 {
609    nir_foreach_block(block, impl) {
610       nir_foreach_instr(instr, block) {
611          if (instr->type != nir_instr_type_intrinsic)
612             continue;
613 
614          nir_intrinsic_instr *resume = nir_instr_as_intrinsic(instr);
615          if (resume->intrinsic != nir_intrinsic_rt_resume)
616             continue;
617 
618          if (nir_intrinsic_call_idx(resume) == call_idx)
619             return &resume->instr;
620       }
621    }
622    unreachable("Couldn't find resume instruction");
623 }
624 
625 /* Walk the CF tree and duplicate the contents of every loop, one half runs on
626  * resume and the other half is for any post-resume loop iterations.  We are
627  * careful in our duplication to ensure that resume_instr is in the resume
628  * half of the loop though a copy of resume_instr will remain in the other
629  * half as well in case the same shader call happens twice.
630  */
631 static bool
duplicate_loop_bodies(nir_function_impl * impl,nir_instr * resume_instr)632 duplicate_loop_bodies(nir_function_impl *impl, nir_instr *resume_instr)
633 {
634    nir_register *resume_reg = NULL;
635    for (nir_cf_node *node = resume_instr->block->cf_node.parent;
636         node->type != nir_cf_node_function; node = node->parent) {
637       if (node->type != nir_cf_node_loop)
638          continue;
639 
640       nir_loop *loop = nir_cf_node_as_loop(node);
641 
642       if (resume_reg == NULL) {
643          /* We only create resume_reg if we encounter a loop.  This way we can
644           * avoid re-validating the shader and calling ssa_to_regs in the case
645           * where it's just if-ladders.
646           */
647          resume_reg = nir_local_reg_create(impl);
648          resume_reg->num_components = 1;
649          resume_reg->bit_size = 1;
650 
651          nir_builder b;
652          nir_builder_init(&b, impl);
653 
654          /* Initialize resume to true */
655          b.cursor = nir_before_cf_list(&impl->body);
656          nir_store_reg(&b, resume_reg, nir_imm_true(&b), 1);
657 
658          /* Set resume to false right after the resume instruction */
659          b.cursor = nir_after_instr(resume_instr);
660          nir_store_reg(&b, resume_reg, nir_imm_false(&b), 1);
661       }
662 
663       /* Before we go any further, make sure that everything which exits the
664        * loop or continues around to the top of the loop does so through
665        * registers.  We're about to duplicate the loop body and we'll have
666        * serious trouble if we don't do this.
667        */
668       nir_convert_loop_to_lcssa(loop);
669       nir_lower_phis_to_regs_block(nir_loop_first_block(loop));
670       nir_lower_phis_to_regs_block(
671          nir_cf_node_as_block(nir_cf_node_next(&loop->cf_node)));
672 
673       nir_cf_list cf_list;
674       nir_cf_list_extract(&cf_list, &loop->body);
675 
676       nir_if *_if = nir_if_create(impl->function->shader);
677       _if->condition = nir_src_for_reg(resume_reg);
678       nir_cf_node_insert(nir_after_cf_list(&loop->body), &_if->cf_node);
679 
680       nir_cf_list clone;
681       nir_cf_list_clone(&clone, &cf_list, &loop->cf_node, NULL);
682 
683       /* Insert the clone in the else and the original in the then so that
684        * the resume_instr remains valid even after the duplication.
685        */
686       nir_cf_reinsert(&cf_list, nir_before_cf_list(&_if->then_list));
687       nir_cf_reinsert(&clone, nir_before_cf_list(&_if->else_list));
688    }
689 
690    if (resume_reg != NULL)
691       nir_metadata_preserve(impl, nir_metadata_none);
692 
693    return resume_reg != NULL;
694 }
695 
696 static bool
cf_node_contains_instr(nir_cf_node * node,nir_instr * instr)697 cf_node_contains_instr(nir_cf_node *node, nir_instr *instr)
698 {
699    for (nir_cf_node *n = &instr->block->cf_node; n != NULL; n = n->parent) {
700       if (n == node)
701          return true;
702    }
703 
704    return false;
705 }
706 
707 static void
rewrite_phis_to_pred(nir_block * block,nir_block * pred)708 rewrite_phis_to_pred(nir_block *block, nir_block *pred)
709 {
710    nir_foreach_instr(instr, block) {
711       if (instr->type != nir_instr_type_phi)
712          break;
713 
714       nir_phi_instr *phi = nir_instr_as_phi(instr);
715 
716       ASSERTED bool found = false;
717       nir_foreach_phi_src(phi_src, phi) {
718          if (phi_src->pred == pred) {
719             found = true;
720             assert(phi_src->src.is_ssa);
721             nir_ssa_def_rewrite_uses(&phi->dest.ssa, phi_src->src.ssa);
722             break;
723          }
724       }
725       assert(found);
726    }
727 }
728 
729 /** Flattens if ladders leading up to a resume
730  *
731  * Given a resume_instr, this function flattens any if ladders leading to the
732  * resume instruction and deletes any code that cannot be encountered on a
733  * direct path to the resume instruction.  This way we get, for the most part,
734  * straight-line control-flow up to the resume instruction.
735  *
736  * While we do this flattening, we also move any code which is in the remat
737  * set up to the top of the function or to the top of the resume portion of
738  * the current loop.  We don't worry about control-flow as we do this because
739  * phis will never be in the remat set (see can_remat_instr) and so nothing
740  * control-dependent will ever need to be re-materialized.  It is possible
741  * that this algorithm will preserve too many instructions by moving them to
742  * the top but we leave that for DCE to clean up.  Any code not in the remat
743  * set is deleted because it's either unused in the continuation or else
744  * unspilled from a previous continuation and the unspill code is after the
745  * resume instruction.
746  *
747  * If, for instance, we have something like this:
748  *
749  *    // block 0
750  *    if (cond1) {
751  *       // block 1
752  *    } else {
753  *       // block 2
754  *       if (cond2) {
755  *          // block 3
756  *          resume;
757  *          if (cond3) {
758  *             // block 4
759  *          }
760  *       } else {
761  *          // block 5
762  *       }
763  *    }
764  *
765  * then we know, because we know the resume instruction had to be encoutered,
766  * that cond1 = false and cond2 = true and we lower as follows:
767  *
768  *    // block 0
769  *    // block 2
770  *    // block 3
771  *    resume;
772  *    if (cond3) {
773  *       // block 4
774  *    }
775  *
776  * As you can see, the code in blocks 1 and 5 was removed because there is no
777  * path from the start of the shader to the resume instruction which execute
778  * blocks 1 or 5.  Any remat code from blocks 0, 2, and 3 is preserved and
779  * moved to the top.  If the resume instruction is inside a loop then we know
780  * a priori that it is of the form
781  *
782  *    loop {
783  *       if (resume) {
784  *          // Contents containing resume_instr
785  *       } else {
786  *          // Second copy of contents
787  *       }
788  *    }
789  *
790  * In this case, we only descend into the first half of the loop.  The second
791  * half is left alone as that portion is only ever executed after the resume
792  * instruction.
793  */
794 static bool
flatten_resume_if_ladder(nir_function_impl * impl,nir_instr * cursor,struct exec_list * child_list,bool child_list_contains_cursor,nir_instr * resume_instr,struct brw_bitset * remat)795 flatten_resume_if_ladder(nir_function_impl *impl,
796                          nir_instr *cursor,
797                          struct exec_list *child_list,
798                          bool child_list_contains_cursor,
799                          nir_instr *resume_instr,
800                          struct brw_bitset *remat)
801 {
802    nir_shader *shader = impl->function->shader;
803    nir_cf_list cf_list;
804 
805    /* If our child list contains the cursor instruction then we start out
806     * before the cursor instruction.  We need to know this so that we can skip
807     * moving instructions which are already before the cursor.
808     */
809    bool before_cursor = child_list_contains_cursor;
810 
811    nir_cf_node *resume_node = NULL;
812    foreach_list_typed_safe(nir_cf_node, child, node, child_list) {
813       switch (child->type) {
814       case nir_cf_node_block: {
815          nir_block *block = nir_cf_node_as_block(child);
816          nir_foreach_instr_safe(instr, block) {
817             if (instr == cursor) {
818                assert(nir_cf_node_is_first(&block->cf_node));
819                assert(before_cursor);
820                before_cursor = false;
821                continue;
822             }
823 
824             if (instr == resume_instr)
825                goto found_resume;
826 
827             if (!before_cursor && can_remat_instr(instr, remat)) {
828                nir_instr_remove(instr);
829                nir_instr_insert(nir_before_instr(cursor), instr);
830 
831                nir_ssa_def *def = nir_instr_ssa_def(instr);
832                BITSET_SET(remat->set, def->index);
833             }
834          }
835          break;
836       }
837 
838       case nir_cf_node_if: {
839          assert(!before_cursor);
840          nir_if *_if = nir_cf_node_as_if(child);
841          if (flatten_resume_if_ladder(impl, cursor, &_if->then_list,
842                                       false, resume_instr, remat)) {
843             resume_node = child;
844             rewrite_phis_to_pred(nir_cf_node_as_block(nir_cf_node_next(child)),
845                                  nir_if_last_then_block(_if));
846             goto found_resume;
847          }
848 
849          if (flatten_resume_if_ladder(impl, cursor, &_if->else_list,
850                                       false, resume_instr, remat)) {
851             resume_node = child;
852             rewrite_phis_to_pred(nir_cf_node_as_block(nir_cf_node_next(child)),
853                                  nir_if_last_else_block(_if));
854             goto found_resume;
855          }
856          break;
857       }
858 
859       case nir_cf_node_loop: {
860          assert(!before_cursor);
861          nir_loop *loop = nir_cf_node_as_loop(child);
862 
863          if (cf_node_contains_instr(&loop->cf_node, resume_instr)) {
864             /* Thanks to our loop body duplication pass, every level of loop
865              * containing the resume instruction contains exactly three nodes:
866              * two blocks and an if.  We don't want to lower away this if
867              * because it's the resume selection if.  The resume half is
868              * always the then_list so that's what we want to flatten.
869              */
870             nir_block *header = nir_loop_first_block(loop);
871             nir_if *_if = nir_cf_node_as_if(nir_cf_node_next(&header->cf_node));
872 
873             /* We want to place anything re-materialized from inside the loop
874              * at the top of the resume half of the loop.
875              */
876             nir_instr *loop_cursor =
877                &nir_intrinsic_instr_create(shader, nir_intrinsic_nop)->instr;
878             nir_instr_insert(nir_before_cf_list(&_if->then_list), loop_cursor);
879 
880             ASSERTED bool found =
881                flatten_resume_if_ladder(impl, loop_cursor, &_if->then_list,
882                                         true, resume_instr, remat);
883             assert(found);
884             resume_node = child;
885             goto found_resume;
886          } else {
887             ASSERTED bool found =
888                flatten_resume_if_ladder(impl, cursor, &loop->body,
889                                         false, resume_instr, remat);
890             assert(!found);
891          }
892          break;
893       }
894 
895       case nir_cf_node_function:
896          unreachable("Unsupported CF node type");
897       }
898    }
899    assert(!before_cursor);
900 
901    /* If we got here, we didn't find the resume node or instruction. */
902    return false;
903 
904 found_resume:
905    /* If we got here then we found either the resume node or the resume
906     * instruction in this CF list.
907     */
908    if (resume_node) {
909       /* If the resume instruction is buried in side one of our children CF
910        * nodes, resume_node now points to that child.
911        */
912       if (resume_node->type == nir_cf_node_if) {
913          /* Thanks to the recursive call, all of the interesting contents of
914           * resume_node have been copied before the cursor.  We just need to
915           * copy the stuff after resume_node.
916           */
917          nir_cf_extract(&cf_list, nir_after_cf_node(resume_node),
918                                   nir_after_cf_list(child_list));
919       } else {
920          /* The loop contains its own cursor and still has useful stuff in it.
921           * We want to move everything after and including the loop to before
922           * the cursor.
923           */
924          assert(resume_node->type == nir_cf_node_loop);
925          nir_cf_extract(&cf_list, nir_before_cf_node(resume_node),
926                                   nir_after_cf_list(child_list));
927       }
928    } else {
929       /* If we found the resume instruction in one of our blocks, grab
930        * everything after it in the entire list (not just the one block), and
931        * place it before the cursor instr.
932        */
933       nir_cf_extract(&cf_list, nir_after_instr(resume_instr),
934                                nir_after_cf_list(child_list));
935    }
936    nir_cf_reinsert(&cf_list, nir_before_instr(cursor));
937 
938    if (!resume_node) {
939       /* We want the resume to be the first "interesting" instruction */
940       nir_instr_remove(resume_instr);
941       nir_instr_insert(nir_before_cf_list(&impl->body), resume_instr);
942    }
943 
944    /* We've copied everything interesting out of this CF list to before the
945     * cursor.  Delete everything else.
946     */
947    if (child_list_contains_cursor) {
948       nir_cf_extract(&cf_list, nir_after_instr(cursor),
949                                nir_after_cf_list(child_list));
950    } else {
951       nir_cf_list_extract(&cf_list, child_list);
952    }
953    nir_cf_delete(&cf_list);
954 
955    return true;
956 }
957 
958 static nir_instr *
lower_resume(nir_shader * shader,int call_idx)959 lower_resume(nir_shader *shader, int call_idx)
960 {
961    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
962 
963    nir_instr *resume_instr = find_resume_instr(impl, call_idx);
964 
965    if (duplicate_loop_bodies(impl, resume_instr)) {
966       nir_validate_shader(shader, "after duplicate_loop_bodies in "
967                                   "brw_nir_lower_shader_calls");
968       /* If we duplicated the bodies of any loops, run regs_to_ssa to get rid
969        * of all those pesky registers we just added.
970        */
971       NIR_PASS_V(shader, nir_lower_regs_to_ssa);
972    }
973 
974    /* Re-index nir_ssa_def::index.  We don't care about actual liveness in
975     * this pass but, so we can use the same helpers as the spilling pass, we
976     * need to make sure that live_index is something sane.  It's used
977     * constantly for determining if an SSA value has been added since the
978     * start of the pass.
979     */
980    nir_index_ssa_defs(impl);
981 
982    void *mem_ctx = ralloc_context(shader);
983 
984    /* Used to track which things may have been assumed to be re-materialized
985     * by the spilling pass and which we shouldn't delete.
986     */
987    struct brw_bitset remat = bitset_create(mem_ctx, impl->ssa_alloc);
988 
989    /* Create a nop instruction to use as a cursor as we extract and re-insert
990     * stuff into the CFG.
991     */
992    nir_instr *cursor =
993       &nir_intrinsic_instr_create(shader, nir_intrinsic_nop)->instr;
994    nir_instr_insert(nir_before_cf_list(&impl->body), cursor);
995 
996    ASSERTED bool found =
997       flatten_resume_if_ladder(impl, cursor, &impl->body,
998                                true, resume_instr, &remat);
999    assert(found);
1000 
1001    ralloc_free(mem_ctx);
1002 
1003    nir_validate_shader(shader, "after flatten_resume_if_ladder in "
1004                                "brw_nir_lower_shader_calls");
1005 
1006    nir_metadata_preserve(impl, nir_metadata_none);
1007 
1008    return resume_instr;
1009 }
1010 
1011 static void
replace_resume_with_halt(nir_shader * shader,nir_instr * keep)1012 replace_resume_with_halt(nir_shader *shader, nir_instr *keep)
1013 {
1014    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
1015 
1016    nir_builder b;
1017    nir_builder_init(&b, impl);
1018 
1019    nir_foreach_block_safe(block, impl) {
1020       nir_foreach_instr_safe(instr, block) {
1021          if (instr == keep)
1022             continue;
1023 
1024          if (instr->type != nir_instr_type_intrinsic)
1025             continue;
1026 
1027          nir_intrinsic_instr *resume = nir_instr_as_intrinsic(instr);
1028          if (resume->intrinsic != nir_intrinsic_rt_resume)
1029             continue;
1030 
1031          /* If this is some other resume, then we've kicked off a ray or
1032           * bindless thread and we don't want to go any further in this
1033           * shader.  Insert a halt so that NIR will delete any instructions
1034           * dominated by this call instruction including the scratch_load
1035           * instructions we inserted.
1036           */
1037          nir_cf_list cf_list;
1038          nir_cf_extract(&cf_list, nir_after_instr(&resume->instr),
1039                                   nir_after_block(block));
1040          nir_cf_delete(&cf_list);
1041          b.cursor = nir_instr_remove(&resume->instr);
1042          nir_jump(&b, nir_jump_halt);
1043          break;
1044       }
1045    }
1046 }
1047 
1048 /** Lower shader call instructions to split shaders.
1049  *
1050  * Shader calls can be split into an initial shader and a series of "resume"
1051  * shaders.   When the shader is first invoked, it is the initial shader which
1052  * is executed.  At any point in the initial shader or any one of the resume
1053  * shaders, a shader call operation may be performed.  The possible shader call
1054  * operations are:
1055  *
1056  *  - trace_ray
1057  *  - report_ray_intersection
1058  *  - execute_callable
1059  *
1060  * When a shader call operation is performed, we push all live values to the
1061  * stack,call rt_trace_ray/rt_execute_callable and then kill the shader. Once
1062  * the operation we invoked is complete, a callee shader will return execution
1063  * to the respective resume shader. The resume shader pops the contents off
1064  * the stack and picks up where the calling shader left off.
1065  *
1066  * Stack management is assumed to be done after this pass. Call
1067  * instructions and their resumes get annotated with stack information that
1068  * should be enough for the backend to implement proper stack management.
1069  */
1070 bool
nir_lower_shader_calls(nir_shader * shader,nir_address_format address_format,unsigned stack_alignment,nir_shader *** resume_shaders_out,uint32_t * num_resume_shaders_out,void * mem_ctx)1071 nir_lower_shader_calls(nir_shader *shader,
1072                        nir_address_format address_format,
1073                        unsigned stack_alignment,
1074                        nir_shader ***resume_shaders_out,
1075                        uint32_t *num_resume_shaders_out,
1076                        void *mem_ctx)
1077 {
1078    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
1079 
1080    nir_builder b;
1081    nir_builder_init(&b, impl);
1082 
1083    int num_calls = 0;
1084    nir_foreach_block(block, impl) {
1085       nir_foreach_instr_safe(instr, block) {
1086          if (instr_is_shader_call(instr))
1087             num_calls++;
1088       }
1089    }
1090 
1091    if (num_calls == 0) {
1092       nir_shader_preserve_all_metadata(shader);
1093       *num_resume_shaders_out = 0;
1094       return false;
1095    }
1096 
1097    /* Some intrinsics not only can't be re-materialized but aren't preserved
1098     * when moving to the continuation shader.  We have to move them to the top
1099     * to ensure they get spilled as needed.
1100     */
1101    {
1102       bool progress = false;
1103       NIR_PASS(progress, shader, move_system_values_to_top);
1104       if (progress)
1105          NIR_PASS(progress, shader, nir_opt_cse);
1106    }
1107 
1108    NIR_PASS_V(shader, spill_ssa_defs_and_lower_shader_calls,
1109               num_calls, address_format, stack_alignment);
1110 
1111    nir_opt_remove_phis(shader);
1112 
1113    /* Make N copies of our shader */
1114    nir_shader **resume_shaders = ralloc_array(mem_ctx, nir_shader *, num_calls);
1115    for (unsigned i = 0; i < num_calls; i++)
1116       resume_shaders[i] = nir_shader_clone(mem_ctx, shader);
1117 
1118    replace_resume_with_halt(shader, NULL);
1119    for (unsigned i = 0; i < num_calls; i++) {
1120       nir_instr *resume_instr = lower_resume(resume_shaders[i], i);
1121       replace_resume_with_halt(resume_shaders[i], resume_instr);
1122       nir_opt_remove_phis(resume_shaders[i]);
1123    }
1124 
1125    *resume_shaders_out = resume_shaders;
1126    *num_resume_shaders_out = num_calls;
1127 
1128    return true;
1129 }
1130