• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1//
2// Copyright (C) 2009-2021 Intel Corporation
3//
4// SPDX-License-Identifier: MIT
5//
6//
7
8#include "binned_sah_shared.h"
9
10#include "libs/lsc_intrinsics.h"
11#include "intrinsics.h"
12#include "AABB.h"
13#include "AABB3f.h"
14
15#include "qbvh6.h"
16#include "common.h"
17
18#include "libs/lsc_intrinsics.h"
19
20#define SGPRINT_16x(prefix,fmt,type,val)  {\
21                                        type v0 = sub_group_broadcast( val, 0 );\
22                                        type v1 = sub_group_broadcast( val, 1 );\
23                                        type v2 = sub_group_broadcast( val, 2 );\
24                                        type v3 = sub_group_broadcast( val, 3 );\
25                                        type v4 = sub_group_broadcast( val, 4 );\
26                                        type v5 = sub_group_broadcast( val, 5 );\
27                                        type v6 = sub_group_broadcast( val, 6 );\
28                                        type v7 = sub_group_broadcast( val, 7 );\
29                                        type v8 = sub_group_broadcast( val, 8 );\
30                                        type v9 = sub_group_broadcast( val, 9 );\
31                                        type v10 = sub_group_broadcast( val, 10 );\
32                                        type v11 = sub_group_broadcast( val, 11 );\
33                                        type v12 = sub_group_broadcast( val, 12 );\
34                                        type v13 = sub_group_broadcast( val, 13 );\
35                                        type v14 = sub_group_broadcast( val, 14 );\
36                                        type v15 = sub_group_broadcast( val, 15 );\
37                                        sub_group_barrier(CLK_LOCAL_MEM_FENCE); \
38                                        if( get_sub_group_local_id() == 0 ) { \
39                                        printf(prefix fmt fmt fmt fmt fmt fmt fmt fmt \
40                                                      fmt fmt fmt fmt fmt fmt fmt fmt"\n" , \
41                                            v0,v1,v2,v3,v4,v5,v6,v7,v8,v9,v10,v11,v12,v13,v14,v15);}}
42
43
44#define SGPRINT_6x(prefix,fmt,type,val)  {\
45                                        type v0 = sub_group_broadcast( val, 0 );\
46                                        type v1 = sub_group_broadcast( val, 1 );\
47                                        type v2 = sub_group_broadcast( val, 2 );\
48                                        type v3 = sub_group_broadcast( val, 3 );\
49                                        type v4 = sub_group_broadcast( val, 4 );\
50                                        type v5 = sub_group_broadcast( val, 5 );\
51                                        sub_group_barrier(CLK_LOCAL_MEM_FENCE); \
52                                        if( get_sub_group_local_id() == 0 ) { \
53                                        printf(prefix fmt fmt fmt fmt fmt fmt "\n" , \
54                                            v0,v1,v2,v3,v4,v5);}}
55
56#define BFS_WG_SIZE  512
57
58#define BFS_NUM_VCONTEXTS 256 // must be multiple of 64
59
60#define TREE_ARITY 6
61
62#define DFS_WG_SIZE  256
63#define DFS_THRESHOLD 256
64
65
66void BFSDispatchQueue_print(struct BFSDispatchQueue* q, uint n)
67{
68    for (uint i = 0; i < q->num_dispatches; i++)
69        printf("   %u,ctx=%u,batch=%u\n", q->wg_count[i], q->records[i].context_id, q->records[i].batch_index);
70}
71
72void VContextScheduler_print(struct VContextScheduler* scheduler)
73{
74    if (get_local_id(0) == 0)
75    {
76        printf("SCHEDULER:\n");
77        printf("    bfs=%u dfs=%u\n", scheduler->num_bfs_wgs, scheduler->num_dfs_wgs);
78
79        printf("BFS QUEUE:\n");
80        BFSDispatchQueue_print(&scheduler->bfs_queue, scheduler->num_bfs_wgs);
81
82
83        printf("DFS QUEUE\n");
84        for (uint i = 0; i < scheduler->num_dfs_wgs; i++)
85        {
86            struct DFSDispatchRecord* r = &scheduler->dfs_queue.records[i];
87            printf("    (%u-%u) root=%u  depth=%u  batch_index=%u\n",
88                r->primref_base, r->primref_base + r->num_primrefs,
89                r->bvh2_base, r->tree_depth, r->batch_index);
90        }
91
92        printf("CONTEXTS:\n");
93        for (uint i = 0; i < BFS_NUM_VCONTEXTS; i++)
94        {
95            if (scheduler->vcontext_state[i] != VCONTEXT_STATE_UNALLOCATED)
96            {
97                printf(" context: %u  state=%u\n", i, scheduler->vcontext_state[i]);
98                printf("     prims: %u-%u\n", scheduler->contexts[i].dispatch_primref_begin, scheduler->contexts[i].dispatch_primref_end);
99                printf("     depth: %u\n", scheduler->contexts[i].tree_depth);
100                printf("     root: %u\n", scheduler->contexts[i].bvh2_root);
101                printf("     batch: %u\n", scheduler->contexts[i].batch_index);
102            }
103        }
104
105
106
107    }
108
109}
110
111
112inline float3 select_min(float3 v, bool mask)
113{
114    return (float3)(mask ? v.x : (float)(INFINITY),
115        mask ? v.y : (float)(INFINITY),
116        mask ? v.z : (float)(INFINITY));
117}
118inline float3 select_max(float3 v, bool mask)
119{
120    return (float3)(mask ? v.x : -(float)(INFINITY),
121        mask ? v.y : -(float)(INFINITY),
122        mask ? v.z : -(float)(INFINITY));
123}
124
125///////////////////////////////////////////////////////////////////////////
126
127//  The 'LRBounds' structure uses negated-max to allow
128//  both atomic_min and atomic_max to be issued fused into one message
129
130struct AABB3f LRBounds_get_left_centroid( LRBounds* b )
131{
132    struct AABB3f* pbox = &b->boxes.left_centroid_bounds;
133    return AABB3f_construct( AABB3f_load_lower(pbox), -AABB3f_load_upper(pbox) );
134}
135struct AABB3f LRBounds_get_right_centroid( LRBounds* b )
136{
137    struct AABB3f* pbox = &b->boxes.right_centroid_bounds;
138    return AABB3f_construct( AABB3f_load_lower(pbox), -AABB3f_load_upper(pbox) );
139}
140struct AABB3f LRBounds_get_left_geom( LRBounds* b )
141{
142    struct AABB3f* pbox = &b->boxes.left_geom_bounds;
143    return AABB3f_construct( AABB3f_load_lower(pbox), -AABB3f_load_upper(pbox) );
144}
145struct AABB3f LRBounds_get_right_geom( LRBounds* b )
146{
147    struct AABB3f* pbox = &b->boxes.right_geom_bounds;
148    return AABB3f_construct( AABB3f_load_lower(pbox), -AABB3f_load_upper(pbox) );
149}
150
151
152void LRBounds_merge_left( local LRBounds* b, float3 CMin, float3 CMax, float3 GMin, float3 GMax )
153{
154    // All of the input vectors have come from sub-group reductions and are thus uniform
155    //   Using atomic_min calls as below results in IGC generating 12 atomic_min messages and a large stack of movs
156    //  The code below should result in 1 atomic_min message and a simularly large stack of movs
157
158    float mergeVal0 = INFINITY;
159    float mergeVal1 = INFINITY;
160    uint i = get_sub_group_local_id();
161
162    // insert the various merge values into one register
163    //  We use two parallel variables here to enable some ILP
164
165    uint imod = (i>=6) ? (i-6) : i;
166    mergeVal0 = (imod==0)  ?  CMin.x : mergeVal0;
167    mergeVal1 = (imod==0)  ?  GMin.x : mergeVal1;
168
169    mergeVal0 = (imod==1)  ?  CMin.y : mergeVal0;
170    mergeVal1 = (imod==1)  ?  GMin.y : mergeVal1;
171
172    mergeVal0 = (imod==2)  ?  CMin.z : mergeVal0;
173    mergeVal1 = (imod==2)  ?  GMin.z : mergeVal1;
174
175    mergeVal0 = (imod==3)  ? -CMax.x : mergeVal0;
176    mergeVal1 = (imod==3)  ? -GMax.x : mergeVal1;
177
178    mergeVal0 = (imod==4)  ? -CMax.y : mergeVal0;
179    mergeVal1 = (imod==4)  ? -GMax.y : mergeVal1;
180
181    mergeVal0 = (imod==5)  ? -CMax.z : mergeVal0;
182    mergeVal1 = (imod==5)  ? -GMax.z : mergeVal1;
183
184    float merge = (i<6) ? mergeVal0 : mergeVal1;
185    if( i < 12 )
186        atomic_min( &b->scalars.Array[i], merge );
187
188    //atomic_min( &b->boxes.left_centroid_bounds.lower[0], CMin.x );
189    //atomic_min( &b->boxes.left_centroid_bounds.lower[1], CMin.y );
190    //atomic_min( &b->boxes.left_centroid_bounds.lower[2], CMin.z );
191    //atomic_min( &b->boxes.left_centroid_bounds.upper[0], -CMax.x );
192    //atomic_min( &b->boxes.left_centroid_bounds.upper[1], -CMax.y );
193    //atomic_min( &b->boxes.left_centroid_bounds.upper[2], -CMax.z );
194    //atomic_min( &b->boxes.left_geom_bounds.lower[0],      GMin.x );
195    //atomic_min( &b->boxes.left_geom_bounds.lower[1],      GMin.y );
196    //atomic_min( &b->boxes.left_geom_bounds.lower[2],      GMin.z );
197    //atomic_min( &b->boxes.left_geom_bounds.upper[0], -GMax.x );
198    //atomic_min( &b->boxes.left_geom_bounds.upper[1], -GMax.y );
199    //atomic_min( &b->boxes.left_geom_bounds.upper[2], -GMax.z );
200}
201
202void LRBounds_merge_right( local LRBounds* b, float3 CMin, float3 CMax, float3 GMin, float3 GMax )
203{
204    // All of the input vectors have come from sub-group reductions and are thus uniform
205    //   Using atomic_min calls as below results in IGC generating 12 atomic_min messages and a large stack of movs
206    //  The code below should result in 1 atomic_min message and a simularly large stack of movs
207
208    float mergeVal0 = INFINITY;
209    float mergeVal1 = INFINITY;
210    uint i = get_sub_group_local_id();
211
212    // insert the various merge values into one register
213    //  We use two parallel variables here to enable some ILP
214
215    uint imod = (i>=6) ? (i-6) : i;
216    mergeVal0 = (imod==0)  ?  CMin.x : mergeVal0;
217    mergeVal1 = (imod==0)  ?  GMin.x : mergeVal1;
218
219    mergeVal0 = (imod==1)  ?  CMin.y : mergeVal0;
220    mergeVal1 = (imod==1)  ?  GMin.y : mergeVal1;
221
222    mergeVal0 = (imod==2)  ?  CMin.z : mergeVal0;
223    mergeVal1 = (imod==2)  ?  GMin.z : mergeVal1;
224
225    mergeVal0 = (imod==3)  ? -CMax.x : mergeVal0;
226    mergeVal1 = (imod==3)  ? -GMax.x : mergeVal1;
227
228    mergeVal0 = (imod==4)  ? -CMax.y : mergeVal0;
229    mergeVal1 = (imod==4)  ? -GMax.y : mergeVal1;
230
231    mergeVal0 = (imod==5)  ? -CMax.z : mergeVal0;
232    mergeVal1 = (imod==5)  ? -GMax.z : mergeVal1;
233
234    float merge = (i<6) ? mergeVal0 : mergeVal1;
235    if( i < 12 )
236        atomic_min( &b->scalars.Array[i+12], merge );
237
238    //atomic_min( &b->boxes.right_centroid_bounds.lower[0],  CMin.x );
239    //atomic_min( &b->boxes.right_centroid_bounds.lower[1],  CMin.y );
240    //atomic_min( &b->boxes.right_centroid_bounds.lower[2],  CMin.z );
241    //atomic_min( &b->boxes.right_centroid_bounds.upper[0], -CMax.x );
242    //atomic_min( &b->boxes.right_centroid_bounds.upper[1], -CMax.y );
243    //atomic_min( &b->boxes.right_centroid_bounds.upper[2], -CMax.z );
244    //atomic_min( &b->boxes.right_geom_bounds.lower[0],      GMin.x );
245    //atomic_min( &b->boxes.right_geom_bounds.lower[1],      GMin.y );
246    //atomic_min( &b->boxes.right_geom_bounds.lower[2],      GMin.z );
247    //atomic_min( &b->boxes.right_geom_bounds.upper[0],     -GMax.x );
248    //atomic_min( &b->boxes.right_geom_bounds.upper[1],     -GMax.y );
249    //atomic_min( &b->boxes.right_geom_bounds.upper[2],     -GMax.z );
250}
251
252void LRBounds_merge( global LRBounds* globalBounds, local LRBounds* localBounds )
253{
254    uint i = get_local_id(0);
255    if( i < 24 )
256        atomic_min(&globalBounds->scalars.Array[i], localBounds->scalars.Array[i] );
257}
258
259
260void LRBounds_init( LRBounds* bounds )
261{
262    uint i = get_local_id(0) * 4;
263    if( i < 24 )
264    {
265        // compiler should merge it into a 4xdword send
266        bounds->scalars.Array[i+0] = INFINITY;
267        bounds->scalars.Array[i+1] = INFINITY;
268        bounds->scalars.Array[i+2] = INFINITY;
269        bounds->scalars.Array[i+3] = INFINITY;
270    }
271
272}
273
274
275inline void LRBounds_init_subgroup( LRBounds* bounds)
276{
277    uint sg_size = get_sub_group_size();
278    uint lane = get_sub_group_local_id();
279
280    for (uint i = lane * 4; i < 24; i += sg_size * 4)
281    {
282        // compiler should merge it into a 4xdword send
283        bounds->scalars.Array[i+0] = INFINITY;
284        bounds->scalars.Array[i+1] = INFINITY;
285        bounds->scalars.Array[i+2] = INFINITY;
286        bounds->scalars.Array[i+3] = INFINITY;
287    }
288
289}
290
291///////////////////////////////////////////////////////////////////////////
292
293inline void BinInfo_init(struct BFS_BinInfo* bin_info)
294{
295    for (uint id = get_local_id(0) * 4; id < 18 * BFS_NUM_BINS; id += get_local_size(0) * 4)
296    {
297        float inf = INFINITY;
298        // compiler should merge it into a 4xdword send
299        bin_info->min_max[id+0] = inf;
300        bin_info->min_max[id+1] = inf;
301        bin_info->min_max[id+2] = inf;
302        bin_info->min_max[id+3] = inf;
303    }
304    for (uint id = get_local_id(0) * 4; id < 3 * BFS_NUM_BINS; id += get_local_size(0) * 4)
305    {
306        // compiler should merge it into a 4xdword send
307        bin_info->counts[id+0] = 0;
308        bin_info->counts[id+1] = 0;
309        bin_info->counts[id+2] = 0;
310        bin_info->counts[id+3] = 0;
311    }
312}
313
314
315// copy global to local
316inline void BinInfo_copy( local struct BFS_BinInfo* local_bin_info, global struct BFS_BinInfo* global_bin_info )
317{
318    for (uint id = get_local_id(0); id < 18 * BFS_NUM_BINS; id += get_local_size(0))
319    {
320        float inf = INFINITY ;
321        float f = global_bin_info->min_max[id];
322        local_bin_info->min_max[id] = f;
323    }
324    for (uint id = get_local_id(0); id < 3 * BFS_NUM_BINS; id += get_local_size(0))
325    {
326        local_bin_info->counts[id] = global_bin_info->counts[id];
327    }
328}
329
330inline void BinInfo_init_subgroup(struct BFS_BinInfo* bin_info)
331{
332    uint sg_size = get_sub_group_size();
333    uint lane = get_sub_group_local_id();
334
335    for (uint i = lane * 4; i < 3 * BFS_NUM_BINS; i += sg_size * 4)
336    {
337        // compiler should merge it into a 4xdword send
338        bin_info->counts[i+0] = 0;
339        bin_info->counts[i+1] = 0;
340        bin_info->counts[i+2] = 0;
341        bin_info->counts[i+3] = 0;
342    }
343
344
345    for (uint i = lane * 4; i < 18 * BFS_NUM_BINS; i += sg_size * 4)
346    {
347        // compiler should merge it into a 4xdword send
348        bin_info->min_max[i+0] = INFINITY;
349        bin_info->min_max[i+1] = INFINITY;
350        bin_info->min_max[i+2] = INFINITY;
351        bin_info->min_max[i+3] = INFINITY;
352    }
353
354}
355
356float3 shuffle_down_float3( float3 a, float3 b, uint delta )
357{
358    return (float3)(
359        intel_sub_group_shuffle_down( a.x, b.x, delta ),
360        intel_sub_group_shuffle_down( a.y, b.y, delta ),
361        intel_sub_group_shuffle_down( a.z, b.z, delta )
362        );
363}
364
365
366
367
368void BinInfo_primref_ballot_loop( local struct BFS_BinInfo* bin_info, uint axis, uint bin, float3 lower, float3 upper, bool active_lane )
369{
370    local float* bins_min = &bin_info->min_max[0];
371    local float* bins_max = &bin_info->min_max[3];
372
373    varying uint place = (bin + axis*BFS_NUM_BINS);
374    varying uint lane = get_sub_group_local_id();
375
376    uniform uint active_mask = intel_sub_group_ballot(active_lane);
377
378    while( active_mask )
379    {
380        uniform uint leader     = ctz( active_mask );
381        uniform uint lead_place = intel_sub_group_shuffle( place, leader );
382        varying bool matching_bin = lead_place == place && active_lane;
383
384        varying float3 lo = (float3)(INFINITY,INFINITY,INFINITY);
385        varying float3 hi = (float3)(-INFINITY,-INFINITY,-INFINITY);
386        if (matching_bin)
387        {
388            lo = lower.xyz;
389            hi = upper.xyz;
390        }
391
392        lo = sub_group_reduce_min_float3( lo );
393        hi = sub_group_reduce_max_float3( hi );
394
395        {
396            // atomic min operation vectorized across 6 lanes
397            //    [ lower.xyz ][-][upper.xyz][-]
398            //
399            // Lanes 3 and 7 are inactive
400
401            uint lmod = lane % 4;
402            uint ldiv = lane / 4;
403            float vlo = lo.x;
404            float vhi = hi.x;
405            vlo = (lmod == 1) ? lo.y : vlo;
406            vhi = (lmod == 1) ? hi.y : vhi;
407            vlo = (lmod == 2) ? lo.z : vlo;
408            vhi = (lmod == 2) ? hi.z : vhi;
409
410            float v = (ldiv == 0) ? vlo : -vhi;
411
412            if( (1<<lane) & 0x77 )
413                atomic_min( &bin_info->min_max[ 6*lead_place + lmod + 3*ldiv ], v );
414        }
415
416      //if( lane == 0 )
417      //    atomic_add_local(&bin_info->counts[lead_place], popcount(active_mask & intel_sub_group_ballot(matching_bin)) );
418
419        active_mask = active_mask & intel_sub_group_ballot(!matching_bin);
420    }
421}
422
423inline void BinInfo_add_primref(struct BinMapping* binMapping, local struct BFS_BinInfo* bin_info, PrimRef* primref, bool active_lane )
424{
425
426    const float4 lower = primref->lower;
427    const float4 upper = primref->upper;
428    const float4 p = lower + upper;
429    const uint4 i = convert_uint4( (p - binMapping->ofs) * binMapping->scale );
430
431    BinInfo_primref_ballot_loop( bin_info, 0, i.x, lower.xyz, upper.xyz, active_lane );
432    BinInfo_primref_ballot_loop( bin_info, 1, i.y, lower.xyz, upper.xyz, active_lane );
433    BinInfo_primref_ballot_loop( bin_info, 2, i.z, lower.xyz, upper.xyz, active_lane );
434
435    if (active_lane)
436    {
437        atomic_inc_local( &bin_info->counts[i.x + 0 * BFS_NUM_BINS] );
438        atomic_inc_local( &bin_info->counts[i.y + 1 * BFS_NUM_BINS] );
439        atomic_inc_local( &bin_info->counts[i.z + 2 * BFS_NUM_BINS] );
440    }
441}
442
443inline void BinInfo_merge(global struct BFS_BinInfo* global_info, local struct BFS_BinInfo* local_info)
444{
445    uint id = get_local_id(0);
446    for (uint id = get_local_id(0); id < 18 * BFS_NUM_BINS; id += get_local_size(0))
447    {
448            float v = local_info->min_max[id];
449            if( v != INFINITY )
450                atomic_min(&global_info->min_max[id], v);
451    }
452    for (uint id = get_local_id(0); id < 3 * BFS_NUM_BINS; id += get_local_size(0))
453    {
454            uint c = local_info->counts[id];
455            if( c )
456                atomic_add_global(&global_info->counts[id], c);
457    }
458}
459
460inline struct AABB3f BinInfo_get_AABB(struct BFS_BinInfo* bin_info, ushort bin, ushort axis)
461{
462    float* min = &bin_info->min_max[6*(bin + axis*BFS_NUM_BINS)];
463    float* max = min + 3;
464    struct AABB3f box;
465    for (uint i = 0; i < 3; i++)
466    {
467        box.lower[i] = min[i];
468        box.upper[i] = -max[i];
469    }
470
471    return box;
472}
473
474inline uint3 BinInfo_get_counts(struct BFS_BinInfo* bin_info, ushort bin)
475{
476    uint3 counts;
477    counts.x = bin_info->counts[bin + 0 * BFS_NUM_BINS]; // TODO: block load these
478    counts.y = bin_info->counts[bin + 1 * BFS_NUM_BINS];
479    counts.z = bin_info->counts[bin + 2 * BFS_NUM_BINS];
480    return counts;
481}
482inline uint BinInfo_get_count(struct BFS_BinInfo* bin_info, ushort bin, ushort axis)
483{
484    return bin_info->counts[bin + axis * BFS_NUM_BINS];
485}
486
487
488void BVH2_Initialize( struct BVH2* bvh )
489{
490    bvh->num_nodes = 1;
491}
492
493inline bool BVH2_IsInnerNode( global struct BVH2* bvh, uint node_index )
494{
495    global struct BVH2Node* n = ((global struct BVH2Node*)(bvh + 1)) + node_index;
496    return (n->meta_ss & 0x10000) != 0;
497}
498inline uint BVH2_GetRoot( struct BVH2* bvh )
499{
500    return 0;
501}
502
503//////////////////////////////////////////////
504// BVH2NodeMetaData funcs
505//////////////////////////////////////////////
506struct BVH2NodeMetaData
507{
508    uint  meta_u;   // leaf:  primref start.  inner: offset from node to its first child
509    uint  meta_ss;
510};
511
512inline struct BVH2NodeMetaData BVH2_GetNodeMetaData( global struct BVH2* bvh, uint node_index )
513{
514    global struct BVH2Node* n = ((global struct BVH2Node*)(bvh + 1)) + node_index;
515    struct BVH2NodeMetaData meta;
516    meta.meta_u = n->meta_u;
517    meta.meta_ss = n->meta_ss;
518    return meta;
519}
520
521inline bool BVH2NodeMetaData_IsInnerNode( struct BVH2NodeMetaData* meta )
522{
523    return (meta->meta_ss & 0x10000) != 0;
524}
525
526inline ushort BVH2NodeMetaData_GetLeafPrimCount( struct BVH2NodeMetaData* meta )
527{
528    return meta->meta_ss & 0xffff;
529}
530
531inline uint BVH2NodeMetaData_GetLeafPrimStart( struct BVH2NodeMetaData* meta )
532{
533    return meta->meta_u;
534}
535
536inline uint BVH2NodeMetaData_GetMask( struct BVH2NodeMetaData* meta )
537{
538    return (meta->meta_ss>>24);
539}
540
541//////////////////////////////////////////////
542
543inline ushort BVH2_GetLeafPrimCount( struct BVH2* bvh, uint node_index )
544{
545    struct BVH2Node* n = ((struct BVH2Node*)(bvh + 1)) + node_index;
546    return n->meta_ss & 0xffff;
547}
548inline uint BVH2_GetLeafPrimStart( struct BVH2* bvh, uint node_index )
549{
550    struct BVH2Node* n = ((struct BVH2Node*)(bvh + 1)) + node_index;
551    return n->meta_u;
552}
553inline uint2 BVH2_GetChildIndices( struct BVH2* bvh, uint node_index )
554{
555    struct BVH2Node* n = ((struct BVH2Node*)(bvh + 1)) + node_index;
556    uint2 idx;
557    idx.x = n->meta_u;
558    idx.y = idx.x + (n->meta_ss & 0xffff);
559    return idx;
560}
561
562inline float BVH2_GetNodeArea( global struct BVH2* bvh, uint node_index )
563{
564    global struct BVH2Node* n = ((global struct BVH2Node*)(bvh + 1)) + node_index;
565    return AABB3f_halfArea( &n->box );
566}
567
568
569inline struct AABB3f BVH2_GetNodeBox( global struct BVH2* bvh, uint node_index )
570{
571    global struct BVH2Node* n = ((global struct BVH2Node*)(bvh + 1)) + node_index;
572    return n->box;
573}
574inline void BVH2_SetNodeBox( global struct BVH2* bvh, uint node_index, struct AABB3f* box )
575{
576    global struct BVH2Node* n = ((global struct BVH2Node*)(bvh + 1)) + node_index;
577    n->box = *box;
578}
579
580inline void BVH2_SetNodeBox_lu( global struct BVH2* bvh, uint node_index, float3 lower, float3 upper )
581{
582    global struct BVH2Node* n = ((global struct BVH2Node*)(bvh + 1)) + node_index;
583    AABB3f_set( &n->box, lower, upper );
584}
585
586inline void BVH2_InitNodeBox( struct BVH2* bvh, uint node_index )
587{
588    struct BVH2Node* n = ((struct BVH2Node*)(bvh + 1)) + node_index;
589    AABB3f_init( &n->box );
590}
591
592inline struct AABB BVH2_GetAABB( global struct BVH2* bvh, uint node_index )
593{
594    global struct BVH2Node* n = ((global struct BVH2Node*)(bvh + 1)) + node_index;
595    struct AABB r;
596    r.lower.xyz = AABB3f_load_lower( &n->box );
597    r.upper.xyz = AABB3f_load_upper( &n->box );
598    return r;
599}
600
601inline void BVH2_WriteInnerNode( global struct BVH2* bvh, uint node_index, struct AABB3f* box, uint2 child_offsets, uint mask )
602{
603    global struct BVH2Node* n = ((global struct BVH2Node*)(bvh + 1)) + node_index;
604    n->box = *box;
605    n->meta_u   = child_offsets.x;
606    n->meta_ss  = 0x10000 + (child_offsets.y - child_offsets.x) + (mask<<24);
607  //  n->is_inner  = true;
608}
609
610inline void BVH2_WriteLeafNode( global struct BVH2* bvh, uint node_index, struct AABB3f* box, uint prim_start, uint prim_count, uint mask  )
611{
612    global struct BVH2Node* n = ((global struct BVH2Node*)(bvh + 1)) + node_index;
613    n->box = *box;
614    n->meta_u   = prim_start;
615    n->meta_ss  = prim_count + (mask<<24);
616    //  n->is_inner  = true;
617}
618
619inline uint BVH2_GetMask( global struct BVH2* bvh, uint node_index )
620{
621    global struct BVH2Node* n = ((global struct BVH2Node*)(bvh + 1)) + node_index;
622    return (n->meta_ss>>24);
623}
624
625
626uint BVH2_AllocateNodes( global struct BVH2* bvh, uint num_nodes )
627{
628    return atomic_add_global( &bvh->num_nodes, num_nodes );
629}
630
631inline void BVH2_AtomicMergeNodeBox( global struct BVH2* bvh, uint node_index, float3 lower, float3 upper )
632{
633    global struct BVH2Node* n = ((global struct BVH2Node*)(bvh + 1)) + node_index;
634    AABB3f_atomic_merge_global_lu( &n->box, lower, upper );
635}
636
637
638void BVH2_print( global struct BVH2* bvh, uint start_node )
639{
640    if ( get_local_id( 0 ) == 0 && get_sub_group_id() == 0 )
641    {
642        uint num_nodes = bvh->num_nodes;
643
644        uint2 stack[BFS_MAX_DEPTH * 2];
645        uint sp = 0;
646
647        printf( "allocated_nodes=%u\n", num_nodes );
648
649        stack[sp++] = (uint2)(start_node, 0);
650        while ( sp > 0 )
651        {
652            uint2 data = stack[--sp];
653            uint node = data.x;
654            uint depth = data.y;
655
656            for ( uint i = 0; i < depth; i++ )
657                printf( "    " );
658
659            if ( BVH2_IsInnerNode( bvh, node ) )
660            {
661                uint2 kids = BVH2_GetChildIndices( bvh, node );
662                printf( " %5u: inner: %u %u \n", node, kids.x, kids.y );
663                stack[sp++] = (uint2)(kids.y, depth + 1);
664                stack[sp++] = (uint2)(kids.x, depth + 1);
665
666                struct AABB3f l = BVH2_GetNodeBox( bvh, kids.x );
667                struct AABB3f r = BVH2_GetNodeBox( bvh, kids.y );
668                struct AABB3f p = BVH2_GetNodeBox( bvh, node );
669
670                float3 pl = AABB3f_load_lower( &p );
671                float3 pu = AABB3f_load_upper( &p );
672                float3 ll = AABB3f_load_lower( &l );
673                float3 lu = AABB3f_load_upper( &l );
674                float3 rl = AABB3f_load_lower( &r );
675                float3 ru = AABB3f_load_upper( &r );
676                if ( any( ll < pl ) || any( rl < pl ) ||
677                     any( lu > pu ) || any( ru > pu ) )
678                {
679                    for ( uint i = 0; i < depth; i++ )
680                        printf( "    " );
681
682                    printf( "BAD_BOUNDS!!!!!!!! %u\n", node );
683                }
684
685
686            }
687            else
688            {
689
690                uint start = BVH2_GetLeafPrimStart( bvh, node );
691                uint count = BVH2_GetLeafPrimCount( bvh, node );
692                printf( " %5u: leaf: start=%u count=%u\n  ",node,start,count );
693
694            }
695        }
696    }
697    barrier( CLK_LOCAL_MEM_FENCE );
698}
699
700
701global uint* SAHBuildGlobals_GetPrimrefIndices_In( struct SAHBuildGlobals* globals, bool odd_pass )
702{
703    uint num_refs = globals->num_primrefs;
704    global uint* ib = (global uint*) globals->p_primref_index_buffers;
705    return ib + (odd_pass ? num_refs : 0);
706}
707
708global uint* SAHBuildGlobals_GetPrimrefIndices_Out( struct SAHBuildGlobals* globals, bool odd_pass )
709{
710    uint num_refs = globals->num_primrefs;
711    global uint* ib = (global uint*) globals->p_primref_index_buffers;
712    return ib + (odd_pass ? 0 : num_refs);
713}
714
715global PrimRef* SAHBuildGlobals_GetPrimrefs( struct SAHBuildGlobals* globals )
716{
717    return (global PrimRef*) globals->p_primrefs_buffer;
718}
719
720global struct BVH2* SAHBuildGlobals_GetBVH2( struct SAHBuildGlobals* globals )
721{
722    return (global struct BVH2*)globals->p_bvh2;
723}
724
725uint SAHBuildGlobals_GetLeafSizeInBytes( struct SAHBuildGlobals* globals )
726{
727    return globals->leaf_size;
728}
729
730uint SAHBuildGlobals_GetLeafType( struct SAHBuildGlobals* globals )
731{
732    return globals->leaf_type;
733}
734
735uint SAHBuildGlobals_GetInternalNodeType( struct SAHBuildGlobals* globals )
736{
737    return NODE_TYPE_INTERNAL;
738}
739
740global struct BVHBase* SAHBuildGlobals_GetBVHBase( struct SAHBuildGlobals* globals )
741{
742    return (global struct BVHBase*) globals->p_bvh_base;
743}
744
745uint SAHBuildGlobals_GetTotalPrimRefs( struct SAHBuildGlobals* globals )
746{
747    return globals->num_primrefs;
748}
749
750inline bool SAHBuildGlobals_NeedBackPointers( struct SAHBuildGlobals* globals )
751{
752    return globals->flags & SAH_FLAG_NEED_BACKPOINTERS;
753}
754inline bool SAHBuildGlobals_NeedMasks( struct SAHBuildGlobals* globals )
755{
756    return globals->flags & SAH_FLAG_NEED_MASKS;
757}
758
759
760void SAHBuildGlobals_print( struct SAHBuildGlobals* globals )
761{
762    if ( get_local_id( 0 ) == 0 )
763    {
764        printf( "SAHBuildGlobals: %p\n", globals );
765        printf( "  p_primref_index_buffers =%p\n", globals->p_primref_index_buffers );
766        printf( "  p_primrefs_buffer =%p\n",       globals->p_primrefs_buffer );
767        printf( "  p_bvh2 =%p\n",                  globals->p_bvh2 );
768        printf( "  p_globals =%p\n",               globals->p_globals );
769        printf( "  p_bvh_base =%p\n",              globals->p_bvh_base );
770        printf( "  num_primrefs = %u\n",           globals->num_primrefs );
771        printf( "  leaf_size = %u\n",              globals->leaf_size );
772        printf( "  leaf_type = %u\n",              globals->leaf_type );
773        printf( "  p_qnode_buffer = %p\n",         globals->p_qnode_root_buffer);
774    }
775
776    barrier( CLK_LOCAL_MEM_FENCE );
777}
778
779
780uint get_num_wgs(uint thread_count, uint WG_SIZE)
781{
782    return (thread_count + WG_SIZE - 1) / WG_SIZE;
783}
784
785
786
787
788
789struct BFSDispatchArgs
790{
791    global struct VContextScheduler* scheduler;
792    global struct VContext* context;
793    global struct BVH2* bvh2;
794    global uint* primref_index_in;
795    global uint* primref_index_out;
796    global PrimRef* primref_buffer;
797
798    uint   wg_primref_begin;
799    uint   wg_primref_end;
800    uint   dispatch_primref_begin;
801    uint   dispatch_primref_end;
802    uint   context_id;
803    uint   num_wgs;
804    uint   bvh2_root;
805    uint   global_num_primrefs;
806    bool   do_mask_processing;
807};
808
809
810
811
812// TODO_OPT:  Enable larger WGs
813//    We need a way to do this in a portable fashion.
814//     Gen12 can support larger WGs than Gen9 can
815//
816GRL_ANNOTATE_IGC_DO_NOT_SPILL
817__attribute__( (reqd_work_group_size( 512, 1, 1 )) )
818kernel void
819begin( global struct VContextScheduler* scheduler,
820       dword leaf_size,
821       dword leaf_type,
822       global uint* primref_index_buffers,
823       global PrimRef* primref_buffer,
824       global struct BVH2* bvh2,
825       global struct BVHBase* bvh_base,
826       global struct Globals* globals,
827       global struct SAHBuildGlobals* sah_globals,
828       global uint2* qnode_root_buffer,
829       dword sah_globals_flags
830     )
831{
832    dword num_primrefs = globals->numPrimitives;
833    if ( get_local_id( 0 ) == 0 )
834    {
835        sah_globals->p_primrefs_buffer       = (qword) primref_buffer;
836        sah_globals->p_primref_index_buffers = (qword)primref_index_buffers;
837        sah_globals->p_bvh2                  = (qword) bvh2;
838        sah_globals->p_bvh_base              = (qword) bvh_base;
839        sah_globals->leaf_size               = leaf_size;
840        sah_globals->leaf_type               = leaf_type;
841        sah_globals->num_primrefs            = num_primrefs;
842        sah_globals->p_globals               = (qword) globals;
843        sah_globals->p_qnode_root_buffer     = (gpuva_t) qnode_root_buffer;
844        sah_globals->flags                   = sah_globals_flags;
845
846        // initialize the spill stack
847        scheduler->bfs2_spill_stack.size = 0;
848
849        // initialize BVH2 node counter
850        BVH2_Initialize( bvh2 );
851
852        // configure first vcontext for first build
853        scheduler->contexts[0].dispatch_primref_begin = 0;
854        scheduler->contexts[0].dispatch_primref_end   = num_primrefs;
855        scheduler->contexts[0].bvh2_root              = BVH2_GetRoot( bvh2 );
856        scheduler->contexts[0].tree_depth             = 0;
857        scheduler->contexts[0].batch_index            = 0;
858
859        scheduler->bfs_queue.records[0].context_id = 0;
860
861        scheduler->contexts[0].num_left = 0;
862        scheduler->contexts[0].num_right = 0;
863        scheduler->contexts[0].lr_mask = 0;
864
865        // copy centroid bounds into the BVH2 root node'
866        BVH2_SetNodeBox_lu( bvh2, BVH2_GetRoot( bvh2 ), globals->centroidBounds.lower.xyz, globals->centroidBounds.upper.xyz );
867
868        // zero the trivial build counters.. these are only used by the batch-build path
869        //  but single-wg QNode path (if used) depends on them
870        scheduler->num_trivial_builds = 0;
871        scheduler->num_single_builds = 0;
872
873        // initialize the root-buffer counters
874        sah_globals->root_buffer_num_produced     = 0;
875        sah_globals->root_buffer_num_produced_hi  = 0;
876        sah_globals->root_buffer_num_consumed     = 0;
877        sah_globals->root_buffer_num_consumed_hi  = 0;
878    }
879
880    // initialize vcontext states
881    for ( uint i = get_local_id( 0 ); i < BFS_NUM_VCONTEXTS; i += get_local_size( 0 ) )
882        scheduler->vcontext_state[i] = (i==0) ? VCONTEXT_STATE_EXECUTING : VCONTEXT_STATE_UNALLOCATED;
883
884    // initialize global bin info in vcontext - only context[0] will be used in first iteration
885    BinInfo_init( &scheduler->contexts[0].global_bin_info );
886    LRBounds_init( &scheduler->contexts[0].lr_bounds );
887
888   // barrier( CLK_GLOBAL_MEM_FENCE  ); // lsc flush ... driver now does these as part of COMPUTE_WALKER
889}
890
891// TODO_OPT:  Enable larger WGs
892//    We need a way to do this in a portable fashion.
893//     Gen12 can support larger WGs than Gen9 can
894//
895
896
897// TODO_OPT:  Enable larger WGs
898//    We need a way to do this in a portable fashion.
899//     Gen12 can support larger WGs than Gen9 can
900//
901GRL_ANNOTATE_IGC_DO_NOT_SPILL
902__attribute__((reqd_work_group_size(512, 1, 1)))
903kernel void
904categorize_builds_and_init_scheduler(
905    global struct VContextScheduler* scheduler,
906    global gpuva_t* globals_ptrs,                // OCL-C does not allow kernel parameters to be pointer-to-pointer, so we trick it...
907    global struct SAHBuildBuffersInfo* buffers_info,
908    global struct SAHBuildGlobals* builds_out,
909    dword num_builds
910)
911{
912    local uint num_trivial;
913    local uint num_single;
914    local uint num_full;
915
916    if (get_group_id(0) == 0) // first workgroup performs build categorization
917    {
918        if (get_local_id(0) == 0)
919        {
920            num_trivial = 0;
921            num_single = 0;
922            num_full = 0;
923        }
924
925        barrier(CLK_LOCAL_MEM_FENCE);
926
927        // first pass, count builds of each type
928        uint triv = 0;
929        uint single = 0;
930        uint full = 0;
931        for (uint i = get_local_id(0); i < num_builds; i += get_local_size(0))
932        {
933            global struct Globals* globals = (global struct Globals*) globals_ptrs[i];
934            dword num_refs = globals->numPrimitives;
935
936            if (num_refs <= TRIVIAL_BUILD_THRESHOLD)
937                triv++;
938            else if (num_refs <= SINGLE_WG_BUILD_THRESHOLD)
939                single++;
940            else
941                full++;
942        }
943
944        // merge counts across work-group.  These variables are now offsets into this thread's ranges
945        triv   = atomic_add_local(&num_trivial, triv);
946        single = atomic_add_local(&num_single, single);
947        full   = atomic_add_local(&num_full, full);
948
949        barrier(CLK_LOCAL_MEM_FENCE);
950
951        global struct SAHBuildGlobals* trivial_builds_out = builds_out;
952        global struct SAHBuildGlobals* single_builds_out = builds_out + num_trivial;
953        global struct SAHBuildGlobals* full_builds_out = builds_out + num_trivial + num_single;
954
955        for (uint i = get_local_id(0); i < num_builds; i += get_local_size(0))
956        {
957            global struct Globals* globals = (global struct Globals*) globals_ptrs[i];
958            global struct SAHBuildBuffersInfo* buffers = &buffers_info[i];
959
960            dword num_refs = globals->numPrimitives;
961            dword leaf_type = globals->leafPrimType;
962            dword leaf_size = globals->leafSize;
963
964            global struct SAHBuildGlobals* place;
965            if (num_refs <= TRIVIAL_BUILD_THRESHOLD)
966                place = trivial_builds_out + (triv++);
967            else if (num_refs <= SINGLE_WG_BUILD_THRESHOLD)
968                place = single_builds_out + (single++);
969            else
970                place = full_builds_out + (full++);
971
972            place->p_primref_index_buffers = buffers->p_primref_index_buffers;
973            place->p_primrefs_buffer    = buffers->p_primrefs_buffer;
974            place->p_bvh2               = buffers->p_bvh2;
975            place->p_bvh_base           = buffers->p_bvh_base;
976            place->p_globals            = (gpuva_t)globals;
977            place->num_primrefs         = num_refs;
978            place->leaf_size            = leaf_size;
979            place->leaf_type            = leaf_type;
980            place->flags                = buffers->sah_globals_flags;
981            place->p_qnode_root_buffer  = buffers->p_qnode_root_buffer;
982
983            // only initialize BVH2 if it will actually be used by the build
984            //   trivial passes will not use it
985            if( num_refs > SINGLE_WG_BUILD_THRESHOLD )
986            {
987                // initialize BVH2 node counter
988                global struct BVH2* bvh2 = SAHBuildGlobals_GetBVH2(place);
989                BVH2_Initialize(bvh2);
990
991                // copy centroid bounds into the BVH2 root node'
992                BVH2_SetNodeBox_lu(bvh2, BVH2_GetRoot(bvh2), globals->centroidBounds.lower.xyz, globals->centroidBounds.upper.xyz);
993            }
994        }
995
996        if (get_local_id(0) == 0)
997        {
998            scheduler->num_trivial_builds   = num_trivial;
999            scheduler->num_single_builds    = num_single;
1000            scheduler->batched_build_offset = num_trivial + num_single;
1001            scheduler->batched_build_count  = num_full;
1002        }
1003    }
1004    else // second workgroup initializes the scheduler
1005    {
1006        // initialize vcontext states
1007        for (uint i = get_local_id(0); i < BFS_NUM_VCONTEXTS; i += get_local_size(0))
1008            scheduler->vcontext_state[i] = (i == 0) ? VCONTEXT_STATE_EXECUTING : VCONTEXT_STATE_UNALLOCATED;
1009
1010        // initialize global bin info in vcontexts
1011        for (uint i = get_sub_group_id(); i < BFS_NUM_VCONTEXTS; i += get_num_sub_groups())
1012            BinInfo_init_subgroup(&scheduler->contexts[i].global_bin_info);
1013
1014        // initialize the spill stack
1015        if (get_local_id(0) == 0)
1016            scheduler->bfs2_spill_stack.size = 0;
1017    }
1018
1019    //barrier( CLK_LOCAL_MEM_FENCE | CLK_GLOBAL_MEM_FENCE );// lsc flush ... driver now does these as part of COMPUTE_WALKER
1020}
1021
1022
1023
1024
1025
1026GRL_ANNOTATE_IGC_DO_NOT_SPILL
1027__attribute__((reqd_work_group_size(BFS_NUM_VCONTEXTS, 1, 1)))
1028kernel void
1029begin_batchable(
1030    global struct VContextScheduler* scheduler,
1031    global struct SAHBuildGlobals* sah_globals
1032)
1033{
1034    ushort scheduler_build_offset = scheduler->batched_build_offset;
1035    ushort scheduler_num_builds   = scheduler->batched_build_count;
1036
1037    ushort num_builds = min( scheduler_num_builds, (ushort)BFS_NUM_VCONTEXTS );
1038
1039    uint num_wgs = 0;
1040
1041    ushort tid = get_local_id(0);
1042    if ( tid < num_builds )
1043    {
1044        ushort batch_index = scheduler_build_offset + tid;
1045
1046        uint num_primrefs = sah_globals[batch_index].num_primrefs;
1047
1048        // configure first vcontext for first build
1049        scheduler->contexts[tid].dispatch_primref_begin = 0;
1050        scheduler->contexts[tid].dispatch_primref_end   = num_primrefs;
1051        scheduler->contexts[tid].bvh2_root              = BVH2_GetRoot( SAHBuildGlobals_GetBVH2(&sah_globals[batch_index]) );
1052        scheduler->contexts[tid].tree_depth             = 0;
1053        scheduler->contexts[tid].batch_index            = batch_index;
1054        scheduler->vcontext_state[tid] = VCONTEXT_STATE_EXECUTING;
1055
1056        scheduler->contexts[tid].num_left = 0;
1057        scheduler->contexts[tid].num_right = 0;
1058        scheduler->contexts[tid].lr_mask   = 0;
1059
1060        num_wgs = get_num_wgs( num_primrefs, BFS_WG_SIZE );
1061
1062        scheduler->bfs_queue.wg_count[tid] = num_wgs;
1063        scheduler->bfs_queue.records[tid].batch_index = batch_index;
1064        scheduler->bfs_queue.records[tid].context_id  = tid;
1065    }
1066
1067    num_wgs = work_group_reduce_add(num_wgs);
1068
1069    if (tid == 0)
1070    {
1071        // write out build count and offset for next BFS iteration
1072        scheduler->batched_build_offset = scheduler_build_offset + num_builds;
1073        scheduler->batched_build_count  = scheduler_num_builds - num_builds;
1074
1075        // write out initial WG count and loop termination mask for command streamer to consume
1076        scheduler->batched_build_wg_count  = num_wgs;
1077        scheduler->batched_build_loop_mask = (scheduler_num_builds > num_builds) ? 1 : 0;
1078
1079        scheduler->bfs_queue.num_dispatches = num_builds;
1080    }
1081
1082    for ( uint i = get_sub_group_id(); i < num_builds; i += get_num_sub_groups() )
1083        BinInfo_init_subgroup( &scheduler->contexts[i].global_bin_info );
1084
1085    for ( uint i = get_sub_group_id(); i < num_builds; i += get_num_sub_groups() )
1086        LRBounds_init_subgroup( &scheduler->contexts[i].lr_bounds );
1087}
1088
1089
1090
1091bool is_leaf( uint num_refs )
1092{
1093    return num_refs <= TREE_ARITY;
1094}
1095
1096bool is_dfs( uint num_refs )
1097{
1098    return num_refs > TREE_ARITY&& num_refs <= DFS_THRESHOLD;
1099}
1100
1101bool is_bfs( uint num_refs )
1102{
1103    return num_refs > DFS_THRESHOLD;
1104}
1105
1106int2 is_leaf_2( uint2 num_refs )
1107{
1108    return num_refs.xy <= TREE_ARITY;
1109}
1110int2 is_bfs_2( uint2 num_refs )
1111{
1112    return num_refs.xy > DFS_THRESHOLD;
1113}
1114
1115int2 is_dfs_2( uint2 num_refs )
1116{
1117    return num_refs.xy > TREE_ARITY && num_refs.xy <= DFS_THRESHOLD;
1118}
1119
1120#if 0
1121GRL_ANNOTATE_IGC_DO_NOT_SPILL
1122__attribute__((reqd_work_group_size(16, 1, 1)))
1123__attribute__((intel_reqd_sub_group_size(16)))
1124kernel void
1125sg_scheduler( global struct VContextScheduler* scheduler )
1126{
1127    local struct BFS1SpillStackEntry SLM_local_spill_stack[BFS_NUM_VCONTEXTS];
1128    local uchar SLM_context_state[BFS_NUM_VCONTEXTS];
1129    local vcontext_id_t SLM_free_list[BFS_NUM_VCONTEXTS];
1130    local vcontext_id_t SLM_exec_list[BFS_NUM_VCONTEXTS];
1131
1132
1133    varying ushort lane = get_sub_group_local_id();
1134
1135    uniform uint free_list_size = 0;
1136    uniform uint exec_list_size = 0;
1137
1138    // read context states, build lists of free and executing contexts
1139    for (varying uint i = lane; i < BFS_NUM_VCONTEXTS; i += get_sub_group_size())
1140    {
1141        uchar state = scheduler->vcontext_state[i];
1142        SLM_context_state[i] = state;
1143
1144        uniform ushort exec_mask = intel_sub_group_ballot(state == VCONTEXT_STATE_EXECUTING);
1145
1146        varying ushort prefix_exec = subgroup_bit_prefix_exclusive(exec_mask);
1147        varying ushort prefix_free = lane - prefix_exec;
1148        varying ushort exec_list_pos = exec_list_size + prefix_exec;
1149        varying ushort free_list_pos = free_list_size + prefix_free;
1150
1151        if (state == VCONTEXT_STATE_EXECUTING)
1152            SLM_exec_list[exec_list_pos] = i;
1153        else
1154            SLM_free_list[free_list_pos] = i;
1155
1156        uniform ushort num_exec = popcount(exec_mask);
1157        exec_list_size += num_exec;
1158        free_list_size += get_sub_group_size() - num_exec;
1159    }
1160
1161    uniform uint total_bfs_dispatches = 0;
1162    uniform uint total_dfs_dispatches = 0;
1163    uniform uint bfs_spill_stack_size   = 0;
1164    uniform uint total_bfs_wgs      = 0;
1165
1166    // process executing context.  accumulate bfs/dfs dispatches and free-list entries
1167    for (uint i = 0; i < exec_list_size; i+= get_sub_group_size() )
1168    {
1169        varying ushort num_dfs_dispatches     = 0;
1170        varying ushort num_bfs_spills         = 0;
1171
1172        varying ushort num_bfs_children;
1173        varying ushort context_id;
1174        struct VContext* context;
1175        varying uint num_left      ;
1176        varying uint num_right     ;
1177        varying uint primref_begin ;
1178        varying uint primref_end   ;
1179        varying uint depth         ;
1180
1181        bool active_lane = (i + lane) < exec_list_size;
1182        if ( active_lane )
1183        {
1184            context_id = SLM_exec_list[i + lane];
1185            context    = &scheduler->contexts[context_id];
1186
1187            num_left      = context->num_left;
1188            num_right     = context->num_right;
1189            primref_begin = context->dispatch_primref_begin;
1190            primref_end   = context->dispatch_primref_end;
1191            depth         = context->tree_depth;
1192
1193            // get dispatch counts
1194
1195            num_dfs_dispatches = is_dfs(num_left) + is_dfs(num_right);
1196            num_bfs_children = is_bfs(num_left) + is_bfs(num_right);
1197            num_bfs_spills = (num_bfs_children == 2) ? 1 : 0;
1198        }
1199
1200        // allocate space for DFS, BFS dispatches, and BFS spills
1201        varying uint dfs_pos               = total_dfs_dispatches + sub_group_scan_exclusive_add(num_dfs_dispatches);
1202        varying ushort mask_bfs_spills     = intel_sub_group_ballot(num_bfs_children & 2); // spill if #children == 2
1203        varying ushort mask_bfs_dispatches = intel_sub_group_ballot(num_bfs_children & 3); // dispatch if #children == 1 or 2
1204        varying uint bfs_spill_pos         = bfs_spill_stack_size + subgroup_bit_prefix_exclusive(mask_bfs_spills);
1205        varying uint bfs_dispatch_pos      = total_bfs_dispatches + subgroup_bit_prefix_exclusive(mask_bfs_dispatches);
1206
1207        total_dfs_dispatches += sub_group_reduce_add(num_dfs_dispatches);
1208        bfs_spill_stack_size += popcount(mask_bfs_spills);
1209        total_bfs_dispatches += popcount(mask_bfs_dispatches);
1210
1211        varying uint num_bfs_wgs = 0;
1212        if (active_lane)
1213        {
1214            if (num_dfs_dispatches)
1215            {
1216                if (is_dfs(num_left))
1217                {
1218                    scheduler->dfs_queue.records[dfs_pos].primref_base = primref_begin;
1219                    scheduler->dfs_queue.records[dfs_pos].num_primrefs = num_left;
1220                    scheduler->dfs_queue.records[dfs_pos].bvh2_base = context->left_bvh2_root;
1221                    scheduler->dfs_queue.records[dfs_pos].tree_depth = depth + 1;
1222                    dfs_pos++;
1223                }
1224                if (is_dfs(num_right))
1225                {
1226                    scheduler->dfs_queue.records[dfs_pos].primref_base = primref_begin + num_left;
1227                    scheduler->dfs_queue.records[dfs_pos].num_primrefs = num_right;
1228                    scheduler->dfs_queue.records[dfs_pos].bvh2_base = context->right_bvh2_root;
1229                    scheduler->dfs_queue.records[dfs_pos].tree_depth = depth + 1;
1230                }
1231            }
1232
1233            uint num_bfs_children = is_bfs(num_left) + is_bfs(num_right);
1234            if (num_bfs_children == 2)
1235            {
1236                // spill the right child.. push an entry onto local spill stack
1237                SLM_local_spill_stack[bfs_spill_pos].primref_begin = primref_begin + num_left;
1238                SLM_local_spill_stack[bfs_spill_pos].primref_end = primref_end;
1239                SLM_local_spill_stack[bfs_spill_pos].bvh2_root = context->right_bvh2_root;
1240                SLM_local_spill_stack[bfs_spill_pos].tree_depth = depth + 1;
1241
1242                // setup BFS1 dispatch for left child
1243                context->dispatch_primref_end = primref_begin + num_left;
1244                context->bvh2_root = context->left_bvh2_root;
1245                context->tree_depth = depth + 1;
1246                num_bfs_wgs = get_num_wgs(num_left, BFS_WG_SIZE);
1247
1248                scheduler->bfs_queue.wg_count[bfs_dispatch_pos]           = num_bfs_wgs;
1249                scheduler->bfs_queue.records[bfs_dispatch_pos].context_id = context_id;
1250            }
1251            else if (num_bfs_children == 1)
1252            {
1253                // setup BFS1 dispatch for whichever child wants it
1254                if (is_bfs(num_left))
1255                {
1256                    // bfs on left child
1257                    context->dispatch_primref_end = context->dispatch_primref_begin + num_left;
1258                    context->bvh2_root = context->left_bvh2_root;
1259                    context->tree_depth = depth + 1;
1260                    num_bfs_wgs = get_num_wgs(num_left, BFS_WG_SIZE);
1261                }
1262                else
1263                {
1264                    // bfs on right child
1265                    context->dispatch_primref_begin = context->dispatch_primref_begin + num_left;
1266                    context->bvh2_root = context->right_bvh2_root;
1267                    context->tree_depth = depth + 1;
1268                    num_bfs_wgs = get_num_wgs(num_right, BFS_WG_SIZE);
1269                }
1270
1271                scheduler->bfs_queue.wg_count[bfs_dispatch_pos]           = num_bfs_wgs;
1272                scheduler->bfs_queue.records[bfs_dispatch_pos].context_id = context_id;
1273            }
1274            else
1275            {
1276                // no bfs dispatch.. this context is now free
1277                SLM_context_state[context_id] = VCONTEXT_STATE_UNALLOCATED;
1278            }
1279        }
1280
1281        // count bfs work groups
1282        total_bfs_wgs += sub_group_reduce_add(num_bfs_wgs);
1283
1284        // add newly deallocated contexts to the free list
1285        uniform uint free_mask = intel_sub_group_ballot( active_lane && num_bfs_children == 0);
1286        varying uint free_list_pos = free_list_size + subgroup_bit_prefix_exclusive(free_mask);
1287        free_list_size += popcount(free_mask);
1288
1289        if ( free_mask & (1<<lane) )
1290            SLM_free_list[free_list_pos] = context_id;
1291
1292    }
1293
1294    barrier(CLK_LOCAL_MEM_FENCE);
1295
1296    // if we have more free contexts than spills, read additional spills from the scheduler's spill stack
1297    uniform uint memory_spill_stack_size = scheduler->bfs2_spill_stack.size;
1298
1299    if(bfs_spill_stack_size < free_list_size && memory_spill_stack_size > 0 )
1300    {
1301        uniform uint read_count = min(free_list_size - bfs_spill_stack_size, memory_spill_stack_size);
1302
1303        for (varying uint i = lane; i < read_count; i+= get_sub_group_size())
1304            SLM_local_spill_stack[bfs_spill_stack_size + i] = scheduler->bfs2_spill_stack.entries[memory_spill_stack_size - 1 - i];
1305
1306        bfs_spill_stack_size += read_count;
1307        memory_spill_stack_size -= read_count;
1308    }
1309
1310    // steal pending BFS work and assign it to free contexts
1311    uniform uint num_steals = min(bfs_spill_stack_size, free_list_size);
1312
1313    for (uniform uint i = 0; i < num_steals; i += get_sub_group_size())
1314    {
1315        varying uint num_bfs_wgs = 0;
1316
1317        if (i + lane < num_steals)
1318        {
1319            uint context_id = SLM_free_list[i+lane];
1320            struct VContext* context = &scheduler->contexts[context_id];
1321            struct BFS1SpillStackEntry entry = SLM_local_spill_stack[i+lane];
1322
1323            context->dispatch_primref_begin = entry.primref_begin;
1324            context->dispatch_primref_end = entry.primref_end;
1325            context->bvh2_root = entry.bvh2_root;
1326            context->tree_depth = entry.tree_depth;
1327
1328            num_bfs_wgs = get_num_wgs(entry.primref_end - entry.primref_begin, BFS_WG_SIZE);
1329
1330            scheduler->bfs_queue.wg_count[total_bfs_dispatches + i + lane] = num_bfs_wgs;
1331            scheduler->bfs_queue.records[total_bfs_dispatches + i + lane].context_id = context_id;
1332
1333            SLM_context_state[context_id] = VCONTEXT_STATE_EXECUTING;
1334        }
1335
1336        total_bfs_wgs += sub_group_reduce_add( num_bfs_wgs );
1337    }
1338
1339    total_bfs_dispatches += num_steals;
1340
1341    //  write out excess spills to global spill stack
1342    uniform uint extra_spills = bfs_spill_stack_size - num_steals;
1343    for (varying uint i = lane; i < extra_spills; i += get_sub_group_size())
1344    {
1345        scheduler->bfs2_spill_stack.entries[memory_spill_stack_size + i] = SLM_local_spill_stack[num_steals+i];
1346    }
1347
1348
1349    // write out modified context states
1350    for ( varying uint i = lane; i < BFS_NUM_VCONTEXTS; i += get_sub_group_size())
1351        scheduler->vcontext_state[i] = SLM_context_state[i];
1352
1353
1354    if (get_local_id(0) == 0)
1355    {
1356        // write out new memory stack size
1357        scheduler->bfs2_spill_stack.size = memory_spill_stack_size + extra_spills;
1358
1359        // store workgroup counters
1360        scheduler->bfs_queue.num_dispatches = total_bfs_dispatches;
1361        scheduler->num_bfs_wgs = total_bfs_wgs;
1362        scheduler->num_dfs_wgs = total_dfs_dispatches;
1363    }
1364
1365  //  barrier(CLK_GLOBAL_MEM_FENCE); // make memory writes globally visible// lsc flush ... driver now does these as part of COMPUTE_WALKER
1366}
1367#endif
1368
1369#define SCHEDULER_SG_SIZE 16
1370#define SCHEDULER_WG_SIZE BFS_NUM_VCONTEXTS
1371#define SCHEDULER_NUM_SGS (SCHEDULER_WG_SIZE / SCHEDULER_SG_SIZE)
1372
1373
1374struct BFSDispatchArgs get_bfs_args_from_record_batchable(
1375    struct BFSDispatchRecord* record,
1376    global struct VContextScheduler* scheduler,
1377    global struct SAHBuildGlobals* globals_buffer );
1378
1379GRL_ANNOTATE_IGC_DO_NOT_SPILL
1380__attribute__((reqd_work_group_size(SCHEDULER_WG_SIZE, 1, 1)))
1381__attribute__((intel_reqd_sub_group_size(SCHEDULER_SG_SIZE)))
1382kernel void
1383scheduler(global struct VContextScheduler* scheduler, global struct SAHBuildGlobals* sah_globals )
1384{
1385    local struct BFS1SpillStackEntry SLM_local_spill_stack[2 * BFS_NUM_VCONTEXTS];
1386    local uint SLM_local_spill_stack_size;
1387    local uint SLM_dfs_dispatch_count;
1388
1389    if (get_local_id(0) == 0)
1390    {
1391        SLM_local_spill_stack_size = 0;
1392        SLM_dfs_dispatch_count = 0;
1393    }
1394
1395    uint context_id = get_local_id(0);
1396    uint state = scheduler->vcontext_state[context_id];
1397    uint initial_state = state;
1398
1399    uint batch_index = 0;
1400    global struct VContext* context = &scheduler->contexts[context_id];
1401
1402    barrier(CLK_LOCAL_MEM_FENCE);
1403
1404
1405    uint global_spill_stack_size = scheduler->bfs2_spill_stack.size;
1406
1407
1408    if (state == VCONTEXT_STATE_EXECUTING)
1409    {
1410        uint left_bvh2_root;
1411        uint right_bvh2_root;
1412
1413        uint num_left = context->num_left;
1414        uint num_right = context->num_right;
1415
1416        uint primref_begin = context->dispatch_primref_begin;
1417        uint primref_end = context->dispatch_primref_end;
1418
1419        uint depth = context->tree_depth;
1420        uint batch_index = context->batch_index;
1421
1422        struct BFSDispatchRecord record;
1423        record.context_id = context_id;
1424        record.batch_index = context->batch_index;
1425
1426        struct BFSDispatchArgs args = get_bfs_args_from_record_batchable( &record, scheduler, sah_globals);
1427
1428        // do cleanup of bfs_pass2
1429        {
1430            // compute geom bounds
1431            struct AABB3f left_geom_bounds;
1432            struct AABB3f right_geom_bounds;
1433            struct AABB3f left_centroid_bounds;
1434            struct AABB3f right_centroid_bounds;
1435            uint2 lr_counts = (uint2)(num_left, num_right);
1436
1437            {
1438                left_centroid_bounds    = LRBounds_get_left_centroid( &context->lr_bounds );
1439                left_geom_bounds        = LRBounds_get_left_geom(  &context->lr_bounds );
1440                right_centroid_bounds   = LRBounds_get_right_centroid( &context->lr_bounds );
1441                right_geom_bounds       = LRBounds_get_right_geom( &context->lr_bounds );
1442            }
1443
1444            int2 v_is_leaf = is_leaf_2( lr_counts );
1445            int2 v_is_dfs  = is_dfs_2( lr_counts );
1446            int2 v_is_bfs  = is_bfs_2( lr_counts );
1447            uint left_mask  = args.do_mask_processing ? context->lr_mask & 0xff : 0xff;
1448            uint right_mask = args.do_mask_processing ? (context->lr_mask & 0xff00) >> 8 : 0xff;
1449
1450            // how many BVH2 nodes do we need to allocate?  For DFS, we need to pre-allocate full subtree
1451            uint2 lr_node_counts = select( (uint2)(1,1), (2*lr_counts-1), v_is_dfs );
1452            uint left_node_count = lr_node_counts.x;
1453            uint right_node_count = lr_node_counts.y;
1454
1455            // allocate the nodes
1456            uint first_node = BVH2_AllocateNodes( args.bvh2, left_node_count + right_node_count );
1457
1458            // point our root node at its children
1459            left_bvh2_root  = first_node;
1460            right_bvh2_root = first_node + left_node_count;
1461
1462            // store combined geom bounds in the root node's AABB.. we previously stored centroid bounds there
1463            //   but node creation requires geom bounds
1464            struct AABB3f geom_bounds = left_geom_bounds;
1465            AABB3f_extend(&geom_bounds, &right_geom_bounds);
1466            BVH2_WriteInnerNode( args.bvh2, args.bvh2_root, &geom_bounds, (uint2)(left_bvh2_root,right_bvh2_root), left_mask | right_mask );
1467
1468//            printf(" node: %u  mask: %x\n", args.bvh2_root, left_mask|right_mask );
1469
1470            // store the appropriate AABBs in the child nodes
1471            //   - BFS passes need centroid bounds
1472            //   - DFS passes need geom bounds
1473            //  Here we also write leaf connectivity information (prim start+count)
1474            //   this will be overwritten later if we are creating an inner node
1475            struct AABB3f left_box, right_box;
1476            left_box  = AABB3f_select( left_geom_bounds,  left_centroid_bounds,  v_is_bfs.xxx );
1477            right_box = AABB3f_select( right_geom_bounds, right_centroid_bounds, v_is_bfs.yyy );
1478
1479            uint left_start  = primref_begin;
1480            uint right_start = primref_begin + num_left;
1481            BVH2_WriteLeafNode( args.bvh2, left_bvh2_root,  &left_box, left_start,  num_left, left_mask );
1482            BVH2_WriteLeafNode( args.bvh2, right_bvh2_root, &right_box, right_start, num_right, right_mask );
1483
1484            // make input and output primref index buffers consistent in the event we're creating a leaf
1485            //   There should only ever be one leaf created, otherwise we'd have done a DFS pass sooner
1486            if (any( v_is_leaf.xy ))
1487            {
1488                uint start    = v_is_leaf.x ? left_start : right_start;
1489                uint num_refs = v_is_leaf.x ? num_left : num_right;
1490
1491                for(uint i = 0; i < num_refs; i++)
1492                {
1493                    args.primref_index_in[start + i] = args.primref_index_out[start + i];
1494                }
1495            }
1496        }
1497
1498        // when BFS2 finishes, we need to dispatch two child tasks.
1499        //   DFS dispatches can run free and do not need a context
1500        //   BFS dispatches need a context.
1501        //  In the case where both of the child nodes are BFS, the current context can immediately run one of the child dispatches
1502        //   and the other is spilled for an unallocated context to pick up
1503
1504        uint num_dfs_dispatches = is_dfs(num_left) + is_dfs(num_right);
1505        if (num_dfs_dispatches)
1506        {
1507            uint dfs_pos = atomic_add_local(&SLM_dfs_dispatch_count, num_dfs_dispatches);
1508            if (is_dfs(num_left))
1509            {
1510                scheduler->dfs_queue.records[dfs_pos].primref_base = primref_begin;
1511                scheduler->dfs_queue.records[dfs_pos].num_primrefs = num_left;
1512                scheduler->dfs_queue.records[dfs_pos].bvh2_base = left_bvh2_root;
1513                scheduler->dfs_queue.records[dfs_pos].tree_depth = depth + 1;
1514                scheduler->dfs_queue.records[dfs_pos].batch_index = batch_index;
1515                dfs_pos++;
1516            }
1517            if (is_dfs(num_right))
1518            {
1519                scheduler->dfs_queue.records[dfs_pos].primref_base = primref_begin + num_left;
1520                scheduler->dfs_queue.records[dfs_pos].num_primrefs = num_right;
1521                scheduler->dfs_queue.records[dfs_pos].bvh2_base = right_bvh2_root;
1522                scheduler->dfs_queue.records[dfs_pos].tree_depth = depth + 1;
1523                scheduler->dfs_queue.records[dfs_pos].batch_index = batch_index;
1524            }
1525        }
1526
1527        uint num_bfs_children = is_bfs(num_left) + is_bfs(num_right);
1528        if (num_bfs_children)
1529        {
1530            uint place = atomic_add_local(&SLM_local_spill_stack_size, num_bfs_children);
1531            if (is_bfs(num_left))
1532            {
1533                SLM_local_spill_stack[place].primref_begin = primref_begin;
1534                SLM_local_spill_stack[place].primref_end = primref_begin + num_left;
1535                SLM_local_spill_stack[place].bvh2_root = left_bvh2_root;
1536                SLM_local_spill_stack[place].tree_depth = depth + 1;
1537                SLM_local_spill_stack[place].batch_index = batch_index;
1538                place++;
1539            }
1540            if (is_bfs(num_right))
1541            {
1542                SLM_local_spill_stack[place].primref_begin = primref_begin + num_left;
1543                SLM_local_spill_stack[place].primref_end = primref_end;
1544                SLM_local_spill_stack[place].bvh2_root = right_bvh2_root;
1545                SLM_local_spill_stack[place].tree_depth = depth + 1;
1546                SLM_local_spill_stack[place].batch_index = batch_index;
1547                place++;
1548            }
1549        }
1550    }
1551
1552    barrier(CLK_LOCAL_MEM_FENCE);
1553
1554    uint local_spill_stack_size = SLM_local_spill_stack_size;
1555
1556    struct BFS1SpillStackEntry entry;
1557    state = VCONTEXT_STATE_UNALLOCATED;
1558    if (context_id < local_spill_stack_size)
1559    {
1560        // pull BFS work from the local spill stack if there's enough work there
1561        entry = SLM_local_spill_stack[context_id];
1562        state = VCONTEXT_STATE_EXECUTING;
1563    }
1564    else if ((context_id - local_spill_stack_size) < (global_spill_stack_size))
1565    {
1566        // if there isn't enough work on the local stack, consume from the global one
1567        uint global_pos = (global_spill_stack_size - 1) - (context_id - local_spill_stack_size);
1568        entry = scheduler->bfs2_spill_stack.entries[global_pos];
1569        state = VCONTEXT_STATE_EXECUTING;
1570    }
1571
1572    // contexts which received work set themselves up for the next BFS1 dispatch
1573    uint num_bfs_wgs = 0;
1574    uint num_bfs_dispatches = 0;
1575    if (state == VCONTEXT_STATE_EXECUTING)
1576    {
1577        context->dispatch_primref_begin = entry.primref_begin;
1578        context->dispatch_primref_end = entry.primref_end;
1579        context->bvh2_root = entry.bvh2_root;
1580        context->tree_depth = entry.tree_depth;
1581        context->batch_index = entry.batch_index;
1582
1583        context->num_left = 0;
1584        context->num_right = 0;
1585        context->lr_mask = 0;
1586
1587        batch_index = entry.batch_index;
1588        num_bfs_wgs = get_num_wgs(entry.primref_end - entry.primref_begin, BFS_WG_SIZE);
1589        num_bfs_dispatches = 1;
1590    }
1591
1592
1593    if (local_spill_stack_size > BFS_NUM_VCONTEXTS)
1594    {
1595        // write out additional spills if we produced more work than we can consume
1596        uint excess_spills = local_spill_stack_size - BFS_NUM_VCONTEXTS;
1597        uint write_base = global_spill_stack_size;
1598        uint lid = get_local_id(0);
1599        if (lid < excess_spills)
1600            scheduler->bfs2_spill_stack.entries[write_base + lid] = SLM_local_spill_stack[BFS_NUM_VCONTEXTS + lid];
1601
1602        if (lid == 0)
1603            scheduler->bfs2_spill_stack.size = global_spill_stack_size + excess_spills;
1604    }
1605    else if (global_spill_stack_size > 0)
1606    {
1607        // otherwise, if we consumed any spills from the global stack, update the stack size
1608        if (get_local_id(0) == 0)
1609        {
1610            uint global_spills_consumed = min(global_spill_stack_size, BFS_NUM_VCONTEXTS - local_spill_stack_size);
1611            scheduler->bfs2_spill_stack.size = global_spill_stack_size - global_spills_consumed;
1612        }
1613    }
1614
1615
1616    // Do various WG reductions..  the code below is a hand-written version of the following:
1617    //
1618    // uint bfs_dispatch_queue_pos     = work_group_scan_exclusive_add( num_bfs_dispatches );
1619    // uint reduce_num_bfs_wgs         = work_group_reduce_add(num_bfs_wgs);
1620    // uint reduce_num_bfs_dispatches  = work_group_reduce_add(num_bfs_dispatches);
1621    uint bfs_dispatch_queue_pos;
1622    uint reduce_num_bfs_dispatches;
1623    uint reduce_num_bfs_wgs;
1624    local uint partial_dispatches[SCHEDULER_WG_SIZE / SCHEDULER_SG_SIZE];
1625    local uint partial_wgs[SCHEDULER_WG_SIZE / SCHEDULER_SG_SIZE];
1626    {
1627        partial_dispatches[get_sub_group_id()] = sub_group_reduce_add(num_bfs_dispatches);
1628        partial_wgs[get_sub_group_id()] = sub_group_reduce_add(num_bfs_wgs);
1629
1630        uint sg_prefix = sub_group_scan_exclusive_add(num_bfs_dispatches);
1631
1632        uint prefix_dispatches = 0;
1633        uint total_dispatches = 0;
1634        uint total_wgs = 0;
1635        ushort lane = get_sub_group_local_id();
1636
1637        barrier(CLK_LOCAL_MEM_FENCE);
1638
1639        for (ushort i = 0; i < SCHEDULER_NUM_SGS; i += SCHEDULER_SG_SIZE) // this loop is intended to be fully unrolled after compilation
1640        {
1641            uint p_dispatch = partial_dispatches[i + lane];
1642            uint p_wg = partial_wgs[i + lane];
1643
1644            prefix_dispatches += (i + lane < get_sub_group_id()) ? p_dispatch : 0;
1645            total_dispatches += p_dispatch;
1646            total_wgs += p_wg;
1647        }
1648
1649        bfs_dispatch_queue_pos = sg_prefix + sub_group_reduce_add(prefix_dispatches);
1650        reduce_num_bfs_dispatches = sub_group_reduce_add(total_dispatches);
1651        reduce_num_bfs_wgs = sub_group_reduce_add(total_wgs);
1652    }
1653
1654    // insert records into BFS queue
1655    if (num_bfs_dispatches)
1656    {
1657        scheduler->bfs_queue.wg_count[bfs_dispatch_queue_pos] = num_bfs_wgs;
1658        scheduler->bfs_queue.records[bfs_dispatch_queue_pos].context_id = context_id;
1659        scheduler->bfs_queue.records[bfs_dispatch_queue_pos].batch_index = batch_index;
1660    }
1661
1662
1663    // store modified vcontext state if it has changed
1664    if (initial_state != state)
1665        scheduler->vcontext_state[context_id] = state;
1666
1667
1668    // store workgroup counters
1669    if (get_local_id(0) == 0)
1670    {
1671        scheduler->bfs_queue.num_dispatches = reduce_num_bfs_dispatches;
1672        scheduler->num_bfs_wgs = reduce_num_bfs_wgs;
1673        scheduler->num_dfs_wgs = SLM_dfs_dispatch_count;
1674    }
1675
1676    const uint contexts_to_clear = min( (uint)BFS_NUM_VCONTEXTS, (uint)(local_spill_stack_size+global_spill_stack_size) );
1677
1678    for ( uint i = get_sub_group_id(); i < contexts_to_clear; i += get_num_sub_groups() )
1679        BinInfo_init_subgroup( &scheduler->contexts[i].global_bin_info );
1680
1681    for ( uint i = get_sub_group_id(); i < contexts_to_clear; i += get_num_sub_groups() )
1682        LRBounds_init_subgroup( &scheduler->contexts[i].lr_bounds );
1683}
1684
1685#if 0
1686uint record_search( struct BFSDispatchRecord* record_out, global struct BFSDispatchQueue* queue )
1687{
1688    uint group = get_group_id(0);
1689    ushort lane = get_sub_group_local_id();
1690    uint num_dispatches = queue->num_dispatches;
1691    uint base = 0;
1692    for (uint i = 0; i < num_dispatches; i += get_sub_group_size())
1693    {
1694        uint counts = intel_sub_group_block_read(&queue->wg_count[i]);
1695
1696        for (uint j = 0; j < get_sub_group_size(); j++)
1697        {
1698            uint n = sub_group_broadcast(counts, j);
1699            if (group < n)
1700            {
1701                *record_out = queue->records[i + j];
1702                return group;
1703            }
1704            group -= n;
1705        }
1706    }
1707
1708    return 0; // NOTE: unreachable in practice
1709}
1710#endif
1711
1712
1713uint record_search(struct BFSDispatchRecord* record_out, global struct BFSDispatchQueue* queue)
1714{
1715    uint group = get_group_id(0);
1716
1717    uint num_dispatches = queue->num_dispatches;
1718
1719    uint dispatch_id = 0;
1720    uint local_id = 0;
1721    uint i = 0;
1722    do
1723    {
1724        uint counts = intel_sub_group_block_read(&queue->wg_count[i]);
1725        uint prefix = sub_group_scan_exclusive_add(counts);
1726
1727        uint g = group - prefix;
1728        uint ballot = intel_sub_group_ballot(g < counts);
1729        if (ballot)
1730        {
1731            uint lane = ctz(ballot);
1732            dispatch_id = i + lane;
1733            local_id = intel_sub_group_shuffle(g, lane);
1734            break;
1735        }
1736
1737        group -= sub_group_broadcast(prefix + counts, get_sub_group_size() - 1);
1738
1739        i += get_sub_group_size();
1740    } while (i < num_dispatches);
1741
1742
1743    *record_out = queue->records[dispatch_id];
1744    return local_id;
1745}
1746
1747
1748
1749
1750struct BFSDispatchArgs get_bfs_args(struct BFSDispatchRecord* record, global struct VContextScheduler* scheduler, global struct SAHBuildGlobals* globals, uint local_group_id)
1751{
1752    uint context_id = record->context_id;
1753    struct VContext* context = &scheduler->contexts[context_id];
1754    bool odd_pass = context->tree_depth & 1;
1755
1756    struct BFSDispatchArgs args;
1757    args.scheduler              = scheduler;
1758    args.primref_index_in       = SAHBuildGlobals_GetPrimrefIndices_In( globals, odd_pass );
1759    args.primref_index_out      = SAHBuildGlobals_GetPrimrefIndices_Out( globals, odd_pass );
1760    args.primref_buffer         = SAHBuildGlobals_GetPrimrefs( globals );
1761    args.wg_primref_begin       = context->dispatch_primref_begin + local_group_id * BFS_WG_SIZE;
1762    args.wg_primref_end         = min( args.wg_primref_begin + BFS_WG_SIZE, context->dispatch_primref_end );
1763    args.dispatch_primref_begin = context->dispatch_primref_begin;
1764    args.dispatch_primref_end   = context->dispatch_primref_end;
1765    args.context_id             = context_id;
1766    args.context                = &scheduler->contexts[context_id];
1767    args.num_wgs                = ((args.dispatch_primref_end - args.dispatch_primref_begin) + BFS_WG_SIZE - 1) / BFS_WG_SIZE;
1768    args.bvh2_root              = context->bvh2_root;
1769    args.bvh2 = SAHBuildGlobals_GetBVH2( globals );
1770    args.global_num_primrefs = SAHBuildGlobals_GetTotalPrimRefs( globals );
1771    args.do_mask_processing = SAHBuildGlobals_NeedMasks( globals );
1772    return args;
1773}
1774
1775struct BFSDispatchArgs get_bfs_args_queue( global struct BFSDispatchQueue* queue,
1776                                           global struct VContextScheduler* scheduler,
1777                                           global struct SAHBuildGlobals* globals )
1778{
1779
1780    // TODO_OPT:  Load this entire prefix array into SLM instead of searching..
1781    //    Or use sub-group ops
1782
1783    struct BFSDispatchRecord record;
1784    uint local_group_id = record_search(&record, queue);
1785
1786    return get_bfs_args(&record, scheduler, globals, local_group_id);
1787}
1788
1789
1790struct BFSDispatchArgs get_bfs_args_from_record( struct BFSDispatchRecord* record,
1791                                           global struct VContextScheduler* scheduler,
1792                                           global struct SAHBuildGlobals* globals )
1793{
1794    return get_bfs_args(record, scheduler, globals, 0);
1795}
1796
1797
1798struct BFSDispatchArgs get_bfs_args_batchable(
1799    global struct BFSDispatchQueue* queue,
1800    global struct VContextScheduler* scheduler,
1801    global struct SAHBuildGlobals* globals_buffer )
1802{
1803
1804    // TODO_OPT:  Load this entire prefix array into SLM instead of searching..
1805    //    Or use sub-group ops
1806
1807    struct BFSDispatchRecord record;
1808    uint local_group_id = record_search(&record, queue);
1809
1810    global struct SAHBuildGlobals* globals = globals_buffer + record.batch_index;
1811
1812    return get_bfs_args(&record, scheduler, globals, local_group_id);
1813}
1814
1815
1816struct BFSDispatchArgs get_bfs_args_from_record_batchable(
1817    struct BFSDispatchRecord* record,
1818    global struct VContextScheduler* scheduler,
1819    global struct SAHBuildGlobals* globals_buffer )
1820{
1821    global struct SAHBuildGlobals* globals = globals_buffer + record->batch_index;
1822
1823    return get_bfs_args(record, scheduler, globals, 0);
1824}
1825
1826struct BFSDispatchArgs get_bfs_args_initial( global struct VContextScheduler* scheduler, global struct SAHBuildGlobals* globals )
1827{
1828    uint context_id = 0;
1829
1830    uint num_refs = SAHBuildGlobals_GetTotalPrimRefs( globals );
1831
1832    struct BFSDispatchArgs args;
1833    args.scheduler = scheduler;
1834    args.primref_index_in   = SAHBuildGlobals_GetPrimrefIndices_In( globals, false );
1835    args.primref_index_out  = SAHBuildGlobals_GetPrimrefIndices_Out( globals, false );
1836    args.primref_buffer     = SAHBuildGlobals_GetPrimrefs( globals );
1837    args.wg_primref_begin   = get_group_id(0) * BFS_WG_SIZE;
1838    args.wg_primref_end     = min( args.wg_primref_begin + BFS_WG_SIZE, num_refs );
1839    args.dispatch_primref_begin = 0;
1840    args.dispatch_primref_end   = num_refs;
1841    args.context_id = context_id;
1842    args.context = &scheduler->contexts[context_id];
1843    args.num_wgs = ((args.dispatch_primref_end - args.dispatch_primref_begin) + BFS_WG_SIZE - 1) / BFS_WG_SIZE;
1844    args.bvh2 = SAHBuildGlobals_GetBVH2( globals );
1845    args.bvh2_root = BVH2_GetRoot( args.bvh2 );
1846    args.global_num_primrefs = SAHBuildGlobals_GetTotalPrimRefs( globals );
1847    args.do_mask_processing = SAHBuildGlobals_NeedMasks(globals);
1848    return args;
1849}
1850
1851
1852inline void BinMapping_init( struct BinMapping* binMapping, struct AABB3f* centBounds, const uint bins )
1853{
1854    const float4 eps = 1E-34f;
1855    const float4 omega = 1E+34f;
1856    float3 l = AABB3f_load_lower( centBounds );
1857    float3 u = AABB3f_load_upper( centBounds );
1858    float4 diag;
1859    diag.xyz = max( eps.xyz, u - l );
1860    diag.w = 0;
1861    float4 scale = (float4)(0.99f * (float)bins) / diag;
1862    scale = select( (float4)(0.0f), scale, diag > eps );
1863    scale = select( (float4)(0.0f), scale, diag < omega );
1864    binMapping->scale = scale;
1865    binMapping->ofs.xyz = l.xyz;
1866    binMapping->ofs.w = 0;
1867}
1868
1869
1870inline ulong getBestSplit( float3 sah, uint ID, const float4 scale, const ulong defaultSplit )
1871{
1872    ulong splitX = (((ulong)as_uint( sah.x )) << 32) | ((uint)ID << 2) | 0;
1873    ulong splitY = (((ulong)as_uint( sah.y )) << 32) | ((uint)ID << 2) | 1;
1874    ulong splitZ = (((ulong)as_uint( sah.z )) << 32) | ((uint)ID << 2) | 2;
1875    /* ignore zero sized dimensions */
1876    splitX = select( splitX, defaultSplit, (ulong)(scale.x == 0) );
1877    splitY = select( splitY, defaultSplit, (ulong)(scale.y == 0) );
1878    splitZ = select( splitZ, defaultSplit, (ulong)(scale.z == 0) );
1879    ulong bestSplit = min( min( splitX, splitY ), splitZ );
1880    bestSplit = sub_group_reduce_min( bestSplit );
1881    return bestSplit;
1882}
1883
1884
1885
1886inline float left_to_right_area16( struct AABB3f* low )
1887{
1888    struct AABB3f low_prefix = AABB3f_sub_group_scan_exclusive_min_max( low );
1889    return halfArea_AABB3f( &low_prefix );
1890}
1891
1892inline uint left_to_right_counts16( uint low )
1893{
1894    return sub_group_scan_exclusive_add( low );
1895}
1896
1897inline float right_to_left_area16( struct AABB3f* low )
1898{
1899    const uint subgroupLocalID = get_sub_group_local_id();
1900    const uint subgroup_size = get_sub_group_size();
1901    const uint ID = subgroup_size - 1 - subgroupLocalID;
1902    struct AABB3f low_reverse = AABB3f_sub_group_shuffle( low, ID );
1903    struct AABB3f low_prefix = AABB3f_sub_group_scan_inclusive_min_max( &low_reverse );
1904    const float low_area = intel_sub_group_shuffle( halfArea_AABB3f( &low_prefix ), ID );
1905    return low_area;
1906}
1907
1908inline uint right_to_left_counts16( uint low )
1909{
1910    const uint subgroupLocalID = get_sub_group_local_id();
1911    const uint subgroup_size = get_sub_group_size();
1912    const uint ID = subgroup_size - 1 - subgroupLocalID;
1913    const uint low_reverse = intel_sub_group_shuffle( low, ID );
1914    const uint low_prefix = sub_group_scan_inclusive_add( low_reverse );
1915    return intel_sub_group_shuffle( low_prefix, ID );
1916}
1917
1918inline float2 left_to_right_area32( struct AABB3f* low, struct AABB3f* high )
1919{
1920    struct AABB3f low_prefix = AABB3f_sub_group_scan_exclusive_min_max( low );
1921    struct AABB3f low_reduce = AABB3f_sub_group_reduce( low );
1922    struct AABB3f high_prefix = AABB3f_sub_group_scan_exclusive_min_max( high );
1923    AABB3f_extend( &high_prefix, &low_reduce );
1924    const float low_area = halfArea_AABB3f( &low_prefix );
1925    const float high_area = halfArea_AABB3f( &high_prefix );
1926    return (float2)(low_area, high_area);
1927}
1928
1929inline uint2 left_to_right_counts32( uint low, uint high )
1930{
1931    const uint low_prefix = sub_group_scan_exclusive_add( low );
1932    const uint low_reduce = sub_group_reduce_add( low );
1933    const uint high_prefix = sub_group_scan_exclusive_add( high );
1934    return (uint2)(low_prefix, low_reduce + high_prefix);
1935}
1936
1937inline float2 right_to_left_area32( struct AABB3f* low, struct AABB3f* high )
1938{
1939    const uint subgroupLocalID = get_sub_group_local_id();
1940    const uint subgroup_size = get_sub_group_size();
1941    const uint ID = subgroup_size - 1 - subgroupLocalID;
1942    struct AABB3f low_reverse = AABB3f_sub_group_shuffle( high, ID );
1943    struct AABB3f high_reverse = AABB3f_sub_group_shuffle( low, ID );
1944    struct AABB3f low_prefix = AABB3f_sub_group_scan_inclusive_min_max( &low_reverse );
1945    struct AABB3f low_reduce = AABB3f_sub_group_reduce( &low_reverse );
1946    struct AABB3f high_prefix = AABB3f_sub_group_scan_inclusive_min_max( &high_reverse );
1947    AABB3f_extend( &high_prefix, &low_reduce );
1948    const float low_area = intel_sub_group_shuffle( halfArea_AABB3f( &high_prefix ), ID );
1949    const float high_area = intel_sub_group_shuffle( halfArea_AABB3f( &low_prefix ), ID );
1950    return (float2)(low_area, high_area);
1951}
1952
1953inline uint2 right_to_left_counts32( uint low, uint high )
1954{
1955    const uint subgroupLocalID = get_sub_group_local_id();
1956    const uint subgroup_size = get_sub_group_size();
1957    const uint ID = subgroup_size - 1 - subgroupLocalID;
1958    const uint low_reverse = intel_sub_group_shuffle( high, ID );
1959    const uint high_reverse = intel_sub_group_shuffle( low, ID );
1960    const uint low_prefix = sub_group_scan_inclusive_add( low_reverse );
1961    const uint low_reduce = sub_group_reduce_add( low_reverse );
1962    const uint high_prefix = sub_group_scan_inclusive_add( high_reverse ) + low_reduce;
1963    return (uint2)(intel_sub_group_shuffle( high_prefix, ID ), intel_sub_group_shuffle( low_prefix, ID ));
1964}
1965
1966inline uint fastDivideBy6_uint( uint v )
1967{
1968#if 1
1969    const ulong u = (ulong)v >> 1;
1970    return (uint)((u * 0x55555556ul) >> 32);
1971#else
1972    return v / 6;
1973#endif
1974}
1975
1976inline uint3 fastDivideBy6_uint3( uint3 v )
1977{
1978    return (uint3)(fastDivideBy6_uint( v.x ), fastDivideBy6_uint( v.y ), fastDivideBy6_uint( v.z ));
1979}
1980
1981#define SAH_LOG_BLOCK_SHIFT 2
1982
1983inline struct BFS_Split BinInfo_reduce( struct BFS_BinInfo* binInfo, const float4 scale )
1984{
1985    const uint subgroupLocalID = get_sub_group_local_id();
1986    const uint subgroup_size = get_sub_group_size();
1987
1988    struct AABB3f boundsX = BinInfo_get_AABB( binInfo, subgroupLocalID, 0 );
1989
1990    const float lr_areaX = left_to_right_area16( &boundsX );
1991    const float rl_areaX = right_to_left_area16( &boundsX );
1992
1993    struct AABB3f boundsY = BinInfo_get_AABB( binInfo, subgroupLocalID, 1 );
1994
1995    const float lr_areaY = left_to_right_area16( &boundsY );
1996    const float rl_areaY = right_to_left_area16( &boundsY );
1997
1998    struct AABB3f boundsZ = BinInfo_get_AABB( binInfo, subgroupLocalID, 2 );
1999
2000    const float lr_areaZ = left_to_right_area16( &boundsZ );
2001    const float rl_areaZ = right_to_left_area16( &boundsZ );
2002
2003    const uint3 counts = BinInfo_get_counts( binInfo, subgroupLocalID );
2004
2005    const uint lr_countsX = left_to_right_counts16( counts.x );
2006    const uint rl_countsX = right_to_left_counts16( counts.x );
2007    const uint lr_countsY = left_to_right_counts16( counts.y );
2008    const uint rl_countsY = right_to_left_counts16( counts.y );
2009    const uint lr_countsZ = left_to_right_counts16( counts.z );
2010    const uint rl_countsZ = right_to_left_counts16( counts.z );
2011
2012    const float3 lr_area = (float3)(lr_areaX, lr_areaY, lr_areaZ);
2013    const float3 rl_area = (float3)(rl_areaX, rl_areaY, rl_areaZ);
2014
2015    const uint3 lr_count = fastDivideBy6_uint3( (uint3)(lr_countsX, lr_countsY, lr_countsZ) + 6 - 1 );
2016    const uint3 rl_count = fastDivideBy6_uint3( (uint3)(rl_countsX, rl_countsY, rl_countsZ) + 6 - 1 );
2017    float3 sah = fma( lr_area, convert_float3( lr_count ), rl_area * convert_float3( rl_count ) );
2018
2019    /* first bin is invalid */
2020    sah.x = select( (float)(INFINITY), sah.x, subgroupLocalID != 0 );
2021    sah.y = select( (float)(INFINITY), sah.y, subgroupLocalID != 0 );
2022    sah.z = select( (float)(INFINITY), sah.z, subgroupLocalID != 0 );
2023
2024    const ulong defaultSplit = (((ulong)as_uint( (float)(INFINITY) )) << 32);
2025
2026    const ulong bestSplit = getBestSplit( sah, subgroupLocalID, scale, defaultSplit );
2027
2028    struct BFS_Split split;
2029    split.sah = as_float( (uint)(bestSplit >> 32) );
2030    split.dim = (uint)bestSplit & 3;
2031    split.pos = (uint)bestSplit >> 2;
2032
2033    return split;
2034}
2035
2036
2037struct BFS_BinInfoReduce3_SLM
2038{
2039    uint sah[3*BFS_NUM_BINS];
2040};
2041
2042
2043
2044inline struct BFS_Split BinInfo_reduce3( local struct BFS_BinInfoReduce3_SLM* slm, struct BFS_BinInfo* binInfo, const float4 scale )
2045{
2046    // process each bin/axis combination across sub-groups
2047    for (uint i = get_sub_group_id(); i < 3 * BFS_NUM_BINS; i += get_num_sub_groups())
2048    {
2049        uint my_bin  = i % BFS_NUM_BINS;
2050        uint my_axis = i / BFS_NUM_BINS;
2051
2052        float3 left_lower  = (float3)(INFINITY,INFINITY,INFINITY);
2053        float3 left_upper  = -left_lower;
2054        float3 right_lower = (float3)(INFINITY,INFINITY,INFINITY);
2055        float3 right_upper = -right_lower;
2056
2057        // load the other bins and assign them to the left or to the right
2058        //  of this subgroup's bin
2059        uint lane = get_sub_group_local_id();
2060        struct AABB3f sg_bins = BinInfo_get_AABB(binInfo,lane,my_axis);
2061
2062        bool is_left = (lane < my_bin);
2063        float3 lower = AABB3f_load_lower(&sg_bins);
2064        float3 upper = AABB3f_load_upper(&sg_bins);
2065
2066        float3 lower_l = select_min( lower, is_left  );
2067        float3 upper_l = select_max( upper, is_left  );
2068        float3 lower_r = select_min( lower, !is_left );
2069        float3 upper_r = select_max( upper, !is_left );
2070
2071        lower_l = sub_group_reduce_min_float3( lower_l );
2072        lower_r = sub_group_reduce_min_float3( lower_r );
2073        upper_l = sub_group_reduce_max_float3( upper_l );
2074        upper_r = sub_group_reduce_max_float3( upper_r );
2075        float3 dl = upper_l - lower_l;
2076        float3 dr = upper_r - lower_r;
2077        float area_l =  dl.x* (dl.y + dl.z) + (dl.y * dl.z);
2078        float area_r =  dr.x* (dr.y + dr.z) + (dr.y * dr.z);
2079
2080        // get the counts
2081        uint sg_bin_count = BinInfo_get_count(binInfo, lane, my_axis);
2082        uint count_l = (is_left) ?  sg_bin_count : 0;
2083        uint count_r = (is_left) ?  0 : sg_bin_count;
2084        count_l = sub_group_reduce_add(count_l);
2085        count_r = sub_group_reduce_add(count_r);
2086
2087        // compute sah
2088        count_l = fastDivideBy6_uint(count_l + 6 - 1);
2089        count_r = fastDivideBy6_uint(count_r + 6 - 1);
2090        float lr_partial = area_l * count_l;
2091        float rl_partial = area_r * count_r;
2092        float sah = lr_partial + rl_partial;
2093
2094        // first bin is invalid
2095        sah = select((float)(INFINITY), sah, my_bin != 0);
2096
2097        // ignore zero sized dimensions
2098        sah = select( sah, (float)(INFINITY), (scale.x == 0 && my_axis == 0) );
2099        sah = select( sah, (float)(INFINITY), (scale.y == 0 && my_axis == 1) );
2100        sah = select( sah, (float)(INFINITY), (scale.z == 0 && my_axis == 2) );
2101
2102        // tuck the axis into the bottom bits of sah cost.
2103        //  The result is an integer between 0 and +inf (7F800000)
2104        //  If we have 3 axes with infinite sah cost, we will select axis 0
2105        slm->sah[i] = (as_uint(sah)&~0x3) | my_axis;
2106    }
2107
2108    barrier( CLK_LOCAL_MEM_FENCE );
2109
2110    // reduce split candidates down to one subgroup
2111    //  sah is strictly positive, so integer compares can be used
2112    //   which results in a faster sub_group_reduce_min()
2113    //
2114    uint best_sah = 0xffffffff;
2115
2116    uint lid = get_sub_group_local_id();
2117    if (lid < BFS_NUM_BINS)
2118    {
2119        best_sah = slm->sah[lid];
2120        lid += BFS_NUM_BINS;
2121        best_sah = min( best_sah, slm->sah[lid] );
2122        lid += BFS_NUM_BINS;
2123        best_sah = min( best_sah, slm->sah[lid] );
2124    }
2125
2126    uint reduced_bestsah = sub_group_reduce_min( best_sah );
2127    uint best_bin = ctz(intel_sub_group_ballot(best_sah == reduced_bestsah));
2128    uint best_axis = as_uint(reduced_bestsah) & 0x3;
2129
2130    struct BFS_Split ret;
2131    ret.sah = as_float(reduced_bestsah);
2132    ret.dim = best_axis;
2133    ret.pos = best_bin;
2134    return ret;
2135}
2136
2137
2138struct BFS_BinInfoReduce_SLM
2139{
2140    struct
2141    {
2142        float sah;
2143        uint bin;
2144    } axisInfo[3];
2145};
2146
2147
2148
2149inline struct BFS_Split BinInfo_reduce2( local struct BFS_BinInfoReduce_SLM* slm, struct BFS_BinInfo* binInfo, const float4 scale, uint num_primrefs)
2150{
2151    ushort my_axis = get_sub_group_id();
2152    ushort my_bin  = get_sub_group_local_id();
2153
2154    if (my_axis < 3)
2155    {
2156        struct AABB3f aabb = BinInfo_get_AABB(binInfo, my_bin, my_axis);
2157        uint count         = BinInfo_get_count(binInfo, my_bin, my_axis);
2158
2159        float lr_area = left_to_right_area16(&aabb);
2160        float rl_area = right_to_left_area16(&aabb);
2161
2162        uint lr_count = sub_group_scan_exclusive_add(count);
2163        uint rl_count = num_primrefs - lr_count;
2164
2165        lr_count = fastDivideBy6_uint(lr_count + 6 - 1);
2166        rl_count = fastDivideBy6_uint(rl_count + 6 - 1);
2167        float lr_partial = lr_area * lr_count;
2168        float rl_partial = rl_area * rl_count;
2169        float sah = lr_partial + rl_partial;
2170
2171        // first bin is invalid
2172        sah = select((float)(INFINITY), sah, my_bin != 0);
2173
2174        float best_sah = sub_group_reduce_min( sah );
2175        uint best_bin = ctz(intel_sub_group_ballot(sah == best_sah));
2176
2177        // ignore zero sized dimensions
2178        best_sah = select( best_sah, (float)(INFINITY), (scale.x == 0 && my_axis == 0) );
2179        best_sah = select( best_sah, (float)(INFINITY), (scale.y == 0 && my_axis == 1) );
2180        best_sah = select( best_sah, (float)(INFINITY), (scale.z == 0 && my_axis == 2) );
2181
2182        if (get_sub_group_local_id() == 0)
2183        {
2184            slm->axisInfo[my_axis].sah = best_sah;
2185            slm->axisInfo[my_axis].bin = best_bin;
2186        }
2187    }
2188    barrier( CLK_LOCAL_MEM_FENCE );
2189
2190    float sah = (float)(INFINITY);
2191    if( get_sub_group_local_id() < 3 )
2192        sah = slm->axisInfo[get_sub_group_local_id()].sah;
2193
2194    float bestsah = min(sub_group_broadcast(sah, 0), min(sub_group_broadcast(sah, 1), sub_group_broadcast(sah, 2)));
2195    uint bestAxis = ctz( intel_sub_group_ballot(bestsah == sah) );
2196
2197    struct BFS_Split split;
2198    split.sah = bestsah;
2199    split.dim = bestAxis;
2200    split.pos = slm->axisInfo[bestAxis].bin;
2201    return split;
2202}
2203
2204
2205inline bool is_left( struct BinMapping* binMapping, struct BFS_Split* split, struct AABB* primref )
2206{
2207    const uint dim = split->dim;
2208    const float lower = primref->lower[dim];
2209    const float upper = primref->upper[dim];
2210    const float c = lower + upper;
2211    const uint pos = convert_uint_rtz( (c - binMapping->ofs[dim]) * binMapping->scale[dim] );
2212    return pos < split->pos;
2213}
2214
2215struct BFS_Pass1_SLM
2216{
2217    struct BFS_BinInfo bin_info;
2218//    struct BFS_BinInfoReduce3_SLM reduce3;
2219};
2220
2221
2222void DO_BFS_pass1( local struct BFS_Pass1_SLM*  slm,
2223                   uint thread_primref_id,
2224                   bool thread_primref_valid,
2225                   struct BFSDispatchArgs args
2226                  )
2227{
2228    local struct BFS_BinInfo* local_bin_info = &slm->bin_info;
2229    global struct VContext* context  = args.context;
2230    struct AABB3f centroid_bounds    = BVH2_GetNodeBox( args.bvh2, args.bvh2_root ); // root AABB is initialized to centroid bounds
2231
2232    struct BinMapping bin_mapping;
2233    BinMapping_init( &bin_mapping, &centroid_bounds, BFS_NUM_BINS );
2234
2235    // fetch this thread's primref
2236    PrimRef ref;
2237    if ( thread_primref_valid )
2238        ref = args.primref_buffer[thread_primref_id];
2239
2240    // init bin info
2241    BinInfo_init( local_bin_info );
2242
2243    // fence on local bin-info init
2244    barrier( CLK_LOCAL_MEM_FENCE );
2245
2246    // merge this thread's primref into local bin info
2247    BinInfo_add_primref( &bin_mapping, local_bin_info, &ref, thread_primref_valid );
2248
2249    // fence on local bin-info update
2250    barrier( CLK_LOCAL_MEM_FENCE );
2251
2252    BinInfo_merge(&context->global_bin_info, local_bin_info);
2253}
2254
2255
2256GRL_ANNOTATE_IGC_DO_NOT_SPILL
2257__attribute__( (reqd_work_group_size(BFS_WG_SIZE,1,1)))
2258__attribute__((intel_reqd_sub_group_size(16)))
2259kernel void
2260BFS_pass1_indexed(
2261    global struct VContextScheduler* scheduler,
2262    global struct SAHBuildGlobals* sah_globals )
2263{
2264    local struct BFS_Pass1_SLM slm;
2265    struct BFSDispatchArgs args = get_bfs_args_queue( &scheduler->bfs_queue, scheduler, sah_globals );
2266
2267    bool thread_primref_valid = (args.wg_primref_begin + get_local_id( 0 )) < args.wg_primref_end;
2268    uint thread_primref_id = 0;
2269    if ( thread_primref_valid )
2270        thread_primref_id = args.primref_index_in[args.wg_primref_begin + get_local_id( 0 )];
2271
2272    DO_BFS_pass1( &slm, thread_primref_id, thread_primref_valid, args );
2273}
2274
2275
2276__attribute__( (reqd_work_group_size( BFS_WG_SIZE, 1, 1 )) )
2277__attribute__((intel_reqd_sub_group_size(16)))
2278kernel void
2279BFS_pass1_initial( global struct VContextScheduler* scheduler, global struct SAHBuildGlobals* sah_globals )
2280{
2281    local struct BFS_Pass1_SLM slm;
2282    struct BFSDispatchArgs args = get_bfs_args_initial( scheduler, sah_globals );
2283
2284    uint thread_primref_id    = args.wg_primref_begin + get_local_id( 0 );
2285    bool thread_primref_valid = thread_primref_id < args.wg_primref_end;
2286
2287    DO_BFS_pass1( &slm, thread_primref_id, thread_primref_valid, args );
2288}
2289
2290
2291GRL_ANNOTATE_IGC_DO_NOT_SPILL
2292__attribute__((reqd_work_group_size(BFS_WG_SIZE, 1, 1)))
2293__attribute__((intel_reqd_sub_group_size(16)))
2294kernel void
2295BFS_pass1_indexed_batchable(
2296    global struct VContextScheduler* scheduler,
2297    global struct SAHBuildGlobals* globals_buffer )
2298{
2299    local struct BFS_Pass1_SLM slm;
2300    struct BFSDispatchArgs args = get_bfs_args_batchable( &scheduler->bfs_queue, scheduler, globals_buffer );
2301
2302    bool thread_primref_valid = (args.wg_primref_begin + get_local_id(0)) < args.wg_primref_end;
2303    uint thread_primref_id = 0;
2304    if (thread_primref_valid)
2305        thread_primref_id = args.primref_index_in[args.wg_primref_begin + get_local_id(0)];
2306
2307    DO_BFS_pass1(&slm, thread_primref_id, thread_primref_valid, args);
2308}
2309
2310
2311GRL_ANNOTATE_IGC_DO_NOT_SPILL
2312__attribute__((reqd_work_group_size(BFS_WG_SIZE, 1, 1)))
2313__attribute__((intel_reqd_sub_group_size(16)))
2314kernel void
2315BFS_pass1_initial_batchable( global struct VContextScheduler* scheduler, global struct SAHBuildGlobals* globals_buffer )
2316{
2317    local struct BFS_Pass1_SLM slm;
2318    struct BFSDispatchArgs args = get_bfs_args_batchable( &scheduler->bfs_queue, scheduler, globals_buffer );
2319
2320    uint thread_primref_id = args.wg_primref_begin + get_local_id(0);
2321    bool thread_primref_valid = thread_primref_id < args.wg_primref_end;
2322
2323    DO_BFS_pass1(&slm, thread_primref_id, thread_primref_valid, args);
2324}
2325
2326
2327/////////////////////////////////////////////////////////////////////////////////////////////////
2328///
2329///        BVH2 construction -- BFS Phase Pass2
2330///
2331/////////////////////////////////////////////////////////////////////////////////////////////////
2332
2333struct BFS_Pass2_SLM
2334{
2335    struct BFS_BinInfoReduce3_SLM reduce3;
2336    //struct AABB3f left_centroid_bounds;
2337    //struct AABB3f right_centroid_bounds;
2338    //struct AABB3f left_geom_bounds;
2339    //struct AABB3f right_geom_bounds;
2340    LRBounds lr_bounds;
2341    uint left_count;
2342    uint right_count;
2343    uint lr_mask;
2344    uint left_primref_base;
2345    uint right_primref_base;
2346//    uint num_wgs;
2347
2348//    uint output_indices[BFS_WG_SIZE];
2349};
2350
2351
2352
2353
2354
2355
2356
2357void DO_BFS_pass2(
2358    local struct BFS_Pass2_SLM* slm,
2359    uint thread_primref_id,
2360    bool thread_primref_valid,
2361    struct BFSDispatchArgs args
2362)
2363{
2364    global struct VContext* context = args.context;
2365
2366    struct AABB3f centroid_bounds = BVH2_GetNodeBox( args.bvh2, args.bvh2_root );
2367
2368    // load the thread's primref
2369    PrimRef ref;
2370    if ( thread_primref_valid )
2371        ref = args.primref_buffer[thread_primref_id];
2372
2373    struct BinMapping bin_mapping;
2374    BinMapping_init( &bin_mapping, &centroid_bounds, BFS_NUM_BINS );
2375
2376    // initialize working SLM space
2377    LRBounds_init(&slm->lr_bounds);
2378    if(get_local_id(0) == 0)
2379    {
2380        slm->left_count  = 0;
2381        slm->right_count = 0;
2382
2383        if( args.do_mask_processing )
2384            slm->lr_mask = 0;
2385    }
2386
2387    // compute split - every workgroup does the same computation
2388    // local barrier inside BinInfo_reduce3
2389    struct BFS_Split split = BinInfo_reduce3( &slm->reduce3, &context->global_bin_info,bin_mapping.scale );
2390
2391    uint wg_prim_count = args.wg_primref_end - args.wg_primref_begin;
2392
2393    // partition primrefs into L/R subsets...
2394    bool go_left = false;
2395    if (split.sah == (float)(INFINITY))      // no valid split, split in the middle.. This can happen due to floating-point limit cases in huge scenes
2396        go_left = get_local_id(0) < (wg_prim_count / 2);
2397    else
2398        go_left = is_left( &bin_mapping, &split, &ref );
2399
2400    // assign this primref a position in the output array, and expand corresponding centroid-bounds
2401    uint local_index;
2402    {
2403        float3 centroid = ref.lower.xyz + ref.upper.xyz;
2404
2405        uint l_ballot = intel_sub_group_ballot(  go_left && thread_primref_valid );
2406        uint r_ballot = intel_sub_group_ballot( !go_left && thread_primref_valid );
2407        if (l_ballot)
2408        {
2409            bool active_lane = l_ballot & (1 << get_sub_group_local_id());
2410            float3 Cmin, Cmax, Gmin, Gmax;
2411            Cmin = select_min( centroid.xyz, active_lane );
2412            Cmax = select_max( centroid.xyz, active_lane );
2413            Gmin = select_min( ref.lower.xyz, active_lane );
2414            Gmax = select_max( ref.upper.xyz, active_lane );
2415
2416            Cmin = sub_group_reduce_min_float3( Cmin );
2417            Cmax = sub_group_reduce_max_float3( Cmax );
2418            Gmin = sub_group_reduce_min_float3( Gmin );
2419            Gmax = sub_group_reduce_max_float3( Gmax );
2420
2421            LRBounds_merge_left( &slm->lr_bounds, Cmin,Cmax,Gmin,Gmax );
2422        }
2423
2424        if (r_ballot)
2425        {
2426            bool active_lane = r_ballot & (1 << get_sub_group_local_id());
2427            float3 Cmin, Cmax, Gmin, Gmax;
2428            Cmin = select_min(centroid.xyz, active_lane);
2429            Cmax = select_max(centroid.xyz, active_lane);
2430            Gmin = select_min(ref.lower.xyz, active_lane);
2431            Gmax = select_max(ref.upper.xyz, active_lane);
2432
2433            Cmin = sub_group_reduce_min_float3(Cmin);
2434            Cmax = sub_group_reduce_max_float3(Cmax);
2435            Gmin = sub_group_reduce_min_float3(Gmin);
2436            Gmax = sub_group_reduce_max_float3(Gmax);
2437
2438            LRBounds_merge_right( &slm->lr_bounds, Cmin,Cmax,Gmin,Gmax );
2439        }
2440
2441        if( args.do_mask_processing )
2442        {
2443            uint mask =0;
2444            if (thread_primref_valid)
2445            {
2446                mask = PRIMREF_instanceMask(&ref) ;
2447                mask = go_left  ? mask : mask<<8;
2448            }
2449
2450            // TODO OPT:  there is no 'sub_group_reduce_or'  and IGC does not do the reduction trick
2451            //   for atomics on sub-group uniform addresses
2452            for( uint i= get_sub_group_size()/2; i>0; i/= 2)
2453                mask = mask | intel_sub_group_shuffle_down(mask,mask,i);
2454            if( get_sub_group_local_id() == 0 )
2455                atomic_or_local( &slm->lr_mask, mask );
2456        }
2457
2458        uint l_base = 0;
2459        uint r_base = 0;
2460        if( get_sub_group_local_id() == 0 && l_ballot )
2461            l_base = atomic_add_local( &slm->left_count, popcount(l_ballot) );
2462        if( get_sub_group_local_id() == 0 && r_ballot )
2463            r_base = atomic_add_local( &slm->right_count, popcount(r_ballot) );
2464
2465        sub_group_barrier( CLK_LOCAL_MEM_FENCE );
2466        l_base = sub_group_broadcast(l_base,0);
2467        r_base = sub_group_broadcast(r_base,0);
2468
2469        l_base = l_base + subgroup_bit_prefix_exclusive( l_ballot );
2470        r_base = r_base + subgroup_bit_prefix_exclusive( r_ballot );
2471
2472        local_index = (go_left) ? l_base : r_base;
2473    }
2474
2475
2476    barrier( CLK_LOCAL_MEM_FENCE );
2477
2478    // merge local into global
2479    // TODO_OPT:  Look at spreading some of this across subgroups
2480    if ( get_sub_group_id() == 0 )
2481    {
2482        // allocate primref space for this wg and merge local/global centroid bounds
2483        uint num_left  = slm->left_count;
2484        {
2485            if (num_left && get_sub_group_local_id() == 0)
2486            {
2487                num_left = atomic_add_global( &context->num_left, num_left );
2488                slm->left_primref_base = args.dispatch_primref_begin + num_left;
2489            }
2490        }
2491        uint num_right = slm->right_count;
2492        {
2493            if (num_right && get_sub_group_local_id() == 0)
2494            {
2495                num_right = atomic_add_global( &context->num_right, num_right );
2496                slm->right_primref_base = (args.dispatch_primref_end - 1) - num_right;
2497            }
2498        }
2499
2500        if( args.do_mask_processing && get_sub_group_local_id() == 0 )
2501            atomic_or_global( &context->lr_mask, slm->lr_mask );
2502    }
2503
2504    barrier( CLK_LOCAL_MEM_FENCE );
2505
2506    LRBounds_merge( &context->lr_bounds, &slm->lr_bounds );
2507
2508    // move thread's primref ID into correct position in output index buffer
2509    if (thread_primref_valid)
2510    {
2511        uint pos = go_left ? slm->left_primref_base + local_index
2512            : slm->right_primref_base - local_index;
2513
2514        args.primref_index_out[pos] = thread_primref_id;
2515    }
2516}
2517
2518
2519GRL_ANNOTATE_IGC_DO_NOT_SPILL
2520__attribute__( (reqd_work_group_size( BFS_WG_SIZE, 1, 1 )) )
2521__attribute__( (intel_reqd_sub_group_size( 16 )) )
2522kernel void
2523BFS_pass2_indexed( global struct VContextScheduler* scheduler, global struct SAHBuildGlobals* sah_globals )
2524{
2525    local struct BFS_Pass2_SLM slm;
2526    struct BFSDispatchArgs args = get_bfs_args_queue( &scheduler->bfs_queue, scheduler, sah_globals );
2527
2528    bool thread_primref_valid = (args.wg_primref_begin + get_local_id( 0 )) < args.wg_primref_end;
2529    uint thread_primref_id = 0;
2530    if ( thread_primref_valid )
2531        thread_primref_id = args.primref_index_in[args.wg_primref_begin + get_local_id( 0 )];
2532
2533    DO_BFS_pass2( &slm, thread_primref_id, thread_primref_valid, args );
2534}
2535
2536
2537GRL_ANNOTATE_IGC_DO_NOT_SPILL
2538__attribute__( (reqd_work_group_size( BFS_WG_SIZE, 1, 1 )) )
2539__attribute__( (intel_reqd_sub_group_size( 16 )) )
2540kernel void
2541BFS_pass2_initial( global struct VContextScheduler* scheduler, global struct SAHBuildGlobals* sah_globals )
2542{
2543    local struct BFS_Pass2_SLM slm;
2544    struct BFSDispatchArgs args = get_bfs_args_initial( scheduler, sah_globals );
2545
2546    uint thread_primref_id    = args.wg_primref_begin + get_local_id( 0 );
2547    bool thread_primref_valid = thread_primref_id < args.wg_primref_end;
2548
2549    DO_BFS_pass2( &slm, thread_primref_id, thread_primref_valid, args );
2550}
2551
2552
2553__attribute__((reqd_work_group_size(BFS_WG_SIZE, 1, 1)))
2554__attribute__((intel_reqd_sub_group_size(16)))
2555kernel void
2556BFS_pass2_indexed_batchable( global struct VContextScheduler* scheduler, global struct SAHBuildGlobals* globals_buffer )
2557{
2558    local struct BFS_Pass2_SLM slm;
2559    struct BFSDispatchArgs args = get_bfs_args_batchable(&scheduler->bfs_queue, scheduler, globals_buffer );
2560
2561    bool thread_primref_valid = (args.wg_primref_begin + get_local_id(0)) < args.wg_primref_end;
2562    uint thread_primref_id = 0;
2563    if (thread_primref_valid)
2564        thread_primref_id = args.primref_index_in[args.wg_primref_begin + get_local_id(0)];
2565
2566    DO_BFS_pass2(&slm, thread_primref_id, thread_primref_valid, args);
2567
2568}
2569
2570
2571GRL_ANNOTATE_IGC_DO_NOT_SPILL
2572__attribute__((reqd_work_group_size(BFS_WG_SIZE, 1, 1)))
2573__attribute__((intel_reqd_sub_group_size(16)))
2574kernel void
2575BFS_pass2_initial_batchable(global struct VContextScheduler* scheduler, global struct SAHBuildGlobals* globals_buffer)
2576{
2577    local struct BFS_Pass2_SLM slm;
2578    struct BFSDispatchArgs args = get_bfs_args_batchable(&scheduler->bfs_queue, scheduler, globals_buffer );
2579
2580    uint thread_primref_id = args.wg_primref_begin + get_local_id(0);
2581    bool thread_primref_valid = thread_primref_id < args.wg_primref_end;
2582
2583    DO_BFS_pass2(&slm, thread_primref_id, thread_primref_valid, args);
2584}
2585
2586
2587
2588
2589/////////////////////////////////////////////////////////////////////////////////////////////////
2590///
2591///        BVH2 construction -- DFS Phase
2592///
2593/////////////////////////////////////////////////////////////////////////////////////////////////
2594
2595struct DFSArgs
2596{
2597    uint primref_base;
2598    uint global_bvh2_base;
2599    bool do_mask_processing;
2600    ushort num_primrefs;
2601    global uint* primref_indices_in;
2602    global uint* primref_indices_out;
2603    global PrimRef* primref_buffer;
2604    global struct BVH2* global_bvh2;
2605};
2606
2607
2608struct DFSPrimRefAABB
2609{
2610    half lower[3];
2611    half upper[3];
2612};
2613
2614void DFSPrimRefAABB_init( struct DFSPrimRefAABB* bb )
2615{
2616    bb->lower[0] = 1;
2617    bb->lower[1] = 1;
2618    bb->lower[2] = 1;
2619    bb->upper[0] = 0;
2620    bb->upper[1] = 0;
2621    bb->upper[2] = 0;
2622}
2623
2624void DFSPrimRefAABB_extend( struct DFSPrimRefAABB* aabb, struct DFSPrimRefAABB* v )
2625{
2626    aabb->lower[0] = min( aabb->lower[0], v->lower[0] );
2627    aabb->lower[1] = min( aabb->lower[1], v->lower[1] );
2628    aabb->lower[2] = min( aabb->lower[2], v->lower[2] );
2629    aabb->upper[0] = max( aabb->upper[0], v->upper[0] );
2630    aabb->upper[1] = max( aabb->upper[1], v->upper[1] );
2631    aabb->upper[2] = max( aabb->upper[2], v->upper[2] );
2632}
2633
2634half DFSPrimRefAABB_halfArea( struct DFSPrimRefAABB* aabb )
2635{
2636    const half3 d = (half3)(aabb->upper[0] - aabb->lower[0], aabb->upper[1] - aabb->lower[1], aabb->upper[2] - aabb->lower[2]);
2637    return fma( d.x, (d.y + d.z), d.y * d.z );
2638}
2639
2640struct DFSPrimRef
2641{
2642    struct DFSPrimRefAABB aabb;
2643    ushort2 meta;
2644};
2645
2646void DFSPrimRef_SetBVH2Root( struct DFSPrimRef* ref, ushort root )
2647{
2648    ref->meta.y = root;
2649}
2650
2651uint DFSPrimRef_GetInputIndex( struct DFSPrimRef* ref )
2652{
2653    return ref->meta.x;
2654}
2655
2656uint DFSPrimRef_GetBVH2Parent( struct DFSPrimRef* ref )
2657{
2658    return ref->meta.y;
2659}
2660
2661
2662struct PrimRefSet
2663{
2664    struct DFSPrimRefAABB AABB[DFS_WG_SIZE];
2665    ushort2 meta[DFS_WG_SIZE];
2666    uint input_indices[DFS_WG_SIZE];
2667};
2668
2669
2670
2671
2672local struct DFSPrimRefAABB* PrimRefSet_GetAABBPointer( local struct PrimRefSet* refs, ushort id )
2673{
2674    return &refs->AABB[id];
2675}
2676struct DFSPrimRef PrimRefSet_GetPrimRef( local struct PrimRefSet* refs, ushort id )
2677{
2678    struct DFSPrimRef r;
2679    r.aabb = refs->AABB[id];
2680    r.meta = refs->meta[id];
2681    return r;
2682}
2683void PrimRefSet_SetPrimRef( local struct PrimRefSet* refs, struct DFSPrimRef ref, ushort id )
2684{
2685    refs->AABB[id] = ref.aabb;
2686    refs->meta[id] = ref.meta;
2687}
2688
2689void PrimRefSet_SetPrimRef_FullPrecision( struct AABB3f* root_aabb, local struct PrimRefSet* refs, PrimRef ref, ushort id )
2690{
2691    float3 root_l = AABB3f_load_lower( root_aabb );
2692    float3 root_u = AABB3f_load_upper( root_aabb );
2693    float3 d = root_u - root_l;
2694    float scale = 1.0f / max( d.x, max( d.y, d.z ) );
2695
2696    float3 l = ref.lower.xyz;
2697    float3 u = ref.upper.xyz;
2698    half3 lh = convert_half3_rtz( (l - root_l) * scale );
2699    half3 uh = convert_half3_rtp( (u - root_l) * scale );
2700
2701    refs->AABB[id].lower[0] = lh.x;
2702    refs->AABB[id].lower[1] = lh.y;
2703    refs->AABB[id].lower[2] = lh.z;
2704    refs->AABB[id].upper[0] = uh.x;
2705    refs->AABB[id].upper[1] = uh.y;
2706    refs->AABB[id].upper[2] = uh.z;
2707    refs->meta[id].x = id;
2708    refs->meta[id].y = 0;
2709}
2710
2711
2712
2713void DFS_CreatePrimRefSet( struct DFSArgs args,
2714                           local struct PrimRefSet* prim_refs )
2715{
2716    ushort id = get_local_id( 0 );
2717    ushort num_primrefs = args.num_primrefs;
2718
2719    struct AABB3f box = BVH2_GetNodeBox( args.global_bvh2, args.global_bvh2_base );
2720    if ( id < num_primrefs )
2721    {
2722        PrimRef ref = args.primref_buffer[args.primref_indices_in[id]];
2723        prim_refs->input_indices[id] = args.primref_indices_in[id];
2724        PrimRefSet_SetPrimRef_FullPrecision( &box, prim_refs, ref, id );
2725    }
2726}
2727
2728struct ThreadRangeInfo
2729{
2730    uchar start;
2731    uchar local_num_prims;
2732    uchar bvh2_root;
2733    bool  active;
2734};
2735
2736struct BVHBuildLocals // size:  ~3.8K
2737{
2738    uchar2                 axis_and_left_count[ DFS_WG_SIZE ];
2739    struct ThreadRangeInfo range[ DFS_WG_SIZE ];
2740    uint                   sah[ DFS_WG_SIZE ];
2741};
2742
2743#define LOCAL_BVH2_NODE_COUNT (2*(DFS_WG_SIZE) -1)
2744
2745struct LocalBVH2
2746{
2747    uint nodes[LOCAL_BVH2_NODE_COUNT];
2748    uint num_nodes;
2749
2750    // bit layout is for a node is
2751    //  uchar child_ptr;    // this is right_child_index >> 1.   right child's msb is always 0
2752    //  uchar primref_base; // index of the node's first primref.  will be 0 at the root
2753    //  uchar parent_dist;  // distance in nodes from this node to its parent
2754    //  uchar prim_counter; // number of prims in this subtree.  For a complete tree (256 prims), the root may be off by 1
2755
2756    // for a WG size of 256, 8b is enough for parent distance, because the tree is built in level order
2757    //    the maximum distance between parent and child occurs for a complete tree.
2758    //    in this scenario the left-most leaf has index 255, its parent has index 127, the deltas to the children are 128 and 129
2759};
2760
2761
2762void LocalBVH2_Initialize( struct LocalBVH2* bvh2, ushort num_prims )
2763{
2764    bvh2->num_nodes = 1;
2765    bvh2->nodes[0] = min(num_prims,(ushort)255);
2766}
2767
2768
2769
2770void LocalBVH2_Initialize_Presplit(struct LocalBVH2* bvh2, ushort num_prims, ushort left_count, ushort right_count )
2771{
2772    bvh2->num_nodes = 3;
2773    bvh2->nodes[0] = min(num_prims, (ushort)255);
2774
2775    ushort bvh2_root = 0;
2776    ushort child_place = 1;
2777
2778    uint child_ptr = (child_place + 1) >> 1;
2779    bvh2->nodes[bvh2_root] |= (child_ptr) << 24;
2780
2781    uint parent_dist = child_place - bvh2_root;
2782
2783    // initialize child nodes
2784    ushort primref_base_left = 0;
2785    ushort primref_base_right = left_count;
2786    uint left = (primref_base_left << 16) + ((parent_dist << 8)) + left_count;
2787    uint right = (primref_base_right << 16) + ((parent_dist + 1) << 8) + right_count;
2788    bvh2->nodes[child_place] = left;
2789    bvh2->nodes[child_place + 1] = right;
2790}
2791
2792
2793void LocalBVH2_CreateInnerNode( local struct LocalBVH2* bvh2, ushort bvh2_root, uint primref_base_left, uint primref_base_right )
2794{
2795    ushort child_place = atomic_add_local( &(bvh2-> num_nodes), 2 );
2796
2797    uint child_ptr   = (child_place + 1) >> 1;
2798    bvh2->nodes[bvh2_root] |= (child_ptr) << 24;
2799
2800    uint parent_dist = child_place - bvh2_root;
2801
2802    // initialize child nodes
2803    uint left  = (primref_base_left << 16)  + ((parent_dist << 8));
2804    uint right = (primref_base_right << 16) + ((parent_dist + 1) << 8);
2805    bvh2->nodes[child_place]     = left;
2806    bvh2->nodes[child_place + 1] = right;
2807}
2808
2809ushort2 LocalBVH2_GetChildIndices( struct LocalBVH2* bvh2, ushort bvh2_root )
2810{
2811    ushort right_idx = (bvh2->nodes[bvh2_root] & 0xff000000) >> 23;
2812    return (ushort2)(right_idx - 1, right_idx);
2813}
2814
2815
2816ushort LocalBVH2_IncrementPrimCount( local struct LocalBVH2* bvh2, ushort bvh2_root )
2817{
2818    // increment only the lower 8 bits.  Algorithm will not overflow by design
2819    return atomic_inc_local( &bvh2->nodes[bvh2_root] ) & 0xff;
2820}
2821
2822ushort LocalBVH2_SetLeafPrimCount(local struct LocalBVH2* bvh2, ushort bvh2_root, ushort count)
2823{
2824    return bvh2->nodes[bvh2_root] |= (count& 0xff);
2825}
2826
2827bool LocalBVH2_IsRoot( struct LocalBVH2* bvh2, ushort node_id )
2828{
2829    return node_id == 0;
2830}
2831
2832ushort LocalBVH2_GetLeafPrimrefStart( struct LocalBVH2* bvh2, ushort bvh2_node_id )
2833{
2834    return (bvh2->nodes[bvh2_node_id] >> 16) & 255;
2835}
2836
2837bool LocalBVH2_IsLeftChild( struct LocalBVH2* bvh2, ushort parent_node, ushort current_node )
2838{
2839    return (current_node & 1); // nodes are allocated in pairs.  first node is root, left child is an odd index
2840}
2841
2842ushort LocalBVH2_GetParent( struct LocalBVH2* bvh2, ushort node )
2843{
2844    return node - ((bvh2->nodes[node] >> 8) & 255);
2845}
2846
2847uint LocalBVH2_GetNodeCount( struct LocalBVH2* bvh2 )
2848{
2849    return bvh2->num_nodes;
2850}
2851
2852bool LocalBVH2_IsLeaf( struct LocalBVH2* bvh2, ushort node_index )
2853{
2854    return (bvh2->nodes[node_index] & 255) <= TREE_ARITY;
2855}
2856
2857ushort LocalBVH2_GetLeafPrimCount( struct LocalBVH2* bvh2, ushort node_index )
2858{
2859    return (bvh2->nodes[node_index] & 255);
2860}
2861
2862void DFS_ConstructBVH2( local struct LocalBVH2* bvh2,
2863                        local struct PrimRefSet* prim_refs,
2864                        ushort bvh2_root,
2865                        ushort prim_range_start,
2866                        ushort local_num_prims,
2867                        ushort global_num_prims,
2868                        local struct BVHBuildLocals* locals,
2869                        local uint* num_active_threads )
2870{
2871    ushort tid = get_local_id( 0 );
2872    ushort primref_position = tid;
2873
2874    bool active_thread = tid < global_num_prims;
2875
2876    // Handle cases where initial binner creates leaves
2877    if ( active_thread && local_num_prims <= TREE_ARITY )
2878    {
2879        struct DFSPrimRef ref = PrimRefSet_GetPrimRef(prim_refs, primref_position);
2880        DFSPrimRef_SetBVH2Root(&ref, bvh2_root);
2881        PrimRefSet_SetPrimRef(prim_refs, ref, primref_position);
2882        active_thread = false;
2883        if (primref_position == prim_range_start)
2884            atomic_sub_local(num_active_threads, local_num_prims);
2885    }
2886
2887    barrier( CLK_LOCAL_MEM_FENCE );
2888
2889    locals->range[ tid ].start           = prim_range_start;
2890    locals->range[ tid ].local_num_prims = local_num_prims;
2891    locals->range[ tid ].bvh2_root       = bvh2_root;
2892    locals->range[ tid ].active          = active_thread;
2893
2894    do
2895    {
2896        if(active_thread && prim_range_start == primref_position)
2897            locals->sah[primref_position] = UINT_MAX;
2898
2899        barrier( CLK_LOCAL_MEM_FENCE );
2900
2901        if ( active_thread )
2902        {
2903            local struct DFSPrimRefAABB* my_box = PrimRefSet_GetAABBPointer( prim_refs, primref_position );
2904
2905            // each thread evaluates a possible split candidate.  Scan primrefs and compute sah cost
2906            //  do this axis-by-axis to keep register pressure low
2907            float best_sah = INFINITY;
2908            ushort best_axis = 3;
2909            ushort best_count = 0;
2910
2911            struct DFSPrimRefAABB box_left[3];
2912            struct DFSPrimRefAABB box_right[3];
2913            float CSplit[3];
2914            ushort count_left[3];
2915
2916            for ( ushort axis = 0; axis < 3; axis++ )
2917            {
2918                DFSPrimRefAABB_init( &box_left[axis] );
2919                DFSPrimRefAABB_init( &box_right[axis] );
2920
2921                CSplit[axis] = my_box->lower[axis] + my_box->upper[axis];
2922                count_left[axis] = 0;
2923            }
2924
2925            // scan primrefs in our subtree and partition using this thread's prim as a split plane
2926            {
2927                struct DFSPrimRefAABB box = *PrimRefSet_GetAABBPointer( prim_refs, prim_range_start );
2928
2929                for ( ushort p = 1; p < local_num_prims; p++ )
2930                {
2931                        struct DFSPrimRefAABB next_box = *PrimRefSet_GetAABBPointer( prim_refs, prim_range_start + p ); //preloading box for next iteration
2932
2933                        for( ushort axis = 0; axis < 3; axis++ )
2934                        {
2935                            float c = box.lower[axis] + box.upper[axis];
2936
2937                            if ( c < CSplit[axis] )
2938                            {
2939                                // this primitive is to our left.
2940                                DFSPrimRefAABB_extend( &box_left[axis], &box );
2941                                count_left[axis]++;
2942                            }
2943                            else
2944                            {
2945                                // this primitive is to our right
2946                                DFSPrimRefAABB_extend( &box_right[axis], &box );
2947                            }
2948                        }
2949
2950                        box = next_box;
2951                }
2952
2953                // last iteration without preloading box
2954                for( ushort axis = 0; axis < 3; axis++ )
2955                {
2956                    float c = box.lower[axis] + box.upper[axis];
2957
2958                    if ( c < CSplit[axis] )
2959                    {
2960                        // this primitive is to our left.
2961                        DFSPrimRefAABB_extend( &box_left[axis], &box );
2962                        count_left[axis]++;
2963                    }
2964                    else
2965                    {
2966                        // this primitive is to our right
2967                        DFSPrimRefAABB_extend( &box_right[axis], &box );
2968                    }
2969                }
2970
2971            }
2972
2973            for ( ushort axis = 0; axis < 3; axis++ )
2974            {
2975                float Al = DFSPrimRefAABB_halfArea( &box_left[axis] );
2976                float Ar = DFSPrimRefAABB_halfArea( &box_right[axis] );
2977
2978                // Avoid NANs in SAH calculation in the corner case where all prims go right
2979                //  In this case we set Al=Ar, because such a split will only be selected if all primrefs
2980                //    are co-incident..  In that case, we will fall back to split-in-the-middle and both subtrees
2981                //    should store the same quantized area value
2982                if ( count_left[axis] == 0 )
2983                    Al = Ar;
2984
2985                // compute sah cost
2986                ushort count_right = local_num_prims - count_left[axis];
2987                float sah = Ar * count_right + Al * count_left[axis];
2988
2989                // keep this split if it is better than the previous one, or if the previous one was a corner-case
2990                if ( sah < best_sah || best_count == 0 )
2991                {
2992                    // yes, keep it
2993                    best_axis = axis;
2994                    best_sah = sah;
2995                    best_count = count_left[axis];
2996                }
2997            }
2998
2999            // write split information to SLM
3000            locals->axis_and_left_count[primref_position].x = best_axis;
3001            locals->axis_and_left_count[primref_position].y = best_count;
3002            uint sah = as_uint(best_sah);
3003            // break ties by axis to ensure deterministic split selection
3004            //  otherwise builder can produce non-deterministic tree structure run to run
3005            //  based on the ordering of primitives (which can vary due to non-determinism in atomic counters)
3006            // Embed split axis and index into sah value; compute min over sah and max over axis
3007            sah = ( ( sah & ~1023 ) | ( 2 - best_axis ) << 8 | tid );
3008
3009            // reduce on split candidates in our local subtree and decide the best one
3010            atomic_min_local( &locals->sah[ prim_range_start ], sah);
3011        }
3012
3013
3014        barrier( CLK_LOCAL_MEM_FENCE );
3015
3016        ushort split_index = locals->sah[ prim_range_start ] & 255;
3017        ushort split_axis = locals->axis_and_left_count[split_index].x;
3018        ushort split_left_count = locals->axis_and_left_count[split_index].y;
3019
3020        if ( (primref_position == split_index) && active_thread )
3021        {
3022            // first thread in a given subtree creates the inner node
3023            ushort start_left  = prim_range_start;
3024            ushort start_right = prim_range_start + split_left_count;
3025            if ( split_left_count == 0 )
3026                start_right = start_left + (local_num_prims / 2); // handle split-in-the-middle case
3027
3028            LocalBVH2_CreateInnerNode( bvh2, bvh2_root, start_left, start_right );
3029        }
3030
3031
3032        barrier( CLK_LOCAL_MEM_FENCE );
3033
3034        struct DFSPrimRef ref;
3035        ushort new_primref_position;
3036
3037        if ( active_thread )
3038        {
3039            ushort2 kids = LocalBVH2_GetChildIndices( bvh2, bvh2_root );
3040            bool go_left;
3041
3042            if ( split_left_count == 0 )
3043            {
3044                // We chose a split with no left-side prims
3045                //  This will only happen if all primrefs are located in the exact same position
3046                //   In that case, fall back to split-in-the-middle
3047                split_left_count = (local_num_prims / 2);
3048                go_left = (primref_position - prim_range_start < split_left_count);
3049            }
3050            else
3051            {
3052                // determine what side of the split this thread's primref belongs on
3053                local struct DFSPrimRefAABB* my_box    = PrimRefSet_GetAABBPointer( prim_refs, primref_position );
3054                local struct DFSPrimRefAABB* split_box = PrimRefSet_GetAABBPointer( prim_refs, split_index );
3055                float c = my_box->lower[split_axis] + my_box->upper[split_axis];
3056                float Csplit = split_box->lower[split_axis] + split_box->upper[split_axis];
3057                go_left = c < Csplit;
3058            }
3059
3060            // adjust state variables for next loop iteration
3061            bvh2_root = (go_left) ? kids.x : kids.y;
3062            local_num_prims = (go_left) ? split_left_count : (local_num_prims - split_left_count);
3063            prim_range_start = (go_left) ? prim_range_start : prim_range_start + split_left_count;
3064
3065            // determine the new primref position by incrementing a counter in the destination subtree
3066            new_primref_position = prim_range_start + LocalBVH2_IncrementPrimCount( bvh2, bvh2_root );
3067
3068            // load our primref from its previous position
3069            ref = PrimRefSet_GetPrimRef( prim_refs, primref_position );
3070        }
3071
3072        barrier( CLK_LOCAL_MEM_FENCE );
3073
3074        if ( active_thread )
3075        {
3076            // write our primref into its sorted position and note which node it went in
3077            DFSPrimRef_SetBVH2Root( &ref, bvh2_root );
3078            PrimRefSet_SetPrimRef( prim_refs, ref, new_primref_position );
3079            primref_position = new_primref_position;
3080
3081
3082            // deactivate all threads whose subtrees are small enough to form a leaf
3083            if ( local_num_prims <= TREE_ARITY )
3084            {
3085                active_thread = false;
3086                if( primref_position == prim_range_start )
3087                    atomic_sub_local( num_active_threads, local_num_prims );
3088            }
3089
3090            locals->range[ primref_position ].start           = prim_range_start;
3091            locals->range[ primref_position ].local_num_prims = local_num_prims;
3092            locals->range[ primref_position ].bvh2_root       = bvh2_root;
3093            locals->range[ primref_position ].active          = active_thread;
3094        }
3095
3096        barrier( CLK_LOCAL_MEM_FENCE );
3097
3098        // if we'll have next iteration then load from SLM
3099        if(*num_active_threads)
3100        {
3101            prim_range_start = locals->range[ tid ].start;
3102            local_num_prims  = locals->range[ tid ].local_num_prims;
3103            bvh2_root        = locals->range[ tid ].bvh2_root;
3104            active_thread    = locals->range[ tid ].active;
3105            primref_position = tid;
3106        }
3107        else
3108        {
3109            break;
3110        }
3111
3112    } while ( true );
3113
3114}
3115
3116
3117#define REFIT_BIT_DWORDS (LOCAL_BVH2_NODE_COUNT - DFS_WG_SIZE)/32
3118
3119struct RefitBits
3120{
3121    uint bits[REFIT_BIT_DWORDS];
3122};
3123
3124struct DFS_SLM
3125{
3126    union
3127    {
3128        struct LocalBVH2 bvh2;
3129        struct {
3130            struct AABB3f centroid_bounds;
3131            uint left_count;
3132            uint right_count;
3133            struct BFS_BinInfo bins;
3134            struct BFS_BinInfoReduce3_SLM reduce3;
3135        } binning;
3136
3137    } u1;
3138
3139    union
3140    {
3141        struct {
3142            struct PrimRefSet prim_refs;
3143            struct BVHBuildLocals locals;
3144        } pass0;
3145
3146        struct AABB3f node_boxes[LOCAL_BVH2_NODE_COUNT];
3147
3148    } u2;
3149
3150    union
3151    {
3152        uchar bytes[DFS_WG_SIZE];
3153        uint dwords[DFS_WG_SIZE/4];
3154    } mask_info;
3155
3156    struct RefitBits refit_bits;
3157
3158};
3159
3160
3161void DFS_InitialBinningPass(
3162    local struct BFS_BinInfo* bins,
3163    local struct BFS_BinInfoReduce3_SLM* reduce3,
3164    uniform local struct AABB3f* centroid_bounds,
3165    local struct PrimRefSet* refs,
3166    local uint* left_counter,
3167    local uint* right_counter,
3168    ushort num_refs )
3169{
3170    uint tid = get_local_id(0);
3171
3172    // initialize SLM structures
3173    if (tid == 0)
3174    {
3175        AABB3f_init(centroid_bounds);
3176        *left_counter = 0;
3177        *right_counter = 0;
3178    }
3179
3180    BinInfo_init(bins);
3181
3182    PrimRef ref;
3183    struct DFSPrimRef dfs_ref;
3184
3185    if (tid < num_refs)
3186    {
3187        dfs_ref = PrimRefSet_GetPrimRef(refs, tid);
3188        struct DFSPrimRefAABB box = dfs_ref.aabb;
3189        ref.lower.xyz = (float3)(box.lower[0], box.lower[1], box.lower[2]);
3190        ref.upper.xyz = (float3)(box.upper[0], box.upper[1], box.upper[2]);
3191    }
3192
3193    barrier(CLK_LOCAL_MEM_FENCE);
3194
3195    // compute centroid bounds so that we can bin
3196    if (tid < num_refs)
3197    {
3198        float3 centroid = ref.lower.xyz + ref.upper.xyz;
3199        Uniform_AABB3f_atomic_merge_local_sub_group_lu(centroid_bounds, centroid, centroid);
3200    }
3201
3202    barrier(CLK_LOCAL_MEM_FENCE);
3203
3204    // add primrefs to bins
3205    struct BinMapping mapping;
3206    BinMapping_init(&mapping, centroid_bounds, BFS_NUM_BINS);
3207
3208    BinInfo_add_primref( &mapping, bins, &ref, tid<num_refs );
3209
3210    barrier(CLK_LOCAL_MEM_FENCE);
3211
3212    // compute split - every sub_group computes different bin
3213    struct BFS_Split split = BinInfo_reduce3(reduce3, bins, mapping.scale);
3214
3215
3216    bool go_left = false;
3217    uint local_pos = 0;
3218    if (tid < num_refs)
3219    {
3220        // partition primrefs into L/R subsets...
3221        if (split.sah == (float)(INFINITY))      // no valid split, split in the middle.. This can happen due to floating-point limit cases in huge scenes
3222            go_left = tid < (num_refs / 2);
3223        else
3224            go_left = is_left(&mapping, &split, &ref);
3225
3226        if (go_left)
3227            local_pos = atomic_inc_local(left_counter);
3228        else
3229            local_pos = num_refs - (1+ atomic_inc_local(right_counter));
3230
3231        PrimRefSet_SetPrimRef(refs, dfs_ref, local_pos);
3232    }
3233
3234}
3235
3236
3237void Do_DFS( struct DFSArgs args, local struct DFS_SLM* slm, local uint* num_active_threads )
3238{
3239    local struct LocalBVH2* bvh2 = &slm->u1.bvh2;
3240
3241    global struct BVH2* global_bvh2 = args.global_bvh2;
3242
3243    PrimRef ref;
3244    uint parent_node;
3245
3246    {
3247        local struct BVHBuildLocals* locals = &slm->u2.pass0.locals;
3248        local struct PrimRefSet* prim_refs = &slm->u2.pass0.prim_refs;
3249
3250        DFS_CreatePrimRefSet(args, prim_refs);
3251
3252        uint local_id = get_local_id(0);
3253
3254        ushort bvh2_root = 0;
3255        ushort prim_range_start = 0;
3256        ushort local_num_prims = args.num_primrefs;
3257
3258        if(local_id == 0)
3259            *num_active_threads = local_num_prims;
3260
3261        // barrier for DFS_CreatePrimRefSet and num_active_threads
3262        barrier(CLK_LOCAL_MEM_FENCE);
3263
3264        // initial binning pass if number of primrefs is large
3265        if( args.num_primrefs > 32 )
3266        {
3267            DFS_InitialBinningPass(&slm->u1.binning.bins, &slm->u1.binning.reduce3, &slm->u1.binning.centroid_bounds, prim_refs,
3268                &slm->u1.binning.left_count, &slm->u1.binning.right_count, args.num_primrefs);
3269
3270            barrier(CLK_LOCAL_MEM_FENCE);
3271
3272            ushort left_count = slm->u1.binning.left_count;
3273            ushort right_count = args.num_primrefs - left_count;
3274            if (get_local_id(0) == 0)
3275                LocalBVH2_Initialize_Presplit(bvh2, args.num_primrefs, left_count, right_count);
3276
3277            bvh2_root        = (local_id < left_count) ? 1 : 2;
3278            local_num_prims = (local_id < left_count) ? left_count : right_count;
3279            prim_range_start = (local_id < left_count) ? 0 : left_count;
3280        }
3281        else
3282        {
3283            if (get_local_id(0) == 0)
3284                LocalBVH2_Initialize(bvh2, args.num_primrefs);
3285        }
3286
3287        DFS_ConstructBVH2( bvh2, prim_refs, bvh2_root, prim_range_start, local_num_prims, args.num_primrefs, locals, num_active_threads);
3288
3289        // move the prim refs into their sorted position
3290        //  keep this thread's primref around for later use
3291        if ( local_id < args.num_primrefs )
3292        {
3293            struct DFSPrimRef dfs_ref = PrimRefSet_GetPrimRef( prim_refs, local_id );
3294
3295            uint input_id = DFSPrimRef_GetInputIndex( &dfs_ref );
3296
3297            parent_node = DFSPrimRef_GetBVH2Parent( &dfs_ref );
3298
3299            uint primref_index = prim_refs->input_indices[input_id];
3300            ref = args.primref_buffer[primref_index];
3301            args.primref_indices_out[local_id] = primref_index;
3302            args.primref_indices_in[local_id] = primref_index;
3303            // these buffers are not read again until the end of kernel
3304        }
3305
3306        barrier( CLK_LOCAL_MEM_FENCE );
3307
3308    }
3309
3310
3311    // initialize flags for determining when subtrees are done refit
3312    if ( get_local_id( 0 ) < REFIT_BIT_DWORDS )
3313        slm->refit_bits.bits[get_local_id( 0 )] = 0;
3314
3315
3316    // stash full-precision primref AABBs in slm storage
3317    local struct AABB3f* slm_boxes = &slm->u2.node_boxes[0];
3318    bool active_thread = get_local_id( 0 ) < args.num_primrefs;
3319    if( active_thread )
3320    {
3321        AABB3f_set( &slm_boxes[get_local_id( 0 )], ref.lower.xyz, ref.upper.xyz );
3322
3323        // stash instance masks in SLM storage
3324        if( args.do_mask_processing )
3325            slm->mask_info.bytes[get_local_id(0)] = PRIMREF_instanceMask( &ref );
3326    }
3327
3328    barrier( CLK_LOCAL_MEM_FENCE );
3329
3330    // Refit leaf nodes
3331    uint box_index;
3332    if ( active_thread )
3333    {
3334        // the thread for the first primref in every leaf is the one that will ascend
3335        // remaining threads merge their AABB/mask into the first one and terminate
3336        uint first_ref = LocalBVH2_GetLeafPrimrefStart( bvh2, parent_node );
3337        if ( first_ref != get_local_id( 0 ) )
3338        {
3339            AABB3f_atomic_merge_local_lu( &slm_boxes[first_ref], ref.lower.xyz, ref.upper.xyz );
3340
3341            if( args.do_mask_processing )
3342            {
3343                uint dword_index = first_ref/4;
3344                uint shift       = (first_ref%4)*8;
3345                uint mask = PRIMREF_instanceMask(&ref) << shift;
3346                atomic_or_local( &slm->mask_info.dwords[dword_index], mask );
3347            }
3348            active_thread = false; // switch off all primref threads except the first one
3349        }
3350
3351        box_index = first_ref;
3352    }
3353
3354    barrier( CLK_LOCAL_MEM_FENCE );
3355
3356    if ( active_thread )
3357    {
3358        uint current_node = parent_node;
3359        parent_node = LocalBVH2_GetParent( bvh2, current_node );
3360
3361        // write out the leaf node's AABB
3362        uint num_prims = LocalBVH2_GetLeafPrimCount( bvh2, current_node );
3363        uint prim_offs = args.primref_base + LocalBVH2_GetLeafPrimrefStart( bvh2, current_node );
3364
3365        uint mask = 0xff;
3366        if( args.do_mask_processing )
3367            mask = slm->mask_info.bytes[box_index];
3368
3369        BVH2_WriteLeafNode( global_bvh2, args.global_bvh2_base + current_node, &slm_boxes[box_index], prim_offs, num_prims, mask );
3370
3371        // we no longer need the BVH2 bits for this node, so re-purpose the memory to store the AABB index
3372        bvh2->nodes[current_node] = box_index;
3373
3374        // toggle flag bit in parent node.  The second thread to flip the bit is the one that gets to proceed
3375        uint thread_mask = (1 << (parent_node % 32));
3376        if ( (atomic_xor_local( &slm->refit_bits.bits[parent_node / 32], thread_mask ) & thread_mask) == 0 )
3377            active_thread = false;
3378    }
3379
3380    // count how many active threads in sub_group we have and increment wg's number of active threads
3381    uint sg_active = sub_group_reduce_add(active_thread ? 1 : 0);
3382    if(get_sub_group_local_id() == 0)
3383    {
3384        atomic_add_local(num_active_threads, sg_active);
3385    }
3386
3387    // refit internal nodes:
3388    // walk up the tree and refit AABBs
3389
3390    do
3391    {
3392        barrier( CLK_LOCAL_MEM_FENCE ); // we need this barrier because we need to make sure all threads read num_active_threads before modifying it
3393        if ( active_thread )
3394        {
3395            uint current_node = parent_node;
3396            parent_node = LocalBVH2_GetParent( bvh2, current_node );
3397
3398            // pull left/right box indices from current node
3399            ushort2 kids = LocalBVH2_GetChildIndices( bvh2, current_node );
3400
3401            uint left_box = bvh2->nodes[kids.x];
3402            uint right_box = bvh2->nodes[kids.y];
3403
3404            struct AABB3f left = slm_boxes[left_box];
3405            struct AABB3f right = slm_boxes[right_box];
3406            AABB3f_extend( &left, &right );
3407
3408            uint2 child_offsets = (uint2)(
3409                args.global_bvh2_base + kids.x,
3410                args.global_bvh2_base + kids.y);
3411
3412            uint mask = 0xff;
3413            if( args.do_mask_processing )
3414            {
3415                mask = slm->mask_info.bytes[left_box]
3416                     | slm->mask_info.bytes[right_box];
3417                slm->mask_info.bytes[left_box] = mask;
3418            }
3419
3420            BVH2_WriteInnerNode( args.global_bvh2, args.global_bvh2_base+current_node, &left, child_offsets, mask );
3421
3422            slm_boxes[left_box] = left;
3423            bvh2->nodes[current_node] = left_box;
3424
3425            // stop at the root
3426            if ( LocalBVH2_IsRoot( bvh2, current_node ) )
3427            {
3428                active_thread = false;
3429                atomic_dec_local(num_active_threads);
3430            }
3431            else
3432            {
3433                // toggle flag bit in parent node.  The second thread to flip the bit is the one that gets to proceed
3434                uint mask = (1 << (parent_node % 32));
3435                if ( (atomic_xor_local( &slm->refit_bits.bits[parent_node / 32], mask ) & mask) == 0 )
3436                {
3437                    active_thread = false;
3438                    atomic_dec_local(num_active_threads);
3439                }
3440            }
3441        }
3442
3443        barrier( CLK_LOCAL_MEM_FENCE );
3444    } while ( *num_active_threads > 0 );
3445}
3446
3447
3448GRL_ANNOTATE_IGC_DO_NOT_SPILL
3449__attribute__( (reqd_work_group_size(DFS_WG_SIZE,1,1) ))
3450__attribute__( (intel_reqd_sub_group_size(16)) )
3451kernel void
3452DFS( global struct VContextScheduler* scheduler,
3453     global struct SAHBuildGlobals* globals_buffer )
3454{
3455    local struct DFS_SLM slm;
3456    local struct DFSDispatchRecord record;
3457    local uint num_active_threads;
3458
3459    if ( get_local_id( 0 ) == 0  )
3460    {
3461        // pop an entry off the DFS dispatch queue
3462        //uint wg_index = atomic_dec_global( &scheduler->num_dfs_wgs ) - 1;
3463        //record = scheduler->dfs_queue.records[wg_index];
3464
3465        // TODO:  The version above races, but is considerably faster... investigate
3466        uint wg_index = get_group_id(0);
3467        record = scheduler->dfs_queue.records[wg_index];
3468        write_mem_fence( CLK_LOCAL_MEM_FENCE );
3469        atomic_dec_global( &scheduler->num_dfs_wgs );
3470    }
3471
3472    barrier( CLK_LOCAL_MEM_FENCE );
3473
3474
3475    bool odd_pass = record.tree_depth & 1;
3476
3477    global struct SAHBuildGlobals* sah_globals = globals_buffer + record.batch_index;
3478
3479    struct DFSArgs args;
3480    args.num_primrefs = record.num_primrefs;
3481    args.primref_indices_in   = SAHBuildGlobals_GetPrimrefIndices_In( sah_globals, odd_pass );
3482    args.primref_indices_out  = SAHBuildGlobals_GetPrimrefIndices_Out( sah_globals, odd_pass );
3483    args.primref_buffer       = SAHBuildGlobals_GetPrimrefs( sah_globals );
3484    args.global_bvh2          = SAHBuildGlobals_GetBVH2( sah_globals );
3485    args.primref_indices_in  += record.primref_base;
3486    args.primref_indices_out += record.primref_base;
3487    args.primref_base         = record.primref_base;
3488    args.global_bvh2_base     = record.bvh2_base;
3489    args.do_mask_processing   = SAHBuildGlobals_NeedMasks( sah_globals );
3490
3491    Do_DFS( args, &slm, &num_active_threads );
3492
3493}
3494
3495
3496/////////////////////////////////////////////////////////////////////////////////////////////////
3497///
3498///        BVH2 to BVH6
3499///
3500/////////////////////////////////////////////////////////////////////////////////////////////////
3501
3502
3503
3504struct BuildFlatTreeArgs
3505{
3506    ushort leaf_size_in_bytes;
3507    ushort leaf_type;
3508    ushort inner_node_type;
3509    bool do_mask_processing;
3510
3511    global uint* primref_indices;
3512    global PrimRef* primref_buffer;
3513    global struct Globals* globals;
3514    global struct BVHBase* bvh_base;
3515    global struct BVH2* bvh2;
3516};
3517
3518
3519// lane i in the return value is the index of the ith largest primref in the input
3520// the return value can be used with shuffle() to move data into its sorted position
3521//  the elements of 'key' must be unique.. only the first 6 elements are sorted
3522varying ushort SUBGROUP_get_sort_indices_N6( varying uint key )
3523{
3524    // each lane computes the number of items larger than it
3525    // this is its position in the descending order
3526    //   TODO_OPT:  Compiler can vectorize these uint16 adds by packing into lower and upper halves of same GPR.... make sure it does it
3527    //     if compiler is not generating optimal code, consider moving to Cm
3528
3529    varying ushort cmp0 = (sub_group_broadcast(key, 0) > key) ? 1 : 0;
3530    varying ushort cmp1 = (sub_group_broadcast(key, 1) > key) ? 1 : 0;
3531    varying ushort cmp2 = (sub_group_broadcast(key, 2) > key) ? 1 : 0;
3532    varying ushort cmp3 = (sub_group_broadcast(key, 3) > key) ? 1 : 0;
3533    varying ushort cmp4 = (sub_group_broadcast(key, 4) > key) ? 1 : 0;
3534    varying ushort cmp5 = (sub_group_broadcast(key, 5) > key) ? 1 : 0;
3535    varying ushort a = cmp0 + cmp2 + cmp4;
3536    varying ushort b = cmp1 + cmp3 + cmp5;
3537    varying ushort num_larger = a + b;
3538
3539    // each lane determines which of the input elements it should pull
3540    varying ushort lane = get_sub_group_local_id();
3541    a  = (sub_group_broadcast(num_larger, 0) == lane) ? 0 : 0;
3542    b  = (sub_group_broadcast(num_larger, 1) == lane) ? 1 : 0;
3543    a += (sub_group_broadcast(num_larger, 2) == lane) ? 2 : 0;
3544    b += (sub_group_broadcast(num_larger, 3) == lane) ? 3 : 0;
3545    a += (sub_group_broadcast(num_larger, 4) == lane) ? 4 : 0;
3546    b += (sub_group_broadcast(num_larger, 5) == lane) ? 5 : 0;
3547    return a + b;
3548}
3549
3550uint SUBGROUP_area_to_sort_key( varying float area, uniform ushort num_children )
3551{
3552    varying ushort lane = get_sub_group_local_id();
3553    area = (lane < num_children) ? area : 0;        // put inactive nodes last
3554
3555    // drop LSBs and break ties by lane number to ensure unique keys
3556    // use descending lane IDs to ensure that sort is stable if the upper MSBs are equal.
3557    //     If we do not do this it can lead to non-deterministic tree structure
3558    return (as_uint(area) & 0xffffff80) + (lane^(get_sub_group_size()-1));
3559}
3560
3561// lane i in the return value is the index of the ith largest primref in the input
3562// the return value can be used with shuffle() to move data into its sorted position
3563//  the elements of 'key' must be unique.. only the first 6 elements are sorted
3564varying ushort SUBGROUP_get_sort_indices_N6_2xSIMD8_in_SIMD16( varying uint key )
3565{
3566    // each lane computes the number of items larger than it
3567    // this is its position in the descending order
3568    //   TODO_OPT:  Compiler can vectorize these uint16 adds by packing into lower and upper halves of same GPR.... make sure it does it
3569    //     if compiler is not generating optimal code, consider moving to Cm
3570
3571    varying ushort cmp0 = (sub_group_broadcast(key, 0) > key) ? 1 : 0;
3572    varying ushort cmp1 = (sub_group_broadcast(key, 1) > key) ? 1 : 0;
3573    varying ushort cmp2 = (sub_group_broadcast(key, 2) > key) ? 1 : 0;
3574    varying ushort cmp3 = (sub_group_broadcast(key, 3) > key) ? 1 : 0;
3575    varying ushort cmp4 = (sub_group_broadcast(key, 4) > key) ? 1 : 0;
3576    varying ushort cmp5 = (sub_group_broadcast(key, 5) > key) ? 1 : 0;
3577    varying ushort a = cmp0 + cmp2 + cmp4;
3578    varying ushort b = cmp1 + cmp3 + cmp5;
3579    varying ushort num_larger = a + b;
3580
3581    varying ushort cmp0_1 = (sub_group_broadcast(key, 8) > key) ? 1 : 0;
3582    varying ushort cmp1_1 = (sub_group_broadcast(key, 9) > key) ? 1 : 0;
3583    varying ushort cmp2_1 = (sub_group_broadcast(key, 10) > key) ? 1 : 0;
3584    varying ushort cmp3_1 = (sub_group_broadcast(key, 11) > key) ? 1 : 0;
3585    varying ushort cmp4_1 = (sub_group_broadcast(key, 12) > key) ? 1 : 0;
3586    varying ushort cmp5_1 = (sub_group_broadcast(key, 13) > key) ? 1 : 0;
3587    varying ushort a_1 = cmp0_1 + cmp2_1 + cmp4_1;
3588    varying ushort b_1 = cmp1_1 + cmp3_1 + cmp5_1;
3589    varying ushort num_larger_1 = a_1 + b_1;
3590
3591    // each lane determines which of the input elements it should pull
3592    varying ushort lane = get_sub_group_local_id();
3593    if(lane < 8)
3594    {
3595        a  = (sub_group_broadcast(num_larger, 0) == lane) ? 0 : 0;
3596        b  = (sub_group_broadcast(num_larger, 1) == lane) ? 1 : 0;
3597        a += (sub_group_broadcast(num_larger, 2) == lane) ? 2 : 0;
3598        b += (sub_group_broadcast(num_larger, 3) == lane) ? 3 : 0;
3599        a += (sub_group_broadcast(num_larger, 4) == lane) ? 4 : 0;
3600        b += (sub_group_broadcast(num_larger, 5) == lane) ? 5 : 0;
3601    }
3602    else
3603    {
3604        a  = (sub_group_broadcast(num_larger_1, 8)  == lane-8) ? 8 : 8;
3605        b  = (sub_group_broadcast(num_larger_1, 9)  == lane-8) ? 1 : 0;
3606        a += (sub_group_broadcast(num_larger_1, 10) == lane-8) ? 2 : 0;
3607        b += (sub_group_broadcast(num_larger_1, 11) == lane-8) ? 3 : 0;
3608        a += (sub_group_broadcast(num_larger_1, 12) == lane-8) ? 4 : 0;
3609        b += (sub_group_broadcast(num_larger_1, 13) == lane-8) ? 5 : 0;
3610    }
3611
3612    return a + b;
3613}
3614
3615uint SUBGROUP_area_to_sort_key_2xSIMD8_in_SIMD16( varying float area, uniform ushort num_children )
3616{
3617    varying ushort lane = get_sub_group_local_id() % 8;
3618    area = (lane < num_children) ? area : 0;        // put inactive nodes last
3619
3620    // drop LSBs and break ties by lane number to ensure unique keys
3621    // use descending lane IDs to ensure that sort is stable if the upper MSBs are equal.
3622    //     If we do not do this it can lead to non-deterministic tree structure
3623    return (as_uint(area) & 0xffffff80) + (lane^7);
3624}
3625
3626ushort SUBGROUP_BuildFlatTreeNode(
3627    uniform struct BuildFlatTreeArgs args,
3628    uniform uint bvh2_root,
3629    uniform struct InternalNode* qnode,
3630    uniform uint qnode_index,
3631    varying uint3* sg_children_out // if an inner node is created, receives the indices of the 6 child nodes (X), and the QNode position (y), and num_children(z)
3632                                   //  if a leaf is created, receives number of primrefs (z)
3633) // return value is the number of child nodes or 0 for a leaf
3634{
3635    global struct BVH2* bvh2 = args.bvh2;
3636    varying ushort lane = get_sub_group_local_id();
3637
3638    global struct BVHBase* base = args.bvh_base;
3639
3640
3641    if ( !BVH2_IsInnerNode( bvh2, bvh2_root ) )
3642    {
3643        uniform ushort num_prims   = BVH2_GetLeafPrimCount( bvh2, bvh2_root );
3644        uniform uint primref_start = BVH2_GetLeafPrimStart( bvh2, bvh2_root );
3645        varying uint primref_index = primref_start + ((lane < num_prims) ? lane : 0);
3646
3647        varying uint ref_id = args.primref_indices[primref_index];
3648        varying PrimRef ref = args.primref_buffer[ref_id];
3649        uniform char* leaf_mem_base = (char*)BVHBase_GetQuadLeaves( args.bvh_base );
3650        uniform char* leaf_mem = leaf_mem_base + primref_start * args.leaf_size_in_bytes;
3651
3652        uniform int offset = (int)(leaf_mem - (char*)qnode);
3653        offset = offset >> 6;
3654
3655        varying uint key = SUBGROUP_area_to_sort_key(AABB_halfArea(&ref), num_prims );
3656        varying ushort sort_index = SUBGROUP_get_sort_indices_N6(key);
3657        ref = PrimRef_sub_group_shuffle(&ref, sort_index);
3658        ref_id = intel_sub_group_shuffle(ref_id, sort_index);
3659
3660        if (lane < num_prims)
3661            args.primref_indices[primref_index] = ref_id;
3662
3663        uint global_num_prims = args.globals->numPrimitives;
3664        char* bvh_mem = (char*) args.bvh_base;
3665
3666        if(lane < num_prims)
3667            args.primref_indices[primref_index + global_num_prims] = qnode - (struct InternalNode*)bvh_mem;
3668
3669        if (args.leaf_type == NODE_TYPE_INSTANCE)
3670            subgroup_setInstanceQBVHNodeN( offset, &ref, num_prims, (struct QBVHNodeN*)qnode, lane < num_prims ? PRIMREF_instanceMask(&ref) : 0 );
3671        else
3672            subgroup_setQBVHNodeN( offset, args.leaf_type, &ref, num_prims, (struct QBVHNodeN*)qnode, BVH_NODE_DEFAULT_MASK );
3673
3674        sg_children_out->z = num_prims;
3675        return 0;
3676    }
3677    else
3678    {
3679        // collapse BVH2 into BVH6.
3680        // We will spread the root node's children across the subgroup, and keep adding SIMD lanes until we have enough
3681        uniform ushort num_children = 2;
3682
3683        uniform uint2 kids = BVH2_GetChildIndices( bvh2, bvh2_root );
3684        varying uint sg_bvh2_node = kids.x;
3685        if ( lane == 1 )
3686            sg_bvh2_node = kids.y;
3687
3688        do
3689        {
3690            // choose the inner node with maximum area to replace.
3691            // Its left child goes in its old location.  Its right child goes in a new lane
3692
3693            // TODO_OPT:  We re-read the AABBs again and again to compute area
3694            //   ... store per-lane boxes instead and pre-compute areas
3695
3696            varying float sg_area = BVH2_GetNodeArea( bvh2, sg_bvh2_node );
3697            varying bool sg_is_inner = BVH2_IsInnerNode( bvh2, sg_bvh2_node );
3698            sg_area = (sg_is_inner && lane < num_children) ? sg_area : 0; // prevent early exit if the largest child is a leaf
3699
3700            uniform float max_area = sub_group_reduce_max_N6( sg_area );
3701            varying bool sg_reducable = max_area == sg_area && (lane < num_children) && sg_is_inner;
3702            uniform uint mask = intel_sub_group_ballot( sg_reducable );
3703
3704            // TODO_OPT:  Some of these ops seem redundant.. look at trimming further
3705
3706            if ( mask == 0 )
3707                break;
3708
3709            // choose the inner node with maximum area to replace
3710            uniform ushort victim_child = ctz( mask );
3711            uniform uint victim_node = sub_group_broadcast( sg_bvh2_node, victim_child );
3712            kids = BVH2_GetChildIndices( bvh2, victim_node );
3713
3714            if ( lane == victim_child )
3715                sg_bvh2_node = kids.x;
3716            else if ( lane == num_children )
3717                sg_bvh2_node = kids.y;
3718
3719            num_children++;
3720
3721        } while ( num_children < TREE_ARITY );
3722
3723        // allocate inner node space
3724        uniform uint kids_offset;
3725        if (get_sub_group_local_id() == 0)
3726            kids_offset = allocate_inner_nodes( args.bvh_base, num_children );
3727        kids_offset = sub_group_broadcast(kids_offset, 0);
3728
3729        uniform struct QBVHNodeN* kid = (((struct QBVHNodeN*)args.bvh_base) + kids_offset);
3730        uniform int offset = (int)((char*)kid - (char*)qnode) >> 6;
3731
3732#if 0
3733        uniform uint kids_offset;
3734        if ( get_sub_group_local_id() == 0 )
3735            kids_offset = alloc_node_mem( args.globals, sizeof( struct QBVHNodeN ) * num_children );
3736        kids_offset = sub_group_broadcast( kids_offset, 0 );
3737
3738
3739        // create inner node
3740        uniform struct QBVHNodeN* kid = (struct QBVHNodeN*) ((char*)(args.bvh_base) + kids_offset);
3741        uniform int offset = (int)((char*)kid - (char*)qnode) >> 6;
3742#endif
3743        uniform uint child_type = args.inner_node_type;
3744
3745        // sort child nodes in descending order by AABB area
3746        varying struct AABB box   = BVH2_GetAABB( bvh2, sg_bvh2_node );
3747        varying uint key          = SUBGROUP_area_to_sort_key(AABB_halfArea(&box), num_children );
3748        varying ushort sort_index = SUBGROUP_get_sort_indices_N6(key);
3749        box          = AABB_sub_group_shuffle(&box, sort_index);
3750        sg_bvh2_node = intel_sub_group_shuffle(sg_bvh2_node, sort_index);
3751
3752        uniform uint node_mask = (args.do_mask_processing) ? BVH2_GetMask( bvh2, bvh2_root ) : 0xff;
3753
3754        subgroup_setQBVHNodeN( offset, child_type, &box, num_children, (struct QBVHNodeN*)qnode, node_mask );
3755
3756        // return child information
3757        *sg_children_out = (uint3)(sg_bvh2_node, qnode_index + offset + get_sub_group_local_id(), num_children );
3758        return num_children;
3759    }
3760}
3761
3762ushort SUBGROUP_BuildFlatTreeNode_2xSIMD8_in_SIMD16(
3763    uniform struct BuildFlatTreeArgs args,
3764    varying uint bvh2_root,
3765    varying struct InternalNode* qnode_base,
3766    varying uint qnode_index,
3767    varying uint3* sg_children_out, // if an inner node is created, receives the indices of the 6 child nodes (X), and the QNode position (y), and num_children(z)
3768                                   //  if a leaf is created, receives number of primrefs (z)
3769    bool active_lane
3770) // return value is the number of child nodes or 0 for a leaf
3771{
3772    global struct BVH2* bvh2 = args.bvh2;
3773    varying ushort SIMD16_lane = get_sub_group_local_id();
3774    varying ushort SIMD8_lane = get_sub_group_local_id() % 8;
3775    varying ushort SIMD8_id = get_sub_group_local_id() / 8;
3776    varying ushort lane = get_sub_group_local_id();
3777    global struct BVHBase* base = args.bvh_base;
3778
3779    struct BVH2NodeMetaData nodeMetaData = BVH2_GetNodeMetaData( bvh2, bvh2_root );
3780
3781    bool is_leaf = active_lane && !BVH2NodeMetaData_IsInnerNode( &nodeMetaData );
3782    bool is_inner = active_lane && BVH2NodeMetaData_IsInnerNode( &nodeMetaData );
3783
3784    uchar mask = BVH_NODE_DEFAULT_MASK;
3785    if(is_inner)
3786        mask = (args.do_mask_processing) ? BVH2NodeMetaData_GetMask( &nodeMetaData ) : 0xff;
3787
3788    int offset;
3789
3790    varying struct InternalNode* qnode = qnode_base + qnode_index;
3791    // TOOD: we don't need unions, I left them only for readability
3792    union {
3793        uint num_prims;
3794        uint num_children;
3795    } lane_num_data;
3796
3797    union {
3798        PrimRef ref; // this is in fact AABB
3799        struct AABB box;
3800    } lane_box_data;
3801
3802    union {
3803        uint ref_id;
3804        uint sg_bvh2_node;
3805    } lane_id_data;
3806
3807    // for leafs
3808    varying uint primref_index;
3809
3810    if(is_leaf)
3811    {
3812        lane_num_data.num_prims   = BVH2NodeMetaData_GetLeafPrimCount( &nodeMetaData );
3813        uint primref_start = BVH2NodeMetaData_GetLeafPrimStart( &nodeMetaData );
3814        primref_index = primref_start + ((SIMD8_lane < lane_num_data.num_prims) ? SIMD8_lane : 0);
3815
3816        lane_id_data.ref_id = args.primref_indices[primref_index];
3817        lane_box_data.ref = args.primref_buffer[lane_id_data.ref_id];
3818        char* leaf_mem_base = (char*)BVHBase_GetQuadLeaves( args.bvh_base );
3819        char* leaf_mem = leaf_mem_base + primref_start * args.leaf_size_in_bytes;
3820
3821        offset = (int)(leaf_mem - (char*)qnode);
3822        offset = offset >> 6;
3823    }
3824
3825
3826    if(intel_sub_group_ballot(is_inner))
3827    {
3828        // collapse BVH2 into BVH6.
3829        // We will spread the root node's children across the subgroup, and keep adding SIMD lanes until we have enough
3830
3831        uint2 kids;
3832        if(is_inner)
3833        {
3834            lane_num_data.num_children = 2;
3835            kids = BVH2_GetChildIndices( bvh2, bvh2_root );
3836
3837            lane_id_data.sg_bvh2_node = kids.x;
3838            if ( SIMD8_lane == 1 )
3839                lane_id_data.sg_bvh2_node = kids.y;
3840        }
3841
3842        bool active = is_inner;
3843        do
3844        {
3845            // choose the inner node with maximum area to replace.
3846            // Its left child goes in its old location.  Its right child goes in a new lane
3847
3848            // TODO_OPT:  We re-read the AABBs again and again to compute area
3849            //   ... store per-lane boxes instead and pre-compute areas
3850
3851            varying float sg_area = 0;
3852            varying bool sg_is_inner = false;
3853            if(active)
3854            {
3855                sg_area = BVH2_GetNodeArea( bvh2, lane_id_data.sg_bvh2_node );
3856                sg_is_inner = BVH2_IsInnerNode( bvh2, lane_id_data.sg_bvh2_node );
3857                sg_area = (sg_is_inner && SIMD8_lane < lane_num_data.num_children) ? sg_area : 0; // prevent early exit if the largest child is a leaf
3858            }
3859
3860            float max_area = sub_group_reduce_max_N6_2xSIMD8_in_SIMD16( sg_area );
3861            varying bool sg_reducable = max_area == sg_area && sg_is_inner && (SIMD8_lane < lane_num_data.num_children);
3862            uint mask = intel_sub_group_ballot( sg_reducable ) & (0xFF << SIMD8_id * 8); // we'll end up with two different masks for two SIMD8 in SIMD16 due to bits masking
3863
3864            // TODO_OPT:  Some of these ops seem redundant.. look at trimming further
3865
3866            if ( mask == 0 )
3867                active = false;
3868
3869            // choose the inner node with maximum area to replace
3870            ushort victim_child = ctz( mask );
3871            uint victim_node = intel_sub_group_shuffle( lane_id_data.sg_bvh2_node, victim_child );
3872            if(active)
3873            {
3874                kids = BVH2_GetChildIndices( bvh2, victim_node );
3875
3876                if ( SIMD16_lane == victim_child ) // we use SIMD16_lane, cause victim_child was calculated based on SIMD16 i.e. second node will have victim from 8..13
3877                    lane_id_data.sg_bvh2_node = kids.x;
3878                else if ( SIMD8_lane == lane_num_data.num_children )
3879                    lane_id_data.sg_bvh2_node = kids.y;
3880
3881                lane_num_data.num_children++;
3882
3883                if(lane_num_data.num_children >= TREE_ARITY)
3884                    active = false;
3885            }
3886
3887        } while ( intel_sub_group_ballot(active) ); // if any active, then continue
3888
3889        // sum children from both halfs of SIMD16 to allocate nodes only once per sub_group
3890        uniform ushort num_children = is_inner ? lane_num_data.num_children : 0;
3891        uniform ushort first_SIMD8_num_children = sub_group_broadcast(num_children, 0);
3892        uniform ushort second_SIMD8_num_children = sub_group_broadcast(num_children, 8);
3893
3894        num_children = first_SIMD8_num_children + second_SIMD8_num_children;
3895        uint kids_offset;
3896
3897        // allocate inner node space
3898        if(num_children && SIMD16_lane == 0)
3899            kids_offset = allocate_inner_nodes( args.bvh_base, num_children );
3900        kids_offset = sub_group_broadcast(kids_offset, 0);
3901        if((is_inner))
3902        {
3903            kids_offset += SIMD8_id * first_SIMD8_num_children;
3904
3905            struct QBVHNodeN* kid = (((struct QBVHNodeN*)args.bvh_base) + kids_offset);
3906
3907            offset = (int)((char*)kid - (char*)qnode) >> 6;
3908            lane_box_data.box = BVH2_GetAABB( bvh2, lane_id_data.sg_bvh2_node );
3909        }
3910    }
3911
3912    // sort child nodes in descending order by AABB area
3913    varying uint key          = SUBGROUP_area_to_sort_key_2xSIMD8_in_SIMD16(AABB_halfArea(&lane_box_data.box), lane_num_data.num_children );
3914    varying ushort sort_index = SUBGROUP_get_sort_indices_N6_2xSIMD8_in_SIMD16(key);
3915    lane_box_data.box         = PrimRef_sub_group_shuffle(&lane_box_data.box, sort_index);
3916    lane_id_data.sg_bvh2_node = intel_sub_group_shuffle(lane_id_data.sg_bvh2_node, sort_index);
3917
3918    char* bvh_mem = (char*) args.bvh_base;
3919    if (is_leaf && SIMD8_lane < lane_num_data.num_prims)
3920    {
3921        args.primref_indices[primref_index] = lane_id_data.ref_id;
3922        args.primref_indices[primref_index + args.globals->numPrimitives] = qnode - (struct InternalNode*)bvh_mem;
3923    }
3924
3925    bool degenerated = false;
3926    uint node_type = is_leaf ? args.leaf_type : args.inner_node_type;
3927
3928    if(args.leaf_type == NODE_TYPE_INSTANCE)
3929        degenerated = subgroup_setInstanceBox_2xSIMD8_in_SIMD16(&lane_box_data.box, lane_num_data.num_children, &mask, SIMD8_lane < lane_num_data.num_prims ? PRIMREF_instanceMask(&lane_box_data.ref) : 0, is_leaf);
3930
3931    subgroup_setQBVHNodeN_setFields_2xSIMD8_in_SIMD16(offset, node_type, &lane_box_data.box, lane_num_data.num_children, mask, (struct QBVHNodeN*)(qnode), degenerated, active_lane);
3932
3933    // return child information
3934    if(is_inner)
3935    {
3936        sg_children_out->x = lane_id_data.sg_bvh2_node;
3937        sg_children_out->y = qnode_index + offset + SIMD8_lane;
3938    }
3939
3940    sg_children_out->z = lane_num_data.num_children;
3941
3942    return is_inner ? lane_num_data.num_children : 0;
3943}
3944
3945void check_primref_integrity( global struct SAHBuildGlobals* globals )
3946{
3947    global uint* primref_in = SAHBuildGlobals_GetPrimrefIndices_In( globals, 0 );
3948    global uint* primref_out = SAHBuildGlobals_GetPrimrefIndices_Out( globals, 0 );
3949    dword num_primrefs = SAHBuildGlobals_GetTotalPrimRefs( globals );
3950    if ( get_local_id( 0 ) == 0 )
3951    {
3952        for ( uint i = 0; i < num_primrefs; i++ )
3953        {
3954            primref_out[i] = 0;
3955        }
3956
3957        for ( uint i = 0; i < num_primrefs; i++ )
3958            primref_out[primref_in[i]]++;
3959
3960        for ( uint i = 0; i < num_primrefs; i++ )
3961            if ( primref_out[i] != 1 )
3962                printf( "Foo: %u   %u\n", i, primref_out[i] );
3963    }
3964}
3965
3966
3967
3968
3969void check_bvh2(global struct SAHBuildGlobals* globals )
3970{
3971    global struct BVH2* bvh2 = SAHBuildGlobals_GetBVH2(globals);
3972    global uint* primref_in = SAHBuildGlobals_GetPrimrefIndices_In(globals, 0);
3973    global uint* primref_out = SAHBuildGlobals_GetPrimrefIndices_Out(globals, 0);
3974    dword num_primrefs = SAHBuildGlobals_GetTotalPrimRefs(globals);
3975
3976    if (get_local_id(0) == 0)
3977    {
3978        for (uint i = 0; i < num_primrefs; i++)
3979            primref_out[i] = 0;
3980
3981        uint stack[256];
3982        uint sp=0;
3983        uint r = BVH2_GetRoot(bvh2);
3984        stack[sp++] = r;
3985        while (sp)
3986        {
3987            r = stack[--sp];
3988            if (BVH2_IsInnerNode(bvh2,r))
3989            {
3990                uint2 kids = BVH2_GetChildIndices( bvh2, r);
3991                if (kids.x >= bvh2->num_nodes || kids.y >= bvh2->num_nodes)
3992                {
3993                    printf("BVH2!! Bad node index found!\n");
3994                    return;
3995                }
3996
3997                stack[sp++] = kids.x;
3998                stack[sp++] = kids.y;
3999            }
4000            else
4001            {
4002                uint ref = BVH2_GetLeafPrimStart(bvh2,r);
4003                uint count = BVH2_GetLeafPrimCount(bvh2,r);
4004                if( count == 0 )
4005                {
4006                    printf("BVH2!! Empty leaf found!\n");
4007                    return;
4008                }
4009                for (uint i = 0; i < count; i++)
4010                {
4011                    if (ref + i > num_primrefs)
4012                    {
4013                        printf("BVH2!! Bad leaf range!\n");
4014                        return;
4015                    }
4016                    uint c = primref_out[ref+i];
4017                    if (c != 0)
4018                    {
4019                        printf("BVH2!! overlapped prim ranges\n");
4020                        return;
4021                    }
4022                    primref_out[ref+i] = 1;
4023                    if (primref_in[ref + i] >= num_primrefs)
4024                    {
4025                        printf("BAD PRIMREF ID FOUND!\n");
4026                        return;
4027                    }
4028                }
4029            }
4030        }
4031    }
4032
4033    printf("bvh2 is ok!\n");
4034}
4035
4036
4037#if 0
4038// TODO_OPT:  Enable larger WGs.  WGSize 512 at SIMD8 hangs on Gen9, but Gen12 can go bigger
4039GRL_ANNOTATE_IGC_DO_NOT_SPILL
4040__attribute__( (reqd_work_group_size(256,1,1)) )
4041__attribute__( (intel_reqd_sub_group_size(8) ) )
4042kernel void
4043build_qnodes( global struct SAHBuildGlobals* globals, global struct VContextScheduler* scheduler )
4044{
4045    globals = globals + (scheduler->num_trivial_builds + scheduler->num_single_builds);
4046    globals = globals + get_group_id(0);
4047
4048
4049    struct BuildFlatTreeArgs args;
4050    args.leaf_size_in_bytes = SAHBuildGlobals_GetLeafSizeInBytes( globals );
4051    args.leaf_type          = SAHBuildGlobals_GetLeafType( globals );
4052    args.inner_node_type    = SAHBuildGlobals_GetInternalNodeType( globals );
4053    args.primref_indices    = SAHBuildGlobals_GetPrimrefIndices_In( globals, 0 );
4054    args.primref_buffer     = SAHBuildGlobals_GetPrimrefs( globals );
4055    args.bvh_base           = SAHBuildGlobals_GetBVHBase( globals );
4056    args.bvh2               = SAHBuildGlobals_GetBVH2( globals );
4057    args.globals            = (global struct Globals*) globals->p_globals;
4058    args.do_mask_processing = SAHBuildGlobals_NeedMasks( globals );
4059
4060    dword alloc_backpointers = SAHBuildGlobals_NeedBackPointers( globals );
4061    global uint2* root_buffer = (global uint2*) globals->p_qnode_root_buffer;
4062    global struct InternalNode* qnodes = (global struct InternalNode*) BVHBase_GetInternalNodes( args.bvh_base );
4063    global uint* back_pointers = (global uint*) BVHBase_GetBackPointers( args.bvh_base );
4064
4065    local uint nodes_produced;
4066    if ( get_sub_group_id() == 0 )
4067    {
4068        // allocate first node
4069        if (get_sub_group_local_id() == 0)
4070            allocate_inner_nodes( args.bvh_base, 1 );
4071
4072        // first subgroup does first node
4073        varying uint3 children_info;
4074        uniform ushort num_children = SUBGROUP_BuildFlatTreeNode(args, BVH2_GetRoot(args.bvh2), qnodes, 0, &children_info );
4075
4076        if ( get_sub_group_local_id() < num_children )
4077            root_buffer[get_sub_group_local_id()] = children_info.xy;
4078
4079        if ( alloc_backpointers )
4080        {
4081            // set root's backpointer
4082            if( get_sub_group_local_id() == 0 )
4083                back_pointers[0] = (0xffffffc0) | (children_info.z << 3);
4084
4085            // point child backpointers at the parent
4086            if( get_sub_group_local_id() < num_children )
4087                back_pointers[children_info.y] = 0;
4088        }
4089
4090        if ( get_sub_group_local_id() == 0 )
4091            nodes_produced = num_children;
4092    }
4093
4094    barrier( CLK_LOCAL_MEM_FENCE | CLK_GLOBAL_MEM_FENCE );
4095
4096
4097    uniform uint buffer_index = get_sub_group_id();
4098    uniform bool sg_active    = buffer_index < nodes_produced;
4099
4100    while ( work_group_any( sg_active ) )
4101    {
4102        if( sg_active )
4103        {
4104            uniform uint bvh2_node    = root_buffer[buffer_index].x;
4105            uniform uint qnode_index  = root_buffer[buffer_index].y;
4106
4107            // build a node
4108            varying uint3 children_info;
4109            uniform ushort num_children = SUBGROUP_BuildFlatTreeNode( args, bvh2_node, qnodes + qnode_index, qnode_index, &children_info );
4110
4111            // handle backpointers
4112            if ( alloc_backpointers )
4113            {
4114                // update this node's backpointer with child count
4115                if ( get_sub_group_local_id() == 0 )
4116                    back_pointers[qnode_index] |= (children_info.z << 3);
4117
4118                // point child backpointers at parent
4119                if ( get_sub_group_local_id() < num_children )
4120                    back_pointers[children_info.y] = (qnode_index << 6);
4121            }
4122
4123            if ( num_children )
4124            {
4125                // allocate space in the child buffer
4126                uint root_buffer_position = 0;
4127                if ( get_sub_group_local_id() == 0 )
4128                    root_buffer_position = atomic_add_local( &nodes_produced, num_children );
4129                root_buffer_position = sub_group_broadcast( root_buffer_position, 0 );
4130
4131                // store child indices in root buffer
4132                if ( get_sub_group_local_id() < num_children )
4133                    root_buffer[root_buffer_position + get_sub_group_local_id()] = children_info.xy;
4134            }
4135        }
4136
4137        // sync everyone
4138        work_group_barrier( CLK_LOCAL_MEM_FENCE | CLK_GLOBAL_MEM_FENCE,
4139                            memory_scope_work_group );
4140
4141
4142        if( sg_active )
4143            buffer_index += get_num_sub_groups();
4144
4145        sg_active = (buffer_index < nodes_produced);
4146    }
4147}
4148#endif
4149
4150
4151
4152
4153
4154
4155
4156inline bool buffer_may_overflow( uint capacity, uint current_size, uint elements_processed_per_sub_group )
4157{
4158    uint num_consumed = min( get_num_sub_groups() * elements_processed_per_sub_group, current_size );
4159    uint space_available = (capacity - current_size) + num_consumed;
4160    uint space_needed = TREE_ARITY * num_consumed;
4161    return space_available < space_needed;
4162}
4163
4164inline uint build_qnodes_pc(
4165    global struct SAHBuildGlobals* globals,
4166    bool alloc_backpointers,
4167    bool process_masks,
4168    uint first_qnode,
4169    uint first_bvh2_node,
4170
4171    local uint2* SLM_local_root_buffer,
4172    local uint* SLM_ring_tail,
4173    const uint  RING_SIZE
4174)
4175
4176{
4177    struct BuildFlatTreeArgs args;
4178    args.leaf_size_in_bytes = SAHBuildGlobals_GetLeafSizeInBytes( globals );
4179    args.leaf_type = SAHBuildGlobals_GetLeafType( globals );
4180    args.inner_node_type = SAHBuildGlobals_GetInternalNodeType( globals );
4181    args.primref_indices = SAHBuildGlobals_GetPrimrefIndices_In( globals, 0 );
4182    args.primref_buffer = SAHBuildGlobals_GetPrimrefs( globals );
4183    args.bvh_base = SAHBuildGlobals_GetBVHBase( globals );
4184    args.bvh2 = SAHBuildGlobals_GetBVH2( globals );
4185    args.globals = (global struct Globals*) globals->p_globals;
4186    args.do_mask_processing = process_masks;
4187
4188    global struct InternalNode* qnodes = (global struct InternalNode*) BVHBase_GetInternalNodes( args.bvh_base );
4189    global uint* back_pointers = (global uint*) BVHBase_GetBackPointers( args.bvh_base );
4190
4191    // first subgroup adds first node
4192    if ( get_sub_group_id() == 0 && get_sub_group_local_id() == 0)
4193    {
4194        SLM_local_root_buffer[0].x = first_bvh2_node;
4195        SLM_local_root_buffer[0].y = first_qnode;
4196        *SLM_ring_tail = 1;
4197
4198    }
4199
4200    uint ring_head = 0;
4201    uint ring_tail = 1;
4202    uint ring_size = 1;
4203
4204    barrier( CLK_LOCAL_MEM_FENCE );
4205
4206    const uniform uint elements_processed_in_sg = 2;
4207
4208    while ( ring_size > 0 && !buffer_may_overflow( RING_SIZE, ring_size, elements_processed_in_sg ) )
4209    {
4210        ushort SIMD16_lane = get_sub_group_local_id();
4211
4212        // SIMD16 as 2xSIMD8
4213        ushort SIMD8_lane = get_sub_group_local_id() % 8;
4214        ushort SIMD8_id = get_sub_group_local_id() / 8;
4215        bool active_lane;
4216
4217        uniform uint nodes_consumed = min( get_num_sub_groups() * elements_processed_in_sg, ring_size ); // times two because we process two nodes in subgroup
4218        uniform bool sg_active = get_sub_group_id() * elements_processed_in_sg < nodes_consumed;
4219        ushort num_children = 0;
4220        varying uint3 children_info = 0;
4221
4222        uint bvh2_node = 0;
4223        uint qnode_index = 0;
4224
4225        if (sg_active)
4226        {
4227            ushort consumed_pos = get_sub_group_id() * elements_processed_in_sg + SIMD8_id;
4228            active_lane = consumed_pos < nodes_consumed ? true : false;
4229            consumed_pos = consumed_pos < nodes_consumed ? consumed_pos : consumed_pos-1;
4230
4231            uint buffer_index = (ring_head + consumed_pos) % RING_SIZE;
4232
4233            bvh2_node = SLM_local_root_buffer[buffer_index].x;
4234            qnode_index = SLM_local_root_buffer[buffer_index].y;
4235        }
4236
4237        barrier( CLK_LOCAL_MEM_FENCE );
4238
4239        if (sg_active)
4240        {
4241            // build a node
4242            num_children = SUBGROUP_BuildFlatTreeNode_2xSIMD8_in_SIMD16(args, bvh2_node, qnodes, qnode_index, &children_info, active_lane);
4243
4244            // handle backpointers
4245            // TODO_OPT:  This should be separate shaders not a runtime branch
4246            //     doing it this way for now because GRLTLK does not make dynamic shader selection on host very easy.
4247            //     this needs to change... GRLTLK should
4248
4249            if (alloc_backpointers && active_lane)
4250            {
4251                // update this node's backpointer with child count
4252                if (SIMD8_lane == 0)
4253                    back_pointers[qnode_index] |= (children_info.z << 3);
4254
4255                // point child backpointers at parent
4256                if (SIMD8_lane < num_children)
4257                    back_pointers[children_info.y] = (qnode_index << 6);
4258            }
4259
4260            // save data
4261
4262            uniform ushort first_SIMD8_num_children  = sub_group_broadcast(num_children, 0);
4263            uniform ushort second_SIMD8_num_children = sub_group_broadcast(num_children, 8);
4264            uniform ushort SIMD16_num_children = first_SIMD8_num_children + second_SIMD8_num_children;
4265
4266            uint root_buffer_position = 0;
4267
4268            // allocate space in the child buffer
4269            if (SIMD16_lane == 0 && SIMD16_num_children)
4270                root_buffer_position = atomic_add_local(SLM_ring_tail, SIMD16_num_children);
4271
4272            root_buffer_position = sub_group_broadcast( root_buffer_position, 0 );
4273            root_buffer_position += SIMD8_id * first_SIMD8_num_children; // update offset for second half of SIMD16
4274
4275            // store child indices in root buffer
4276            if (SIMD8_lane < num_children)
4277            {
4278                uint store_pos = (root_buffer_position + SIMD8_lane) % RING_SIZE;
4279                SLM_local_root_buffer[store_pos] = children_info.xy;
4280            }
4281        }
4282
4283        // sync everyone
4284        barrier( CLK_LOCAL_MEM_FENCE );
4285
4286        ring_head += nodes_consumed;
4287        ring_tail = *SLM_ring_tail;
4288        ring_size = ring_tail - ring_head;
4289    }
4290
4291    return ring_head;
4292}
4293
4294
4295
4296
4297inline void amplify_and_spill(
4298    global struct SAHBuildGlobals* globals,
4299    dword alloc_backpointers,
4300    uint first_qnode,
4301    uint first_bvh2_node,
4302    global uint2* global_root_buffer,
4303    local uint* root_buffer_counter,
4304    const uint  RING_SIZE
4305)
4306
4307{
4308    struct BuildFlatTreeArgs args;
4309    args.leaf_size_in_bytes = SAHBuildGlobals_GetLeafSizeInBytes(globals);
4310    args.leaf_type = SAHBuildGlobals_GetLeafType(globals);
4311    args.inner_node_type = SAHBuildGlobals_GetInternalNodeType(globals);
4312    args.primref_indices = SAHBuildGlobals_GetPrimrefIndices_In(globals, 0);
4313    args.primref_buffer = SAHBuildGlobals_GetPrimrefs(globals);
4314    args.bvh_base = SAHBuildGlobals_GetBVHBase(globals);
4315    args.bvh2 = SAHBuildGlobals_GetBVH2(globals);
4316    args.globals = (global struct Globals*) globals->p_globals;
4317
4318    global struct InternalNode* qnodes = (global struct InternalNode*) BVHBase_GetInternalNodes(args.bvh_base);
4319    global uint* back_pointers = (global uint*) BVHBase_GetBackPointers(args.bvh_base);
4320
4321
4322    varying uint3 children_info;
4323    uniform ushort num_children = SUBGROUP_BuildFlatTreeNode(args, first_bvh2_node, qnodes + first_qnode, first_qnode, &children_info);
4324
4325    if (alloc_backpointers)
4326    {
4327        // set first node's backpointer
4328        if (get_sub_group_local_id() == 0)
4329        {
4330            // if first node is root, use root sentinel in backpointer
4331            //   otherwise, need to merge the child count in with the parent offset (which was already put there by the parent's thread)
4332            uint bp = 0xffffffc0;
4333            if (first_qnode != 0)
4334                bp = back_pointers[first_qnode];
4335            bp |= (children_info.z << 3);
4336
4337            back_pointers[first_qnode] = bp;
4338        }
4339
4340        // point child backpointers at the parent
4341        if (get_sub_group_local_id() < num_children)
4342            back_pointers[children_info.y] = (first_qnode << 6);
4343    }
4344
4345    if (num_children)
4346    {
4347        uint spill_pos = 0;
4348        if (get_sub_group_local_id() == 0)
4349            spill_pos = atomic_add_local(root_buffer_counter,num_children);
4350
4351        spill_pos = sub_group_broadcast(spill_pos, 0);
4352
4353        if (get_sub_group_local_id() < num_children)
4354            global_root_buffer[spill_pos+get_sub_group_local_id()] = children_info.xy;
4355    }
4356
4357}
4358
4359
4360
4361
4362inline void build_qnodes_pc_kickoff_func(
4363    global struct SAHBuildGlobals* globals,
4364    global uint2* root_buffer,
4365    bool alloc_backpointers,
4366    bool process_masks,
4367
4368    local uint2* SLM_local_root_buffer,
4369    local uint* SLM_spill_pos,
4370    local uint* SLM_ring_tail,
4371    int RING_SIZE
4372)
4373{
4374    // allocate first node
4375    if ( get_sub_group_id() == 0 && get_sub_group_local_id() == 0 )
4376        allocate_inner_nodes( SAHBuildGlobals_GetBVHBase(globals), 1 );
4377
4378    *SLM_spill_pos=0;
4379
4380    uint ring_head = build_qnodes_pc( globals, alloc_backpointers, process_masks,
4381                     0, BVH2_GetRoot(SAHBuildGlobals_GetBVH2(globals)), SLM_local_root_buffer, SLM_ring_tail, RING_SIZE );
4382
4383
4384    uint n = *SLM_ring_tail - ring_head;
4385    if (n > 0)
4386    {
4387#if 0
4388        // do an additional round of amplification so we can get more nodes into the root buffer and go wider in the next phase
4389        /// JDB TODO: this is causing hangs on DG2 for metro, so disabling for now...
4390        for (uint i = get_sub_group_id(); i < n; i+= get_num_sub_groups() )
4391        {
4392            uint consume_pos = (ring_head + i) % RING_SIZE;
4393            uniform uint bvh2_root = SLM_local_root_buffer[consume_pos].x;
4394            uniform uint qnode_root = SLM_local_root_buffer[consume_pos].y;
4395
4396            amplify_and_spill( globals, alloc_backpointers, qnode_root, bvh2_root, root_buffer, SLM_spill_pos, RING_SIZE );
4397        }
4398
4399        barrier( CLK_LOCAL_MEM_FENCE );
4400#else
4401        for (uint i = get_local_id(0); i < n; i += get_local_size(0))
4402            root_buffer[i] = SLM_local_root_buffer[(ring_head + i) % RING_SIZE];
4403#endif
4404
4405        if (get_local_id(0) == 0)
4406        {
4407            globals->root_buffer_num_produced = n;
4408            globals->root_buffer_num_produced_hi = 0;
4409            globals->root_buffer_num_consumed = 0;
4410            globals->root_buffer_num_consumed_hi = 0;
4411        }
4412    }
4413}
4414
4415
4416
4417
4418GRL_ANNOTATE_IGC_DO_NOT_SPILL
4419__attribute__( (reqd_work_group_size( 256, 1, 1 )) )
4420__attribute__( (intel_reqd_sub_group_size( 16 )) )
4421kernel void
4422build_qnodes_pc_kickoff(
4423    global struct SAHBuildGlobals* globals,
4424    global uint2* root_buffer,
4425    dword sah_flags
4426)
4427{
4428    bool alloc_backpointers = sah_flags & SAH_FLAG_NEED_BACKPOINTERS;
4429    bool process_masks = sah_flags & SAH_FLAG_NEED_MASKS;
4430
4431
4432    const int RING_SIZE = 64;
4433
4434    local uint2 SLM_local_root_buffer[RING_SIZE];
4435    local uint SLM_spill_pos;
4436    local uint SLM_ring_tail;
4437
4438    build_qnodes_pc_kickoff_func(globals,
4439                                 root_buffer,
4440                                 alloc_backpointers,
4441                                 process_masks,
4442                                 SLM_local_root_buffer,
4443                                 &SLM_spill_pos,
4444                                 &SLM_ring_tail,
4445                                 RING_SIZE
4446                                 );
4447}
4448
4449
4450
4451
4452inline void build_qnodes_pc_amplify_func(
4453    global struct SAHBuildGlobals* globals,
4454    global uint2* root_buffer,
4455    bool alloc_backpointers,
4456    bool process_masks,
4457
4458    local uint2* SLM_local_root_buffer,
4459    local uint*  SLM_broadcast,
4460    local uint*  SLM_ring_tail,
4461    int RING_SIZE
4462    )
4463{
4464    // TODO_OPT:  Probably don't need this atomic.. could clear 'num_consumed' every time
4465    //     and just use get_group_id()
4466    //
4467
4468    if (get_local_id(0) == 0)
4469        *SLM_broadcast = atomic_inc_global(&globals->root_buffer_num_consumed);
4470
4471    barrier( CLK_LOCAL_MEM_FENCE );
4472
4473    uniform uint consume_pos = *SLM_broadcast;
4474    uniform uint bvh2_root = root_buffer[consume_pos].x;
4475    uniform uint qnode_root = root_buffer[consume_pos].y;
4476
4477    uint ring_head = build_qnodes_pc(globals, alloc_backpointers,process_masks,
4478        qnode_root, bvh2_root, SLM_local_root_buffer, SLM_ring_tail, RING_SIZE);
4479
4480    // TODO_OPT:  Instead of spilling the nodes, do one more round of amplification and write
4481    //   generated children directly into the root buffer.  This should allow faster amplification
4482
4483    // spill root buffer contents
4484    uint n = *SLM_ring_tail - ring_head;
4485    if (n > 0)
4486    {
4487
4488        if (get_local_id(0) == 0)
4489            *SLM_broadcast = atomic_add_global(&globals->root_buffer_num_produced, n);
4490
4491        barrier( CLK_LOCAL_MEM_FENCE );
4492        uint produce_pos = *SLM_broadcast;
4493
4494        for (uint i = get_local_id(0); i < n; i += get_local_size(0))
4495            root_buffer[produce_pos + i] = SLM_local_root_buffer[(ring_head + i) % RING_SIZE];
4496    }
4497}
4498
4499
4500
4501
4502
4503// Process two nodes per wg during amplification phase.
4504// DOing it this way ensures maximum parallelism
4505GRL_ANNOTATE_IGC_DO_NOT_SPILL
4506__attribute__((reqd_work_group_size(16, 1, 1)))
4507__attribute__((intel_reqd_sub_group_size(16)))
4508kernel void
4509build_qnodes_pc_amplify(
4510    global struct SAHBuildGlobals* globals,
4511    global uint2* root_buffer,
4512    dword sah_flags )
4513{
4514    bool alloc_backpointers = sah_flags & SAH_FLAG_NEED_BACKPOINTERS;
4515
4516    struct BuildFlatTreeArgs args;
4517    args.leaf_size_in_bytes = SAHBuildGlobals_GetLeafSizeInBytes(globals);
4518    args.leaf_type = SAHBuildGlobals_GetLeafType(globals);
4519    args.inner_node_type = SAHBuildGlobals_GetInternalNodeType(globals);
4520    args.primref_indices = SAHBuildGlobals_GetPrimrefIndices_In(globals, 0);
4521    args.primref_buffer = SAHBuildGlobals_GetPrimrefs(globals);
4522    args.bvh_base = SAHBuildGlobals_GetBVHBase(globals);
4523    args.bvh2 = SAHBuildGlobals_GetBVH2(globals);
4524    args.globals = (global struct Globals*) globals->p_globals;
4525    args.do_mask_processing = sah_flags & SAH_FLAG_NEED_MASKS;
4526
4527    global struct InternalNode* qnodes = (global struct InternalNode*) BVHBase_GetInternalNodes(args.bvh_base);
4528    global uint* back_pointers = (global uint*) BVHBase_GetBackPointers(args.bvh_base);
4529
4530    ushort SIMD16_lane = get_sub_group_local_id();
4531
4532    // SIMD16 as 2xSIMD8
4533    ushort SIMD8_lane = get_sub_group_local_id() % 8;
4534    ushort SIMD8_id = get_sub_group_local_id() / 8;
4535    bool active_lane = false;
4536
4537    uint consume_pos;
4538    consume_pos = globals->root_buffer_num_consumed + get_group_id(0) * 2; // times 2 because we process two nodes in workgroup
4539    consume_pos += SIMD8_id;
4540
4541    active_lane = consume_pos < globals->root_buffer_num_to_consume ? true : false;
4542    consume_pos = consume_pos < globals->root_buffer_num_to_consume ? consume_pos : consume_pos-1;
4543
4544    uint first_bvh2_node = root_buffer[consume_pos].x;
4545    uint first_qnode = root_buffer[consume_pos].y;
4546
4547    varying uint3 children_info;
4548    ushort num_children = SUBGROUP_BuildFlatTreeNode_2xSIMD8_in_SIMD16(args, first_bvh2_node, qnodes, first_qnode, &children_info, active_lane);
4549
4550    if (alloc_backpointers && active_lane)
4551    {
4552        // set first node's backpointer
4553        if (SIMD8_lane == 0)
4554        {
4555            // if first node is root, use root sentinel in backpointer
4556            //   otherwise, need to merge the child count in with the parent offset (which was already put there by the parent's thread)
4557            uint bp = 0xffffffc0;
4558            if (first_qnode != 0)
4559                bp = back_pointers[first_qnode];
4560            bp |= (children_info.z << 3);
4561
4562            back_pointers[first_qnode] = bp;
4563        }
4564
4565        // point child backpointers at the parent
4566        if (SIMD8_lane < num_children)
4567            back_pointers[children_info.y] = (first_qnode << 6);
4568    }
4569
4570    // save data
4571    {
4572        // sum children from both halfs of SIMD16 to do only one atomic per sub_group
4573        uint produce_pos;
4574        uniform ushort first_SIMD8_num_children = sub_group_broadcast(num_children, 0);
4575        uniform ushort second_SIMD8_num_children = sub_group_broadcast(num_children, 8);
4576        uniform ushort SIMD16_num_children = first_SIMD8_num_children + second_SIMD8_num_children;
4577
4578        if (SIMD16_lane == 0 && SIMD16_num_children)
4579            produce_pos = atomic_add_global(&globals->root_buffer_num_produced, SIMD16_num_children);
4580
4581        produce_pos = sub_group_broadcast(produce_pos, 0);
4582        produce_pos += SIMD8_id * first_SIMD8_num_children; // update offset for second half of SIMD16
4583
4584        if (SIMD8_lane < num_children)
4585        {
4586            root_buffer[produce_pos + SIMD8_lane] = children_info.xy;
4587        }
4588    }
4589}
4590
4591
4592//////////
4593//
4594// Batched version of qnode creation
4595//
4596//////////
4597
4598
4599
4600
4601GRL_ANNOTATE_IGC_DO_NOT_SPILL
4602__attribute__((reqd_work_group_size(1, 1, 1)))
4603kernel void
4604build_qnodes_init_scheduler_batched(global struct QnodeScheduler* scheduler, dword num_builds, dword num_max_qnode_global_root_buffer_entries)
4605{
4606
4607    scheduler->batched_build_offset = scheduler->num_trivial_builds + scheduler->num_single_builds;
4608    scheduler->batched_build_count = num_builds - scheduler->batched_build_offset;
4609    scheduler->num_max_qnode_global_root_buffer_entries = num_max_qnode_global_root_buffer_entries;
4610
4611    const uint num_builds_to_process = scheduler->batched_build_count;
4612    const uint max_qnode_grb_entries = scheduler->num_max_qnode_global_root_buffer_entries;
4613
4614    scheduler->batched_builds_to_process = num_builds_to_process;
4615    scheduler->num_qnode_grb_curr_entries = (num_builds_to_process + 15) / 16; // here we store number of workgroups for "build_qnodes_begin_batchable" kernel
4616    scheduler->num_qnode_grb_new_entries = num_builds_to_process;
4617    scheduler->qnode_global_root_buffer.curr_entries_offset = max_qnode_grb_entries;
4618}
4619
4620
4621
4622
4623GRL_ANNOTATE_IGC_DO_NOT_SPILL
4624__attribute__((reqd_work_group_size(16, 1, 1)))
4625__attribute__( (intel_reqd_sub_group_size( 16 )) )
4626kernel void
4627build_qnodes_begin_batchable(global struct QnodeScheduler* scheduler,
4628                             global struct SAHBuildGlobals* builds_globals)
4629{
4630    const uint tid = get_group_id(0) * get_local_size(0) + get_local_id(0);
4631
4632    const uint num_builds_to_process = scheduler->batched_builds_to_process;
4633
4634    if(tid < num_builds_to_process)
4635    {
4636        const uint build_idx = scheduler->batched_build_offset + tid;
4637
4638        uint bvh2_node = BVH2_GetRoot(SAHBuildGlobals_GetBVH2(&builds_globals[build_idx]));
4639        uint qnode = 0;
4640        struct QNodeGlobalRootBufferEntry entry = { bvh2_node, qnode, build_idx, 1};
4641        scheduler->qnode_global_root_buffer.entries[tid] = entry;
4642
4643        builds_globals[build_idx].root_buffer_num_produced = 0;
4644        builds_globals[build_idx].root_buffer_num_produced_hi = 0;
4645        builds_globals[build_idx].root_buffer_num_consumed = 0;
4646        builds_globals[build_idx].root_buffer_num_consumed_hi = 0;
4647
4648        // allocate first node for this build
4649        //allocate_inner_nodes( SAHBuildGlobals_GetBVHBase(&builds_globals[build_idx]), 1 );
4650        SAHBuildGlobals_GetBVHBase(&builds_globals[build_idx])->nodeDataCur++;
4651    }
4652}
4653
4654
4655
4656
4657GRL_ANNOTATE_IGC_DO_NOT_SPILL
4658__attribute__( (reqd_work_group_size( 1, 1, 1 )) )
4659kernel void
4660build_qnodes_scheduler(global struct QnodeScheduler* scheduler)
4661{
4662    const uint max_qnode_grb_entries = scheduler->num_max_qnode_global_root_buffer_entries;
4663
4664    uint new_entries = min(scheduler->num_qnode_grb_new_entries, max_qnode_grb_entries);
4665
4666    scheduler->num_qnode_grb_curr_entries = new_entries;
4667    scheduler->num_qnode_grb_new_entries = 0;
4668    scheduler->qnode_global_root_buffer.curr_entries_offset = scheduler->qnode_global_root_buffer.curr_entries_offset ? 0 : max_qnode_grb_entries;
4669}
4670
4671
4672
4673
4674// TODO_OPT:  Enable larger WGs.  WGSize 512 at SIMD8 hangs on Gen9, but Gen12 can go bigger
4675GRL_ANNOTATE_IGC_DO_NOT_SPILL
4676__attribute__( (reqd_work_group_size( 32, 1, 1 )) )
4677__attribute__( (intel_reqd_sub_group_size( 16 )) )
4678kernel void
4679build_qnodes_pc_amplify_batched(
4680    global struct SAHBuildGlobals* builds_globals,
4681    global struct QnodeScheduler* scheduler
4682    )
4683{
4684    const uint group_id = get_group_id(0);
4685
4686    global struct QNodeGlobalRootBuffer* global_root_buffer = &scheduler->qnode_global_root_buffer;
4687    const uint curr_entries_offset = global_root_buffer->curr_entries_offset;
4688    struct QNodeGlobalRootBufferEntry entry = global_root_buffer->entries[curr_entries_offset + group_id];
4689
4690    const uint build_id = entry.build_idx;
4691
4692    global struct SAHBuildGlobals* globals = &builds_globals[build_id];
4693    global uint2* root_buffer = (global uint2*)globals->p_qnode_root_buffer;
4694    bool alloc_backpointers = SAHBuildGlobals_NeedBackPointers(globals);
4695    bool process_masks = SAHBuildGlobals_NeedMasks(globals);
4696
4697    const int RING_SIZE = 32; // for 2 SGs, 16 should result in 2 rounds:  one SG produces 6, then 2 SGs consume 2 and produce 12
4698                              // for 4 SGs, 32 results in 2 rounds:  one SG produces 6, 4 SGs consume 4 and produce 24, resulting in 26
4699
4700    local uint2 SLM_local_root_buffer[RING_SIZE];
4701    local uint  SLM_broadcast;
4702    local uint  SLM_ring_tail;
4703    local uint  SLM_grb_broadcast;
4704
4705
4706    //// This below can be moved to separate function if needed for TLAS ////
4707
4708    uniform uint bvh2_root = entry.bvh2_node;
4709    uniform uint qnode_root = entry.qnode;
4710
4711    uint ring_head = build_qnodes_pc(globals, alloc_backpointers, process_masks,
4712        qnode_root, bvh2_root, SLM_local_root_buffer, &SLM_ring_tail, RING_SIZE);
4713
4714    // spill root buffer contents
4715    uint n = SLM_ring_tail - ring_head;
4716    if (n > 0)
4717    {
4718        const uint max_qnode_grb_entries = scheduler->num_max_qnode_global_root_buffer_entries;
4719
4720        if (get_local_id(0) == 0)
4721        {
4722            SLM_grb_broadcast = atomic_add_global(&scheduler->num_qnode_grb_new_entries, n);
4723
4724            if(SLM_grb_broadcast >= max_qnode_grb_entries) // if global_root_buffer is full, then make space in build's root_buffer
4725                SLM_broadcast = atomic_add_global(&globals->root_buffer_num_produced, n);
4726            else if( (SLM_grb_broadcast + n) >= max_qnode_grb_entries) // if we exceed global_root_buffer with our entries, then make space in build's root_buffer
4727                SLM_broadcast = atomic_add_global(&globals->root_buffer_num_produced, n - (max_qnode_grb_entries - SLM_grb_broadcast));
4728        }
4729
4730        barrier( CLK_LOCAL_MEM_FENCE );
4731
4732        uint produce_pos = SLM_broadcast;
4733
4734        uint grb_produce_num = n; // grb stands for global_root_buffer
4735        uint lrb_produce_num = 0; // lrb stands for local root buffer, meaning this build's root_buffer
4736
4737        if(SLM_grb_broadcast >= max_qnode_grb_entries) // if global_root_buffer is full, don't write to it
4738        {
4739            grb_produce_num = 0;
4740            lrb_produce_num = n;
4741        }
4742        else if( (SLM_grb_broadcast + n) >= max_qnode_grb_entries) // if we exceed global_root_buffer with our entries, then decrease amount of entries and store rest in build's root buffer
4743        {
4744            grb_produce_num = max_qnode_grb_entries - SLM_grb_broadcast;
4745            lrb_produce_num = n - grb_produce_num;
4746        }
4747
4748        // save data to global_root_buffer
4749        for(uint i = get_local_id(0); i < grb_produce_num; i += get_local_size(0))
4750        {
4751            const uint2 slm_record = SLM_local_root_buffer[(ring_head + i) % RING_SIZE];
4752
4753            struct QNodeGlobalRootBufferEntry new_entry;
4754            new_entry.bvh2_node = slm_record.x;
4755            new_entry.qnode = slm_record.y;
4756            new_entry.build_idx = entry.build_idx;
4757
4758            const uint new_entries_offset = curr_entries_offset ? 0 : max_qnode_grb_entries;
4759            global_root_buffer->entries[new_entries_offset + SLM_grb_broadcast + i] = new_entry;
4760        }
4761
4762        // if anything left, write to build's root buffer
4763        for (uint i = get_local_id(0); i < lrb_produce_num; i += get_local_size(0))
4764            root_buffer[produce_pos + i] = SLM_local_root_buffer[(ring_head + i + grb_produce_num) % RING_SIZE];
4765    }
4766}
4767
4768
4769
4770
4771GRL_ANNOTATE_IGC_DO_NOT_SPILL
4772__attribute__( (reqd_work_group_size( 16, 1, 1 )) )
4773__attribute__( (intel_reqd_sub_group_size( 16 )) )
4774kernel void
4775build_qnodes_try_to_fill_grb_batched(
4776    global struct SAHBuildGlobals* builds_globals,
4777    global struct QnodeScheduler* scheduler
4778    )
4779{
4780    const uint build_id = scheduler->batched_build_offset + get_group_id(0);
4781    global struct SAHBuildGlobals* globals = &builds_globals[build_id];
4782    global uint2* root_buffer = (global uint2*)globals->p_qnode_root_buffer;
4783
4784    global struct QNodeGlobalRootBuffer* qnode_root_buffer = (global struct QNodeGlobalRootBuffer*)&scheduler->qnode_global_root_buffer;
4785
4786    const uint num_produced = globals->root_buffer_num_produced;
4787    const uint num_consumed = globals->root_buffer_num_consumed;
4788    const uint entries =  num_produced - num_consumed; // entries to build's root buffer
4789
4790    if(!entries)
4791        return;
4792
4793    uint global_root_buffer_offset;
4794    if(get_local_id(0) == 0)
4795        global_root_buffer_offset = atomic_add_global(&scheduler->num_qnode_grb_new_entries, entries);
4796
4797    global_root_buffer_offset = sub_group_broadcast(global_root_buffer_offset, 0);
4798
4799    const uint max_qnode_grb_entries = scheduler->num_max_qnode_global_root_buffer_entries;
4800
4801    if(global_root_buffer_offset >= max_qnode_grb_entries) // if global_root_buffer is full, then return
4802        return;
4803
4804    uint global_root_buffer_produce_num = entries;
4805    if(global_root_buffer_offset + entries >= max_qnode_grb_entries) // if we exceed global_root_buffer with our entries, then reduce number of entries to push
4806        global_root_buffer_produce_num = max_qnode_grb_entries - global_root_buffer_offset;
4807
4808    for(uint i = get_local_id(0); i < global_root_buffer_produce_num; i += get_local_size(0))
4809    {
4810        const uint2 entry = root_buffer[num_consumed + i];
4811
4812        struct QNodeGlobalRootBufferEntry new_entry;
4813        new_entry.bvh2_node = entry.x;
4814        new_entry.qnode = entry.y;
4815        new_entry.build_idx = build_id;
4816
4817        const uint new_entries_offset = qnode_root_buffer->curr_entries_offset ? 0 : max_qnode_grb_entries;
4818        qnode_root_buffer->entries[new_entries_offset + global_root_buffer_offset + i] = new_entry;
4819    }
4820
4821    if(get_local_id(0) == 0)
4822        globals->root_buffer_num_consumed += global_root_buffer_produce_num;
4823}
4824