1 /*
2 * Copyright (c) 2021 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24 #include "brw_nir_rt.h"
25 #include "brw_nir_rt_builder.h"
26
27 #include "nir_deref.h"
28
29 #include "util/macros.h"
30
31 struct lowering_state {
32 const struct intel_device_info *devinfo;
33
34 struct hash_table *queries;
35 uint32_t n_queries;
36
37 struct brw_nir_rt_globals_defs globals;
38 nir_ssa_def *rq_globals;
39
40 uint32_t state_scratch_base_offset;
41 };
42
43 struct brw_ray_query {
44 nir_variable *opaque_var;
45 uint32_t id;
46 };
47
48 #define SIZEOF_QUERY_STATE (sizeof(uint32_t))
49
50 static bool
need_spill_fill(struct lowering_state * state)51 need_spill_fill(struct lowering_state *state)
52 {
53 return state->n_queries > 1;
54 }
55
56 /**
57 * This pass converts opaque RayQuery structures from SPIRV into a vec3 where
58 * the first 2 elements store a global address for the query and the third
59 * element is an incremented counter on the number of executed
60 * nir_intrinsic_rq_proceed.
61 */
62
63 static bool
maybe_create_brw_var(nir_instr * instr,struct lowering_state * state)64 maybe_create_brw_var(nir_instr *instr, struct lowering_state *state)
65 {
66 if (instr->type != nir_instr_type_deref)
67 return false;
68
69 nir_deref_instr *deref = nir_instr_as_deref(instr);
70 if (deref->deref_type != nir_deref_type_var &&
71 deref->deref_type != nir_deref_type_array)
72 return false;
73
74 nir_variable *opaque_var = nir_deref_instr_get_variable(deref);
75 if (!opaque_var || !opaque_var->data.ray_query)
76 return false;
77
78 struct hash_entry *entry = _mesa_hash_table_search(state->queries, opaque_var);
79 if (entry)
80 return false;
81
82 struct brw_ray_query *rq = rzalloc(state->queries, struct brw_ray_query);
83 rq->opaque_var = opaque_var;
84 rq->id = state->n_queries;
85
86 _mesa_hash_table_insert(state->queries, opaque_var, rq);
87
88 unsigned aoa_size = glsl_get_aoa_size(opaque_var->type);
89 state->n_queries += MAX2(1, aoa_size);
90
91 return true;
92 }
93
94 static nir_ssa_def *
get_ray_query_shadow_addr(nir_builder * b,nir_deref_instr * deref,struct lowering_state * state,nir_ssa_def ** out_state_offset)95 get_ray_query_shadow_addr(nir_builder *b,
96 nir_deref_instr *deref,
97 struct lowering_state *state,
98 nir_ssa_def **out_state_offset)
99 {
100 nir_deref_path path;
101 nir_deref_path_init(&path, deref, NULL);
102 assert(path.path[0]->deref_type == nir_deref_type_var);
103
104 nir_variable *opaque_var = nir_deref_instr_get_variable(path.path[0]);
105 struct hash_entry *entry = _mesa_hash_table_search(state->queries, opaque_var);
106 assert(entry);
107
108 struct brw_ray_query *rq = entry->data;
109
110 /* Base address in the shadow memory of the variable associated with this
111 * ray query variable.
112 */
113 nir_ssa_def *base_addr =
114 nir_iadd_imm(b, state->globals.resume_sbt_addr,
115 brw_rt_ray_queries_shadow_stack_size(state->devinfo) * rq->id);
116
117 bool spill_fill = need_spill_fill(state);
118 *out_state_offset = nir_imm_int(b, state->state_scratch_base_offset +
119 SIZEOF_QUERY_STATE * rq->id);
120
121 if (!spill_fill)
122 return NULL;
123
124 /* Just emit code and let constant-folding go to town */
125 nir_deref_instr **p = &path.path[1];
126 for (; *p; p++) {
127 if ((*p)->deref_type == nir_deref_type_array) {
128 nir_ssa_def *index = nir_ssa_for_src(b, (*p)->arr.index, 1);
129
130 /**/
131 uint32_t local_state_offset = SIZEOF_QUERY_STATE *
132 MAX2(1, glsl_get_aoa_size((*p)->type));
133 *out_state_offset =
134 nir_iadd(b, *out_state_offset,
135 nir_imul_imm(b, index, local_state_offset));
136
137 /**/
138 uint64_t size = MAX2(1, glsl_get_aoa_size((*p)->type)) *
139 brw_rt_ray_queries_shadow_stack_size(state->devinfo);
140
141 nir_ssa_def *mul = nir_amul_imm(b, nir_i2i64(b, index), size);
142
143 base_addr = nir_iadd(b, base_addr, mul);
144 } else {
145 unreachable("Unsupported deref type");
146 }
147 }
148
149 nir_deref_path_finish(&path);
150
151 /* Add the lane offset to the shadow memory address */
152 nir_ssa_def *lane_offset =
153 nir_imul_imm(
154 b,
155 nir_iadd(
156 b,
157 nir_imul(
158 b,
159 brw_load_btd_dss_id(b),
160 brw_nir_rt_load_num_simd_lanes_per_dss(b, state->devinfo)),
161 brw_nir_rt_sync_stack_id(b)),
162 BRW_RT_SIZEOF_SHADOW_RAY_QUERY);
163
164 return nir_iadd(b, base_addr, nir_i2i64(b, lane_offset));
165 }
166
167 static void
update_trace_ctrl_level(nir_builder * b,nir_ssa_def * state_scratch_offset,nir_ssa_def ** out_old_ctrl,nir_ssa_def ** out_old_level,nir_ssa_def * new_ctrl,nir_ssa_def * new_level)168 update_trace_ctrl_level(nir_builder *b,
169 nir_ssa_def *state_scratch_offset,
170 nir_ssa_def **out_old_ctrl,
171 nir_ssa_def **out_old_level,
172 nir_ssa_def *new_ctrl,
173 nir_ssa_def *new_level)
174 {
175 nir_ssa_def *old_value = nir_load_scratch(b, 1, 32, state_scratch_offset, 4);
176 nir_ssa_def *old_ctrl = nir_ishr_imm(b, old_value, 2);
177 nir_ssa_def *old_level = nir_iand_imm(b, old_value, 0x3);
178
179 if (out_old_ctrl)
180 *out_old_ctrl = old_ctrl;
181 if (out_old_level)
182 *out_old_level = old_level;
183
184 if (new_ctrl || new_level) {
185 if (!new_ctrl)
186 new_ctrl = old_ctrl;
187 if (!new_level)
188 new_level = old_level;
189
190 nir_ssa_def *new_value = nir_ior(b, nir_ishl_imm(b, new_ctrl, 2), new_level);
191 nir_store_scratch(b, new_value, state_scratch_offset, 4, 0x1);
192 }
193 }
194
195 static void
fill_query(nir_builder * b,nir_ssa_def * hw_stack_addr,nir_ssa_def * shadow_stack_addr,nir_ssa_def * ctrl)196 fill_query(nir_builder *b,
197 nir_ssa_def *hw_stack_addr,
198 nir_ssa_def *shadow_stack_addr,
199 nir_ssa_def *ctrl)
200 {
201 brw_nir_memcpy_global(b, hw_stack_addr, 64, shadow_stack_addr, 64,
202 BRW_RT_SIZEOF_RAY_QUERY);
203 }
204
205 static void
spill_query(nir_builder * b,nir_ssa_def * hw_stack_addr,nir_ssa_def * shadow_stack_addr)206 spill_query(nir_builder *b,
207 nir_ssa_def *hw_stack_addr,
208 nir_ssa_def *shadow_stack_addr)
209 {
210 brw_nir_memcpy_global(b, shadow_stack_addr, 64, hw_stack_addr, 64,
211 BRW_RT_SIZEOF_RAY_QUERY);
212 }
213
214
215 static void
lower_ray_query_intrinsic(nir_builder * b,nir_intrinsic_instr * intrin,struct lowering_state * state)216 lower_ray_query_intrinsic(nir_builder *b,
217 nir_intrinsic_instr *intrin,
218 struct lowering_state *state)
219 {
220 nir_deref_instr *deref = nir_src_as_deref(intrin->src[0]);
221
222 b->cursor = nir_instr_remove(&intrin->instr);
223
224 nir_ssa_def *ctrl_level_addr;
225 nir_ssa_def *shadow_stack_addr =
226 get_ray_query_shadow_addr(b, deref, state, &ctrl_level_addr);
227 nir_ssa_def *hw_stack_addr =
228 brw_nir_rt_sync_stack_addr(b, state->globals.base_mem_addr, state->devinfo);
229 nir_ssa_def *stack_addr = shadow_stack_addr ? shadow_stack_addr : hw_stack_addr;
230
231 switch (intrin->intrinsic) {
232 case nir_intrinsic_rq_initialize: {
233 nir_ssa_def *as_addr = intrin->src[1].ssa;
234 nir_ssa_def *ray_flags = intrin->src[2].ssa;
235 /* From the SPIR-V spec:
236 *
237 * "Only the 8 least-significant bits of Cull Mask are used by
238 * this instruction - other bits are ignored.
239 *
240 * Only the 16 least-significant bits of Miss Index are used by
241 * this instruction - other bits are ignored."
242 */
243 nir_ssa_def *cull_mask = nir_iand_imm(b, intrin->src[3].ssa, 0xff);
244 nir_ssa_def *ray_orig = intrin->src[4].ssa;
245 nir_ssa_def *ray_t_min = intrin->src[5].ssa;
246 nir_ssa_def *ray_dir = intrin->src[6].ssa;
247 nir_ssa_def *ray_t_max = intrin->src[7].ssa;
248
249 nir_ssa_def *root_node_ptr =
250 brw_nir_rt_acceleration_structure_to_root_node(b, as_addr);
251
252 struct brw_nir_rt_mem_ray_defs ray_defs = {
253 .root_node_ptr = root_node_ptr,
254 .ray_flags = nir_u2u16(b, ray_flags),
255 .ray_mask = cull_mask,
256 .orig = ray_orig,
257 .t_near = ray_t_min,
258 .dir = ray_dir,
259 .t_far = ray_t_max,
260 };
261
262 nir_ssa_def *ray_addr =
263 brw_nir_rt_mem_ray_addr(b, stack_addr, BRW_RT_BVH_LEVEL_WORLD);
264
265 brw_nir_rt_query_mark_init(b, stack_addr);
266 brw_nir_rt_init_mem_hit_at_addr(b, stack_addr, false, ray_t_max);
267 brw_nir_rt_init_mem_hit_at_addr(b, stack_addr, true, ray_t_max);
268 brw_nir_rt_store_mem_ray_query_at_addr(b, ray_addr, &ray_defs);
269
270 update_trace_ctrl_level(b, ctrl_level_addr,
271 NULL, NULL,
272 nir_imm_int(b, GEN_RT_TRACE_RAY_INITAL),
273 nir_imm_int(b, BRW_RT_BVH_LEVEL_WORLD));
274 break;
275 }
276
277 case nir_intrinsic_rq_proceed: {
278 nir_ssa_def *not_done =
279 nir_inot(b, brw_nir_rt_query_done(b, stack_addr));
280 nir_ssa_def *not_done_then, *not_done_else;
281
282 nir_push_if(b, not_done);
283 {
284 nir_ssa_def *ctrl, *level;
285 update_trace_ctrl_level(b, ctrl_level_addr,
286 &ctrl, &level,
287 NULL,
288 NULL);
289
290 /* Mark the query as done because handing it over to the HW for
291 * processing. If the HW make any progress, it will write back some
292 * data and as a side effect, clear the "done" bit. If no progress is
293 * made, HW does not write anything back and we can use this bit to
294 * detect that.
295 */
296 brw_nir_rt_query_mark_done(b, stack_addr);
297
298 if (shadow_stack_addr)
299 fill_query(b, hw_stack_addr, shadow_stack_addr, ctrl);
300
301 nir_trace_ray_intel(b, state->rq_globals, level, ctrl, .synchronous = true);
302
303 struct brw_nir_rt_mem_hit_defs hit_in = {};
304 brw_nir_rt_load_mem_hit_from_addr(b, &hit_in, hw_stack_addr, false);
305
306 if (shadow_stack_addr)
307 spill_query(b, hw_stack_addr, shadow_stack_addr);
308
309 update_trace_ctrl_level(b, ctrl_level_addr,
310 NULL, NULL,
311 nir_imm_int(b, GEN_RT_TRACE_RAY_CONTINUE),
312 hit_in.bvh_level);
313
314 not_done_then = nir_inot(b, hit_in.done);
315 }
316 nir_push_else(b, NULL);
317 {
318 not_done_else = nir_imm_false(b);
319 }
320 nir_pop_if(b, NULL);
321 not_done = nir_if_phi(b, not_done_then, not_done_else);
322 nir_ssa_def_rewrite_uses(&intrin->dest.ssa, not_done);
323 break;
324 }
325
326 case nir_intrinsic_rq_confirm_intersection: {
327 brw_nir_memcpy_global(b,
328 brw_nir_rt_mem_hit_addr_from_addr(b, stack_addr, true), 16,
329 brw_nir_rt_mem_hit_addr_from_addr(b, stack_addr, false), 16,
330 BRW_RT_SIZEOF_HIT_INFO);
331 update_trace_ctrl_level(b, ctrl_level_addr,
332 NULL, NULL,
333 nir_imm_int(b, GEN_RT_TRACE_RAY_COMMIT),
334 nir_imm_int(b, BRW_RT_BVH_LEVEL_OBJECT));
335 break;
336 }
337
338 case nir_intrinsic_rq_generate_intersection: {
339 brw_nir_rt_generate_hit_addr(b, stack_addr, intrin->src[1].ssa);
340 update_trace_ctrl_level(b, ctrl_level_addr,
341 NULL, NULL,
342 nir_imm_int(b, GEN_RT_TRACE_RAY_COMMIT),
343 nir_imm_int(b, BRW_RT_BVH_LEVEL_OBJECT));
344 break;
345 }
346
347 case nir_intrinsic_rq_terminate: {
348 brw_nir_rt_query_mark_done(b, stack_addr);
349 break;
350 }
351
352 case nir_intrinsic_rq_load: {
353 const bool committed = nir_src_as_bool(intrin->src[1]);
354
355 struct brw_nir_rt_mem_ray_defs world_ray_in = {};
356 struct brw_nir_rt_mem_ray_defs object_ray_in = {};
357 struct brw_nir_rt_mem_hit_defs hit_in = {};
358 brw_nir_rt_load_mem_ray_from_addr(b, &world_ray_in, stack_addr,
359 BRW_RT_BVH_LEVEL_WORLD);
360 brw_nir_rt_load_mem_ray_from_addr(b, &object_ray_in, stack_addr,
361 BRW_RT_BVH_LEVEL_OBJECT);
362 brw_nir_rt_load_mem_hit_from_addr(b, &hit_in, stack_addr, committed);
363
364 nir_ssa_def *sysval = NULL;
365 switch (nir_intrinsic_base(intrin)) {
366 case nir_ray_query_value_intersection_type:
367 if (committed) {
368 /* Values we want to generate :
369 *
370 * RayQueryCommittedIntersectionNoneEXT = 0U <= hit_in.valid == false
371 * RayQueryCommittedIntersectionTriangleEXT = 1U <= hit_in.leaf_type == BRW_RT_BVH_NODE_TYPE_QUAD (4)
372 * RayQueryCommittedIntersectionGeneratedEXT = 2U <= hit_in.leaf_type == BRW_RT_BVH_NODE_TYPE_PROCEDURAL (3)
373 */
374 sysval =
375 nir_bcsel(b, nir_ieq(b, hit_in.leaf_type, nir_imm_int(b, 4)),
376 nir_imm_int(b, 1), nir_imm_int(b, 2));
377 sysval =
378 nir_bcsel(b, hit_in.valid,
379 sysval, nir_imm_int(b, 0));
380 } else {
381 /* 0 -> triangle, 1 -> AABB */
382 sysval =
383 nir_b2i32(b,
384 nir_ieq(b, hit_in.leaf_type,
385 nir_imm_int(b, BRW_RT_BVH_NODE_TYPE_PROCEDURAL)));
386 }
387 break;
388
389 case nir_ray_query_value_intersection_t:
390 sysval = hit_in.t;
391 break;
392
393 case nir_ray_query_value_intersection_instance_custom_index: {
394 struct brw_nir_rt_bvh_instance_leaf_defs leaf;
395 brw_nir_rt_load_bvh_instance_leaf(b, &leaf, hit_in.inst_leaf_ptr);
396 sysval = leaf.instance_id;
397 break;
398 }
399
400 case nir_ray_query_value_intersection_instance_id: {
401 struct brw_nir_rt_bvh_instance_leaf_defs leaf;
402 brw_nir_rt_load_bvh_instance_leaf(b, &leaf, hit_in.inst_leaf_ptr);
403 sysval = leaf.instance_index;
404 break;
405 }
406
407 case nir_ray_query_value_intersection_instance_sbt_index: {
408 struct brw_nir_rt_bvh_instance_leaf_defs leaf;
409 brw_nir_rt_load_bvh_instance_leaf(b, &leaf, hit_in.inst_leaf_ptr);
410 sysval = leaf.contribution_to_hit_group_index;
411 break;
412 }
413
414 case nir_ray_query_value_intersection_geometry_index: {
415 nir_ssa_def *geometry_index_dw =
416 nir_load_global(b, nir_iadd_imm(b, hit_in.prim_leaf_ptr, 4), 4,
417 1, 32);
418 sysval = nir_iand_imm(b, geometry_index_dw, BITFIELD_MASK(29));
419 break;
420 }
421
422 case nir_ray_query_value_intersection_primitive_index:
423 sysval = brw_nir_rt_load_primitive_id_from_hit(b, NULL /* is_procedural */, &hit_in);
424 break;
425
426 case nir_ray_query_value_intersection_barycentrics:
427 sysval = hit_in.tri_bary;
428 break;
429
430 case nir_ray_query_value_intersection_front_face:
431 sysval = hit_in.front_face;
432 break;
433
434 case nir_ray_query_value_intersection_object_ray_direction:
435 sysval = world_ray_in.dir;
436 break;
437
438 case nir_ray_query_value_intersection_object_ray_origin:
439 sysval = world_ray_in.orig;
440 break;
441
442 case nir_ray_query_value_intersection_object_to_world: {
443 struct brw_nir_rt_bvh_instance_leaf_defs leaf;
444 brw_nir_rt_load_bvh_instance_leaf(b, &leaf, hit_in.inst_leaf_ptr);
445 sysval = leaf.object_to_world[nir_intrinsic_column(intrin)];
446 break;
447 }
448
449 case nir_ray_query_value_intersection_world_to_object: {
450 struct brw_nir_rt_bvh_instance_leaf_defs leaf;
451 brw_nir_rt_load_bvh_instance_leaf(b, &leaf, hit_in.inst_leaf_ptr);
452 sysval = leaf.world_to_object[nir_intrinsic_column(intrin)];
453 break;
454 }
455
456 case nir_ray_query_value_intersection_candidate_aabb_opaque:
457 sysval = hit_in.front_face;
458 break;
459
460 case nir_ray_query_value_tmin:
461 sysval = world_ray_in.t_near;
462 break;
463
464 case nir_ray_query_value_flags:
465 sysval = nir_u2u32(b, world_ray_in.ray_flags);
466 break;
467
468 case nir_ray_query_value_world_ray_direction:
469 sysval = world_ray_in.dir;
470 break;
471
472 case nir_ray_query_value_world_ray_origin:
473 sysval = world_ray_in.orig;
474 break;
475
476 default:
477 unreachable("Invalid ray query");
478 }
479
480 assert(sysval);
481 nir_ssa_def_rewrite_uses(&intrin->dest.ssa, sysval);
482 break;
483 }
484
485 default:
486 unreachable("Invalid intrinsic");
487 }
488 }
489
490 static void
lower_ray_query_impl(nir_function_impl * impl,struct lowering_state * state)491 lower_ray_query_impl(nir_function_impl *impl, struct lowering_state *state)
492 {
493 nir_builder _b, *b = &_b;
494 nir_builder_init(&_b, impl);
495
496 b->cursor = nir_before_block(nir_start_block(b->impl));
497
498 state->rq_globals = nir_load_ray_query_global_intel(b);
499
500 brw_nir_rt_load_globals_addr(b, &state->globals, state->rq_globals);
501
502 nir_foreach_block_safe(block, impl) {
503 nir_foreach_instr_safe(instr, block) {
504 if (instr->type != nir_instr_type_intrinsic)
505 continue;
506
507 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
508 if (intrin->intrinsic != nir_intrinsic_rq_initialize &&
509 intrin->intrinsic != nir_intrinsic_rq_terminate &&
510 intrin->intrinsic != nir_intrinsic_rq_proceed &&
511 intrin->intrinsic != nir_intrinsic_rq_generate_intersection &&
512 intrin->intrinsic != nir_intrinsic_rq_confirm_intersection &&
513 intrin->intrinsic != nir_intrinsic_rq_load)
514 continue;
515
516 lower_ray_query_intrinsic(b, intrin, state);
517 }
518 }
519
520 nir_metadata_preserve(impl, nir_metadata_none);
521 }
522
523 bool
brw_nir_lower_ray_queries(nir_shader * shader,const struct intel_device_info * devinfo)524 brw_nir_lower_ray_queries(nir_shader *shader,
525 const struct intel_device_info *devinfo)
526 {
527 struct lowering_state state = {
528 .devinfo = devinfo,
529 .queries = _mesa_pointer_hash_table_create(NULL),
530 };
531
532 assert(exec_list_length(&shader->functions) == 1);
533
534 /* Find query variables */
535 nir_function_impl *impl = nir_shader_get_entrypoint(shader);
536 nir_foreach_block_safe(block, impl) {
537 nir_foreach_instr(instr, block)
538 maybe_create_brw_var(instr, &state);
539 }
540
541 bool progress = state.n_queries > 0;
542
543 if (progress) {
544 state.state_scratch_base_offset = shader->scratch_size;
545 shader->scratch_size += SIZEOF_QUERY_STATE * state.n_queries;
546
547 lower_ray_query_impl(impl, &state);
548
549 nir_remove_dead_derefs(shader);
550 nir_remove_dead_variables(shader,
551 nir_var_shader_temp | nir_var_function_temp,
552 NULL);
553
554 nir_metadata_preserve(impl, nir_metadata_none);
555 }
556
557 ralloc_free(state.queries);
558
559 return progress;
560 }
561