• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright © 2022 Konstantin Seurer
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 #include "nir/nir.h"
25 #include "nir/nir_builder.h"
26 
27 #include "util/hash_table.h"
28 
29 #include "bvh/bvh.h"
30 #include "nir/radv_nir_rt_common.h"
31 #include "radv_debug.h"
32 #include "radv_nir.h"
33 #include "radv_private.h"
34 #include "radv_shader.h"
35 
36 /* Traversal stack size. Traversal supports backtracking so we can go deeper than this size if
37  * needed. However, we keep a large stack size to avoid it being put into registers, which hurts
38  * occupancy. */
39 #define MAX_SCRATCH_STACK_ENTRY_COUNT 76
40 
41 typedef struct {
42    nir_variable *variable;
43    unsigned array_length;
44 } rq_variable;
45 
46 static rq_variable *
rq_variable_create(void * ctx,nir_shader * shader,unsigned array_length,const struct glsl_type * type,const char * name)47 rq_variable_create(void *ctx, nir_shader *shader, unsigned array_length, const struct glsl_type *type, const char *name)
48 {
49    rq_variable *result = ralloc(ctx, rq_variable);
50    result->array_length = array_length;
51 
52    const struct glsl_type *variable_type = type;
53    if (array_length != 1)
54       variable_type = glsl_array_type(type, array_length, glsl_get_explicit_stride(type));
55 
56    result->variable = nir_variable_create(shader, nir_var_shader_temp, variable_type, name);
57 
58    return result;
59 }
60 
61 static nir_def *
nir_load_array(nir_builder * b,nir_variable * array,nir_def * index)62 nir_load_array(nir_builder *b, nir_variable *array, nir_def *index)
63 {
64    return nir_load_deref(b, nir_build_deref_array(b, nir_build_deref_var(b, array), index));
65 }
66 
67 static void
nir_store_array(nir_builder * b,nir_variable * array,nir_def * index,nir_def * value,unsigned writemask)68 nir_store_array(nir_builder *b, nir_variable *array, nir_def *index, nir_def *value, unsigned writemask)
69 {
70    nir_store_deref(b, nir_build_deref_array(b, nir_build_deref_var(b, array), index), value, writemask);
71 }
72 
73 static nir_deref_instr *
rq_deref_var(nir_builder * b,nir_def * index,rq_variable * var)74 rq_deref_var(nir_builder *b, nir_def *index, rq_variable *var)
75 {
76    if (var->array_length == 1)
77       return nir_build_deref_var(b, var->variable);
78 
79    return nir_build_deref_array(b, nir_build_deref_var(b, var->variable), index);
80 }
81 
82 static nir_def *
rq_load_var(nir_builder * b,nir_def * index,rq_variable * var)83 rq_load_var(nir_builder *b, nir_def *index, rq_variable *var)
84 {
85    if (var->array_length == 1)
86       return nir_load_var(b, var->variable);
87 
88    return nir_load_array(b, var->variable, index);
89 }
90 
91 static void
rq_store_var(nir_builder * b,nir_def * index,rq_variable * var,nir_def * value,unsigned writemask)92 rq_store_var(nir_builder *b, nir_def *index, rq_variable *var, nir_def *value, unsigned writemask)
93 {
94    if (var->array_length == 1) {
95       nir_store_var(b, var->variable, value, writemask);
96    } else {
97       nir_store_array(b, var->variable, index, value, writemask);
98    }
99 }
100 
101 static void
rq_copy_var(nir_builder * b,nir_def * index,rq_variable * dst,rq_variable * src,unsigned mask)102 rq_copy_var(nir_builder *b, nir_def *index, rq_variable *dst, rq_variable *src, unsigned mask)
103 {
104    rq_store_var(b, index, dst, rq_load_var(b, index, src), mask);
105 }
106 
107 static nir_def *
rq_load_array(nir_builder * b,nir_def * index,rq_variable * var,nir_def * array_index)108 rq_load_array(nir_builder *b, nir_def *index, rq_variable *var, nir_def *array_index)
109 {
110    if (var->array_length == 1)
111       return nir_load_array(b, var->variable, array_index);
112 
113    return nir_load_deref(
114       b, nir_build_deref_array(b, nir_build_deref_array(b, nir_build_deref_var(b, var->variable), index), array_index));
115 }
116 
117 static void
rq_store_array(nir_builder * b,nir_def * index,rq_variable * var,nir_def * array_index,nir_def * value,unsigned writemask)118 rq_store_array(nir_builder *b, nir_def *index, rq_variable *var, nir_def *array_index, nir_def *value,
119                unsigned writemask)
120 {
121    if (var->array_length == 1) {
122       nir_store_array(b, var->variable, array_index, value, writemask);
123    } else {
124       nir_store_deref(
125          b,
126          nir_build_deref_array(b, nir_build_deref_array(b, nir_build_deref_var(b, var->variable), index), array_index),
127          value, writemask);
128    }
129 }
130 
131 struct ray_query_traversal_vars {
132    rq_variable *origin;
133    rq_variable *direction;
134 
135    rq_variable *bvh_base;
136    rq_variable *stack;
137    rq_variable *top_stack;
138    rq_variable *stack_low_watermark;
139    rq_variable *current_node;
140    rq_variable *previous_node;
141    rq_variable *instance_top_node;
142    rq_variable *instance_bottom_node;
143 };
144 
145 struct ray_query_intersection_vars {
146    rq_variable *primitive_id;
147    rq_variable *geometry_id_and_flags;
148    rq_variable *instance_addr;
149    rq_variable *intersection_type;
150    rq_variable *opaque;
151    rq_variable *frontface;
152    rq_variable *sbt_offset_and_flags;
153    rq_variable *barycentrics;
154    rq_variable *t;
155 };
156 
157 struct ray_query_vars {
158    rq_variable *root_bvh_base;
159    rq_variable *flags;
160    rq_variable *cull_mask;
161    rq_variable *origin;
162    rq_variable *tmin;
163    rq_variable *direction;
164 
165    rq_variable *incomplete;
166 
167    struct ray_query_intersection_vars closest;
168    struct ray_query_intersection_vars candidate;
169 
170    struct ray_query_traversal_vars trav;
171 
172    rq_variable *stack;
173    uint32_t shared_base;
174    uint32_t stack_entries;
175 
176    nir_intrinsic_instr *initialize;
177 };
178 
179 #define VAR_NAME(name) strcat(strcpy(ralloc_size(ctx, strlen(base_name) + strlen(name) + 1), base_name), name)
180 
181 static struct ray_query_traversal_vars
init_ray_query_traversal_vars(void * ctx,nir_shader * shader,unsigned array_length,const char * base_name)182 init_ray_query_traversal_vars(void *ctx, nir_shader *shader, unsigned array_length, const char *base_name)
183 {
184    struct ray_query_traversal_vars result;
185 
186    const struct glsl_type *vec3_type = glsl_vector_type(GLSL_TYPE_FLOAT, 3);
187 
188    result.origin = rq_variable_create(ctx, shader, array_length, vec3_type, VAR_NAME("_origin"));
189    result.direction = rq_variable_create(ctx, shader, array_length, vec3_type, VAR_NAME("_direction"));
190 
191    result.bvh_base = rq_variable_create(ctx, shader, array_length, glsl_uint64_t_type(), VAR_NAME("_bvh_base"));
192    result.stack = rq_variable_create(ctx, shader, array_length, glsl_uint_type(), VAR_NAME("_stack"));
193    result.top_stack = rq_variable_create(ctx, shader, array_length, glsl_uint_type(), VAR_NAME("_top_stack"));
194    result.stack_low_watermark =
195       rq_variable_create(ctx, shader, array_length, glsl_uint_type(), VAR_NAME("_stack_low_watermark"));
196    result.current_node = rq_variable_create(ctx, shader, array_length, glsl_uint_type(), VAR_NAME("_current_node"));
197    result.previous_node = rq_variable_create(ctx, shader, array_length, glsl_uint_type(), VAR_NAME("_previous_node"));
198    result.instance_top_node =
199       rq_variable_create(ctx, shader, array_length, glsl_uint_type(), VAR_NAME("_instance_top_node"));
200    result.instance_bottom_node =
201       rq_variable_create(ctx, shader, array_length, glsl_uint_type(), VAR_NAME("_instance_bottom_node"));
202    return result;
203 }
204 
205 static struct ray_query_intersection_vars
init_ray_query_intersection_vars(void * ctx,nir_shader * shader,unsigned array_length,const char * base_name)206 init_ray_query_intersection_vars(void *ctx, nir_shader *shader, unsigned array_length, const char *base_name)
207 {
208    struct ray_query_intersection_vars result;
209 
210    const struct glsl_type *vec2_type = glsl_vector_type(GLSL_TYPE_FLOAT, 2);
211 
212    result.primitive_id = rq_variable_create(ctx, shader, array_length, glsl_uint_type(), VAR_NAME("_primitive_id"));
213    result.geometry_id_and_flags =
214       rq_variable_create(ctx, shader, array_length, glsl_uint_type(), VAR_NAME("_geometry_id_and_flags"));
215    result.instance_addr =
216       rq_variable_create(ctx, shader, array_length, glsl_uint64_t_type(), VAR_NAME("_instance_addr"));
217    result.intersection_type =
218       rq_variable_create(ctx, shader, array_length, glsl_uint_type(), VAR_NAME("_intersection_type"));
219    result.opaque = rq_variable_create(ctx, shader, array_length, glsl_bool_type(), VAR_NAME("_opaque"));
220    result.frontface = rq_variable_create(ctx, shader, array_length, glsl_bool_type(), VAR_NAME("_frontface"));
221    result.sbt_offset_and_flags =
222       rq_variable_create(ctx, shader, array_length, glsl_uint_type(), VAR_NAME("_sbt_offset_and_flags"));
223    result.barycentrics = rq_variable_create(ctx, shader, array_length, vec2_type, VAR_NAME("_barycentrics"));
224    result.t = rq_variable_create(ctx, shader, array_length, glsl_float_type(), VAR_NAME("_t"));
225 
226    return result;
227 }
228 
229 static void
init_ray_query_vars(nir_shader * shader,unsigned array_length,struct ray_query_vars * dst,const char * base_name,uint32_t max_shared_size)230 init_ray_query_vars(nir_shader *shader, unsigned array_length, struct ray_query_vars *dst, const char *base_name,
231                     uint32_t max_shared_size)
232 {
233    void *ctx = dst;
234    const struct glsl_type *vec3_type = glsl_vector_type(GLSL_TYPE_FLOAT, 3);
235 
236    dst->root_bvh_base = rq_variable_create(dst, shader, array_length, glsl_uint64_t_type(), VAR_NAME("_root_bvh_base"));
237    dst->flags = rq_variable_create(dst, shader, array_length, glsl_uint_type(), VAR_NAME("_flags"));
238    dst->cull_mask = rq_variable_create(dst, shader, array_length, glsl_uint_type(), VAR_NAME("_cull_mask"));
239    dst->origin = rq_variable_create(dst, shader, array_length, vec3_type, VAR_NAME("_origin"));
240    dst->tmin = rq_variable_create(dst, shader, array_length, glsl_float_type(), VAR_NAME("_tmin"));
241    dst->direction = rq_variable_create(dst, shader, array_length, vec3_type, VAR_NAME("_direction"));
242 
243    dst->incomplete = rq_variable_create(dst, shader, array_length, glsl_bool_type(), VAR_NAME("_incomplete"));
244 
245    dst->closest = init_ray_query_intersection_vars(dst, shader, array_length, VAR_NAME("_closest"));
246    dst->candidate = init_ray_query_intersection_vars(dst, shader, array_length, VAR_NAME("_candidate"));
247 
248    dst->trav = init_ray_query_traversal_vars(dst, shader, array_length, VAR_NAME("_top"));
249 
250    uint32_t workgroup_size =
251       shader->info.workgroup_size[0] * shader->info.workgroup_size[1] * shader->info.workgroup_size[2];
252    uint32_t shared_stack_entries = shader->info.ray_queries == 1 ? 16 : 8;
253    uint32_t shared_stack_size = workgroup_size * shared_stack_entries * 4;
254    uint32_t shared_offset = align(shader->info.shared_size, 4);
255    if (shader->info.stage != MESA_SHADER_COMPUTE || array_length > 1 ||
256        shared_offset + shared_stack_size > max_shared_size) {
257       dst->stack =
258          rq_variable_create(dst, shader, array_length,
259                             glsl_array_type(glsl_uint_type(), MAX_SCRATCH_STACK_ENTRY_COUNT, 0), VAR_NAME("_stack"));
260       dst->stack_entries = MAX_SCRATCH_STACK_ENTRY_COUNT;
261    } else {
262       dst->stack = NULL;
263       dst->shared_base = shared_offset;
264       dst->stack_entries = shared_stack_entries;
265 
266       shader->info.shared_size = shared_offset + shared_stack_size;
267    }
268 }
269 
270 #undef VAR_NAME
271 
272 static void
lower_ray_query(nir_shader * shader,nir_variable * ray_query,struct hash_table * ht,uint32_t max_shared_size)273 lower_ray_query(nir_shader *shader, nir_variable *ray_query, struct hash_table *ht, uint32_t max_shared_size)
274 {
275    struct ray_query_vars *vars = ralloc(ht, struct ray_query_vars);
276 
277    unsigned array_length = 1;
278    if (glsl_type_is_array(ray_query->type))
279       array_length = glsl_get_length(ray_query->type);
280 
281    init_ray_query_vars(shader, array_length, vars, ray_query->name == NULL ? "" : ray_query->name, max_shared_size);
282 
283    _mesa_hash_table_insert(ht, ray_query, vars);
284 }
285 
286 static void
copy_candidate_to_closest(nir_builder * b,nir_def * index,struct ray_query_vars * vars)287 copy_candidate_to_closest(nir_builder *b, nir_def *index, struct ray_query_vars *vars)
288 {
289    rq_copy_var(b, index, vars->closest.barycentrics, vars->candidate.barycentrics, 0x3);
290    rq_copy_var(b, index, vars->closest.geometry_id_and_flags, vars->candidate.geometry_id_and_flags, 0x1);
291    rq_copy_var(b, index, vars->closest.instance_addr, vars->candidate.instance_addr, 0x1);
292    rq_copy_var(b, index, vars->closest.intersection_type, vars->candidate.intersection_type, 0x1);
293    rq_copy_var(b, index, vars->closest.opaque, vars->candidate.opaque, 0x1);
294    rq_copy_var(b, index, vars->closest.frontface, vars->candidate.frontface, 0x1);
295    rq_copy_var(b, index, vars->closest.sbt_offset_and_flags, vars->candidate.sbt_offset_and_flags, 0x1);
296    rq_copy_var(b, index, vars->closest.primitive_id, vars->candidate.primitive_id, 0x1);
297    rq_copy_var(b, index, vars->closest.t, vars->candidate.t, 0x1);
298 }
299 
300 static void
insert_terminate_on_first_hit(nir_builder * b,nir_def * index,struct ray_query_vars * vars,const struct radv_ray_flags * ray_flags,bool break_on_terminate)301 insert_terminate_on_first_hit(nir_builder *b, nir_def *index, struct ray_query_vars *vars,
302                               const struct radv_ray_flags *ray_flags, bool break_on_terminate)
303 {
304    nir_def *terminate_on_first_hit;
305    if (ray_flags)
306       terminate_on_first_hit = ray_flags->terminate_on_first_hit;
307    else
308       terminate_on_first_hit =
309          nir_test_mask(b, rq_load_var(b, index, vars->flags), SpvRayFlagsTerminateOnFirstHitKHRMask);
310    nir_push_if(b, terminate_on_first_hit);
311    {
312       rq_store_var(b, index, vars->incomplete, nir_imm_false(b), 0x1);
313       if (break_on_terminate)
314          nir_jump(b, nir_jump_break);
315    }
316    nir_pop_if(b, NULL);
317 }
318 
319 static void
lower_rq_confirm_intersection(nir_builder * b,nir_def * index,nir_intrinsic_instr * instr,struct ray_query_vars * vars)320 lower_rq_confirm_intersection(nir_builder *b, nir_def *index, nir_intrinsic_instr *instr, struct ray_query_vars *vars)
321 {
322    copy_candidate_to_closest(b, index, vars);
323    insert_terminate_on_first_hit(b, index, vars, NULL, false);
324 }
325 
326 static void
lower_rq_generate_intersection(nir_builder * b,nir_def * index,nir_intrinsic_instr * instr,struct ray_query_vars * vars)327 lower_rq_generate_intersection(nir_builder *b, nir_def *index, nir_intrinsic_instr *instr, struct ray_query_vars *vars)
328 {
329    nir_push_if(b, nir_iand(b, nir_fge(b, rq_load_var(b, index, vars->closest.t), instr->src[1].ssa),
330                            nir_fge(b, instr->src[1].ssa, rq_load_var(b, index, vars->tmin))));
331    {
332       copy_candidate_to_closest(b, index, vars);
333       insert_terminate_on_first_hit(b, index, vars, NULL, false);
334       rq_store_var(b, index, vars->closest.t, instr->src[1].ssa, 0x1);
335    }
336    nir_pop_if(b, NULL);
337 }
338 
339 enum rq_intersection_type { intersection_type_none, intersection_type_triangle, intersection_type_aabb };
340 
341 static void
lower_rq_initialize(nir_builder * b,nir_def * index,nir_intrinsic_instr * instr,struct ray_query_vars * vars,struct radv_instance * instance)342 lower_rq_initialize(nir_builder *b, nir_def *index, nir_intrinsic_instr *instr, struct ray_query_vars *vars,
343                     struct radv_instance *instance)
344 {
345    rq_store_var(b, index, vars->flags, instr->src[2].ssa, 0x1);
346    rq_store_var(b, index, vars->cull_mask, nir_ishl_imm(b, instr->src[3].ssa, 24), 0x1);
347 
348    rq_store_var(b, index, vars->origin, instr->src[4].ssa, 0x7);
349    rq_store_var(b, index, vars->trav.origin, instr->src[4].ssa, 0x7);
350 
351    rq_store_var(b, index, vars->tmin, instr->src[5].ssa, 0x1);
352 
353    rq_store_var(b, index, vars->direction, instr->src[6].ssa, 0x7);
354    rq_store_var(b, index, vars->trav.direction, instr->src[6].ssa, 0x7);
355 
356    rq_store_var(b, index, vars->closest.t, instr->src[7].ssa, 0x1);
357    rq_store_var(b, index, vars->closest.intersection_type, nir_imm_int(b, intersection_type_none), 0x1);
358 
359    nir_def *accel_struct = instr->src[1].ssa;
360 
361    /* Make sure that instance data loads don't hang in case of a miss by setting a valid initial address. */
362    rq_store_var(b, index, vars->closest.instance_addr, accel_struct, 1);
363    rq_store_var(b, index, vars->candidate.instance_addr, accel_struct, 1);
364 
365    nir_def *bvh_offset = nir_build_load_global(
366       b, 1, 32, nir_iadd_imm(b, accel_struct, offsetof(struct radv_accel_struct_header, bvh_offset)),
367       .access = ACCESS_NON_WRITEABLE);
368    nir_def *bvh_base = nir_iadd(b, accel_struct, nir_u2u64(b, bvh_offset));
369    bvh_base = build_addr_to_node(b, bvh_base);
370 
371    rq_store_var(b, index, vars->root_bvh_base, bvh_base, 0x1);
372    rq_store_var(b, index, vars->trav.bvh_base, bvh_base, 1);
373 
374    if (vars->stack) {
375       rq_store_var(b, index, vars->trav.stack, nir_imm_int(b, 0), 0x1);
376       rq_store_var(b, index, vars->trav.stack_low_watermark, nir_imm_int(b, 0), 0x1);
377    } else {
378       nir_def *base_offset = nir_imul_imm(b, nir_load_local_invocation_index(b), sizeof(uint32_t));
379       base_offset = nir_iadd_imm(b, base_offset, vars->shared_base);
380       rq_store_var(b, index, vars->trav.stack, base_offset, 0x1);
381       rq_store_var(b, index, vars->trav.stack_low_watermark, base_offset, 0x1);
382    }
383 
384    rq_store_var(b, index, vars->trav.current_node, nir_imm_int(b, RADV_BVH_ROOT_NODE), 0x1);
385    rq_store_var(b, index, vars->trav.previous_node, nir_imm_int(b, RADV_BVH_INVALID_NODE), 0x1);
386    rq_store_var(b, index, vars->trav.instance_top_node, nir_imm_int(b, RADV_BVH_INVALID_NODE), 0x1);
387    rq_store_var(b, index, vars->trav.instance_bottom_node, nir_imm_int(b, RADV_BVH_NO_INSTANCE_ROOT), 0x1);
388 
389    rq_store_var(b, index, vars->trav.top_stack, nir_imm_int(b, -1), 1);
390 
391    rq_store_var(b, index, vars->incomplete, nir_imm_bool(b, !(instance->debug_flags & RADV_DEBUG_NO_RT)), 0x1);
392 
393    vars->initialize = instr;
394 }
395 
396 static nir_def *
lower_rq_load(struct radv_device * device,nir_builder * b,nir_def * index,nir_intrinsic_instr * instr,struct ray_query_vars * vars)397 lower_rq_load(struct radv_device *device, nir_builder *b, nir_def *index, nir_intrinsic_instr *instr,
398               struct ray_query_vars *vars)
399 {
400    bool committed = nir_intrinsic_committed(instr);
401    struct ray_query_intersection_vars *intersection = committed ? &vars->closest : &vars->candidate;
402 
403    uint32_t column = nir_intrinsic_column(instr);
404 
405    nir_ray_query_value value = nir_intrinsic_ray_query_value(instr);
406    switch (value) {
407    case nir_ray_query_value_flags:
408       return rq_load_var(b, index, vars->flags);
409    case nir_ray_query_value_intersection_barycentrics:
410       return rq_load_var(b, index, intersection->barycentrics);
411    case nir_ray_query_value_intersection_candidate_aabb_opaque:
412       return nir_iand(b, rq_load_var(b, index, vars->candidate.opaque),
413                       nir_ieq_imm(b, rq_load_var(b, index, vars->candidate.intersection_type), intersection_type_aabb));
414    case nir_ray_query_value_intersection_front_face:
415       return rq_load_var(b, index, intersection->frontface);
416    case nir_ray_query_value_intersection_geometry_index:
417       return nir_iand_imm(b, rq_load_var(b, index, intersection->geometry_id_and_flags), 0xFFFFFF);
418    case nir_ray_query_value_intersection_instance_custom_index: {
419       nir_def *instance_node_addr = rq_load_var(b, index, intersection->instance_addr);
420       return nir_iand_imm(
421          b,
422          nir_build_load_global(
423             b, 1, 32,
424             nir_iadd_imm(b, instance_node_addr, offsetof(struct radv_bvh_instance_node, custom_instance_and_mask))),
425          0xFFFFFF);
426    }
427    case nir_ray_query_value_intersection_instance_id: {
428       nir_def *instance_node_addr = rq_load_var(b, index, intersection->instance_addr);
429       return nir_build_load_global(
430          b, 1, 32, nir_iadd_imm(b, instance_node_addr, offsetof(struct radv_bvh_instance_node, instance_id)));
431    }
432    case nir_ray_query_value_intersection_instance_sbt_index:
433       return nir_iand_imm(b, rq_load_var(b, index, intersection->sbt_offset_and_flags), 0xFFFFFF);
434    case nir_ray_query_value_intersection_object_ray_direction: {
435       nir_def *instance_node_addr = rq_load_var(b, index, intersection->instance_addr);
436       nir_def *wto_matrix[3];
437       nir_build_wto_matrix_load(b, instance_node_addr, wto_matrix);
438       return nir_build_vec3_mat_mult(b, rq_load_var(b, index, vars->direction), wto_matrix, false);
439    }
440    case nir_ray_query_value_intersection_object_ray_origin: {
441       nir_def *instance_node_addr = rq_load_var(b, index, intersection->instance_addr);
442       nir_def *wto_matrix[3];
443       nir_build_wto_matrix_load(b, instance_node_addr, wto_matrix);
444       return nir_build_vec3_mat_mult(b, rq_load_var(b, index, vars->origin), wto_matrix, true);
445    }
446    case nir_ray_query_value_intersection_object_to_world: {
447       nir_def *instance_node_addr = rq_load_var(b, index, intersection->instance_addr);
448       nir_def *rows[3];
449       for (unsigned r = 0; r < 3; ++r)
450          rows[r] = nir_build_load_global(
451             b, 4, 32,
452             nir_iadd_imm(b, instance_node_addr, offsetof(struct radv_bvh_instance_node, otw_matrix) + r * 16));
453 
454       return nir_vec3(b, nir_channel(b, rows[0], column), nir_channel(b, rows[1], column),
455                       nir_channel(b, rows[2], column));
456    }
457    case nir_ray_query_value_intersection_primitive_index:
458       return rq_load_var(b, index, intersection->primitive_id);
459    case nir_ray_query_value_intersection_t:
460       return rq_load_var(b, index, intersection->t);
461    case nir_ray_query_value_intersection_type: {
462       nir_def *intersection_type = rq_load_var(b, index, intersection->intersection_type);
463       if (!committed)
464          intersection_type = nir_iadd_imm(b, intersection_type, -1);
465 
466       return intersection_type;
467    }
468    case nir_ray_query_value_intersection_world_to_object: {
469       nir_def *instance_node_addr = rq_load_var(b, index, intersection->instance_addr);
470 
471       nir_def *wto_matrix[3];
472       nir_build_wto_matrix_load(b, instance_node_addr, wto_matrix);
473 
474       nir_def *vals[3];
475       for (unsigned i = 0; i < 3; ++i)
476          vals[i] = nir_channel(b, wto_matrix[i], column);
477 
478       return nir_vec(b, vals, 3);
479    }
480    case nir_ray_query_value_tmin:
481       return rq_load_var(b, index, vars->tmin);
482    case nir_ray_query_value_world_ray_direction:
483       return rq_load_var(b, index, vars->direction);
484    case nir_ray_query_value_world_ray_origin:
485       return rq_load_var(b, index, vars->origin);
486    case nir_ray_query_value_intersection_triangle_vertex_positions: {
487       nir_def *instance_node_addr = rq_load_var(b, index, intersection->instance_addr);
488       nir_def *primitive_id = rq_load_var(b, index, intersection->primitive_id);
489       return radv_load_vertex_position(device, b, instance_node_addr, primitive_id, nir_intrinsic_column(instr));
490    }
491    default:
492       unreachable("Invalid nir_ray_query_value!");
493    }
494 
495    return NULL;
496 }
497 
498 struct traversal_data {
499    struct ray_query_vars *vars;
500    nir_def *index;
501 };
502 
503 static void
handle_candidate_aabb(nir_builder * b,struct radv_leaf_intersection * intersection,const struct radv_ray_traversal_args * args)504 handle_candidate_aabb(nir_builder *b, struct radv_leaf_intersection *intersection,
505                       const struct radv_ray_traversal_args *args)
506 {
507    struct traversal_data *data = args->data;
508    struct ray_query_vars *vars = data->vars;
509    nir_def *index = data->index;
510 
511    rq_store_var(b, index, vars->candidate.primitive_id, intersection->primitive_id, 1);
512    rq_store_var(b, index, vars->candidate.geometry_id_and_flags, intersection->geometry_id_and_flags, 1);
513    rq_store_var(b, index, vars->candidate.opaque, intersection->opaque, 0x1);
514    rq_store_var(b, index, vars->candidate.intersection_type, nir_imm_int(b, intersection_type_aabb), 0x1);
515 
516    nir_jump(b, nir_jump_break);
517 }
518 
519 static void
handle_candidate_triangle(nir_builder * b,struct radv_triangle_intersection * intersection,const struct radv_ray_traversal_args * args,const struct radv_ray_flags * ray_flags)520 handle_candidate_triangle(nir_builder *b, struct radv_triangle_intersection *intersection,
521                           const struct radv_ray_traversal_args *args, const struct radv_ray_flags *ray_flags)
522 {
523    struct traversal_data *data = args->data;
524    struct ray_query_vars *vars = data->vars;
525    nir_def *index = data->index;
526 
527    rq_store_var(b, index, vars->candidate.barycentrics, intersection->barycentrics, 3);
528    rq_store_var(b, index, vars->candidate.primitive_id, intersection->base.primitive_id, 1);
529    rq_store_var(b, index, vars->candidate.geometry_id_and_flags, intersection->base.geometry_id_and_flags, 1);
530    rq_store_var(b, index, vars->candidate.t, intersection->t, 0x1);
531    rq_store_var(b, index, vars->candidate.opaque, intersection->base.opaque, 0x1);
532    rq_store_var(b, index, vars->candidate.frontface, intersection->frontface, 0x1);
533    rq_store_var(b, index, vars->candidate.intersection_type, nir_imm_int(b, intersection_type_triangle), 0x1);
534 
535    nir_push_if(b, intersection->base.opaque);
536    {
537       copy_candidate_to_closest(b, index, vars);
538       insert_terminate_on_first_hit(b, index, vars, ray_flags, true);
539    }
540    nir_push_else(b, NULL);
541    {
542       nir_jump(b, nir_jump_break);
543    }
544    nir_pop_if(b, NULL);
545 }
546 
547 static void
store_stack_entry(nir_builder * b,nir_def * index,nir_def * value,const struct radv_ray_traversal_args * args)548 store_stack_entry(nir_builder *b, nir_def *index, nir_def *value, const struct radv_ray_traversal_args *args)
549 {
550    struct traversal_data *data = args->data;
551    if (data->vars->stack)
552       rq_store_array(b, data->index, data->vars->stack, index, value, 1);
553    else
554       nir_store_shared(b, value, index, .base = 0, .align_mul = 4);
555 }
556 
557 static nir_def *
load_stack_entry(nir_builder * b,nir_def * index,const struct radv_ray_traversal_args * args)558 load_stack_entry(nir_builder *b, nir_def *index, const struct radv_ray_traversal_args *args)
559 {
560    struct traversal_data *data = args->data;
561    if (data->vars->stack)
562       return rq_load_array(b, data->index, data->vars->stack, index);
563    else
564       return nir_load_shared(b, 1, 32, index, .base = 0, .align_mul = 4);
565 }
566 
567 static nir_def *
lower_rq_proceed(nir_builder * b,nir_def * index,nir_intrinsic_instr * instr,struct ray_query_vars * vars,struct radv_device * device)568 lower_rq_proceed(nir_builder *b, nir_def *index, nir_intrinsic_instr *instr, struct ray_query_vars *vars,
569                  struct radv_device *device)
570 {
571    nir_metadata_require(nir_cf_node_get_function(&instr->instr.block->cf_node), nir_metadata_dominance);
572 
573    bool ignore_cull_mask = false;
574    if (nir_block_dominates(vars->initialize->instr.block, instr->instr.block)) {
575       nir_src cull_mask = vars->initialize->src[3];
576       if (nir_src_is_const(cull_mask) && nir_src_as_uint(cull_mask) == 0xFF)
577          ignore_cull_mask = true;
578    }
579 
580    nir_variable *inv_dir = nir_local_variable_create(b->impl, glsl_vector_type(GLSL_TYPE_FLOAT, 3), "inv_dir");
581    nir_store_var(b, inv_dir, nir_frcp(b, rq_load_var(b, index, vars->trav.direction)), 0x7);
582 
583    struct radv_ray_traversal_vars trav_vars = {
584       .tmax = rq_deref_var(b, index, vars->closest.t),
585       .origin = rq_deref_var(b, index, vars->trav.origin),
586       .dir = rq_deref_var(b, index, vars->trav.direction),
587       .inv_dir = nir_build_deref_var(b, inv_dir),
588       .bvh_base = rq_deref_var(b, index, vars->trav.bvh_base),
589       .stack = rq_deref_var(b, index, vars->trav.stack),
590       .top_stack = rq_deref_var(b, index, vars->trav.top_stack),
591       .stack_low_watermark = rq_deref_var(b, index, vars->trav.stack_low_watermark),
592       .current_node = rq_deref_var(b, index, vars->trav.current_node),
593       .previous_node = rq_deref_var(b, index, vars->trav.previous_node),
594       .instance_top_node = rq_deref_var(b, index, vars->trav.instance_top_node),
595       .instance_bottom_node = rq_deref_var(b, index, vars->trav.instance_bottom_node),
596       .instance_addr = rq_deref_var(b, index, vars->candidate.instance_addr),
597       .sbt_offset_and_flags = rq_deref_var(b, index, vars->candidate.sbt_offset_and_flags),
598    };
599 
600    struct traversal_data data = {
601       .vars = vars,
602       .index = index,
603    };
604 
605    struct radv_ray_traversal_args args = {
606       .root_bvh_base = rq_load_var(b, index, vars->root_bvh_base),
607       .flags = rq_load_var(b, index, vars->flags),
608       .cull_mask = rq_load_var(b, index, vars->cull_mask),
609       .origin = rq_load_var(b, index, vars->origin),
610       .tmin = rq_load_var(b, index, vars->tmin),
611       .dir = rq_load_var(b, index, vars->direction),
612       .vars = trav_vars,
613       .stack_entries = vars->stack_entries,
614       .ignore_cull_mask = ignore_cull_mask,
615       .stack_store_cb = store_stack_entry,
616       .stack_load_cb = load_stack_entry,
617       .aabb_cb = handle_candidate_aabb,
618       .triangle_cb = handle_candidate_triangle,
619       .data = &data,
620    };
621 
622    if (vars->stack) {
623       args.stack_stride = 1;
624       args.stack_base = 0;
625    } else {
626       uint32_t workgroup_size =
627          b->shader->info.workgroup_size[0] * b->shader->info.workgroup_size[1] * b->shader->info.workgroup_size[2];
628       args.stack_stride = workgroup_size * 4;
629       args.stack_base = vars->shared_base;
630    }
631 
632    nir_push_if(b, rq_load_var(b, index, vars->incomplete));
633    {
634       nir_def *incomplete = radv_build_ray_traversal(device, b, &args);
635       rq_store_var(b, index, vars->incomplete, nir_iand(b, rq_load_var(b, index, vars->incomplete), incomplete), 1);
636    }
637    nir_pop_if(b, NULL);
638 
639    return rq_load_var(b, index, vars->incomplete);
640 }
641 
642 static void
lower_rq_terminate(nir_builder * b,nir_def * index,nir_intrinsic_instr * instr,struct ray_query_vars * vars)643 lower_rq_terminate(nir_builder *b, nir_def *index, nir_intrinsic_instr *instr, struct ray_query_vars *vars)
644 {
645    rq_store_var(b, index, vars->incomplete, nir_imm_false(b), 0x1);
646 }
647 
648 bool
radv_nir_lower_ray_queries(struct nir_shader * shader,struct radv_device * device)649 radv_nir_lower_ray_queries(struct nir_shader *shader, struct radv_device *device)
650 {
651    bool progress = false;
652    struct hash_table *query_ht = _mesa_pointer_hash_table_create(NULL);
653 
654    nir_foreach_variable_in_list (var, &shader->variables) {
655       if (!var->data.ray_query)
656          continue;
657 
658       lower_ray_query(shader, var, query_ht, device->physical_device->max_shared_size);
659 
660       progress = true;
661    }
662 
663    nir_foreach_function (function, shader) {
664       if (!function->impl)
665          continue;
666 
667       nir_builder builder = nir_builder_create(function->impl);
668 
669       nir_foreach_variable_in_list (var, &function->impl->locals) {
670          if (!var->data.ray_query)
671             continue;
672 
673          lower_ray_query(shader, var, query_ht, device->physical_device->max_shared_size);
674 
675          progress = true;
676       }
677 
678       nir_foreach_block (block, function->impl) {
679          nir_foreach_instr_safe (instr, block) {
680             if (instr->type != nir_instr_type_intrinsic)
681                continue;
682 
683             nir_intrinsic_instr *intrinsic = nir_instr_as_intrinsic(instr);
684 
685             if (!nir_intrinsic_is_ray_query(intrinsic->intrinsic))
686                continue;
687 
688             nir_deref_instr *ray_query_deref = nir_instr_as_deref(intrinsic->src[0].ssa->parent_instr);
689             nir_def *index = NULL;
690 
691             if (ray_query_deref->deref_type == nir_deref_type_array) {
692                index = ray_query_deref->arr.index.ssa;
693                ray_query_deref = nir_instr_as_deref(ray_query_deref->parent.ssa->parent_instr);
694             }
695 
696             assert(ray_query_deref->deref_type == nir_deref_type_var);
697 
698             struct ray_query_vars *vars =
699                (struct ray_query_vars *)_mesa_hash_table_search(query_ht, ray_query_deref->var)->data;
700 
701             builder.cursor = nir_before_instr(instr);
702 
703             nir_def *new_dest = NULL;
704 
705             switch (intrinsic->intrinsic) {
706             case nir_intrinsic_rq_confirm_intersection:
707                lower_rq_confirm_intersection(&builder, index, intrinsic, vars);
708                break;
709             case nir_intrinsic_rq_generate_intersection:
710                lower_rq_generate_intersection(&builder, index, intrinsic, vars);
711                break;
712             case nir_intrinsic_rq_initialize:
713                lower_rq_initialize(&builder, index, intrinsic, vars, device->instance);
714                break;
715             case nir_intrinsic_rq_load:
716                new_dest = lower_rq_load(device, &builder, index, intrinsic, vars);
717                break;
718             case nir_intrinsic_rq_proceed:
719                new_dest = lower_rq_proceed(&builder, index, intrinsic, vars, device);
720                break;
721             case nir_intrinsic_rq_terminate:
722                lower_rq_terminate(&builder, index, intrinsic, vars);
723                break;
724             default:
725                unreachable("Unsupported ray query intrinsic!");
726             }
727 
728             if (new_dest)
729                nir_def_rewrite_uses(&intrinsic->def, new_dest);
730 
731             nir_instr_remove(instr);
732             nir_instr_free(instr);
733 
734             progress = true;
735          }
736       }
737 
738       nir_metadata_preserve(function->impl, nir_metadata_none);
739    }
740 
741    ralloc_free(query_ht);
742 
743    return progress;
744 }
745