• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1/*
2 * Copyright © 2022 Friedrich Vock
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24#version 460
25
26#extension GL_GOOGLE_include_directive : require
27
28#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
29#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
30#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
31#extension GL_EXT_shader_explicit_arithmetic_types_int64 : require
32#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
33#extension GL_EXT_scalar_block_layout : require
34#extension GL_EXT_buffer_reference : require
35#extension GL_EXT_buffer_reference2 : require
36#extension GL_KHR_memory_scope_semantics : require
37
38layout(local_size_x = 128, local_size_y = 1, local_size_z = 1) in;
39
40#include "tu_build_helpers.h"
41#include "tu_build_interface.h"
42
43layout(push_constant) uniform CONSTS {
44   encode_args args;
45};
46
47void set_parent(uint32_t child, uint32_t parent)
48{
49   uint64_t addr = args.output_bvh - child * 4 - 4;
50   DEREF(REF(uint32_t)(addr)) = parent;
51}
52
53/* This encoder struct is designed to encode a compressed node without keeping
54 * all the data live at once, making sure register pressure isn't too high.
55 */
56
57struct tu_encoder {
58   uint32_t cur_value;
59   uint word_offset;
60   uint bit_offset;
61   REF(tu_compressed_node) node;
62};
63
64void encode_init(out tu_encoder encoder, REF(tu_compressed_node) node)
65{
66   encoder.cur_value = 0;
67   encoder.word_offset = 0;
68   encoder.bit_offset = 0;
69   encoder.node = node;
70}
71
72void encode(inout tu_encoder encoder, uint32_t val, uint bits)
73{
74   encoder.cur_value |= val << encoder.bit_offset;
75   if (encoder.bit_offset + bits >= 32) {
76      DEREF(encoder.node).data[encoder.word_offset] = encoder.cur_value;
77      encoder.cur_value = val >> (32 - encoder.bit_offset);
78      encoder.word_offset++;
79      encoder.bit_offset = encoder.bit_offset + bits - 32;
80   } else {
81      encoder.bit_offset += bits;
82   }
83}
84
85void encode_skip(inout tu_encoder encoder, uint bits)
86{
87   if (encoder.bit_offset + bits >= 32) {
88      DEREF(encoder.node).data[encoder.word_offset] = encoder.cur_value;
89      encoder.word_offset++;
90      encoder.bit_offset = encoder.bit_offset + bits - 32;
91   } else {
92      encoder.bit_offset += bits;
93   }
94}
95
96void encode_finalize(tu_encoder encoder)
97{
98   DEREF(encoder.node).data[encoder.word_offset] = encoder.cur_value;
99}
100
101void
102encode_leaf_node(uint32_t type, uint64_t src_node, uint64_t dst_node, uint64_t dst_instances, REF(tu_accel_struct_header) dst_header)
103{
104   float coords[3][3];
105   uint32_t id;
106   uint32_t geometry_id;
107   uint32_t type_flags = TU_NODE_TYPE_LEAF;
108
109   switch (type) {
110   case vk_ir_node_triangle: {
111      vk_ir_triangle_node src = DEREF(REF(vk_ir_triangle_node)(src_node));
112
113      coords = src.coords;
114      uint32_t geometry_id_and_flags = src.geometry_id_and_flags;
115      if ((geometry_id_and_flags & VK_GEOMETRY_OPAQUE) != 0) {
116         atomicAnd(DEREF(dst_header).instance_flags, ~TU_INSTANCE_ALL_NONOPAQUE);
117      } else {
118         type_flags |= TU_NODE_TYPE_NONOPAQUE;
119         atomicAnd(DEREF(dst_header).instance_flags, ~TU_INSTANCE_ALL_OPAQUE);
120      }
121      geometry_id = geometry_id_and_flags & 0xffffff;
122      id = src.triangle_id;
123      break;
124   }
125   case vk_ir_node_aabb: {
126      vk_ir_aabb_node src = DEREF(REF(vk_ir_aabb_node)(src_node));
127      vk_aabb aabb = src.base.aabb;
128      coords[0][0] = aabb.min[0];
129      coords[0][1] = aabb.min[1];
130      coords[0][2] = aabb.min[2];
131      coords[1][0] = aabb.max[0];
132      coords[1][1] = aabb.max[1];
133      coords[1][2] = aabb.max[2];
134
135      type_flags |= TU_NODE_TYPE_AABB;
136
137      if ((src.geometry_id_and_flags & VK_GEOMETRY_OPAQUE) != 0) {
138         atomicAnd(DEREF(dst_header).instance_flags, ~TU_INSTANCE_ALL_NONOPAQUE);
139      } else {
140         type_flags |= TU_NODE_TYPE_NONOPAQUE;
141         atomicAnd(DEREF(dst_header).instance_flags, ~TU_INSTANCE_ALL_OPAQUE);
142      }
143      geometry_id = src.geometry_id_and_flags & 0xffffff;
144      id = src.primitive_id;
145      break;
146   }
147   case vk_ir_node_instance: {
148      vk_ir_instance_node src = DEREF(REF(vk_ir_instance_node)(src_node));
149
150      id = src.instance_id;
151      geometry_id = 0;
152      REF(tu_instance_descriptor) dst_instance = REF(tu_instance_descriptor)(dst_instances + SIZEOF(tu_instance_descriptor) * id);
153
154      REF(tu_accel_struct_header) blas_header = REF(tu_accel_struct_header)(src.base_ptr);
155      uint64_t bvh_ptr = DEREF(blas_header).bvh_ptr;
156      uint32_t bvh_offset = uint32_t(bvh_ptr - src.base_ptr);
157
158      uint32_t sbt_offset_and_flags = src.sbt_offset_and_flags;
159      uint32_t custom_instance_and_mask = src.custom_instance_and_mask;
160      DEREF(dst_instance).bvh_ptr = bvh_ptr;
161      DEREF(dst_instance).custom_instance_index = custom_instance_and_mask & 0xffffffu;
162      DEREF(dst_instance).sbt_offset_and_flags = sbt_offset_and_flags;
163      DEREF(dst_instance).bvh_offset = bvh_offset;
164
165      mat4 transform = mat4(src.otw_matrix);
166
167      mat4 inv_transform = transpose(inverse(transpose(transform)));
168      DEREF(dst_instance).wto_matrix = mat3x4(inv_transform);
169      DEREF(dst_instance).otw_matrix = mat3x4(transform);
170
171      vk_aabb aabb = src.base.aabb;
172      coords[0][0] = aabb.min[0];
173      coords[0][1] = aabb.min[1];
174      coords[0][2] = aabb.min[2];
175      coords[1][0] = aabb.max[0];
176      coords[1][1] = aabb.max[1];
177      coords[1][2] = aabb.max[2];
178
179      type_flags |= TU_NODE_TYPE_TLAS;
180
181      uint32_t instance_flags = DEREF(blas_header).instance_flags;
182
183      /* Apply VK_GEOMETRY_INSTANCE_FORCE_OPAQUE_BIT_KHR and
184       * VK_GEOMETRY_INSTANCE_FORCE_NO_OPAQUE_BIT_KHR to correct the
185       * ALL_OPAQUE/ALL_NONOPAQUE flags.
186       */
187      if (((sbt_offset_and_flags >> 24) & (VK_GEOMETRY_INSTANCE_FORCE_OPAQUE_BIT_KHR |
188                                           VK_GEOMETRY_INSTANCE_FORCE_NO_OPAQUE_BIT_KHR)) != 0) {
189         instance_flags &= ~(VK_GEOMETRY_INSTANCE_FORCE_OPAQUE_BIT_KHR |
190                             VK_GEOMETRY_INSTANCE_FORCE_NO_OPAQUE_BIT_KHR);
191         instance_flags |= (sbt_offset_and_flags >> 24) & (VK_GEOMETRY_INSTANCE_FORCE_OPAQUE_BIT_KHR |
192                                                           VK_GEOMETRY_INSTANCE_FORCE_NO_OPAQUE_BIT_KHR);
193      }
194      uint32_t cull_mask_and_flags = ((custom_instance_and_mask >> 16) & 0xff00) | instance_flags;
195
196      coords[2][0] = uintBitsToFloat(cull_mask_and_flags);
197      break;
198   }
199   }
200
201   REF(tu_leaf_node) dst = REF(tu_leaf_node)(dst_node);
202   DEREF(dst).coords = coords;
203   DEREF(dst).id = id;
204   DEREF(dst).geometry_id = geometry_id;
205   DEREF(dst).type_flags = type_flags;
206}
207
208/* Truncate to bfloat16 while rounding down. bfloat16 is used to store the bases.
209 */
210
211u16vec3 to_bfloat_round_down(vec3 coord)
212{
213   u32vec3 icoord = floatBitsToUint(coord);
214   return u16vec3(mix(icoord >> 16, (icoord + 0xffff) >> 16, notEqual(icoord & u32vec3(0x80000000), u32vec3(0))));
215}
216
217/* Approximate subtraction while rounding up. Return a result greater than or
218 * equal to the infinitely-precise result. This just uses the native
219 * subtraction and then shifts one ULP towards infinity. Because the result is
220 * further rounded, it should usually be good enough while being faster than
221 * emulated floating-point math.
222 *
223 * We assume here that the result is always nonnegative, because it's only used
224 * to subtract away the base.
225 */
226
227vec3 subtract_round_up_approx(vec3 a, vec3 b)
228{
229   vec3 f = a - b;
230   u32vec3 i = floatBitsToUint(f);
231
232   i++;
233
234   /* Handle infinity/zero special cases */
235   i = mix(i, floatBitsToUint(f), isinf(f));
236   i = mix(i, floatBitsToUint(f), equal(f, vec3(0)));
237
238   return uintBitsToFloat(i);
239}
240
241vec3 subtract_round_down_approx(vec3 a, vec3 b)
242{
243   vec3 f = a - b;
244   u32vec3 i = floatBitsToUint(f);
245
246   i--;
247
248   /* Handle infinity/zero special cases */
249   i = mix(i, floatBitsToUint(f), isinf(f));
250   i = mix(i, floatBitsToUint(f), equal(f, vec3(0)));
251
252   return uintBitsToFloat(i);
253}
254
255u32vec3 extract_mantissa(vec3 f)
256{
257   return mix((floatBitsToUint(f) & 0x7fffff) | 0x800000, u32vec3(0), equal(f, vec3(0)));
258}
259
260void
261encode_internal_node(uint32_t children[8], uint32_t children_offset, uint child_count,
262                     vec3 min_offset, vec3 max_offset, uint32_t bvh_offset)
263{
264   REF(tu_internal_node) dst_node = REF(tu_internal_node)(OFFSET(args.output_bvh, SIZEOF(tu_internal_node) * bvh_offset));
265
266   DEREF(dst_node).id = children_offset;
267
268   u16vec3 base_bfloat = to_bfloat_round_down(min_offset);
269   vec3 base_float = uintBitsToFloat(u32vec3(base_bfloat) << 16);
270   DEREF(dst_node).bases[0] = base_bfloat.x;
271   DEREF(dst_node).bases[1] = base_bfloat.y;
272   DEREF(dst_node).bases[2] = base_bfloat.z;
273
274   vec3 children_max = subtract_round_up_approx(max_offset, base_float);
275
276   /* The largest child offset will be encoded in 8 bits, including the
277    * explicit leading 1. We need to downcast to this precision while rounding
278    * up to catch cases where the exponent is increased by rounding up, then
279    * extract the exponent. Because children_max is always nonnegative, we can
280    * do the downcast with "(floatBitsToUint(children_max) + 0xffff) >> 16",
281    * and then we further shift to get the rounded exponent.
282    */
283   u16vec3 exponents = u16vec3((floatBitsToUint(children_max) + 0xffff) >> 23);
284   u8vec3 exponents_u8 = u8vec3(exponents);
285   DEREF(dst_node).exponents[0] = exponents_u8.x;
286   DEREF(dst_node).exponents[1] = exponents_u8.y;
287   DEREF(dst_node).exponents[2] = exponents_u8.z;
288
289   for (uint32_t i = 0; i < child_count; i++) {
290      uint32_t offset = ir_id_to_offset(children[i]);
291
292      vk_aabb child_aabb =
293         DEREF(REF(vk_ir_node)OFFSET(args.intermediate_bvh, offset)).aabb;
294
295      /* Note: because we subtract from the minimum, we should never have a
296       * negative value here.
297       */
298      vec3 child_min = subtract_round_down_approx(child_aabb.min, base_float);
299      vec3 child_max = subtract_round_up_approx(child_aabb.max, base_float);
300
301      u16vec3 child_min_exponents = u16vec3(floatBitsToUint(child_min) >> 23);
302      u16vec3 child_max_exponents = u16vec3(floatBitsToUint(child_max) >> 23);
303
304      u16vec3 child_min_shift = u16vec3(16) + exponents - child_min_exponents;
305      /* Divide the mantissa by 2**child_min_shift, rounding down */
306      u8vec3 child_min_mantissas =
307         mix(u8vec3(extract_mantissa(child_min) >> child_min_shift), u8vec3(0),
308		       greaterThanEqual(child_min_shift, u16vec3(32)));
309      u16vec3 child_max_shift = u16vec3(16) + exponents - child_max_exponents;
310      /* Divide the mantissa by 2**child_max_shift, rounding up */
311      u8vec3 child_max_mantissas =
312         mix(u8vec3((extract_mantissa(child_max) + ((u32vec3(1u) << u32vec3(child_max_shift)) - 1)) >> child_max_shift),
313             u8vec3(notEqual(extract_mantissa(child_max), u32vec3(0))),
314             greaterThanEqual(child_max_shift, u16vec3(32)));
315
316      DEREF(dst_node).mantissas[i][0][0] = child_min_mantissas.x;
317      DEREF(dst_node).mantissas[i][0][1] = child_min_mantissas.y;
318      DEREF(dst_node).mantissas[i][0][2] = child_min_mantissas.z;
319      DEREF(dst_node).mantissas[i][1][0] = child_max_mantissas.x;
320      DEREF(dst_node).mantissas[i][1][1] = child_max_mantissas.y;
321      DEREF(dst_node).mantissas[i][1][2] = child_max_mantissas.z;
322   }
323
324   for (uint32_t i = child_count; i < 8; i++) {
325      DEREF(dst_node).mantissas[i][0][0] = uint8_t(0xff);
326      DEREF(dst_node).mantissas[i][0][1] = uint8_t(0xff);
327      DEREF(dst_node).mantissas[i][0][2] = uint8_t(0xff);
328      DEREF(dst_node).mantissas[i][1][0] = uint8_t(0);
329      DEREF(dst_node).mantissas[i][1][1] = uint8_t(0);
330      DEREF(dst_node).mantissas[i][1][2] = uint8_t(0);
331   }
332
333   DEREF(dst_node).child_count = uint8_t(child_count);
334   DEREF(dst_node).type_flags = uint16_t(args.geometry_type == VK_GEOMETRY_TYPE_INSTANCES_KHR ? (TU_NODE_TYPE_TLAS >> 16) : 0);
335}
336
337void
338main()
339{
340   /* Revert the order so we start at the root */
341   uint32_t global_id = DEREF(args.header).ir_internal_node_count - 1 - gl_GlobalInvocationID.x;
342
343   uint32_t intermediate_leaf_node_size;
344   switch (args.geometry_type) {
345   case VK_GEOMETRY_TYPE_TRIANGLES_KHR:
346      intermediate_leaf_node_size = SIZEOF(vk_ir_triangle_node);
347      break;
348   case VK_GEOMETRY_TYPE_AABBS_KHR:
349      intermediate_leaf_node_size = SIZEOF(vk_ir_aabb_node);
350      break;
351   default: /* instances */
352      intermediate_leaf_node_size = SIZEOF(vk_ir_instance_node);
353      break;
354   }
355
356   uint32_t intermediate_leaf_nodes_size = args.leaf_node_count * intermediate_leaf_node_size;
357
358   REF(vk_ir_box_node) intermediate_internal_nodes =
359      REF(vk_ir_box_node)OFFSET(args.intermediate_bvh, intermediate_leaf_nodes_size);
360   REF(vk_ir_box_node) src_node = INDEX(vk_ir_box_node, intermediate_internal_nodes, global_id);
361   vk_ir_box_node src = DEREF(src_node);
362
363   uint64_t dst_instances = args.output_bvh - args.output_bvh_offset + SIZEOF(tu_accel_struct_header);
364
365   bool is_root_node = global_id == DEREF(args.header).ir_internal_node_count - 1;
366
367   REF(tu_accel_struct_header) header = REF(tu_accel_struct_header)(args.output_bvh - args.output_bvh_offset);
368
369   if (is_root_node) {
370      DEREF(header).instance_flags =
371         (args.geometry_type == VK_GEOMETRY_TYPE_AABBS_KHR ? TU_INSTANCE_ALL_AABB : 0) |
372         /* These will be removed when processing leaf nodes */
373         TU_INSTANCE_ALL_NONOPAQUE | TU_INSTANCE_ALL_OPAQUE;
374      DEREF(args.header).dst_node_offset = 1;
375   }
376
377   for (;;) {
378      /* Make changes to the current node's BVH offset value visible. */
379      memoryBarrier(gl_ScopeDevice, gl_StorageSemanticsBuffer,
380                    gl_SemanticsAcquireRelease | gl_SemanticsMakeAvailable | gl_SemanticsMakeVisible);
381
382      uint32_t bvh_offset = is_root_node ? 0 : DEREF(src_node).bvh_offset;
383      if (bvh_offset == VK_UNKNOWN_BVH_OFFSET)
384         continue;
385
386      if (bvh_offset == VK_NULL_BVH_OFFSET)
387         break;
388
389      uint32_t found_child_count = 0;
390      uint32_t children[8] = {VK_BVH_INVALID_NODE, VK_BVH_INVALID_NODE,
391                              VK_BVH_INVALID_NODE, VK_BVH_INVALID_NODE,
392                              VK_BVH_INVALID_NODE, VK_BVH_INVALID_NODE,
393                              VK_BVH_INVALID_NODE, VK_BVH_INVALID_NODE};
394
395      for (uint32_t i = 0; i < 2; ++i)
396         if (src.children[i] != VK_BVH_INVALID_NODE)
397            children[found_child_count++] = src.children[i];
398
399      while (found_child_count < 8) {
400         int32_t collapsed_child_index = -1;
401         float largest_surface_area = -INFINITY;
402
403         for (int32_t i = 0; i < found_child_count; ++i) {
404            if (ir_id_to_type(children[i]) != vk_ir_node_internal)
405               continue;
406
407            vk_aabb bounds =
408               DEREF(REF(vk_ir_node)OFFSET(args.intermediate_bvh,
409                                           ir_id_to_offset(children[i]))).aabb;
410
411            float surface_area = aabb_surface_area(bounds);
412            if (surface_area > largest_surface_area) {
413               largest_surface_area = surface_area;
414               collapsed_child_index = i;
415            }
416         }
417
418         if (collapsed_child_index != -1) {
419            REF(vk_ir_box_node) child_node =
420               REF(vk_ir_box_node)OFFSET(args.intermediate_bvh,
421                                        ir_id_to_offset(children[collapsed_child_index]));
422            uint32_t grandchildren[2] = DEREF(child_node).children;
423            uint32_t valid_grandchild_count = 0;
424
425            if (grandchildren[1] != VK_BVH_INVALID_NODE)
426               ++valid_grandchild_count;
427
428            if (grandchildren[0] != VK_BVH_INVALID_NODE)
429               ++valid_grandchild_count;
430            else
431               grandchildren[0] = grandchildren[1];
432
433            if (valid_grandchild_count > 1)
434               children[found_child_count++] = grandchildren[1];
435
436            if (valid_grandchild_count > 0)
437               children[collapsed_child_index] = grandchildren[0];
438            else {
439               found_child_count--;
440               children[collapsed_child_index] = children[found_child_count];
441            }
442
443            DEREF(child_node).bvh_offset = VK_NULL_BVH_OFFSET;
444         } else
445            break;
446      }
447
448      /* If there is only one child, collapse the current node by setting the
449       * child's offset to this node's offset. Otherwise, use an atomic to
450       * allocate contiguous space for all of the children.
451       */
452      uint32_t children_offset = bvh_offset;
453      if (found_child_count > 1) {
454         children_offset = atomicAdd(DEREF(args.header).dst_node_offset, found_child_count);
455      }
456
457      vec3 min_offset = vec3(INFINITY);
458      vec3 max_offset = vec3(-INFINITY);
459      for (uint32_t i = 0; i < found_child_count; ++i) {
460         uint32_t type = ir_id_to_type(children[i]);
461         uint32_t offset = ir_id_to_offset(children[i]);
462         uint32_t dst_offset;
463
464         dst_offset = children_offset + i;
465
466         if (type == vk_ir_node_internal) {
467            REF(vk_ir_box_node) child_node = REF(vk_ir_box_node)OFFSET(args.intermediate_bvh, offset);
468            DEREF(child_node).bvh_offset = dst_offset;
469         } else {
470            encode_leaf_node(type, args.intermediate_bvh + offset,
471                             args.output_bvh + SIZEOF(tu_internal_node) * dst_offset, dst_instances,
472                             header);
473         }
474
475         vk_aabb child_aabb =
476            DEREF(REF(vk_ir_node)OFFSET(args.intermediate_bvh, offset)).aabb;
477
478         min_offset = min(min_offset, child_aabb.min);
479         max_offset = max(max_offset, child_aabb.max);
480
481         if (found_child_count > 1) {
482            set_parent(dst_offset, bvh_offset);
483         }
484      }
485
486      /* Make changes to the children's BVH offset value available to the other invocations. */
487      memoryBarrier(gl_ScopeDevice, gl_StorageSemanticsBuffer,
488                    gl_SemanticsAcquireRelease | gl_SemanticsMakeAvailable | gl_SemanticsMakeVisible);
489
490      if (found_child_count > 1 || found_child_count == 0)
491         encode_internal_node(children, children_offset, found_child_count, min_offset, max_offset, bvh_offset);
492
493      break;
494   }
495
496   if (is_root_node) {
497      DEREF(header).aabb = src.base.aabb;
498      DEREF(header).bvh_ptr = args.output_bvh;
499
500      set_parent(0, VK_BVH_INVALID_NODE);
501   }
502}
503