• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright © 2020 Intel Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 #pragma once
25 
26 /* This file provides helpers to access memory based data structures that the
27  * RT hardware reads/writes and their locations.
28  *
29  * See also "Memory Based Data Structures for Ray Tracing" (BSpec 47547) and
30  * "Ray Tracing Address Computation for Memory Resident Structures" (BSpec
31  * 47550).
32  */
33 
34 #include "brw_rt.h"
35 #include "nir_builder.h"
36 
37 #define is_access_for_builder(b) \
38    ((b)->shader->info.stage == MESA_SHADER_FRAGMENT ? \
39     ACCESS_INCLUDE_HELPERS : 0)
40 
41 static inline nir_def *
brw_nir_rt_load(nir_builder * b,nir_def * addr,unsigned align,unsigned components,unsigned bit_size)42 brw_nir_rt_load(nir_builder *b, nir_def *addr, unsigned align,
43                 unsigned components, unsigned bit_size)
44 {
45    return nir_build_load_global(b, components, bit_size, addr,
46                                 .align_mul = align,
47                                 .access = is_access_for_builder(b));
48 }
49 
50 static inline void
brw_nir_rt_store(nir_builder * b,nir_def * addr,unsigned align,nir_def * value,unsigned write_mask)51 brw_nir_rt_store(nir_builder *b, nir_def *addr, unsigned align,
52                  nir_def *value, unsigned write_mask)
53 {
54    nir_build_store_global(b, value, addr,
55                           .align_mul = align,
56                           .write_mask = (write_mask) &
57                                         BITFIELD_MASK(value->num_components),
58                           .access = is_access_for_builder(b));
59 }
60 
61 static inline nir_def *
brw_nir_rt_load_const(nir_builder * b,unsigned components,nir_def * addr)62 brw_nir_rt_load_const(nir_builder *b, unsigned components, nir_def *addr)
63 {
64    return nir_load_global_constant_uniform_block_intel(
65       b, components, 32, addr,
66       .access = ACCESS_CAN_REORDER | ACCESS_NON_WRITEABLE,
67       .align_mul = 64);
68 }
69 
70 static inline nir_def *
brw_load_btd_dss_id(nir_builder * b)71 brw_load_btd_dss_id(nir_builder *b)
72 {
73    return nir_load_topology_id_intel(b, .base = BRW_TOPOLOGY_ID_DSS);
74 }
75 
76 static inline nir_def *
brw_load_eu_thread_simd(nir_builder * b)77 brw_load_eu_thread_simd(nir_builder *b)
78 {
79    return nir_load_topology_id_intel(b, .base = BRW_TOPOLOGY_ID_EU_THREAD_SIMD);
80 }
81 
82 static inline nir_def *
brw_nir_rt_async_stack_id(nir_builder * b)83 brw_nir_rt_async_stack_id(nir_builder *b)
84 {
85    return nir_iadd(b, nir_umul_32x16(b, nir_load_ray_num_dss_rt_stacks_intel(b),
86                                         brw_load_btd_dss_id(b)),
87                       nir_load_btd_stack_id_intel(b));
88 }
89 
90 static inline nir_def *
brw_nir_rt_sync_stack_id(nir_builder * b)91 brw_nir_rt_sync_stack_id(nir_builder *b)
92 {
93    return brw_load_eu_thread_simd(b);
94 }
95 
96 /* We have our own load/store scratch helpers because they emit a global
97  * memory read or write based on the scratch_base_ptr system value rather
98  * than a load/store_scratch intrinsic.
99  */
100 static inline nir_def *
brw_nir_rt_load_scratch(nir_builder * b,uint32_t offset,unsigned align,unsigned num_components,unsigned bit_size)101 brw_nir_rt_load_scratch(nir_builder *b, uint32_t offset, unsigned align,
102                         unsigned num_components, unsigned bit_size)
103 {
104    nir_def *addr =
105       nir_iadd_imm(b, nir_load_scratch_base_ptr(b, 1, 64, 1), offset);
106    return brw_nir_rt_load(b, addr, MIN2(align, BRW_BTD_STACK_ALIGN),
107                              num_components, bit_size);
108 }
109 
110 static inline void
brw_nir_rt_store_scratch(nir_builder * b,uint32_t offset,unsigned align,nir_def * value,nir_component_mask_t write_mask)111 brw_nir_rt_store_scratch(nir_builder *b, uint32_t offset, unsigned align,
112                          nir_def *value, nir_component_mask_t write_mask)
113 {
114    nir_def *addr =
115       nir_iadd_imm(b, nir_load_scratch_base_ptr(b, 1, 64, 1), offset);
116    brw_nir_rt_store(b, addr, MIN2(align, BRW_BTD_STACK_ALIGN),
117                     value, write_mask);
118 }
119 
120 static inline void
brw_nir_btd_spawn(nir_builder * b,nir_def * record_addr)121 brw_nir_btd_spawn(nir_builder *b, nir_def *record_addr)
122 {
123    nir_btd_spawn_intel(b, nir_load_btd_global_arg_addr_intel(b), record_addr);
124 }
125 
126 static inline void
brw_nir_btd_retire(nir_builder * b)127 brw_nir_btd_retire(nir_builder *b)
128 {
129    nir_btd_retire_intel(b);
130 }
131 
132 /** This is a pseudo-op which does a bindless return
133  *
134  * It loads the return address from the stack and calls btd_spawn to spawn the
135  * resume shader.
136  */
137 static inline void
brw_nir_btd_return(struct nir_builder * b)138 brw_nir_btd_return(struct nir_builder *b)
139 {
140    nir_def *resume_addr =
141       brw_nir_rt_load_scratch(b, BRW_BTD_STACK_RESUME_BSR_ADDR_OFFSET,
142                               8 /* align */, 1, 64);
143    brw_nir_btd_spawn(b, resume_addr);
144 }
145 
146 static inline void
assert_def_size(nir_def * def,unsigned num_components,unsigned bit_size)147 assert_def_size(nir_def *def, unsigned num_components, unsigned bit_size)
148 {
149    assert(def->num_components == num_components);
150    assert(def->bit_size == bit_size);
151 }
152 
153 static inline nir_def *
brw_nir_num_rt_stacks(nir_builder * b,const struct intel_device_info * devinfo)154 brw_nir_num_rt_stacks(nir_builder *b,
155                       const struct intel_device_info *devinfo)
156 {
157    return nir_imul_imm(b, nir_load_ray_num_dss_rt_stacks_intel(b),
158                           intel_device_info_dual_subslice_id_bound(devinfo));
159 }
160 
161 static inline nir_def *
brw_nir_rt_sw_hotzone_addr(nir_builder * b,const struct intel_device_info * devinfo)162 brw_nir_rt_sw_hotzone_addr(nir_builder *b,
163                            const struct intel_device_info *devinfo)
164 {
165    nir_def *offset32 =
166       nir_imul_imm(b, brw_nir_rt_async_stack_id(b),
167                       BRW_RT_SIZEOF_HOTZONE);
168 
169    offset32 = nir_iadd(b, offset32, nir_ineg(b,
170       nir_imul_imm(b, brw_nir_num_rt_stacks(b, devinfo),
171                       BRW_RT_SIZEOF_HOTZONE)));
172 
173    return nir_iadd(b, nir_load_ray_base_mem_addr_intel(b),
174                       nir_i2i64(b, offset32));
175 }
176 
177 static inline nir_def *
brw_nir_rt_sync_stack_addr(nir_builder * b,nir_def * base_mem_addr,nir_def * num_dss_rt_stacks)178 brw_nir_rt_sync_stack_addr(nir_builder *b,
179                            nir_def *base_mem_addr,
180                            nir_def *num_dss_rt_stacks)
181 {
182    /* Bspec 47547 (Xe) and 56936 (Xe2+) say:
183     *    For Ray queries (Synchronous Ray Tracing), the formula is similar but
184     *    goes down from rtMemBasePtr :
185     *
186     *       syncBase  = RTDispatchGlobals.rtMemBasePtr
187     *                 - (DSSID * NUM_SIMD_LANES_PER_DSS + SyncStackID + 1)
188     *                 * syncStackSize
189     *
190     *    We assume that we can calculate a 32-bit offset first and then add it
191     *    to the 64-bit base address at the end.
192     *
193     * However, on HSD 14020275151 it's clarified that the HW uses
194     * NUM_SYNC_STACKID_PER_DSS instead.
195     */
196    nir_def *offset32 =
197       nir_imul(b,
198                nir_iadd(b,
199                         nir_imul(b, brw_load_btd_dss_id(b),
200                                     num_dss_rt_stacks),
201                         nir_iadd_imm(b, brw_nir_rt_sync_stack_id(b), 1)),
202                nir_imm_int(b, BRW_RT_SIZEOF_RAY_QUERY));
203    return nir_isub(b, base_mem_addr, nir_u2u64(b, offset32));
204 }
205 
206 static inline nir_def *
brw_nir_rt_stack_addr(nir_builder * b)207 brw_nir_rt_stack_addr(nir_builder *b)
208 {
209    /* From the BSpec "Address Computation for Memory Based Data Structures:
210     * Ray and TraversalStack (Async Ray Tracing)":
211     *
212     *    stackBase = RTDispatchGlobals.rtMemBasePtr
213     *              + (DSSID * RTDispatchGlobals.numDSSRTStacks + stackID)
214     *              * RTDispatchGlobals.stackSizePerRay // 64B aligned
215     *
216     * We assume that we can calculate a 32-bit offset first and then add it
217     * to the 64-bit base address at the end.
218     */
219    nir_def *offset32 =
220       nir_imul(b, brw_nir_rt_async_stack_id(b),
221                   nir_load_ray_hw_stack_size_intel(b));
222    return nir_iadd(b, nir_load_ray_base_mem_addr_intel(b),
223                       nir_u2u64(b, offset32));
224 }
225 
226 static inline nir_def *
brw_nir_rt_mem_hit_addr_from_addr(nir_builder * b,nir_def * stack_addr,bool committed)227 brw_nir_rt_mem_hit_addr_from_addr(nir_builder *b,
228                         nir_def *stack_addr,
229                         bool committed)
230 {
231    return nir_iadd_imm(b, stack_addr, committed ? 0 : BRW_RT_SIZEOF_HIT_INFO);
232 }
233 
234 static inline nir_def *
brw_nir_rt_mem_hit_addr(nir_builder * b,bool committed)235 brw_nir_rt_mem_hit_addr(nir_builder *b, bool committed)
236 {
237    return nir_iadd_imm(b, brw_nir_rt_stack_addr(b),
238                           committed ? 0 : BRW_RT_SIZEOF_HIT_INFO);
239 }
240 
241 static inline nir_def *
brw_nir_rt_hit_attrib_data_addr(nir_builder * b)242 brw_nir_rt_hit_attrib_data_addr(nir_builder *b)
243 {
244    return nir_iadd_imm(b, brw_nir_rt_stack_addr(b),
245                           BRW_RT_OFFSETOF_HIT_ATTRIB_DATA);
246 }
247 
248 static inline nir_def *
brw_nir_rt_mem_ray_addr(nir_builder * b,nir_def * stack_addr,enum brw_rt_bvh_level bvh_level)249 brw_nir_rt_mem_ray_addr(nir_builder *b,
250                         nir_def *stack_addr,
251                         enum brw_rt_bvh_level bvh_level)
252 {
253    /* From the BSpec "Address Computation for Memory Based Data Structures:
254     * Ray and TraversalStack (Async Ray Tracing)":
255     *
256     *    rayBase = stackBase + sizeof(HitInfo) * 2 // 64B aligned
257     *    rayPtr  = rayBase + bvhLevel * sizeof(Ray); // 64B aligned
258     *
259     * In Vulkan, we always have exactly two levels of BVH: World and Object.
260     */
261    uint32_t offset = BRW_RT_SIZEOF_HIT_INFO * 2 +
262                      bvh_level * BRW_RT_SIZEOF_RAY;
263    return nir_iadd_imm(b, stack_addr, offset);
264 }
265 
266 static inline nir_def *
brw_nir_rt_sw_stack_addr(nir_builder * b,const struct intel_device_info * devinfo)267 brw_nir_rt_sw_stack_addr(nir_builder *b,
268                          const struct intel_device_info *devinfo)
269 {
270    nir_def *addr = nir_load_ray_base_mem_addr_intel(b);
271 
272    nir_def *offset32 = nir_imul(b, brw_nir_num_rt_stacks(b, devinfo),
273                                        nir_load_ray_hw_stack_size_intel(b));
274    addr = nir_iadd(b, addr, nir_u2u64(b, offset32));
275 
276    nir_def *offset_in_stack =
277       nir_imul(b, nir_u2u64(b, brw_nir_rt_async_stack_id(b)),
278                   nir_u2u64(b, nir_load_ray_sw_stack_size_intel(b)));
279 
280    return nir_iadd(b, addr, offset_in_stack);
281 }
282 
283 static inline nir_def *
nir_unpack_64_4x16_split_z(nir_builder * b,nir_def * val)284 nir_unpack_64_4x16_split_z(nir_builder *b, nir_def *val)
285 {
286    return nir_unpack_32_2x16_split_x(b, nir_unpack_64_2x32_split_y(b, val));
287 }
288 
289 struct brw_nir_rt_globals_defs {
290    nir_def *base_mem_addr;
291    nir_def *call_stack_handler_addr;
292    nir_def *hw_stack_size;
293    nir_def *num_dss_rt_stacks;
294    nir_def *hit_sbt_addr;
295    nir_def *hit_sbt_stride;
296    nir_def *miss_sbt_addr;
297    nir_def *miss_sbt_stride;
298    nir_def *sw_stack_size;
299    nir_def *launch_size;
300    nir_def *call_sbt_addr;
301    nir_def *call_sbt_stride;
302    nir_def *resume_sbt_addr;
303 };
304 
305 static inline void
brw_nir_rt_load_globals_addr(nir_builder * b,struct brw_nir_rt_globals_defs * defs,nir_def * addr)306 brw_nir_rt_load_globals_addr(nir_builder *b,
307                              struct brw_nir_rt_globals_defs *defs,
308                              nir_def *addr)
309 {
310    nir_def *data;
311    data = brw_nir_rt_load_const(b, 16, addr);
312    defs->base_mem_addr = nir_pack_64_2x32(b, nir_trim_vector(b, data, 2));
313 
314    defs->call_stack_handler_addr =
315       nir_pack_64_2x32(b, nir_channels(b, data, 0x3 << 2));
316 
317    defs->hw_stack_size = nir_channel(b, data, 4);
318    defs->num_dss_rt_stacks = nir_iand_imm(b, nir_channel(b, data, 5), 0xffff);
319    defs->hit_sbt_addr =
320       nir_pack_64_2x32_split(b, nir_channel(b, data, 8),
321                                 nir_extract_i16(b, nir_channel(b, data, 9),
322                                                    nir_imm_int(b, 0)));
323    defs->hit_sbt_stride =
324       nir_unpack_32_2x16_split_y(b, nir_channel(b, data, 9));
325    defs->miss_sbt_addr =
326       nir_pack_64_2x32_split(b, nir_channel(b, data, 10),
327                                 nir_extract_i16(b, nir_channel(b, data, 11),
328                                                    nir_imm_int(b, 0)));
329    defs->miss_sbt_stride =
330       nir_unpack_32_2x16_split_y(b, nir_channel(b, data, 11));
331    defs->sw_stack_size = nir_channel(b, data, 12);
332    defs->launch_size = nir_channels(b, data, 0x7u << 13);
333 
334    data = brw_nir_rt_load_const(b, 8, nir_iadd_imm(b, addr, 64));
335    defs->call_sbt_addr =
336       nir_pack_64_2x32_split(b, nir_channel(b, data, 0),
337                                 nir_extract_i16(b, nir_channel(b, data, 1),
338                                                    nir_imm_int(b, 0)));
339    defs->call_sbt_stride =
340       nir_unpack_32_2x16_split_y(b, nir_channel(b, data, 1));
341 
342    defs->resume_sbt_addr =
343       nir_pack_64_2x32(b, nir_channels(b, data, 0x3 << 2));
344 }
345 
346 static inline void
brw_nir_rt_load_globals(nir_builder * b,struct brw_nir_rt_globals_defs * defs)347 brw_nir_rt_load_globals(nir_builder *b,
348                         struct brw_nir_rt_globals_defs *defs)
349 {
350    brw_nir_rt_load_globals_addr(b, defs, nir_load_btd_global_arg_addr_intel(b));
351 }
352 
353 static inline nir_def *
brw_nir_rt_unpack_leaf_ptr(nir_builder * b,nir_def * vec2)354 brw_nir_rt_unpack_leaf_ptr(nir_builder *b, nir_def *vec2)
355 {
356    /* Hit record leaf pointers are 42-bit and assumed to be in 64B chunks.
357     * This leaves 22 bits at the top for other stuff.
358     */
359    nir_def *ptr64 = nir_imul_imm(b, nir_pack_64_2x32(b, vec2), 64);
360 
361    /* The top 16 bits (remember, we shifted by 6 already) contain garbage
362     * that we need to get rid of.
363     */
364    nir_def *ptr_lo = nir_unpack_64_2x32_split_x(b, ptr64);
365    nir_def *ptr_hi = nir_unpack_64_2x32_split_y(b, ptr64);
366    ptr_hi = nir_extract_i16(b, ptr_hi, nir_imm_int(b, 0));
367    return nir_pack_64_2x32_split(b, ptr_lo, ptr_hi);
368 }
369 
370 /**
371  * MemHit memory layout (BSpec 47547) :
372  *
373  *      name            bits    description
374  *    - t               32      hit distance of current hit (or initial traversal distance)
375  *    - u               32      barycentric hit coordinates
376  *    - v               32      barycentric hit coordinates
377  *    - primIndexDelta  16      prim index delta for compressed meshlets and quads
378  *    - valid            1      set if there is a hit
379  *    - leafType         3      type of node primLeafPtr is pointing to
380  *    - primLeafIndex    4      index of the hit primitive inside the leaf
381  *    - bvhLevel         3      the instancing level at which the hit occured
382  *    - frontFace        1      whether we hit the front-facing side of a triangle (also used to pass opaque flag when calling intersection shaders)
383  *    - pad0             4      unused bits
384  *    - primLeafPtr     42      pointer to BVH leaf node (multiple of 64 bytes)
385  *    - hitGroupRecPtr0 22      LSB of hit group record of the hit triangle (multiple of 16 bytes)
386  *    - instLeafPtr     42      pointer to BVH instance leaf node (in multiple of 64 bytes)
387  *    - hitGroupRecPtr1 22      MSB of hit group record of the hit triangle (multiple of 32 bytes)
388  */
389 struct brw_nir_rt_mem_hit_defs {
390    nir_def *t;
391    nir_def *tri_bary; /**< Only valid for triangle geometry */
392    nir_def *aabb_hit_kind; /**< Only valid for AABB geometry */
393    nir_def *valid;
394    nir_def *leaf_type;
395    nir_def *prim_index_delta;
396    nir_def *prim_leaf_index;
397    nir_def *bvh_level;
398    nir_def *front_face;
399    nir_def *done; /**< Only for ray queries */
400    nir_def *prim_leaf_ptr;
401    nir_def *inst_leaf_ptr;
402 };
403 
404 static inline void
brw_nir_rt_load_mem_hit_from_addr(nir_builder * b,struct brw_nir_rt_mem_hit_defs * defs,nir_def * stack_addr,bool committed)405 brw_nir_rt_load_mem_hit_from_addr(nir_builder *b,
406                                   struct brw_nir_rt_mem_hit_defs *defs,
407                                   nir_def *stack_addr,
408                                   bool committed)
409 {
410    nir_def *hit_addr =
411       brw_nir_rt_mem_hit_addr_from_addr(b, stack_addr, committed);
412 
413    nir_def *data = brw_nir_rt_load(b, hit_addr, 16, 4, 32);
414    defs->t = nir_channel(b, data, 0);
415    defs->aabb_hit_kind = nir_channel(b, data, 1);
416    defs->tri_bary = nir_channels(b, data, 0x6);
417    nir_def *bitfield = nir_channel(b, data, 3);
418    defs->prim_index_delta =
419       nir_ubitfield_extract(b, bitfield, nir_imm_int(b, 0), nir_imm_int(b, 16));
420    defs->valid = nir_i2b(b, nir_iand_imm(b, bitfield, 1u << 16));
421    defs->leaf_type =
422       nir_ubitfield_extract(b, bitfield, nir_imm_int(b, 17), nir_imm_int(b, 3));
423    defs->prim_leaf_index =
424       nir_ubitfield_extract(b, bitfield, nir_imm_int(b, 20), nir_imm_int(b, 4));
425    defs->bvh_level =
426       nir_ubitfield_extract(b, bitfield, nir_imm_int(b, 24), nir_imm_int(b, 3));
427    defs->front_face = nir_i2b(b, nir_iand_imm(b, bitfield, 1 << 27));
428    defs->done = nir_i2b(b, nir_iand_imm(b, bitfield, 1 << 28));
429 
430    data = brw_nir_rt_load(b, nir_iadd_imm(b, hit_addr, 16), 16, 4, 32);
431    defs->prim_leaf_ptr =
432       brw_nir_rt_unpack_leaf_ptr(b, nir_channels(b, data, 0x3 << 0));
433    defs->inst_leaf_ptr =
434       brw_nir_rt_unpack_leaf_ptr(b, nir_channels(b, data, 0x3 << 2));
435 }
436 
437 static inline void
brw_nir_rt_load_mem_hit(nir_builder * b,struct brw_nir_rt_mem_hit_defs * defs,bool committed)438 brw_nir_rt_load_mem_hit(nir_builder *b,
439                         struct brw_nir_rt_mem_hit_defs *defs,
440                         bool committed)
441 {
442    brw_nir_rt_load_mem_hit_from_addr(b, defs, brw_nir_rt_stack_addr(b),
443                                      committed);
444 }
445 
446 static inline void
brw_nir_memcpy_global(nir_builder * b,nir_def * dst_addr,uint32_t dst_align,nir_def * src_addr,uint32_t src_align,uint32_t size)447 brw_nir_memcpy_global(nir_builder *b,
448                       nir_def *dst_addr, uint32_t dst_align,
449                       nir_def *src_addr, uint32_t src_align,
450                       uint32_t size)
451 {
452    /* We're going to copy in 16B chunks */
453    assert(size % 16 == 0);
454    dst_align = MIN2(dst_align, 16);
455    src_align = MIN2(src_align, 16);
456 
457    for (unsigned offset = 0; offset < size; offset += 16) {
458       nir_def *data =
459          brw_nir_rt_load(b, nir_iadd_imm(b, src_addr, offset), 16,
460                          4, 32);
461       brw_nir_rt_store(b, nir_iadd_imm(b, dst_addr, offset), 16,
462                        data, 0xf /* write_mask */);
463    }
464 }
465 
466 static inline void
brw_nir_memclear_global(nir_builder * b,nir_def * dst_addr,uint32_t dst_align,uint32_t size)467 brw_nir_memclear_global(nir_builder *b,
468                         nir_def *dst_addr, uint32_t dst_align,
469                         uint32_t size)
470 {
471    /* We're going to copy in 16B chunks */
472    assert(size % 16 == 0);
473    dst_align = MIN2(dst_align, 16);
474 
475    nir_def *zero = nir_imm_ivec4(b, 0, 0, 0, 0);
476    for (unsigned offset = 0; offset < size; offset += 16) {
477       brw_nir_rt_store(b, nir_iadd_imm(b, dst_addr, offset), dst_align,
478                        zero, 0xf /* write_mask */);
479    }
480 }
481 
482 static inline nir_def *
brw_nir_rt_query_done(nir_builder * b,nir_def * stack_addr)483 brw_nir_rt_query_done(nir_builder *b, nir_def *stack_addr)
484 {
485    struct brw_nir_rt_mem_hit_defs hit_in = {};
486    brw_nir_rt_load_mem_hit_from_addr(b, &hit_in, stack_addr,
487                                      false /* committed */);
488 
489    return hit_in.done;
490 }
491 
492 static inline void
brw_nir_rt_set_dword_bit_at(nir_builder * b,nir_def * addr,uint32_t addr_offset,uint32_t bit)493 brw_nir_rt_set_dword_bit_at(nir_builder *b,
494                             nir_def *addr,
495                             uint32_t addr_offset,
496                             uint32_t bit)
497 {
498    nir_def *dword_addr = nir_iadd_imm(b, addr, addr_offset);
499    nir_def *dword = brw_nir_rt_load(b, dword_addr, 4, 1, 32);
500    brw_nir_rt_store(b, dword_addr, 4, nir_ior_imm(b, dword, 1u << bit), 0x1);
501 }
502 
503 static inline void
brw_nir_rt_query_mark_done(nir_builder * b,nir_def * stack_addr)504 brw_nir_rt_query_mark_done(nir_builder *b, nir_def *stack_addr)
505 {
506    brw_nir_rt_set_dword_bit_at(b,
507                                brw_nir_rt_mem_hit_addr_from_addr(b, stack_addr,
508                                                                  false /* committed */),
509                                4 * 3 /* dword offset */, 28 /* bit */);
510 }
511 
512 /* This helper clears the 3rd dword of the MemHit structure where the valid
513  * bit is located.
514  */
515 static inline void
brw_nir_rt_query_mark_init(nir_builder * b,nir_def * stack_addr)516 brw_nir_rt_query_mark_init(nir_builder *b, nir_def *stack_addr)
517 {
518    nir_def *dword_addr;
519 
520    for (uint32_t i = 0; i < 2; i++) {
521       dword_addr =
522          nir_iadd_imm(b,
523                       brw_nir_rt_mem_hit_addr_from_addr(b, stack_addr,
524                                                         i == 0 /* committed */),
525                       4 * 3 /* dword offset */);
526       brw_nir_rt_store(b, dword_addr, 4, nir_imm_int(b, 0), 0x1);
527    }
528 }
529 
530 /* This helper is pretty much a memcpy of uncommitted into committed hit
531  * structure, just adding the valid bit.
532  */
533 static inline void
brw_nir_rt_commit_hit_addr(nir_builder * b,nir_def * stack_addr)534 brw_nir_rt_commit_hit_addr(nir_builder *b, nir_def *stack_addr)
535 {
536    nir_def *dst_addr =
537       brw_nir_rt_mem_hit_addr_from_addr(b, stack_addr, true /* committed */);
538    nir_def *src_addr =
539       brw_nir_rt_mem_hit_addr_from_addr(b, stack_addr, false /* committed */);
540 
541    for (unsigned offset = 0; offset < BRW_RT_SIZEOF_HIT_INFO; offset += 16) {
542       nir_def *data =
543          brw_nir_rt_load(b, nir_iadd_imm(b, src_addr, offset), 16, 4, 32);
544 
545       if (offset == 0) {
546          data = nir_vec4(b,
547                          nir_channel(b, data, 0),
548                          nir_channel(b, data, 1),
549                          nir_channel(b, data, 2),
550                          nir_ior_imm(b,
551                                      nir_channel(b, data, 3),
552                                      0x1 << 16 /* valid */));
553 
554          /* Also write the potential hit as we change it. */
555          brw_nir_rt_store(b, nir_iadd_imm(b, src_addr, offset), 16,
556                           data, 0xf /* write_mask */);
557       }
558 
559       brw_nir_rt_store(b, nir_iadd_imm(b, dst_addr, offset), 16,
560                        data, 0xf /* write_mask */);
561    }
562 }
563 
564 static inline void
brw_nir_rt_commit_hit(nir_builder * b)565 brw_nir_rt_commit_hit(nir_builder *b)
566 {
567    nir_def *stack_addr = brw_nir_rt_stack_addr(b);
568    brw_nir_rt_commit_hit_addr(b, stack_addr);
569 }
570 
571 static inline void
brw_nir_rt_generate_hit_addr(nir_builder * b,nir_def * stack_addr,nir_def * t_val)572 brw_nir_rt_generate_hit_addr(nir_builder *b, nir_def *stack_addr, nir_def *t_val)
573 {
574    nir_def *committed_addr =
575       brw_nir_rt_mem_hit_addr_from_addr(b, stack_addr, true /* committed */);
576    nir_def *potential_addr =
577       brw_nir_rt_mem_hit_addr_from_addr(b, stack_addr, false /* committed */);
578 
579    /* Set:
580     *
581     *   potential.t     = t_val;
582     *   potential.valid = true;
583     */
584    nir_def *potential_hit_dwords_0_3 =
585       brw_nir_rt_load(b, potential_addr, 16, 4, 32);
586    potential_hit_dwords_0_3 =
587       nir_vec4(b,
588                t_val,
589                nir_channel(b, potential_hit_dwords_0_3, 1),
590                nir_channel(b, potential_hit_dwords_0_3, 2),
591                nir_ior_imm(b, nir_channel(b, potential_hit_dwords_0_3, 3),
592                            (0x1 << 16) /* valid */));
593    brw_nir_rt_store(b, potential_addr, 16, potential_hit_dwords_0_3, 0xf /* write_mask */);
594 
595    /* Set:
596     *
597     *   committed.t               = t_val;
598     *   committed.u               = 0.0f;
599     *   committed.v               = 0.0f;
600     *   committed.valid           = true;
601     *   committed.leaf_type       = potential.leaf_type;
602     *   committed.bvh_level       = BRW_RT_BVH_LEVEL_OBJECT;
603     *   committed.front_face      = false;
604     *   committed.prim_leaf_index = 0;
605     *   committed.done            = false;
606     */
607    nir_def *committed_hit_dwords_0_3 =
608       brw_nir_rt_load(b, committed_addr, 16, 4, 32);
609    committed_hit_dwords_0_3 =
610       nir_vec4(b,
611                t_val,
612                nir_imm_float(b, 0.0f),
613                nir_imm_float(b, 0.0f),
614                nir_ior_imm(b,
615                            nir_ior_imm(b, nir_channel(b, potential_hit_dwords_0_3, 3), 0x000e0000),
616                            (0x1 << 16)                     /* valid */ |
617                            (BRW_RT_BVH_LEVEL_OBJECT << 24) /* leaf_type */));
618    brw_nir_rt_store(b, committed_addr, 16, committed_hit_dwords_0_3, 0xf /* write_mask */);
619 
620    /* Set:
621     *
622     *   committed.prim_leaf_ptr   = potential.prim_leaf_ptr;
623     *   committed.inst_leaf_ptr   = potential.inst_leaf_ptr;
624     */
625    brw_nir_memcpy_global(b,
626                          nir_iadd_imm(b, committed_addr, 16), 16,
627                          nir_iadd_imm(b, potential_addr, 16), 16,
628                          16);
629 }
630 
631 struct brw_nir_rt_mem_ray_defs {
632    nir_def *orig;
633    nir_def *dir;
634    nir_def *t_near;
635    nir_def *t_far;
636    nir_def *root_node_ptr;
637    nir_def *ray_flags;
638    nir_def *hit_group_sr_base_ptr;
639    nir_def *hit_group_sr_stride;
640    nir_def *miss_sr_ptr;
641    nir_def *shader_index_multiplier;
642    nir_def *inst_leaf_ptr;
643    nir_def *ray_mask;
644 };
645 
646 static inline void
brw_nir_rt_store_mem_ray_query_at_addr(nir_builder * b,nir_def * ray_addr,const struct brw_nir_rt_mem_ray_defs * defs)647 brw_nir_rt_store_mem_ray_query_at_addr(nir_builder *b,
648                                        nir_def *ray_addr,
649                                        const struct brw_nir_rt_mem_ray_defs *defs)
650 {
651    assert_def_size(defs->orig, 3, 32);
652    assert_def_size(defs->dir, 3, 32);
653    brw_nir_rt_store(b, nir_iadd_imm(b, ray_addr, 0), 16,
654       nir_vec4(b, nir_channel(b, defs->orig, 0),
655                   nir_channel(b, defs->orig, 1),
656                   nir_channel(b, defs->orig, 2),
657                   nir_channel(b, defs->dir, 0)),
658       ~0 /* write mask */);
659 
660    assert_def_size(defs->t_near, 1, 32);
661    assert_def_size(defs->t_far, 1, 32);
662    brw_nir_rt_store(b, nir_iadd_imm(b, ray_addr, 16), 16,
663       nir_vec4(b, nir_channel(b, defs->dir, 1),
664                   nir_channel(b, defs->dir, 2),
665                   defs->t_near,
666                   defs->t_far),
667       ~0 /* write mask */);
668 
669    assert_def_size(defs->root_node_ptr, 1, 64);
670    assert_def_size(defs->ray_flags, 1, 16);
671    brw_nir_rt_store(b, nir_iadd_imm(b, ray_addr, 32), 16,
672       nir_vec2(b, nir_unpack_64_2x32_split_x(b, defs->root_node_ptr),
673                   nir_pack_32_2x16_split(b,
674                      nir_unpack_64_4x16_split_z(b, defs->root_node_ptr),
675                      defs->ray_flags)),
676       0x3 /* write mask */);
677 
678    /* leaf_ptr is optional */
679    nir_def *inst_leaf_ptr;
680    if (defs->inst_leaf_ptr) {
681       inst_leaf_ptr = defs->inst_leaf_ptr;
682    } else {
683       inst_leaf_ptr = nir_imm_int64(b, 0);
684    }
685 
686    assert_def_size(inst_leaf_ptr, 1, 64);
687    assert_def_size(defs->ray_mask, 1, 32);
688    brw_nir_rt_store(b, nir_iadd_imm(b, ray_addr, 56), 8,
689       nir_vec2(b, nir_unpack_64_2x32_split_x(b, inst_leaf_ptr),
690                   nir_pack_32_2x16_split(b,
691                      nir_unpack_64_4x16_split_z(b, inst_leaf_ptr),
692                      nir_unpack_32_2x16_split_x(b, defs->ray_mask))),
693       ~0 /* write mask */);
694 }
695 
696 static inline void
brw_nir_rt_store_mem_ray(nir_builder * b,const struct brw_nir_rt_mem_ray_defs * defs,enum brw_rt_bvh_level bvh_level)697 brw_nir_rt_store_mem_ray(nir_builder *b,
698                          const struct brw_nir_rt_mem_ray_defs *defs,
699                          enum brw_rt_bvh_level bvh_level)
700 {
701    nir_def *ray_addr =
702       brw_nir_rt_mem_ray_addr(b, brw_nir_rt_stack_addr(b), bvh_level);
703 
704    assert_def_size(defs->orig, 3, 32);
705    assert_def_size(defs->dir, 3, 32);
706    brw_nir_rt_store(b, nir_iadd_imm(b, ray_addr, 0), 16,
707       nir_vec4(b, nir_channel(b, defs->orig, 0),
708                   nir_channel(b, defs->orig, 1),
709                   nir_channel(b, defs->orig, 2),
710                   nir_channel(b, defs->dir, 0)),
711       ~0 /* write mask */);
712 
713    assert_def_size(defs->t_near, 1, 32);
714    assert_def_size(defs->t_far, 1, 32);
715    brw_nir_rt_store(b, nir_iadd_imm(b, ray_addr, 16), 16,
716       nir_vec4(b, nir_channel(b, defs->dir, 1),
717                   nir_channel(b, defs->dir, 2),
718                   defs->t_near,
719                   defs->t_far),
720       ~0 /* write mask */);
721 
722    assert_def_size(defs->root_node_ptr, 1, 64);
723    assert_def_size(defs->ray_flags, 1, 16);
724    assert_def_size(defs->hit_group_sr_base_ptr, 1, 64);
725    assert_def_size(defs->hit_group_sr_stride, 1, 16);
726    brw_nir_rt_store(b, nir_iadd_imm(b, ray_addr, 32), 16,
727       nir_vec4(b, nir_unpack_64_2x32_split_x(b, defs->root_node_ptr),
728                   nir_pack_32_2x16_split(b,
729                      nir_unpack_64_4x16_split_z(b, defs->root_node_ptr),
730                      defs->ray_flags),
731                   nir_unpack_64_2x32_split_x(b, defs->hit_group_sr_base_ptr),
732                   nir_pack_32_2x16_split(b,
733                      nir_unpack_64_4x16_split_z(b, defs->hit_group_sr_base_ptr),
734                      defs->hit_group_sr_stride)),
735       ~0 /* write mask */);
736 
737    /* leaf_ptr is optional */
738    nir_def *inst_leaf_ptr;
739    if (defs->inst_leaf_ptr) {
740       inst_leaf_ptr = defs->inst_leaf_ptr;
741    } else {
742       inst_leaf_ptr = nir_imm_int64(b, 0);
743    }
744 
745    assert_def_size(defs->miss_sr_ptr, 1, 64);
746    assert_def_size(defs->shader_index_multiplier, 1, 32);
747    assert_def_size(inst_leaf_ptr, 1, 64);
748    assert_def_size(defs->ray_mask, 1, 32);
749    brw_nir_rt_store(b, nir_iadd_imm(b, ray_addr, 48), 16,
750       nir_vec4(b, nir_unpack_64_2x32_split_x(b, defs->miss_sr_ptr),
751                   nir_pack_32_2x16_split(b,
752                      nir_unpack_64_4x16_split_z(b, defs->miss_sr_ptr),
753                      nir_unpack_32_2x16_split_x(b,
754                         nir_ishl(b, defs->shader_index_multiplier,
755                                     nir_imm_int(b, 8)))),
756                   nir_unpack_64_2x32_split_x(b, inst_leaf_ptr),
757                   nir_pack_32_2x16_split(b,
758                      nir_unpack_64_4x16_split_z(b, inst_leaf_ptr),
759                      nir_unpack_32_2x16_split_x(b, defs->ray_mask))),
760       ~0 /* write mask */);
761 }
762 
763 static inline void
brw_nir_rt_load_mem_ray_from_addr(nir_builder * b,struct brw_nir_rt_mem_ray_defs * defs,nir_def * ray_base_addr,enum brw_rt_bvh_level bvh_level)764 brw_nir_rt_load_mem_ray_from_addr(nir_builder *b,
765                                   struct brw_nir_rt_mem_ray_defs *defs,
766                                   nir_def *ray_base_addr,
767                                   enum brw_rt_bvh_level bvh_level)
768 {
769    nir_def *ray_addr = brw_nir_rt_mem_ray_addr(b,
770                                                    ray_base_addr,
771                                                    bvh_level);
772 
773    nir_def *data[4] = {
774       brw_nir_rt_load(b, nir_iadd_imm(b, ray_addr,  0), 16, 4, 32),
775       brw_nir_rt_load(b, nir_iadd_imm(b, ray_addr, 16), 16, 4, 32),
776       brw_nir_rt_load(b, nir_iadd_imm(b, ray_addr, 32), 16, 4, 32),
777       brw_nir_rt_load(b, nir_iadd_imm(b, ray_addr, 48), 16, 4, 32),
778    };
779 
780    defs->orig = nir_trim_vector(b, data[0], 3);
781    defs->dir = nir_vec3(b, nir_channel(b, data[0], 3),
782                            nir_channel(b, data[1], 0),
783                            nir_channel(b, data[1], 1));
784    defs->t_near = nir_channel(b, data[1], 2);
785    defs->t_far = nir_channel(b, data[1], 3);
786    defs->root_node_ptr =
787       nir_pack_64_2x32_split(b, nir_channel(b, data[2], 0),
788                                 nir_extract_i16(b, nir_channel(b, data[2], 1),
789                                                    nir_imm_int(b, 0)));
790    defs->ray_flags =
791       nir_unpack_32_2x16_split_y(b, nir_channel(b, data[2], 1));
792    defs->hit_group_sr_base_ptr =
793       nir_pack_64_2x32_split(b, nir_channel(b, data[2], 2),
794                                 nir_extract_i16(b, nir_channel(b, data[2], 3),
795                                                    nir_imm_int(b, 0)));
796    defs->hit_group_sr_stride =
797       nir_unpack_32_2x16_split_y(b, nir_channel(b, data[2], 3));
798    defs->miss_sr_ptr =
799       nir_pack_64_2x32_split(b, nir_channel(b, data[3], 0),
800                                 nir_extract_i16(b, nir_channel(b, data[3], 1),
801                                                    nir_imm_int(b, 0)));
802    defs->shader_index_multiplier =
803       nir_ushr(b, nir_unpack_32_2x16_split_y(b, nir_channel(b, data[3], 1)),
804                   nir_imm_int(b, 8));
805    defs->inst_leaf_ptr =
806       nir_pack_64_2x32_split(b, nir_channel(b, data[3], 2),
807                                 nir_extract_i16(b, nir_channel(b, data[3], 3),
808                                                    nir_imm_int(b, 0)));
809    defs->ray_mask =
810       nir_unpack_32_2x16_split_y(b, nir_channel(b, data[3], 3));
811 }
812 
813 static inline void
brw_nir_rt_load_mem_ray(nir_builder * b,struct brw_nir_rt_mem_ray_defs * defs,enum brw_rt_bvh_level bvh_level)814 brw_nir_rt_load_mem_ray(nir_builder *b,
815                         struct brw_nir_rt_mem_ray_defs *defs,
816                         enum brw_rt_bvh_level bvh_level)
817 {
818    brw_nir_rt_load_mem_ray_from_addr(b, defs, brw_nir_rt_stack_addr(b),
819                                      bvh_level);
820 }
821 
822 struct brw_nir_rt_bvh_instance_leaf_defs {
823    nir_def *shader_index;
824    nir_def *contribution_to_hit_group_index;
825    nir_def *world_to_object[4];
826    nir_def *instance_id;
827    nir_def *instance_index;
828    nir_def *object_to_world[4];
829 };
830 
831 static inline void
brw_nir_rt_load_bvh_instance_leaf(nir_builder * b,struct brw_nir_rt_bvh_instance_leaf_defs * defs,nir_def * leaf_addr)832 brw_nir_rt_load_bvh_instance_leaf(nir_builder *b,
833                                   struct brw_nir_rt_bvh_instance_leaf_defs *defs,
834                                   nir_def *leaf_addr)
835 {
836    nir_def *leaf_desc = brw_nir_rt_load(b, leaf_addr, 4, 2, 32);
837 
838    defs->shader_index =
839       nir_iand_imm(b, nir_channel(b, leaf_desc, 0), (1 << 24) - 1);
840    defs->contribution_to_hit_group_index =
841       nir_iand_imm(b, nir_channel(b, leaf_desc, 1), (1 << 24) - 1);
842 
843    defs->world_to_object[0] =
844       brw_nir_rt_load(b, nir_iadd_imm(b, leaf_addr, 16), 4, 3, 32);
845    defs->world_to_object[1] =
846       brw_nir_rt_load(b, nir_iadd_imm(b, leaf_addr, 28), 4, 3, 32);
847    defs->world_to_object[2] =
848       brw_nir_rt_load(b, nir_iadd_imm(b, leaf_addr, 40), 4, 3, 32);
849    /* The last column of the matrices is swapped between the two probably
850     * because it makes it easier/faster for hardware somehow.
851     */
852    defs->object_to_world[3] =
853       brw_nir_rt_load(b, nir_iadd_imm(b, leaf_addr, 52), 4, 3, 32);
854 
855    nir_def *data =
856       brw_nir_rt_load(b, nir_iadd_imm(b, leaf_addr, 64), 4, 4, 32);
857    defs->instance_id = nir_channel(b, data, 2);
858    defs->instance_index = nir_channel(b, data, 3);
859 
860    defs->object_to_world[0] =
861       brw_nir_rt_load(b, nir_iadd_imm(b, leaf_addr, 80), 4, 3, 32);
862    defs->object_to_world[1] =
863       brw_nir_rt_load(b, nir_iadd_imm(b, leaf_addr, 92), 4, 3, 32);
864    defs->object_to_world[2] =
865       brw_nir_rt_load(b, nir_iadd_imm(b, leaf_addr, 104), 4, 3, 32);
866    defs->world_to_object[3] =
867       brw_nir_rt_load(b, nir_iadd_imm(b, leaf_addr, 116), 4, 3, 32);
868 }
869 
870 struct brw_nir_rt_bvh_primitive_leaf_defs {
871    nir_def *shader_index;
872    nir_def *geom_mask;
873    nir_def *geom_index;
874    nir_def *type;
875    nir_def *geom_flags;
876 };
877 
878 static inline void
brw_nir_rt_load_bvh_primitive_leaf(nir_builder * b,struct brw_nir_rt_bvh_primitive_leaf_defs * defs,nir_def * leaf_addr)879 brw_nir_rt_load_bvh_primitive_leaf(nir_builder *b,
880                                    struct brw_nir_rt_bvh_primitive_leaf_defs *defs,
881                                    nir_def *leaf_addr)
882 {
883    nir_def *desc = brw_nir_rt_load(b, leaf_addr, 4, 2, 32);
884 
885    defs->shader_index =
886       nir_ubitfield_extract(b, nir_channel(b, desc, 0),
887                             nir_imm_int(b, 23), nir_imm_int(b, 0));
888    defs->geom_mask =
889       nir_ubitfield_extract(b, nir_channel(b, desc, 0),
890                             nir_imm_int(b, 31), nir_imm_int(b, 24));
891 
892    defs->geom_index =
893       nir_ubitfield_extract(b, nir_channel(b, desc, 1),
894                             nir_imm_int(b, 28), nir_imm_int(b, 0));
895    defs->type =
896       nir_ubitfield_extract(b, nir_channel(b, desc, 1),
897                             nir_imm_int(b, 29), nir_imm_int(b, 29));
898    defs->geom_flags =
899       nir_ubitfield_extract(b, nir_channel(b, desc, 1),
900                             nir_imm_int(b, 31), nir_imm_int(b, 30));
901 }
902 
903 struct brw_nir_rt_bvh_primitive_leaf_positions_defs {
904    nir_def *positions[3];
905 };
906 
907 static inline void
brw_nir_rt_load_bvh_primitive_leaf_positions(nir_builder * b,struct brw_nir_rt_bvh_primitive_leaf_positions_defs * defs,nir_def * leaf_addr)908 brw_nir_rt_load_bvh_primitive_leaf_positions(nir_builder *b,
909                                              struct brw_nir_rt_bvh_primitive_leaf_positions_defs *defs,
910                                              nir_def *leaf_addr)
911 {
912    for (unsigned i = 0; i < ARRAY_SIZE(defs->positions); i++) {
913       defs->positions[i] =
914          brw_nir_rt_load(b, nir_iadd_imm(b, leaf_addr, 16 + i * 4 * 3), 4, 3, 32);
915    }
916 }
917 
918 static inline nir_def *
brw_nir_rt_load_primitive_id_from_hit(nir_builder * b,nir_def * is_procedural,const struct brw_nir_rt_mem_hit_defs * defs)919 brw_nir_rt_load_primitive_id_from_hit(nir_builder *b,
920                                       nir_def *is_procedural,
921                                       const struct brw_nir_rt_mem_hit_defs *defs)
922 {
923    if (!is_procedural) {
924       is_procedural =
925          nir_ieq_imm(b, defs->leaf_type,
926                         BRW_RT_BVH_NODE_TYPE_PROCEDURAL);
927    }
928 
929    nir_def *prim_id_proc, *prim_id_quad;
930    nir_push_if(b, is_procedural);
931    {
932       /* For procedural leafs, the index is in dw[3]. */
933       nir_def *offset =
934          nir_iadd_imm(b, nir_ishl_imm(b, defs->prim_leaf_index, 2), 12);
935       prim_id_proc = nir_load_global(b, nir_iadd(b, defs->prim_leaf_ptr,
936                                                  nir_u2u64(b, offset)),
937                                      4, /* align */ 1, 32);
938    }
939    nir_push_else(b, NULL);
940    {
941       /* For quad leafs, the index is dw[2] and there is a 16bit additional
942        * offset in dw[3].
943        */
944       prim_id_quad = nir_load_global(b, nir_iadd_imm(b, defs->prim_leaf_ptr, 8),
945                                      4, /* align */ 1, 32);
946       prim_id_quad = nir_iadd(b,
947                               prim_id_quad,
948                               defs->prim_index_delta);
949    }
950    nir_pop_if(b, NULL);
951 
952    return nir_if_phi(b, prim_id_proc, prim_id_quad);
953 }
954 
955 static inline nir_def *
brw_nir_rt_acceleration_structure_to_root_node(nir_builder * b,nir_def * as_addr)956 brw_nir_rt_acceleration_structure_to_root_node(nir_builder *b,
957                                                nir_def *as_addr)
958 {
959    /* The HW memory structure in which we specify what acceleration structure
960     * to traverse, takes the address to the root node in the acceleration
961     * structure, not the acceleration structure itself. To find that, we have
962     * to read the root node offset from the acceleration structure which is
963     * the first QWord.
964     *
965     * But if the acceleration structure pointer is NULL, then we should return
966     * NULL as root node pointer.
967     *
968     * TODO: we could optimize this by assuming that for a given version of the
969     * BVH, we can find the root node at a given offset.
970     */
971    nir_def *root_node_ptr, *null_node_ptr;
972    nir_push_if(b, nir_ieq_imm(b, as_addr, 0));
973    {
974       null_node_ptr = nir_imm_int64(b, 0);
975    }
976    nir_push_else(b, NULL);
977    {
978       root_node_ptr =
979          nir_iadd(b, as_addr, brw_nir_rt_load(b, as_addr, 256, 1, 64));
980    }
981    nir_pop_if(b, NULL);
982 
983    return nir_if_phi(b, null_node_ptr, root_node_ptr);
984 }
985