• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright © 2020 Intel Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 #ifndef BRW_NIR_RT_BUILDER_H
25 #define BRW_NIR_RT_BUILDER_H
26 
27 /* This file provides helpers to access memory based data structures that the
28  * RT hardware reads/writes and their locations.
29  *
30  * See also "Memory Based Data Structures for Ray Tracing" (BSpec 47547) and
31  * "Ray Tracing Address Computation for Memory Resident Structures" (BSpec
32  * 47550).
33  */
34 
35 #include "brw_rt.h"
36 #include "nir_builder.h"
37 
38 #define is_access_for_builder(b) \
39    ((b)->shader->info.stage == MESA_SHADER_FRAGMENT ? \
40     ACCESS_INCLUDE_HELPERS : 0)
41 
42 static inline nir_ssa_def *
brw_nir_rt_load(nir_builder * b,nir_ssa_def * addr,unsigned align,unsigned components,unsigned bit_size)43 brw_nir_rt_load(nir_builder *b, nir_ssa_def *addr, unsigned align,
44                 unsigned components, unsigned bit_size)
45 {
46    return nir_build_load_global(b, components, bit_size, addr,
47                                 .align_mul = align,
48                                 .access = is_access_for_builder(b));
49 }
50 
51 static inline void
brw_nir_rt_store(nir_builder * b,nir_ssa_def * addr,unsigned align,nir_ssa_def * value,unsigned write_mask)52 brw_nir_rt_store(nir_builder *b, nir_ssa_def *addr, unsigned align,
53                  nir_ssa_def *value, unsigned write_mask)
54 {
55    nir_build_store_global(b, value, addr,
56                           .align_mul = align,
57                           .write_mask = (write_mask) &
58                                         BITFIELD_MASK(value->num_components),
59                           .access = is_access_for_builder(b));
60 }
61 
62 static inline nir_ssa_def *
brw_nir_rt_load_const(nir_builder * b,unsigned components,nir_ssa_def * addr,nir_ssa_def * pred)63 brw_nir_rt_load_const(nir_builder *b, unsigned components,
64                       nir_ssa_def *addr, nir_ssa_def *pred)
65 {
66    return nir_build_load_global_const_block_intel(b, components, addr, pred);
67 }
68 
69 static inline nir_ssa_def *
brw_load_btd_dss_id(nir_builder * b)70 brw_load_btd_dss_id(nir_builder *b)
71 {
72    return nir_build_load_topology_id_intel(b, .base = BRW_TOPOLOGY_ID_DSS);
73 }
74 
75 static inline nir_ssa_def *
brw_nir_rt_load_num_simd_lanes_per_dss(nir_builder * b,const struct intel_device_info * devinfo)76 brw_nir_rt_load_num_simd_lanes_per_dss(nir_builder *b,
77                                        const struct intel_device_info *devinfo)
78 {
79    return nir_imm_int(b, devinfo->num_thread_per_eu *
80                          devinfo->max_eus_per_subslice *
81                          16 /* The RT computation is based off SIMD16 */);
82 }
83 
84 static inline nir_ssa_def *
brw_load_eu_thread_simd(nir_builder * b)85 brw_load_eu_thread_simd(nir_builder *b)
86 {
87    return nir_build_load_topology_id_intel(b, .base = BRW_TOPOLOGY_ID_EU_THREAD_SIMD);
88 }
89 
90 static inline nir_ssa_def *
brw_nir_rt_async_stack_id(nir_builder * b)91 brw_nir_rt_async_stack_id(nir_builder *b)
92 {
93    return nir_iadd(b, nir_umul_32x16(b, nir_load_ray_num_dss_rt_stacks_intel(b),
94                                         brw_load_btd_dss_id(b)),
95                       nir_load_btd_stack_id_intel(b));
96 }
97 
98 static inline nir_ssa_def *
brw_nir_rt_sync_stack_id(nir_builder * b)99 brw_nir_rt_sync_stack_id(nir_builder *b)
100 {
101    return brw_load_eu_thread_simd(b);
102 }
103 
104 /* We have our own load/store scratch helpers because they emit a global
105  * memory read or write based on the scratch_base_ptr system value rather
106  * than a load/store_scratch intrinsic.
107  */
108 static inline nir_ssa_def *
brw_nir_rt_load_scratch(nir_builder * b,uint32_t offset,unsigned align,unsigned num_components,unsigned bit_size)109 brw_nir_rt_load_scratch(nir_builder *b, uint32_t offset, unsigned align,
110                         unsigned num_components, unsigned bit_size)
111 {
112    nir_ssa_def *addr =
113       nir_iadd_imm(b, nir_load_scratch_base_ptr(b, 1, 64, 1), offset);
114    return brw_nir_rt_load(b, addr, MIN2(align, BRW_BTD_STACK_ALIGN),
115                              num_components, bit_size);
116 }
117 
118 static inline void
brw_nir_rt_store_scratch(nir_builder * b,uint32_t offset,unsigned align,nir_ssa_def * value,nir_component_mask_t write_mask)119 brw_nir_rt_store_scratch(nir_builder *b, uint32_t offset, unsigned align,
120                          nir_ssa_def *value, nir_component_mask_t write_mask)
121 {
122    nir_ssa_def *addr =
123       nir_iadd_imm(b, nir_load_scratch_base_ptr(b, 1, 64, 1), offset);
124    brw_nir_rt_store(b, addr, MIN2(align, BRW_BTD_STACK_ALIGN),
125                     value, write_mask);
126 }
127 
128 static inline void
brw_nir_btd_spawn(nir_builder * b,nir_ssa_def * record_addr)129 brw_nir_btd_spawn(nir_builder *b, nir_ssa_def *record_addr)
130 {
131    nir_btd_spawn_intel(b, nir_load_btd_global_arg_addr_intel(b), record_addr);
132 }
133 
134 static inline void
brw_nir_btd_retire(nir_builder * b)135 brw_nir_btd_retire(nir_builder *b)
136 {
137    nir_btd_retire_intel(b);
138 }
139 
140 /** This is a pseudo-op which does a bindless return
141  *
142  * It loads the return address from the stack and calls btd_spawn to spawn the
143  * resume shader.
144  */
145 static inline void
brw_nir_btd_return(struct nir_builder * b)146 brw_nir_btd_return(struct nir_builder *b)
147 {
148    nir_ssa_def *resume_addr =
149       brw_nir_rt_load_scratch(b, BRW_BTD_STACK_RESUME_BSR_ADDR_OFFSET,
150                               8 /* align */, 1, 64);
151    brw_nir_btd_spawn(b, resume_addr);
152 }
153 
154 static inline void
assert_def_size(nir_ssa_def * def,unsigned num_components,unsigned bit_size)155 assert_def_size(nir_ssa_def *def, unsigned num_components, unsigned bit_size)
156 {
157    assert(def->num_components == num_components);
158    assert(def->bit_size == bit_size);
159 }
160 
161 static inline nir_ssa_def *
brw_nir_num_rt_stacks(nir_builder * b,const struct intel_device_info * devinfo)162 brw_nir_num_rt_stacks(nir_builder *b,
163                       const struct intel_device_info *devinfo)
164 {
165    return nir_imul_imm(b, nir_load_ray_num_dss_rt_stacks_intel(b),
166                           intel_device_info_num_dual_subslices(devinfo));
167 }
168 
169 static inline nir_ssa_def *
brw_nir_rt_sw_hotzone_addr(nir_builder * b,const struct intel_device_info * devinfo)170 brw_nir_rt_sw_hotzone_addr(nir_builder *b,
171                            const struct intel_device_info *devinfo)
172 {
173    nir_ssa_def *offset32 =
174       nir_imul_imm(b, brw_nir_rt_async_stack_id(b),
175                       BRW_RT_SIZEOF_HOTZONE);
176 
177    offset32 = nir_iadd(b, offset32, nir_ineg(b,
178       nir_imul_imm(b, brw_nir_num_rt_stacks(b, devinfo),
179                       BRW_RT_SIZEOF_HOTZONE)));
180 
181    return nir_iadd(b, nir_load_ray_base_mem_addr_intel(b),
182                       nir_i2i64(b, offset32));
183 }
184 
185 static inline nir_ssa_def *
brw_nir_rt_sync_stack_addr(nir_builder * b,nir_ssa_def * base_mem_addr,const struct intel_device_info * devinfo)186 brw_nir_rt_sync_stack_addr(nir_builder *b,
187                            nir_ssa_def *base_mem_addr,
188                            const struct intel_device_info *devinfo)
189 {
190    /* For Ray queries (Synchronous Ray Tracing), the formula is similar but
191     * goes down from rtMemBasePtr :
192     *
193     *    syncBase  = RTDispatchGlobals.rtMemBasePtr
194     *              - (DSSID * NUM_SIMD_LANES_PER_DSS + SyncStackID + 1)
195     *              * syncStackSize
196     *
197     * We assume that we can calculate a 32-bit offset first and then add it
198     * to the 64-bit base address at the end.
199     */
200    nir_ssa_def *offset32 =
201       nir_imul(b,
202                nir_iadd(b,
203                         nir_imul(b, brw_load_btd_dss_id(b),
204                                     brw_nir_rt_load_num_simd_lanes_per_dss(b, devinfo)),
205                         nir_iadd_imm(b, brw_nir_rt_sync_stack_id(b), 1)),
206                nir_imm_int(b, BRW_RT_SIZEOF_RAY_QUERY));
207    return nir_isub(b, base_mem_addr, nir_u2u64(b, offset32));
208 }
209 
210 static inline nir_ssa_def *
brw_nir_rt_stack_addr(nir_builder * b)211 brw_nir_rt_stack_addr(nir_builder *b)
212 {
213    /* From the BSpec "Address Computation for Memory Based Data Structures:
214     * Ray and TraversalStack (Async Ray Tracing)":
215     *
216     *    stackBase = RTDispatchGlobals.rtMemBasePtr
217     *              + (DSSID * RTDispatchGlobals.numDSSRTStacks + stackID)
218     *              * RTDispatchGlobals.stackSizePerRay // 64B aligned
219     *
220     * We assume that we can calculate a 32-bit offset first and then add it
221     * to the 64-bit base address at the end.
222     */
223    nir_ssa_def *offset32 =
224       nir_imul(b, brw_nir_rt_async_stack_id(b),
225                   nir_load_ray_hw_stack_size_intel(b));
226    return nir_iadd(b, nir_load_ray_base_mem_addr_intel(b),
227                       nir_u2u64(b, offset32));
228 }
229 
230 static inline nir_ssa_def *
brw_nir_rt_mem_hit_addr_from_addr(nir_builder * b,nir_ssa_def * stack_addr,bool committed)231 brw_nir_rt_mem_hit_addr_from_addr(nir_builder *b,
232                         nir_ssa_def *stack_addr,
233                         bool committed)
234 {
235    return nir_iadd_imm(b, stack_addr, committed ? 0 : BRW_RT_SIZEOF_HIT_INFO);
236 }
237 
238 static inline nir_ssa_def *
brw_nir_rt_mem_hit_addr(nir_builder * b,bool committed)239 brw_nir_rt_mem_hit_addr(nir_builder *b, bool committed)
240 {
241    return nir_iadd_imm(b, brw_nir_rt_stack_addr(b),
242                           committed ? 0 : BRW_RT_SIZEOF_HIT_INFO);
243 }
244 
245 static inline nir_ssa_def *
brw_nir_rt_hit_attrib_data_addr(nir_builder * b)246 brw_nir_rt_hit_attrib_data_addr(nir_builder *b)
247 {
248    return nir_iadd_imm(b, brw_nir_rt_stack_addr(b),
249                           BRW_RT_OFFSETOF_HIT_ATTRIB_DATA);
250 }
251 
252 static inline nir_ssa_def *
brw_nir_rt_mem_ray_addr(nir_builder * b,nir_ssa_def * stack_addr,enum brw_rt_bvh_level bvh_level)253 brw_nir_rt_mem_ray_addr(nir_builder *b,
254                         nir_ssa_def *stack_addr,
255                         enum brw_rt_bvh_level bvh_level)
256 {
257    /* From the BSpec "Address Computation for Memory Based Data Structures:
258     * Ray and TraversalStack (Async Ray Tracing)":
259     *
260     *    rayBase = stackBase + sizeof(HitInfo) * 2 // 64B aligned
261     *    rayPtr  = rayBase + bvhLevel * sizeof(Ray); // 64B aligned
262     *
263     * In Vulkan, we always have exactly two levels of BVH: World and Object.
264     */
265    uint32_t offset = BRW_RT_SIZEOF_HIT_INFO * 2 +
266                      bvh_level * BRW_RT_SIZEOF_RAY;
267    return nir_iadd_imm(b, stack_addr, offset);
268 }
269 
270 static inline nir_ssa_def *
brw_nir_rt_sw_stack_addr(nir_builder * b,const struct intel_device_info * devinfo)271 brw_nir_rt_sw_stack_addr(nir_builder *b,
272                          const struct intel_device_info *devinfo)
273 {
274    nir_ssa_def *addr = nir_load_ray_base_mem_addr_intel(b);
275 
276    nir_ssa_def *offset32 = nir_imul(b, brw_nir_num_rt_stacks(b, devinfo),
277                                        nir_load_ray_hw_stack_size_intel(b));
278    addr = nir_iadd(b, addr, nir_u2u64(b, offset32));
279 
280    nir_ssa_def *offset_in_stack =
281       nir_imul(b, nir_u2u64(b, brw_nir_rt_async_stack_id(b)),
282                   nir_u2u64(b, nir_load_ray_sw_stack_size_intel(b)));
283 
284    return nir_iadd(b, addr, offset_in_stack);
285 }
286 
287 static inline nir_ssa_def *
nir_unpack_64_4x16_split_z(nir_builder * b,nir_ssa_def * val)288 nir_unpack_64_4x16_split_z(nir_builder *b, nir_ssa_def *val)
289 {
290    return nir_unpack_32_2x16_split_x(b, nir_unpack_64_2x32_split_y(b, val));
291 }
292 
293 struct brw_nir_rt_globals_defs {
294    nir_ssa_def *base_mem_addr;
295    nir_ssa_def *call_stack_handler_addr;
296    nir_ssa_def *hw_stack_size;
297    nir_ssa_def *num_dss_rt_stacks;
298    nir_ssa_def *hit_sbt_addr;
299    nir_ssa_def *hit_sbt_stride;
300    nir_ssa_def *miss_sbt_addr;
301    nir_ssa_def *miss_sbt_stride;
302    nir_ssa_def *sw_stack_size;
303    nir_ssa_def *launch_size;
304    nir_ssa_def *call_sbt_addr;
305    nir_ssa_def *call_sbt_stride;
306    nir_ssa_def *resume_sbt_addr;
307 };
308 
309 static inline void
brw_nir_rt_load_globals_addr(nir_builder * b,struct brw_nir_rt_globals_defs * defs,nir_ssa_def * addr)310 brw_nir_rt_load_globals_addr(nir_builder *b,
311                              struct brw_nir_rt_globals_defs *defs,
312                              nir_ssa_def *addr)
313 {
314    nir_ssa_def *data;
315    data = brw_nir_rt_load_const(b, 16, addr, nir_imm_true(b));
316    defs->base_mem_addr = nir_pack_64_2x32(b, nir_channels(b, data, 0x3));
317 
318    defs->call_stack_handler_addr =
319       nir_pack_64_2x32(b, nir_channels(b, data, 0x3 << 2));
320 
321    defs->hw_stack_size = nir_channel(b, data, 4);
322    defs->num_dss_rt_stacks = nir_iand_imm(b, nir_channel(b, data, 5), 0xffff);
323    defs->hit_sbt_addr =
324       nir_pack_64_2x32_split(b, nir_channel(b, data, 8),
325                                 nir_extract_i16(b, nir_channel(b, data, 9),
326                                                    nir_imm_int(b, 0)));
327    defs->hit_sbt_stride =
328       nir_unpack_32_2x16_split_y(b, nir_channel(b, data, 9));
329    defs->miss_sbt_addr =
330       nir_pack_64_2x32_split(b, nir_channel(b, data, 10),
331                                 nir_extract_i16(b, nir_channel(b, data, 11),
332                                                    nir_imm_int(b, 0)));
333    defs->miss_sbt_stride =
334       nir_unpack_32_2x16_split_y(b, nir_channel(b, data, 11));
335    defs->sw_stack_size = nir_channel(b, data, 12);
336    defs->launch_size = nir_channels(b, data, 0x7u << 13);
337 
338    data = brw_nir_rt_load_const(b, 8, nir_iadd_imm(b, addr, 64), nir_imm_true(b));
339    defs->call_sbt_addr =
340       nir_pack_64_2x32_split(b, nir_channel(b, data, 0),
341                                 nir_extract_i16(b, nir_channel(b, data, 1),
342                                                    nir_imm_int(b, 0)));
343    defs->call_sbt_stride =
344       nir_unpack_32_2x16_split_y(b, nir_channel(b, data, 1));
345 
346    defs->resume_sbt_addr =
347       nir_pack_64_2x32(b, nir_channels(b, data, 0x3 << 2));
348 }
349 
350 static inline void
brw_nir_rt_load_globals(nir_builder * b,struct brw_nir_rt_globals_defs * defs)351 brw_nir_rt_load_globals(nir_builder *b,
352                         struct brw_nir_rt_globals_defs *defs)
353 {
354    brw_nir_rt_load_globals_addr(b, defs, nir_load_btd_global_arg_addr_intel(b));
355 }
356 
357 static inline nir_ssa_def *
brw_nir_rt_unpack_leaf_ptr(nir_builder * b,nir_ssa_def * vec2)358 brw_nir_rt_unpack_leaf_ptr(nir_builder *b, nir_ssa_def *vec2)
359 {
360    /* Hit record leaf pointers are 42-bit and assumed to be in 64B chunks.
361     * This leaves 22 bits at the top for other stuff.
362     */
363    nir_ssa_def *ptr64 = nir_imul_imm(b, nir_pack_64_2x32(b, vec2), 64);
364 
365    /* The top 16 bits (remember, we shifted by 6 already) contain garbage
366     * that we need to get rid of.
367     */
368    nir_ssa_def *ptr_lo = nir_unpack_64_2x32_split_x(b, ptr64);
369    nir_ssa_def *ptr_hi = nir_unpack_64_2x32_split_y(b, ptr64);
370    ptr_hi = nir_extract_i16(b, ptr_hi, nir_imm_int(b, 0));
371    return nir_pack_64_2x32_split(b, ptr_lo, ptr_hi);
372 }
373 
374 /**
375  * MemHit memory layout (BSpec 47547) :
376  *
377  *      name            bits    description
378  *    - t               32      hit distance of current hit (or initial traversal distance)
379  *    - u               32      barycentric hit coordinates
380  *    - v               32      barycentric hit coordinates
381  *    - primIndexDelta  16      prim index delta for compressed meshlets and quads
382  *    - valid            1      set if there is a hit
383  *    - leafType         3      type of node primLeafPtr is pointing to
384  *    - primLeafIndex    4      index of the hit primitive inside the leaf
385  *    - bvhLevel         3      the instancing level at which the hit occured
386  *    - frontFace        1      whether we hit the front-facing side of a triangle (also used to pass opaque flag when calling intersection shaders)
387  *    - pad0             4      unused bits
388  *    - primLeafPtr     42      pointer to BVH leaf node (multiple of 64 bytes)
389  *    - hitGroupRecPtr0 22      LSB of hit group record of the hit triangle (multiple of 16 bytes)
390  *    - instLeafPtr     42      pointer to BVH instance leaf node (in multiple of 64 bytes)
391  *    - hitGroupRecPtr1 22      MSB of hit group record of the hit triangle (multiple of 32 bytes)
392  */
393 struct brw_nir_rt_mem_hit_defs {
394    nir_ssa_def *t;
395    nir_ssa_def *tri_bary; /**< Only valid for triangle geometry */
396    nir_ssa_def *aabb_hit_kind; /**< Only valid for AABB geometry */
397    nir_ssa_def *valid;
398    nir_ssa_def *leaf_type;
399    nir_ssa_def *prim_leaf_index;
400    nir_ssa_def *bvh_level;
401    nir_ssa_def *front_face;
402    nir_ssa_def *done; /**< Only for ray queries */
403    nir_ssa_def *prim_leaf_ptr;
404    nir_ssa_def *inst_leaf_ptr;
405 };
406 
407 static inline void
brw_nir_rt_load_mem_hit_from_addr(nir_builder * b,struct brw_nir_rt_mem_hit_defs * defs,nir_ssa_def * stack_addr,bool committed)408 brw_nir_rt_load_mem_hit_from_addr(nir_builder *b,
409                                   struct brw_nir_rt_mem_hit_defs *defs,
410                                   nir_ssa_def *stack_addr,
411                                   bool committed)
412 {
413    nir_ssa_def *hit_addr =
414       brw_nir_rt_mem_hit_addr_from_addr(b, stack_addr, committed);
415 
416    nir_ssa_def *data = brw_nir_rt_load(b, hit_addr, 16, 4, 32);
417    defs->t = nir_channel(b, data, 0);
418    defs->aabb_hit_kind = nir_channel(b, data, 1);
419    defs->tri_bary = nir_channels(b, data, 0x6);
420    nir_ssa_def *bitfield = nir_channel(b, data, 3);
421    defs->valid = nir_i2b(b, nir_iand_imm(b, bitfield, 1u << 16));
422    defs->leaf_type =
423       nir_ubitfield_extract(b, bitfield, nir_imm_int(b, 17), nir_imm_int(b, 3));
424    defs->prim_leaf_index =
425       nir_ubitfield_extract(b, bitfield, nir_imm_int(b, 20), nir_imm_int(b, 4));
426    defs->bvh_level =
427       nir_ubitfield_extract(b, bitfield, nir_imm_int(b, 24), nir_imm_int(b, 3));
428    defs->front_face = nir_i2b(b, nir_iand_imm(b, bitfield, 1 << 27));
429    defs->done = nir_i2b(b, nir_iand_imm(b, bitfield, 1 << 28));
430 
431    data = brw_nir_rt_load(b, nir_iadd_imm(b, hit_addr, 16), 16, 4, 32);
432    defs->prim_leaf_ptr =
433       brw_nir_rt_unpack_leaf_ptr(b, nir_channels(b, data, 0x3 << 0));
434    defs->inst_leaf_ptr =
435       brw_nir_rt_unpack_leaf_ptr(b, nir_channels(b, data, 0x3 << 2));
436 }
437 
438 static inline void
brw_nir_rt_init_mem_hit_at_addr(nir_builder * b,nir_ssa_def * stack_addr,bool committed,nir_ssa_def * t_max)439 brw_nir_rt_init_mem_hit_at_addr(nir_builder *b,
440                                 nir_ssa_def *stack_addr,
441                                 bool committed,
442                                 nir_ssa_def *t_max)
443 {
444    nir_ssa_def *mem_hit_addr =
445       brw_nir_rt_mem_hit_addr_from_addr(b, stack_addr, committed);
446 
447    /* Set the t_max value from the ray initialization */
448    nir_ssa_def *hit_t_addr = mem_hit_addr;
449    brw_nir_rt_store(b, hit_t_addr, 4, t_max, 0x1);
450 
451    /* Clear all the flags packed behind primIndexDelta */
452    nir_ssa_def *state_addr = nir_iadd_imm(b, mem_hit_addr, 12);
453    brw_nir_rt_store(b, state_addr, 4, nir_imm_int(b, 0), 0x1);
454 }
455 
456 static inline void
brw_nir_rt_load_mem_hit(nir_builder * b,struct brw_nir_rt_mem_hit_defs * defs,bool committed)457 brw_nir_rt_load_mem_hit(nir_builder *b,
458                         struct brw_nir_rt_mem_hit_defs *defs,
459                         bool committed)
460 {
461    brw_nir_rt_load_mem_hit_from_addr(b, defs, brw_nir_rt_stack_addr(b),
462                                      committed);
463 }
464 
465 static inline void
brw_nir_memcpy_global(nir_builder * b,nir_ssa_def * dst_addr,uint32_t dst_align,nir_ssa_def * src_addr,uint32_t src_align,uint32_t size)466 brw_nir_memcpy_global(nir_builder *b,
467                       nir_ssa_def *dst_addr, uint32_t dst_align,
468                       nir_ssa_def *src_addr, uint32_t src_align,
469                       uint32_t size)
470 {
471    /* We're going to copy in 16B chunks */
472    assert(size % 16 == 0);
473    dst_align = MIN2(dst_align, 16);
474    src_align = MIN2(src_align, 16);
475 
476    for (unsigned offset = 0; offset < size; offset += 16) {
477       nir_ssa_def *data =
478          brw_nir_rt_load(b, nir_iadd_imm(b, src_addr, offset), src_align,
479                          4, 32);
480       brw_nir_rt_store(b, nir_iadd_imm(b, dst_addr, offset), dst_align,
481                        data, 0xf /* write_mask */);
482    }
483 }
484 
485 static inline void
brw_nir_memclear_global(nir_builder * b,nir_ssa_def * dst_addr,uint32_t dst_align,uint32_t size)486 brw_nir_memclear_global(nir_builder *b,
487                         nir_ssa_def *dst_addr, uint32_t dst_align,
488                         uint32_t size)
489 {
490    /* We're going to copy in 16B chunks */
491    assert(size % 16 == 0);
492    dst_align = MIN2(dst_align, 16);
493 
494    nir_ssa_def *zero = nir_imm_ivec4(b, 0, 0, 0, 0);
495    for (unsigned offset = 0; offset < size; offset += 16) {
496       brw_nir_rt_store(b, nir_iadd_imm(b, dst_addr, offset), dst_align,
497                        zero, 0xf /* write_mask */);
498    }
499 }
500 
501 static inline nir_ssa_def *
brw_nir_rt_query_done(nir_builder * b,nir_ssa_def * stack_addr)502 brw_nir_rt_query_done(nir_builder *b, nir_ssa_def *stack_addr)
503 {
504    struct brw_nir_rt_mem_hit_defs hit_in = {};
505    brw_nir_rt_load_mem_hit_from_addr(b, &hit_in, stack_addr,
506                                      false /* committed */);
507 
508    return hit_in.done;
509 }
510 
511 static inline void
brw_nir_rt_set_dword_bit_at(nir_builder * b,nir_ssa_def * addr,uint32_t addr_offset,uint32_t bit)512 brw_nir_rt_set_dword_bit_at(nir_builder *b,
513                             nir_ssa_def *addr,
514                             uint32_t addr_offset,
515                             uint32_t bit)
516 {
517    nir_ssa_def *dword_addr = nir_iadd_imm(b, addr, addr_offset);
518    nir_ssa_def *dword = brw_nir_rt_load(b, dword_addr, 4, 1, 32);
519    brw_nir_rt_store(b, dword_addr, 4, nir_ior_imm(b, dword, 1u << bit), 0x1);
520 }
521 
522 static inline void
brw_nir_rt_query_mark_done(nir_builder * b,nir_ssa_def * stack_addr)523 brw_nir_rt_query_mark_done(nir_builder *b, nir_ssa_def *stack_addr)
524 {
525    brw_nir_rt_set_dword_bit_at(b,
526                                brw_nir_rt_mem_hit_addr_from_addr(b, stack_addr,
527                                                                  false /* committed */),
528                                4 * 3 /* dword offset */, 28 /* bit */);
529 }
530 
531 /* This helper clears the 3rd dword of the MemHit structure where the valid
532  * bit is located.
533  */
534 static inline void
brw_nir_rt_query_mark_init(nir_builder * b,nir_ssa_def * stack_addr)535 brw_nir_rt_query_mark_init(nir_builder *b, nir_ssa_def *stack_addr)
536 {
537    nir_ssa_def *dword_addr;
538 
539    for (uint32_t i = 0; i < 2; i++) {
540       dword_addr =
541          nir_iadd_imm(b,
542                       brw_nir_rt_mem_hit_addr_from_addr(b, stack_addr,
543                                                         i == 0 /* committed */),
544                       4 * 3 /* dword offset */);
545       brw_nir_rt_store(b, dword_addr, 4, nir_imm_int(b, 0), 0x1);
546    }
547 }
548 
549 /* This helper is pretty much a memcpy of uncommitted into committed hit
550  * structure, just adding the valid bit.
551  */
552 static inline void
brw_nir_rt_commit_hit_addr(nir_builder * b,nir_ssa_def * stack_addr)553 brw_nir_rt_commit_hit_addr(nir_builder *b, nir_ssa_def *stack_addr)
554 {
555    nir_ssa_def *dst_addr =
556       brw_nir_rt_mem_hit_addr_from_addr(b, stack_addr, true /* committed */);
557    nir_ssa_def *src_addr =
558       brw_nir_rt_mem_hit_addr_from_addr(b, stack_addr, false /* committed */);
559 
560    for (unsigned offset = 0; offset < BRW_RT_SIZEOF_HIT_INFO; offset += 16) {
561       nir_ssa_def *data =
562          brw_nir_rt_load(b, nir_iadd_imm(b, src_addr, offset), 16, 4, 32);
563 
564       if (offset == 0) {
565          data = nir_vec4(b,
566                          nir_channel(b, data, 0),
567                          nir_channel(b, data, 1),
568                          nir_channel(b, data, 2),
569                          nir_ior_imm(b,
570                                      nir_channel(b, data, 3),
571                                      0x1 << 16 /* valid */));
572 
573          /* Also write the potential hit as we change it. */
574          brw_nir_rt_store(b, nir_iadd_imm(b, src_addr, offset), 16,
575                           data, 0xf /* write_mask */);
576       }
577 
578       brw_nir_rt_store(b, nir_iadd_imm(b, dst_addr, offset), 16,
579                        data, 0xf /* write_mask */);
580    }
581 }
582 
583 static inline void
brw_nir_rt_commit_hit(nir_builder * b)584 brw_nir_rt_commit_hit(nir_builder *b)
585 {
586    nir_ssa_def *stack_addr = brw_nir_rt_stack_addr(b);
587    brw_nir_rt_commit_hit_addr(b, stack_addr);
588 }
589 
590 static inline void
brw_nir_rt_generate_hit_addr(nir_builder * b,nir_ssa_def * stack_addr,nir_ssa_def * t_val)591 brw_nir_rt_generate_hit_addr(nir_builder *b, nir_ssa_def *stack_addr, nir_ssa_def *t_val)
592 {
593    nir_ssa_def *committed_addr =
594       brw_nir_rt_mem_hit_addr_from_addr(b, stack_addr, true /* committed */);
595    nir_ssa_def *potential_addr =
596       brw_nir_rt_mem_hit_addr_from_addr(b, stack_addr, false /* committed */);
597 
598    /* Set:
599     *
600     *   potential.t     = t_val;
601     *   potential.valid = true;
602     */
603    nir_ssa_def *potential_hit_dwords_0_3 =
604       brw_nir_rt_load(b, potential_addr, 16, 4, 32);
605    potential_hit_dwords_0_3 =
606       nir_vec4(b,
607                t_val,
608                nir_channel(b, potential_hit_dwords_0_3, 1),
609                nir_channel(b, potential_hit_dwords_0_3, 2),
610                nir_ior_imm(b, nir_channel(b, potential_hit_dwords_0_3, 3),
611                            (0x1 << 16) /* valid */));
612    brw_nir_rt_store(b, potential_addr, 16, potential_hit_dwords_0_3, 0xf /* write_mask */);
613 
614    /* Set:
615     *
616     *   committed.t               = t_val;
617     *   committed.u               = 0.0f;
618     *   committed.v               = 0.0f;
619     *   committed.valid           = true;
620     *   committed.leaf_type       = potential.leaf_type;
621     *   committed.bvh_level       = BRW_RT_BVH_LEVEL_OBJECT;
622     *   committed.front_face      = false;
623     *   committed.prim_leaf_index = 0;
624     *   committed.done            = false;
625     */
626    nir_ssa_def *committed_hit_dwords_0_3 =
627       brw_nir_rt_load(b, committed_addr, 16, 4, 32);
628    committed_hit_dwords_0_3 =
629       nir_vec4(b,
630                t_val,
631                nir_imm_float(b, 0.0f),
632                nir_imm_float(b, 0.0f),
633                nir_ior_imm(b,
634                            nir_ior_imm(b, nir_channel(b, potential_hit_dwords_0_3, 3), 0x000e0000),
635                            (0x1 << 16)                     /* valid */ |
636                            (BRW_RT_BVH_LEVEL_OBJECT << 24) /* leaf_type */));
637    brw_nir_rt_store(b, committed_addr, 16, committed_hit_dwords_0_3, 0xf /* write_mask */);
638 
639    /* Set:
640     *
641     *   committed.prim_leaf_ptr   = potential.prim_leaf_ptr;
642     *   committed.inst_leaf_ptr   = potential.inst_leaf_ptr;
643     */
644    brw_nir_memcpy_global(b,
645                          nir_iadd_imm(b, committed_addr, 16), 16,
646                          nir_iadd_imm(b, potential_addr, 16), 16,
647                          16);
648 }
649 
650 struct brw_nir_rt_mem_ray_defs {
651    nir_ssa_def *orig;
652    nir_ssa_def *dir;
653    nir_ssa_def *t_near;
654    nir_ssa_def *t_far;
655    nir_ssa_def *root_node_ptr;
656    nir_ssa_def *ray_flags;
657    nir_ssa_def *hit_group_sr_base_ptr;
658    nir_ssa_def *hit_group_sr_stride;
659    nir_ssa_def *miss_sr_ptr;
660    nir_ssa_def *shader_index_multiplier;
661    nir_ssa_def *inst_leaf_ptr;
662    nir_ssa_def *ray_mask;
663 };
664 
665 static inline void
brw_nir_rt_store_mem_ray_query_at_addr(nir_builder * b,nir_ssa_def * ray_addr,const struct brw_nir_rt_mem_ray_defs * defs)666 brw_nir_rt_store_mem_ray_query_at_addr(nir_builder *b,
667                                        nir_ssa_def *ray_addr,
668                                        const struct brw_nir_rt_mem_ray_defs *defs)
669 {
670    assert_def_size(defs->orig, 3, 32);
671    assert_def_size(defs->dir, 3, 32);
672    brw_nir_rt_store(b, nir_iadd_imm(b, ray_addr, 0), 16,
673       nir_vec4(b, nir_channel(b, defs->orig, 0),
674                   nir_channel(b, defs->orig, 1),
675                   nir_channel(b, defs->orig, 2),
676                   nir_channel(b, defs->dir, 0)),
677       ~0 /* write mask */);
678 
679    assert_def_size(defs->t_near, 1, 32);
680    assert_def_size(defs->t_far, 1, 32);
681    brw_nir_rt_store(b, nir_iadd_imm(b, ray_addr, 16), 16,
682       nir_vec4(b, nir_channel(b, defs->dir, 1),
683                   nir_channel(b, defs->dir, 2),
684                   defs->t_near,
685                   defs->t_far),
686       ~0 /* write mask */);
687 
688    assert_def_size(defs->root_node_ptr, 1, 64);
689    assert_def_size(defs->ray_flags, 1, 16);
690    brw_nir_rt_store(b, nir_iadd_imm(b, ray_addr, 32), 16,
691       nir_vec2(b, nir_unpack_64_2x32_split_x(b, defs->root_node_ptr),
692                   nir_pack_32_2x16_split(b,
693                      nir_unpack_64_4x16_split_z(b, defs->root_node_ptr),
694                      defs->ray_flags)),
695       0x3 /* write mask */);
696 
697    /* leaf_ptr is optional */
698    nir_ssa_def *inst_leaf_ptr;
699    if (defs->inst_leaf_ptr) {
700       inst_leaf_ptr = defs->inst_leaf_ptr;
701    } else {
702       inst_leaf_ptr = nir_imm_int64(b, 0);
703    }
704 
705    assert_def_size(inst_leaf_ptr, 1, 64);
706    assert_def_size(defs->ray_mask, 1, 32);
707    brw_nir_rt_store(b, nir_iadd_imm(b, ray_addr, 56), 8,
708       nir_vec2(b, nir_unpack_64_2x32_split_x(b, inst_leaf_ptr),
709                   nir_pack_32_2x16_split(b,
710                      nir_unpack_64_4x16_split_z(b, inst_leaf_ptr),
711                      nir_unpack_32_2x16_split_x(b, defs->ray_mask))),
712       ~0 /* write mask */);
713 }
714 
715 static inline void
brw_nir_rt_store_mem_ray(nir_builder * b,const struct brw_nir_rt_mem_ray_defs * defs,enum brw_rt_bvh_level bvh_level)716 brw_nir_rt_store_mem_ray(nir_builder *b,
717                          const struct brw_nir_rt_mem_ray_defs *defs,
718                          enum brw_rt_bvh_level bvh_level)
719 {
720    nir_ssa_def *ray_addr =
721       brw_nir_rt_mem_ray_addr(b, brw_nir_rt_stack_addr(b), bvh_level);
722 
723    assert_def_size(defs->orig, 3, 32);
724    assert_def_size(defs->dir, 3, 32);
725    brw_nir_rt_store(b, nir_iadd_imm(b, ray_addr, 0), 16,
726       nir_vec4(b, nir_channel(b, defs->orig, 0),
727                   nir_channel(b, defs->orig, 1),
728                   nir_channel(b, defs->orig, 2),
729                   nir_channel(b, defs->dir, 0)),
730       ~0 /* write mask */);
731 
732    assert_def_size(defs->t_near, 1, 32);
733    assert_def_size(defs->t_far, 1, 32);
734    brw_nir_rt_store(b, nir_iadd_imm(b, ray_addr, 16), 16,
735       nir_vec4(b, nir_channel(b, defs->dir, 1),
736                   nir_channel(b, defs->dir, 2),
737                   defs->t_near,
738                   defs->t_far),
739       ~0 /* write mask */);
740 
741    assert_def_size(defs->root_node_ptr, 1, 64);
742    assert_def_size(defs->ray_flags, 1, 16);
743    assert_def_size(defs->hit_group_sr_base_ptr, 1, 64);
744    assert_def_size(defs->hit_group_sr_stride, 1, 16);
745    brw_nir_rt_store(b, nir_iadd_imm(b, ray_addr, 32), 16,
746       nir_vec4(b, nir_unpack_64_2x32_split_x(b, defs->root_node_ptr),
747                   nir_pack_32_2x16_split(b,
748                      nir_unpack_64_4x16_split_z(b, defs->root_node_ptr),
749                      defs->ray_flags),
750                   nir_unpack_64_2x32_split_x(b, defs->hit_group_sr_base_ptr),
751                   nir_pack_32_2x16_split(b,
752                      nir_unpack_64_4x16_split_z(b, defs->hit_group_sr_base_ptr),
753                      defs->hit_group_sr_stride)),
754       ~0 /* write mask */);
755 
756    /* leaf_ptr is optional */
757    nir_ssa_def *inst_leaf_ptr;
758    if (defs->inst_leaf_ptr) {
759       inst_leaf_ptr = defs->inst_leaf_ptr;
760    } else {
761       inst_leaf_ptr = nir_imm_int64(b, 0);
762    }
763 
764    assert_def_size(defs->miss_sr_ptr, 1, 64);
765    assert_def_size(defs->shader_index_multiplier, 1, 32);
766    assert_def_size(inst_leaf_ptr, 1, 64);
767    assert_def_size(defs->ray_mask, 1, 32);
768    brw_nir_rt_store(b, nir_iadd_imm(b, ray_addr, 48), 16,
769       nir_vec4(b, nir_unpack_64_2x32_split_x(b, defs->miss_sr_ptr),
770                   nir_pack_32_2x16_split(b,
771                      nir_unpack_64_4x16_split_z(b, defs->miss_sr_ptr),
772                      nir_unpack_32_2x16_split_x(b,
773                         nir_ishl(b, defs->shader_index_multiplier,
774                                     nir_imm_int(b, 8)))),
775                   nir_unpack_64_2x32_split_x(b, inst_leaf_ptr),
776                   nir_pack_32_2x16_split(b,
777                      nir_unpack_64_4x16_split_z(b, inst_leaf_ptr),
778                      nir_unpack_32_2x16_split_x(b, defs->ray_mask))),
779       ~0 /* write mask */);
780 }
781 
782 static inline void
brw_nir_rt_load_mem_ray_from_addr(nir_builder * b,struct brw_nir_rt_mem_ray_defs * defs,nir_ssa_def * ray_base_addr,enum brw_rt_bvh_level bvh_level)783 brw_nir_rt_load_mem_ray_from_addr(nir_builder *b,
784                                   struct brw_nir_rt_mem_ray_defs *defs,
785                                   nir_ssa_def *ray_base_addr,
786                                   enum brw_rt_bvh_level bvh_level)
787 {
788    nir_ssa_def *ray_addr = brw_nir_rt_mem_ray_addr(b,
789                                                    ray_base_addr,
790                                                    bvh_level);
791 
792    nir_ssa_def *data[4] = {
793       brw_nir_rt_load(b, nir_iadd_imm(b, ray_addr,  0), 16, 4, 32),
794       brw_nir_rt_load(b, nir_iadd_imm(b, ray_addr, 16), 16, 4, 32),
795       brw_nir_rt_load(b, nir_iadd_imm(b, ray_addr, 32), 16, 4, 32),
796       brw_nir_rt_load(b, nir_iadd_imm(b, ray_addr, 48), 16, 4, 32),
797    };
798 
799    defs->orig = nir_channels(b, data[0], 0x7);
800    defs->dir = nir_vec3(b, nir_channel(b, data[0], 3),
801                            nir_channel(b, data[1], 0),
802                            nir_channel(b, data[1], 1));
803    defs->t_near = nir_channel(b, data[1], 2);
804    defs->t_far = nir_channel(b, data[1], 3);
805    defs->root_node_ptr =
806       nir_pack_64_2x32_split(b, nir_channel(b, data[2], 0),
807                                 nir_extract_i16(b, nir_channel(b, data[2], 1),
808                                                    nir_imm_int(b, 0)));
809    defs->ray_flags =
810       nir_unpack_32_2x16_split_y(b, nir_channel(b, data[2], 1));
811    defs->hit_group_sr_base_ptr =
812       nir_pack_64_2x32_split(b, nir_channel(b, data[2], 2),
813                                 nir_extract_i16(b, nir_channel(b, data[2], 3),
814                                                    nir_imm_int(b, 0)));
815    defs->hit_group_sr_stride =
816       nir_unpack_32_2x16_split_y(b, nir_channel(b, data[2], 3));
817    defs->miss_sr_ptr =
818       nir_pack_64_2x32_split(b, nir_channel(b, data[3], 0),
819                                 nir_extract_i16(b, nir_channel(b, data[3], 1),
820                                                    nir_imm_int(b, 0)));
821    defs->shader_index_multiplier =
822       nir_ushr(b, nir_unpack_32_2x16_split_y(b, nir_channel(b, data[3], 1)),
823                   nir_imm_int(b, 8));
824    defs->inst_leaf_ptr =
825       nir_pack_64_2x32_split(b, nir_channel(b, data[3], 2),
826                                 nir_extract_i16(b, nir_channel(b, data[3], 3),
827                                                    nir_imm_int(b, 0)));
828    defs->ray_mask =
829       nir_unpack_32_2x16_split_y(b, nir_channel(b, data[3], 3));
830 }
831 
832 static inline void
brw_nir_rt_load_mem_ray(nir_builder * b,struct brw_nir_rt_mem_ray_defs * defs,enum brw_rt_bvh_level bvh_level)833 brw_nir_rt_load_mem_ray(nir_builder *b,
834                         struct brw_nir_rt_mem_ray_defs *defs,
835                         enum brw_rt_bvh_level bvh_level)
836 {
837    brw_nir_rt_load_mem_ray_from_addr(b, defs, brw_nir_rt_stack_addr(b),
838                                      bvh_level);
839 }
840 
841 struct brw_nir_rt_bvh_instance_leaf_defs {
842    nir_ssa_def *shader_index;
843    nir_ssa_def *contribution_to_hit_group_index;
844    nir_ssa_def *world_to_object[4];
845    nir_ssa_def *instance_id;
846    nir_ssa_def *instance_index;
847    nir_ssa_def *object_to_world[4];
848 };
849 
850 static inline void
brw_nir_rt_load_bvh_instance_leaf(nir_builder * b,struct brw_nir_rt_bvh_instance_leaf_defs * defs,nir_ssa_def * leaf_addr)851 brw_nir_rt_load_bvh_instance_leaf(nir_builder *b,
852                                   struct brw_nir_rt_bvh_instance_leaf_defs *defs,
853                                   nir_ssa_def *leaf_addr)
854 {
855    defs->shader_index =
856       nir_iand_imm(b, brw_nir_rt_load(b, leaf_addr, 4, 1, 32), (1 << 24) - 1);
857    defs->contribution_to_hit_group_index =
858       nir_iand_imm(b,
859                    brw_nir_rt_load(b, nir_iadd_imm(b, leaf_addr, 4), 4, 1, 32),
860                    (1 << 24) - 1);
861 
862    defs->world_to_object[0] =
863       brw_nir_rt_load(b, nir_iadd_imm(b, leaf_addr, 16), 4, 3, 32);
864    defs->world_to_object[1] =
865       brw_nir_rt_load(b, nir_iadd_imm(b, leaf_addr, 28), 4, 3, 32);
866    defs->world_to_object[2] =
867       brw_nir_rt_load(b, nir_iadd_imm(b, leaf_addr, 40), 4, 3, 32);
868    /* The last column of the matrices is swapped between the two probably
869     * because it makes it easier/faster for hardware somehow.
870     */
871    defs->object_to_world[3] =
872       brw_nir_rt_load(b, nir_iadd_imm(b, leaf_addr, 52), 4, 3, 32);
873 
874    nir_ssa_def *data =
875       brw_nir_rt_load(b, nir_iadd_imm(b, leaf_addr, 64), 4, 4, 32);
876    defs->instance_id = nir_channel(b, data, 2);
877    defs->instance_index = nir_channel(b, data, 3);
878 
879    defs->object_to_world[0] =
880       brw_nir_rt_load(b, nir_iadd_imm(b, leaf_addr, 80), 4, 3, 32);
881    defs->object_to_world[1] =
882       brw_nir_rt_load(b, nir_iadd_imm(b, leaf_addr, 92), 4, 3, 32);
883    defs->object_to_world[2] =
884       brw_nir_rt_load(b, nir_iadd_imm(b, leaf_addr, 104), 4, 3, 32);
885    defs->world_to_object[3] =
886       brw_nir_rt_load(b, nir_iadd_imm(b, leaf_addr, 116), 4, 3, 32);
887 }
888 
889 struct brw_nir_rt_bvh_primitive_leaf_defs {
890    nir_ssa_def *shader_index;
891    nir_ssa_def *geom_mask;
892    nir_ssa_def *geom_index;
893    nir_ssa_def *type;
894    nir_ssa_def *geom_flags;
895 };
896 
897 static inline void
brw_nir_rt_load_bvh_primitive_leaf(nir_builder * b,struct brw_nir_rt_bvh_primitive_leaf_defs * defs,nir_ssa_def * leaf_addr)898 brw_nir_rt_load_bvh_primitive_leaf(nir_builder *b,
899                                    struct brw_nir_rt_bvh_primitive_leaf_defs *defs,
900                                    nir_ssa_def *leaf_addr)
901 {
902    nir_ssa_def *desc = brw_nir_rt_load(b, leaf_addr, 4, 2, 32);
903 
904    defs->shader_index =
905       nir_ubitfield_extract(b, nir_channel(b, desc, 0),
906                             nir_imm_int(b, 23), nir_imm_int(b, 0));
907    defs->geom_mask =
908       nir_ubitfield_extract(b, nir_channel(b, desc, 0),
909                             nir_imm_int(b, 31), nir_imm_int(b, 24));
910 
911    defs->geom_index =
912       nir_ubitfield_extract(b, nir_channel(b, desc, 1),
913                             nir_imm_int(b, 28), nir_imm_int(b, 0));
914    defs->type =
915       nir_ubitfield_extract(b, nir_channel(b, desc, 1),
916                             nir_imm_int(b, 29), nir_imm_int(b, 29));
917    defs->geom_flags =
918       nir_ubitfield_extract(b, nir_channel(b, desc, 1),
919                             nir_imm_int(b, 31), nir_imm_int(b, 30));
920 }
921 
922 static inline nir_ssa_def *
brw_nir_rt_load_primitive_id_from_hit(nir_builder * b,nir_ssa_def * is_procedural,const struct brw_nir_rt_mem_hit_defs * defs)923 brw_nir_rt_load_primitive_id_from_hit(nir_builder *b,
924                                       nir_ssa_def *is_procedural,
925                                       const struct brw_nir_rt_mem_hit_defs *defs)
926 {
927    if (!is_procedural) {
928       is_procedural =
929          nir_ieq(b, defs->leaf_type,
930                     nir_imm_int(b, BRW_RT_BVH_NODE_TYPE_PROCEDURAL));
931    }
932 
933    /* The IDs are located in the leaf. Take the index of the hit.
934     *
935     * The index in dw[3] for procedural and dw[2] for quad.
936     */
937    nir_ssa_def *offset =
938       nir_bcsel(b, is_procedural,
939                    nir_iadd_imm(b, nir_ishl_imm(b, defs->prim_leaf_index, 2), 12),
940                    nir_imm_int(b, 8));
941    return nir_load_global(b, nir_iadd(b, defs->prim_leaf_ptr,
942                                          nir_u2u64(b, offset)),
943                              4, /* align */ 1, 32);
944 }
945 
946 static inline nir_ssa_def *
brw_nir_rt_acceleration_structure_to_root_node(nir_builder * b,nir_ssa_def * as_addr)947 brw_nir_rt_acceleration_structure_to_root_node(nir_builder *b,
948                                                nir_ssa_def *as_addr)
949 {
950    /* The HW memory structure in which we specify what acceleration structure
951     * to traverse, takes the address to the root node in the acceleration
952     * structure, not the acceleration structure itself. To find that, we have
953     * to read the root node offset from the acceleration structure which is
954     * the first QWord.
955     *
956     * But if the acceleration structure pointer is NULL, then we should return
957     * NULL as root node pointer.
958     *
959     * TODO: we could optimize this by assuming that for a given version of the
960     * BVH, we can find the root node at a given offset.
961     */
962    nir_ssa_def *root_node_ptr, *null_node_ptr;
963    nir_push_if(b, nir_ieq(b, as_addr, nir_imm_int64(b, 0)));
964    {
965       null_node_ptr = nir_imm_int64(b, 0);
966    }
967    nir_push_else(b, NULL);
968    {
969       root_node_ptr =
970          nir_iadd(b, as_addr, brw_nir_rt_load(b, as_addr, 256, 1, 64));
971    }
972    nir_pop_if(b, NULL);
973 
974    return nir_if_phi(b, null_node_ptr, root_node_ptr);
975 }
976 
977 #endif /* BRW_NIR_RT_BUILDER_H */
978