• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright © 2021 Google
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 #include "nir/radv_nir_rt_common.h"
25 #include "bvh/bvh.h"
26 #include "radv_debug.h"
27 
28 #if LLVM_AVAILABLE
29 #include <llvm/Config/llvm-config.h>
30 #endif
31 
32 static nir_def *build_node_to_addr(struct radv_device *device, nir_builder *b, nir_def *node, bool skip_type_and);
33 
34 static void
nir_sort_hit_pair(nir_builder * b,nir_variable * var_distances,nir_variable * var_indices,uint32_t chan_1,uint32_t chan_2)35 nir_sort_hit_pair(nir_builder *b, nir_variable *var_distances, nir_variable *var_indices, uint32_t chan_1,
36                   uint32_t chan_2)
37 {
38    nir_def *ssa_distances = nir_load_var(b, var_distances);
39    nir_def *ssa_indices = nir_load_var(b, var_indices);
40    /* if (distances[chan_2] < distances[chan_1]) { */
41    nir_push_if(b, nir_flt(b, nir_channel(b, ssa_distances, chan_2), nir_channel(b, ssa_distances, chan_1)));
42    {
43       /* swap(distances[chan_2], distances[chan_1]); */
44       nir_def *new_distances[4] = {nir_undef(b, 1, 32), nir_undef(b, 1, 32), nir_undef(b, 1, 32), nir_undef(b, 1, 32)};
45       nir_def *new_indices[4] = {nir_undef(b, 1, 32), nir_undef(b, 1, 32), nir_undef(b, 1, 32), nir_undef(b, 1, 32)};
46       new_distances[chan_2] = nir_channel(b, ssa_distances, chan_1);
47       new_distances[chan_1] = nir_channel(b, ssa_distances, chan_2);
48       new_indices[chan_2] = nir_channel(b, ssa_indices, chan_1);
49       new_indices[chan_1] = nir_channel(b, ssa_indices, chan_2);
50       nir_store_var(b, var_distances, nir_vec(b, new_distances, 4), (1u << chan_1) | (1u << chan_2));
51       nir_store_var(b, var_indices, nir_vec(b, new_indices, 4), (1u << chan_1) | (1u << chan_2));
52    }
53    /* } */
54    nir_pop_if(b, NULL);
55 }
56 
57 static nir_def *
intersect_ray_amd_software_box(struct radv_device * device,nir_builder * b,nir_def * bvh_node,nir_def * ray_tmax,nir_def * origin,nir_def * dir,nir_def * inv_dir)58 intersect_ray_amd_software_box(struct radv_device *device, nir_builder *b, nir_def *bvh_node, nir_def *ray_tmax,
59                                nir_def *origin, nir_def *dir, nir_def *inv_dir)
60 {
61    const struct glsl_type *vec4_type = glsl_vector_type(GLSL_TYPE_FLOAT, 4);
62    const struct glsl_type *uvec4_type = glsl_vector_type(GLSL_TYPE_UINT, 4);
63 
64    bool old_exact = b->exact;
65    b->exact = true;
66 
67    nir_def *node_addr = build_node_to_addr(device, b, bvh_node, false);
68 
69    /* vec4 distances = vec4(INF, INF, INF, INF); */
70    nir_variable *distances = nir_variable_create(b->shader, nir_var_shader_temp, vec4_type, "distances");
71    nir_store_var(b, distances, nir_imm_vec4(b, INFINITY, INFINITY, INFINITY, INFINITY), 0xf);
72 
73    /* uvec4 child_indices = uvec4(0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff); */
74    nir_variable *child_indices = nir_variable_create(b->shader, nir_var_shader_temp, uvec4_type, "child_indices");
75    nir_store_var(b, child_indices, nir_imm_ivec4(b, 0xffffffffu, 0xffffffffu, 0xffffffffu, 0xffffffffu), 0xf);
76 
77    /* Need to remove infinities here because otherwise we get nasty NaN propagation
78     * if the direction has 0s in it. */
79    /* inv_dir = clamp(inv_dir, -FLT_MAX, FLT_MAX); */
80    inv_dir = nir_fclamp(b, inv_dir, nir_imm_float(b, -FLT_MAX), nir_imm_float(b, FLT_MAX));
81 
82    for (int i = 0; i < 4; i++) {
83       const uint32_t child_offset = offsetof(struct radv_bvh_box32_node, children[i]);
84       const uint32_t coord_offsets[2] = {
85          offsetof(struct radv_bvh_box32_node, coords[i].min.x),
86          offsetof(struct radv_bvh_box32_node, coords[i].max.x),
87       };
88 
89       /* node->children[i] -> uint */
90       nir_def *child_index = nir_build_load_global(b, 1, 32, nir_iadd_imm(b, node_addr, child_offset), .align_mul = 64,
91                                                    .align_offset = child_offset % 64);
92       /* node->coords[i][0], node->coords[i][1] -> vec3 */
93       nir_def *node_coords[2] = {
94          nir_build_load_global(b, 3, 32, nir_iadd_imm(b, node_addr, coord_offsets[0]), .align_mul = 64,
95                                .align_offset = coord_offsets[0] % 64),
96          nir_build_load_global(b, 3, 32, nir_iadd_imm(b, node_addr, coord_offsets[1]), .align_mul = 64,
97                                .align_offset = coord_offsets[1] % 64),
98       };
99 
100       /* If x of the aabb min is NaN, then this is an inactive aabb.
101        * We don't need to care about any other components being NaN as that is UB.
102        * https://www.khronos.org/registry/vulkan/specs/1.2-extensions/html/chap36.html#VkAabbPositionsKHR
103        */
104       nir_def *min_x = nir_channel(b, node_coords[0], 0);
105       nir_def *min_x_is_not_nan = nir_inot(b, nir_fneu(b, min_x, min_x)); /* NaN != NaN -> true */
106 
107       /* vec3 bound0 = (node->coords[i][0] - origin) * inv_dir; */
108       nir_def *bound0 = nir_fmul(b, nir_fsub(b, node_coords[0], origin), inv_dir);
109       /* vec3 bound1 = (node->coords[i][1] - origin) * inv_dir; */
110       nir_def *bound1 = nir_fmul(b, nir_fsub(b, node_coords[1], origin), inv_dir);
111 
112       /* float tmin = max(max(min(bound0.x, bound1.x), min(bound0.y, bound1.y)), min(bound0.z,
113        * bound1.z)); */
114       nir_def *tmin = nir_fmax(b,
115                                nir_fmax(b, nir_fmin(b, nir_channel(b, bound0, 0), nir_channel(b, bound1, 0)),
116                                         nir_fmin(b, nir_channel(b, bound0, 1), nir_channel(b, bound1, 1))),
117                                nir_fmin(b, nir_channel(b, bound0, 2), nir_channel(b, bound1, 2)));
118 
119       /* float tmax = min(min(max(bound0.x, bound1.x), max(bound0.y, bound1.y)), max(bound0.z,
120        * bound1.z)); */
121       nir_def *tmax = nir_fmin(b,
122                                nir_fmin(b, nir_fmax(b, nir_channel(b, bound0, 0), nir_channel(b, bound1, 0)),
123                                         nir_fmax(b, nir_channel(b, bound0, 1), nir_channel(b, bound1, 1))),
124                                nir_fmax(b, nir_channel(b, bound0, 2), nir_channel(b, bound1, 2)));
125 
126       /* if (!isnan(node->coords[i][0].x) && tmax >= max(0.0f, tmin) && tmin < ray_tmax) { */
127       nir_push_if(b, nir_iand(b, min_x_is_not_nan,
128                               nir_iand(b, nir_fge(b, tmax, nir_fmax(b, nir_imm_float(b, 0.0f), tmin)),
129                                        nir_flt(b, tmin, ray_tmax))));
130       {
131          /* child_indices[i] = node->children[i]; */
132          nir_def *new_child_indices[4] = {child_index, child_index, child_index, child_index};
133          nir_store_var(b, child_indices, nir_vec(b, new_child_indices, 4), 1u << i);
134 
135          /* distances[i] = tmin; */
136          nir_def *new_distances[4] = {tmin, tmin, tmin, tmin};
137          nir_store_var(b, distances, nir_vec(b, new_distances, 4), 1u << i);
138       }
139       /* } */
140       nir_pop_if(b, NULL);
141    }
142 
143    /* Sort our distances with a sorting network. */
144    nir_sort_hit_pair(b, distances, child_indices, 0, 1);
145    nir_sort_hit_pair(b, distances, child_indices, 2, 3);
146    nir_sort_hit_pair(b, distances, child_indices, 0, 2);
147    nir_sort_hit_pair(b, distances, child_indices, 1, 3);
148    nir_sort_hit_pair(b, distances, child_indices, 1, 2);
149 
150    b->exact = old_exact;
151    return nir_load_var(b, child_indices);
152 }
153 
154 static nir_def *
intersect_ray_amd_software_tri(struct radv_device * device,nir_builder * b,nir_def * bvh_node,nir_def * ray_tmax,nir_def * origin,nir_def * dir,nir_def * inv_dir)155 intersect_ray_amd_software_tri(struct radv_device *device, nir_builder *b, nir_def *bvh_node, nir_def *ray_tmax,
156                                nir_def *origin, nir_def *dir, nir_def *inv_dir)
157 {
158    const struct glsl_type *vec4_type = glsl_vector_type(GLSL_TYPE_FLOAT, 4);
159 
160    bool old_exact = b->exact;
161    b->exact = true;
162 
163    nir_def *node_addr = build_node_to_addr(device, b, bvh_node, false);
164 
165    const uint32_t coord_offsets[3] = {
166       offsetof(struct radv_bvh_triangle_node, coords[0]),
167       offsetof(struct radv_bvh_triangle_node, coords[1]),
168       offsetof(struct radv_bvh_triangle_node, coords[2]),
169    };
170 
171    /* node->coords[0], node->coords[1], node->coords[2] -> vec3 */
172    nir_def *node_coords[3] = {
173       nir_build_load_global(b, 3, 32, nir_iadd_imm(b, node_addr, coord_offsets[0]), .align_mul = 64,
174                             .align_offset = coord_offsets[0] % 64),
175       nir_build_load_global(b, 3, 32, nir_iadd_imm(b, node_addr, coord_offsets[1]), .align_mul = 64,
176                             .align_offset = coord_offsets[1] % 64),
177       nir_build_load_global(b, 3, 32, nir_iadd_imm(b, node_addr, coord_offsets[2]), .align_mul = 64,
178                             .align_offset = coord_offsets[2] % 64),
179    };
180 
181    nir_variable *result = nir_variable_create(b->shader, nir_var_shader_temp, vec4_type, "result");
182    nir_store_var(b, result, nir_imm_vec4(b, INFINITY, 1.0f, 0.0f, 0.0f), 0xf);
183 
184    /* Based on watertight Ray/Triangle intersection from
185     * http://jcgt.org/published/0002/01/05/paper.pdf */
186 
187    /* Calculate the dimension where the ray direction is largest */
188    nir_def *abs_dir = nir_fabs(b, dir);
189 
190    nir_def *abs_dirs[3] = {
191       nir_channel(b, abs_dir, 0),
192       nir_channel(b, abs_dir, 1),
193       nir_channel(b, abs_dir, 2),
194    };
195    /* Find index of greatest value of abs_dir and put that as kz. */
196    nir_def *kz = nir_bcsel(b, nir_fge(b, abs_dirs[0], abs_dirs[1]),
197                            nir_bcsel(b, nir_fge(b, abs_dirs[0], abs_dirs[2]), nir_imm_int(b, 0), nir_imm_int(b, 2)),
198                            nir_bcsel(b, nir_fge(b, abs_dirs[1], abs_dirs[2]), nir_imm_int(b, 1), nir_imm_int(b, 2)));
199    nir_def *kx = nir_imod_imm(b, nir_iadd_imm(b, kz, 1), 3);
200    nir_def *ky = nir_imod_imm(b, nir_iadd_imm(b, kx, 1), 3);
201    nir_def *k_indices[3] = {kx, ky, kz};
202    nir_def *k = nir_vec(b, k_indices, 3);
203 
204    /* Swap kx and ky dimensions to preserve winding order */
205    unsigned swap_xy_swizzle[4] = {1, 0, 2, 3};
206    k = nir_bcsel(b, nir_flt_imm(b, nir_vector_extract(b, dir, kz), 0.0f), nir_swizzle(b, k, swap_xy_swizzle, 3), k);
207 
208    kx = nir_channel(b, k, 0);
209    ky = nir_channel(b, k, 1);
210    kz = nir_channel(b, k, 2);
211 
212    /* Calculate shear constants */
213    nir_def *sz = nir_frcp(b, nir_vector_extract(b, dir, kz));
214    nir_def *sx = nir_fmul(b, nir_vector_extract(b, dir, kx), sz);
215    nir_def *sy = nir_fmul(b, nir_vector_extract(b, dir, ky), sz);
216 
217    /* Calculate vertices relative to ray origin */
218    nir_def *v_a = nir_fsub(b, node_coords[0], origin);
219    nir_def *v_b = nir_fsub(b, node_coords[1], origin);
220    nir_def *v_c = nir_fsub(b, node_coords[2], origin);
221 
222    /* Perform shear and scale */
223    nir_def *ax = nir_fsub(b, nir_vector_extract(b, v_a, kx), nir_fmul(b, sx, nir_vector_extract(b, v_a, kz)));
224    nir_def *ay = nir_fsub(b, nir_vector_extract(b, v_a, ky), nir_fmul(b, sy, nir_vector_extract(b, v_a, kz)));
225    nir_def *bx = nir_fsub(b, nir_vector_extract(b, v_b, kx), nir_fmul(b, sx, nir_vector_extract(b, v_b, kz)));
226    nir_def *by = nir_fsub(b, nir_vector_extract(b, v_b, ky), nir_fmul(b, sy, nir_vector_extract(b, v_b, kz)));
227    nir_def *cx = nir_fsub(b, nir_vector_extract(b, v_c, kx), nir_fmul(b, sx, nir_vector_extract(b, v_c, kz)));
228    nir_def *cy = nir_fsub(b, nir_vector_extract(b, v_c, ky), nir_fmul(b, sy, nir_vector_extract(b, v_c, kz)));
229 
230    ax = nir_f2f64(b, ax);
231    ay = nir_f2f64(b, ay);
232    bx = nir_f2f64(b, bx);
233    by = nir_f2f64(b, by);
234    cx = nir_f2f64(b, cx);
235    cy = nir_f2f64(b, cy);
236 
237    nir_def *u = nir_fsub(b, nir_fmul(b, cx, by), nir_fmul(b, cy, bx));
238    nir_def *v = nir_fsub(b, nir_fmul(b, ax, cy), nir_fmul(b, ay, cx));
239    nir_def *w = nir_fsub(b, nir_fmul(b, bx, ay), nir_fmul(b, by, ax));
240 
241    /* Perform edge tests. */
242    nir_def *cond_back =
243       nir_ior(b, nir_ior(b, nir_flt_imm(b, u, 0.0f), nir_flt_imm(b, v, 0.0f)), nir_flt_imm(b, w, 0.0f));
244 
245    nir_def *cond_front =
246       nir_ior(b, nir_ior(b, nir_fgt_imm(b, u, 0.0f), nir_fgt_imm(b, v, 0.0f)), nir_fgt_imm(b, w, 0.0f));
247 
248    nir_def *cond = nir_inot(b, nir_iand(b, cond_back, cond_front));
249 
250    nir_push_if(b, cond);
251    {
252       nir_def *det = nir_fadd(b, u, nir_fadd(b, v, w));
253 
254       sz = nir_f2f64(b, sz);
255 
256       v_a = nir_f2f64(b, v_a);
257       v_b = nir_f2f64(b, v_b);
258       v_c = nir_f2f64(b, v_c);
259 
260       nir_def *az = nir_fmul(b, sz, nir_vector_extract(b, v_a, kz));
261       nir_def *bz = nir_fmul(b, sz, nir_vector_extract(b, v_b, kz));
262       nir_def *cz = nir_fmul(b, sz, nir_vector_extract(b, v_c, kz));
263 
264       nir_def *t = nir_fadd(b, nir_fadd(b, nir_fmul(b, u, az), nir_fmul(b, v, bz)), nir_fmul(b, w, cz));
265 
266       nir_def *t_signed = nir_fmul(b, nir_fsign(b, det), t);
267 
268       nir_def *det_cond_front = nir_inot(b, nir_flt_imm(b, t_signed, 0.0f));
269 
270       nir_push_if(b, det_cond_front);
271       {
272          t = nir_f2f32(b, nir_fdiv(b, t, det));
273          v = nir_f2f32(b, nir_fdiv(b, v, det));
274          w = nir_f2f32(b, nir_fdiv(b, w, det));
275 
276          nir_def *indices[4] = {t, nir_imm_float(b, 1.0), v, w};
277          nir_store_var(b, result, nir_vec(b, indices, 4), 0xf);
278       }
279       nir_pop_if(b, NULL);
280    }
281    nir_pop_if(b, NULL);
282 
283    b->exact = old_exact;
284    return nir_load_var(b, result);
285 }
286 
287 nir_def *
build_addr_to_node(nir_builder * b,nir_def * addr)288 build_addr_to_node(nir_builder *b, nir_def *addr)
289 {
290    const uint64_t bvh_size = 1ull << 42;
291    nir_def *node = nir_ushr_imm(b, addr, 3);
292    return nir_iand_imm(b, node, (bvh_size - 1) << 3);
293 }
294 
295 static nir_def *
build_node_to_addr(struct radv_device * device,nir_builder * b,nir_def * node,bool skip_type_and)296 build_node_to_addr(struct radv_device *device, nir_builder *b, nir_def *node, bool skip_type_and)
297 {
298    nir_def *addr = skip_type_and ? node : nir_iand_imm(b, node, ~7ull);
299    addr = nir_ishl_imm(b, addr, 3);
300    /* Assumes everything is in the top half of address space, which is true in
301     * GFX9+ for now. */
302    return device->physical_device->rad_info.gfx_level >= GFX9 ? nir_ior_imm(b, addr, 0xffffull << 48) : addr;
303 }
304 
305 nir_def *
nir_build_vec3_mat_mult(nir_builder * b,nir_def * vec,nir_def * matrix[],bool translation)306 nir_build_vec3_mat_mult(nir_builder *b, nir_def *vec, nir_def *matrix[], bool translation)
307 {
308    nir_def *result_components[3] = {
309       nir_channel(b, matrix[0], 3),
310       nir_channel(b, matrix[1], 3),
311       nir_channel(b, matrix[2], 3),
312    };
313    for (unsigned i = 0; i < 3; ++i) {
314       for (unsigned j = 0; j < 3; ++j) {
315          nir_def *v = nir_fmul(b, nir_channels(b, vec, 1 << j), nir_channels(b, matrix[i], 1 << j));
316          result_components[i] = (translation || j) ? nir_fadd(b, result_components[i], v) : v;
317       }
318    }
319    return nir_vec(b, result_components, 3);
320 }
321 
322 void
nir_build_wto_matrix_load(nir_builder * b,nir_def * instance_addr,nir_def ** out)323 nir_build_wto_matrix_load(nir_builder *b, nir_def *instance_addr, nir_def **out)
324 {
325    unsigned offset = offsetof(struct radv_bvh_instance_node, wto_matrix);
326    for (unsigned i = 0; i < 3; ++i) {
327       out[i] = nir_build_load_global(b, 4, 32, nir_iadd_imm(b, instance_addr, offset + i * 16), .align_mul = 64,
328                                      .align_offset = offset + i * 16);
329    }
330 }
331 
332 nir_def *
radv_load_vertex_position(struct radv_device * device,nir_builder * b,nir_def * instance_addr,nir_def * primitive_id,uint32_t index)333 radv_load_vertex_position(struct radv_device *device, nir_builder *b, nir_def *instance_addr, nir_def *primitive_id,
334                           uint32_t index)
335 {
336    nir_def *bvh_addr_id =
337       nir_build_load_global(b, 1, 64, nir_iadd_imm(b, instance_addr, offsetof(struct radv_bvh_instance_node, bvh_ptr)));
338    nir_def *bvh_addr = build_node_to_addr(device, b, bvh_addr_id, true);
339 
340    nir_def *offset = nir_imul_imm(b, primitive_id, sizeof(struct radv_bvh_triangle_node));
341    offset = nir_iadd_imm(b, offset, sizeof(struct radv_bvh_box32_node) + index * 3 * sizeof(float));
342 
343    return nir_build_load_global(b, 3, 32, nir_iadd(b, bvh_addr, nir_u2u64(b, offset)));
344 }
345 
346 /* When a hit is opaque the any_hit shader is skipped for this hit and the hit
347  * is assumed to be an actual hit. */
348 static nir_def *
hit_is_opaque(nir_builder * b,nir_def * sbt_offset_and_flags,const struct radv_ray_flags * ray_flags,nir_def * geometry_id_and_flags)349 hit_is_opaque(nir_builder *b, nir_def *sbt_offset_and_flags, const struct radv_ray_flags *ray_flags,
350               nir_def *geometry_id_and_flags)
351 {
352    nir_def *opaque = nir_uge_imm(b, nir_ior(b, geometry_id_and_flags, sbt_offset_and_flags),
353                                  RADV_INSTANCE_FORCE_OPAQUE | RADV_INSTANCE_NO_FORCE_NOT_OPAQUE);
354    opaque = nir_bcsel(b, ray_flags->force_opaque, nir_imm_true(b), opaque);
355    opaque = nir_bcsel(b, ray_flags->force_not_opaque, nir_imm_false(b), opaque);
356    return opaque;
357 }
358 
359 static nir_def *
create_bvh_descriptor(nir_builder * b)360 create_bvh_descriptor(nir_builder *b)
361 {
362    /* We create a BVH descriptor that covers the entire memory range. That way we can always
363     * use the same descriptor, which avoids divergence when different rays hit different
364     * instances at the cost of having to use 64-bit node ids. */
365    const uint64_t bvh_size = 1ull << 42;
366    return nir_imm_ivec4(b, 0, 1u << 31 /* Enable box sorting */, (bvh_size - 1) & 0xFFFFFFFFu,
367                         ((bvh_size - 1) >> 32) | (1u << 24 /* Return IJ for triangles */) | (1u << 31));
368 }
369 
370 static void
insert_traversal_triangle_case(struct radv_device * device,nir_builder * b,const struct radv_ray_traversal_args * args,const struct radv_ray_flags * ray_flags,nir_def * result,nir_def * bvh_node)371 insert_traversal_triangle_case(struct radv_device *device, nir_builder *b, const struct radv_ray_traversal_args *args,
372                                const struct radv_ray_flags *ray_flags, nir_def *result, nir_def *bvh_node)
373 {
374    if (!args->triangle_cb)
375       return;
376 
377    struct radv_triangle_intersection intersection;
378    intersection.t = nir_channel(b, result, 0);
379    nir_def *div = nir_channel(b, result, 1);
380    intersection.t = nir_fdiv(b, intersection.t, div);
381 
382    nir_def *tmax = nir_load_deref(b, args->vars.tmax);
383 
384    nir_push_if(b, nir_flt(b, intersection.t, tmax));
385    {
386       intersection.frontface = nir_fgt_imm(b, div, 0);
387       nir_def *switch_ccw =
388          nir_test_mask(b, nir_load_deref(b, args->vars.sbt_offset_and_flags), RADV_INSTANCE_TRIANGLE_FLIP_FACING);
389       intersection.frontface = nir_ixor(b, intersection.frontface, switch_ccw);
390 
391       nir_def *not_cull = ray_flags->no_skip_triangles;
392       nir_def *not_facing_cull =
393          nir_bcsel(b, intersection.frontface, ray_flags->no_cull_front, ray_flags->no_cull_back);
394 
395       not_cull = nir_iand(b, not_cull,
396                           nir_ior(b, not_facing_cull,
397                                   nir_test_mask(b, nir_load_deref(b, args->vars.sbt_offset_and_flags),
398                                                 RADV_INSTANCE_TRIANGLE_FACING_CULL_DISABLE)));
399 
400       nir_push_if(b, nir_iand(b,
401 
402                               nir_flt(b, args->tmin, intersection.t), not_cull));
403       {
404          intersection.base.node_addr = build_node_to_addr(device, b, bvh_node, false);
405          nir_def *triangle_info = nir_build_load_global(
406             b, 2, 32,
407             nir_iadd_imm(b, intersection.base.node_addr, offsetof(struct radv_bvh_triangle_node, triangle_id)));
408          intersection.base.primitive_id = nir_channel(b, triangle_info, 0);
409          intersection.base.geometry_id_and_flags = nir_channel(b, triangle_info, 1);
410          intersection.base.opaque = hit_is_opaque(b, nir_load_deref(b, args->vars.sbt_offset_and_flags), ray_flags,
411                                                   intersection.base.geometry_id_and_flags);
412 
413          not_cull = nir_bcsel(b, intersection.base.opaque, ray_flags->no_cull_opaque, ray_flags->no_cull_no_opaque);
414          nir_push_if(b, not_cull);
415          {
416             nir_def *divs[2] = {div, div};
417             intersection.barycentrics = nir_fdiv(b, nir_channels(b, result, 0xc), nir_vec(b, divs, 2));
418 
419             args->triangle_cb(b, &intersection, args, ray_flags);
420          }
421          nir_pop_if(b, NULL);
422       }
423       nir_pop_if(b, NULL);
424    }
425    nir_pop_if(b, NULL);
426 }
427 
428 static void
insert_traversal_aabb_case(struct radv_device * device,nir_builder * b,const struct radv_ray_traversal_args * args,const struct radv_ray_flags * ray_flags,nir_def * bvh_node)429 insert_traversal_aabb_case(struct radv_device *device, nir_builder *b, const struct radv_ray_traversal_args *args,
430                            const struct radv_ray_flags *ray_flags, nir_def *bvh_node)
431 {
432    if (!args->aabb_cb)
433       return;
434 
435    struct radv_leaf_intersection intersection;
436    intersection.node_addr = build_node_to_addr(device, b, bvh_node, false);
437    nir_def *triangle_info = nir_build_load_global(
438       b, 2, 32, nir_iadd_imm(b, intersection.node_addr, offsetof(struct radv_bvh_aabb_node, primitive_id)));
439    intersection.primitive_id = nir_channel(b, triangle_info, 0);
440    intersection.geometry_id_and_flags = nir_channel(b, triangle_info, 1);
441    intersection.opaque = hit_is_opaque(b, nir_load_deref(b, args->vars.sbt_offset_and_flags), ray_flags,
442                                        intersection.geometry_id_and_flags);
443 
444    nir_def *not_cull = nir_bcsel(b, intersection.opaque, ray_flags->no_cull_opaque, ray_flags->no_cull_no_opaque);
445    not_cull = nir_iand(b, not_cull, ray_flags->no_skip_aabbs);
446    nir_push_if(b, not_cull);
447    {
448       args->aabb_cb(b, &intersection, args);
449    }
450    nir_pop_if(b, NULL);
451 }
452 
453 static nir_def *
fetch_parent_node(nir_builder * b,nir_def * bvh,nir_def * node)454 fetch_parent_node(nir_builder *b, nir_def *bvh, nir_def *node)
455 {
456    nir_def *offset = nir_iadd_imm(b, nir_imul_imm(b, nir_udiv_imm(b, node, 8), 4), 4);
457 
458    return nir_build_load_global(b, 1, 32, nir_isub(b, bvh, nir_u2u64(b, offset)), .align_mul = 4);
459 }
460 
461 nir_def *
radv_build_ray_traversal(struct radv_device * device,nir_builder * b,const struct radv_ray_traversal_args * args)462 radv_build_ray_traversal(struct radv_device *device, nir_builder *b, const struct radv_ray_traversal_args *args)
463 {
464    nir_variable *incomplete = nir_local_variable_create(b->impl, glsl_bool_type(), "incomplete");
465    nir_store_var(b, incomplete, nir_imm_true(b), 0x1);
466 
467    nir_def *desc = create_bvh_descriptor(b);
468    nir_def *vec3ones = nir_imm_vec3(b, 1.0, 1.0, 1.0);
469 
470    struct radv_ray_flags ray_flags = {
471       .force_opaque = nir_test_mask(b, args->flags, SpvRayFlagsOpaqueKHRMask),
472       .force_not_opaque = nir_test_mask(b, args->flags, SpvRayFlagsNoOpaqueKHRMask),
473       .terminate_on_first_hit = nir_test_mask(b, args->flags, SpvRayFlagsTerminateOnFirstHitKHRMask),
474       .no_cull_front = nir_ieq_imm(b, nir_iand_imm(b, args->flags, SpvRayFlagsCullFrontFacingTrianglesKHRMask), 0),
475       .no_cull_back = nir_ieq_imm(b, nir_iand_imm(b, args->flags, SpvRayFlagsCullBackFacingTrianglesKHRMask), 0),
476       .no_cull_opaque = nir_ieq_imm(b, nir_iand_imm(b, args->flags, SpvRayFlagsCullOpaqueKHRMask), 0),
477       .no_cull_no_opaque = nir_ieq_imm(b, nir_iand_imm(b, args->flags, SpvRayFlagsCullNoOpaqueKHRMask), 0),
478       .no_skip_triangles = nir_ieq_imm(b, nir_iand_imm(b, args->flags, SpvRayFlagsSkipTrianglesKHRMask), 0),
479       .no_skip_aabbs = nir_ieq_imm(b, nir_iand_imm(b, args->flags, SpvRayFlagsSkipAABBsKHRMask), 0),
480    };
481    nir_push_loop(b);
482    {
483       nir_push_if(b, nir_ieq_imm(b, nir_load_deref(b, args->vars.current_node), RADV_BVH_INVALID_NODE));
484       {
485          /* Early exit if we never overflowed the stack, to avoid having to backtrack to
486           * the root for no reason. */
487          nir_push_if(b, nir_ilt_imm(b, nir_load_deref(b, args->vars.stack), args->stack_base + args->stack_stride));
488          {
489             nir_store_var(b, incomplete, nir_imm_false(b), 0x1);
490             nir_jump(b, nir_jump_break);
491          }
492          nir_pop_if(b, NULL);
493 
494          nir_def *stack_instance_exit =
495             nir_ige(b, nir_load_deref(b, args->vars.top_stack), nir_load_deref(b, args->vars.stack));
496          nir_def *root_instance_exit =
497             nir_ieq(b, nir_load_deref(b, args->vars.previous_node), nir_load_deref(b, args->vars.instance_bottom_node));
498          nir_if *instance_exit = nir_push_if(b, nir_ior(b, stack_instance_exit, root_instance_exit));
499          instance_exit->control = nir_selection_control_dont_flatten;
500          {
501             nir_store_deref(b, args->vars.top_stack, nir_imm_int(b, -1), 1);
502             nir_store_deref(b, args->vars.previous_node, nir_load_deref(b, args->vars.instance_top_node), 1);
503             nir_store_deref(b, args->vars.instance_bottom_node, nir_imm_int(b, RADV_BVH_NO_INSTANCE_ROOT), 1);
504 
505             nir_store_deref(b, args->vars.bvh_base, args->root_bvh_base, 1);
506             nir_store_deref(b, args->vars.origin, args->origin, 7);
507             nir_store_deref(b, args->vars.dir, args->dir, 7);
508             nir_store_deref(b, args->vars.inv_dir, nir_fdiv(b, vec3ones, args->dir), 7);
509          }
510          nir_pop_if(b, NULL);
511 
512          nir_push_if(
513             b, nir_ige(b, nir_load_deref(b, args->vars.stack_low_watermark), nir_load_deref(b, args->vars.stack)));
514          {
515             nir_def *prev = nir_load_deref(b, args->vars.previous_node);
516             nir_def *bvh_addr = build_node_to_addr(device, b, nir_load_deref(b, args->vars.bvh_base), true);
517 
518             nir_def *parent = fetch_parent_node(b, bvh_addr, prev);
519             nir_push_if(b, nir_ieq_imm(b, parent, RADV_BVH_INVALID_NODE));
520             {
521                nir_store_var(b, incomplete, nir_imm_false(b), 0x1);
522                nir_jump(b, nir_jump_break);
523             }
524             nir_pop_if(b, NULL);
525             nir_store_deref(b, args->vars.current_node, parent, 0x1);
526          }
527          nir_push_else(b, NULL);
528          {
529             nir_store_deref(b, args->vars.stack,
530                             nir_iadd_imm(b, nir_load_deref(b, args->vars.stack), -args->stack_stride), 1);
531 
532             nir_def *stack_ptr =
533                nir_umod_imm(b, nir_load_deref(b, args->vars.stack), args->stack_stride * args->stack_entries);
534             nir_def *bvh_node = args->stack_load_cb(b, stack_ptr, args);
535             nir_store_deref(b, args->vars.current_node, bvh_node, 0x1);
536             nir_store_deref(b, args->vars.previous_node, nir_imm_int(b, RADV_BVH_INVALID_NODE), 0x1);
537          }
538          nir_pop_if(b, NULL);
539       }
540       nir_push_else(b, NULL);
541       {
542          nir_store_deref(b, args->vars.previous_node, nir_imm_int(b, RADV_BVH_INVALID_NODE), 0x1);
543       }
544       nir_pop_if(b, NULL);
545 
546       nir_def *bvh_node = nir_load_deref(b, args->vars.current_node);
547 
548       nir_def *prev_node = nir_load_deref(b, args->vars.previous_node);
549       nir_store_deref(b, args->vars.previous_node, bvh_node, 0x1);
550       nir_store_deref(b, args->vars.current_node, nir_imm_int(b, RADV_BVH_INVALID_NODE), 0x1);
551 
552       nir_def *global_bvh_node = nir_iadd(b, nir_load_deref(b, args->vars.bvh_base), nir_u2u64(b, bvh_node));
553 
554       nir_def *intrinsic_result = NULL;
555       if (!radv_emulate_rt(device->physical_device)) {
556          intrinsic_result =
557             nir_bvh64_intersect_ray_amd(b, 32, desc, nir_unpack_64_2x32(b, global_bvh_node),
558                                         nir_load_deref(b, args->vars.tmax), nir_load_deref(b, args->vars.origin),
559                                         nir_load_deref(b, args->vars.dir), nir_load_deref(b, args->vars.inv_dir));
560       }
561 
562       nir_def *node_type = nir_iand_imm(b, bvh_node, 7);
563       nir_push_if(b, nir_uge_imm(b, node_type, radv_bvh_node_box16));
564       {
565          nir_push_if(b, nir_uge_imm(b, node_type, radv_bvh_node_instance));
566          {
567             nir_push_if(b, nir_ieq_imm(b, node_type, radv_bvh_node_aabb));
568             {
569                insert_traversal_aabb_case(device, b, args, &ray_flags, global_bvh_node);
570             }
571             nir_push_else(b, NULL);
572             {
573                if (args->vars.iteration_instance_count) {
574                   nir_def *iteration_instance_count = nir_load_deref(b, args->vars.iteration_instance_count);
575                   iteration_instance_count = nir_iadd_imm(b, iteration_instance_count, 1 << 16);
576                   nir_store_deref(b, args->vars.iteration_instance_count, iteration_instance_count, 0x1);
577                }
578 
579                /* instance */
580                nir_def *instance_node_addr = build_node_to_addr(device, b, global_bvh_node, false);
581                nir_store_deref(b, args->vars.instance_addr, instance_node_addr, 1);
582 
583                nir_def *instance_data =
584                   nir_build_load_global(b, 4, 32, instance_node_addr, .align_mul = 64, .align_offset = 0);
585 
586                nir_def *wto_matrix[3];
587                nir_build_wto_matrix_load(b, instance_node_addr, wto_matrix);
588 
589                nir_store_deref(b, args->vars.sbt_offset_and_flags, nir_channel(b, instance_data, 3), 1);
590 
591                if (!args->ignore_cull_mask) {
592                   nir_def *instance_and_mask = nir_channel(b, instance_data, 2);
593                   nir_push_if(b, nir_ult(b, nir_iand(b, instance_and_mask, args->cull_mask), nir_imm_int(b, 1 << 24)));
594                   {
595                      nir_jump(b, nir_jump_continue);
596                   }
597                   nir_pop_if(b, NULL);
598                }
599 
600                nir_store_deref(b, args->vars.top_stack, nir_load_deref(b, args->vars.stack), 1);
601                nir_store_deref(b, args->vars.bvh_base, nir_pack_64_2x32(b, nir_trim_vector(b, instance_data, 2)), 1);
602 
603                /* Push the instance root node onto the stack */
604                nir_store_deref(b, args->vars.current_node, nir_imm_int(b, RADV_BVH_ROOT_NODE), 0x1);
605                nir_store_deref(b, args->vars.instance_bottom_node, nir_imm_int(b, RADV_BVH_ROOT_NODE), 1);
606                nir_store_deref(b, args->vars.instance_top_node, bvh_node, 1);
607 
608                /* Transform the ray into object space */
609                nir_store_deref(b, args->vars.origin, nir_build_vec3_mat_mult(b, args->origin, wto_matrix, true), 7);
610                nir_store_deref(b, args->vars.dir, nir_build_vec3_mat_mult(b, args->dir, wto_matrix, false), 7);
611                nir_store_deref(b, args->vars.inv_dir, nir_fdiv(b, vec3ones, nir_load_deref(b, args->vars.dir)), 7);
612             }
613             nir_pop_if(b, NULL);
614          }
615          nir_push_else(b, NULL);
616          {
617             nir_def *result = intrinsic_result;
618             if (!result) {
619                /* If we didn't run the intrinsic cause the hardware didn't support it,
620                 * emulate ray/box intersection here */
621                result = intersect_ray_amd_software_box(
622                   device, b, global_bvh_node, nir_load_deref(b, args->vars.tmax), nir_load_deref(b, args->vars.origin),
623                   nir_load_deref(b, args->vars.dir), nir_load_deref(b, args->vars.inv_dir));
624             }
625 
626             /* box */
627             nir_push_if(b, nir_ieq_imm(b, prev_node, RADV_BVH_INVALID_NODE));
628             {
629                nir_def *new_nodes[4];
630                for (unsigned i = 0; i < 4; ++i)
631                   new_nodes[i] = nir_channel(b, result, i);
632 
633                for (unsigned i = 1; i < 4; ++i)
634                   nir_push_if(b, nir_ine_imm(b, new_nodes[i], RADV_BVH_INVALID_NODE));
635 
636                for (unsigned i = 4; i-- > 1;) {
637                   nir_def *stack = nir_load_deref(b, args->vars.stack);
638                   nir_def *stack_ptr = nir_umod_imm(b, stack, args->stack_entries * args->stack_stride);
639                   args->stack_store_cb(b, stack_ptr, new_nodes[i], args);
640                   nir_store_deref(b, args->vars.stack, nir_iadd_imm(b, stack, args->stack_stride), 1);
641 
642                   if (i == 1) {
643                      nir_def *new_watermark =
644                         nir_iadd_imm(b, nir_load_deref(b, args->vars.stack), -args->stack_entries * args->stack_stride);
645                      new_watermark = nir_imax(b, nir_load_deref(b, args->vars.stack_low_watermark), new_watermark);
646                      nir_store_deref(b, args->vars.stack_low_watermark, new_watermark, 0x1);
647                   }
648 
649                   nir_pop_if(b, NULL);
650                }
651                nir_store_deref(b, args->vars.current_node, new_nodes[0], 0x1);
652             }
653             nir_push_else(b, NULL);
654             {
655                nir_def *next = nir_imm_int(b, RADV_BVH_INVALID_NODE);
656                for (unsigned i = 0; i < 3; ++i) {
657                   next = nir_bcsel(b, nir_ieq(b, prev_node, nir_channel(b, result, i)), nir_channel(b, result, i + 1),
658                                    next);
659                }
660                nir_store_deref(b, args->vars.current_node, next, 0x1);
661             }
662             nir_pop_if(b, NULL);
663          }
664          nir_pop_if(b, NULL);
665       }
666       nir_push_else(b, NULL);
667       {
668          nir_def *result = intrinsic_result;
669          if (!result) {
670             /* If we didn't run the intrinsic cause the hardware didn't support it,
671              * emulate ray/tri intersection here */
672             result = intersect_ray_amd_software_tri(
673                device, b, global_bvh_node, nir_load_deref(b, args->vars.tmax), nir_load_deref(b, args->vars.origin),
674                nir_load_deref(b, args->vars.dir), nir_load_deref(b, args->vars.inv_dir));
675          }
676          insert_traversal_triangle_case(device, b, args, &ray_flags, result, global_bvh_node);
677       }
678       nir_pop_if(b, NULL);
679 
680       if (args->vars.iteration_instance_count) {
681          nir_def *iteration_instance_count = nir_load_deref(b, args->vars.iteration_instance_count);
682          iteration_instance_count = nir_iadd_imm(b, iteration_instance_count, 1);
683          nir_store_deref(b, args->vars.iteration_instance_count, iteration_instance_count, 0x1);
684       }
685    }
686    nir_pop_loop(b, NULL);
687 
688    return nir_load_var(b, incomplete);
689 }
690