• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright © 2024 Valve Corporation
3  * SPDX-License-Identifier: MIT
4  */
5 
6 #include "tu_shader.h"
7 
8 #include "bvh/tu_build_interface.h"
9 
10 #include "compiler/spirv/spirv.h"
11 
12 #include "nir_builder.h"
13 #include "nir_deref.h"
14 
15 enum rq_intersection_var_index {
16    rq_intersection_primitive_id,
17    rq_intersection_geometry_id,
18    rq_intersection_origin,
19    rq_intersection_direction,
20    rq_intersection_instance,
21    rq_intersection_type_flags,
22    rq_intersection_sbt_offset,
23    rq_intersection_barycentrics,
24    rq_intersection_t,
25 };
26 
27 static const glsl_type *
get_rq_intersection_type(void)28 get_rq_intersection_type(void)
29 {
30    struct glsl_struct_field fields[] = {
31 #define FIELD(_type, _name) \
32       [rq_intersection_##_name] = glsl_struct_field(_type, #_name),
33       FIELD(glsl_uint_type(), primitive_id)
34       FIELD(glsl_uint_type(), geometry_id)
35       FIELD(glsl_vec_type(3), origin)
36       FIELD(glsl_vec_type(3), direction)
37       FIELD(glsl_uint_type(), instance)
38       FIELD(glsl_uint_type(), type_flags)
39       FIELD(glsl_uint_type(), sbt_offset)
40       FIELD(glsl_vec2_type(), barycentrics)
41       FIELD(glsl_float_type(), t)
42 #undef FIELD
43    };
44 
45    return glsl_struct_type(fields, ARRAY_SIZE(fields), "ray_query_intersection", false);
46 }
47 
48 enum rq_var_index {
49    rq_index_accel_struct_base,
50    rq_index_root_bvh_base,
51    rq_index_bvh_base,
52    rq_index_flags,
53    rq_index_tmin,
54    rq_index_world_origin,
55    rq_index_world_direction,
56    rq_index_incomplete,
57    rq_index_closest,
58    rq_index_candidate,
59    rq_index_stack_ptr,
60    rq_index_top_stack,
61    rq_index_stack_low_watermark,
62    rq_index_current_node,
63    rq_index_previous_node,
64    rq_index_instance_top_node,
65    rq_index_instance_bottom_node,
66    rq_index_stack,
67 };
68 
69 /* Driver-internal flag to indicate that we haven't found an intersection */
70 #define TU_INTERSECTION_TYPE_NO_INTERSECTION (1u << 0)
71 
72 #define MAX_STACK_DEPTH 8
73 
74 static const glsl_type *
get_rq_type(void)75 get_rq_type(void)
76 {
77    const glsl_type *intersection_type = get_rq_intersection_type();
78 
79    struct glsl_struct_field fields[] = {
80 #define FIELD(_type, _name) \
81       [rq_index_##_name] = glsl_struct_field(_type, #_name),
82       FIELD(glsl_uvec2_type(), accel_struct_base)
83       FIELD(glsl_uvec2_type(), root_bvh_base)
84       FIELD(glsl_uvec2_type(), bvh_base)
85       FIELD(glsl_uint_type(), flags)
86       FIELD(glsl_float_type(), tmin)
87       FIELD(glsl_vec_type(3), world_origin)
88       FIELD(glsl_vec_type(3), world_direction)
89       FIELD(glsl_bool_type(), incomplete)
90       FIELD(intersection_type, closest)
91       FIELD(intersection_type, candidate)
92       FIELD(glsl_uint_type(), stack_ptr)
93       FIELD(glsl_uint_type(), top_stack)
94       FIELD(glsl_uint_type(), stack_low_watermark)
95       FIELD(glsl_uint_type(), current_node)
96       FIELD(glsl_uint_type(), previous_node)
97       FIELD(glsl_uint_type(), instance_top_node)
98       FIELD(glsl_uint_type(), instance_bottom_node)
99       FIELD(glsl_array_type(glsl_uvec2_type(), MAX_STACK_DEPTH, 0), stack)
100 #undef FIELD
101    };
102 
103    return glsl_struct_type(fields, ARRAY_SIZE(fields), "ray_query", false);
104 }
105 
106 struct rq_var {
107    nir_variable *rq;
108 
109    nir_intrinsic_instr *initialization;
110    nir_def *uav_index;
111 };
112 
113 static void
lower_ray_query(nir_shader * shader,nir_function_impl * impl,nir_variable * ray_query,struct hash_table * ht)114 lower_ray_query(nir_shader *shader, nir_function_impl *impl,
115                 nir_variable *ray_query, struct hash_table *ht)
116 {
117    struct rq_var *var = rzalloc(ht, struct rq_var);
118    const glsl_type *type = ray_query->type;
119 
120    const glsl_type *rq_type = glsl_type_wrap_in_arrays(get_rq_type(), type);
121 
122    if (impl)
123       var->rq = nir_local_variable_create(impl, rq_type, "ray_query");
124    else
125       var->rq = nir_variable_create(shader, nir_var_shader_temp, rq_type, "ray_query");
126 
127    _mesa_hash_table_insert(ht, ray_query, var);
128 }
129 
130 static nir_deref_instr *
get_rq_deref(nir_builder * b,struct hash_table * ht,nir_def * def,struct rq_var ** rq_var_out)131 get_rq_deref(nir_builder *b, struct hash_table *ht, nir_def *def,
132              struct rq_var **rq_var_out)
133 {
134    nir_deref_instr *deref = nir_instr_as_deref(def->parent_instr);
135 
136    nir_deref_path path;
137    nir_deref_path_init(&path, deref, NULL);
138    assert(path.path[0]->deref_type == nir_deref_type_var);
139 
140    nir_variable *opaque_var = nir_deref_instr_get_variable(path.path[0]);
141    struct hash_entry *entry = _mesa_hash_table_search(ht, opaque_var);
142    assert(entry);
143 
144    struct rq_var *rq = (struct rq_var *)entry->data;
145 
146    nir_deref_instr *out_deref = nir_build_deref_var(b, rq->rq);
147 
148    if (glsl_type_is_array(opaque_var->type)) {
149       nir_deref_instr **p = &path.path[1];
150       for (; *p; p++) {
151          if ((*p)->deref_type == nir_deref_type_array) {
152             nir_def *index = (*p)->arr.index.ssa;
153 
154             out_deref = nir_build_deref_array(b, out_deref, index);
155          } else {
156             unreachable("Unsupported deref type");
157          }
158       }
159    }
160 
161    nir_deref_path_finish(&path);
162 
163    if (rq_var_out)
164       *rq_var_out = rq;
165 
166    return out_deref;
167 }
168 
169 static nir_def *
get_rq_initialize_uav_index(nir_intrinsic_instr * intr,struct rq_var * var)170 get_rq_initialize_uav_index(nir_intrinsic_instr *intr, struct rq_var *var)
171 {
172    if (intr->src[1].ssa->parent_instr->type == nir_instr_type_intrinsic &&
173        nir_instr_as_intrinsic(intr->src[1].ssa->parent_instr)->intrinsic ==
174        nir_intrinsic_load_vulkan_descriptor) {
175       return intr->src[1].ssa;
176    } else {
177       return NULL;
178    }
179 }
180 
181 /* Before we modify control flow, walk the shader and determine ray query
182  * instructions for which we know the ray query has been initialized via a
183  * descriptor instead of a pointer, and record the UAV descriptor.
184  */
185 static void
calc_uav_index(nir_function_impl * impl,struct hash_table * ht)186 calc_uav_index(nir_function_impl *impl, struct hash_table *ht)
187 {
188    nir_metadata_require(impl, nir_metadata_dominance);
189 
190    nir_foreach_block (block, impl) {
191       nir_foreach_instr (instr, block) {
192          if (instr->type != nir_instr_type_intrinsic)
193             continue;
194 
195          nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
196 
197          nir_def *rq_def;
198          switch (intr->intrinsic) {
199          case nir_intrinsic_rq_initialize:
200          case nir_intrinsic_rq_load:
201          case nir_intrinsic_rq_proceed:
202             rq_def = intr->src[0].ssa;
203             break;
204          default:
205             continue;
206          }
207 
208          nir_deref_instr *deref = nir_instr_as_deref(rq_def->parent_instr);
209 
210          if (deref->deref_type != nir_deref_type_var)
211             continue;
212 
213          nir_variable *opaque_var = deref->var;
214          struct hash_entry *entry = _mesa_hash_table_search(ht, opaque_var);
215          assert(entry);
216 
217          struct rq_var *rq = (struct rq_var *)entry->data;
218 
219          if (intr->intrinsic == nir_intrinsic_rq_initialize) {
220             rq->initialization = intr;
221             rq->uav_index = get_rq_initialize_uav_index(intr, rq);
222          } else {
223             if (rq->initialization &&
224                 nir_block_dominates(rq->initialization->instr.block,
225                                     block) && rq->uav_index) {
226                _mesa_hash_table_insert(ht, instr, rq->uav_index);
227             }
228          }
229       }
230    }
231 }
232 
233 /* Return a pointer to the TLAS descriptor, which is actually a UAV
234  * descriptor, if we know that the ray query has been initialized via a
235  * descriptor and not a pointer. If not known, returns NULL.
236  */
237 static nir_def *
get_uav_index(nir_instr * cur_instr,struct hash_table * ht)238 get_uav_index(nir_instr *cur_instr, struct hash_table *ht)
239 {
240    struct hash_entry *entry = _mesa_hash_table_search(ht, cur_instr);
241    if (entry)
242       return (nir_def *)entry->data;
243    return NULL;
244 }
245 
246 /* Load some data from the TLAS header or instance descriptors. This uses the
247  * UAV descriptor when available, via "uav_index" which should be obtained
248  * from get_uav_index().
249  */
250 static nir_def *
load_tlas(nir_builder * b,nir_def * tlas,nir_def * uav_index,nir_def * index,unsigned offset,unsigned components)251 load_tlas(nir_builder *b, nir_def *tlas,
252           nir_def *uav_index, nir_def *index,
253           unsigned offset, unsigned components)
254 {
255    if (uav_index) {
256       return nir_load_uav_ir3(b, components, 32, uav_index,
257                               nir_vec2(b, index, nir_imm_int(b, offset / 4)),
258                               .access = (gl_access_qualifier)(
259                                           ACCESS_NON_WRITEABLE |
260                                           ACCESS_CAN_REORDER),
261                               .align_mul = AS_RECORD_SIZE,
262                               .align_offset = offset);
263    } else {
264       return nir_load_global_ir3(b, components, 32,
265                                  tlas,
266                                  nir_iadd_imm(b, nir_imul_imm(b, index, AS_RECORD_SIZE / 4),
267                                               offset / 4),
268                                  /* The required alignment of the
269                                   * user-specified base from the Vulkan spec.
270                                   */
271                                  .align_mul = 256,
272                                  .align_offset = 0);
273    }
274 }
275 
276 /* The first record is the TLAS header and the rest of the records are
277  * instances, so we need to add 1 to the instance ID when reading data in an
278  * instance.
279  */
280 #define load_instance_offset(b, tlas, uav_index, instance, \
281                              field, offset, components) \
282    load_tlas(b, tlas, uav_index, nir_iadd_imm(b, instance, 1), \
283              offsetof(struct tu_instance_descriptor, field) + offset, \
284              components)
285 
286 #define load_instance(b, tlas, uav_index, instance, field, components) \
287    load_instance_offset(b, tlas, uav_index, instance, field, 0, components)
288 
289 #define rq_deref(b, _rq, name) nir_build_deref_struct(b, _rq, rq_index_##name)
290 #define rq_load(b, _rq, name) nir_load_deref(b, rq_deref(b, _rq, name))
291 #define rq_store(b, _rq, name, val, wrmask) \
292    nir_store_deref(b, rq_deref(b, _rq, name), val, wrmask)
293 #define rqi_deref(b, _rq, name) nir_build_deref_struct(b, _rq, rq_intersection_##name)
294 #define rqi_load(b, _rq, name) nir_load_deref(b, rqi_deref(b, _rq, name))
295 #define rqi_store(b, _rq, name, val, wrmask) \
296    nir_store_deref(b, rqi_deref(b, _rq, name), val, wrmask)
297 
298 static void
lower_rq_initialize(nir_builder * b,struct hash_table * ht,nir_intrinsic_instr * intr)299 lower_rq_initialize(nir_builder *b, struct hash_table *ht,
300                     nir_intrinsic_instr *intr)
301 {
302    struct rq_var *var;
303    nir_deref_instr *rq = get_rq_deref(b, ht, intr->src[0].ssa, &var);
304 
305    if (nir_instr_as_deref(intr->src[0].ssa->parent_instr)->deref_type ==
306        nir_deref_type_var) {
307       var->initialization = intr;
308    } else {
309       var->initialization = NULL;
310    }
311 
312    nir_def *uav_index = get_rq_initialize_uav_index(intr, var);
313 
314    nir_def *tlas = intr->src[1].ssa;
315    nir_def *flags = intr->src[2].ssa;
316    nir_def *cull_mask = intr->src[3].ssa;
317    nir_def *origin = intr->src[4].ssa;
318    nir_def *tmin = intr->src[5].ssa;
319    nir_def *direction = intr->src[6].ssa;
320    nir_def *tmax = intr->src[7].ssa;
321 
322    nir_def *tlas_base;
323    if (uav_index) {
324       tlas_base = load_tlas(b, NULL, uav_index, nir_imm_int(b, 0),
325                             offsetof(struct tu_accel_struct_header,
326                                      self_ptr), 2);
327    } else {
328       tlas_base = nir_unpack_64_2x32(b, tlas);
329    }
330 
331    rq_store(b, rq, accel_struct_base, tlas_base, 0x3);
332 
333    nir_def *root_bvh_base = load_tlas(b, tlas_base, uav_index, nir_imm_int(b, 0),
334                                       offsetof(struct tu_accel_struct_header,
335                                                bvh_ptr), 2);
336 
337    nir_deref_instr *closest = rq_deref(b, rq, closest);
338    nir_deref_instr *candidate = rq_deref(b, rq, candidate);
339 
340    rq_store(b, rq, flags,
341             /* Fill out initial fourth src of ray_intersection */
342             nir_ior_imm(b,
343                         nir_ior(b, nir_ishl_imm(b, flags, 4),
344                                 nir_ishl_imm(b, cull_mask, 16)),
345                         0b1111), 0x1);
346 
347    rqi_store(b, candidate, origin, origin, 0x7);
348    rqi_store(b, candidate, direction, direction, 0x7);
349 
350    rq_store(b, rq, tmin, tmin, 0x1);
351    rq_store(b, rq, world_origin, origin, 0x7);
352    rq_store(b, rq, world_direction, direction, 0x7);
353 
354    rqi_store(b, closest, t, tmax, 0x1);
355    rqi_store(b, closest, type_flags, nir_imm_int(b, TU_INTERSECTION_TYPE_NO_INTERSECTION), 0x1);
356 
357    /* Make sure that instance data loads don't hang in case of a miss by setting a valid initial instance. */
358    rqi_store(b, closest, instance, nir_imm_int(b, 0), 0x1);
359    rqi_store(b, candidate, instance, nir_imm_int(b, 0), 0x1);
360 
361    rq_store(b, rq, root_bvh_base, root_bvh_base, 0x3);
362    rq_store(b, rq, bvh_base, root_bvh_base, 0x3);
363 
364    rq_store(b, rq, stack_ptr, nir_imm_int(b, 0), 0x1);
365    rq_store(b, rq, top_stack, nir_imm_int(b, -1), 0x1);
366    rq_store(b, rq, stack_low_watermark, nir_imm_int(b, 0), 0x1);
367    rq_store(b, rq, current_node, nir_imm_int(b, 0), 0x1);
368    rq_store(b, rq, previous_node, nir_imm_int(b, VK_BVH_INVALID_NODE), 0x1);
369    rq_store(b, rq, instance_top_node, nir_imm_int(b, VK_BVH_INVALID_NODE), 0x1);
370    rq_store(b, rq, instance_bottom_node, nir_imm_int(b, VK_BVH_INVALID_NODE), 0x1);
371 
372    rq_store(b, rq, incomplete, nir_imm_true(b), 0x1);
373 }
374 
375 static void
insert_terminate_on_first_hit(nir_builder * b,nir_deref_instr * rq)376 insert_terminate_on_first_hit(nir_builder *b, nir_deref_instr *rq)
377 {
378    nir_def *terminate_on_first_hit =
379       nir_test_mask(b, rq_load(b, rq, flags),
380                     SpvRayFlagsTerminateOnFirstHitKHRMask << 4);
381    nir_push_if(b, terminate_on_first_hit);
382    {
383       rq_store(b, rq, incomplete, nir_imm_false(b), 0x1);
384    }
385    nir_pop_if(b, NULL);
386 }
387 
388 static void
lower_rq_confirm_intersection(nir_builder * b,struct hash_table * ht,nir_intrinsic_instr * intr)389 lower_rq_confirm_intersection(nir_builder *b, struct hash_table *ht, nir_intrinsic_instr *intr)
390 {
391    nir_deref_instr *rq = get_rq_deref(b, ht, intr->src[0].ssa, NULL);
392    nir_copy_deref(b, rq_deref(b, rq, closest), rq_deref(b, rq, candidate));
393    insert_terminate_on_first_hit(b, rq);
394 }
395 
396 static void
lower_rq_generate_intersection(nir_builder * b,struct hash_table * ht,nir_intrinsic_instr * intr)397 lower_rq_generate_intersection(nir_builder *b, struct hash_table *ht, nir_intrinsic_instr *intr)
398 {
399    nir_deref_instr *rq = get_rq_deref(b, ht, intr->src[0].ssa, NULL);
400    nir_deref_instr *closest = rq_deref(b, rq, closest);
401    nir_deref_instr *candidate = rq_deref(b, rq, candidate);
402 
403    nir_push_if(b, nir_iand(b, nir_fge(b, rqi_load(b, closest, t),
404                                       intr->src[1].ssa),
405                            nir_fge(b, intr->src[1].ssa,
406                                    rq_load(b, rq, tmin))));
407    {
408       nir_copy_deref(b, closest, candidate);
409       insert_terminate_on_first_hit(b, rq);
410       rqi_store(b, closest, t, intr->src[1].ssa, 0x1);
411    }
412    nir_pop_if(b, NULL);
413 }
414 
415 static void
lower_rq_terminate(nir_builder * b,struct hash_table * ht,nir_intrinsic_instr * intr)416 lower_rq_terminate(nir_builder *b, struct hash_table *ht, nir_intrinsic_instr *intr)
417 {
418    nir_deref_instr *rq = get_rq_deref(b, ht, intr->src[0].ssa, NULL);
419    rq_store(b, rq, incomplete, nir_imm_false(b), 0x1);
420 }
421 
422 static nir_def *
lower_rq_load(nir_builder * b,struct hash_table * ht,nir_intrinsic_instr * intr)423 lower_rq_load(nir_builder *b, struct hash_table *ht, nir_intrinsic_instr *intr)
424 {
425    struct rq_var *var;
426    nir_deref_instr *rq = get_rq_deref(b, ht, intr->src[0].ssa, &var);
427    nir_def *uav_index = get_uav_index(&intr->instr, ht);
428    nir_def *tlas = rq_load(b, rq, accel_struct_base);
429    nir_deref_instr *closest = rq_deref(b, rq, closest);
430    nir_deref_instr *candidate = rq_deref(b, rq, candidate);
431    bool committed = nir_intrinsic_committed(intr);
432    nir_deref_instr *intersection = committed ? closest : candidate;
433 
434    uint32_t column = nir_intrinsic_column(intr);
435 
436    nir_ray_query_value value = nir_intrinsic_ray_query_value(intr);
437    switch (value) {
438    case nir_ray_query_value_flags: {
439       nir_def *flags = rq_load(b, rq, flags);
440       return nir_ubitfield_extract(b, flags, nir_imm_int(b, 4),
441                                    nir_imm_int(b, 12));
442    }
443    case nir_ray_query_value_intersection_barycentrics:
444       return rqi_load(b, intersection, barycentrics);
445    case nir_ray_query_value_intersection_candidate_aabb_opaque:
446       return nir_ieq_imm(b, nir_iand_imm(b, rqi_load(b, candidate, type_flags),
447                                          TU_INTERSECTION_TYPE_AABB |
448                                          TU_INTERSECTION_TYPE_NONOPAQUE |
449                                          TU_INTERSECTION_TYPE_NO_INTERSECTION),
450                          TU_INTERSECTION_TYPE_AABB);
451    case nir_ray_query_value_intersection_front_face:
452       return nir_inot(b, nir_test_mask(b, rqi_load(b, intersection, type_flags),
453                                        TU_INTERSECTION_BACK_FACE));
454    case nir_ray_query_value_intersection_geometry_index:
455       return rqi_load(b, intersection, geometry_id);
456    case nir_ray_query_value_intersection_instance_custom_index: {
457       nir_def *instance = rqi_load(b, intersection, instance);
458       return load_instance(b, tlas, uav_index, instance, custom_instance_index, 1);
459    }
460    case nir_ray_query_value_intersection_instance_id:
461       return rqi_load(b, intersection, instance);
462    case nir_ray_query_value_intersection_instance_sbt_index:
463       return rqi_load(b, intersection, sbt_offset);
464    case nir_ray_query_value_intersection_object_ray_direction:
465       return rqi_load(b, intersection, direction);
466    case nir_ray_query_value_intersection_object_ray_origin:
467       return rqi_load(b, intersection, origin);
468    case nir_ray_query_value_intersection_object_to_world: {
469       nir_def *instance = rqi_load(b, intersection, instance);
470       nir_def *rows[3];
471       for (unsigned r = 0; r < 3; ++r)
472          rows[r] = load_instance_offset(b, tlas, uav_index, instance,
473                                         otw_matrix.values,
474                                         r * 16, 4);
475 
476       return nir_vec3(b, nir_channel(b, rows[0], column), nir_channel(b, rows[1], column),
477                       nir_channel(b, rows[2], column));
478    }
479    case nir_ray_query_value_intersection_primitive_index:
480       return rqi_load(b, intersection, primitive_id);
481    case nir_ray_query_value_intersection_t:
482       return rqi_load(b, intersection, t);
483    case nir_ray_query_value_intersection_type: {
484       nir_def *intersection_type =
485          nir_iand_imm(b, nir_ishr_imm(b, rqi_load(b, intersection, type_flags),
486                                       util_logbase2(TU_INTERSECTION_TYPE_AABB)), 1);
487       if (committed) {
488          nir_def *has_intersection =
489             nir_inot(b,
490                      nir_test_mask(b, rqi_load(b, intersection, type_flags),
491                                    TU_INTERSECTION_TYPE_NO_INTERSECTION));
492          intersection_type = nir_iadd(b, intersection_type,
493                                       nir_b2i32(b, has_intersection));
494       }
495       return intersection_type;
496    }
497    case nir_ray_query_value_intersection_world_to_object: {
498       nir_def *instance = rqi_load(b, intersection, instance);
499       nir_def *rows[3];
500       for (unsigned r = 0; r < 3; ++r)
501          rows[r] = load_instance_offset(b, tlas, uav_index, instance,
502                                         wto_matrix.values, r * 16, 4);
503 
504       return nir_vec3(b, nir_channel(b, rows[0], column), nir_channel(b, rows[1], column),
505                       nir_channel(b, rows[2], column));
506    }
507    case nir_ray_query_value_tmin:
508       return rq_load(b, rq, tmin);
509    case nir_ray_query_value_world_ray_direction:
510       return rq_load(b, rq, world_direction);
511    case nir_ray_query_value_world_ray_origin:
512       return rq_load(b, rq, world_origin);
513    default:
514       unreachable("Invalid nir_ray_query_value!");
515    }
516 
517    return NULL;
518 }
519 
520 /* For the initialization of instance_bottom_node. Explicitly different than
521  * VK_BVH_INVALID_NODE or any real node, to ensure we never exit an instance
522  * when we're not in one.
523  */
524 #define TU_BVH_NO_INSTANCE_ROOT 0xfffffffeu
525 
526 nir_def *
nir_build_vec3_mat_mult(nir_builder * b,nir_def * vec,nir_def * matrix[],bool translation)527 nir_build_vec3_mat_mult(nir_builder *b, nir_def *vec, nir_def *matrix[], bool translation)
528 {
529    nir_def *result_components[3] = {
530       nir_channel(b, matrix[0], 3),
531       nir_channel(b, matrix[1], 3),
532       nir_channel(b, matrix[2], 3),
533    };
534    for (unsigned i = 0; i < 3; ++i) {
535       for (unsigned j = 0; j < 3; ++j) {
536          nir_def *v = nir_fmul(b, nir_channels(b, vec, 1 << j), nir_channels(b, matrix[i], 1 << j));
537          result_components[i] = (translation || j) ? nir_fadd(b, result_components[i], v) : v;
538       }
539    }
540    return nir_vec(b, result_components, 3);
541 }
542 
543 static nir_def *
fetch_parent_node(nir_builder * b,nir_def * bvh,nir_def * node)544 fetch_parent_node(nir_builder *b, nir_def *bvh, nir_def *node)
545 {
546    nir_def *offset = nir_iadd_imm(b, nir_imul_imm(b, node, 4), 4);
547 
548    return nir_build_load_global(b, 1, 32, nir_isub(b, nir_pack_64_2x32(b, bvh),
549                                                    nir_u2u64(b, offset)), .align_mul = 4);
550 }
551 
552 static nir_def *
build_ray_traversal(nir_builder * b,nir_deref_instr * rq,nir_def * tlas,nir_def * uav_index)553 build_ray_traversal(nir_builder *b, nir_deref_instr *rq,
554                     nir_def *tlas, nir_def *uav_index)
555 {
556    nir_deref_instr *candidate = rq_deref(b, rq, candidate);
557    nir_deref_instr *closest = rq_deref(b, rq, closest);
558 
559    nir_variable *incomplete = nir_local_variable_create(b->impl, glsl_bool_type(), "incomplete");
560    nir_store_var(b, incomplete, nir_imm_true(b), 0x1);
561 
562    nir_push_loop(b);
563    {
564       /* Go up the stack if current_node == VK_BVH_INVALID_NODE */
565       nir_push_if(b, nir_ieq_imm(b, rq_load(b, rq, current_node), VK_BVH_INVALID_NODE));
566       {
567          /* Early exit if we never overflowed the stack, to avoid having to backtrack to
568           * the root for no reason. */
569          nir_push_if(b, nir_ilt_imm(b, rq_load(b, rq, stack_ptr), 1));
570          {
571             nir_store_var(b, incomplete, nir_imm_false(b), 0x1);
572             nir_jump(b, nir_jump_break);
573          }
574          nir_pop_if(b, NULL);
575 
576          nir_def *stack_instance_exit =
577             nir_ige(b, rq_load(b, rq, top_stack), rq_load(b, rq, stack_ptr));
578          nir_def *root_instance_exit =
579             nir_ieq(b, rq_load(b, rq, previous_node), rq_load(b, rq, instance_bottom_node));
580          nir_if *instance_exit = nir_push_if(b, nir_ior(b, stack_instance_exit, root_instance_exit));
581          instance_exit->control = nir_selection_control_dont_flatten;
582          {
583             rq_store(b, rq, top_stack, nir_imm_int(b, -1), 1);
584             rq_store(b, rq, previous_node, rq_load(b, rq, instance_top_node), 1);
585             rq_store(b, rq, instance_bottom_node, nir_imm_int(b, TU_BVH_NO_INSTANCE_ROOT), 1);
586 
587             rq_store(b, rq, bvh_base, rq_load(b, rq, root_bvh_base), 3);
588             rqi_store(b, candidate, origin, rq_load(b, rq, world_origin), 7);
589             rqi_store(b, candidate, direction, rq_load(b, rq, world_direction), 7);
590          }
591          nir_pop_if(b, NULL);
592 
593          nir_push_if(
594             b, nir_ige(b, rq_load(b, rq, stack_low_watermark), rq_load(b, rq, stack_ptr)));
595          {
596             /* Get the parent of the previous node using the parent pointers.
597              * We will re-intersect the parent and figure out what index the
598              * previous node was below.
599              */
600             nir_def *prev = rq_load(b, rq, previous_node);
601             nir_def *bvh_addr = rq_load(b, rq, bvh_base);
602 
603             nir_def *parent = fetch_parent_node(b, bvh_addr, prev);
604             nir_push_if(b, nir_ieq_imm(b, parent, VK_BVH_INVALID_NODE));
605             {
606                nir_store_var(b, incomplete, nir_imm_false(b), 0x1);
607                nir_jump(b, nir_jump_break);
608             }
609             nir_pop_if(b, NULL);
610             rq_store(b, rq, current_node, parent, 0x1);
611          }
612          nir_push_else(b, NULL);
613          {
614             /* Go up the stack and get the next child of the parent. */
615             nir_def *stack_ptr = nir_iadd_imm(b, rq_load(b, rq, stack_ptr), -1);
616 
617             nir_def *stack_idx =
618                nir_umod_imm(b, stack_ptr, MAX_STACK_DEPTH);
619             nir_deref_instr *stack_deref =
620                nir_build_deref_array(b, rq_deref(b, rq, stack), stack_idx);
621             nir_def *stack_entry = nir_load_deref(b, stack_deref);
622             nir_def *children_base = nir_channel(b, stack_entry, 0);
623             nir_def *children = nir_channel(b, stack_entry, 1);
624 
625             nir_def *next_child_idx =
626                nir_iadd_imm(b, nir_iand_imm(b, children, 0x1f), -3);
627 
628             nir_def *child_offset =
629                nir_iand_imm(b, nir_ishr(b, children, next_child_idx), 0x7);
630             nir_def *bvh_node = nir_iadd(b, children_base, child_offset);
631 
632             nir_push_if(b, nir_ieq_imm(b, next_child_idx, 8));
633             {
634                rq_store(b, rq, stack_ptr, stack_ptr, 1);
635             }
636             nir_push_else(b, NULL);
637             {
638                children = nir_bitfield_insert(b, children, next_child_idx,
639                                               nir_imm_int(b, 0),
640                                               nir_imm_int(b, 5));
641                nir_store_deref(b, stack_deref,
642                                nir_vec2(b, nir_undef(b, 1, 32), children),
643                                0x2);
644             }
645             nir_pop_if(b, NULL);
646 
647             rq_store(b, rq, current_node, bvh_node, 0x1);
648             /* We don't need previous_node when we have the stack. Indicate to
649              * the internal intersection handling below that this isn't the
650              * underflow case.
651              */
652             rq_store(b, rq, previous_node, nir_imm_int(b, VK_BVH_INVALID_NODE), 0x1);
653          }
654          nir_pop_if(b, NULL);
655       }
656       nir_push_else(b, NULL);
657       {
658          rq_store(b, rq, previous_node, nir_imm_int(b, VK_BVH_INVALID_NODE), 0x1);
659       }
660       nir_pop_if(b, NULL);
661 
662       nir_def *bvh_node = rq_load(b, rq, current_node);
663       nir_def *bvh_base = rq_load(b, rq, bvh_base);
664 
665       nir_def *prev_node = rq_load(b, rq, previous_node);
666       rq_store(b, rq, previous_node, bvh_node, 0x1);
667       rq_store(b, rq, current_node, nir_imm_int(b, VK_BVH_INVALID_NODE), 0x1);
668 
669       nir_def *origin = rqi_load(b, candidate, origin);
670       nir_def *tmin = rq_load(b, rq, tmin);
671       nir_def *direction = rqi_load(b, candidate, direction);
672       nir_def *tmax = rqi_load(b, closest, t);
673 
674       nir_def *intrinsic_result =
675          nir_ray_intersection_ir3(b, 32, bvh_base, bvh_node,
676                                   nir_vec8(b,
677                                            nir_channel(b, origin, 0),
678                                            nir_channel(b, origin, 1),
679                                            nir_channel(b, origin, 2),
680                                            tmin,
681                                            nir_channel(b, direction, 0),
682                                            nir_channel(b, direction, 1),
683                                            nir_channel(b, direction, 2),
684                                            tmax),
685                                   rq_load(b, rq, flags));
686 
687       nir_def *intersection_flags = nir_channel(b, intrinsic_result, 0);
688       nir_def *intersection_count =
689          nir_ubitfield_extract_imm(b, intersection_flags, 4, 4);
690       nir_def *intersection_id = nir_channel(b, intrinsic_result, 1);
691 
692       nir_push_if(b, nir_test_mask(b, intersection_flags,
693                                    TU_INTERSECTION_TYPE_LEAF));
694       {
695          nir_def *processed_mask = nir_iand_imm(b, intersection_flags, 0xf);
696 
697          /* Keep processing the current node if the mask isn't yet 0 */
698          rq_store(b, rq, current_node,
699                   nir_bcsel(b, nir_ieq_imm(b, processed_mask, 0),
700                             nir_imm_int(b, VK_BVH_INVALID_NODE),
701                             bvh_node), 1);
702 
703          /* If the mask is 0, replace with the initial 0xf for the next
704           * intersection.
705           */
706          processed_mask =
707             nir_bcsel(b, nir_ieq_imm(b, processed_mask, 0),
708                       nir_imm_int(b, 0xf), processed_mask);
709 
710          /* Replace the mask in the flags. */
711          rq_store(b, rq, flags,
712                   nir_bitfield_insert(b, rq_load(b, rq, flags),
713                                       processed_mask, nir_imm_int(b, 0),
714                                       nir_imm_int(b, 4)), 1);
715 
716          nir_push_if(b, nir_ieq_imm(b, intersection_count, 0));
717          {
718             nir_jump(b, nir_jump_continue);
719          }
720          nir_pop_if(b, NULL);
721 
722          nir_push_if(b, nir_test_mask(b, intersection_flags,
723                                       TU_INTERSECTION_TYPE_TLAS));
724          {
725             /* instance */
726             rqi_store(b, candidate, instance, intersection_id, 1);
727 
728             nir_def *wto_matrix[3];
729             for (unsigned i = 0; i < 3; i++)
730                wto_matrix[i] = load_instance_offset(b, tlas, uav_index,
731                                                     intersection_id,
732                                                     wto_matrix.values,
733                                                     i * 16, 4);
734 
735             nir_def *sbt_offset_and_flags =
736                load_instance(b, tlas, uav_index, intersection_id,
737                              sbt_offset_and_flags, 1);
738             nir_def *blas_bvh =
739                load_instance(b, tlas, uav_index, intersection_id,
740                              bvh_ptr, 2);
741 
742             nir_def *instance_flags = nir_iand_imm(b, sbt_offset_and_flags,
743                                                    0xff000000);
744             nir_def *sbt_offset = nir_iand_imm(b, sbt_offset_and_flags,
745                                                0x00ffffff);
746             nir_def *flags = rq_load(b, rq, flags);
747             flags = nir_ior(b, nir_iand_imm(b, flags, 0x00ffffff),
748                             instance_flags);
749             rq_store(b, rq, flags, flags, 1);
750 
751             rqi_store(b, candidate, sbt_offset, sbt_offset, 1);
752 
753             rq_store(b, rq, top_stack, rq_load(b, rq, stack_ptr), 1);
754             rq_store(b, rq, bvh_base, blas_bvh, 3);
755 
756             /* Push the instance root node onto the stack */
757             rq_store(b, rq, current_node, nir_imm_int(b, 0), 0x1);
758             rq_store(b, rq, instance_bottom_node, nir_imm_int(b, 0), 1);
759             rq_store(b, rq, instance_top_node, bvh_node, 1);
760 
761             /* Transform the ray into object space */
762             rqi_store(b, candidate, origin,
763                       nir_build_vec3_mat_mult(b, rq_load(b, rq, world_origin),
764                                               wto_matrix, true), 7);
765             rqi_store(b, candidate, direction,
766                       nir_build_vec3_mat_mult(b, rq_load(b, rq, world_direction),
767                                               wto_matrix, false), 7);
768          }
769          nir_push_else(b, NULL);
770          {
771             /* AABB & triangle */
772             rqi_store(b, candidate, type_flags,
773                       nir_iand_imm(b, intersection_flags,
774                                    TU_INTERSECTION_TYPE_AABB |
775                                    TU_INTERSECTION_TYPE_NONOPAQUE |
776                                    TU_INTERSECTION_BACK_FACE), 1);
777 
778             rqi_store(b, candidate, primitive_id, intersection_id, 1);
779 
780             /* TODO: Implement optimization to try to combine these into 1
781              * 32-bit ID, for compressed nodes.
782              *
783              * load_global_ir3 doesn't have the required range so we have to
784              * do the offset math ourselves.
785              */
786             nir_def *offset =
787                nir_ior_imm(b, nir_imul_imm(b, nir_u2u64(b, bvh_node),
788                                             sizeof(tu_leaf_node)),
789                            offsetof(struct tu_leaf_node, geometry_id));
790             nir_def *geometry_id_ptr = nir_iadd(b, nir_pack_64_2x32(b, bvh_base),
791                                                 offset);
792             nir_def *geometry_id =
793                nir_build_load_global(b, 1, 32, geometry_id_ptr,
794                                      .access = ACCESS_NON_WRITEABLE,
795                                      .align_mul = sizeof(struct tu_leaf_node),
796                                      .align_offset = offsetof(struct tu_leaf_node,
797                                                               geometry_id));
798             rqi_store(b, candidate, geometry_id, geometry_id, 1);
799 
800             nir_push_if(b, nir_test_mask(b, intersection_flags,
801                                          TU_INTERSECTION_TYPE_AABB));
802             {
803                nir_jump(b, nir_jump_break);
804             }
805             nir_push_else(b, NULL);
806             {
807                rqi_store(b, candidate, barycentrics,
808                          nir_vec2(b, nir_channel(b, intrinsic_result, 3),
809                                   nir_channel(b, intrinsic_result, 4)), 0x3);
810                rqi_store(b, candidate, t, nir_channel(b, intrinsic_result,
811                                                       2), 0x1);
812                nir_push_if(b, nir_test_mask(b, intersection_flags,
813                                             TU_INTERSECTION_TYPE_NONOPAQUE));
814                {
815                   nir_jump(b, nir_jump_break);
816                }
817                nir_push_else(b, NULL);
818                {
819                   nir_copy_deref(b, closest, candidate);
820                   nir_def *terminate_on_first_hit =
821                      nir_test_mask(b, rq_load(b, rq, flags),
822                                    SpvRayFlagsTerminateOnFirstHitKHRMask << 4);
823                   nir_push_if(b, terminate_on_first_hit);
824                   {
825                      nir_store_var(b, incomplete, nir_imm_false(b), 0x1);
826                      nir_jump(b, nir_jump_break);
827                   }
828                   nir_pop_if(b, NULL);
829                }
830                nir_pop_if(b, NULL);
831             }
832             nir_pop_if(b, NULL);
833          }
834          nir_pop_if(b, NULL);
835       }
836       nir_push_else(b, NULL);
837       {
838          /* internal */
839          nir_push_if(b, nir_ine_imm(b, intersection_count, 0));
840          {
841             nir_def *children = nir_channel(b, intrinsic_result, 3);
842 
843             nir_push_if(b, nir_ieq_imm(b, prev_node, VK_BVH_INVALID_NODE));
844             {
845                /* The children array returned by the HW is specially set up so
846                 * that we can do this to get the first child.
847                 */
848                nir_def *first_child_offset =
849                   nir_iand_imm(b, nir_ishr(b, children, children), 0x7);
850 
851                rq_store(b, rq, current_node,
852                         nir_iadd(b, intersection_id, first_child_offset),
853                         0x1);
854 
855                nir_push_if(b, nir_igt_imm(b, intersection_count, 1));
856                {
857                   nir_def *stack_ptr = rq_load(b, rq, stack_ptr);
858                   nir_def *stack_idx = nir_umod_imm(b, stack_ptr,
859                                                     MAX_STACK_DEPTH);
860                   nir_def *stack_entry =
861                      nir_vec2(b, intersection_id, children);
862                   nir_store_deref(b,
863                                   nir_build_deref_array(b, rq_deref(b, rq, stack),
864                                                         stack_idx),
865                                   stack_entry, 0x7);
866                   rq_store(b, rq, stack_ptr,
867                            nir_iadd_imm(b, rq_load(b, rq, stack_ptr), 1), 0x1);
868 
869                   nir_def *new_watermark =
870                      nir_iadd_imm(b, rq_load(b, rq, stack_ptr),
871                                   -MAX_STACK_DEPTH);
872                   new_watermark = nir_imax(b, rq_load(b, rq,
873                                                       stack_low_watermark),
874                                            new_watermark);
875                   rq_store(b, rq, stack_low_watermark, new_watermark, 0x1);
876                }
877                nir_pop_if(b, NULL);
878             }
879             nir_push_else(b, NULL);
880             {
881                /* The underflow case. We have the previous_node and an array
882                 * of intersecting children of its parent, and we need to find
883                 * its position in the array so that we can return the next
884                 * child in the array or VK_BVH_INVALID_NODE if it's the last
885                 * child.
886                 */
887                nir_def *prev_offset =
888                   nir_isub(b, prev_node, intersection_id);
889 
890                /* A bit-pattern with ones at the LSB of each child's
891                 * position.
892                 */
893                uint32_t ones = 0b1001001001001001001001 << 8;
894 
895                /* Replicate prev_offset into the position of each child. */
896                nir_def *prev_offset_repl =
897                   nir_imul_imm(b, prev_offset, ones);
898 
899                /* a == b <=> a ^ b == 0. Reduce the problem to finding the
900                 * child whose bits are 0.
901                 */
902                nir_def *diff = nir_ixor(b, prev_offset_repl, children);
903 
904                /* This magic formula comes from Hacker's Delight, section 6.1
905                 * "Find First 0-byte", adapted for 3-bit "bytes". The first
906                 * zero byte will be the lowest byte with 1 set in the highest
907                 * position (i.e. bit 2). We need to then subtract 2 to get the
908                 * current position and 5 to get the next position.
909                 */
910                diff = nir_iand_imm(b, nir_iand(b, nir_iadd_imm(b, diff, -ones),
911                                                nir_inot(b, diff)),
912                                    ones << 2);
913                diff = nir_find_lsb(b, diff);
914 
915                nir_def *next_offset =
916                   nir_iand_imm(b, nir_ishr(b, children,
917                                            nir_iadd_imm(b, diff, -5)),
918                                0x7);
919 
920                nir_def *next =
921                   nir_bcsel(b, nir_ieq_imm(b, diff, 8 + 2),
922                             nir_imm_int(b, VK_BVH_INVALID_NODE),
923                             nir_iadd(b, next_offset, intersection_id));
924                rq_store(b, rq, current_node, next, 0x1);
925             }
926             nir_pop_if(b, NULL);
927          }
928          nir_pop_if(b, NULL);
929       }
930       nir_pop_if(b, NULL);
931    }
932    nir_pop_loop(b, NULL);
933 
934    return nir_load_var(b, incomplete);
935 }
936 
937 static nir_def *
lower_rq_proceed(nir_builder * b,struct hash_table * ht,nir_intrinsic_instr * intr)938 lower_rq_proceed(nir_builder *b, struct hash_table *ht, nir_intrinsic_instr *intr)
939 {
940    struct rq_var *var;
941    nir_deref_instr *rq = get_rq_deref(b, ht, intr->src[0].ssa, &var);
942    nir_def *uav_index = get_uav_index(&intr->instr, ht);
943    nir_def *tlas = rq_load(b, rq, accel_struct_base);
944 
945    nir_push_if(b, nir_load_deref(b, rq_deref(b, rq, incomplete)));
946    {
947       nir_def *incomplete = build_ray_traversal(b, rq, tlas, uav_index);
948       nir_store_deref(b, rq_deref(b, rq, incomplete), incomplete, 0x1);
949    }
950    nir_pop_if(b, NULL);
951 
952    return nir_load_deref(b, rq_deref(b, rq, incomplete));
953 }
954 
955 bool
tu_nir_lower_ray_queries(nir_shader * shader)956 tu_nir_lower_ray_queries(nir_shader *shader)
957 {
958    bool progress = false;
959    struct hash_table *query_ht = _mesa_pointer_hash_table_create(NULL);
960 
961    nir_foreach_variable_in_list (var, &shader->variables) {
962       if (!var->data.ray_query)
963          continue;
964 
965       lower_ray_query(shader, NULL, var, query_ht);
966 
967       progress = true;
968    }
969 
970    nir_foreach_function (function, shader) {
971       if (!function->impl)
972          continue;
973 
974       nir_builder builder = nir_builder_create(function->impl);
975 
976       nir_foreach_variable_in_list (var, &function->impl->locals) {
977          if (!var->data.ray_query)
978             continue;
979 
980          lower_ray_query(shader, function->impl, var, query_ht);
981 
982          progress = true;
983       }
984 
985       calc_uav_index(function->impl, query_ht);
986 
987       nir_foreach_block (block, function->impl) {
988          nir_foreach_instr_safe (instr, block) {
989             if (instr->type != nir_instr_type_intrinsic)
990                continue;
991 
992             nir_intrinsic_instr *intrinsic = nir_instr_as_intrinsic(instr);
993 
994             if (!nir_intrinsic_is_ray_query(intrinsic->intrinsic))
995                continue;
996 
997             builder.cursor = nir_before_instr(instr);
998 
999             nir_def *new_dest = NULL;
1000 
1001             switch (intrinsic->intrinsic) {
1002             case nir_intrinsic_rq_confirm_intersection:
1003                lower_rq_confirm_intersection(&builder, query_ht, intrinsic);
1004                break;
1005             case nir_intrinsic_rq_generate_intersection:
1006                lower_rq_generate_intersection(&builder, query_ht, intrinsic);
1007                break;
1008             case nir_intrinsic_rq_initialize:
1009                lower_rq_initialize(&builder, query_ht, intrinsic);
1010                break;
1011             case nir_intrinsic_rq_load:
1012                new_dest = lower_rq_load(&builder, query_ht, intrinsic);
1013                break;
1014             case nir_intrinsic_rq_proceed:
1015                new_dest = lower_rq_proceed(&builder, query_ht, intrinsic);
1016                break;
1017             case nir_intrinsic_rq_terminate:
1018                lower_rq_terminate(&builder, query_ht, intrinsic);
1019                break;
1020             default:
1021                unreachable("Unsupported ray query intrinsic!");
1022             }
1023 
1024             if (new_dest)
1025                nir_def_rewrite_uses(&intrinsic->def, new_dest);
1026 
1027             nir_instr_remove(instr);
1028             nir_instr_free(instr);
1029 
1030             progress = true;
1031          }
1032       }
1033 
1034       nir_metadata_preserve(function->impl, nir_metadata_none);
1035    }
1036 
1037    ralloc_free(query_ht);
1038 
1039    return progress;
1040 }
1041 
1042