• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright © 2021 Google
3  *
4  * SPDX-License-Identifier: MIT
5  */
6 
7 #include "nir/radv_nir_rt_common.h"
8 #include "bvh/bvh.h"
9 #include "radv_debug.h"
10 
11 #if AMD_LLVM_AVAILABLE
12 #include <llvm/Config/llvm-config.h>
13 #endif
14 
15 static nir_def *build_node_to_addr(struct radv_device *device, nir_builder *b, nir_def *node, bool skip_type_and);
16 
17 static void
nir_sort_hit_pair(nir_builder * b,nir_variable * var_distances,nir_variable * var_indices,uint32_t chan_1,uint32_t chan_2)18 nir_sort_hit_pair(nir_builder *b, nir_variable *var_distances, nir_variable *var_indices, uint32_t chan_1,
19                   uint32_t chan_2)
20 {
21    nir_def *ssa_distances = nir_load_var(b, var_distances);
22    nir_def *ssa_indices = nir_load_var(b, var_indices);
23    /* if (distances[chan_2] < distances[chan_1]) { */
24    nir_push_if(b, nir_flt(b, nir_channel(b, ssa_distances, chan_2), nir_channel(b, ssa_distances, chan_1)));
25    {
26       /* swap(distances[chan_2], distances[chan_1]); */
27       nir_def *new_distances[4] = {nir_undef(b, 1, 32), nir_undef(b, 1, 32), nir_undef(b, 1, 32), nir_undef(b, 1, 32)};
28       nir_def *new_indices[4] = {nir_undef(b, 1, 32), nir_undef(b, 1, 32), nir_undef(b, 1, 32), nir_undef(b, 1, 32)};
29       new_distances[chan_2] = nir_channel(b, ssa_distances, chan_1);
30       new_distances[chan_1] = nir_channel(b, ssa_distances, chan_2);
31       new_indices[chan_2] = nir_channel(b, ssa_indices, chan_1);
32       new_indices[chan_1] = nir_channel(b, ssa_indices, chan_2);
33       nir_store_var(b, var_distances, nir_vec(b, new_distances, 4), (1u << chan_1) | (1u << chan_2));
34       nir_store_var(b, var_indices, nir_vec(b, new_indices, 4), (1u << chan_1) | (1u << chan_2));
35    }
36    /* } */
37    nir_pop_if(b, NULL);
38 }
39 
40 static nir_def *
intersect_ray_amd_software_box(struct radv_device * device,nir_builder * b,nir_def * bvh_node,nir_def * ray_tmax,nir_def * origin,nir_def * dir,nir_def * inv_dir)41 intersect_ray_amd_software_box(struct radv_device *device, nir_builder *b, nir_def *bvh_node, nir_def *ray_tmax,
42                                nir_def *origin, nir_def *dir, nir_def *inv_dir)
43 {
44    const struct glsl_type *vec4_type = glsl_vector_type(GLSL_TYPE_FLOAT, 4);
45    const struct glsl_type *uvec4_type = glsl_vector_type(GLSL_TYPE_UINT, 4);
46 
47    bool old_exact = b->exact;
48    b->exact = true;
49 
50    nir_def *node_addr = build_node_to_addr(device, b, bvh_node, false);
51 
52    /* vec4 distances = vec4(INF, INF, INF, INF); */
53    nir_variable *distances = nir_variable_create(b->shader, nir_var_shader_temp, vec4_type, "distances");
54    nir_store_var(b, distances, nir_imm_vec4(b, INFINITY, INFINITY, INFINITY, INFINITY), 0xf);
55 
56    /* uvec4 child_indices = uvec4(0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff); */
57    nir_variable *child_indices = nir_variable_create(b->shader, nir_var_shader_temp, uvec4_type, "child_indices");
58    nir_store_var(b, child_indices, nir_imm_ivec4(b, 0xffffffffu, 0xffffffffu, 0xffffffffu, 0xffffffffu), 0xf);
59 
60    /* Need to remove infinities here because otherwise we get nasty NaN propagation
61     * if the direction has 0s in it. */
62    /* inv_dir = clamp(inv_dir, -FLT_MAX, FLT_MAX); */
63    inv_dir = nir_fclamp(b, inv_dir, nir_imm_float(b, -FLT_MAX), nir_imm_float(b, FLT_MAX));
64 
65    for (int i = 0; i < 4; i++) {
66       const uint32_t child_offset = offsetof(struct radv_bvh_box32_node, children[i]);
67       const uint32_t coord_offsets[2] = {
68          offsetof(struct radv_bvh_box32_node, coords[i].min.x),
69          offsetof(struct radv_bvh_box32_node, coords[i].max.x),
70       };
71 
72       /* node->children[i] -> uint */
73       nir_def *child_index = nir_build_load_global(b, 1, 32, nir_iadd_imm(b, node_addr, child_offset), .align_mul = 64,
74                                                    .align_offset = child_offset % 64);
75       /* node->coords[i][0], node->coords[i][1] -> vec3 */
76       nir_def *node_coords[2] = {
77          nir_build_load_global(b, 3, 32, nir_iadd_imm(b, node_addr, coord_offsets[0]), .align_mul = 64,
78                                .align_offset = coord_offsets[0] % 64),
79          nir_build_load_global(b, 3, 32, nir_iadd_imm(b, node_addr, coord_offsets[1]), .align_mul = 64,
80                                .align_offset = coord_offsets[1] % 64),
81       };
82 
83       /* If x of the aabb min is NaN, then this is an inactive aabb.
84        * We don't need to care about any other components being NaN as that is UB.
85        * https://www.khronos.org/registry/vulkan/specs/1.2-extensions/html/chap36.html#VkAabbPositionsKHR
86        */
87       nir_def *min_x = nir_channel(b, node_coords[0], 0);
88       nir_def *min_x_is_not_nan = nir_inot(b, nir_fneu(b, min_x, min_x)); /* NaN != NaN -> true */
89 
90       /* vec3 bound0 = (node->coords[i][0] - origin) * inv_dir; */
91       nir_def *bound0 = nir_fmul(b, nir_fsub(b, node_coords[0], origin), inv_dir);
92       /* vec3 bound1 = (node->coords[i][1] - origin) * inv_dir; */
93       nir_def *bound1 = nir_fmul(b, nir_fsub(b, node_coords[1], origin), inv_dir);
94 
95       /* float tmin = max(max(min(bound0.x, bound1.x), min(bound0.y, bound1.y)), min(bound0.z,
96        * bound1.z)); */
97       nir_def *tmin = nir_fmax(b,
98                                nir_fmax(b, nir_fmin(b, nir_channel(b, bound0, 0), nir_channel(b, bound1, 0)),
99                                         nir_fmin(b, nir_channel(b, bound0, 1), nir_channel(b, bound1, 1))),
100                                nir_fmin(b, nir_channel(b, bound0, 2), nir_channel(b, bound1, 2)));
101 
102       /* float tmax = min(min(max(bound0.x, bound1.x), max(bound0.y, bound1.y)), max(bound0.z,
103        * bound1.z)); */
104       nir_def *tmax = nir_fmin(b,
105                                nir_fmin(b, nir_fmax(b, nir_channel(b, bound0, 0), nir_channel(b, bound1, 0)),
106                                         nir_fmax(b, nir_channel(b, bound0, 1), nir_channel(b, bound1, 1))),
107                                nir_fmax(b, nir_channel(b, bound0, 2), nir_channel(b, bound1, 2)));
108 
109       /* if (!isnan(node->coords[i][0].x) && tmax >= max(0.0f, tmin) && tmin < ray_tmax) { */
110       nir_push_if(b, nir_iand(b, min_x_is_not_nan,
111                               nir_iand(b, nir_fge(b, tmax, nir_fmax(b, nir_imm_float(b, 0.0f), tmin)),
112                                        nir_flt(b, tmin, ray_tmax))));
113       {
114          /* child_indices[i] = node->children[i]; */
115          nir_def *new_child_indices[4] = {child_index, child_index, child_index, child_index};
116          nir_store_var(b, child_indices, nir_vec(b, new_child_indices, 4), 1u << i);
117 
118          /* distances[i] = tmin; */
119          nir_def *new_distances[4] = {tmin, tmin, tmin, tmin};
120          nir_store_var(b, distances, nir_vec(b, new_distances, 4), 1u << i);
121       }
122       /* } */
123       nir_pop_if(b, NULL);
124    }
125 
126    /* Sort our distances with a sorting network. */
127    nir_sort_hit_pair(b, distances, child_indices, 0, 1);
128    nir_sort_hit_pair(b, distances, child_indices, 2, 3);
129    nir_sort_hit_pair(b, distances, child_indices, 0, 2);
130    nir_sort_hit_pair(b, distances, child_indices, 1, 3);
131    nir_sort_hit_pair(b, distances, child_indices, 1, 2);
132 
133    b->exact = old_exact;
134    return nir_load_var(b, child_indices);
135 }
136 
137 static nir_def *
intersect_ray_amd_software_tri(struct radv_device * device,nir_builder * b,nir_def * bvh_node,nir_def * ray_tmax,nir_def * origin,nir_def * dir,nir_def * inv_dir)138 intersect_ray_amd_software_tri(struct radv_device *device, nir_builder *b, nir_def *bvh_node, nir_def *ray_tmax,
139                                nir_def *origin, nir_def *dir, nir_def *inv_dir)
140 {
141    const struct glsl_type *vec4_type = glsl_vector_type(GLSL_TYPE_FLOAT, 4);
142 
143    bool old_exact = b->exact;
144    b->exact = true;
145 
146    nir_def *node_addr = build_node_to_addr(device, b, bvh_node, false);
147 
148    const uint32_t coord_offsets[3] = {
149       offsetof(struct radv_bvh_triangle_node, coords[0]),
150       offsetof(struct radv_bvh_triangle_node, coords[1]),
151       offsetof(struct radv_bvh_triangle_node, coords[2]),
152    };
153 
154    /* node->coords[0], node->coords[1], node->coords[2] -> vec3 */
155    nir_def *node_coords[3] = {
156       nir_build_load_global(b, 3, 32, nir_iadd_imm(b, node_addr, coord_offsets[0]), .align_mul = 64,
157                             .align_offset = coord_offsets[0] % 64),
158       nir_build_load_global(b, 3, 32, nir_iadd_imm(b, node_addr, coord_offsets[1]), .align_mul = 64,
159                             .align_offset = coord_offsets[1] % 64),
160       nir_build_load_global(b, 3, 32, nir_iadd_imm(b, node_addr, coord_offsets[2]), .align_mul = 64,
161                             .align_offset = coord_offsets[2] % 64),
162    };
163 
164    nir_variable *result = nir_variable_create(b->shader, nir_var_shader_temp, vec4_type, "result");
165    nir_store_var(b, result, nir_imm_vec4(b, INFINITY, 1.0f, 0.0f, 0.0f), 0xf);
166 
167    /* Based on watertight Ray/Triangle intersection from
168     * http://jcgt.org/published/0002/01/05/paper.pdf */
169 
170    /* Calculate the dimension where the ray direction is largest */
171    nir_def *abs_dir = nir_fabs(b, dir);
172 
173    nir_def *abs_dirs[3] = {
174       nir_channel(b, abs_dir, 0),
175       nir_channel(b, abs_dir, 1),
176       nir_channel(b, abs_dir, 2),
177    };
178    /* Find index of greatest value of abs_dir and put that as kz. */
179    nir_def *kz = nir_bcsel(b, nir_fge(b, abs_dirs[0], abs_dirs[1]),
180                            nir_bcsel(b, nir_fge(b, abs_dirs[0], abs_dirs[2]), nir_imm_int(b, 0), nir_imm_int(b, 2)),
181                            nir_bcsel(b, nir_fge(b, abs_dirs[1], abs_dirs[2]), nir_imm_int(b, 1), nir_imm_int(b, 2)));
182    nir_def *kx = nir_imod_imm(b, nir_iadd_imm(b, kz, 1), 3);
183    nir_def *ky = nir_imod_imm(b, nir_iadd_imm(b, kx, 1), 3);
184    nir_def *k_indices[3] = {kx, ky, kz};
185    nir_def *k = nir_vec(b, k_indices, 3);
186 
187    /* Swap kx and ky dimensions to preserve winding order */
188    unsigned swap_xy_swizzle[4] = {1, 0, 2, 3};
189    k = nir_bcsel(b, nir_flt_imm(b, nir_vector_extract(b, dir, kz), 0.0f), nir_swizzle(b, k, swap_xy_swizzle, 3), k);
190 
191    kx = nir_channel(b, k, 0);
192    ky = nir_channel(b, k, 1);
193    kz = nir_channel(b, k, 2);
194 
195    /* Calculate shear constants */
196    nir_def *sz = nir_frcp(b, nir_vector_extract(b, dir, kz));
197    nir_def *sx = nir_fmul(b, nir_vector_extract(b, dir, kx), sz);
198    nir_def *sy = nir_fmul(b, nir_vector_extract(b, dir, ky), sz);
199 
200    /* Calculate vertices relative to ray origin */
201    nir_def *v_a = nir_fsub(b, node_coords[0], origin);
202    nir_def *v_b = nir_fsub(b, node_coords[1], origin);
203    nir_def *v_c = nir_fsub(b, node_coords[2], origin);
204 
205    /* Perform shear and scale */
206    nir_def *ax = nir_fsub(b, nir_vector_extract(b, v_a, kx), nir_fmul(b, sx, nir_vector_extract(b, v_a, kz)));
207    nir_def *ay = nir_fsub(b, nir_vector_extract(b, v_a, ky), nir_fmul(b, sy, nir_vector_extract(b, v_a, kz)));
208    nir_def *bx = nir_fsub(b, nir_vector_extract(b, v_b, kx), nir_fmul(b, sx, nir_vector_extract(b, v_b, kz)));
209    nir_def *by = nir_fsub(b, nir_vector_extract(b, v_b, ky), nir_fmul(b, sy, nir_vector_extract(b, v_b, kz)));
210    nir_def *cx = nir_fsub(b, nir_vector_extract(b, v_c, kx), nir_fmul(b, sx, nir_vector_extract(b, v_c, kz)));
211    nir_def *cy = nir_fsub(b, nir_vector_extract(b, v_c, ky), nir_fmul(b, sy, nir_vector_extract(b, v_c, kz)));
212 
213    ax = nir_f2f64(b, ax);
214    ay = nir_f2f64(b, ay);
215    bx = nir_f2f64(b, bx);
216    by = nir_f2f64(b, by);
217    cx = nir_f2f64(b, cx);
218    cy = nir_f2f64(b, cy);
219 
220    nir_def *u = nir_fsub(b, nir_fmul(b, cx, by), nir_fmul(b, cy, bx));
221    nir_def *v = nir_fsub(b, nir_fmul(b, ax, cy), nir_fmul(b, ay, cx));
222    nir_def *w = nir_fsub(b, nir_fmul(b, bx, ay), nir_fmul(b, by, ax));
223 
224    /* Perform edge tests. */
225    nir_def *cond_back =
226       nir_ior(b, nir_ior(b, nir_flt_imm(b, u, 0.0f), nir_flt_imm(b, v, 0.0f)), nir_flt_imm(b, w, 0.0f));
227 
228    nir_def *cond_front =
229       nir_ior(b, nir_ior(b, nir_fgt_imm(b, u, 0.0f), nir_fgt_imm(b, v, 0.0f)), nir_fgt_imm(b, w, 0.0f));
230 
231    nir_def *cond = nir_inot(b, nir_iand(b, cond_back, cond_front));
232 
233    nir_push_if(b, cond);
234    {
235       nir_def *det = nir_fadd(b, u, nir_fadd(b, v, w));
236 
237       sz = nir_f2f64(b, sz);
238 
239       v_a = nir_f2f64(b, v_a);
240       v_b = nir_f2f64(b, v_b);
241       v_c = nir_f2f64(b, v_c);
242 
243       nir_def *az = nir_fmul(b, sz, nir_vector_extract(b, v_a, kz));
244       nir_def *bz = nir_fmul(b, sz, nir_vector_extract(b, v_b, kz));
245       nir_def *cz = nir_fmul(b, sz, nir_vector_extract(b, v_c, kz));
246 
247       nir_def *t = nir_fadd(b, nir_fadd(b, nir_fmul(b, u, az), nir_fmul(b, v, bz)), nir_fmul(b, w, cz));
248 
249       nir_def *t_signed = nir_fmul(b, nir_fsign(b, det), t);
250 
251       nir_def *det_cond_front = nir_inot(b, nir_flt_imm(b, t_signed, 0.0f));
252 
253       nir_push_if(b, det_cond_front);
254       {
255          nir_def *det_abs = nir_fabs(b, det);
256 
257          t = nir_f2f32(b, nir_fdiv(b, t, det_abs));
258          v = nir_f2f32(b, nir_fdiv(b, v, det_abs));
259          w = nir_f2f32(b, nir_fdiv(b, w, det_abs));
260 
261          nir_def *indices[4] = {t, nir_f2f32(b, nir_fsign(b, det)), v, w};
262          nir_store_var(b, result, nir_vec(b, indices, 4), 0xf);
263       }
264       nir_pop_if(b, NULL);
265    }
266    nir_pop_if(b, NULL);
267 
268    b->exact = old_exact;
269    return nir_load_var(b, result);
270 }
271 
272 nir_def *
build_addr_to_node(nir_builder * b,nir_def * addr)273 build_addr_to_node(nir_builder *b, nir_def *addr)
274 {
275    const uint64_t bvh_size = 1ull << 42;
276    nir_def *node = nir_ushr_imm(b, addr, 3);
277    return nir_iand_imm(b, node, (bvh_size - 1) << 3);
278 }
279 
280 static nir_def *
build_node_to_addr(struct radv_device * device,nir_builder * b,nir_def * node,bool skip_type_and)281 build_node_to_addr(struct radv_device *device, nir_builder *b, nir_def *node, bool skip_type_and)
282 {
283    const struct radv_physical_device *pdev = radv_device_physical(device);
284    nir_def *addr = skip_type_and ? node : nir_iand_imm(b, node, ~7ull);
285    addr = nir_ishl_imm(b, addr, 3);
286    /* Assumes everything is in the top half of address space, which is true in
287     * GFX9+ for now. */
288    return pdev->info.gfx_level >= GFX9 ? nir_ior_imm(b, addr, 0xffffull << 48) : addr;
289 }
290 
291 nir_def *
nir_build_vec3_mat_mult(nir_builder * b,nir_def * vec,nir_def * matrix[],bool translation)292 nir_build_vec3_mat_mult(nir_builder *b, nir_def *vec, nir_def *matrix[], bool translation)
293 {
294    nir_def *result_components[3] = {
295       nir_channel(b, matrix[0], 3),
296       nir_channel(b, matrix[1], 3),
297       nir_channel(b, matrix[2], 3),
298    };
299    for (unsigned i = 0; i < 3; ++i) {
300       for (unsigned j = 0; j < 3; ++j) {
301          nir_def *v = nir_fmul(b, nir_channels(b, vec, 1 << j), nir_channels(b, matrix[i], 1 << j));
302          result_components[i] = (translation || j) ? nir_fadd(b, result_components[i], v) : v;
303       }
304    }
305    return nir_vec(b, result_components, 3);
306 }
307 
308 void
nir_build_wto_matrix_load(nir_builder * b,nir_def * instance_addr,nir_def ** out)309 nir_build_wto_matrix_load(nir_builder *b, nir_def *instance_addr, nir_def **out)
310 {
311    unsigned offset = offsetof(struct radv_bvh_instance_node, wto_matrix);
312    for (unsigned i = 0; i < 3; ++i) {
313       out[i] = nir_build_load_global(b, 4, 32, nir_iadd_imm(b, instance_addr, offset + i * 16), .align_mul = 64,
314                                      .align_offset = offset + i * 16);
315    }
316 }
317 
318 nir_def *
radv_load_vertex_position(struct radv_device * device,nir_builder * b,nir_def * instance_addr,nir_def * primitive_id,uint32_t index)319 radv_load_vertex_position(struct radv_device *device, nir_builder *b, nir_def *instance_addr, nir_def *primitive_id,
320                           uint32_t index)
321 {
322    nir_def *bvh_addr_id =
323       nir_build_load_global(b, 1, 64, nir_iadd_imm(b, instance_addr, offsetof(struct radv_bvh_instance_node, bvh_ptr)));
324    nir_def *bvh_addr = build_node_to_addr(device, b, bvh_addr_id, true);
325 
326    nir_def *offset = nir_imul_imm(b, primitive_id, sizeof(struct radv_bvh_triangle_node));
327    offset = nir_iadd_imm(b, offset, sizeof(struct radv_bvh_box32_node) + index * 3 * sizeof(float));
328 
329    return nir_build_load_global(b, 3, 32, nir_iadd(b, bvh_addr, nir_u2u64(b, offset)));
330 }
331 
332 /* When a hit is opaque the any_hit shader is skipped for this hit and the hit
333  * is assumed to be an actual hit. */
334 static nir_def *
hit_is_opaque(nir_builder * b,nir_def * sbt_offset_and_flags,const struct radv_ray_flags * ray_flags,nir_def * geometry_id_and_flags)335 hit_is_opaque(nir_builder *b, nir_def *sbt_offset_and_flags, const struct radv_ray_flags *ray_flags,
336               nir_def *geometry_id_and_flags)
337 {
338    nir_def *opaque = nir_uge_imm(b, nir_ior(b, geometry_id_and_flags, sbt_offset_and_flags),
339                                  RADV_INSTANCE_FORCE_OPAQUE | RADV_INSTANCE_NO_FORCE_NOT_OPAQUE);
340    opaque = nir_bcsel(b, ray_flags->force_opaque, nir_imm_true(b), opaque);
341    opaque = nir_bcsel(b, ray_flags->force_not_opaque, nir_imm_false(b), opaque);
342    return opaque;
343 }
344 
345 static nir_def *
create_bvh_descriptor(nir_builder * b,const struct radv_physical_device * pdev,struct radv_ray_flags * ray_flags)346 create_bvh_descriptor(nir_builder *b, const struct radv_physical_device *pdev, struct radv_ray_flags *ray_flags)
347 {
348    /* We create a BVH descriptor that covers the entire memory range. That way we can always
349     * use the same descriptor, which avoids divergence when different rays hit different
350     * instances at the cost of having to use 64-bit node ids. */
351    const uint64_t bvh_size = 1ull << 42;
352    nir_def *desc = nir_imm_ivec4(b, 0, 1u << 31 /* Enable box sorting */, (bvh_size - 1) & 0xFFFFFFFFu,
353                                  ((bvh_size - 1) >> 32) | (1u << 24 /* Return IJ for triangles */) | (1u << 31));
354 
355    if (pdev->info.gfx_level >= GFX11) {
356       /* Instead of the default box sorting (closest point), use largest for terminate_on_first_hit rays and midpoint
357        * for closest hit; this makes it more likely that the ray traversal will visit fewer nodes. */
358       const uint32_t box_sort_largest = 1;
359       const uint32_t box_sort_midpoint = 2;
360 
361       /* Only use largest/midpoint sorting when all invocations have the same ray flags, otherwise
362        * fall back to the default closest point. */
363       nir_def *box_sort = nir_imm_int(b, 1u << 31);
364       box_sort = nir_bcsel(b, nir_vote_any(b, 1, ray_flags->terminate_on_first_hit), box_sort,
365                            nir_imm_int(b, (box_sort_midpoint << 21) | (1u << 31)));
366       box_sort = nir_bcsel(b, nir_vote_all(b, 1, ray_flags->terminate_on_first_hit),
367                            nir_imm_int(b, (box_sort_largest << 21) | (1u << 31)), box_sort);
368 
369       desc = nir_vector_insert(b, desc, box_sort, nir_imm_int(b, 1));
370    }
371 
372    return desc;
373 }
374 
375 static void
insert_traversal_triangle_case(struct radv_device * device,nir_builder * b,const struct radv_ray_traversal_args * args,const struct radv_ray_flags * ray_flags,nir_def * result,nir_def * bvh_node)376 insert_traversal_triangle_case(struct radv_device *device, nir_builder *b, const struct radv_ray_traversal_args *args,
377                                const struct radv_ray_flags *ray_flags, nir_def *result, nir_def *bvh_node)
378 {
379    if (!args->triangle_cb)
380       return;
381 
382    struct radv_triangle_intersection intersection;
383    intersection.t = nir_channel(b, result, 0);
384    nir_def *div = nir_channel(b, result, 1);
385    intersection.t = nir_fdiv(b, intersection.t, div);
386 
387    nir_def *tmax = nir_load_deref(b, args->vars.tmax);
388 
389    nir_push_if(b, nir_flt(b, intersection.t, tmax));
390    {
391       intersection.frontface = nir_fgt_imm(b, div, 0);
392       nir_def *switch_ccw =
393          nir_test_mask(b, nir_load_deref(b, args->vars.sbt_offset_and_flags), RADV_INSTANCE_TRIANGLE_FLIP_FACING);
394       intersection.frontface = nir_ixor(b, intersection.frontface, switch_ccw);
395 
396       nir_def *not_cull = ray_flags->no_skip_triangles;
397       nir_def *not_facing_cull =
398          nir_bcsel(b, intersection.frontface, ray_flags->no_cull_front, ray_flags->no_cull_back);
399 
400       not_cull = nir_iand(b, not_cull,
401                           nir_ior(b, not_facing_cull,
402                                   nir_test_mask(b, nir_load_deref(b, args->vars.sbt_offset_and_flags),
403                                                 RADV_INSTANCE_TRIANGLE_FACING_CULL_DISABLE)));
404 
405       nir_push_if(b, nir_iand(b,
406 
407                               nir_flt(b, args->tmin, intersection.t), not_cull));
408       {
409          intersection.base.node_addr = build_node_to_addr(device, b, bvh_node, false);
410          nir_def *triangle_info = nir_build_load_global(
411             b, 2, 32,
412             nir_iadd_imm(b, intersection.base.node_addr, offsetof(struct radv_bvh_triangle_node, triangle_id)));
413          intersection.base.primitive_id = nir_channel(b, triangle_info, 0);
414          intersection.base.geometry_id_and_flags = nir_channel(b, triangle_info, 1);
415          intersection.base.opaque = hit_is_opaque(b, nir_load_deref(b, args->vars.sbt_offset_and_flags), ray_flags,
416                                                   intersection.base.geometry_id_and_flags);
417 
418          not_cull = nir_bcsel(b, intersection.base.opaque, ray_flags->no_cull_opaque, ray_flags->no_cull_no_opaque);
419          nir_push_if(b, not_cull);
420          {
421             nir_def *divs[2] = {div, div};
422             intersection.barycentrics = nir_fdiv(b, nir_channels(b, result, 0xc), nir_vec(b, divs, 2));
423 
424             args->triangle_cb(b, &intersection, args, ray_flags);
425          }
426          nir_pop_if(b, NULL);
427       }
428       nir_pop_if(b, NULL);
429    }
430    nir_pop_if(b, NULL);
431 }
432 
433 static void
insert_traversal_aabb_case(struct radv_device * device,nir_builder * b,const struct radv_ray_traversal_args * args,const struct radv_ray_flags * ray_flags,nir_def * bvh_node)434 insert_traversal_aabb_case(struct radv_device *device, nir_builder *b, const struct radv_ray_traversal_args *args,
435                            const struct radv_ray_flags *ray_flags, nir_def *bvh_node)
436 {
437    if (!args->aabb_cb)
438       return;
439 
440    nir_push_if(b, ray_flags->no_skip_aabbs);
441    {
442       struct radv_leaf_intersection intersection;
443       intersection.node_addr = build_node_to_addr(device, b, bvh_node, false);
444       nir_def *triangle_info = nir_build_load_global(
445          b, 2, 32, nir_iadd_imm(b, intersection.node_addr, offsetof(struct radv_bvh_aabb_node, primitive_id)));
446       intersection.primitive_id = nir_channel(b, triangle_info, 0);
447       intersection.geometry_id_and_flags = nir_channel(b, triangle_info, 1);
448       intersection.opaque = hit_is_opaque(b, nir_load_deref(b, args->vars.sbt_offset_and_flags), ray_flags,
449                                           intersection.geometry_id_and_flags);
450 
451       nir_push_if(b, nir_bcsel(b, intersection.opaque, ray_flags->no_cull_opaque, ray_flags->no_cull_no_opaque));
452       {
453          args->aabb_cb(b, &intersection, args);
454       }
455       nir_pop_if(b, NULL);
456    }
457    nir_pop_if(b, NULL);
458 }
459 
460 static nir_def *
fetch_parent_node(nir_builder * b,nir_def * bvh,nir_def * node)461 fetch_parent_node(nir_builder *b, nir_def *bvh, nir_def *node)
462 {
463    nir_def *offset = nir_iadd_imm(b, nir_imul_imm(b, nir_udiv_imm(b, node, 8), 4), 4);
464 
465    return nir_build_load_global(b, 1, 32, nir_isub(b, bvh, nir_u2u64(b, offset)), .align_mul = 4);
466 }
467 
468 static nir_def *
radv_test_flag(nir_builder * b,const struct radv_ray_traversal_args * args,uint32_t flag,bool set)469 radv_test_flag(nir_builder *b, const struct radv_ray_traversal_args *args, uint32_t flag, bool set)
470 {
471    nir_def *result;
472    if (args->set_flags & flag)
473       result = nir_imm_true(b);
474    else if (args->unset_flags & flag)
475       result = nir_imm_false(b);
476    else
477       result = nir_test_mask(b, args->flags, flag);
478 
479    return set ? result : nir_inot(b, result);
480 }
481 
482 nir_def *
radv_build_ray_traversal(struct radv_device * device,nir_builder * b,const struct radv_ray_traversal_args * args)483 radv_build_ray_traversal(struct radv_device *device, nir_builder *b, const struct radv_ray_traversal_args *args)
484 {
485    const struct radv_physical_device *pdev = radv_device_physical(device);
486    nir_variable *incomplete = nir_local_variable_create(b->impl, glsl_bool_type(), "incomplete");
487    nir_store_var(b, incomplete, nir_imm_true(b), 0x1);
488 
489    struct radv_ray_flags ray_flags = {
490       .force_opaque = radv_test_flag(b, args, SpvRayFlagsOpaqueKHRMask, true),
491       .force_not_opaque = radv_test_flag(b, args, SpvRayFlagsNoOpaqueKHRMask, true),
492       .terminate_on_first_hit = radv_test_flag(b, args, SpvRayFlagsTerminateOnFirstHitKHRMask, true),
493       .no_cull_front = radv_test_flag(b, args, SpvRayFlagsCullFrontFacingTrianglesKHRMask, false),
494       .no_cull_back = radv_test_flag(b, args, SpvRayFlagsCullBackFacingTrianglesKHRMask, false),
495       .no_cull_opaque = radv_test_flag(b, args, SpvRayFlagsCullOpaqueKHRMask, false),
496       .no_cull_no_opaque = radv_test_flag(b, args, SpvRayFlagsCullNoOpaqueKHRMask, false),
497       .no_skip_triangles = radv_test_flag(b, args, SpvRayFlagsSkipTrianglesKHRMask, false),
498       .no_skip_aabbs = radv_test_flag(b, args, SpvRayFlagsSkipAABBsKHRMask, false),
499    };
500 
501    nir_def *desc = create_bvh_descriptor(b, pdev, &ray_flags);
502    nir_def *vec3ones = nir_imm_vec3(b, 1.0, 1.0, 1.0);
503 
504    nir_push_loop(b);
505    {
506       nir_push_if(b, nir_ieq_imm(b, nir_load_deref(b, args->vars.current_node), RADV_BVH_INVALID_NODE));
507       {
508          /* Early exit if we never overflowed the stack, to avoid having to backtrack to
509           * the root for no reason. */
510          nir_push_if(b, nir_ilt_imm(b, nir_load_deref(b, args->vars.stack), args->stack_base + args->stack_stride));
511          {
512             nir_store_var(b, incomplete, nir_imm_false(b), 0x1);
513             nir_jump(b, nir_jump_break);
514          }
515          nir_pop_if(b, NULL);
516 
517          nir_def *stack_instance_exit =
518             nir_ige(b, nir_load_deref(b, args->vars.top_stack), nir_load_deref(b, args->vars.stack));
519          nir_def *root_instance_exit =
520             nir_ieq(b, nir_load_deref(b, args->vars.previous_node), nir_load_deref(b, args->vars.instance_bottom_node));
521          nir_if *instance_exit = nir_push_if(b, nir_ior(b, stack_instance_exit, root_instance_exit));
522          instance_exit->control = nir_selection_control_dont_flatten;
523          {
524             nir_store_deref(b, args->vars.top_stack, nir_imm_int(b, -1), 1);
525             nir_store_deref(b, args->vars.previous_node, nir_load_deref(b, args->vars.instance_top_node), 1);
526             nir_store_deref(b, args->vars.instance_bottom_node, nir_imm_int(b, RADV_BVH_NO_INSTANCE_ROOT), 1);
527 
528             nir_store_deref(b, args->vars.bvh_base, args->root_bvh_base, 1);
529             nir_store_deref(b, args->vars.origin, args->origin, 7);
530             nir_store_deref(b, args->vars.dir, args->dir, 7);
531             nir_store_deref(b, args->vars.inv_dir, nir_fdiv(b, vec3ones, args->dir), 7);
532          }
533          nir_pop_if(b, NULL);
534 
535          nir_push_if(
536             b, nir_ige(b, nir_load_deref(b, args->vars.stack_low_watermark), nir_load_deref(b, args->vars.stack)));
537          {
538             nir_def *prev = nir_load_deref(b, args->vars.previous_node);
539             nir_def *bvh_addr = build_node_to_addr(device, b, nir_load_deref(b, args->vars.bvh_base), true);
540 
541             nir_def *parent = fetch_parent_node(b, bvh_addr, prev);
542             nir_push_if(b, nir_ieq_imm(b, parent, RADV_BVH_INVALID_NODE));
543             {
544                nir_store_var(b, incomplete, nir_imm_false(b), 0x1);
545                nir_jump(b, nir_jump_break);
546             }
547             nir_pop_if(b, NULL);
548             nir_store_deref(b, args->vars.current_node, parent, 0x1);
549          }
550          nir_push_else(b, NULL);
551          {
552             nir_store_deref(b, args->vars.stack,
553                             nir_iadd_imm(b, nir_load_deref(b, args->vars.stack), -args->stack_stride), 1);
554 
555             nir_def *stack_ptr =
556                nir_umod_imm(b, nir_load_deref(b, args->vars.stack), args->stack_stride * args->stack_entries);
557             nir_def *bvh_node = args->stack_load_cb(b, stack_ptr, args);
558             nir_store_deref(b, args->vars.current_node, bvh_node, 0x1);
559             nir_store_deref(b, args->vars.previous_node, nir_imm_int(b, RADV_BVH_INVALID_NODE), 0x1);
560          }
561          nir_pop_if(b, NULL);
562       }
563       nir_push_else(b, NULL);
564       {
565          nir_store_deref(b, args->vars.previous_node, nir_imm_int(b, RADV_BVH_INVALID_NODE), 0x1);
566       }
567       nir_pop_if(b, NULL);
568 
569       nir_def *bvh_node = nir_load_deref(b, args->vars.current_node);
570 
571       nir_def *prev_node = nir_load_deref(b, args->vars.previous_node);
572       nir_store_deref(b, args->vars.previous_node, bvh_node, 0x1);
573       nir_store_deref(b, args->vars.current_node, nir_imm_int(b, RADV_BVH_INVALID_NODE), 0x1);
574 
575       nir_def *global_bvh_node = nir_iadd(b, nir_load_deref(b, args->vars.bvh_base), nir_u2u64(b, bvh_node));
576 
577       nir_def *intrinsic_result = NULL;
578       if (pdev->info.gfx_level >= GFX10_3 && !radv_emulate_rt(pdev)) {
579          intrinsic_result =
580             nir_bvh64_intersect_ray_amd(b, 32, desc, nir_unpack_64_2x32(b, global_bvh_node),
581                                         nir_load_deref(b, args->vars.tmax), nir_load_deref(b, args->vars.origin),
582                                         nir_load_deref(b, args->vars.dir), nir_load_deref(b, args->vars.inv_dir));
583       }
584 
585       nir_push_if(b, nir_test_mask(b, bvh_node, BITFIELD64_BIT(ffs(radv_bvh_node_box16) - 1)));
586       {
587          nir_push_if(b, nir_test_mask(b, bvh_node, BITFIELD64_BIT(ffs(radv_bvh_node_instance) - 1)));
588          {
589             nir_push_if(b, nir_test_mask(b, bvh_node, BITFIELD64_BIT(ffs(radv_bvh_node_aabb) - 1)));
590             {
591                insert_traversal_aabb_case(device, b, args, &ray_flags, global_bvh_node);
592             }
593             nir_push_else(b, NULL);
594             {
595                if (args->vars.iteration_instance_count) {
596                   nir_def *iteration_instance_count = nir_load_deref(b, args->vars.iteration_instance_count);
597                   iteration_instance_count = nir_iadd_imm(b, iteration_instance_count, 1 << 16);
598                   nir_store_deref(b, args->vars.iteration_instance_count, iteration_instance_count, 0x1);
599                }
600 
601                /* instance */
602                nir_def *instance_node_addr = build_node_to_addr(device, b, global_bvh_node, false);
603                nir_store_deref(b, args->vars.instance_addr, instance_node_addr, 1);
604 
605                nir_def *instance_data =
606                   nir_build_load_global(b, 4, 32, instance_node_addr, .align_mul = 64, .align_offset = 0);
607 
608                nir_def *wto_matrix[3];
609                nir_build_wto_matrix_load(b, instance_node_addr, wto_matrix);
610 
611                nir_store_deref(b, args->vars.sbt_offset_and_flags, nir_channel(b, instance_data, 3), 1);
612 
613                if (!args->ignore_cull_mask) {
614                   nir_def *instance_and_mask = nir_channel(b, instance_data, 2);
615                   nir_push_if(b, nir_ult(b, nir_iand(b, instance_and_mask, args->cull_mask), nir_imm_int(b, 1 << 24)));
616                   {
617                      nir_jump(b, nir_jump_continue);
618                   }
619                   nir_pop_if(b, NULL);
620                }
621 
622                nir_store_deref(b, args->vars.top_stack, nir_load_deref(b, args->vars.stack), 1);
623                nir_store_deref(b, args->vars.bvh_base, nir_pack_64_2x32(b, nir_trim_vector(b, instance_data, 2)), 1);
624 
625                /* Push the instance root node onto the stack */
626                nir_store_deref(b, args->vars.current_node, nir_imm_int(b, RADV_BVH_ROOT_NODE), 0x1);
627                nir_store_deref(b, args->vars.instance_bottom_node, nir_imm_int(b, RADV_BVH_ROOT_NODE), 1);
628                nir_store_deref(b, args->vars.instance_top_node, bvh_node, 1);
629 
630                /* Transform the ray into object space */
631                nir_store_deref(b, args->vars.origin, nir_build_vec3_mat_mult(b, args->origin, wto_matrix, true), 7);
632                nir_store_deref(b, args->vars.dir, nir_build_vec3_mat_mult(b, args->dir, wto_matrix, false), 7);
633                nir_store_deref(b, args->vars.inv_dir, nir_fdiv(b, vec3ones, nir_load_deref(b, args->vars.dir)), 7);
634             }
635             nir_pop_if(b, NULL);
636          }
637          nir_push_else(b, NULL);
638          {
639             nir_def *result = intrinsic_result;
640             if (!result) {
641                /* If we didn't run the intrinsic cause the hardware didn't support it,
642                 * emulate ray/box intersection here */
643                result = intersect_ray_amd_software_box(
644                   device, b, global_bvh_node, nir_load_deref(b, args->vars.tmax), nir_load_deref(b, args->vars.origin),
645                   nir_load_deref(b, args->vars.dir), nir_load_deref(b, args->vars.inv_dir));
646             }
647 
648             /* box */
649             nir_push_if(b, nir_ieq_imm(b, prev_node, RADV_BVH_INVALID_NODE));
650             {
651                nir_def *new_nodes[4];
652                for (unsigned i = 0; i < 4; ++i)
653                   new_nodes[i] = nir_channel(b, result, i);
654 
655                for (unsigned i = 1; i < 4; ++i)
656                   nir_push_if(b, nir_ine_imm(b, new_nodes[i], RADV_BVH_INVALID_NODE));
657 
658                for (unsigned i = 4; i-- > 1;) {
659                   nir_def *stack = nir_load_deref(b, args->vars.stack);
660                   nir_def *stack_ptr = nir_umod_imm(b, stack, args->stack_entries * args->stack_stride);
661                   args->stack_store_cb(b, stack_ptr, new_nodes[i], args);
662                   nir_store_deref(b, args->vars.stack, nir_iadd_imm(b, stack, args->stack_stride), 1);
663 
664                   if (i == 1) {
665                      nir_def *new_watermark =
666                         nir_iadd_imm(b, nir_load_deref(b, args->vars.stack), -args->stack_entries * args->stack_stride);
667                      new_watermark = nir_imax(b, nir_load_deref(b, args->vars.stack_low_watermark), new_watermark);
668                      nir_store_deref(b, args->vars.stack_low_watermark, new_watermark, 0x1);
669                   }
670 
671                   nir_pop_if(b, NULL);
672                }
673                nir_store_deref(b, args->vars.current_node, new_nodes[0], 0x1);
674             }
675             nir_push_else(b, NULL);
676             {
677                nir_def *next = nir_imm_int(b, RADV_BVH_INVALID_NODE);
678                for (unsigned i = 0; i < 3; ++i) {
679                   next = nir_bcsel(b, nir_ieq(b, prev_node, nir_channel(b, result, i)), nir_channel(b, result, i + 1),
680                                    next);
681                }
682                nir_store_deref(b, args->vars.current_node, next, 0x1);
683             }
684             nir_pop_if(b, NULL);
685          }
686          nir_pop_if(b, NULL);
687       }
688       nir_push_else(b, NULL);
689       {
690          nir_def *result = intrinsic_result;
691          if (!result) {
692             /* If we didn't run the intrinsic cause the hardware didn't support it,
693              * emulate ray/tri intersection here */
694             result = intersect_ray_amd_software_tri(
695                device, b, global_bvh_node, nir_load_deref(b, args->vars.tmax), nir_load_deref(b, args->vars.origin),
696                nir_load_deref(b, args->vars.dir), nir_load_deref(b, args->vars.inv_dir));
697          }
698          insert_traversal_triangle_case(device, b, args, &ray_flags, result, global_bvh_node);
699       }
700       nir_pop_if(b, NULL);
701 
702       if (args->vars.iteration_instance_count) {
703          nir_def *iteration_instance_count = nir_load_deref(b, args->vars.iteration_instance_count);
704          iteration_instance_count = nir_iadd_imm(b, iteration_instance_count, 1);
705          nir_store_deref(b, args->vars.iteration_instance_count, iteration_instance_count, 0x1);
706       }
707    }
708    nir_pop_loop(b, NULL);
709 
710    return nir_load_var(b, incomplete);
711 }
712