1 /*
2 * Copyright © 2024 Valve Corporation
3 * SPDX-License-Identifier: MIT
4 */
5
6 #include "tu_shader.h"
7
8 #include "bvh/tu_build_interface.h"
9
10 #include "compiler/spirv/spirv.h"
11
12 #include "nir_builder.h"
13 #include "nir_deref.h"
14
15 enum rq_intersection_var_index {
16 rq_intersection_primitive_id,
17 rq_intersection_geometry_id,
18 rq_intersection_origin,
19 rq_intersection_direction,
20 rq_intersection_instance,
21 rq_intersection_type_flags,
22 rq_intersection_sbt_offset,
23 rq_intersection_barycentrics,
24 rq_intersection_t,
25 };
26
27 static const glsl_type *
get_rq_intersection_type(void)28 get_rq_intersection_type(void)
29 {
30 struct glsl_struct_field fields[] = {
31 #define FIELD(_type, _name) \
32 [rq_intersection_##_name] = glsl_struct_field(_type, #_name),
33 FIELD(glsl_uint_type(), primitive_id)
34 FIELD(glsl_uint_type(), geometry_id)
35 FIELD(glsl_vec_type(3), origin)
36 FIELD(glsl_vec_type(3), direction)
37 FIELD(glsl_uint_type(), instance)
38 FIELD(glsl_uint_type(), type_flags)
39 FIELD(glsl_uint_type(), sbt_offset)
40 FIELD(glsl_vec2_type(), barycentrics)
41 FIELD(glsl_float_type(), t)
42 #undef FIELD
43 };
44
45 return glsl_struct_type(fields, ARRAY_SIZE(fields), "ray_query_intersection", false);
46 }
47
48 enum rq_var_index {
49 rq_index_accel_struct_base,
50 rq_index_root_bvh_base,
51 rq_index_bvh_base,
52 rq_index_flags,
53 rq_index_tmin,
54 rq_index_world_origin,
55 rq_index_world_direction,
56 rq_index_incomplete,
57 rq_index_closest,
58 rq_index_candidate,
59 rq_index_stack_ptr,
60 rq_index_top_stack,
61 rq_index_stack_low_watermark,
62 rq_index_current_node,
63 rq_index_previous_node,
64 rq_index_instance_top_node,
65 rq_index_instance_bottom_node,
66 rq_index_stack,
67 };
68
69 /* Driver-internal flag to indicate that we haven't found an intersection */
70 #define TU_INTERSECTION_TYPE_NO_INTERSECTION (1u << 0)
71
72 #define MAX_STACK_DEPTH 8
73
74 static const glsl_type *
get_rq_type(void)75 get_rq_type(void)
76 {
77 const glsl_type *intersection_type = get_rq_intersection_type();
78
79 struct glsl_struct_field fields[] = {
80 #define FIELD(_type, _name) \
81 [rq_index_##_name] = glsl_struct_field(_type, #_name),
82 FIELD(glsl_uvec2_type(), accel_struct_base)
83 FIELD(glsl_uvec2_type(), root_bvh_base)
84 FIELD(glsl_uvec2_type(), bvh_base)
85 FIELD(glsl_uint_type(), flags)
86 FIELD(glsl_float_type(), tmin)
87 FIELD(glsl_vec_type(3), world_origin)
88 FIELD(glsl_vec_type(3), world_direction)
89 FIELD(glsl_bool_type(), incomplete)
90 FIELD(intersection_type, closest)
91 FIELD(intersection_type, candidate)
92 FIELD(glsl_uint_type(), stack_ptr)
93 FIELD(glsl_uint_type(), top_stack)
94 FIELD(glsl_uint_type(), stack_low_watermark)
95 FIELD(glsl_uint_type(), current_node)
96 FIELD(glsl_uint_type(), previous_node)
97 FIELD(glsl_uint_type(), instance_top_node)
98 FIELD(glsl_uint_type(), instance_bottom_node)
99 FIELD(glsl_array_type(glsl_uvec2_type(), MAX_STACK_DEPTH, 0), stack)
100 #undef FIELD
101 };
102
103 return glsl_struct_type(fields, ARRAY_SIZE(fields), "ray_query", false);
104 }
105
106 struct rq_var {
107 nir_variable *rq;
108
109 nir_intrinsic_instr *initialization;
110 nir_def *uav_index;
111 };
112
113 static void
lower_ray_query(nir_shader * shader,nir_function_impl * impl,nir_variable * ray_query,struct hash_table * ht)114 lower_ray_query(nir_shader *shader, nir_function_impl *impl,
115 nir_variable *ray_query, struct hash_table *ht)
116 {
117 struct rq_var *var = rzalloc(ht, struct rq_var);
118 const glsl_type *type = ray_query->type;
119
120 const glsl_type *rq_type = glsl_type_wrap_in_arrays(get_rq_type(), type);
121
122 if (impl)
123 var->rq = nir_local_variable_create(impl, rq_type, "ray_query");
124 else
125 var->rq = nir_variable_create(shader, nir_var_shader_temp, rq_type, "ray_query");
126
127 _mesa_hash_table_insert(ht, ray_query, var);
128 }
129
130 static nir_deref_instr *
get_rq_deref(nir_builder * b,struct hash_table * ht,nir_def * def,struct rq_var ** rq_var_out)131 get_rq_deref(nir_builder *b, struct hash_table *ht, nir_def *def,
132 struct rq_var **rq_var_out)
133 {
134 nir_deref_instr *deref = nir_instr_as_deref(def->parent_instr);
135
136 nir_deref_path path;
137 nir_deref_path_init(&path, deref, NULL);
138 assert(path.path[0]->deref_type == nir_deref_type_var);
139
140 nir_variable *opaque_var = nir_deref_instr_get_variable(path.path[0]);
141 struct hash_entry *entry = _mesa_hash_table_search(ht, opaque_var);
142 assert(entry);
143
144 struct rq_var *rq = (struct rq_var *)entry->data;
145
146 nir_deref_instr *out_deref = nir_build_deref_var(b, rq->rq);
147
148 if (glsl_type_is_array(opaque_var->type)) {
149 nir_deref_instr **p = &path.path[1];
150 for (; *p; p++) {
151 if ((*p)->deref_type == nir_deref_type_array) {
152 nir_def *index = (*p)->arr.index.ssa;
153
154 out_deref = nir_build_deref_array(b, out_deref, index);
155 } else {
156 unreachable("Unsupported deref type");
157 }
158 }
159 }
160
161 nir_deref_path_finish(&path);
162
163 if (rq_var_out)
164 *rq_var_out = rq;
165
166 return out_deref;
167 }
168
169 static nir_def *
get_rq_initialize_uav_index(nir_intrinsic_instr * intr,struct rq_var * var)170 get_rq_initialize_uav_index(nir_intrinsic_instr *intr, struct rq_var *var)
171 {
172 if (intr->src[1].ssa->parent_instr->type == nir_instr_type_intrinsic &&
173 nir_instr_as_intrinsic(intr->src[1].ssa->parent_instr)->intrinsic ==
174 nir_intrinsic_load_vulkan_descriptor) {
175 return intr->src[1].ssa;
176 } else {
177 return NULL;
178 }
179 }
180
181 /* Before we modify control flow, walk the shader and determine ray query
182 * instructions for which we know the ray query has been initialized via a
183 * descriptor instead of a pointer, and record the UAV descriptor.
184 */
185 static void
calc_uav_index(nir_function_impl * impl,struct hash_table * ht)186 calc_uav_index(nir_function_impl *impl, struct hash_table *ht)
187 {
188 nir_metadata_require(impl, nir_metadata_dominance);
189
190 nir_foreach_block (block, impl) {
191 nir_foreach_instr (instr, block) {
192 if (instr->type != nir_instr_type_intrinsic)
193 continue;
194
195 nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
196
197 nir_def *rq_def;
198 switch (intr->intrinsic) {
199 case nir_intrinsic_rq_initialize:
200 case nir_intrinsic_rq_load:
201 case nir_intrinsic_rq_proceed:
202 rq_def = intr->src[0].ssa;
203 break;
204 default:
205 continue;
206 }
207
208 nir_deref_instr *deref = nir_instr_as_deref(rq_def->parent_instr);
209
210 if (deref->deref_type != nir_deref_type_var)
211 continue;
212
213 nir_variable *opaque_var = deref->var;
214 struct hash_entry *entry = _mesa_hash_table_search(ht, opaque_var);
215 assert(entry);
216
217 struct rq_var *rq = (struct rq_var *)entry->data;
218
219 if (intr->intrinsic == nir_intrinsic_rq_initialize) {
220 rq->initialization = intr;
221 rq->uav_index = get_rq_initialize_uav_index(intr, rq);
222 } else {
223 if (rq->initialization &&
224 nir_block_dominates(rq->initialization->instr.block,
225 block) && rq->uav_index) {
226 _mesa_hash_table_insert(ht, instr, rq->uav_index);
227 }
228 }
229 }
230 }
231 }
232
233 /* Return a pointer to the TLAS descriptor, which is actually a UAV
234 * descriptor, if we know that the ray query has been initialized via a
235 * descriptor and not a pointer. If not known, returns NULL.
236 */
237 static nir_def *
get_uav_index(nir_instr * cur_instr,struct hash_table * ht)238 get_uav_index(nir_instr *cur_instr, struct hash_table *ht)
239 {
240 struct hash_entry *entry = _mesa_hash_table_search(ht, cur_instr);
241 if (entry)
242 return (nir_def *)entry->data;
243 return NULL;
244 }
245
246 /* Load some data from the TLAS header or instance descriptors. This uses the
247 * UAV descriptor when available, via "uav_index" which should be obtained
248 * from get_uav_index().
249 */
250 static nir_def *
load_tlas(nir_builder * b,nir_def * tlas,nir_def * uav_index,nir_def * index,unsigned offset,unsigned components)251 load_tlas(nir_builder *b, nir_def *tlas,
252 nir_def *uav_index, nir_def *index,
253 unsigned offset, unsigned components)
254 {
255 if (uav_index) {
256 return nir_load_uav_ir3(b, components, 32, uav_index,
257 nir_vec2(b, index, nir_imm_int(b, offset / 4)),
258 .access = (gl_access_qualifier)(
259 ACCESS_NON_WRITEABLE |
260 ACCESS_CAN_REORDER),
261 .align_mul = AS_RECORD_SIZE,
262 .align_offset = offset);
263 } else {
264 return nir_load_global_ir3(b, components, 32,
265 tlas,
266 nir_iadd_imm(b, nir_imul_imm(b, index, AS_RECORD_SIZE / 4),
267 offset / 4),
268 /* The required alignment of the
269 * user-specified base from the Vulkan spec.
270 */
271 .align_mul = 256,
272 .align_offset = 0);
273 }
274 }
275
276 /* The first record is the TLAS header and the rest of the records are
277 * instances, so we need to add 1 to the instance ID when reading data in an
278 * instance.
279 */
280 #define load_instance_offset(b, tlas, uav_index, instance, \
281 field, offset, components) \
282 load_tlas(b, tlas, uav_index, nir_iadd_imm(b, instance, 1), \
283 offsetof(struct tu_instance_descriptor, field) + offset, \
284 components)
285
286 #define load_instance(b, tlas, uav_index, instance, field, components) \
287 load_instance_offset(b, tlas, uav_index, instance, field, 0, components)
288
289 #define rq_deref(b, _rq, name) nir_build_deref_struct(b, _rq, rq_index_##name)
290 #define rq_load(b, _rq, name) nir_load_deref(b, rq_deref(b, _rq, name))
291 #define rq_store(b, _rq, name, val, wrmask) \
292 nir_store_deref(b, rq_deref(b, _rq, name), val, wrmask)
293 #define rqi_deref(b, _rq, name) nir_build_deref_struct(b, _rq, rq_intersection_##name)
294 #define rqi_load(b, _rq, name) nir_load_deref(b, rqi_deref(b, _rq, name))
295 #define rqi_store(b, _rq, name, val, wrmask) \
296 nir_store_deref(b, rqi_deref(b, _rq, name), val, wrmask)
297
298 static void
lower_rq_initialize(nir_builder * b,struct hash_table * ht,nir_intrinsic_instr * intr)299 lower_rq_initialize(nir_builder *b, struct hash_table *ht,
300 nir_intrinsic_instr *intr)
301 {
302 struct rq_var *var;
303 nir_deref_instr *rq = get_rq_deref(b, ht, intr->src[0].ssa, &var);
304
305 if (nir_instr_as_deref(intr->src[0].ssa->parent_instr)->deref_type ==
306 nir_deref_type_var) {
307 var->initialization = intr;
308 } else {
309 var->initialization = NULL;
310 }
311
312 nir_def *uav_index = get_rq_initialize_uav_index(intr, var);
313
314 nir_def *tlas = intr->src[1].ssa;
315 nir_def *flags = intr->src[2].ssa;
316 nir_def *cull_mask = intr->src[3].ssa;
317 nir_def *origin = intr->src[4].ssa;
318 nir_def *tmin = intr->src[5].ssa;
319 nir_def *direction = intr->src[6].ssa;
320 nir_def *tmax = intr->src[7].ssa;
321
322 nir_def *tlas_base;
323 if (uav_index) {
324 tlas_base = load_tlas(b, NULL, uav_index, nir_imm_int(b, 0),
325 offsetof(struct tu_accel_struct_header,
326 self_ptr), 2);
327 } else {
328 tlas_base = nir_unpack_64_2x32(b, tlas);
329 }
330
331 rq_store(b, rq, accel_struct_base, tlas_base, 0x3);
332
333 nir_def *root_bvh_base = load_tlas(b, tlas_base, uav_index, nir_imm_int(b, 0),
334 offsetof(struct tu_accel_struct_header,
335 bvh_ptr), 2);
336
337 nir_deref_instr *closest = rq_deref(b, rq, closest);
338 nir_deref_instr *candidate = rq_deref(b, rq, candidate);
339
340 rq_store(b, rq, flags,
341 /* Fill out initial fourth src of ray_intersection */
342 nir_ior_imm(b,
343 nir_ior(b, nir_ishl_imm(b, flags, 4),
344 nir_ishl_imm(b, cull_mask, 16)),
345 0b1111), 0x1);
346
347 rqi_store(b, candidate, origin, origin, 0x7);
348 rqi_store(b, candidate, direction, direction, 0x7);
349
350 rq_store(b, rq, tmin, tmin, 0x1);
351 rq_store(b, rq, world_origin, origin, 0x7);
352 rq_store(b, rq, world_direction, direction, 0x7);
353
354 rqi_store(b, closest, t, tmax, 0x1);
355 rqi_store(b, closest, type_flags, nir_imm_int(b, TU_INTERSECTION_TYPE_NO_INTERSECTION), 0x1);
356
357 /* Make sure that instance data loads don't hang in case of a miss by setting a valid initial instance. */
358 rqi_store(b, closest, instance, nir_imm_int(b, 0), 0x1);
359 rqi_store(b, candidate, instance, nir_imm_int(b, 0), 0x1);
360
361 rq_store(b, rq, root_bvh_base, root_bvh_base, 0x3);
362 rq_store(b, rq, bvh_base, root_bvh_base, 0x3);
363
364 rq_store(b, rq, stack_ptr, nir_imm_int(b, 0), 0x1);
365 rq_store(b, rq, top_stack, nir_imm_int(b, -1), 0x1);
366 rq_store(b, rq, stack_low_watermark, nir_imm_int(b, 0), 0x1);
367 rq_store(b, rq, current_node, nir_imm_int(b, 0), 0x1);
368 rq_store(b, rq, previous_node, nir_imm_int(b, VK_BVH_INVALID_NODE), 0x1);
369 rq_store(b, rq, instance_top_node, nir_imm_int(b, VK_BVH_INVALID_NODE), 0x1);
370 rq_store(b, rq, instance_bottom_node, nir_imm_int(b, VK_BVH_INVALID_NODE), 0x1);
371
372 rq_store(b, rq, incomplete, nir_imm_true(b), 0x1);
373 }
374
375 static void
insert_terminate_on_first_hit(nir_builder * b,nir_deref_instr * rq)376 insert_terminate_on_first_hit(nir_builder *b, nir_deref_instr *rq)
377 {
378 nir_def *terminate_on_first_hit =
379 nir_test_mask(b, rq_load(b, rq, flags),
380 SpvRayFlagsTerminateOnFirstHitKHRMask << 4);
381 nir_push_if(b, terminate_on_first_hit);
382 {
383 rq_store(b, rq, incomplete, nir_imm_false(b), 0x1);
384 }
385 nir_pop_if(b, NULL);
386 }
387
388 static void
lower_rq_confirm_intersection(nir_builder * b,struct hash_table * ht,nir_intrinsic_instr * intr)389 lower_rq_confirm_intersection(nir_builder *b, struct hash_table *ht, nir_intrinsic_instr *intr)
390 {
391 nir_deref_instr *rq = get_rq_deref(b, ht, intr->src[0].ssa, NULL);
392 nir_copy_deref(b, rq_deref(b, rq, closest), rq_deref(b, rq, candidate));
393 insert_terminate_on_first_hit(b, rq);
394 }
395
396 static void
lower_rq_generate_intersection(nir_builder * b,struct hash_table * ht,nir_intrinsic_instr * intr)397 lower_rq_generate_intersection(nir_builder *b, struct hash_table *ht, nir_intrinsic_instr *intr)
398 {
399 nir_deref_instr *rq = get_rq_deref(b, ht, intr->src[0].ssa, NULL);
400 nir_deref_instr *closest = rq_deref(b, rq, closest);
401 nir_deref_instr *candidate = rq_deref(b, rq, candidate);
402
403 nir_push_if(b, nir_iand(b, nir_fge(b, rqi_load(b, closest, t),
404 intr->src[1].ssa),
405 nir_fge(b, intr->src[1].ssa,
406 rq_load(b, rq, tmin))));
407 {
408 nir_copy_deref(b, closest, candidate);
409 insert_terminate_on_first_hit(b, rq);
410 rqi_store(b, closest, t, intr->src[1].ssa, 0x1);
411 }
412 nir_pop_if(b, NULL);
413 }
414
415 static void
lower_rq_terminate(nir_builder * b,struct hash_table * ht,nir_intrinsic_instr * intr)416 lower_rq_terminate(nir_builder *b, struct hash_table *ht, nir_intrinsic_instr *intr)
417 {
418 nir_deref_instr *rq = get_rq_deref(b, ht, intr->src[0].ssa, NULL);
419 rq_store(b, rq, incomplete, nir_imm_false(b), 0x1);
420 }
421
422 static nir_def *
lower_rq_load(nir_builder * b,struct hash_table * ht,nir_intrinsic_instr * intr)423 lower_rq_load(nir_builder *b, struct hash_table *ht, nir_intrinsic_instr *intr)
424 {
425 struct rq_var *var;
426 nir_deref_instr *rq = get_rq_deref(b, ht, intr->src[0].ssa, &var);
427 nir_def *uav_index = get_uav_index(&intr->instr, ht);
428 nir_def *tlas = rq_load(b, rq, accel_struct_base);
429 nir_deref_instr *closest = rq_deref(b, rq, closest);
430 nir_deref_instr *candidate = rq_deref(b, rq, candidate);
431 bool committed = nir_intrinsic_committed(intr);
432 nir_deref_instr *intersection = committed ? closest : candidate;
433
434 uint32_t column = nir_intrinsic_column(intr);
435
436 nir_ray_query_value value = nir_intrinsic_ray_query_value(intr);
437 switch (value) {
438 case nir_ray_query_value_flags: {
439 nir_def *flags = rq_load(b, rq, flags);
440 return nir_ubitfield_extract(b, flags, nir_imm_int(b, 4),
441 nir_imm_int(b, 12));
442 }
443 case nir_ray_query_value_intersection_barycentrics:
444 return rqi_load(b, intersection, barycentrics);
445 case nir_ray_query_value_intersection_candidate_aabb_opaque:
446 return nir_ieq_imm(b, nir_iand_imm(b, rqi_load(b, candidate, type_flags),
447 TU_INTERSECTION_TYPE_AABB |
448 TU_INTERSECTION_TYPE_NONOPAQUE |
449 TU_INTERSECTION_TYPE_NO_INTERSECTION),
450 TU_INTERSECTION_TYPE_AABB);
451 case nir_ray_query_value_intersection_front_face:
452 return nir_inot(b, nir_test_mask(b, rqi_load(b, intersection, type_flags),
453 TU_INTERSECTION_BACK_FACE));
454 case nir_ray_query_value_intersection_geometry_index:
455 return rqi_load(b, intersection, geometry_id);
456 case nir_ray_query_value_intersection_instance_custom_index: {
457 nir_def *instance = rqi_load(b, intersection, instance);
458 return load_instance(b, tlas, uav_index, instance, custom_instance_index, 1);
459 }
460 case nir_ray_query_value_intersection_instance_id:
461 return rqi_load(b, intersection, instance);
462 case nir_ray_query_value_intersection_instance_sbt_index:
463 return rqi_load(b, intersection, sbt_offset);
464 case nir_ray_query_value_intersection_object_ray_direction:
465 return rqi_load(b, intersection, direction);
466 case nir_ray_query_value_intersection_object_ray_origin:
467 return rqi_load(b, intersection, origin);
468 case nir_ray_query_value_intersection_object_to_world: {
469 nir_def *instance = rqi_load(b, intersection, instance);
470 nir_def *rows[3];
471 for (unsigned r = 0; r < 3; ++r)
472 rows[r] = load_instance_offset(b, tlas, uav_index, instance,
473 otw_matrix.values,
474 r * 16, 4);
475
476 return nir_vec3(b, nir_channel(b, rows[0], column), nir_channel(b, rows[1], column),
477 nir_channel(b, rows[2], column));
478 }
479 case nir_ray_query_value_intersection_primitive_index:
480 return rqi_load(b, intersection, primitive_id);
481 case nir_ray_query_value_intersection_t:
482 return rqi_load(b, intersection, t);
483 case nir_ray_query_value_intersection_type: {
484 nir_def *intersection_type =
485 nir_iand_imm(b, nir_ishr_imm(b, rqi_load(b, intersection, type_flags),
486 util_logbase2(TU_INTERSECTION_TYPE_AABB)), 1);
487 if (committed) {
488 nir_def *has_intersection =
489 nir_inot(b,
490 nir_test_mask(b, rqi_load(b, intersection, type_flags),
491 TU_INTERSECTION_TYPE_NO_INTERSECTION));
492 intersection_type = nir_iadd(b, intersection_type,
493 nir_b2i32(b, has_intersection));
494 }
495 return intersection_type;
496 }
497 case nir_ray_query_value_intersection_world_to_object: {
498 nir_def *instance = rqi_load(b, intersection, instance);
499 nir_def *rows[3];
500 for (unsigned r = 0; r < 3; ++r)
501 rows[r] = load_instance_offset(b, tlas, uav_index, instance,
502 wto_matrix.values, r * 16, 4);
503
504 return nir_vec3(b, nir_channel(b, rows[0], column), nir_channel(b, rows[1], column),
505 nir_channel(b, rows[2], column));
506 }
507 case nir_ray_query_value_tmin:
508 return rq_load(b, rq, tmin);
509 case nir_ray_query_value_world_ray_direction:
510 return rq_load(b, rq, world_direction);
511 case nir_ray_query_value_world_ray_origin:
512 return rq_load(b, rq, world_origin);
513 default:
514 unreachable("Invalid nir_ray_query_value!");
515 }
516
517 return NULL;
518 }
519
520 /* For the initialization of instance_bottom_node. Explicitly different than
521 * VK_BVH_INVALID_NODE or any real node, to ensure we never exit an instance
522 * when we're not in one.
523 */
524 #define TU_BVH_NO_INSTANCE_ROOT 0xfffffffeu
525
526 nir_def *
nir_build_vec3_mat_mult(nir_builder * b,nir_def * vec,nir_def * matrix[],bool translation)527 nir_build_vec3_mat_mult(nir_builder *b, nir_def *vec, nir_def *matrix[], bool translation)
528 {
529 nir_def *result_components[3] = {
530 nir_channel(b, matrix[0], 3),
531 nir_channel(b, matrix[1], 3),
532 nir_channel(b, matrix[2], 3),
533 };
534 for (unsigned i = 0; i < 3; ++i) {
535 for (unsigned j = 0; j < 3; ++j) {
536 nir_def *v = nir_fmul(b, nir_channels(b, vec, 1 << j), nir_channels(b, matrix[i], 1 << j));
537 result_components[i] = (translation || j) ? nir_fadd(b, result_components[i], v) : v;
538 }
539 }
540 return nir_vec(b, result_components, 3);
541 }
542
543 static nir_def *
fetch_parent_node(nir_builder * b,nir_def * bvh,nir_def * node)544 fetch_parent_node(nir_builder *b, nir_def *bvh, nir_def *node)
545 {
546 nir_def *offset = nir_iadd_imm(b, nir_imul_imm(b, node, 4), 4);
547
548 return nir_build_load_global(b, 1, 32, nir_isub(b, nir_pack_64_2x32(b, bvh),
549 nir_u2u64(b, offset)), .align_mul = 4);
550 }
551
552 static nir_def *
build_ray_traversal(nir_builder * b,nir_deref_instr * rq,nir_def * tlas,nir_def * uav_index)553 build_ray_traversal(nir_builder *b, nir_deref_instr *rq,
554 nir_def *tlas, nir_def *uav_index)
555 {
556 nir_deref_instr *candidate = rq_deref(b, rq, candidate);
557 nir_deref_instr *closest = rq_deref(b, rq, closest);
558
559 nir_variable *incomplete = nir_local_variable_create(b->impl, glsl_bool_type(), "incomplete");
560 nir_store_var(b, incomplete, nir_imm_true(b), 0x1);
561
562 nir_push_loop(b);
563 {
564 /* Go up the stack if current_node == VK_BVH_INVALID_NODE */
565 nir_push_if(b, nir_ieq_imm(b, rq_load(b, rq, current_node), VK_BVH_INVALID_NODE));
566 {
567 /* Early exit if we never overflowed the stack, to avoid having to backtrack to
568 * the root for no reason. */
569 nir_push_if(b, nir_ilt_imm(b, rq_load(b, rq, stack_ptr), 1));
570 {
571 nir_store_var(b, incomplete, nir_imm_false(b), 0x1);
572 nir_jump(b, nir_jump_break);
573 }
574 nir_pop_if(b, NULL);
575
576 nir_def *stack_instance_exit =
577 nir_ige(b, rq_load(b, rq, top_stack), rq_load(b, rq, stack_ptr));
578 nir_def *root_instance_exit =
579 nir_ieq(b, rq_load(b, rq, previous_node), rq_load(b, rq, instance_bottom_node));
580 nir_if *instance_exit = nir_push_if(b, nir_ior(b, stack_instance_exit, root_instance_exit));
581 instance_exit->control = nir_selection_control_dont_flatten;
582 {
583 rq_store(b, rq, top_stack, nir_imm_int(b, -1), 1);
584 rq_store(b, rq, previous_node, rq_load(b, rq, instance_top_node), 1);
585 rq_store(b, rq, instance_bottom_node, nir_imm_int(b, TU_BVH_NO_INSTANCE_ROOT), 1);
586
587 rq_store(b, rq, bvh_base, rq_load(b, rq, root_bvh_base), 3);
588 rqi_store(b, candidate, origin, rq_load(b, rq, world_origin), 7);
589 rqi_store(b, candidate, direction, rq_load(b, rq, world_direction), 7);
590 }
591 nir_pop_if(b, NULL);
592
593 nir_push_if(
594 b, nir_ige(b, rq_load(b, rq, stack_low_watermark), rq_load(b, rq, stack_ptr)));
595 {
596 /* Get the parent of the previous node using the parent pointers.
597 * We will re-intersect the parent and figure out what index the
598 * previous node was below.
599 */
600 nir_def *prev = rq_load(b, rq, previous_node);
601 nir_def *bvh_addr = rq_load(b, rq, bvh_base);
602
603 nir_def *parent = fetch_parent_node(b, bvh_addr, prev);
604 nir_push_if(b, nir_ieq_imm(b, parent, VK_BVH_INVALID_NODE));
605 {
606 nir_store_var(b, incomplete, nir_imm_false(b), 0x1);
607 nir_jump(b, nir_jump_break);
608 }
609 nir_pop_if(b, NULL);
610 rq_store(b, rq, current_node, parent, 0x1);
611 }
612 nir_push_else(b, NULL);
613 {
614 /* Go up the stack and get the next child of the parent. */
615 nir_def *stack_ptr = nir_iadd_imm(b, rq_load(b, rq, stack_ptr), -1);
616
617 nir_def *stack_idx =
618 nir_umod_imm(b, stack_ptr, MAX_STACK_DEPTH);
619 nir_deref_instr *stack_deref =
620 nir_build_deref_array(b, rq_deref(b, rq, stack), stack_idx);
621 nir_def *stack_entry = nir_load_deref(b, stack_deref);
622 nir_def *children_base = nir_channel(b, stack_entry, 0);
623 nir_def *children = nir_channel(b, stack_entry, 1);
624
625 nir_def *next_child_idx =
626 nir_iadd_imm(b, nir_iand_imm(b, children, 0x1f), -3);
627
628 nir_def *child_offset =
629 nir_iand_imm(b, nir_ishr(b, children, next_child_idx), 0x7);
630 nir_def *bvh_node = nir_iadd(b, children_base, child_offset);
631
632 nir_push_if(b, nir_ieq_imm(b, next_child_idx, 8));
633 {
634 rq_store(b, rq, stack_ptr, stack_ptr, 1);
635 }
636 nir_push_else(b, NULL);
637 {
638 children = nir_bitfield_insert(b, children, next_child_idx,
639 nir_imm_int(b, 0),
640 nir_imm_int(b, 5));
641 nir_store_deref(b, stack_deref,
642 nir_vec2(b, nir_undef(b, 1, 32), children),
643 0x2);
644 }
645 nir_pop_if(b, NULL);
646
647 rq_store(b, rq, current_node, bvh_node, 0x1);
648 /* We don't need previous_node when we have the stack. Indicate to
649 * the internal intersection handling below that this isn't the
650 * underflow case.
651 */
652 rq_store(b, rq, previous_node, nir_imm_int(b, VK_BVH_INVALID_NODE), 0x1);
653 }
654 nir_pop_if(b, NULL);
655 }
656 nir_push_else(b, NULL);
657 {
658 rq_store(b, rq, previous_node, nir_imm_int(b, VK_BVH_INVALID_NODE), 0x1);
659 }
660 nir_pop_if(b, NULL);
661
662 nir_def *bvh_node = rq_load(b, rq, current_node);
663 nir_def *bvh_base = rq_load(b, rq, bvh_base);
664
665 nir_def *prev_node = rq_load(b, rq, previous_node);
666 rq_store(b, rq, previous_node, bvh_node, 0x1);
667 rq_store(b, rq, current_node, nir_imm_int(b, VK_BVH_INVALID_NODE), 0x1);
668
669 nir_def *origin = rqi_load(b, candidate, origin);
670 nir_def *tmin = rq_load(b, rq, tmin);
671 nir_def *direction = rqi_load(b, candidate, direction);
672 nir_def *tmax = rqi_load(b, closest, t);
673
674 nir_def *intrinsic_result =
675 nir_ray_intersection_ir3(b, 32, bvh_base, bvh_node,
676 nir_vec8(b,
677 nir_channel(b, origin, 0),
678 nir_channel(b, origin, 1),
679 nir_channel(b, origin, 2),
680 tmin,
681 nir_channel(b, direction, 0),
682 nir_channel(b, direction, 1),
683 nir_channel(b, direction, 2),
684 tmax),
685 rq_load(b, rq, flags));
686
687 nir_def *intersection_flags = nir_channel(b, intrinsic_result, 0);
688 nir_def *intersection_count =
689 nir_ubitfield_extract_imm(b, intersection_flags, 4, 4);
690 nir_def *intersection_id = nir_channel(b, intrinsic_result, 1);
691
692 nir_push_if(b, nir_test_mask(b, intersection_flags,
693 TU_INTERSECTION_TYPE_LEAF));
694 {
695 nir_def *processed_mask = nir_iand_imm(b, intersection_flags, 0xf);
696
697 /* Keep processing the current node if the mask isn't yet 0 */
698 rq_store(b, rq, current_node,
699 nir_bcsel(b, nir_ieq_imm(b, processed_mask, 0),
700 nir_imm_int(b, VK_BVH_INVALID_NODE),
701 bvh_node), 1);
702
703 /* If the mask is 0, replace with the initial 0xf for the next
704 * intersection.
705 */
706 processed_mask =
707 nir_bcsel(b, nir_ieq_imm(b, processed_mask, 0),
708 nir_imm_int(b, 0xf), processed_mask);
709
710 /* Replace the mask in the flags. */
711 rq_store(b, rq, flags,
712 nir_bitfield_insert(b, rq_load(b, rq, flags),
713 processed_mask, nir_imm_int(b, 0),
714 nir_imm_int(b, 4)), 1);
715
716 nir_push_if(b, nir_ieq_imm(b, intersection_count, 0));
717 {
718 nir_jump(b, nir_jump_continue);
719 }
720 nir_pop_if(b, NULL);
721
722 nir_push_if(b, nir_test_mask(b, intersection_flags,
723 TU_INTERSECTION_TYPE_TLAS));
724 {
725 /* instance */
726 rqi_store(b, candidate, instance, intersection_id, 1);
727
728 nir_def *wto_matrix[3];
729 for (unsigned i = 0; i < 3; i++)
730 wto_matrix[i] = load_instance_offset(b, tlas, uav_index,
731 intersection_id,
732 wto_matrix.values,
733 i * 16, 4);
734
735 nir_def *sbt_offset_and_flags =
736 load_instance(b, tlas, uav_index, intersection_id,
737 sbt_offset_and_flags, 1);
738 nir_def *blas_bvh =
739 load_instance(b, tlas, uav_index, intersection_id,
740 bvh_ptr, 2);
741
742 nir_def *instance_flags = nir_iand_imm(b, sbt_offset_and_flags,
743 0xff000000);
744 nir_def *sbt_offset = nir_iand_imm(b, sbt_offset_and_flags,
745 0x00ffffff);
746 nir_def *flags = rq_load(b, rq, flags);
747 flags = nir_ior(b, nir_iand_imm(b, flags, 0x00ffffff),
748 instance_flags);
749 rq_store(b, rq, flags, flags, 1);
750
751 rqi_store(b, candidate, sbt_offset, sbt_offset, 1);
752
753 rq_store(b, rq, top_stack, rq_load(b, rq, stack_ptr), 1);
754 rq_store(b, rq, bvh_base, blas_bvh, 3);
755
756 /* Push the instance root node onto the stack */
757 rq_store(b, rq, current_node, nir_imm_int(b, 0), 0x1);
758 rq_store(b, rq, instance_bottom_node, nir_imm_int(b, 0), 1);
759 rq_store(b, rq, instance_top_node, bvh_node, 1);
760
761 /* Transform the ray into object space */
762 rqi_store(b, candidate, origin,
763 nir_build_vec3_mat_mult(b, rq_load(b, rq, world_origin),
764 wto_matrix, true), 7);
765 rqi_store(b, candidate, direction,
766 nir_build_vec3_mat_mult(b, rq_load(b, rq, world_direction),
767 wto_matrix, false), 7);
768 }
769 nir_push_else(b, NULL);
770 {
771 /* AABB & triangle */
772 rqi_store(b, candidate, type_flags,
773 nir_iand_imm(b, intersection_flags,
774 TU_INTERSECTION_TYPE_AABB |
775 TU_INTERSECTION_TYPE_NONOPAQUE |
776 TU_INTERSECTION_BACK_FACE), 1);
777
778 rqi_store(b, candidate, primitive_id, intersection_id, 1);
779
780 /* TODO: Implement optimization to try to combine these into 1
781 * 32-bit ID, for compressed nodes.
782 *
783 * load_global_ir3 doesn't have the required range so we have to
784 * do the offset math ourselves.
785 */
786 nir_def *offset =
787 nir_ior_imm(b, nir_imul_imm(b, nir_u2u64(b, bvh_node),
788 sizeof(tu_leaf_node)),
789 offsetof(struct tu_leaf_node, geometry_id));
790 nir_def *geometry_id_ptr = nir_iadd(b, nir_pack_64_2x32(b, bvh_base),
791 offset);
792 nir_def *geometry_id =
793 nir_build_load_global(b, 1, 32, geometry_id_ptr,
794 .access = ACCESS_NON_WRITEABLE,
795 .align_mul = sizeof(struct tu_leaf_node),
796 .align_offset = offsetof(struct tu_leaf_node,
797 geometry_id));
798 rqi_store(b, candidate, geometry_id, geometry_id, 1);
799
800 nir_push_if(b, nir_test_mask(b, intersection_flags,
801 TU_INTERSECTION_TYPE_AABB));
802 {
803 nir_jump(b, nir_jump_break);
804 }
805 nir_push_else(b, NULL);
806 {
807 rqi_store(b, candidate, barycentrics,
808 nir_vec2(b, nir_channel(b, intrinsic_result, 3),
809 nir_channel(b, intrinsic_result, 4)), 0x3);
810 rqi_store(b, candidate, t, nir_channel(b, intrinsic_result,
811 2), 0x1);
812 nir_push_if(b, nir_test_mask(b, intersection_flags,
813 TU_INTERSECTION_TYPE_NONOPAQUE));
814 {
815 nir_jump(b, nir_jump_break);
816 }
817 nir_push_else(b, NULL);
818 {
819 nir_copy_deref(b, closest, candidate);
820 nir_def *terminate_on_first_hit =
821 nir_test_mask(b, rq_load(b, rq, flags),
822 SpvRayFlagsTerminateOnFirstHitKHRMask << 4);
823 nir_push_if(b, terminate_on_first_hit);
824 {
825 nir_store_var(b, incomplete, nir_imm_false(b), 0x1);
826 nir_jump(b, nir_jump_break);
827 }
828 nir_pop_if(b, NULL);
829 }
830 nir_pop_if(b, NULL);
831 }
832 nir_pop_if(b, NULL);
833 }
834 nir_pop_if(b, NULL);
835 }
836 nir_push_else(b, NULL);
837 {
838 /* internal */
839 nir_push_if(b, nir_ine_imm(b, intersection_count, 0));
840 {
841 nir_def *children = nir_channel(b, intrinsic_result, 3);
842
843 nir_push_if(b, nir_ieq_imm(b, prev_node, VK_BVH_INVALID_NODE));
844 {
845 /* The children array returned by the HW is specially set up so
846 * that we can do this to get the first child.
847 */
848 nir_def *first_child_offset =
849 nir_iand_imm(b, nir_ishr(b, children, children), 0x7);
850
851 rq_store(b, rq, current_node,
852 nir_iadd(b, intersection_id, first_child_offset),
853 0x1);
854
855 nir_push_if(b, nir_igt_imm(b, intersection_count, 1));
856 {
857 nir_def *stack_ptr = rq_load(b, rq, stack_ptr);
858 nir_def *stack_idx = nir_umod_imm(b, stack_ptr,
859 MAX_STACK_DEPTH);
860 nir_def *stack_entry =
861 nir_vec2(b, intersection_id, children);
862 nir_store_deref(b,
863 nir_build_deref_array(b, rq_deref(b, rq, stack),
864 stack_idx),
865 stack_entry, 0x7);
866 rq_store(b, rq, stack_ptr,
867 nir_iadd_imm(b, rq_load(b, rq, stack_ptr), 1), 0x1);
868
869 nir_def *new_watermark =
870 nir_iadd_imm(b, rq_load(b, rq, stack_ptr),
871 -MAX_STACK_DEPTH);
872 new_watermark = nir_imax(b, rq_load(b, rq,
873 stack_low_watermark),
874 new_watermark);
875 rq_store(b, rq, stack_low_watermark, new_watermark, 0x1);
876 }
877 nir_pop_if(b, NULL);
878 }
879 nir_push_else(b, NULL);
880 {
881 /* The underflow case. We have the previous_node and an array
882 * of intersecting children of its parent, and we need to find
883 * its position in the array so that we can return the next
884 * child in the array or VK_BVH_INVALID_NODE if it's the last
885 * child.
886 */
887 nir_def *prev_offset =
888 nir_isub(b, prev_node, intersection_id);
889
890 /* A bit-pattern with ones at the LSB of each child's
891 * position.
892 */
893 uint32_t ones = 0b1001001001001001001001 << 8;
894
895 /* Replicate prev_offset into the position of each child. */
896 nir_def *prev_offset_repl =
897 nir_imul_imm(b, prev_offset, ones);
898
899 /* a == b <=> a ^ b == 0. Reduce the problem to finding the
900 * child whose bits are 0.
901 */
902 nir_def *diff = nir_ixor(b, prev_offset_repl, children);
903
904 /* This magic formula comes from Hacker's Delight, section 6.1
905 * "Find First 0-byte", adapted for 3-bit "bytes". The first
906 * zero byte will be the lowest byte with 1 set in the highest
907 * position (i.e. bit 2). We need to then subtract 2 to get the
908 * current position and 5 to get the next position.
909 */
910 diff = nir_iand_imm(b, nir_iand(b, nir_iadd_imm(b, diff, -ones),
911 nir_inot(b, diff)),
912 ones << 2);
913 diff = nir_find_lsb(b, diff);
914
915 nir_def *next_offset =
916 nir_iand_imm(b, nir_ishr(b, children,
917 nir_iadd_imm(b, diff, -5)),
918 0x7);
919
920 nir_def *next =
921 nir_bcsel(b, nir_ieq_imm(b, diff, 8 + 2),
922 nir_imm_int(b, VK_BVH_INVALID_NODE),
923 nir_iadd(b, next_offset, intersection_id));
924 rq_store(b, rq, current_node, next, 0x1);
925 }
926 nir_pop_if(b, NULL);
927 }
928 nir_pop_if(b, NULL);
929 }
930 nir_pop_if(b, NULL);
931 }
932 nir_pop_loop(b, NULL);
933
934 return nir_load_var(b, incomplete);
935 }
936
937 static nir_def *
lower_rq_proceed(nir_builder * b,struct hash_table * ht,nir_intrinsic_instr * intr)938 lower_rq_proceed(nir_builder *b, struct hash_table *ht, nir_intrinsic_instr *intr)
939 {
940 struct rq_var *var;
941 nir_deref_instr *rq = get_rq_deref(b, ht, intr->src[0].ssa, &var);
942 nir_def *uav_index = get_uav_index(&intr->instr, ht);
943 nir_def *tlas = rq_load(b, rq, accel_struct_base);
944
945 nir_push_if(b, nir_load_deref(b, rq_deref(b, rq, incomplete)));
946 {
947 nir_def *incomplete = build_ray_traversal(b, rq, tlas, uav_index);
948 nir_store_deref(b, rq_deref(b, rq, incomplete), incomplete, 0x1);
949 }
950 nir_pop_if(b, NULL);
951
952 return nir_load_deref(b, rq_deref(b, rq, incomplete));
953 }
954
955 bool
tu_nir_lower_ray_queries(nir_shader * shader)956 tu_nir_lower_ray_queries(nir_shader *shader)
957 {
958 bool progress = false;
959 struct hash_table *query_ht = _mesa_pointer_hash_table_create(NULL);
960
961 nir_foreach_variable_in_list (var, &shader->variables) {
962 if (!var->data.ray_query)
963 continue;
964
965 lower_ray_query(shader, NULL, var, query_ht);
966
967 progress = true;
968 }
969
970 nir_foreach_function (function, shader) {
971 if (!function->impl)
972 continue;
973
974 nir_builder builder = nir_builder_create(function->impl);
975
976 nir_foreach_variable_in_list (var, &function->impl->locals) {
977 if (!var->data.ray_query)
978 continue;
979
980 lower_ray_query(shader, function->impl, var, query_ht);
981
982 progress = true;
983 }
984
985 calc_uav_index(function->impl, query_ht);
986
987 nir_foreach_block (block, function->impl) {
988 nir_foreach_instr_safe (instr, block) {
989 if (instr->type != nir_instr_type_intrinsic)
990 continue;
991
992 nir_intrinsic_instr *intrinsic = nir_instr_as_intrinsic(instr);
993
994 if (!nir_intrinsic_is_ray_query(intrinsic->intrinsic))
995 continue;
996
997 builder.cursor = nir_before_instr(instr);
998
999 nir_def *new_dest = NULL;
1000
1001 switch (intrinsic->intrinsic) {
1002 case nir_intrinsic_rq_confirm_intersection:
1003 lower_rq_confirm_intersection(&builder, query_ht, intrinsic);
1004 break;
1005 case nir_intrinsic_rq_generate_intersection:
1006 lower_rq_generate_intersection(&builder, query_ht, intrinsic);
1007 break;
1008 case nir_intrinsic_rq_initialize:
1009 lower_rq_initialize(&builder, query_ht, intrinsic);
1010 break;
1011 case nir_intrinsic_rq_load:
1012 new_dest = lower_rq_load(&builder, query_ht, intrinsic);
1013 break;
1014 case nir_intrinsic_rq_proceed:
1015 new_dest = lower_rq_proceed(&builder, query_ht, intrinsic);
1016 break;
1017 case nir_intrinsic_rq_terminate:
1018 lower_rq_terminate(&builder, query_ht, intrinsic);
1019 break;
1020 default:
1021 unreachable("Unsupported ray query intrinsic!");
1022 }
1023
1024 if (new_dest)
1025 nir_def_rewrite_uses(&intrinsic->def, new_dest);
1026
1027 nir_instr_remove(instr);
1028 nir_instr_free(instr);
1029
1030 progress = true;
1031 }
1032 }
1033
1034 nir_metadata_preserve(function->impl, nir_metadata_none);
1035 }
1036
1037 ralloc_free(query_ht);
1038
1039 return progress;
1040 }
1041
1042