• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1//
2// Copyright (C) 2009-2021 Intel Corporation
3//
4// SPDX-License-Identifier: MIT
5//
6//
7
8#include "intrinsics.h"
9#include "AABB3f.h"
10#include "AABB.h"
11#include "GRLGen12.h"
12#include "quad.h"
13#include "common.h"
14#include "instance.h"
15
16#include "api_interface.h"
17
18#include "binned_sah_shared.h"
19
20
21#if 0
22#define LOOP_TRIPWIRE_INIT uint _loop_trip=0;
23
24#define LOOP_TRIPWIRE_INCREMENT(max_iterations) \
25    _loop_trip++;\
26    if ( _loop_trip > max_iterations  )\
27    {\
28        if( get_local_id(0) == 0 )\
29            printf( "@@@@@@@@@@@@@@@@@@@@ TRIPWIRE!!!!!!!!!!! group=%u\n", get_group_id(0) );\
30        break;\
31    }
32#else
33
34#define LOOP_TRIPWIRE_INIT
35#define LOOP_TRIPWIRE_INCREMENT(max_iterations)
36
37#endif
38
39
40// =========================================================
41//             DFS
42// =========================================================
43
44// there are 128 threads x SIMD16 == 2048 lanes in a DSS
45//   There is 128KB of SLM.  Upper limit of 64KB per WG, so target is 2 groups of 1024 lanes @ 64K each
46//     --> Full occupancy requires using less than 64B per lane
47//
48//   Groups of 256 lanes gives us 16KB per group
49//
50
51// We use subgroups very heavily here in order to avoid
52//    use of per-thread scratch space for intermediate values
53
54#define DFS_WG_SIZE 256
55#define DFS_NUM_SUBGROUPS 16
56#define DFS_BVH2_NODE_COUNT (2*(DFS_WG_SIZE)-1)
57#define TREE_ARITY 6
58
59// FlatTree node limits:
60// these are the derivations if we always collapse to one primitive and pack nodes as tightly as possible
61//   If BVH2 construction is allowed to terminate early and place multiple prims in a leaf, these numbers will be too low
62#if 0
63
64// maximum flattree size is the number of inner nodes in a full M-ary tree with one leaf per primitive
65//  This is given by I = (L-1)/(M-1)
66//  For a 256 thread workgroup, L=256, M=6, this gives: 51
67#define DFS_MAX_FLATTREE_NODES 51
68
69
70// A flattree leaf is a node which contains only primitives.
71//
72//  The maximum number of leaves is related to the number of nodes as:
73//   L(N) = ((M-1)*N + 1) / M
74//
75#define DFS_MAX_FLATTREE_LEAFS 43  // = 43 for 256 thread WG (L=256, M=6)
76
77#else
78
79//  This is the result of estimate_qbvh6_nodes(256)
80
81#define DFS_MAX_FLATTREE_LEAFS 256
82#define DFS_MAX_FLATTREE_NODES 307 // 256 fat-leaves + 51 inner nodes.  51 = ceil(256/5)
83#define DFS_MAX_FLATTREE_DEPTH 52  // number of inner nodes in the worst-case tree
84
85#endif
86
87#define uniform
88#define varying
89
90
91struct DFSArgs
92{
93    global struct BVHBase* bvh_base;
94    global PrimRef* primref_buffer;
95    ushort leaf_node_type;
96    ushort inner_node_type;
97    ushort leaf_size_in_bytes;
98    bool need_backpointers;
99    bool need_masks;
100    ushort num_primrefs;
101    global uint* primref_index_buffer;
102};
103
104
105struct DFSPrimRefAABB
106{
107    half lower[3];
108    half upper[3];
109};
110
111GRL_INLINE void DFSPrimRefAABB_init( struct DFSPrimRefAABB* bb )
112{
113    bb->lower[0] = 1;
114    bb->lower[1] = 1;
115    bb->lower[2] = 1;
116    bb->upper[0] = 0;
117    bb->upper[1] = 0;
118    bb->upper[2] = 0;
119}
120
121GRL_INLINE void DFSPrimRefAABB_extend( struct DFSPrimRefAABB* aabb, struct DFSPrimRefAABB* v )
122{
123    aabb->lower[0] = min( aabb->lower[0], v->lower[0] );
124    aabb->lower[1] = min( aabb->lower[1], v->lower[1] );
125    aabb->lower[2] = min( aabb->lower[2], v->lower[2] );
126    aabb->upper[0] = max( aabb->upper[0], v->upper[0] );
127    aabb->upper[1] = max( aabb->upper[1], v->upper[1] );
128    aabb->upper[2] = max( aabb->upper[2], v->upper[2] );
129}
130
131GRL_INLINE float DFSPrimRefAABB_halfArea( struct DFSPrimRefAABB* aabb )
132{
133    const half3 d = (half3)(aabb->upper[0] - aabb->lower[0], aabb->upper[1] - aabb->lower[1], aabb->upper[2] - aabb->lower[2]);
134    return fma( d.x, (d.y + d.z), d.y * d.z );
135}
136
137GRL_INLINE struct DFSPrimRefAABB DFSPrimRefAABB_sub_group_reduce( struct DFSPrimRefAABB* aabb )
138{
139    struct DFSPrimRefAABB bounds;
140    bounds.lower[0] = sub_group_reduce_min( aabb->lower[0] );
141    bounds.lower[1] = sub_group_reduce_min( aabb->lower[1] );
142    bounds.lower[2] = sub_group_reduce_min( aabb->lower[2] );
143    bounds.upper[0] = sub_group_reduce_max( aabb->upper[0] );
144    bounds.upper[1] = sub_group_reduce_max( aabb->upper[1] );
145    bounds.upper[2] = sub_group_reduce_max( aabb->upper[2] );
146    return bounds;
147}
148
149struct DFSPrimRef
150{
151    struct DFSPrimRefAABB aabb;
152    uint2 meta;
153};
154
155struct PrimRefMeta
156{
157    uchar2 meta;
158};
159
160GRL_INLINE uint PrimRefMeta_GetInputIndex( struct PrimRefMeta* it )
161{
162    return it->meta.x;
163}
164GRL_INLINE uint PrimRefMeta_GetInstanceMask( struct PrimRefMeta* it )
165{
166    return it->meta.y;
167}
168
169
170struct PrimRefSet
171{
172    struct AABB3f root_aabb;
173    struct DFSPrimRefAABB AABB[DFS_WG_SIZE];
174    uint2 meta[DFS_WG_SIZE];
175
176};
177
178GRL_INLINE local struct DFSPrimRefAABB* PrimRefSet_GetAABBPointer( local struct PrimRefSet* refs, ushort id )
179{
180    return &refs->AABB[id];
181}
182
183GRL_INLINE float PrimRefSet_GetMaxAABBArea( local struct PrimRefSet* refs )
184{
185    float3 root_l = AABB3f_load_lower( &refs->root_aabb );
186    float3 root_u = AABB3f_load_upper( &refs->root_aabb );
187    float3 d = root_u - root_l;
188    float scale = 1.0f / max( d.x, max( d.y, d.z ) );
189
190    half3 dh = convert_half3_rtp( d * scale );
191    return fma( dh.x, (dh.y + dh.z), dh.y * dh.z );
192}
193
194GRL_INLINE float3 ulp3( float3 v ) {
195
196    return fabs(v) * FLT_EPSILON;
197}
198
199GRL_INLINE struct AABB PrimRefSet_ConvertAABB( local struct PrimRefSet* refs, struct DFSPrimRefAABB* box )
200{
201    float3 root_l = AABB3f_load_lower( &refs->root_aabb );
202    float3 root_u = AABB3f_load_upper( &refs->root_aabb );
203    float3 d = root_u - root_l;
204    float scale = max( d.x, max( d.y, d.z ) );
205
206    float3 l = convert_float3_rtz( (half3)(box->lower[0], box->lower[1], box->lower[2]) );
207    float3 u = convert_float3_rtp( (half3)(box->upper[0], box->upper[1], box->upper[2]) );
208    l =  l * scale + root_l ;
209    u =  u * scale + root_l ;
210
211    // clamping is necessary in case that a vertex lies exactly in the upper AABB plane.
212    //   If we use unclamped values, roundoff error in the scale factor calculation can cause us
213    //   to snap to a flattened AABB that lies outside of the original one, resulting in missed geometry.
214    u = min( u, root_u );
215    l = min( l, root_u );
216
217    struct AABB r;
218    r.lower.xyz = l.xyz;
219    r.upper.xyz = u.xyz;
220    return r;
221}
222
223GRL_INLINE PrimRef PrimRefSet_GetFullPrecisionAABB( local struct PrimRefSet* refs, ushort id )
224{
225    struct AABB r;
226    r = PrimRefSet_ConvertAABB( refs, &refs->AABB[id] );
227    r.lower.w = 0;
228    r.upper.w = 0;
229    return r;
230}
231
232GRL_INLINE uint PrimRefSet_GetInputIndex( local struct PrimRefSet* refs, ushort id )
233{
234    return refs->meta[id].x;
235}
236
237GRL_INLINE uint PrimRefSet_GetInstanceMask( local struct PrimRefSet* refs, ushort id )
238{
239    return refs->meta[id].y;
240}
241GRL_INLINE struct PrimRefMeta PrimRefSet_GetMeta( local struct PrimRefSet* refs, ushort id )
242{
243    struct PrimRefMeta meta;
244    meta.meta.x = refs->meta[id].x;
245    meta.meta.y = refs->meta[id].y;
246    return meta;
247}
248
249
250GRL_INLINE struct DFSPrimRef PrimRefSet_GetPrimRef( local struct PrimRefSet* refs, ushort id )
251{
252    struct DFSPrimRef r;
253    r.aabb = refs->AABB[id];
254    r.meta = refs->meta[id];
255    return r;
256}
257
258
259GRL_INLINE void PrimRefSet_SetPrimRef_FullPrecision( local struct PrimRefSet* refs, PrimRef ref, ushort id )
260{
261
262    float3 root_l = AABB3f_load_lower( &refs->root_aabb );
263    float3 root_u = AABB3f_load_upper( &refs->root_aabb );
264    float3 d      = root_u - root_l;
265    float scale   = 1.0f / max(d.x, max(d.y,d.z));
266
267    float3 l = ref.lower.xyz;
268    float3 u = ref.upper.xyz;
269    half3 lh = convert_half3_rtz( (l - root_l) * scale );
270    half3 uh = convert_half3_rtp( (u - root_l) * scale );
271
272    refs->AABB[id].lower[0] = lh.x;
273    refs->AABB[id].lower[1] = lh.y;
274    refs->AABB[id].lower[2] = lh.z;
275    refs->AABB[id].upper[0] = uh.x;
276    refs->AABB[id].upper[1] = uh.y;
277    refs->AABB[id].upper[2] = uh.z;
278    refs->meta[id].x = id;
279    refs->meta[id].y = PRIMREF_instanceMask(&ref);
280
281
282}
283
284GRL_INLINE void PrimRefSet_SetPrimRef( local struct PrimRefSet* refs, struct DFSPrimRef ref, ushort id )
285{
286    refs->AABB[id] = ref.aabb;
287    refs->meta[id] = ref.meta;
288}
289
290GRL_INLINE struct AABB3f PrimRefSet_GetRootAABB( local struct PrimRefSet* refs )
291{
292    return refs->root_aabb;
293}
294
295GRL_INLINE void SUBGROUP_PrimRefSet_Initialize( local struct PrimRefSet* refs )
296{
297    if ( get_sub_group_local_id() == 0 )
298        AABB3f_init( &refs->root_aabb ); // TODO_OPT: subgroup-vectorized version of AABB3f_init
299}
300
301
302GRL_INLINE void PrimRefSet_Printf( local struct PrimRefSet* refs, ushort num_prims )
303{
304
305    barrier( CLK_LOCAL_MEM_FENCE );
306    if ( get_local_id( 0 ) == 0 )
307    {
308        printf( "Scene AABB:\n" );
309        struct AABB3f rootBox = PrimRefSet_GetRootAABB( refs );
310        AABB3f_print( &rootBox );
311
312        float ma = PrimRefSet_GetMaxAABBArea( refs );
313
314        for ( uint i = 0; i < num_prims; i++ )
315        {
316            printf( "Ref: %u\n", i );
317            struct AABB r = PrimRefSet_GetFullPrecisionAABB( refs, i );
318            AABB_print( &r );
319
320            float a = DFSPrimRefAABB_halfArea( PrimRefSet_GetAABBPointer( refs, i ) );
321            printf( "Scaled Area: %f / %f = %f \n", a, ma, a / ma );
322
323        }
324    }
325    barrier( CLK_LOCAL_MEM_FENCE );
326}
327
328
329
330GRL_INLINE void PrimRefSet_CheckBounds( local struct PrimRefSet* refs, ushort num_prims, PrimRef* primref_buffer )
331{
332
333    barrier( CLK_LOCAL_MEM_FENCE );
334    if ( get_local_id( 0 ) == 0 )
335    {
336
337        for ( uint i = 0; i < num_prims; i++ )
338        {
339            PrimRef ref = primref_buffer[i];
340            struct AABB r2 = PrimRefSet_GetFullPrecisionAABB( refs, i );
341
342            struct DFSPrimRefAABB* box = &refs->AABB[i];
343            float3 l = convert_float3_rtz( (half3)(box->lower[0], box->lower[1], box->lower[2]) );
344            float3 u = convert_float3_rtp( (half3)(box->upper[0], box->upper[1], box->upper[2]) );
345
346            printf( " halfs:{%x,%x,%x}{%x,%x,%x}\n", as_uint(l.x), as_uint(l.y), as_uint(l.z), as_uint(u.x), as_uint(u.y), as_uint(u.z) );
347
348            printf( " {%f,%f,%f} {%f,%f,%f}    {%f,%f,%f} {%f,%f,%f} {%u,%u,%u,%u,%u,%u}\n",
349                ref.lower.x, ref.lower.y, ref.lower.z, r2.lower.x, r2.lower.y, r2.lower.z,
350                ref.upper.x, ref.upper.y, ref.upper.z, r2.upper.x, r2.upper.y, r2.upper.z,
351                r2.lower.x <= ref.lower.x,
352                r2.lower.y <= ref.lower.y,
353                r2.lower.z <= ref.lower.z,
354
355                r2.upper.x >= ref.upper.x,
356                r2.upper.y >= ref.upper.y,
357                r2.upper.z >= ref.upper.z );
358
359        }
360
361    }
362    barrier( CLK_LOCAL_MEM_FENCE );
363}
364
365
366
367struct LocalBVH2
368{
369    uint num_nodes;
370    uint nodes[DFS_BVH2_NODE_COUNT];
371
372    // nodes are a bitfield:
373    //    bits 8:0 (9b)     ==> number of primrefs in this subtree
374    //
375    //    bits 17:9 (9b)    ==> for an inner node:  contains offset to a pair of children
376    //                      ==> for a leaf node: contains index of the first primref in this leaf
377    //
378    //    bits 30:18 (13b)  ==> quantized AABB area (relative to root box)
379    //    bit 31 (1b)       ==> is_inner flag
380    //
381    // NOTE: The left child offset of any node is always odd.. therefore, it is possible to recover a bit if we need it
382    //        by storing only the 8 MSBs
383};
384
385#define DFS_BVH2_AREA_QUANT 8191.0f
386
387
388
389GRL_INLINE void SUBGROUP_LocalBVH2_Initialize( local struct LocalBVH2* tree, ushort num_prims )
390{
391    tree->num_nodes = 1; // include the root node
392    tree->nodes[0] = num_prims; // initialize root node as a leaf containing the full subtree
393
394}
395
396GRL_INLINE void LocalBVH2_CreateInnerNode( local struct LocalBVH2* tree, ushort node_index,
397                           ushort start_left, ushort start_right,
398                           ushort quantized_left_area, ushort quantized_right_area )
399{
400    uint child_pos   = atomic_add_local( &tree->num_nodes, 2 );
401
402    // set the inner node flag and child position in the parent
403    // leave the other bits intact
404    uint parent_node = tree->nodes[node_index];
405    parent_node |= 0x80000000;
406    parent_node = (parent_node & ~(0x1ff<<9)) | (child_pos << 9);
407    tree->nodes[node_index] = parent_node;
408
409    // setup children as leaf nodes with prim-count zero
410    uint left_child  = (convert_uint(start_left) << 9)  | (convert_uint( quantized_left_area )  << 18);
411    uint right_child = (convert_uint(start_right) << 9) | (convert_uint( quantized_right_area ) << 18);
412    tree->nodes[child_pos]      = left_child;
413    tree->nodes[child_pos + 1]  = right_child;
414
415}
416
417GRL_INLINE ushort LocalBVH2_IncrementPrimCount( local struct LocalBVH2* tree, ushort node_index )
418{
419    // increment only the lower bits.  Given correct tree construction algorithm this will not overflow into MSBs
420    return (atomic_inc_local( &tree->nodes[node_index] )) & 0x1ff;
421}
422
423GRL_INLINE ushort LocalBVH2_GetNodeArea( local struct LocalBVH2* tree, ushort nodeID )
424{
425    return (tree->nodes[nodeID] >> 18) & 0x1FFF;
426}
427
428GRL_INLINE bool LocalBVH2_IsInnerNode( local struct LocalBVH2* tree, ushort nodeID )
429{
430    return (tree->nodes[nodeID] & 0x80000000) != 0;
431}
432
433
434GRL_INLINE ushort2 LocalBVH2_GetChildIndices( local struct LocalBVH2* tree, ushort nodeID )
435{
436    ushort idx = ((tree->nodes[nodeID] >> 9) & 0x1FF);
437    return (ushort2)(idx, idx + 1);
438}
439
440GRL_INLINE ushort LocalBVH2_GetSubtreePrimCount( local struct LocalBVH2* tree, ushort node )
441{
442    return tree->nodes[node] & 0x1FF;
443}
444
445GRL_INLINE ushort LocalBVH2_GetLeafPrimStart( local struct LocalBVH2* tree, ushort node )
446{
447    return ((tree->nodes[node] >> 9) & 0x1FF);
448}
449
450
451GRL_INLINE void LocalBVH2_Printf( local struct LocalBVH2* tree )
452{
453    barrier( CLK_LOCAL_MEM_FENCE );
454
455    if ( get_local_id( 0 ) == 0 )
456    {
457        printf( "Nodes: %u\n", tree->num_nodes );
458
459        for ( uint i = 0; i < tree->num_nodes; i++ )
460        {
461            uint num_prims = LocalBVH2_GetSubtreePrimCount( tree, i );
462            printf( "%3u : 0x%08x  %3u 0x%04x ", i, tree->nodes[i], num_prims, LocalBVH2_GetNodeArea(tree,i) );
463            if ( LocalBVH2_IsInnerNode( tree, i ) )
464            {
465                ushort2 kids = LocalBVH2_GetChildIndices( tree, i );
466                printf( " INNER ( %3u %3u )\n", kids.x, kids.y );
467            }
468            else
469            {
470                printf( " LEAF {" );
471                for ( uint j = 0; j < num_prims; j++ )
472                    printf( " %3u ", LocalBVH2_GetLeafPrimStart( tree, i ) + j );
473                printf( "}\n" );
474            }
475        }
476    }
477
478    barrier( CLK_LOCAL_MEM_FENCE );
479}
480
481struct FlatTreeInnerNode
482{
483    uint DW0;                // lower 16b are index of corresponding LocalBVH2 node.. Bits 30:16  are an atomic flag used during refit.  Bit 31 is a leaf marker
484    ushort parent_index;
485    ushort first_child;
486    uchar index_in_parent;
487    uchar num_children;
488
489    //struct DFSPrimRefAABB AABB;
490};
491
492struct FlatTree
493{
494    uint num_nodes;
495    uint qnode_byte_offset; // byte offset from the BVHBase to the flat-tree's first QNode
496    uint qnode_base_index;
497
498    struct FlatTreeInnerNode nodes[DFS_MAX_FLATTREE_NODES];
499    uchar primref_back_pointers[DFS_WG_SIZE];
500};
501
502GRL_INLINE void FlatTree_Printf( local struct FlatTree* flat_tree )
503{
504    barrier( CLK_LOCAL_MEM_FENCE );
505    if ( get_local_id( 0 ) == 0 )
506    {
507        printf( "NumNodes: %u\n", flat_tree->num_nodes );
508        for ( uint i = 0; i < flat_tree->num_nodes; i++ )
509        {
510            ushort bvh2_node = flat_tree->nodes[i].DW0 & 0xffff;
511            printf( "%2u  Parent: %2u  Index_in_parent: %u, NumKids: %u  FirstKid: %3u bvh2: %3u DW0: 0x%x\n",
512                i,
513                flat_tree->nodes[i].parent_index,
514                flat_tree->nodes[i].index_in_parent,
515                flat_tree->nodes[i].num_children,
516                flat_tree->nodes[i].first_child,
517                bvh2_node,
518                flat_tree->nodes[i].DW0 );
519        }
520    }
521    barrier( CLK_LOCAL_MEM_FENCE );
522}
523
524
525
526
527GRL_INLINE ushort FlatTree_GetNodeCount( local struct FlatTree* flat_tree )
528{
529    return flat_tree->num_nodes;
530}
531
532GRL_INLINE uint FlatTree_GetParentIndex( local struct FlatTree* flat_tree, ushort id )
533{
534    return flat_tree->nodes[id].parent_index;
535}
536
537GRL_INLINE ushort FlatTree_GetBVH2Root( local struct FlatTree* flat_tree, ushort node_index )
538{
539    return (flat_tree->nodes[node_index].DW0) & 0xffff;
540}
541
542GRL_INLINE ushort FlatTree_GetNumChildren( local struct FlatTree* flat_tree, ushort node_index )
543{
544    return flat_tree->nodes[node_index].num_children;
545}
546
547GRL_INLINE bool FlatTree_IsLeafNode( local struct FlatTree* flat_tree, ushort node_index )
548{
549    return (flat_tree->nodes[node_index].DW0 & 0x80000000) != 0;
550}
551
552
553GRL_INLINE uint FlatTree_GetQNodeByteOffset( struct FlatTree* flat_tree, ushort node_index )
554{
555    return flat_tree->qnode_byte_offset + node_index * sizeof(struct QBVHNodeN);
556}
557
558GRL_INLINE uint FlatTree_GetQNodeIndex( struct FlatTree* flat_tree, ushort node_index )
559{
560    return flat_tree->qnode_base_index + node_index;
561}
562
563GRL_INLINE void FlatTree_AllocateQNodes( struct FlatTree* flat_tree, struct DFSArgs args )
564{
565    uint node_base = 64*allocate_inner_nodes( args.bvh_base, flat_tree->num_nodes );
566    flat_tree->qnode_base_index  = (node_base - BVH_ROOT_NODE_OFFSET) / sizeof( struct QBVHNodeN );
567    flat_tree->qnode_byte_offset = node_base;
568}
569
570GRL_INLINE ushort FlatTree_GetFirstChild( struct FlatTree* flat_tree, ushort node_index )
571{
572    return flat_tree->nodes[node_index].first_child;
573}
574
575GRL_INLINE ushort FlatTree_GetPrimRefStart( struct FlatTree* flat_tree, ushort node_index )
576{
577    return flat_tree->nodes[node_index].first_child;
578}
579GRL_INLINE ushort FlatTree_GetPrimRefCount( struct FlatTree* flat_tree, ushort node_index )
580{
581    return flat_tree->nodes[node_index].num_children;
582}
583
584GRL_INLINE uint FlatTree_BuildBackPointer( local struct FlatTree* flat_tree, ushort node_index )
585{
586    uint parent_index = flat_tree->nodes[node_index].parent_index + flat_tree->qnode_base_index;
587    parent_index = (parent_index << 6) | (FlatTree_GetNumChildren( flat_tree, node_index ) << 3);
588    return parent_index;
589}
590
591
592GRL_INLINE void SUBGROUP_FlatTree_Initialize( uniform local struct FlatTree* flat_tree, struct DFSArgs args )
593{
594    if ( get_sub_group_local_id() == 0 )
595    {
596        flat_tree->num_nodes    = 1;
597        flat_tree->nodes[0].DW0 = 0; // point first node at BVH2 root node, which is assumed to be at index zero
598    }
599
600}
601/*
602GRL_INLINE void SUBGROUP_FlatTree_ReduceAndSetAABB( uniform local struct FlatTree* flat_tree,
603                                         uniform ushort node_index,
604                                         varying local struct DFSPrimRefAABB* box )
605{
606    // TODO_OPT: Replace this with an optimized reduction which exploits the fact that we only ever have 6 active lanes
607    //       Try using the "negated max" trick here to compute min/max simultaneously, with max in top 6 lanes
608    //          This will replace 6 reductions with 3
609
610    // TODO_OPT:  This only utilizes up to 6 SIMD lanes.  We can use up to 12 of them by putting
611    //  min into even lanes, and -max into odd lanes, and using a manual min-reduction on pairs of lanes
612
613    struct DFSPrimRefAABB bb = DFSPrimRefAABB_sub_group_reduce( box );
614    if( get_sub_group_local_id() )
615        flat_tree->nodes[node_index].AABB = bb;
616}
617*/
618
619GRL_INLINE void SUBGROUP_FlatTree_CreateInnerNode( uniform local struct FlatTree* flat_tree,
620                                        uniform ushort flat_tree_root,
621                                        varying ushort sg_child_bvh2_root,
622                                        uniform ushort num_children )
623{
624    uniform uint lane = get_sub_group_local_id();
625
626    // increment counter to allocate new nodes.. set required root node fields
627    uniform uint child_base;
628    if ( lane == 0 )
629    {
630        child_base = atomic_add_local( &flat_tree->num_nodes, num_children );
631        flat_tree->nodes[flat_tree_root].first_child  = (uchar) child_base;
632        flat_tree->nodes[flat_tree_root].num_children = num_children;
633
634        // initialize mask bits for this node's live children
635        uint child_mask = ((1 << num_children) - 1) << 16;
636        flat_tree->nodes[flat_tree_root].DW0 |= child_mask;
637    }
638
639    child_base = sub_group_broadcast( child_base, 0 );
640
641    // initialize child nodes
642    if ( lane < num_children )
643    {
644        varying uint child = child_base + lane;
645        flat_tree->nodes[child].DW0 = sg_child_bvh2_root;
646        flat_tree->nodes[child].index_in_parent = lane;
647        flat_tree->nodes[child].parent_index = flat_tree_root;
648    }
649
650}
651
652
653
654GRL_INLINE void SUBGROUP_FlatTree_CreateLeafNode( uniform local struct FlatTree* flat_tree,
655                                       uniform ushort flat_tree_root,
656                                       uniform ushort primref_start,
657                                       uniform ushort num_prims )
658{
659    ushort lane = get_sub_group_local_id();
660    if ( lane < num_prims )
661    {
662        flat_tree->primref_back_pointers[primref_start + lane] = (uchar) flat_tree_root;
663        if ( lane == 0 )
664        {
665            flat_tree->nodes[flat_tree_root].first_child  = (uchar) primref_start;
666            flat_tree->nodes[flat_tree_root].num_children = (uchar) num_prims;
667            flat_tree->nodes[flat_tree_root].DW0 |= 0x80000000;
668        }
669    }
670}
671
672
673GRL_INLINE uniform bool SUBGROUP_FlatTree_SignalRefitComplete( uniform local struct FlatTree* flat_tree, uniform ushort* p_node_index )
674{
675    uniform ushort node_index       = *p_node_index;
676    uniform ushort parent           = flat_tree->nodes[node_index].parent_index;
677    uniform ushort index_in_parent  = flat_tree->nodes[node_index].index_in_parent;
678
679    // clear the corresponding mask bit in the parent node
680    uniform uint child_mask         = (0x10000 << index_in_parent);
681    uniform uint old_mask_bits = 0;
682    if( get_sub_group_local_id() == 0 )
683        old_mask_bits = atomic_xor( &flat_tree->nodes[parent].DW0, child_mask );
684
685    old_mask_bits = sub_group_broadcast( old_mask_bits, 0 );
686
687    // if we cleared the last mask bit, this subgroup proceeds up the tree and refits the next node
688    //  otherwise, it looks for something else to do
689    if ( ((old_mask_bits^child_mask) & 0xffff0000) == 0 )
690    {
691        *p_node_index = parent;
692        return true;
693    }
694
695    return false;
696}
697
698/*
699GRL_INLINE local struct DFSPrimRefAABB* FlatTree_GetChildAABB( local struct FlatTree* flat_tree,
700                                            local struct PrimRefSet* prim_refs,
701                                            ushort node_index, ushort child_index )
702{
703    ushort child_id = FlatTree_GetFirstChild( flat_tree, node_index ) + child_index;
704
705    if( !FlatTree_IsLeafNode( flat_tree, node_index ) )
706        return &flat_tree->nodes[child_id].AABB;
707    else
708        return PrimRefSet_GetAABBPointer( prim_refs, child_id );
709}
710*/
711GRL_INLINE uint FlatTree_GetPrimRefBackPointer( local struct FlatTree* flat_tree, ushort primref_index )
712{
713    return flat_tree->primref_back_pointers[primref_index] * sizeof(struct QBVHNodeN) + flat_tree->qnode_byte_offset;
714}
715
716
717GRL_INLINE void FlatTree_check_boxes(local struct FlatTree* flat_tree,
718    global struct AABB* primref_buffer,
719    local struct AABB3f* boxes,
720    local struct PrimRefMeta* meta )
721
722{
723    barrier(CLK_LOCAL_MEM_FENCE);
724    if (get_local_id(0) == 0)
725    {
726        printf("checking flattree bounds...\n");
727
728        for (uint i = 0; i < flat_tree->num_nodes; i++)
729        {
730            struct AABB rb;
731            rb.lower.xyz = AABB3f_load_lower(&boxes[i]);
732            rb.upper.xyz = AABB3f_load_upper(&boxes[i]);
733
734            uint offs  = FlatTree_GetFirstChild( flat_tree, i );
735            uint count = FlatTree_GetNumChildren( flat_tree, i );
736
737            for (uint c = 0; c < count; c++)
738            {
739                struct AABB lb;
740                if (FlatTree_IsLeafNode( flat_tree, i ))
741                {
742                    lb = primref_buffer[ PrimRefMeta_GetInputIndex( &meta[offs+c] ) ];
743                }
744                else
745                {
746                    lb.lower.xyz = AABB3f_load_lower(&boxes[ offs+c ]);
747                    lb.upper.xyz = AABB3f_load_upper(&boxes[ offs+c ]);
748                }
749
750                if( !AABB_subset( &lb, &rb ) )
751                    printf("Bad bounds!!  child %u of %u   %f : %f  %f : %f %f : %f    %f : %f  %f : %f %f : %f \n",
752                        c, i ,
753                        rb.lower.x, rb.upper.x, rb.lower.y, rb.upper.y, rb.lower.z, rb.upper.z,
754                        lb.lower.x, lb.upper.x, lb.lower.y, lb.upper.y, lb.lower.z, lb.upper.z
755                        );
756            }
757        }
758    }
759    barrier(CLK_LOCAL_MEM_FENCE);
760}
761
762
763struct FlatTreeScheduler
764{
765    int   num_leafs;
766    uint  writeout_produce_count;
767    uint  writeout_consume_count;
768    uint  active_subgroups;
769    uint  num_built_nodes;
770    uint  num_levels;   // number of depth levels in the tree
771
772    //uchar leaf_indices[DFS_MAX_FLATTREE_LEAFS];     // indices of leaf FlatTree nodes to be refitted
773    //uchar writeout_indices[DFS_MAX_FLATTREE_NODES]; // indices of flattree nodes to be written out or collapsed
774
775    ushort level_ordered_nodes[DFS_MAX_FLATTREE_NODES]; // node indices sorted by depth (pre-order, high depth before low depth)
776    ushort level_start[DFS_MAX_FLATTREE_DEPTH]; // first node at given level in the level-ordered node array
777    uint level_count[DFS_MAX_FLATTREE_DEPTH];  // number of nodes at given level
778};
779
780GRL_INLINE void SUBGROUP_FlatTreeScheduler_Initialize( uniform local struct FlatTreeScheduler* scheduler )
781{
782    scheduler->num_built_nodes = 0;
783    scheduler->num_leafs = 0;
784    scheduler->writeout_produce_count = 0;
785    scheduler->writeout_consume_count = 0;
786    scheduler->active_subgroups = DFS_NUM_SUBGROUPS;
787}
788/*
789GRL_INLINE void SUBGROUP_FlatTreeScheduler_QueueLeafForRefit( uniform local struct FlatTreeScheduler* scheduler,
790                                                   uniform ushort leaf )
791{
792    if ( get_sub_group_local_id() == 0 )
793        scheduler->leaf_indices[atomic_inc( &scheduler->num_leafs )] = leaf;
794}*/
795
796GRL_INLINE void SUBGROUP_FlatTreeScheduler_SignalNodeBuilt( uniform local struct FlatTreeScheduler* scheduler, uniform ushort node )
797{
798    if ( get_sub_group_local_id() == 0 )
799        atomic_inc_local( &scheduler->num_built_nodes );
800}
801
802GRL_INLINE uint FlatTreeScheduler_GetNumBuiltNodes( uniform local struct FlatTreeScheduler* scheduler )
803{
804    return scheduler->num_built_nodes;
805}
806
807/*
808GRL_INLINE void SUBGROUP_FlatTreeScheduler_QueueNodeForWriteOut( uniform local struct FlatTreeScheduler* scheduler, uniform ushort node )
809{
810    if ( get_sub_group_local_id() == 0 )
811        scheduler->writeout_indices[atomic_inc( &scheduler->writeout_produce_count )] = node;
812}*/
813
814/*
815GRL_INLINE bool SUBGROUP_FlatTreeScheduler_GetRefitTask( uniform local struct FlatTreeScheduler* scheduler, uniform ushort* leaf_idx )
816{
817    // schedule the leaves in reverse order to ensure that later leaves
818    //   complete before earlier ones.. This prevents contention during the WriteOut stage
819    //
820    // There is a barrier between this function and 'QueueLeafForRefit' so we can safely decrement the same counter
821    //   that we incremented earlier
822    varying int idx = 0;
823    if( get_sub_group_local_id() == 0 )
824        idx = atomic_dec( &scheduler->num_leafs );
825
826    sub_group_barrier( CLK_LOCAL_MEM_FENCE );
827    idx = sub_group_broadcast( idx, 0 );
828
829    if ( idx <= 0 )
830        return false;
831
832    *leaf_idx = scheduler->leaf_indices[idx-1];
833    return true;
834}*/
835
836/*
837// Signal the scheduler that a subgroup has reached the DONE state.
838//  Return true if this is the last subgroup to be done
839void SUBGROUP_FlatTreeScheduler_SubGroupDone( local struct FlatTreeScheduler* scheduler )
840{
841    if ( get_sub_group_local_id() == 0 )
842        atomic_dec( &scheduler->active_subgroups );
843}
844*/
845
846/*
847
848#define STATE_SCHEDULE_REFIT    0x1234
849#define STATE_SCHEDULE_WRITEOUT 0x5679
850#define STATE_REFIT             0xabcd
851#define STATE_WRITEOUT          0xefef
852#define STATE_DONE              0xaabb
853
854// Get a flattree node to write out.  Returns the new scheduler state
855GRL_INLINE ushort SUBGROUP_FlatTreeScheduler_GetWriteOutTask( uniform local struct FlatTreeScheduler* scheduler,
856                                                   uniform ushort num_nodes,
857                                                   uniform ushort* node_idx )
858{
859    uniform ushort return_state = STATE_WRITEOUT;
860    uniform ushort idx = 0;
861    if ( get_sub_group_local_id() == 0 )
862    {
863        idx = atomic_inc( &scheduler->writeout_consume_count );
864
865        if ( idx >= scheduler->writeout_produce_count )
866        {
867            // more consumers than there are produced tasks....
868
869            if ( scheduler->writeout_produce_count == num_nodes )
870            {
871                // if all nodes have been written out, flattening is done
872                return_state = STATE_DONE;
873            }
874            else
875            {
876                // some writeout tasks remain, and have not been produced by refit threads yet
877                //   we need to put this one back
878                atomic_dec( &scheduler->writeout_consume_count );
879                return_state = STATE_SCHEDULE_WRITEOUT;
880            }
881        }
882        else
883        {
884            // scheduled successfully
885            idx = scheduler->writeout_indices[idx];
886        }
887    }
888
889    *node_idx = sub_group_broadcast( idx, 0 );
890    return sub_group_broadcast( return_state, 0 );
891
892}
893*/
894
895
896/*
897GRL_INLINE void FlatTreeScheduler_Printf( local struct FlatTreeScheduler* scheduler )
898{
899    barrier( CLK_LOCAL_MEM_FENCE );
900
901    if ( get_local_id( 0 ) == 0 )
902    {
903        printf( "***SCHEDULER***\n" );
904        printf( "built_nodes=%u  active_sgs=%u  leafs=%u wo_p=%u  wo_c=%u\n", scheduler->num_built_nodes, scheduler->active_subgroups, scheduler->num_leafs,
905            scheduler->writeout_produce_count, scheduler->writeout_consume_count );
906        printf( "leafs for refit: {" );
907
908        int nleaf = max( scheduler->num_leafs, 0 );
909
910        for ( uint i = 0; i < nleaf; i++ )
911            printf( "%u ", scheduler->leaf_indices[i] );
912        printf( "}\n" );
913
914        printf( "writeout queue: %u:%u {", scheduler->writeout_produce_count, scheduler->writeout_consume_count );
915        for ( uint i = 0; i < scheduler->writeout_produce_count; i++ )
916            printf( "%u ", scheduler->writeout_indices[i] );
917        printf( "}\n" );
918    }
919
920    barrier( CLK_LOCAL_MEM_FENCE );
921
922}
923*/
924
925
926GRL_INLINE void SUBGROUP_BuildFlatTreeNode( local struct LocalBVH2* bvh2,
927                                 local struct FlatTree* flat_tree,
928                                 local struct FlatTreeScheduler* scheduler,
929                                 uniform ushort flat_tree_root )
930{
931    varying ushort lane = get_sub_group_local_id();
932    varying ushort bvh2_root = FlatTree_GetBVH2Root( flat_tree, flat_tree_root );
933
934    if ( !LocalBVH2_IsInnerNode( bvh2, bvh2_root ) )
935    {
936        uniform ushort num_prims        = LocalBVH2_GetSubtreePrimCount( bvh2, bvh2_root );
937        uniform ushort primref_start    = LocalBVH2_GetLeafPrimStart( bvh2, bvh2_root );
938
939        SUBGROUP_FlatTree_CreateLeafNode( flat_tree, flat_tree_root, primref_start, num_prims );
940    }
941    else
942    {
943        // collapse BVH2 into BVH6.
944        // We will spread the root node's children across the subgroup, and keep adding SIMD lanes until we have enough
945        uniform ushort num_children = 2;
946
947        uniform ushort2 kids =  LocalBVH2_GetChildIndices( bvh2, bvh2_root );
948        varying ushort sg_bvh2_node = kids.x;
949        if ( lane == 1 )
950            sg_bvh2_node = kids.y;
951
952        do
953        {
954            // choose the inner node with maximum area to replace.
955            // Its left child goes in its old location.  Its right child goes in a new lane
956
957            varying ushort sg_area   = LocalBVH2_GetNodeArea( bvh2, sg_bvh2_node );
958            varying bool sg_is_inner = LocalBVH2_IsInnerNode( bvh2, sg_bvh2_node );
959            sg_area = (sg_is_inner && lane < num_children) ? sg_area : 0; // prevent early exit if the largest child is a leaf
960
961            uniform ushort max_area  = sub_group_reduce_max( sg_area );
962            varying bool sg_reducable = max_area == sg_area && (lane < num_children) && sg_is_inner;
963            uniform uint mask         = intel_sub_group_ballot( sg_reducable );
964
965            // TODO_OPT:  Some of these ops seem redundant.. look at trimming further
966            // TODO_OPT:  sub_group_reduce_max results in too many instructions...... unroll the loop and specialize it..
967            //       or ask IGC to give us a version that declares a static maximum number of subgroups to use
968
969            if ( mask == 0 )
970                break;
971
972            // choose the inner node with maximum area to replace
973            uniform ushort victim_child = ctz( mask );
974            uniform ushort victim_node  = sub_group_broadcast( sg_bvh2_node, victim_child );
975            uniform ushort2 kids        = LocalBVH2_GetChildIndices( bvh2, victim_node );
976
977            if ( lane == victim_child )
978                sg_bvh2_node = kids.x;
979            else if ( lane == num_children )
980                sg_bvh2_node = kids.y;
981
982
983            num_children++;
984
985
986        }while ( num_children < TREE_ARITY );
987
988        SUBGROUP_FlatTree_CreateInnerNode( flat_tree, flat_tree_root, sg_bvh2_node, num_children );
989    }
990
991}
992
993
994GRL_INLINE void SUBGROUP_DFS_BuildFlatTree( uniform local struct LocalBVH2* bvh2,
995                                 uniform local struct FlatTree* flat_tree,
996                                 uniform local struct FlatTreeScheduler* scheduler
997                                )
998{
999
1000    uniform ushort flat_tree_node_index = get_sub_group_id();
1001    uniform ushort num_nodes     = 1;
1002    uniform ushort num_built     = 0;
1003
1004    uint tid = get_local_id(0);
1005    if (tid < DFS_MAX_FLATTREE_DEPTH)
1006    {
1007        scheduler->level_start[tid] = DFS_MAX_FLATTREE_NODES;
1008        scheduler->level_count[tid] = 0;
1009        scheduler->num_levels = 0;
1010    }
1011
1012    LOOP_TRIPWIRE_INIT;
1013
1014    do
1015    {
1016        // process one flat tree node per sub group, as many as are available
1017        //
1018        //  The first pass will only run one sub-group, the second up to 6, the third up to 36, and so on
1019        //     nodes will be processed in breadth-first order, but they are not guaranteed to be stored in this order
1020        //      due to use of atomic counters for node allocation
1021        //
1022        if ( flat_tree_node_index < num_nodes )
1023        {
1024            SUBGROUP_BuildFlatTreeNode( bvh2, flat_tree, scheduler, flat_tree_node_index );
1025            SUBGROUP_FlatTreeScheduler_SignalNodeBuilt( scheduler, flat_tree_node_index );
1026            flat_tree_node_index += get_num_sub_groups();
1027        }
1028
1029        barrier( CLK_LOCAL_MEM_FENCE );
1030
1031        // bump up the node count if new nodes were created
1032        // stop as soon as all flattree nodes have been processed
1033        num_nodes = FlatTree_GetNodeCount( flat_tree );
1034        num_built = FlatTreeScheduler_GetNumBuiltNodes( scheduler );
1035
1036        barrier( CLK_LOCAL_MEM_FENCE );
1037
1038        LOOP_TRIPWIRE_INCREMENT( 300 );
1039
1040    } while ( num_built < num_nodes );
1041
1042    barrier( CLK_LOCAL_MEM_FENCE );
1043
1044
1045    // determine depth of each node, compute node ranges and counts for each depth level,
1046    //  and prepare a depth-ordered node index array
1047    uint depth = 0;
1048    uint level_pos = 0;
1049    for( uint i=tid; i<num_nodes; i += get_local_size(0) )
1050    {
1051        // compute depth of this node
1052        uint node_index = i;
1053        while ( node_index != 0 )
1054        {
1055            node_index = FlatTree_GetParentIndex( flat_tree, node_index );
1056            depth++;
1057        }
1058
1059        // assign this node a position within it's depth level
1060        level_pos = atomic_inc_local( &scheduler->level_count[depth] );
1061
1062        // compute total number of levels
1063        atomic_max_local( &scheduler->num_levels, depth+1 );
1064    }
1065
1066    barrier( CLK_LOCAL_MEM_FENCE );
1067
1068    for( uint i=tid; i<num_nodes; i += get_local_size(0) )
1069    {
1070        // prefix-sum level start positions.  Re-computed for each thread
1071        // TODO:  Hierarchical reduction ??
1072        uint level_start=0;
1073        for( uint d=0; d<depth; d++ )
1074            level_start += scheduler->level_count[d];
1075
1076        scheduler->level_start[depth] = level_start;
1077
1078        // scatter node indices into level-ordered node array
1079        scheduler->level_ordered_nodes[level_start + level_pos] = tid;
1080    }
1081
1082    barrier( CLK_LOCAL_MEM_FENCE );
1083
1084}
1085
1086/*
1087GRL_INLINE bool SUBGROUP_RefitNode( uniform local struct FlatTree* flat_tree,
1088                         uniform local struct PrimRefSet* prim_refs,
1089                         uniform ushort* p_node_index )
1090{
1091
1092    // fetch and reduce child AABBs across the subgroup
1093    uniform ushort node_index = *p_node_index;
1094    uniform ushort num_kids = FlatTree_GetNumChildren( flat_tree, node_index );
1095    varying ushort sg_child_index = (get_sub_group_local_id() < num_kids) ? get_sub_group_local_id() : 0;
1096
1097    varying local struct DFSPrimRefAABB* box = FlatTree_GetChildAABB( flat_tree, prim_refs, node_index, sg_child_index );
1098
1099    SUBGROUP_FlatTree_ReduceAndSetAABB( flat_tree, node_index, box );
1100
1101    if ( node_index == 0 )
1102        return false; // if we just refitted the root, we can stop now
1103
1104    // signal the parent node that this node was refitted.  If this was the last child to be refitted
1105    //    returns true and sets 'node_index' to the parent node, so that this thread can continue refitting
1106    return SUBGROUP_FlatTree_SignalRefitComplete( flat_tree, p_node_index );
1107}*/
1108
1109GRL_INLINE struct QBVHNodeN* qnode_ptr( BVHBase* bvh_mem, uint byte_offset )
1110{
1111    return (struct QBVHNodeN*)(((char*)bvh_mem) + byte_offset);
1112}
1113
1114GRL_INLINE void SUBGROUP_WriteQBVHNode(
1115        uniform local struct FlatTree* flat_tree,
1116        uniform local struct PrimRefMeta* primref_meta,
1117        uniform local struct AABB3f* boxes,
1118        uniform ushort flat_tree_root,
1119        uniform struct DFSArgs args,
1120        uniform local uchar* masks
1121      )
1122{
1123
1124
1125    uniform ushort num_children = FlatTree_GetNumChildren( flat_tree, flat_tree_root );
1126    uniform bool is_leaf        = FlatTree_IsLeafNode( flat_tree, flat_tree_root );
1127
1128    varying ushort lane = get_sub_group_local_id();
1129    varying ushort sg_child_index = (lane < num_children) ? lane : 0;
1130
1131    uniform ushort child_base = FlatTree_GetFirstChild( flat_tree, flat_tree_root );
1132
1133    varying struct AABB sg_box4;
1134    if (FlatTree_IsLeafNode( flat_tree, flat_tree_root ))
1135    {
1136        // fetch AABBs for primrefs
1137        sg_box4 = args.primref_buffer[ PrimRefMeta_GetInputIndex( &primref_meta[child_base + sg_child_index] ) ];
1138
1139    }
1140    else
1141    {
1142        // fetch AABBs for child nodes
1143        sg_box4.lower.xyz = AABB3f_load_lower( &boxes[child_base+sg_child_index] );
1144        sg_box4.upper.xyz = AABB3f_load_upper( &boxes[child_base+sg_child_index] );
1145    }
1146
1147
1148    struct QBVHNodeN* qnode = qnode_ptr( args.bvh_base, FlatTree_GetQNodeByteOffset( flat_tree, flat_tree_root ) );
1149
1150    uniform int offset;
1151    uniform uint child_type;
1152    if ( is_leaf )
1153    {
1154        char* leaf_mem = (char*)BVHBase_GetQuadLeaves( args.bvh_base );
1155
1156        leaf_mem += ( FlatTree_GetPrimRefStart( flat_tree, flat_tree_root )) * args.leaf_size_in_bytes;
1157
1158        offset = (int)(leaf_mem - (char*)qnode);
1159        child_type = args.leaf_node_type;
1160    }
1161    else
1162    {
1163        struct QBVHNodeN* kid = qnode_ptr( args.bvh_base, FlatTree_GetQNodeByteOffset( flat_tree, FlatTree_GetFirstChild( flat_tree, flat_tree_root ) ) );
1164        offset = (int) ((char*)kid - (char*)qnode);
1165        child_type = args.inner_node_type;
1166    }
1167    offset = offset >> 6;
1168
1169    if (child_type == NODE_TYPE_INSTANCE)
1170    {
1171        uint instanceMask = PrimRefMeta_GetInstanceMask( &primref_meta[child_base + sg_child_index] );
1172        subgroup_setInstanceQBVHNodeN( offset, &sg_box4, num_children, qnode, lane < num_children ? instanceMask : 0 );
1173    }
1174    else
1175    {
1176        uint mask = BVH_NODE_DEFAULT_MASK;
1177        if( args.need_masks )
1178            mask = masks[flat_tree_root];
1179
1180        subgroup_setQBVHNodeN( offset, child_type, &sg_box4, num_children, qnode, mask );
1181    }
1182
1183    if ( args.need_backpointers )
1184    {
1185        global uint* back_pointers = (global uint*) BVHBase_GetBackPointers( args.bvh_base );
1186        uint idx = FlatTree_GetQNodeIndex( flat_tree, flat_tree_root );
1187        uint bp = FlatTree_BuildBackPointer( flat_tree, flat_tree_root );
1188        back_pointers[idx] = bp;
1189    }
1190
1191    /*
1192    // TODO_OPT:  Eventually this section should also handle leaf splitting due to mixed primref types
1193    //    For now this is done by the leaf creation pipeline, but that path should probably be refactored
1194    //      such that all inner node creation is done in one place
1195
1196    uniform ushort num_children = FlatTree_GetNumChildren( flat_tree, flat_tree_root );
1197    uniform bool is_leaf        = FlatTree_IsLeafNode( flat_tree, flat_tree_root );
1198
1199    varying ushort lane = get_sub_group_local_id();
1200    varying ushort sg_child_index = (lane < num_children) ? lane : 0;
1201
1202    varying local struct DFSPrimRefAABB* sg_box = FlatTree_GetChildAABB( flat_tree, prim_refs, flat_tree_root, sg_child_index );
1203
1204    varying struct AABB sg_box4 = PrimRefSet_ConvertAABB( prim_refs, sg_box );
1205
1206    struct QBVHNodeN* qnode = qnode_ptr( args.bvh_base, FlatTree_GetQNodeByteOffset( flat_tree, flat_tree_root ) );
1207
1208    uniform int offset;
1209    uniform uint child_type;
1210    if ( is_leaf )
1211    {
1212        char* leaf_mem = (char*)BVHBase_GetQuadLeaves( args.bvh_base );
1213
1214        leaf_mem += ( FlatTree_GetPrimRefStart( flat_tree, flat_tree_root )) * args.leaf_size_in_bytes;
1215
1216        offset = (int)(leaf_mem - (char*)qnode);
1217        child_type = args.leaf_node_type;
1218    }
1219    else
1220    {
1221        struct QBVHNodeN* kid = qnode_ptr( args.bvh_base, FlatTree_GetQNodeByteOffset( flat_tree, FlatTree_GetFirstChild( flat_tree, flat_tree_root ) ) );
1222        offset = (int) ((char*)kid - (char*)qnode);
1223        child_type = args.inner_node_type;
1224    }
1225    offset = offset >> 6;
1226
1227    if (child_type == NODE_TYPE_INSTANCE)
1228    {
1229        uint instanceMask = PrimRefSet_GetInstanceMask( prim_refs, FlatTree_GetPrimRefStart(flat_tree, flat_tree_root) + lane );
1230        subgroup_setInstanceQBVHNodeN( offset, &sg_box4, num_children, qnode, lane < num_children ? instanceMask : 0 );
1231    }
1232    else
1233        subgroup_setQBVHNodeN( offset, child_type, &sg_box4, num_children, qnode );
1234
1235    if ( args.need_backpointers )
1236    {
1237        global uint* back_pointers = (global uint*) BVHBase_GetBackPointers( args.bvh_base );
1238        uint idx = FlatTree_GetQNodeIndex( flat_tree, flat_tree_root );
1239        uint bp = FlatTree_BuildBackPointer( flat_tree, flat_tree_root );
1240        back_pointers[idx] = bp;
1241    }
1242    */
1243}
1244
1245/*
1246GRL_INLINE void SUBGROUP_DFS_RefitAndWriteOutFlatTree(
1247    uniform local struct FlatTree* flat_tree,
1248    uniform local struct PrimRefSet* prim_refs,
1249    uniform local struct FlatTreeScheduler* scheduler,
1250    uniform struct DFSArgs args)
1251{
1252
1253    uniform ushort state = STATE_SCHEDULE_REFIT;
1254    uniform ushort node_index = 0;
1255    uniform ushort num_nodes = FlatTree_GetNodeCount(flat_tree);
1256
1257    {
1258        LOOP_TRIPWIRE_INIT;
1259
1260        bool active = true;
1261        bool continue_refit = false;
1262        while (1)
1263        {
1264            if (active)
1265            {
1266                if (continue_refit || SUBGROUP_FlatTreeScheduler_GetRefitTask(scheduler, &node_index))
1267                {
1268                    continue_refit = SUBGROUP_RefitNode(flat_tree, prim_refs, &node_index);
1269                }
1270                else
1271                {
1272                    active = false;
1273                    if (get_sub_group_local_id() == 0)
1274                        atomic_dec(&scheduler->active_subgroups);
1275
1276                    sub_group_barrier(CLK_LOCAL_MEM_FENCE);
1277                }
1278            }
1279
1280            barrier(CLK_LOCAL_MEM_FENCE); // finish all atomics
1281            if (scheduler->active_subgroups == 0)
1282                break;
1283            barrier(CLK_LOCAL_MEM_FENCE); // finish all checks.. prevent race between thread which loops around and thread which doesn't
1284
1285            LOOP_TRIPWIRE_INCREMENT(200);
1286        }
1287    }
1288
1289    for (uint i = get_sub_group_id(); i < num_nodes; i += get_num_sub_groups())
1290        SUBGROUP_WriteQBVHInnerNodes(flat_tree, prim_refs, i, args);
1291
1292    barrier(CLK_LOCAL_MEM_FENCE);
1293
1294
1295    // JDB:  Version below attempts to interleave refit and qnode write-out
1296    //  This could theoretically reduce thread idle time, but it is more complex and does more atomics for scheduling
1297
1298#if 0
1299    // after we've constructed the flat tree (phase 1), there are two things that need to happen:
1300    //   PHASE 2:  Refit the flat tree, computing all of the node ABBs
1301    //   PHASE 3:  Write the nodes out to memory
1302    //
1303    //  all of this is sub-group centric.  Different subgroups can execute phases 2 and 3 concurrently
1304    //
1305
1306    // TODO_OPT:  The scheduling algorithm might need to be re-thought.
1307    //  Fused EUs are very hard to reason about.   It's possible that by scheduling independent
1308    //  SGs in this way we would lose a lot of performance due to fused EU serialization.
1309    //     Needs to be tested experimentally if such a thing is possible
1310
1311    uniform ushort state = STATE_SCHEDULE_REFIT;
1312    uniform ushort node_index = 0;
1313    uniform ushort num_nodes = FlatTree_GetNodeCount(flat_tree);
1314
1315    LOOP_TRIPWIRE_INIT;
1316
1317    do
1318    {
1319        // barrier necessary to protect access to scheduler->active_subgroups
1320        barrier(CLK_LOCAL_MEM_FENCE);
1321
1322        if (state == STATE_SCHEDULE_REFIT)
1323        {
1324            if (SUBGROUP_FlatTreeScheduler_GetRefitTask(scheduler, &node_index))
1325                state = STATE_REFIT;
1326            else
1327                state = STATE_SCHEDULE_WRITEOUT; // fallthrough
1328        }
1329        if (state == STATE_SCHEDULE_WRITEOUT)
1330        {
1331            state = SUBGROUP_FlatTreeScheduler_GetWriteOutTask(scheduler, num_nodes, &node_index);
1332            if (state == STATE_DONE)
1333                SUBGROUP_FlatTreeScheduler_SubGroupDone(scheduler);
1334        }
1335
1336
1337        // A barrier is necessary to ensure that 'QueueNodeForWriteOut' is synchronized with 'GetWriteOutTask'
1338        //  Note that in theory we could have the write-out tasks spin until the refit tasks clear, which would make this barrier unnecessary
1339        //   However, we cannot do this safely on SKUs which do not support independent subgroup forward progress.
1340        barrier(CLK_LOCAL_MEM_FENCE);
1341
1342        if (state == STATE_REFIT)
1343        {
1344            uniform ushort prev_node = node_index;
1345            uniform bool continue_refit = SUBGROUP_RefitNode(flat_tree, prim_refs, &node_index);
1346
1347            SUBGROUP_FlatTreeScheduler_QueueNodeForWriteOut(scheduler, prev_node);
1348
1349            if (!continue_refit)
1350                state = STATE_SCHEDULE_REFIT;
1351        }
1352        else if (state == STATE_WRITEOUT)
1353        {
1354            SUBGROUP_WriteQBVHInnerNodes(flat_tree, prim_refs, node_index, args);
1355            state = STATE_SCHEDULE_WRITEOUT;
1356        }
1357        // A barrier is necessary to ensure that 'QueueNodeForWriteOut' is synchronized with 'GetWriteOutTask'
1358        barrier(CLK_LOCAL_MEM_FENCE);
1359
1360        LOOP_TRIPWIRE_INCREMENT(200);
1361
1362    } while (scheduler->active_subgroups > 0);
1363
1364#endif
1365}
1366*/
1367
1368GRL_INLINE void DFS_CreatePrimRefSet( struct DFSArgs args,
1369                           local struct PrimRefSet* prim_refs )
1370{
1371    ushort id = get_local_id( 0 );
1372    ushort num_primrefs = args.num_primrefs;
1373
1374
1375    PrimRef ref;
1376    struct AABB3f local_aabb;
1377    if ( id < num_primrefs )
1378    {
1379        ref = args.primref_buffer[id];
1380        AABB3f_set_lower( &local_aabb, ref.lower.xyz );
1381        AABB3f_set_upper( &local_aabb, ref.upper.xyz );
1382    }
1383    else
1384    {
1385        AABB3f_init( &local_aabb );
1386    }
1387
1388    AABB3f_atomic_merge_localBB_nocheck( &prim_refs->root_aabb, &local_aabb );
1389
1390    barrier( CLK_LOCAL_MEM_FENCE );
1391
1392    if ( id < num_primrefs )
1393        PrimRefSet_SetPrimRef_FullPrecision( prim_refs, ref, id );
1394}
1395
1396
1397
1398struct BVHBuildLocals
1399{
1400    float  Al[DFS_WG_SIZE];
1401    float  Ar[DFS_WG_SIZE];
1402    uchar2 axis_and_left_count[ DFS_WG_SIZE ];
1403    uint   sah[DFS_WG_SIZE];
1404    uint   num_active_threads;
1405};
1406
1407
1408GRL_INLINE void DFS_ConstructBVH2( local struct LocalBVH2* bvh2,
1409                        local struct PrimRefSet* prim_refs,
1410                        ushort num_prims,
1411                        local struct BVHBuildLocals* locals )
1412{
1413    ushort tid = get_local_id( 0 );
1414
1415    ushort bvh2_root         = 0;
1416    ushort prim_range_start  = 0;
1417    ushort primref_position = tid;
1418
1419    bool active_thread       = tid < num_prims;
1420    float root_area  = PrimRefSet_GetMaxAABBArea( prim_refs );
1421    float area_scale = DFS_BVH2_AREA_QUANT / root_area;
1422
1423    locals->num_active_threads = num_prims;
1424    barrier( CLK_LOCAL_MEM_FENCE );
1425
1426    LOOP_TRIPWIRE_INIT;
1427
1428    do
1429    {
1430        if(active_thread && prim_range_start == primref_position)
1431            locals->sah[primref_position] = UINT_MAX;
1432
1433        if ( active_thread )
1434        {
1435            local struct DFSPrimRefAABB* my_box = PrimRefSet_GetAABBPointer( prim_refs, primref_position );
1436
1437            // each thread evaluates a possible split candidate.  Scan primrefs and compute sah cost
1438            //  do this axis-by-axis to keep register pressure low
1439            float best_sah    = INFINITY;
1440            ushort best_axis  = 3;
1441            ushort best_count = 0;
1442            float best_al     = INFINITY;
1443            float best_ar     = INFINITY;
1444
1445            struct DFSPrimRefAABB box_left[3];
1446            struct DFSPrimRefAABB box_right[3];
1447            float CSplit[3];
1448            ushort count_left[3];
1449
1450            for ( ushort axis = 0; axis < 3; axis++ )
1451            {
1452                DFSPrimRefAABB_init( &box_left[axis] );
1453                DFSPrimRefAABB_init( &box_right[axis] );
1454
1455                CSplit[axis] = my_box->lower[axis] + my_box->upper[axis];
1456                count_left[axis] = 0;
1457            }
1458
1459            // scan primrefs in our subtree and partition using this thread's prim as a split plane
1460            {
1461                struct DFSPrimRefAABB box = *PrimRefSet_GetAABBPointer( prim_refs, prim_range_start );
1462
1463                for ( ushort p = 1; p < num_prims; p++ )
1464                {
1465                        struct DFSPrimRefAABB next_box = *PrimRefSet_GetAABBPointer( prim_refs, prim_range_start + p ); //preloading box for next iteration
1466
1467                        for( ushort axis = 0; axis < 3; axis++ )
1468                        {
1469                            float c = box.lower[axis] + box.upper[axis];
1470
1471                            if ( c < CSplit[axis] )
1472                            {
1473                                // this primitive is to our left.
1474                                DFSPrimRefAABB_extend( &box_left[axis], &box );
1475                                count_left[axis]++;
1476                            }
1477                            else
1478                            {
1479                                // this primitive is to our right
1480                                DFSPrimRefAABB_extend( &box_right[axis], &box );
1481                            }
1482                        }
1483
1484                        box = next_box;
1485                }
1486
1487                // last iteration without preloading box
1488                for( ushort axis = 0; axis < 3; axis++ )
1489                {
1490                    float c = box.lower[axis] + box.upper[axis];
1491
1492                    if ( c < CSplit[axis] )
1493                    {
1494                        // this primitive is to our left.
1495                        DFSPrimRefAABB_extend( &box_left[axis], &box );
1496                        count_left[axis]++;
1497                    }
1498                    else
1499                    {
1500                        // this primitive is to our right
1501                        DFSPrimRefAABB_extend( &box_right[axis], &box );
1502                    }
1503                }
1504            }
1505
1506            for ( ushort axis = 0; axis < 3; axis++ )
1507            {
1508                float Al = DFSPrimRefAABB_halfArea( &box_left[axis]  );
1509                float Ar = DFSPrimRefAABB_halfArea( &box_right[axis] );
1510
1511                // Avoid NANs in SAH calculation in the corner case where all prims go right
1512                //  In this case we set Al=Ar, because such a split will only be selected if all primrefs
1513                //    are co-incident..  In that case, we will fall back to split-in-the-middle and both subtrees
1514                //    should store the same quantized area value
1515                if ( count_left[axis] == 0 )
1516                    Al = Ar;
1517
1518                // compute sah cost
1519                ushort count_right = num_prims - count_left[axis];
1520                float sah = Ar * count_right + Al * count_left[axis];
1521
1522                // keep this split if it is better than the previous one, or if the previous one was a corner-case
1523                if ( sah < best_sah || best_count == 0 )
1524                {
1525                    // yes, keep it
1526                    best_axis   = axis;
1527                    best_sah    = sah;
1528                    best_count  = count_left[axis];
1529                    best_al     = Al;
1530                    best_ar     = Ar;
1531                }
1532            }
1533
1534
1535            // write split information to SLM
1536            locals->Al[primref_position]             = best_al;
1537            locals->Ar[primref_position]             = best_ar;
1538            locals->axis_and_left_count[primref_position].x = best_axis;
1539            locals->axis_and_left_count[primref_position].y = best_count;
1540
1541            uint sah = as_uint(best_sah);
1542            // break ties by axis to ensure deterministic split selection
1543            //  otherwise builder can produce non-deterministic tree structure run to run
1544            //  based on the ordering of primitives (which can vary due to non-determinism in atomic counters)
1545            // Embed split axis and index into sah value; compute min over sah and max over axis
1546            sah = ( ( sah & ~1023 ) | ( 2 - best_axis ) << 8 | primref_position );
1547
1548            // reduce on split candidates in our local subtree and decide the best one
1549            atomic_min_local( &locals->sah[ prim_range_start ], sah);
1550        }
1551
1552        barrier( CLK_LOCAL_MEM_FENCE );
1553
1554        ushort split_index      = locals->sah[ prim_range_start ] & 255;
1555        ushort split_axis       = locals->axis_and_left_count[split_index].x;
1556        ushort split_left_count = locals->axis_and_left_count[split_index].y;
1557        float split_al          = locals->Al[split_index];
1558        float split_ar          = locals->Ar[split_index];
1559
1560        if ( (primref_position == prim_range_start) && active_thread )
1561        {
1562            // first thread in a given subtree creates the inner node
1563            ushort quantized_left_area  = convert_ushort_rtn( split_al * area_scale );
1564            ushort quantized_right_area = convert_ushort_rtn( split_ar * area_scale );
1565            ushort start_left  = prim_range_start;
1566            ushort start_right = prim_range_start + split_left_count;
1567            if ( split_left_count == 0 )
1568                start_right = start_left + (num_prims / 2); // handle split-in-the-middle case
1569
1570            LocalBVH2_CreateInnerNode( bvh2, bvh2_root,
1571                                      start_left, start_right,
1572                                      quantized_left_area, quantized_right_area );
1573        }
1574
1575        barrier( CLK_LOCAL_MEM_FENCE );
1576
1577        struct DFSPrimRef ref;
1578        ushort new_primref_position;
1579
1580        if ( active_thread )
1581        {
1582            ushort2 kids = LocalBVH2_GetChildIndices( bvh2, bvh2_root );
1583            bool go_left;
1584
1585            if ( split_left_count == 0 )
1586            {
1587                // We chose a split with no left-side prims
1588                //  This will only happen if all primrefs are located in the exact same position
1589                //   In that case, fall back to split-in-the-middle
1590                split_left_count = (num_prims / 2);
1591                go_left = (primref_position - prim_range_start < split_left_count);
1592            }
1593            else
1594            {
1595                // determine what side of the split this thread's primref belongs on
1596                local struct DFSPrimRefAABB* my_box     = PrimRefSet_GetAABBPointer( prim_refs, primref_position );
1597                local struct DFSPrimRefAABB* split_box  = PrimRefSet_GetAABBPointer( prim_refs, split_index );
1598                float c = my_box->lower[split_axis] + my_box->upper[split_axis];
1599                float Csplit = split_box->lower[split_axis] + split_box->upper[split_axis];
1600                go_left = c < Csplit;
1601            }
1602
1603            // adjust state variables for next loop iteration
1604            bvh2_root                    = (go_left) ? kids.x : kids.y;
1605            num_prims                    = (go_left) ? split_left_count : (num_prims - split_left_count);
1606            prim_range_start             = (go_left) ? prim_range_start : prim_range_start + split_left_count;
1607
1608            // determine the new primref position by incrementing a counter in the destination subtree
1609            new_primref_position = prim_range_start + LocalBVH2_IncrementPrimCount( bvh2, bvh2_root );
1610
1611            // load our primref from its previous position
1612            ref = PrimRefSet_GetPrimRef( prim_refs, primref_position );
1613        }
1614
1615        barrier( CLK_LOCAL_MEM_FENCE );
1616
1617        if ( active_thread )
1618        {
1619            // write our primref into its sorted position
1620            PrimRefSet_SetPrimRef( prim_refs, ref, new_primref_position );
1621            primref_position = new_primref_position;
1622
1623            // deactivate all threads whose subtrees are small enough to form a leaf
1624            if ( num_prims <= TREE_ARITY )
1625            {
1626                active_thread = false;
1627                atomic_dec_local( &locals->num_active_threads );
1628            }
1629        }
1630
1631        barrier( CLK_LOCAL_MEM_FENCE );
1632
1633        LOOP_TRIPWIRE_INCREMENT( 50 );
1634
1635
1636    } while ( locals->num_active_threads > 0 );
1637
1638
1639}
1640
1641
1642
1643// fast path for #prims <= TREE_ARITY
1644GRL_INLINE void Trivial_DFS( struct DFSArgs args )
1645{
1646
1647    ushort tid = get_local_id( 0 );
1648
1649    PrimRef myRef;
1650    AABB_init( &myRef );
1651    if( tid < args.num_primrefs )
1652        myRef = args.primref_buffer[tid];
1653
1654    uint node_offset;
1655    if ( tid == 0 )
1656        node_offset = 64*allocate_inner_nodes( args.bvh_base, 1 );
1657    node_offset = sub_group_broadcast(node_offset,0);
1658
1659    char* bvh_mem = (char*) args.bvh_base;
1660    struct QBVHNodeN* qnode  = (struct QBVHNodeN*) (bvh_mem + node_offset);
1661
1662    uint child_type = args.leaf_node_type;
1663    uint prim_base  = args.bvh_base->quadLeafStart*64 ;
1664
1665    char* leaf_mem = bvh_mem + prim_base;
1666    int offset = (int)( leaf_mem  - (char*)qnode );
1667
1668    if (child_type == NODE_TYPE_INSTANCE)
1669    {
1670        subgroup_setInstanceQBVHNodeN( offset >> 6, &myRef, args.num_primrefs, qnode, tid < args.num_primrefs ? PRIMREF_instanceMask(&myRef) : 0  );
1671    }
1672    else
1673        subgroup_setQBVHNodeN( offset >> 6, child_type, &myRef, args.num_primrefs, qnode, BVH_NODE_DEFAULT_MASK );
1674
1675    if ( tid < args.num_primrefs )
1676    {
1677        global uint* primref_back_pointers = args.primref_index_buffer + args.num_primrefs;
1678        uint bp = node_offset;
1679
1680        // TODO_OPT:  Leaf creation pipeline can be made simpler by having a sideband buffer containing
1681        //    fatleaf index + position in fatleaf for each primref, instead of forcing leaf creation shader to reconstruct it
1682        //   should also probably do the fat-leaf splitting here
1683        args.primref_buffer[tid]        = myRef;
1684        args.primref_index_buffer[tid]  = tid;
1685
1686        primref_back_pointers[tid] = bp / sizeof(struct QBVHNodeN);
1687
1688        if ( tid == 0 && args.need_backpointers )
1689        {
1690            uint bp = ((uint)-1) << 6;
1691            bp |= (args.num_primrefs) << 3;
1692            *(InnerNode_GetBackPointer(BVHBase_GetBackPointers( args.bvh_base ),0)) = bp;
1693        }
1694    }
1695}
1696
1697
1698
1699
1700
1701void SUBGROUP_DFS_ComputeFlatTreeBoxesAndMasks( uniform local struct FlatTree* flat_tree,
1702                                                uniform local struct FlatTreeScheduler* flat_scheduler,
1703                                                uniform local struct AABB3f* boxes,
1704                                                uniform local struct PrimRefMeta* primref_meta,
1705                                                uniform global struct AABB* primref_buffer,
1706                                                uniform local uchar* masks,
1707                                                bool need_masks )
1708
1709{
1710    uniform int num_levels = (int) flat_scheduler->num_levels;
1711    varying ushort lane = get_sub_group_local_id();
1712
1713    // iterate over depth levels in the tree... deepest to shallowest
1714    for (uniform int level = num_levels - 1; level >= 0; level--)
1715    {
1716        // loop over a range of flattree nodes at this level, one node per sub-group
1717        // TODO_OPT:  Try  and enable this code to process two nodes in a SIMD16 subgroup
1718        uniform ushort level_start      = flat_scheduler->level_start[level];
1719        uniform ushort level_node_count = flat_scheduler->level_count[level];
1720
1721        for (uniform ushort i = get_sub_group_id(); i < level_node_count; i += get_num_sub_groups())
1722        {
1723            uniform ushort node_index = flat_scheduler->level_ordered_nodes[ level_start + i ];
1724
1725            varying struct AABB box;
1726            AABB_init(&box);
1727
1728            uniform uint child_base   = FlatTree_GetFirstChild( flat_tree, node_index );
1729            uniform uint num_children = FlatTree_GetNumChildren( flat_tree, node_index );
1730            varying uint child_index  = child_base + ((lane<num_children)?lane : 0);
1731
1732            varying uint mask = 0xff;
1733            if (FlatTree_IsLeafNode( flat_tree, node_index ))
1734            {
1735                // fetch AABBs for primrefs
1736                box = primref_buffer[ PrimRefMeta_GetInputIndex( &primref_meta[child_index] ) ];
1737                if( need_masks )
1738                    mask = PRIMREF_instanceMask(&box);
1739            }
1740            else
1741            {
1742                // fetch AABBs for child nodes
1743                box.lower.xyz = AABB3f_load_lower( &boxes[child_index] );
1744                box.upper.xyz = AABB3f_load_upper( &boxes[child_index] );
1745                if ( need_masks )
1746                    mask = masks[child_index];
1747            }
1748
1749
1750            // reduce and write box
1751            box = AABB_sub_group_reduce_N6( &box );
1752            if( lane == 0 )
1753                AABB3f_set( &boxes[node_index], box.lower.xyz, box.upper.xyz );
1754
1755            if( need_masks )
1756            {
1757                mask = sub_group_reduce_or_N6(mask);
1758                masks[node_index] = mask;
1759            }
1760
1761        }
1762
1763        barrier( CLK_LOCAL_MEM_FENCE );
1764    }
1765}
1766
1767
1768void SUBGROUP_DFS_WriteNodes(
1769    uniform local struct FlatTree* flat_tree,
1770    uniform local struct AABB3f* boxes,
1771    uniform local struct PrimRefMeta* primref_meta,
1772    uniform struct DFSArgs args,
1773    uniform local uchar* masks
1774    )
1775
1776{
1777    uniform uint num_nodes = FlatTree_GetNodeCount(flat_tree);
1778
1779    for ( uniform uint i = get_sub_group_id(); i < num_nodes; i += get_num_sub_groups() )
1780    {
1781        SUBGROUP_WriteQBVHNode( flat_tree, primref_meta, boxes, i, args, masks );
1782    }
1783
1784}
1785
1786
1787
1788
1789struct Single_WG_build_SLM
1790{
1791    struct FlatTree           flat_tree;
1792    struct FlatTreeScheduler  flat_scheduler;
1793    struct PrimRefMeta primitive_meta[DFS_WG_SIZE];
1794
1795    union
1796    {
1797        struct{
1798            struct PrimRefSet         prim_refs;
1799            struct LocalBVH2          bvh2;
1800            struct BVHBuildLocals     bvh2_locals;
1801        } s1;
1802
1803        struct {
1804            struct AABB3f boxes[DFS_MAX_FLATTREE_NODES];
1805            uchar masks[DFS_MAX_FLATTREE_NODES];
1806        } s2;
1807    } u;
1808
1809};
1810
1811
1812GRL_INLINE void execute_single_WG_build(
1813        struct DFSArgs args,
1814        local struct Single_WG_build_SLM* slm
1815    )
1816{
1817
1818    ushort tid = get_local_id( 0 );
1819
1820    //
1821    // Initialize the various SLM structures.  Different sub-groups take different init paths.
1822    //    NOTE: even numbered subgroups here to avoid the fused-EU serialization bug
1823    //
1824    if ( get_sub_group_id() == 0 )
1825        SUBGROUP_FlatTree_Initialize( &slm->flat_tree, args );
1826    else if ( get_sub_group_id() == 2 )
1827        SUBGROUP_LocalBVH2_Initialize( &slm->u.s1.bvh2, args.num_primrefs );
1828    else if ( get_sub_group_id() == 4 )
1829        SUBGROUP_FlatTreeScheduler_Initialize( &slm->flat_scheduler );
1830    else if ( get_sub_group_id() == 6 )
1831        SUBGROUP_PrimRefSet_Initialize( &slm->u.s1.prim_refs );
1832
1833    barrier( CLK_LOCAL_MEM_FENCE );
1834
1835    // load the PrimRefs
1836    DFS_CreatePrimRefSet( args, &slm->u.s1.prim_refs );
1837
1838    // build the BVH2
1839    DFS_ConstructBVH2( &slm->u.s1.bvh2, &slm->u.s1.prim_refs, args.num_primrefs, &slm->u.s1.bvh2_locals );
1840
1841    // copy out metadata for primrefs now that they have been sorted
1842    if( tid < args.num_primrefs )
1843    {
1844        slm->primitive_meta[tid] = PrimRefSet_GetMeta( &slm->u.s1.prim_refs, tid );
1845    }
1846    barrier( CLK_LOCAL_MEM_FENCE );
1847
1848    // collapse into a FlatTree
1849    SUBGROUP_DFS_BuildFlatTree( &slm->u.s1.bvh2, &slm->flat_tree, &slm->flat_scheduler );
1850
1851    // allocate output QBVH6 nodes
1852    if ( get_local_id( 0 ) == 0 )
1853        FlatTree_AllocateQNodes( &slm->flat_tree, args );
1854
1855    barrier( CLK_LOCAL_MEM_FENCE );
1856
1857    SUBGROUP_DFS_ComputeFlatTreeBoxesAndMasks( &slm->flat_tree, &slm->flat_scheduler, &slm->u.s2.boxes[0], slm->primitive_meta, args.primref_buffer, slm->u.s2.masks, args.need_masks );
1858
1859    //FlatTree_Printf( &slm->flat_tree );
1860    //FlatTree_check_boxes ( &slm->flat_tree, args.primref_buffer, &slm->u.s2.boxes[0], slm->primitive_meta );
1861
1862    SUBGROUP_DFS_WriteNodes( &slm->flat_tree, &slm->u.s2.boxes[0], slm->primitive_meta, args, slm->u.s2.masks );
1863
1864
1865    // generate sorted primref index buffer and backpointers to feed the leaf creation pipeilne
1866    if ( tid < args.num_primrefs )
1867    {
1868        uint input_index = PrimRefMeta_GetInputIndex(&slm->primitive_meta[tid]);
1869
1870        uint bp = FlatTree_GetPrimRefBackPointer( &slm->flat_tree, tid );
1871        global uint* primref_back_pointers = args.primref_index_buffer + args.num_primrefs;
1872
1873        args.primref_index_buffer[tid] = input_index;
1874
1875        primref_back_pointers[tid] = bp / sizeof(struct QBVHNodeN);
1876
1877        if ( tid == 0 && args.need_backpointers  )
1878        {
1879            *(InnerNode_GetBackPointer(BVHBase_GetBackPointers( args.bvh_base ),0)) |= ((uint)-1) << 6;
1880        }
1881    }
1882}
1883
1884
1885
1886
1887GRL_ANNOTATE_IGC_DO_NOT_SPILL
1888__attribute__( (reqd_work_group_size( DFS_WG_SIZE, 1, 1 )) )
1889__attribute__( (intel_reqd_sub_group_size( 16 )) )
1890kernel void DFS( global struct Globals* globals,
1891                 global char* bvh_mem,
1892                 global PrimRef* primref_buffer,
1893                 global uint* primref_index_buffer,
1894                 uint alloc_backpointers
1895                 )
1896{
1897    struct DFSArgs args;
1898    args.bvh_base             = (global struct BVHBase*) bvh_mem;
1899    args.leaf_node_type       = globals->leafPrimType;
1900    args.inner_node_type      = NODE_TYPE_INTERNAL;
1901    args.leaf_size_in_bytes   = globals->leafSize;
1902    args.primref_buffer       = primref_buffer;
1903    args.need_backpointers    = alloc_backpointers != 0;
1904    args.num_primrefs         = globals->numPrimitives;
1905    args.primref_index_buffer = primref_index_buffer;
1906    args.need_masks           = args.leaf_node_type == NODE_TYPE_INSTANCE;
1907
1908    if ( args.num_primrefs <= TREE_ARITY )
1909    {
1910        // TODO_OPT: This decision should be made using indirect dispatch
1911        if( get_sub_group_id() == 0 )
1912            Trivial_DFS( args );
1913        return;
1914    }
1915
1916    local struct Single_WG_build_SLM slm;
1917
1918    execute_single_WG_build( args, &slm );
1919}
1920
1921
1922
1923
1924GRL_ANNOTATE_IGC_DO_NOT_SPILL
1925__attribute__( (reqd_work_group_size( DFS_WG_SIZE, 1, 1 )) )
1926__attribute__( (intel_reqd_sub_group_size( 16 )) )
1927kernel void DFS_single_wg(
1928    global struct Globals* globals,
1929    global char* bvh_mem,
1930    global PrimRef* primref_buffer,
1931    global uint* primref_index_buffer,
1932    uint sah_flags
1933)
1934{
1935    struct DFSArgs args;
1936    args.bvh_base = (global struct BVHBase*) bvh_mem;
1937    args.leaf_node_type = globals->leafPrimType;
1938    args.inner_node_type = NODE_TYPE_INTERNAL;
1939    args.leaf_size_in_bytes = globals->leafSize;
1940    args.primref_buffer = primref_buffer;
1941    args.need_backpointers = sah_flags & SAH_FLAG_NEED_BACKPOINTERS;
1942    args.num_primrefs = globals->numPrimitives;
1943    args.primref_index_buffer = primref_index_buffer;
1944    args.need_masks = sah_flags & SAH_FLAG_NEED_MASKS;
1945
1946    local struct Single_WG_build_SLM slm;
1947
1948    execute_single_WG_build( args, &slm );
1949}
1950
1951
1952GRL_ANNOTATE_IGC_DO_NOT_SPILL
1953__attribute__( (reqd_work_group_size( 16, 1, 1 )) )
1954__attribute__( (intel_reqd_sub_group_size( 16 )) )
1955kernel void DFS_trivial(
1956    global struct Globals* globals,
1957    global char* bvh_mem,
1958    global PrimRef* primref_buffer,
1959    global uint* primref_index_buffer,
1960    uint sah_flags
1961)
1962{
1963    struct DFSArgs args;
1964    args.bvh_base = (global struct BVHBase*) bvh_mem;
1965    args.leaf_node_type = globals->leafPrimType;
1966    args.inner_node_type = NODE_TYPE_INTERNAL;
1967    args.leaf_size_in_bytes = globals->leafSize;
1968    args.primref_buffer = primref_buffer;
1969    args.need_backpointers = sah_flags & SAH_FLAG_NEED_BACKPOINTERS;
1970    args.num_primrefs = globals->numPrimitives;
1971    args.primref_index_buffer = primref_index_buffer;
1972    args.need_masks = sah_flags & SAH_FLAG_NEED_MASKS;
1973
1974    Trivial_DFS( args );
1975}
1976
1977
1978struct DFSArgs dfs_args_from_sah_globals( global struct SAHBuildGlobals* sah_globals )
1979{
1980    struct DFSArgs args;
1981    args.bvh_base               = (global struct BVHBase*) sah_globals->p_bvh_base;
1982    args.leaf_node_type         = sah_globals->leaf_type;
1983    args.inner_node_type        = NODE_TYPE_INTERNAL;
1984    args.leaf_size_in_bytes     = sah_globals->leaf_size;
1985    args.primref_buffer         = (global PrimRef*) sah_globals->p_primrefs_buffer;
1986    args.need_backpointers      = sah_globals->flags & SAH_FLAG_NEED_BACKPOINTERS;
1987    args.num_primrefs           = sah_globals->num_primrefs;
1988    args.primref_index_buffer   = (global uint*) sah_globals->p_primref_index_buffers;
1989    args.need_masks             = sah_globals->flags & SAH_FLAG_NEED_MASKS;
1990
1991    return args;
1992}
1993
1994
1995GRL_ANNOTATE_IGC_DO_NOT_SPILL
1996__attribute__((reqd_work_group_size(DFS_WG_SIZE, 1, 1)))
1997__attribute__((intel_reqd_sub_group_size(16)))
1998kernel void DFS_single_wg_batchable(
1999    global struct SAHBuildGlobals* globals_buffer,
2000    global struct VContextScheduler* scheduler
2001)
2002{
2003    global struct SAHBuildGlobals* sah_globals = globals_buffer + scheduler->num_trivial_builds + get_group_id(0);
2004
2005    struct DFSArgs args = dfs_args_from_sah_globals( sah_globals );
2006
2007    local struct Single_WG_build_SLM slm;
2008
2009    execute_single_WG_build(args, &slm);
2010}
2011
2012
2013GRL_ANNOTATE_IGC_DO_NOT_SPILL
2014__attribute__((reqd_work_group_size(16, 1, 1)))
2015__attribute__((intel_reqd_sub_group_size(16)))
2016kernel void DFS_trivial_batchable(
2017    global struct SAHBuildGlobals* globals_buffer
2018)
2019{
2020    global struct SAHBuildGlobals* sah_globals = globals_buffer + get_group_id(0);
2021
2022    struct DFSArgs args = dfs_args_from_sah_globals(sah_globals);
2023
2024    Trivial_DFS(args);
2025}