• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2021 Intel Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 #include "brw_nir_rt.h"
25 #include "brw_nir_rt_builder.h"
26 
27 #include "nir_deref.h"
28 
29 #include "util/macros.h"
30 
31 struct lowering_state {
32    const struct intel_device_info *devinfo;
33 
34    struct hash_table *queries;
35    uint32_t n_queries;
36 
37    struct brw_nir_rt_globals_defs globals;
38    nir_ssa_def *rq_globals;
39 
40    uint32_t state_scratch_base_offset;
41 };
42 
43 struct brw_ray_query {
44    nir_variable *opaque_var;
45    uint32_t id;
46 };
47 
48 #define SIZEOF_QUERY_STATE (sizeof(uint32_t))
49 
50 static bool
need_spill_fill(struct lowering_state * state)51 need_spill_fill(struct lowering_state *state)
52 {
53    return state->n_queries > 1;
54 }
55 
56 /**
57  * This pass converts opaque RayQuery structures from SPIRV into a vec3 where
58  * the first 2 elements store a global address for the query and the third
59  * element is an incremented counter on the number of executed
60  * nir_intrinsic_rq_proceed.
61  */
62 
63 static bool
maybe_create_brw_var(nir_instr * instr,struct lowering_state * state)64 maybe_create_brw_var(nir_instr *instr, struct lowering_state *state)
65 {
66    if (instr->type != nir_instr_type_deref)
67       return false;
68 
69    nir_deref_instr *deref = nir_instr_as_deref(instr);
70    if (deref->deref_type != nir_deref_type_var &&
71        deref->deref_type != nir_deref_type_array)
72       return false;
73 
74    nir_variable *opaque_var = nir_deref_instr_get_variable(deref);
75    if (!opaque_var || !opaque_var->data.ray_query)
76       return false;
77 
78    struct hash_entry *entry = _mesa_hash_table_search(state->queries, opaque_var);
79    if (entry)
80       return false;
81 
82    struct brw_ray_query *rq = rzalloc(state->queries, struct brw_ray_query);
83    rq->opaque_var = opaque_var;
84    rq->id = state->n_queries;
85 
86    _mesa_hash_table_insert(state->queries, opaque_var, rq);
87 
88    unsigned aoa_size = glsl_get_aoa_size(opaque_var->type);
89    state->n_queries += MAX2(1, aoa_size);
90 
91    return true;
92 }
93 
94 static nir_ssa_def *
get_ray_query_shadow_addr(nir_builder * b,nir_deref_instr * deref,struct lowering_state * state,nir_ssa_def ** out_state_offset)95 get_ray_query_shadow_addr(nir_builder *b,
96                           nir_deref_instr *deref,
97                           struct lowering_state *state,
98                           nir_ssa_def **out_state_offset)
99 {
100    nir_deref_path path;
101    nir_deref_path_init(&path, deref, NULL);
102    assert(path.path[0]->deref_type == nir_deref_type_var);
103 
104    nir_variable *opaque_var = nir_deref_instr_get_variable(path.path[0]);
105    struct hash_entry *entry = _mesa_hash_table_search(state->queries, opaque_var);
106    assert(entry);
107 
108    struct brw_ray_query *rq = entry->data;
109 
110    /* Base address in the shadow memory of the variable associated with this
111     * ray query variable.
112     */
113    nir_ssa_def *base_addr =
114       nir_iadd_imm(b, state->globals.resume_sbt_addr,
115                    brw_rt_ray_queries_shadow_stack_size(state->devinfo) * rq->id);
116 
117    bool spill_fill = need_spill_fill(state);
118    *out_state_offset = nir_imm_int(b, state->state_scratch_base_offset +
119                                       SIZEOF_QUERY_STATE * rq->id);
120 
121    if (!spill_fill)
122       return NULL;
123 
124    /* Just emit code and let constant-folding go to town */
125    nir_deref_instr **p = &path.path[1];
126    for (; *p; p++) {
127       if ((*p)->deref_type == nir_deref_type_array) {
128          nir_ssa_def *index = nir_ssa_for_src(b, (*p)->arr.index, 1);
129 
130          /**/
131          uint32_t local_state_offset = SIZEOF_QUERY_STATE *
132                                        MAX2(1, glsl_get_aoa_size((*p)->type));
133          *out_state_offset =
134             nir_iadd(b, *out_state_offset,
135                         nir_imul_imm(b, index, local_state_offset));
136 
137          /**/
138          uint64_t size = MAX2(1, glsl_get_aoa_size((*p)->type)) *
139             brw_rt_ray_queries_shadow_stack_size(state->devinfo);
140 
141          nir_ssa_def *mul = nir_amul_imm(b, nir_i2i64(b, index), size);
142 
143          base_addr = nir_iadd(b, base_addr, mul);
144       } else {
145          unreachable("Unsupported deref type");
146       }
147    }
148 
149    nir_deref_path_finish(&path);
150 
151    /* Add the lane offset to the shadow memory address */
152    nir_ssa_def *lane_offset =
153       nir_imul_imm(
154          b,
155          nir_iadd(
156             b,
157             nir_imul(
158                b,
159                brw_load_btd_dss_id(b),
160                brw_nir_rt_load_num_simd_lanes_per_dss(b, state->devinfo)),
161             brw_nir_rt_sync_stack_id(b)),
162          BRW_RT_SIZEOF_SHADOW_RAY_QUERY);
163 
164    return nir_iadd(b, base_addr, nir_i2i64(b, lane_offset));
165 }
166 
167 static void
update_trace_ctrl_level(nir_builder * b,nir_ssa_def * state_scratch_offset,nir_ssa_def ** out_old_ctrl,nir_ssa_def ** out_old_level,nir_ssa_def * new_ctrl,nir_ssa_def * new_level)168 update_trace_ctrl_level(nir_builder *b,
169                         nir_ssa_def *state_scratch_offset,
170                         nir_ssa_def **out_old_ctrl,
171                         nir_ssa_def **out_old_level,
172                         nir_ssa_def *new_ctrl,
173                         nir_ssa_def *new_level)
174 {
175    nir_ssa_def *old_value = nir_load_scratch(b, 1, 32, state_scratch_offset, 4);
176    nir_ssa_def *old_ctrl = nir_ishr_imm(b, old_value, 2);
177    nir_ssa_def *old_level = nir_iand_imm(b, old_value, 0x3);
178 
179    if (out_old_ctrl)
180       *out_old_ctrl = old_ctrl;
181    if (out_old_level)
182       *out_old_level = old_level;
183 
184    if (new_ctrl || new_level) {
185       if (!new_ctrl)
186          new_ctrl = old_ctrl;
187       if (!new_level)
188          new_level = old_level;
189 
190       nir_ssa_def *new_value = nir_ior(b, nir_ishl_imm(b, new_ctrl, 2), new_level);
191       nir_store_scratch(b, new_value, state_scratch_offset, 4, 0x1);
192    }
193 }
194 
195 static void
fill_query(nir_builder * b,nir_ssa_def * hw_stack_addr,nir_ssa_def * shadow_stack_addr,nir_ssa_def * ctrl)196 fill_query(nir_builder *b,
197            nir_ssa_def *hw_stack_addr,
198            nir_ssa_def *shadow_stack_addr,
199            nir_ssa_def *ctrl)
200 {
201    brw_nir_memcpy_global(b, hw_stack_addr, 64, shadow_stack_addr, 64,
202                          BRW_RT_SIZEOF_RAY_QUERY);
203 }
204 
205 static void
spill_query(nir_builder * b,nir_ssa_def * hw_stack_addr,nir_ssa_def * shadow_stack_addr)206 spill_query(nir_builder *b,
207             nir_ssa_def *hw_stack_addr,
208             nir_ssa_def *shadow_stack_addr)
209 {
210    brw_nir_memcpy_global(b, shadow_stack_addr, 64, hw_stack_addr, 64,
211                          BRW_RT_SIZEOF_RAY_QUERY);
212 }
213 
214 
215 static void
lower_ray_query_intrinsic(nir_builder * b,nir_intrinsic_instr * intrin,struct lowering_state * state)216 lower_ray_query_intrinsic(nir_builder *b,
217                           nir_intrinsic_instr *intrin,
218                           struct lowering_state *state)
219 {
220    nir_deref_instr *deref = nir_src_as_deref(intrin->src[0]);
221 
222    b->cursor = nir_instr_remove(&intrin->instr);
223 
224    nir_ssa_def *ctrl_level_addr;
225    nir_ssa_def *shadow_stack_addr =
226       get_ray_query_shadow_addr(b, deref, state, &ctrl_level_addr);
227    nir_ssa_def *hw_stack_addr =
228       brw_nir_rt_sync_stack_addr(b, state->globals.base_mem_addr, state->devinfo);
229    nir_ssa_def *stack_addr = shadow_stack_addr ? shadow_stack_addr : hw_stack_addr;
230 
231    switch (intrin->intrinsic) {
232    case nir_intrinsic_rq_initialize: {
233       nir_ssa_def *as_addr = intrin->src[1].ssa;
234       nir_ssa_def *ray_flags = intrin->src[2].ssa;
235       /* From the SPIR-V spec:
236        *
237        *    "Only the 8 least-significant bits of Cull Mask are used by
238        *    this instruction - other bits are ignored.
239        *
240        *    Only the 16 least-significant bits of Miss Index are used by
241        *    this instruction - other bits are ignored."
242        */
243       nir_ssa_def *cull_mask = nir_iand_imm(b, intrin->src[3].ssa, 0xff);
244       nir_ssa_def *ray_orig = intrin->src[4].ssa;
245       nir_ssa_def *ray_t_min = intrin->src[5].ssa;
246       nir_ssa_def *ray_dir = intrin->src[6].ssa;
247       nir_ssa_def *ray_t_max = intrin->src[7].ssa;
248 
249       nir_ssa_def *root_node_ptr =
250          brw_nir_rt_acceleration_structure_to_root_node(b, as_addr);
251 
252       struct brw_nir_rt_mem_ray_defs ray_defs = {
253          .root_node_ptr = root_node_ptr,
254          .ray_flags = nir_u2u16(b, ray_flags),
255          .ray_mask = cull_mask,
256          .orig = ray_orig,
257          .t_near = ray_t_min,
258          .dir = ray_dir,
259          .t_far = ray_t_max,
260       };
261 
262       nir_ssa_def *ray_addr =
263          brw_nir_rt_mem_ray_addr(b, stack_addr, BRW_RT_BVH_LEVEL_WORLD);
264 
265       brw_nir_rt_query_mark_init(b, stack_addr);
266       brw_nir_rt_init_mem_hit_at_addr(b, stack_addr, false, ray_t_max);
267       brw_nir_rt_init_mem_hit_at_addr(b, stack_addr, true, ray_t_max);
268       brw_nir_rt_store_mem_ray_query_at_addr(b, ray_addr, &ray_defs);
269 
270       update_trace_ctrl_level(b, ctrl_level_addr,
271                               NULL, NULL,
272                               nir_imm_int(b, GEN_RT_TRACE_RAY_INITAL),
273                               nir_imm_int(b, BRW_RT_BVH_LEVEL_WORLD));
274       break;
275    }
276 
277    case nir_intrinsic_rq_proceed: {
278       nir_ssa_def *not_done =
279          nir_inot(b, brw_nir_rt_query_done(b, stack_addr));
280       nir_ssa_def *not_done_then, *not_done_else;
281 
282       nir_push_if(b, not_done);
283       {
284          nir_ssa_def *ctrl, *level;
285          update_trace_ctrl_level(b, ctrl_level_addr,
286                                  &ctrl, &level,
287                                  NULL,
288                                  NULL);
289 
290          /* Mark the query as done because handing it over to the HW for
291           * processing. If the HW make any progress, it will write back some
292           * data and as a side effect, clear the "done" bit. If no progress is
293           * made, HW does not write anything back and we can use this bit to
294           * detect that.
295           */
296          brw_nir_rt_query_mark_done(b, stack_addr);
297 
298          if (shadow_stack_addr)
299             fill_query(b, hw_stack_addr, shadow_stack_addr, ctrl);
300 
301          nir_trace_ray_intel(b, state->rq_globals, level, ctrl, .synchronous = true);
302 
303          struct brw_nir_rt_mem_hit_defs hit_in = {};
304          brw_nir_rt_load_mem_hit_from_addr(b, &hit_in, hw_stack_addr, false);
305 
306          if (shadow_stack_addr)
307             spill_query(b, hw_stack_addr, shadow_stack_addr);
308 
309          update_trace_ctrl_level(b, ctrl_level_addr,
310                                  NULL, NULL,
311                                  nir_imm_int(b, GEN_RT_TRACE_RAY_CONTINUE),
312                                  hit_in.bvh_level);
313 
314          not_done_then = nir_inot(b, hit_in.done);
315       }
316       nir_push_else(b, NULL);
317       {
318          not_done_else = nir_imm_false(b);
319       }
320       nir_pop_if(b, NULL);
321       not_done = nir_if_phi(b, not_done_then, not_done_else);
322       nir_ssa_def_rewrite_uses(&intrin->dest.ssa, not_done);
323       break;
324    }
325 
326    case nir_intrinsic_rq_confirm_intersection: {
327       brw_nir_memcpy_global(b,
328                             brw_nir_rt_mem_hit_addr_from_addr(b, stack_addr, true), 16,
329                             brw_nir_rt_mem_hit_addr_from_addr(b, stack_addr, false), 16,
330                             BRW_RT_SIZEOF_HIT_INFO);
331       update_trace_ctrl_level(b, ctrl_level_addr,
332                               NULL, NULL,
333                               nir_imm_int(b, GEN_RT_TRACE_RAY_COMMIT),
334                               nir_imm_int(b, BRW_RT_BVH_LEVEL_OBJECT));
335       break;
336    }
337 
338    case nir_intrinsic_rq_generate_intersection: {
339       brw_nir_rt_generate_hit_addr(b, stack_addr, intrin->src[1].ssa);
340       update_trace_ctrl_level(b, ctrl_level_addr,
341                               NULL, NULL,
342                               nir_imm_int(b, GEN_RT_TRACE_RAY_COMMIT),
343                               nir_imm_int(b, BRW_RT_BVH_LEVEL_OBJECT));
344       break;
345    }
346 
347    case nir_intrinsic_rq_terminate: {
348       brw_nir_rt_query_mark_done(b, stack_addr);
349       break;
350    }
351 
352    case nir_intrinsic_rq_load: {
353       const bool committed = nir_src_as_bool(intrin->src[1]);
354 
355       struct brw_nir_rt_mem_ray_defs world_ray_in = {};
356       struct brw_nir_rt_mem_ray_defs object_ray_in = {};
357       struct brw_nir_rt_mem_hit_defs hit_in = {};
358       brw_nir_rt_load_mem_ray_from_addr(b, &world_ray_in, stack_addr,
359                                         BRW_RT_BVH_LEVEL_WORLD);
360       brw_nir_rt_load_mem_ray_from_addr(b, &object_ray_in, stack_addr,
361                                         BRW_RT_BVH_LEVEL_OBJECT);
362       brw_nir_rt_load_mem_hit_from_addr(b, &hit_in, stack_addr, committed);
363 
364       nir_ssa_def *sysval = NULL;
365       switch (nir_intrinsic_base(intrin)) {
366       case nir_ray_query_value_intersection_type:
367          if (committed) {
368             /* Values we want to generate :
369              *
370              * RayQueryCommittedIntersectionNoneEXT = 0U        <= hit_in.valid == false
371              * RayQueryCommittedIntersectionTriangleEXT = 1U    <= hit_in.leaf_type == BRW_RT_BVH_NODE_TYPE_QUAD (4)
372              * RayQueryCommittedIntersectionGeneratedEXT = 2U   <= hit_in.leaf_type == BRW_RT_BVH_NODE_TYPE_PROCEDURAL (3)
373              */
374             sysval =
375                nir_bcsel(b, nir_ieq(b, hit_in.leaf_type, nir_imm_int(b, 4)),
376                          nir_imm_int(b, 1), nir_imm_int(b, 2));
377             sysval =
378                nir_bcsel(b, hit_in.valid,
379                          sysval, nir_imm_int(b, 0));
380          } else {
381             /* 0 -> triangle, 1 -> AABB */
382             sysval =
383                nir_b2i32(b,
384                          nir_ieq(b, hit_in.leaf_type,
385                                     nir_imm_int(b, BRW_RT_BVH_NODE_TYPE_PROCEDURAL)));
386          }
387          break;
388 
389       case nir_ray_query_value_intersection_t:
390          sysval = hit_in.t;
391          break;
392 
393       case nir_ray_query_value_intersection_instance_custom_index: {
394          struct brw_nir_rt_bvh_instance_leaf_defs leaf;
395          brw_nir_rt_load_bvh_instance_leaf(b, &leaf, hit_in.inst_leaf_ptr);
396          sysval = leaf.instance_id;
397          break;
398       }
399 
400       case nir_ray_query_value_intersection_instance_id: {
401          struct brw_nir_rt_bvh_instance_leaf_defs leaf;
402          brw_nir_rt_load_bvh_instance_leaf(b, &leaf, hit_in.inst_leaf_ptr);
403          sysval = leaf.instance_index;
404          break;
405       }
406 
407       case nir_ray_query_value_intersection_instance_sbt_index: {
408          struct brw_nir_rt_bvh_instance_leaf_defs leaf;
409          brw_nir_rt_load_bvh_instance_leaf(b, &leaf, hit_in.inst_leaf_ptr);
410          sysval = leaf.contribution_to_hit_group_index;
411          break;
412       }
413 
414       case nir_ray_query_value_intersection_geometry_index: {
415          nir_ssa_def *geometry_index_dw =
416             nir_load_global(b, nir_iadd_imm(b, hit_in.prim_leaf_ptr, 4), 4,
417                             1, 32);
418          sysval = nir_iand_imm(b, geometry_index_dw, BITFIELD_MASK(29));
419          break;
420       }
421 
422       case nir_ray_query_value_intersection_primitive_index:
423          sysval = brw_nir_rt_load_primitive_id_from_hit(b, NULL /* is_procedural */, &hit_in);
424          break;
425 
426       case nir_ray_query_value_intersection_barycentrics:
427          sysval = hit_in.tri_bary;
428          break;
429 
430       case nir_ray_query_value_intersection_front_face:
431          sysval = hit_in.front_face;
432          break;
433 
434       case nir_ray_query_value_intersection_object_ray_direction:
435          sysval = world_ray_in.dir;
436          break;
437 
438       case nir_ray_query_value_intersection_object_ray_origin:
439          sysval = world_ray_in.orig;
440          break;
441 
442       case nir_ray_query_value_intersection_object_to_world: {
443          struct brw_nir_rt_bvh_instance_leaf_defs leaf;
444          brw_nir_rt_load_bvh_instance_leaf(b, &leaf, hit_in.inst_leaf_ptr);
445          sysval = leaf.object_to_world[nir_intrinsic_column(intrin)];
446          break;
447       }
448 
449       case nir_ray_query_value_intersection_world_to_object: {
450          struct brw_nir_rt_bvh_instance_leaf_defs leaf;
451          brw_nir_rt_load_bvh_instance_leaf(b, &leaf, hit_in.inst_leaf_ptr);
452          sysval = leaf.world_to_object[nir_intrinsic_column(intrin)];
453          break;
454       }
455 
456       case nir_ray_query_value_intersection_candidate_aabb_opaque:
457          sysval = hit_in.front_face;
458          break;
459 
460       case nir_ray_query_value_tmin:
461          sysval = world_ray_in.t_near;
462          break;
463 
464       case nir_ray_query_value_flags:
465          sysval = nir_u2u32(b, world_ray_in.ray_flags);
466          break;
467 
468       case nir_ray_query_value_world_ray_direction:
469          sysval = world_ray_in.dir;
470          break;
471 
472       case nir_ray_query_value_world_ray_origin:
473          sysval = world_ray_in.orig;
474          break;
475 
476       default:
477          unreachable("Invalid ray query");
478       }
479 
480       assert(sysval);
481       nir_ssa_def_rewrite_uses(&intrin->dest.ssa, sysval);
482       break;
483    }
484 
485    default:
486       unreachable("Invalid intrinsic");
487    }
488 }
489 
490 static void
lower_ray_query_impl(nir_function_impl * impl,struct lowering_state * state)491 lower_ray_query_impl(nir_function_impl *impl, struct lowering_state *state)
492 {
493    nir_builder _b, *b = &_b;
494    nir_builder_init(&_b, impl);
495 
496    b->cursor = nir_before_block(nir_start_block(b->impl));
497 
498    state->rq_globals = nir_load_ray_query_global_intel(b);
499 
500    brw_nir_rt_load_globals_addr(b, &state->globals, state->rq_globals);
501 
502    nir_foreach_block_safe(block, impl) {
503       nir_foreach_instr_safe(instr, block) {
504          if (instr->type != nir_instr_type_intrinsic)
505             continue;
506 
507          nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
508          if (intrin->intrinsic != nir_intrinsic_rq_initialize &&
509              intrin->intrinsic != nir_intrinsic_rq_terminate &&
510              intrin->intrinsic != nir_intrinsic_rq_proceed &&
511              intrin->intrinsic != nir_intrinsic_rq_generate_intersection &&
512              intrin->intrinsic != nir_intrinsic_rq_confirm_intersection &&
513              intrin->intrinsic != nir_intrinsic_rq_load)
514             continue;
515 
516          lower_ray_query_intrinsic(b, intrin, state);
517       }
518    }
519 
520    nir_metadata_preserve(impl, nir_metadata_none);
521 }
522 
523 bool
brw_nir_lower_ray_queries(nir_shader * shader,const struct intel_device_info * devinfo)524 brw_nir_lower_ray_queries(nir_shader *shader,
525                           const struct intel_device_info *devinfo)
526 {
527    struct lowering_state state = {
528       .devinfo = devinfo,
529       .queries = _mesa_pointer_hash_table_create(NULL),
530    };
531 
532    assert(exec_list_length(&shader->functions) == 1);
533 
534    /* Find query variables */
535    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
536    nir_foreach_block_safe(block, impl) {
537       nir_foreach_instr(instr, block)
538          maybe_create_brw_var(instr, &state);
539    }
540 
541    bool progress = state.n_queries > 0;
542 
543    if (progress) {
544       state.state_scratch_base_offset = shader->scratch_size;
545       shader->scratch_size += SIZEOF_QUERY_STATE * state.n_queries;
546 
547       lower_ray_query_impl(impl, &state);
548 
549       nir_remove_dead_derefs(shader);
550       nir_remove_dead_variables(shader,
551                                 nir_var_shader_temp | nir_var_function_temp,
552                                 NULL);
553 
554       nir_metadata_preserve(impl, nir_metadata_none);
555    }
556 
557    ralloc_free(state.queries);
558 
559    return progress;
560 }
561