• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright © 2020 Intel Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 #ifndef BRW_NIR_RT_BUILDER_H
25 #define BRW_NIR_RT_BUILDER_H
26 
27 /* This file provides helpers to access memory based data structures that the
28  * RT hardware reads/writes and their locations.
29  *
30  * See also "Memory Based Data Structures for Ray Tracing" (BSpec 47547) and
31  * "Ray Tracing Address Computation for Memory Resident Structures" (BSpec
32  * 47550).
33  */
34 
35 #include "brw_rt.h"
36 #include "nir_builder.h"
37 
38 #define is_access_for_builder(b) \
39    ((b)->shader->info.stage == MESA_SHADER_FRAGMENT ? \
40     ACCESS_INCLUDE_HELPERS : 0)
41 
42 static inline nir_def *
brw_nir_rt_load(nir_builder * b,nir_def * addr,unsigned align,unsigned components,unsigned bit_size)43 brw_nir_rt_load(nir_builder *b, nir_def *addr, unsigned align,
44                 unsigned components, unsigned bit_size)
45 {
46    return nir_build_load_global(b, components, bit_size, addr,
47                                 .align_mul = align,
48                                 .access = is_access_for_builder(b));
49 }
50 
51 static inline void
brw_nir_rt_store(nir_builder * b,nir_def * addr,unsigned align,nir_def * value,unsigned write_mask)52 brw_nir_rt_store(nir_builder *b, nir_def *addr, unsigned align,
53                  nir_def *value, unsigned write_mask)
54 {
55    nir_build_store_global(b, value, addr,
56                           .align_mul = align,
57                           .write_mask = (write_mask) &
58                                         BITFIELD_MASK(value->num_components),
59                           .access = is_access_for_builder(b));
60 }
61 
62 static inline nir_def *
brw_nir_rt_load_const(nir_builder * b,unsigned components,nir_def * addr,nir_def * pred)63 brw_nir_rt_load_const(nir_builder *b, unsigned components,
64                       nir_def *addr, nir_def *pred)
65 {
66    return nir_load_global_const_block_intel(b, components, addr, pred);
67 }
68 
69 static inline nir_def *
brw_load_btd_dss_id(nir_builder * b)70 brw_load_btd_dss_id(nir_builder *b)
71 {
72    return nir_load_topology_id_intel(b, .base = BRW_TOPOLOGY_ID_DSS);
73 }
74 
75 static inline nir_def *
brw_nir_rt_load_num_simd_lanes_per_dss(nir_builder * b,const struct intel_device_info * devinfo)76 brw_nir_rt_load_num_simd_lanes_per_dss(nir_builder *b,
77                                        const struct intel_device_info *devinfo)
78 {
79    return nir_imm_int(b, devinfo->num_thread_per_eu *
80                          devinfo->max_eus_per_subslice *
81                          16 /* The RT computation is based off SIMD16 */);
82 }
83 
84 static inline nir_def *
brw_load_eu_thread_simd(nir_builder * b)85 brw_load_eu_thread_simd(nir_builder *b)
86 {
87    return nir_load_topology_id_intel(b, .base = BRW_TOPOLOGY_ID_EU_THREAD_SIMD);
88 }
89 
90 static inline nir_def *
brw_nir_rt_async_stack_id(nir_builder * b)91 brw_nir_rt_async_stack_id(nir_builder *b)
92 {
93    return nir_iadd(b, nir_umul_32x16(b, nir_load_ray_num_dss_rt_stacks_intel(b),
94                                         brw_load_btd_dss_id(b)),
95                       nir_load_btd_stack_id_intel(b));
96 }
97 
98 static inline nir_def *
brw_nir_rt_sync_stack_id(nir_builder * b)99 brw_nir_rt_sync_stack_id(nir_builder *b)
100 {
101    return brw_load_eu_thread_simd(b);
102 }
103 
104 /* We have our own load/store scratch helpers because they emit a global
105  * memory read or write based on the scratch_base_ptr system value rather
106  * than a load/store_scratch intrinsic.
107  */
108 static inline nir_def *
brw_nir_rt_load_scratch(nir_builder * b,uint32_t offset,unsigned align,unsigned num_components,unsigned bit_size)109 brw_nir_rt_load_scratch(nir_builder *b, uint32_t offset, unsigned align,
110                         unsigned num_components, unsigned bit_size)
111 {
112    nir_def *addr =
113       nir_iadd_imm(b, nir_load_scratch_base_ptr(b, 1, 64, 1), offset);
114    return brw_nir_rt_load(b, addr, MIN2(align, BRW_BTD_STACK_ALIGN),
115                              num_components, bit_size);
116 }
117 
118 static inline void
brw_nir_rt_store_scratch(nir_builder * b,uint32_t offset,unsigned align,nir_def * value,nir_component_mask_t write_mask)119 brw_nir_rt_store_scratch(nir_builder *b, uint32_t offset, unsigned align,
120                          nir_def *value, nir_component_mask_t write_mask)
121 {
122    nir_def *addr =
123       nir_iadd_imm(b, nir_load_scratch_base_ptr(b, 1, 64, 1), offset);
124    brw_nir_rt_store(b, addr, MIN2(align, BRW_BTD_STACK_ALIGN),
125                     value, write_mask);
126 }
127 
128 static inline void
brw_nir_btd_spawn(nir_builder * b,nir_def * record_addr)129 brw_nir_btd_spawn(nir_builder *b, nir_def *record_addr)
130 {
131    nir_btd_spawn_intel(b, nir_load_btd_global_arg_addr_intel(b), record_addr);
132 }
133 
134 static inline void
brw_nir_btd_retire(nir_builder * b)135 brw_nir_btd_retire(nir_builder *b)
136 {
137    nir_btd_retire_intel(b);
138 }
139 
140 /** This is a pseudo-op which does a bindless return
141  *
142  * It loads the return address from the stack and calls btd_spawn to spawn the
143  * resume shader.
144  */
145 static inline void
brw_nir_btd_return(struct nir_builder * b)146 brw_nir_btd_return(struct nir_builder *b)
147 {
148    nir_def *resume_addr =
149       brw_nir_rt_load_scratch(b, BRW_BTD_STACK_RESUME_BSR_ADDR_OFFSET,
150                               8 /* align */, 1, 64);
151    brw_nir_btd_spawn(b, resume_addr);
152 }
153 
154 static inline void
assert_def_size(nir_def * def,unsigned num_components,unsigned bit_size)155 assert_def_size(nir_def *def, unsigned num_components, unsigned bit_size)
156 {
157    assert(def->num_components == num_components);
158    assert(def->bit_size == bit_size);
159 }
160 
161 static inline nir_def *
brw_nir_num_rt_stacks(nir_builder * b,const struct intel_device_info * devinfo)162 brw_nir_num_rt_stacks(nir_builder *b,
163                       const struct intel_device_info *devinfo)
164 {
165    return nir_imul_imm(b, nir_load_ray_num_dss_rt_stacks_intel(b),
166                           intel_device_info_dual_subslice_id_bound(devinfo));
167 }
168 
169 static inline nir_def *
brw_nir_rt_sw_hotzone_addr(nir_builder * b,const struct intel_device_info * devinfo)170 brw_nir_rt_sw_hotzone_addr(nir_builder *b,
171                            const struct intel_device_info *devinfo)
172 {
173    nir_def *offset32 =
174       nir_imul_imm(b, brw_nir_rt_async_stack_id(b),
175                       BRW_RT_SIZEOF_HOTZONE);
176 
177    offset32 = nir_iadd(b, offset32, nir_ineg(b,
178       nir_imul_imm(b, brw_nir_num_rt_stacks(b, devinfo),
179                       BRW_RT_SIZEOF_HOTZONE)));
180 
181    return nir_iadd(b, nir_load_ray_base_mem_addr_intel(b),
182                       nir_i2i64(b, offset32));
183 }
184 
185 static inline nir_def *
brw_nir_rt_sync_stack_addr(nir_builder * b,nir_def * base_mem_addr,const struct intel_device_info * devinfo)186 brw_nir_rt_sync_stack_addr(nir_builder *b,
187                            nir_def *base_mem_addr,
188                            const struct intel_device_info *devinfo)
189 {
190    /* For Ray queries (Synchronous Ray Tracing), the formula is similar but
191     * goes down from rtMemBasePtr :
192     *
193     *    syncBase  = RTDispatchGlobals.rtMemBasePtr
194     *              - (DSSID * NUM_SIMD_LANES_PER_DSS + SyncStackID + 1)
195     *              * syncStackSize
196     *
197     * We assume that we can calculate a 32-bit offset first and then add it
198     * to the 64-bit base address at the end.
199     */
200    nir_def *offset32 =
201       nir_imul(b,
202                nir_iadd(b,
203                         nir_imul(b, brw_load_btd_dss_id(b),
204                                     brw_nir_rt_load_num_simd_lanes_per_dss(b, devinfo)),
205                         nir_iadd_imm(b, brw_nir_rt_sync_stack_id(b), 1)),
206                nir_imm_int(b, BRW_RT_SIZEOF_RAY_QUERY));
207    return nir_isub(b, base_mem_addr, nir_u2u64(b, offset32));
208 }
209 
210 static inline nir_def *
brw_nir_rt_stack_addr(nir_builder * b)211 brw_nir_rt_stack_addr(nir_builder *b)
212 {
213    /* From the BSpec "Address Computation for Memory Based Data Structures:
214     * Ray and TraversalStack (Async Ray Tracing)":
215     *
216     *    stackBase = RTDispatchGlobals.rtMemBasePtr
217     *              + (DSSID * RTDispatchGlobals.numDSSRTStacks + stackID)
218     *              * RTDispatchGlobals.stackSizePerRay // 64B aligned
219     *
220     * We assume that we can calculate a 32-bit offset first and then add it
221     * to the 64-bit base address at the end.
222     */
223    nir_def *offset32 =
224       nir_imul(b, brw_nir_rt_async_stack_id(b),
225                   nir_load_ray_hw_stack_size_intel(b));
226    return nir_iadd(b, nir_load_ray_base_mem_addr_intel(b),
227                       nir_u2u64(b, offset32));
228 }
229 
230 static inline nir_def *
brw_nir_rt_mem_hit_addr_from_addr(nir_builder * b,nir_def * stack_addr,bool committed)231 brw_nir_rt_mem_hit_addr_from_addr(nir_builder *b,
232                         nir_def *stack_addr,
233                         bool committed)
234 {
235    return nir_iadd_imm(b, stack_addr, committed ? 0 : BRW_RT_SIZEOF_HIT_INFO);
236 }
237 
238 static inline nir_def *
brw_nir_rt_mem_hit_addr(nir_builder * b,bool committed)239 brw_nir_rt_mem_hit_addr(nir_builder *b, bool committed)
240 {
241    return nir_iadd_imm(b, brw_nir_rt_stack_addr(b),
242                           committed ? 0 : BRW_RT_SIZEOF_HIT_INFO);
243 }
244 
245 static inline nir_def *
brw_nir_rt_hit_attrib_data_addr(nir_builder * b)246 brw_nir_rt_hit_attrib_data_addr(nir_builder *b)
247 {
248    return nir_iadd_imm(b, brw_nir_rt_stack_addr(b),
249                           BRW_RT_OFFSETOF_HIT_ATTRIB_DATA);
250 }
251 
252 static inline nir_def *
brw_nir_rt_mem_ray_addr(nir_builder * b,nir_def * stack_addr,enum brw_rt_bvh_level bvh_level)253 brw_nir_rt_mem_ray_addr(nir_builder *b,
254                         nir_def *stack_addr,
255                         enum brw_rt_bvh_level bvh_level)
256 {
257    /* From the BSpec "Address Computation for Memory Based Data Structures:
258     * Ray and TraversalStack (Async Ray Tracing)":
259     *
260     *    rayBase = stackBase + sizeof(HitInfo) * 2 // 64B aligned
261     *    rayPtr  = rayBase + bvhLevel * sizeof(Ray); // 64B aligned
262     *
263     * In Vulkan, we always have exactly two levels of BVH: World and Object.
264     */
265    uint32_t offset = BRW_RT_SIZEOF_HIT_INFO * 2 +
266                      bvh_level * BRW_RT_SIZEOF_RAY;
267    return nir_iadd_imm(b, stack_addr, offset);
268 }
269 
270 static inline nir_def *
brw_nir_rt_sw_stack_addr(nir_builder * b,const struct intel_device_info * devinfo)271 brw_nir_rt_sw_stack_addr(nir_builder *b,
272                          const struct intel_device_info *devinfo)
273 {
274    nir_def *addr = nir_load_ray_base_mem_addr_intel(b);
275 
276    nir_def *offset32 = nir_imul(b, brw_nir_num_rt_stacks(b, devinfo),
277                                        nir_load_ray_hw_stack_size_intel(b));
278    addr = nir_iadd(b, addr, nir_u2u64(b, offset32));
279 
280    nir_def *offset_in_stack =
281       nir_imul(b, nir_u2u64(b, brw_nir_rt_async_stack_id(b)),
282                   nir_u2u64(b, nir_load_ray_sw_stack_size_intel(b)));
283 
284    return nir_iadd(b, addr, offset_in_stack);
285 }
286 
287 static inline nir_def *
nir_unpack_64_4x16_split_z(nir_builder * b,nir_def * val)288 nir_unpack_64_4x16_split_z(nir_builder *b, nir_def *val)
289 {
290    return nir_unpack_32_2x16_split_x(b, nir_unpack_64_2x32_split_y(b, val));
291 }
292 
293 struct brw_nir_rt_globals_defs {
294    nir_def *base_mem_addr;
295    nir_def *call_stack_handler_addr;
296    nir_def *hw_stack_size;
297    nir_def *num_dss_rt_stacks;
298    nir_def *hit_sbt_addr;
299    nir_def *hit_sbt_stride;
300    nir_def *miss_sbt_addr;
301    nir_def *miss_sbt_stride;
302    nir_def *sw_stack_size;
303    nir_def *launch_size;
304    nir_def *call_sbt_addr;
305    nir_def *call_sbt_stride;
306    nir_def *resume_sbt_addr;
307 };
308 
309 static inline void
brw_nir_rt_load_globals_addr(nir_builder * b,struct brw_nir_rt_globals_defs * defs,nir_def * addr)310 brw_nir_rt_load_globals_addr(nir_builder *b,
311                              struct brw_nir_rt_globals_defs *defs,
312                              nir_def *addr)
313 {
314    nir_def *data;
315    data = brw_nir_rt_load_const(b, 16, addr, nir_imm_true(b));
316    defs->base_mem_addr = nir_pack_64_2x32(b, nir_trim_vector(b, data, 2));
317 
318    defs->call_stack_handler_addr =
319       nir_pack_64_2x32(b, nir_channels(b, data, 0x3 << 2));
320 
321    defs->hw_stack_size = nir_channel(b, data, 4);
322    defs->num_dss_rt_stacks = nir_iand_imm(b, nir_channel(b, data, 5), 0xffff);
323    defs->hit_sbt_addr =
324       nir_pack_64_2x32_split(b, nir_channel(b, data, 8),
325                                 nir_extract_i16(b, nir_channel(b, data, 9),
326                                                    nir_imm_int(b, 0)));
327    defs->hit_sbt_stride =
328       nir_unpack_32_2x16_split_y(b, nir_channel(b, data, 9));
329    defs->miss_sbt_addr =
330       nir_pack_64_2x32_split(b, nir_channel(b, data, 10),
331                                 nir_extract_i16(b, nir_channel(b, data, 11),
332                                                    nir_imm_int(b, 0)));
333    defs->miss_sbt_stride =
334       nir_unpack_32_2x16_split_y(b, nir_channel(b, data, 11));
335    defs->sw_stack_size = nir_channel(b, data, 12);
336    defs->launch_size = nir_channels(b, data, 0x7u << 13);
337 
338    data = brw_nir_rt_load_const(b, 8, nir_iadd_imm(b, addr, 64), nir_imm_true(b));
339    defs->call_sbt_addr =
340       nir_pack_64_2x32_split(b, nir_channel(b, data, 0),
341                                 nir_extract_i16(b, nir_channel(b, data, 1),
342                                                    nir_imm_int(b, 0)));
343    defs->call_sbt_stride =
344       nir_unpack_32_2x16_split_y(b, nir_channel(b, data, 1));
345 
346    defs->resume_sbt_addr =
347       nir_pack_64_2x32(b, nir_channels(b, data, 0x3 << 2));
348 }
349 
350 static inline void
brw_nir_rt_load_globals(nir_builder * b,struct brw_nir_rt_globals_defs * defs)351 brw_nir_rt_load_globals(nir_builder *b,
352                         struct brw_nir_rt_globals_defs *defs)
353 {
354    brw_nir_rt_load_globals_addr(b, defs, nir_load_btd_global_arg_addr_intel(b));
355 }
356 
357 static inline nir_def *
brw_nir_rt_unpack_leaf_ptr(nir_builder * b,nir_def * vec2)358 brw_nir_rt_unpack_leaf_ptr(nir_builder *b, nir_def *vec2)
359 {
360    /* Hit record leaf pointers are 42-bit and assumed to be in 64B chunks.
361     * This leaves 22 bits at the top for other stuff.
362     */
363    nir_def *ptr64 = nir_imul_imm(b, nir_pack_64_2x32(b, vec2), 64);
364 
365    /* The top 16 bits (remember, we shifted by 6 already) contain garbage
366     * that we need to get rid of.
367     */
368    nir_def *ptr_lo = nir_unpack_64_2x32_split_x(b, ptr64);
369    nir_def *ptr_hi = nir_unpack_64_2x32_split_y(b, ptr64);
370    ptr_hi = nir_extract_i16(b, ptr_hi, nir_imm_int(b, 0));
371    return nir_pack_64_2x32_split(b, ptr_lo, ptr_hi);
372 }
373 
374 /**
375  * MemHit memory layout (BSpec 47547) :
376  *
377  *      name            bits    description
378  *    - t               32      hit distance of current hit (or initial traversal distance)
379  *    - u               32      barycentric hit coordinates
380  *    - v               32      barycentric hit coordinates
381  *    - primIndexDelta  16      prim index delta for compressed meshlets and quads
382  *    - valid            1      set if there is a hit
383  *    - leafType         3      type of node primLeafPtr is pointing to
384  *    - primLeafIndex    4      index of the hit primitive inside the leaf
385  *    - bvhLevel         3      the instancing level at which the hit occured
386  *    - frontFace        1      whether we hit the front-facing side of a triangle (also used to pass opaque flag when calling intersection shaders)
387  *    - pad0             4      unused bits
388  *    - primLeafPtr     42      pointer to BVH leaf node (multiple of 64 bytes)
389  *    - hitGroupRecPtr0 22      LSB of hit group record of the hit triangle (multiple of 16 bytes)
390  *    - instLeafPtr     42      pointer to BVH instance leaf node (in multiple of 64 bytes)
391  *    - hitGroupRecPtr1 22      MSB of hit group record of the hit triangle (multiple of 32 bytes)
392  */
393 struct brw_nir_rt_mem_hit_defs {
394    nir_def *t;
395    nir_def *tri_bary; /**< Only valid for triangle geometry */
396    nir_def *aabb_hit_kind; /**< Only valid for AABB geometry */
397    nir_def *valid;
398    nir_def *leaf_type;
399    nir_def *prim_index_delta;
400    nir_def *prim_leaf_index;
401    nir_def *bvh_level;
402    nir_def *front_face;
403    nir_def *done; /**< Only for ray queries */
404    nir_def *prim_leaf_ptr;
405    nir_def *inst_leaf_ptr;
406 };
407 
408 static inline void
brw_nir_rt_load_mem_hit_from_addr(nir_builder * b,struct brw_nir_rt_mem_hit_defs * defs,nir_def * stack_addr,bool committed)409 brw_nir_rt_load_mem_hit_from_addr(nir_builder *b,
410                                   struct brw_nir_rt_mem_hit_defs *defs,
411                                   nir_def *stack_addr,
412                                   bool committed)
413 {
414    nir_def *hit_addr =
415       brw_nir_rt_mem_hit_addr_from_addr(b, stack_addr, committed);
416 
417    nir_def *data = brw_nir_rt_load(b, hit_addr, 16, 4, 32);
418    defs->t = nir_channel(b, data, 0);
419    defs->aabb_hit_kind = nir_channel(b, data, 1);
420    defs->tri_bary = nir_channels(b, data, 0x6);
421    nir_def *bitfield = nir_channel(b, data, 3);
422    defs->prim_index_delta =
423       nir_ubitfield_extract(b, bitfield, nir_imm_int(b, 0), nir_imm_int(b, 16));
424    defs->valid = nir_i2b(b, nir_iand_imm(b, bitfield, 1u << 16));
425    defs->leaf_type =
426       nir_ubitfield_extract(b, bitfield, nir_imm_int(b, 17), nir_imm_int(b, 3));
427    defs->prim_leaf_index =
428       nir_ubitfield_extract(b, bitfield, nir_imm_int(b, 20), nir_imm_int(b, 4));
429    defs->bvh_level =
430       nir_ubitfield_extract(b, bitfield, nir_imm_int(b, 24), nir_imm_int(b, 3));
431    defs->front_face = nir_i2b(b, nir_iand_imm(b, bitfield, 1 << 27));
432    defs->done = nir_i2b(b, nir_iand_imm(b, bitfield, 1 << 28));
433 
434    data = brw_nir_rt_load(b, nir_iadd_imm(b, hit_addr, 16), 16, 4, 32);
435    defs->prim_leaf_ptr =
436       brw_nir_rt_unpack_leaf_ptr(b, nir_channels(b, data, 0x3 << 0));
437    defs->inst_leaf_ptr =
438       brw_nir_rt_unpack_leaf_ptr(b, nir_channels(b, data, 0x3 << 2));
439 }
440 
441 static inline void
brw_nir_rt_load_mem_hit(nir_builder * b,struct brw_nir_rt_mem_hit_defs * defs,bool committed)442 brw_nir_rt_load_mem_hit(nir_builder *b,
443                         struct brw_nir_rt_mem_hit_defs *defs,
444                         bool committed)
445 {
446    brw_nir_rt_load_mem_hit_from_addr(b, defs, brw_nir_rt_stack_addr(b),
447                                      committed);
448 }
449 
450 static inline void
brw_nir_memcpy_global(nir_builder * b,nir_def * dst_addr,uint32_t dst_align,nir_def * src_addr,uint32_t src_align,uint32_t size)451 brw_nir_memcpy_global(nir_builder *b,
452                       nir_def *dst_addr, uint32_t dst_align,
453                       nir_def *src_addr, uint32_t src_align,
454                       uint32_t size)
455 {
456    /* We're going to copy in 16B chunks */
457    assert(size % 16 == 0);
458    dst_align = MIN2(dst_align, 16);
459    src_align = MIN2(src_align, 16);
460 
461    for (unsigned offset = 0; offset < size; offset += 16) {
462       nir_def *data =
463          brw_nir_rt_load(b, nir_iadd_imm(b, src_addr, offset), 16,
464                          4, 32);
465       brw_nir_rt_store(b, nir_iadd_imm(b, dst_addr, offset), 16,
466                        data, 0xf /* write_mask */);
467    }
468 }
469 
470 static inline void
brw_nir_memclear_global(nir_builder * b,nir_def * dst_addr,uint32_t dst_align,uint32_t size)471 brw_nir_memclear_global(nir_builder *b,
472                         nir_def *dst_addr, uint32_t dst_align,
473                         uint32_t size)
474 {
475    /* We're going to copy in 16B chunks */
476    assert(size % 16 == 0);
477    dst_align = MIN2(dst_align, 16);
478 
479    nir_def *zero = nir_imm_ivec4(b, 0, 0, 0, 0);
480    for (unsigned offset = 0; offset < size; offset += 16) {
481       brw_nir_rt_store(b, nir_iadd_imm(b, dst_addr, offset), dst_align,
482                        zero, 0xf /* write_mask */);
483    }
484 }
485 
486 static inline nir_def *
brw_nir_rt_query_done(nir_builder * b,nir_def * stack_addr)487 brw_nir_rt_query_done(nir_builder *b, nir_def *stack_addr)
488 {
489    struct brw_nir_rt_mem_hit_defs hit_in = {};
490    brw_nir_rt_load_mem_hit_from_addr(b, &hit_in, stack_addr,
491                                      false /* committed */);
492 
493    return hit_in.done;
494 }
495 
496 static inline void
brw_nir_rt_set_dword_bit_at(nir_builder * b,nir_def * addr,uint32_t addr_offset,uint32_t bit)497 brw_nir_rt_set_dword_bit_at(nir_builder *b,
498                             nir_def *addr,
499                             uint32_t addr_offset,
500                             uint32_t bit)
501 {
502    nir_def *dword_addr = nir_iadd_imm(b, addr, addr_offset);
503    nir_def *dword = brw_nir_rt_load(b, dword_addr, 4, 1, 32);
504    brw_nir_rt_store(b, dword_addr, 4, nir_ior_imm(b, dword, 1u << bit), 0x1);
505 }
506 
507 static inline void
brw_nir_rt_query_mark_done(nir_builder * b,nir_def * stack_addr)508 brw_nir_rt_query_mark_done(nir_builder *b, nir_def *stack_addr)
509 {
510    brw_nir_rt_set_dword_bit_at(b,
511                                brw_nir_rt_mem_hit_addr_from_addr(b, stack_addr,
512                                                                  false /* committed */),
513                                4 * 3 /* dword offset */, 28 /* bit */);
514 }
515 
516 /* This helper clears the 3rd dword of the MemHit structure where the valid
517  * bit is located.
518  */
519 static inline void
brw_nir_rt_query_mark_init(nir_builder * b,nir_def * stack_addr)520 brw_nir_rt_query_mark_init(nir_builder *b, nir_def *stack_addr)
521 {
522    nir_def *dword_addr;
523 
524    for (uint32_t i = 0; i < 2; i++) {
525       dword_addr =
526          nir_iadd_imm(b,
527                       brw_nir_rt_mem_hit_addr_from_addr(b, stack_addr,
528                                                         i == 0 /* committed */),
529                       4 * 3 /* dword offset */);
530       brw_nir_rt_store(b, dword_addr, 4, nir_imm_int(b, 0), 0x1);
531    }
532 }
533 
534 /* This helper is pretty much a memcpy of uncommitted into committed hit
535  * structure, just adding the valid bit.
536  */
537 static inline void
brw_nir_rt_commit_hit_addr(nir_builder * b,nir_def * stack_addr)538 brw_nir_rt_commit_hit_addr(nir_builder *b, nir_def *stack_addr)
539 {
540    nir_def *dst_addr =
541       brw_nir_rt_mem_hit_addr_from_addr(b, stack_addr, true /* committed */);
542    nir_def *src_addr =
543       brw_nir_rt_mem_hit_addr_from_addr(b, stack_addr, false /* committed */);
544 
545    for (unsigned offset = 0; offset < BRW_RT_SIZEOF_HIT_INFO; offset += 16) {
546       nir_def *data =
547          brw_nir_rt_load(b, nir_iadd_imm(b, src_addr, offset), 16, 4, 32);
548 
549       if (offset == 0) {
550          data = nir_vec4(b,
551                          nir_channel(b, data, 0),
552                          nir_channel(b, data, 1),
553                          nir_channel(b, data, 2),
554                          nir_ior_imm(b,
555                                      nir_channel(b, data, 3),
556                                      0x1 << 16 /* valid */));
557 
558          /* Also write the potential hit as we change it. */
559          brw_nir_rt_store(b, nir_iadd_imm(b, src_addr, offset), 16,
560                           data, 0xf /* write_mask */);
561       }
562 
563       brw_nir_rt_store(b, nir_iadd_imm(b, dst_addr, offset), 16,
564                        data, 0xf /* write_mask */);
565    }
566 }
567 
568 static inline void
brw_nir_rt_commit_hit(nir_builder * b)569 brw_nir_rt_commit_hit(nir_builder *b)
570 {
571    nir_def *stack_addr = brw_nir_rt_stack_addr(b);
572    brw_nir_rt_commit_hit_addr(b, stack_addr);
573 }
574 
575 static inline void
brw_nir_rt_generate_hit_addr(nir_builder * b,nir_def * stack_addr,nir_def * t_val)576 brw_nir_rt_generate_hit_addr(nir_builder *b, nir_def *stack_addr, nir_def *t_val)
577 {
578    nir_def *committed_addr =
579       brw_nir_rt_mem_hit_addr_from_addr(b, stack_addr, true /* committed */);
580    nir_def *potential_addr =
581       brw_nir_rt_mem_hit_addr_from_addr(b, stack_addr, false /* committed */);
582 
583    /* Set:
584     *
585     *   potential.t     = t_val;
586     *   potential.valid = true;
587     */
588    nir_def *potential_hit_dwords_0_3 =
589       brw_nir_rt_load(b, potential_addr, 16, 4, 32);
590    potential_hit_dwords_0_3 =
591       nir_vec4(b,
592                t_val,
593                nir_channel(b, potential_hit_dwords_0_3, 1),
594                nir_channel(b, potential_hit_dwords_0_3, 2),
595                nir_ior_imm(b, nir_channel(b, potential_hit_dwords_0_3, 3),
596                            (0x1 << 16) /* valid */));
597    brw_nir_rt_store(b, potential_addr, 16, potential_hit_dwords_0_3, 0xf /* write_mask */);
598 
599    /* Set:
600     *
601     *   committed.t               = t_val;
602     *   committed.u               = 0.0f;
603     *   committed.v               = 0.0f;
604     *   committed.valid           = true;
605     *   committed.leaf_type       = potential.leaf_type;
606     *   committed.bvh_level       = BRW_RT_BVH_LEVEL_OBJECT;
607     *   committed.front_face      = false;
608     *   committed.prim_leaf_index = 0;
609     *   committed.done            = false;
610     */
611    nir_def *committed_hit_dwords_0_3 =
612       brw_nir_rt_load(b, committed_addr, 16, 4, 32);
613    committed_hit_dwords_0_3 =
614       nir_vec4(b,
615                t_val,
616                nir_imm_float(b, 0.0f),
617                nir_imm_float(b, 0.0f),
618                nir_ior_imm(b,
619                            nir_ior_imm(b, nir_channel(b, potential_hit_dwords_0_3, 3), 0x000e0000),
620                            (0x1 << 16)                     /* valid */ |
621                            (BRW_RT_BVH_LEVEL_OBJECT << 24) /* leaf_type */));
622    brw_nir_rt_store(b, committed_addr, 16, committed_hit_dwords_0_3, 0xf /* write_mask */);
623 
624    /* Set:
625     *
626     *   committed.prim_leaf_ptr   = potential.prim_leaf_ptr;
627     *   committed.inst_leaf_ptr   = potential.inst_leaf_ptr;
628     */
629    brw_nir_memcpy_global(b,
630                          nir_iadd_imm(b, committed_addr, 16), 16,
631                          nir_iadd_imm(b, potential_addr, 16), 16,
632                          16);
633 }
634 
635 struct brw_nir_rt_mem_ray_defs {
636    nir_def *orig;
637    nir_def *dir;
638    nir_def *t_near;
639    nir_def *t_far;
640    nir_def *root_node_ptr;
641    nir_def *ray_flags;
642    nir_def *hit_group_sr_base_ptr;
643    nir_def *hit_group_sr_stride;
644    nir_def *miss_sr_ptr;
645    nir_def *shader_index_multiplier;
646    nir_def *inst_leaf_ptr;
647    nir_def *ray_mask;
648 };
649 
650 static inline void
brw_nir_rt_store_mem_ray_query_at_addr(nir_builder * b,nir_def * ray_addr,const struct brw_nir_rt_mem_ray_defs * defs)651 brw_nir_rt_store_mem_ray_query_at_addr(nir_builder *b,
652                                        nir_def *ray_addr,
653                                        const struct brw_nir_rt_mem_ray_defs *defs)
654 {
655    assert_def_size(defs->orig, 3, 32);
656    assert_def_size(defs->dir, 3, 32);
657    brw_nir_rt_store(b, nir_iadd_imm(b, ray_addr, 0), 16,
658       nir_vec4(b, nir_channel(b, defs->orig, 0),
659                   nir_channel(b, defs->orig, 1),
660                   nir_channel(b, defs->orig, 2),
661                   nir_channel(b, defs->dir, 0)),
662       ~0 /* write mask */);
663 
664    assert_def_size(defs->t_near, 1, 32);
665    assert_def_size(defs->t_far, 1, 32);
666    brw_nir_rt_store(b, nir_iadd_imm(b, ray_addr, 16), 16,
667       nir_vec4(b, nir_channel(b, defs->dir, 1),
668                   nir_channel(b, defs->dir, 2),
669                   defs->t_near,
670                   defs->t_far),
671       ~0 /* write mask */);
672 
673    assert_def_size(defs->root_node_ptr, 1, 64);
674    assert_def_size(defs->ray_flags, 1, 16);
675    brw_nir_rt_store(b, nir_iadd_imm(b, ray_addr, 32), 16,
676       nir_vec2(b, nir_unpack_64_2x32_split_x(b, defs->root_node_ptr),
677                   nir_pack_32_2x16_split(b,
678                      nir_unpack_64_4x16_split_z(b, defs->root_node_ptr),
679                      defs->ray_flags)),
680       0x3 /* write mask */);
681 
682    /* leaf_ptr is optional */
683    nir_def *inst_leaf_ptr;
684    if (defs->inst_leaf_ptr) {
685       inst_leaf_ptr = defs->inst_leaf_ptr;
686    } else {
687       inst_leaf_ptr = nir_imm_int64(b, 0);
688    }
689 
690    assert_def_size(inst_leaf_ptr, 1, 64);
691    assert_def_size(defs->ray_mask, 1, 32);
692    brw_nir_rt_store(b, nir_iadd_imm(b, ray_addr, 56), 8,
693       nir_vec2(b, nir_unpack_64_2x32_split_x(b, inst_leaf_ptr),
694                   nir_pack_32_2x16_split(b,
695                      nir_unpack_64_4x16_split_z(b, inst_leaf_ptr),
696                      nir_unpack_32_2x16_split_x(b, defs->ray_mask))),
697       ~0 /* write mask */);
698 }
699 
700 static inline void
brw_nir_rt_store_mem_ray(nir_builder * b,const struct brw_nir_rt_mem_ray_defs * defs,enum brw_rt_bvh_level bvh_level)701 brw_nir_rt_store_mem_ray(nir_builder *b,
702                          const struct brw_nir_rt_mem_ray_defs *defs,
703                          enum brw_rt_bvh_level bvh_level)
704 {
705    nir_def *ray_addr =
706       brw_nir_rt_mem_ray_addr(b, brw_nir_rt_stack_addr(b), bvh_level);
707 
708    assert_def_size(defs->orig, 3, 32);
709    assert_def_size(defs->dir, 3, 32);
710    brw_nir_rt_store(b, nir_iadd_imm(b, ray_addr, 0), 16,
711       nir_vec4(b, nir_channel(b, defs->orig, 0),
712                   nir_channel(b, defs->orig, 1),
713                   nir_channel(b, defs->orig, 2),
714                   nir_channel(b, defs->dir, 0)),
715       ~0 /* write mask */);
716 
717    assert_def_size(defs->t_near, 1, 32);
718    assert_def_size(defs->t_far, 1, 32);
719    brw_nir_rt_store(b, nir_iadd_imm(b, ray_addr, 16), 16,
720       nir_vec4(b, nir_channel(b, defs->dir, 1),
721                   nir_channel(b, defs->dir, 2),
722                   defs->t_near,
723                   defs->t_far),
724       ~0 /* write mask */);
725 
726    assert_def_size(defs->root_node_ptr, 1, 64);
727    assert_def_size(defs->ray_flags, 1, 16);
728    assert_def_size(defs->hit_group_sr_base_ptr, 1, 64);
729    assert_def_size(defs->hit_group_sr_stride, 1, 16);
730    brw_nir_rt_store(b, nir_iadd_imm(b, ray_addr, 32), 16,
731       nir_vec4(b, nir_unpack_64_2x32_split_x(b, defs->root_node_ptr),
732                   nir_pack_32_2x16_split(b,
733                      nir_unpack_64_4x16_split_z(b, defs->root_node_ptr),
734                      defs->ray_flags),
735                   nir_unpack_64_2x32_split_x(b, defs->hit_group_sr_base_ptr),
736                   nir_pack_32_2x16_split(b,
737                      nir_unpack_64_4x16_split_z(b, defs->hit_group_sr_base_ptr),
738                      defs->hit_group_sr_stride)),
739       ~0 /* write mask */);
740 
741    /* leaf_ptr is optional */
742    nir_def *inst_leaf_ptr;
743    if (defs->inst_leaf_ptr) {
744       inst_leaf_ptr = defs->inst_leaf_ptr;
745    } else {
746       inst_leaf_ptr = nir_imm_int64(b, 0);
747    }
748 
749    assert_def_size(defs->miss_sr_ptr, 1, 64);
750    assert_def_size(defs->shader_index_multiplier, 1, 32);
751    assert_def_size(inst_leaf_ptr, 1, 64);
752    assert_def_size(defs->ray_mask, 1, 32);
753    brw_nir_rt_store(b, nir_iadd_imm(b, ray_addr, 48), 16,
754       nir_vec4(b, nir_unpack_64_2x32_split_x(b, defs->miss_sr_ptr),
755                   nir_pack_32_2x16_split(b,
756                      nir_unpack_64_4x16_split_z(b, defs->miss_sr_ptr),
757                      nir_unpack_32_2x16_split_x(b,
758                         nir_ishl(b, defs->shader_index_multiplier,
759                                     nir_imm_int(b, 8)))),
760                   nir_unpack_64_2x32_split_x(b, inst_leaf_ptr),
761                   nir_pack_32_2x16_split(b,
762                      nir_unpack_64_4x16_split_z(b, inst_leaf_ptr),
763                      nir_unpack_32_2x16_split_x(b, defs->ray_mask))),
764       ~0 /* write mask */);
765 }
766 
767 static inline void
brw_nir_rt_load_mem_ray_from_addr(nir_builder * b,struct brw_nir_rt_mem_ray_defs * defs,nir_def * ray_base_addr,enum brw_rt_bvh_level bvh_level)768 brw_nir_rt_load_mem_ray_from_addr(nir_builder *b,
769                                   struct brw_nir_rt_mem_ray_defs *defs,
770                                   nir_def *ray_base_addr,
771                                   enum brw_rt_bvh_level bvh_level)
772 {
773    nir_def *ray_addr = brw_nir_rt_mem_ray_addr(b,
774                                                    ray_base_addr,
775                                                    bvh_level);
776 
777    nir_def *data[4] = {
778       brw_nir_rt_load(b, nir_iadd_imm(b, ray_addr,  0), 16, 4, 32),
779       brw_nir_rt_load(b, nir_iadd_imm(b, ray_addr, 16), 16, 4, 32),
780       brw_nir_rt_load(b, nir_iadd_imm(b, ray_addr, 32), 16, 4, 32),
781       brw_nir_rt_load(b, nir_iadd_imm(b, ray_addr, 48), 16, 4, 32),
782    };
783 
784    defs->orig = nir_trim_vector(b, data[0], 3);
785    defs->dir = nir_vec3(b, nir_channel(b, data[0], 3),
786                            nir_channel(b, data[1], 0),
787                            nir_channel(b, data[1], 1));
788    defs->t_near = nir_channel(b, data[1], 2);
789    defs->t_far = nir_channel(b, data[1], 3);
790    defs->root_node_ptr =
791       nir_pack_64_2x32_split(b, nir_channel(b, data[2], 0),
792                                 nir_extract_i16(b, nir_channel(b, data[2], 1),
793                                                    nir_imm_int(b, 0)));
794    defs->ray_flags =
795       nir_unpack_32_2x16_split_y(b, nir_channel(b, data[2], 1));
796    defs->hit_group_sr_base_ptr =
797       nir_pack_64_2x32_split(b, nir_channel(b, data[2], 2),
798                                 nir_extract_i16(b, nir_channel(b, data[2], 3),
799                                                    nir_imm_int(b, 0)));
800    defs->hit_group_sr_stride =
801       nir_unpack_32_2x16_split_y(b, nir_channel(b, data[2], 3));
802    defs->miss_sr_ptr =
803       nir_pack_64_2x32_split(b, nir_channel(b, data[3], 0),
804                                 nir_extract_i16(b, nir_channel(b, data[3], 1),
805                                                    nir_imm_int(b, 0)));
806    defs->shader_index_multiplier =
807       nir_ushr(b, nir_unpack_32_2x16_split_y(b, nir_channel(b, data[3], 1)),
808                   nir_imm_int(b, 8));
809    defs->inst_leaf_ptr =
810       nir_pack_64_2x32_split(b, nir_channel(b, data[3], 2),
811                                 nir_extract_i16(b, nir_channel(b, data[3], 3),
812                                                    nir_imm_int(b, 0)));
813    defs->ray_mask =
814       nir_unpack_32_2x16_split_y(b, nir_channel(b, data[3], 3));
815 }
816 
817 static inline void
brw_nir_rt_load_mem_ray(nir_builder * b,struct brw_nir_rt_mem_ray_defs * defs,enum brw_rt_bvh_level bvh_level)818 brw_nir_rt_load_mem_ray(nir_builder *b,
819                         struct brw_nir_rt_mem_ray_defs *defs,
820                         enum brw_rt_bvh_level bvh_level)
821 {
822    brw_nir_rt_load_mem_ray_from_addr(b, defs, brw_nir_rt_stack_addr(b),
823                                      bvh_level);
824 }
825 
826 struct brw_nir_rt_bvh_instance_leaf_defs {
827    nir_def *shader_index;
828    nir_def *contribution_to_hit_group_index;
829    nir_def *world_to_object[4];
830    nir_def *instance_id;
831    nir_def *instance_index;
832    nir_def *object_to_world[4];
833 };
834 
835 static inline void
brw_nir_rt_load_bvh_instance_leaf(nir_builder * b,struct brw_nir_rt_bvh_instance_leaf_defs * defs,nir_def * leaf_addr)836 brw_nir_rt_load_bvh_instance_leaf(nir_builder *b,
837                                   struct brw_nir_rt_bvh_instance_leaf_defs *defs,
838                                   nir_def *leaf_addr)
839 {
840    nir_def *leaf_desc = brw_nir_rt_load(b, leaf_addr, 4, 2, 32);
841 
842    defs->shader_index =
843       nir_iand_imm(b, nir_channel(b, leaf_desc, 0), (1 << 24) - 1);
844    defs->contribution_to_hit_group_index =
845       nir_iand_imm(b, nir_channel(b, leaf_desc, 1), (1 << 24) - 1);
846 
847    defs->world_to_object[0] =
848       brw_nir_rt_load(b, nir_iadd_imm(b, leaf_addr, 16), 4, 3, 32);
849    defs->world_to_object[1] =
850       brw_nir_rt_load(b, nir_iadd_imm(b, leaf_addr, 28), 4, 3, 32);
851    defs->world_to_object[2] =
852       brw_nir_rt_load(b, nir_iadd_imm(b, leaf_addr, 40), 4, 3, 32);
853    /* The last column of the matrices is swapped between the two probably
854     * because it makes it easier/faster for hardware somehow.
855     */
856    defs->object_to_world[3] =
857       brw_nir_rt_load(b, nir_iadd_imm(b, leaf_addr, 52), 4, 3, 32);
858 
859    nir_def *data =
860       brw_nir_rt_load(b, nir_iadd_imm(b, leaf_addr, 64), 4, 4, 32);
861    defs->instance_id = nir_channel(b, data, 2);
862    defs->instance_index = nir_channel(b, data, 3);
863 
864    defs->object_to_world[0] =
865       brw_nir_rt_load(b, nir_iadd_imm(b, leaf_addr, 80), 4, 3, 32);
866    defs->object_to_world[1] =
867       brw_nir_rt_load(b, nir_iadd_imm(b, leaf_addr, 92), 4, 3, 32);
868    defs->object_to_world[2] =
869       brw_nir_rt_load(b, nir_iadd_imm(b, leaf_addr, 104), 4, 3, 32);
870    defs->world_to_object[3] =
871       brw_nir_rt_load(b, nir_iadd_imm(b, leaf_addr, 116), 4, 3, 32);
872 }
873 
874 struct brw_nir_rt_bvh_primitive_leaf_defs {
875    nir_def *shader_index;
876    nir_def *geom_mask;
877    nir_def *geom_index;
878    nir_def *type;
879    nir_def *geom_flags;
880 };
881 
882 static inline void
brw_nir_rt_load_bvh_primitive_leaf(nir_builder * b,struct brw_nir_rt_bvh_primitive_leaf_defs * defs,nir_def * leaf_addr)883 brw_nir_rt_load_bvh_primitive_leaf(nir_builder *b,
884                                    struct brw_nir_rt_bvh_primitive_leaf_defs *defs,
885                                    nir_def *leaf_addr)
886 {
887    nir_def *desc = brw_nir_rt_load(b, leaf_addr, 4, 2, 32);
888 
889    defs->shader_index =
890       nir_ubitfield_extract(b, nir_channel(b, desc, 0),
891                             nir_imm_int(b, 23), nir_imm_int(b, 0));
892    defs->geom_mask =
893       nir_ubitfield_extract(b, nir_channel(b, desc, 0),
894                             nir_imm_int(b, 31), nir_imm_int(b, 24));
895 
896    defs->geom_index =
897       nir_ubitfield_extract(b, nir_channel(b, desc, 1),
898                             nir_imm_int(b, 28), nir_imm_int(b, 0));
899    defs->type =
900       nir_ubitfield_extract(b, nir_channel(b, desc, 1),
901                             nir_imm_int(b, 29), nir_imm_int(b, 29));
902    defs->geom_flags =
903       nir_ubitfield_extract(b, nir_channel(b, desc, 1),
904                             nir_imm_int(b, 31), nir_imm_int(b, 30));
905 }
906 
907 struct brw_nir_rt_bvh_primitive_leaf_positions_defs {
908    nir_def *positions[3];
909 };
910 
911 static inline void
brw_nir_rt_load_bvh_primitive_leaf_positions(nir_builder * b,struct brw_nir_rt_bvh_primitive_leaf_positions_defs * defs,nir_def * leaf_addr)912 brw_nir_rt_load_bvh_primitive_leaf_positions(nir_builder *b,
913                                              struct brw_nir_rt_bvh_primitive_leaf_positions_defs *defs,
914                                              nir_def *leaf_addr)
915 {
916    for (unsigned i = 0; i < ARRAY_SIZE(defs->positions); i++) {
917       defs->positions[i] =
918          brw_nir_rt_load(b, nir_iadd_imm(b, leaf_addr, 16 + i * 4 * 3), 4, 3, 32);
919    }
920 }
921 
922 static inline nir_def *
brw_nir_rt_load_primitive_id_from_hit(nir_builder * b,nir_def * is_procedural,const struct brw_nir_rt_mem_hit_defs * defs)923 brw_nir_rt_load_primitive_id_from_hit(nir_builder *b,
924                                       nir_def *is_procedural,
925                                       const struct brw_nir_rt_mem_hit_defs *defs)
926 {
927    if (!is_procedural) {
928       is_procedural =
929          nir_ieq_imm(b, defs->leaf_type,
930                         BRW_RT_BVH_NODE_TYPE_PROCEDURAL);
931    }
932 
933    nir_def *prim_id_proc, *prim_id_quad;
934    nir_push_if(b, is_procedural);
935    {
936       /* For procedural leafs, the index is in dw[3]. */
937       nir_def *offset =
938          nir_iadd_imm(b, nir_ishl_imm(b, defs->prim_leaf_index, 2), 12);
939       prim_id_proc = nir_load_global(b, nir_iadd(b, defs->prim_leaf_ptr,
940                                                  nir_u2u64(b, offset)),
941                                      4, /* align */ 1, 32);
942    }
943    nir_push_else(b, NULL);
944    {
945       /* For quad leafs, the index is dw[2] and there is a 16bit additional
946        * offset in dw[3].
947        */
948       prim_id_quad = nir_load_global(b, nir_iadd_imm(b, defs->prim_leaf_ptr, 8),
949                                      4, /* align */ 1, 32);
950       prim_id_quad = nir_iadd(b,
951                               prim_id_quad,
952                               defs->prim_index_delta);
953    }
954    nir_pop_if(b, NULL);
955 
956    return nir_if_phi(b, prim_id_proc, prim_id_quad);
957 }
958 
959 static inline nir_def *
brw_nir_rt_acceleration_structure_to_root_node(nir_builder * b,nir_def * as_addr)960 brw_nir_rt_acceleration_structure_to_root_node(nir_builder *b,
961                                                nir_def *as_addr)
962 {
963    /* The HW memory structure in which we specify what acceleration structure
964     * to traverse, takes the address to the root node in the acceleration
965     * structure, not the acceleration structure itself. To find that, we have
966     * to read the root node offset from the acceleration structure which is
967     * the first QWord.
968     *
969     * But if the acceleration structure pointer is NULL, then we should return
970     * NULL as root node pointer.
971     *
972     * TODO: we could optimize this by assuming that for a given version of the
973     * BVH, we can find the root node at a given offset.
974     */
975    nir_def *root_node_ptr, *null_node_ptr;
976    nir_push_if(b, nir_ieq_imm(b, as_addr, 0));
977    {
978       null_node_ptr = nir_imm_int64(b, 0);
979    }
980    nir_push_else(b, NULL);
981    {
982       root_node_ptr =
983          nir_iadd(b, as_addr, brw_nir_rt_load(b, as_addr, 256, 1, 64));
984    }
985    nir_pop_if(b, NULL);
986 
987    return nir_if_phi(b, null_node_ptr, root_node_ptr);
988 }
989 
990 #endif /* BRW_NIR_RT_BUILDER_H */
991