• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1/*
2 * Copyright © 2022 Bas Nieuwenhuizen
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24#version 460
25
26#extension GL_GOOGLE_include_directive : require
27
28#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
29#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
30#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
31#extension GL_EXT_shader_explicit_arithmetic_types_int64 : require
32#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
33#extension GL_EXT_scalar_block_layout : require
34#extension GL_EXT_buffer_reference : require
35#extension GL_EXT_buffer_reference2 : require
36#extension GL_KHR_memory_scope_semantics : require
37#extension GL_KHR_shader_subgroup_vote : require
38#extension GL_KHR_shader_subgroup_arithmetic : require
39#extension GL_KHR_shader_subgroup_ballot : require
40
41layout(local_size_x = 1024, local_size_y = 1, local_size_z = 1) in;
42
43#define USE_GLOBAL_SYNC
44#include "vk_build_interface.h"
45
46TYPE(ploc_prefix_scan_partition, 4);
47
48layout(push_constant) uniform CONSTS
49{
50   ploc_args args;
51};
52
53shared uint32_t exclusive_prefix_sum;
54shared uint32_t aggregate_sums[PLOC_SUBGROUPS_PER_WORKGROUP];
55shared uint32_t aggregate_sums2[PLOC_SUBGROUPS_PER_WORKGROUP];
56
57/*
58 * Global prefix scan over all workgroups to find out the index of the collapsed node to write.
59 * See https://research.nvidia.com/sites/default/files/publications/nvr-2016-002.pdf
60 * One partition = one workgroup in this case.
61 */
62uint32_t
63prefix_scan(uvec4 ballot, REF(ploc_prefix_scan_partition) partitions, uint32_t task_index)
64{
65   if (gl_LocalInvocationIndex == 0) {
66      exclusive_prefix_sum = 0;
67      if (task_index >= gl_WorkGroupSize.x) {
68         REF(ploc_prefix_scan_partition) current_partition =
69            REF(ploc_prefix_scan_partition)(INDEX(ploc_prefix_scan_partition, partitions, task_index / gl_WorkGroupSize.x));
70
71         REF(ploc_prefix_scan_partition) previous_partition = current_partition - 1;
72
73         while (true) {
74            /* See if this previous workgroup already set their inclusive sum */
75            if (atomicLoad(DEREF(previous_partition).inclusive_sum, gl_ScopeDevice,
76                           gl_StorageSemanticsBuffer,
77                           gl_SemanticsAcquire | gl_SemanticsMakeVisible) != 0xFFFFFFFF) {
78               atomicAdd(exclusive_prefix_sum, DEREF(previous_partition).inclusive_sum);
79               break;
80            } else {
81               atomicAdd(exclusive_prefix_sum, DEREF(previous_partition).aggregate);
82               previous_partition -= 1;
83            }
84         }
85         /* Set the inclusive sum for the next workgroups */
86         atomicStore(DEREF(current_partition).inclusive_sum,
87                     DEREF(current_partition).aggregate + exclusive_prefix_sum, gl_ScopeDevice,
88                     gl_StorageSemanticsBuffer, gl_SemanticsRelease | gl_SemanticsMakeAvailable);
89      }
90   }
91
92   if (subgroupElect())
93      aggregate_sums[gl_SubgroupID] = subgroupBallotBitCount(ballot);
94   barrier();
95
96   if (PLOC_SUBGROUPS_PER_WORKGROUP <= SUBGROUP_SIZE) {
97      if (gl_LocalInvocationID.x < PLOC_SUBGROUPS_PER_WORKGROUP) {
98         aggregate_sums[gl_LocalInvocationID.x] =
99            exclusive_prefix_sum + subgroupExclusiveAdd(aggregate_sums[gl_LocalInvocationID.x]);
100      }
101   } else {
102      /* If the length of aggregate_sums[] is larger than SUBGROUP_SIZE,
103       * the prefix scan can't be done simply by subgroupExclusiveAdd.
104       */
105      if (gl_LocalInvocationID.x < PLOC_SUBGROUPS_PER_WORKGROUP)
106         aggregate_sums2[gl_LocalInvocationID.x] = aggregate_sums[gl_LocalInvocationID.x];
107      barrier();
108
109      /* Hillis Steele inclusive scan on aggregate_sums2 */
110      for (uint32_t stride = 1; stride < PLOC_SUBGROUPS_PER_WORKGROUP; stride *= 2) {
111         uint32_t value = 0;
112         if (gl_LocalInvocationID.x >= stride && gl_LocalInvocationID.x < PLOC_SUBGROUPS_PER_WORKGROUP)
113            value = aggregate_sums2[gl_LocalInvocationID.x - stride];
114         barrier();
115         if (gl_LocalInvocationID.x < PLOC_SUBGROUPS_PER_WORKGROUP)
116            aggregate_sums2[gl_LocalInvocationID.x] += value;
117         barrier();
118      }
119
120      /* Adapt to exclusive and add the prefix_sum from previous workgroups */
121      if (gl_LocalInvocationID.x < PLOC_SUBGROUPS_PER_WORKGROUP) {
122         if (gl_LocalInvocationID.x == 0)
123            aggregate_sums[gl_LocalInvocationID.x] = exclusive_prefix_sum;
124         else
125            aggregate_sums[gl_LocalInvocationID.x] = exclusive_prefix_sum + aggregate_sums2[gl_LocalInvocationID.x - 1];
126      }
127   }
128   barrier();
129
130   return aggregate_sums[gl_SubgroupID] + subgroupBallotExclusiveBitCount(ballot);
131}
132
133/* Relative cost of increasing the BVH depth. Deep BVHs will require more backtracking. */
134#define BVH_LEVEL_COST 0.2
135
136uint32_t
137push_node(uint32_t children[2], vk_aabb bounds[2])
138{
139   uint32_t internal_node_index = atomicAdd(DEREF(args.header).ir_internal_node_count, 1);
140   uint32_t dst_offset = args.internal_node_offset + internal_node_index * SIZEOF(vk_ir_box_node);
141   uint32_t dst_id = pack_ir_node_id(dst_offset, vk_ir_node_internal);
142   REF(vk_ir_box_node) dst_node = REF(vk_ir_box_node)(OFFSET(args.bvh, dst_offset));
143
144   vk_aabb total_bounds;
145   total_bounds.min = vec3(INFINITY);
146   total_bounds.max = vec3(-INFINITY);
147
148   for (uint i = 0; i < 2; ++i) {
149      VOID_REF node = OFFSET(args.bvh, ir_id_to_offset(children[i]));
150      REF(vk_ir_node) child = REF(vk_ir_node)(node);
151
152      total_bounds.min = min(total_bounds.min, bounds[i].min);
153      total_bounds.max = max(total_bounds.max, bounds[i].max);
154
155      DEREF(dst_node).children[i] = children[i];
156   }
157
158   DEREF(dst_node).base.aabb = total_bounds;
159   DEREF(dst_node).bvh_offset = VK_UNKNOWN_BVH_OFFSET;
160   return dst_id;
161}
162
163#define PLOC_NEIGHBOURHOOD 16
164#define PLOC_OFFSET_MASK   ((1 << 5) - 1)
165
166uint32_t
167encode_neighbour_offset(float sah, uint32_t i, uint32_t j)
168{
169   int32_t offset = int32_t(j) - int32_t(i);
170   uint32_t encoded_offset = offset + PLOC_NEIGHBOURHOOD - (offset > 0 ? 1 : 0);
171   return (floatBitsToUint(sah) & (~PLOC_OFFSET_MASK)) | encoded_offset;
172}
173
174int32_t
175decode_neighbour_offset(uint32_t encoded_offset)
176{
177   int32_t offset = int32_t(encoded_offset & PLOC_OFFSET_MASK) - PLOC_NEIGHBOURHOOD;
178   return offset + (offset >= 0 ? 1 : 0);
179}
180
181#define NUM_PLOC_LDS_ITEMS PLOC_WORKGROUP_SIZE + 4 * PLOC_NEIGHBOURHOOD
182
183shared vk_aabb shared_bounds[NUM_PLOC_LDS_ITEMS];
184shared uint32_t nearest_neighbour_indices[NUM_PLOC_LDS_ITEMS];
185
186uint32_t
187load_id(VOID_REF ids, uint32_t iter, uint32_t index)
188{
189   if (iter == 0)
190      return DEREF(REF(key_id_pair)(INDEX(key_id_pair, ids, index))).id;
191   else
192      return DEREF(REF(uint32_t)(INDEX(uint32_t, ids, index)));
193}
194
195void
196load_bounds(VOID_REF ids, uint32_t iter, uint32_t task_index, uint32_t lds_base,
197            uint32_t neighbourhood_overlap, uint32_t search_bound)
198{
199   for (uint32_t i = task_index - 2 * neighbourhood_overlap; i < search_bound;
200        i += gl_WorkGroupSize.x) {
201      uint32_t id = load_id(ids, iter, i);
202      if (id == VK_BVH_INVALID_NODE)
203         continue;
204
205      VOID_REF addr = OFFSET(args.bvh, ir_id_to_offset(id));
206      REF(vk_ir_node) node = REF(vk_ir_node)(addr);
207
208      shared_bounds[i - lds_base] = DEREF(node).aabb;
209   }
210}
211
212float
213combined_node_cost(uint32_t lds_base, uint32_t i, uint32_t j)
214{
215   vk_aabb combined_bounds;
216   combined_bounds.min = min(shared_bounds[i - lds_base].min, shared_bounds[j - lds_base].min);
217   combined_bounds.max = max(shared_bounds[i - lds_base].max, shared_bounds[j - lds_base].max);
218   return aabb_surface_area(combined_bounds);
219}
220
221shared uint32_t shared_aggregate_sum;
222
223void
224main(void)
225{
226   VOID_REF src_ids = args.ids_0;
227   VOID_REF dst_ids = args.ids_1;
228
229   /* We try to use LBVH for BVHs where we know there will be less than 5 leaves,
230    * but sometimes all leaves might be inactive */
231   if (DEREF(args.header).active_leaf_count <= 2) {
232      if (gl_GlobalInvocationID.x == 0) {
233         uint32_t internal_node_index = atomicAdd(DEREF(args.header).ir_internal_node_count, 1);
234         uint32_t dst_offset = args.internal_node_offset + internal_node_index * SIZEOF(vk_ir_box_node);
235         REF(vk_ir_box_node) dst_node = REF(vk_ir_box_node)(OFFSET(args.bvh, dst_offset));
236
237         vk_aabb total_bounds;
238         total_bounds.min = vec3(INFINITY);
239         total_bounds.max = vec3(-INFINITY);
240
241         uint32_t i = 0;
242         for (; i < DEREF(args.header).active_leaf_count; i++) {
243            uint32_t child_id = DEREF(INDEX(key_id_pair, src_ids, i)).id;
244
245            if (child_id != VK_BVH_INVALID_NODE) {
246               VOID_REF node = OFFSET(args.bvh, ir_id_to_offset(child_id));
247               REF(vk_ir_node) child = REF(vk_ir_node)(node);
248               vk_aabb bounds = DEREF(child).aabb;
249
250               total_bounds.min = min(total_bounds.min, bounds.min);
251               total_bounds.max = max(total_bounds.max, bounds.max);
252            }
253
254            DEREF(dst_node).children[i] = child_id;
255         }
256         for (; i < 2; i++)
257            DEREF(dst_node).children[i] = VK_BVH_INVALID_NODE;
258
259         DEREF(dst_node).base.aabb = total_bounds;
260         DEREF(dst_node).bvh_offset = VK_UNKNOWN_BVH_OFFSET;
261      }
262      return;
263   }
264
265   /* Only initialize sync_data once per workgroup. For intra-workgroup synchronization,
266    * fetch_task contains a workgroup-scoped control+memory barrier.
267    */
268   if (gl_LocalInvocationIndex == 0) {
269      atomicCompSwap(DEREF(args.header).sync_data.task_counts[0], 0xFFFFFFFF,
270                     DEREF(args.header).active_leaf_count);
271      atomicCompSwap(DEREF(args.header).sync_data.current_phase_end_counter, 0xFFFFFFFF,
272                     DIV_ROUND_UP(DEREF(args.header).active_leaf_count, gl_WorkGroupSize.x));
273   }
274
275   REF(ploc_prefix_scan_partition)
276   partitions = REF(ploc_prefix_scan_partition)(args.prefix_scan_partitions);
277
278   uint32_t task_index = fetch_task(args.header, false);
279
280   for (uint iter = 0;; ++iter) {
281      uint32_t current_task_count = task_count(args.header);
282      if (task_index == TASK_INDEX_INVALID)
283         break;
284
285      /* Find preferred partners and merge them */
286      PHASE (args.header) {
287         uint32_t base_index = task_index - gl_LocalInvocationID.x;
288         uint32_t neighbourhood_overlap = min(PLOC_NEIGHBOURHOOD, base_index);
289         uint32_t double_neighbourhood_overlap = min(2 * PLOC_NEIGHBOURHOOD, base_index);
290         /* Upper bound to where valid nearest node indices are written. */
291         uint32_t write_bound =
292            min(current_task_count, base_index + gl_WorkGroupSize.x + PLOC_NEIGHBOURHOOD);
293         /* Upper bound to where valid nearest node indices are searched. */
294         uint32_t search_bound =
295            min(current_task_count, base_index + gl_WorkGroupSize.x + 2 * PLOC_NEIGHBOURHOOD);
296         uint32_t lds_base = base_index - double_neighbourhood_overlap;
297
298         load_bounds(src_ids, iter, task_index, lds_base, neighbourhood_overlap, search_bound);
299
300         for (uint32_t i = gl_LocalInvocationID.x; i < NUM_PLOC_LDS_ITEMS; i += gl_WorkGroupSize.x)
301            nearest_neighbour_indices[i] = 0xFFFFFFFF;
302         barrier();
303
304         for (uint32_t i = task_index - double_neighbourhood_overlap; i < write_bound;
305              i += gl_WorkGroupSize.x) {
306            uint32_t right_bound = min(search_bound - 1 - i, PLOC_NEIGHBOURHOOD);
307
308            uint32_t fallback_pair = i == 0 ? (i + 1) : (i - 1);
309            uint32_t min_offset = encode_neighbour_offset(INFINITY, i, fallback_pair);
310
311            for (uint32_t j = max(i + 1, base_index - neighbourhood_overlap); j <= i + right_bound;
312                 ++j) {
313
314               float sah = combined_node_cost(lds_base, i, j);
315               uint32_t i_encoded_offset = encode_neighbour_offset(sah, i, j);
316               uint32_t j_encoded_offset = encode_neighbour_offset(sah, j, i);
317               min_offset = min(min_offset, i_encoded_offset);
318               atomicMin(nearest_neighbour_indices[j - lds_base], j_encoded_offset);
319            }
320            if (i >= base_index - neighbourhood_overlap)
321               atomicMin(nearest_neighbour_indices[i - lds_base], min_offset);
322         }
323
324         if (gl_LocalInvocationID.x == 0)
325            shared_aggregate_sum = 0;
326         barrier();
327
328         for (uint32_t i = task_index - neighbourhood_overlap; i < write_bound;
329              i += gl_WorkGroupSize.x) {
330            uint32_t left_bound = min(i, PLOC_NEIGHBOURHOOD);
331            uint32_t right_bound = min(search_bound - 1 - i, PLOC_NEIGHBOURHOOD);
332            /*
333             * Workaround for a worst-case scenario in PLOC: If the combined area of
334             * all nodes (in the neighbourhood) is the same, then the chosen nearest
335             * neighbour is the first neighbour. However, this means that no nodes
336             * except the first two will find each other as nearest neighbour. Therefore,
337             * only one node is contained in each BVH level. By first testing if the immediate
338             * neighbour on one side is the nearest, all immediate neighbours will be merged
339             * on every step.
340             */
341            uint32_t preferred_pair;
342            if ((i & 1) != 0)
343               preferred_pair = i - min(left_bound, 1);
344            else
345               preferred_pair = i + min(right_bound, 1);
346
347            if (preferred_pair != i) {
348               uint32_t encoded_min_sah =
349                  nearest_neighbour_indices[i - lds_base] & (~PLOC_OFFSET_MASK);
350               float sah = combined_node_cost(lds_base, i, preferred_pair);
351               uint32_t encoded_sah = floatBitsToUint(sah) & (~PLOC_OFFSET_MASK);
352               uint32_t encoded_offset = encode_neighbour_offset(sah, i, preferred_pair);
353               if (encoded_sah <= encoded_min_sah) {
354                  nearest_neighbour_indices[i - lds_base] = encoded_offset;
355               }
356            }
357         }
358         barrier();
359
360         bool has_valid_node = true;
361
362         if (task_index < current_task_count) {
363            uint32_t base_index = task_index - gl_LocalInvocationID.x;
364
365            uint32_t neighbour_index =
366               task_index +
367               decode_neighbour_offset(nearest_neighbour_indices[task_index - lds_base]);
368            uint32_t other_neighbour_index =
369               neighbour_index +
370               decode_neighbour_offset(nearest_neighbour_indices[neighbour_index - lds_base]);
371            uint32_t id = load_id(src_ids, iter, task_index);
372            if (other_neighbour_index == task_index) {
373               if (task_index < neighbour_index) {
374                  uint32_t neighbour_id = load_id(src_ids, iter, neighbour_index);
375                  uint32_t children[2] = {id, neighbour_id};
376                  vk_aabb bounds[2] = {shared_bounds[task_index - lds_base], shared_bounds[neighbour_index - lds_base]};
377
378                  DEREF(REF(uint32_t)(INDEX(uint32_t, dst_ids, task_index))) = push_node(children, bounds);
379                  DEREF(REF(uint32_t)(INDEX(uint32_t, dst_ids, neighbour_index))) =
380                     VK_BVH_INVALID_NODE;
381               } else {
382                  /* We still store in the other case so we don't destroy the node id needed to
383                   * create the internal node */
384                  has_valid_node = false;
385               }
386            } else {
387               DEREF(REF(uint32_t)(INDEX(uint32_t, dst_ids, task_index))) = id;
388            }
389
390            /* Compact - prepare prefix scan */
391            uvec4 ballot = subgroupBallot(has_valid_node);
392
393            uint32_t aggregate_sum = subgroupBallotBitCount(ballot);
394            if (subgroupElect())
395               atomicAdd(shared_aggregate_sum, aggregate_sum);
396         }
397
398         barrier();
399         /*
400          * The paper proposes initializing all partitions to an invalid state
401          * and only computing aggregates afterwards. We skip that step and
402          * initialize the partitions to a valid state. This also simplifies
403          * the look-back, as there will never be any blocking due to invalid
404          * partitions.
405          */
406         if (gl_LocalInvocationIndex == 0) {
407            REF(ploc_prefix_scan_partition)
408            current_partition = REF(ploc_prefix_scan_partition)(
409               INDEX(ploc_prefix_scan_partition, partitions, task_index / gl_WorkGroupSize.x));
410            DEREF(current_partition).aggregate = shared_aggregate_sum;
411            if (task_index < gl_WorkGroupSize.x) {
412               DEREF(current_partition).inclusive_sum = shared_aggregate_sum;
413            } else {
414               DEREF(current_partition).inclusive_sum = 0xFFFFFFFF;
415            }
416         }
417
418         if (task_index == 0)
419            set_next_task_count(args.header, task_count(args.header));
420      }
421
422      /* Compact - prefix scan and update */
423      PHASE (args.header) {
424         uint32_t current_task_count = task_count(args.header);
425
426         uint32_t id = task_index < current_task_count
427                          ? DEREF(REF(uint32_t)(INDEX(uint32_t, dst_ids, task_index)))
428                          : VK_BVH_INVALID_NODE;
429         uvec4 ballot = subgroupBallot(id != VK_BVH_INVALID_NODE);
430
431         uint32_t new_offset = prefix_scan(ballot, partitions, task_index);
432         if (task_index >= current_task_count)
433            continue;
434
435         if (id != VK_BVH_INVALID_NODE) {
436            DEREF(REF(uint32_t)(INDEX(uint32_t, src_ids, new_offset))) = id;
437            ++new_offset;
438         }
439
440         if (task_index == current_task_count - 1) {
441            set_next_task_count(args.header, new_offset);
442            if (new_offset == 1)
443               DEREF(args.header).sync_data.next_phase_exit_flag = 1;
444         }
445      }
446   }
447}
448