1/* 2 * Copyright © 2022 Friedrich Vock 3 * 4 * Permission is hereby granted, free of charge, to any person obtaining a 5 * copy of this software and associated documentation files (the "Software"), 6 * to deal in the Software without restriction, including without limitation 7 * the rights to use, copy, modify, merge, publish, distribute, sublicense, 8 * and/or sell copies of the Software, and to permit persons to whom the 9 * Software is furnished to do so, subject to the following conditions: 10 * 11 * The above copyright notice and this permission notice (including the next 12 * paragraph) shall be included in all copies or substantial portions of the 13 * Software. 14 * 15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS 21 * IN THE SOFTWARE. 22 */ 23 24#version 460 25 26#extension GL_GOOGLE_include_directive : require 27 28#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require 29#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require 30#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require 31#extension GL_EXT_shader_explicit_arithmetic_types_int64 : require 32#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require 33#extension GL_EXT_scalar_block_layout : require 34#extension GL_EXT_buffer_reference : require 35#extension GL_EXT_buffer_reference2 : require 36#extension GL_KHR_memory_scope_semantics : require 37 38layout(local_size_x = 128, local_size_y = 1, local_size_z = 1) in; 39 40#include "tu_build_helpers.h" 41#include "tu_build_interface.h" 42 43layout(push_constant) uniform CONSTS { 44 encode_args args; 45}; 46 47void set_parent(uint32_t child, uint32_t parent) 48{ 49 uint64_t addr = args.output_bvh - child * 4 - 4; 50 DEREF(REF(uint32_t)(addr)) = parent; 51} 52 53/* This encoder struct is designed to encode a compressed node without keeping 54 * all the data live at once, making sure register pressure isn't too high. 55 */ 56 57struct tu_encoder { 58 uint32_t cur_value; 59 uint word_offset; 60 uint bit_offset; 61 REF(tu_compressed_node) node; 62}; 63 64void encode_init(out tu_encoder encoder, REF(tu_compressed_node) node) 65{ 66 encoder.cur_value = 0; 67 encoder.word_offset = 0; 68 encoder.bit_offset = 0; 69 encoder.node = node; 70} 71 72void encode(inout tu_encoder encoder, uint32_t val, uint bits) 73{ 74 encoder.cur_value |= val << encoder.bit_offset; 75 if (encoder.bit_offset + bits >= 32) { 76 DEREF(encoder.node).data[encoder.word_offset] = encoder.cur_value; 77 encoder.cur_value = val >> (32 - encoder.bit_offset); 78 encoder.word_offset++; 79 encoder.bit_offset = encoder.bit_offset + bits - 32; 80 } else { 81 encoder.bit_offset += bits; 82 } 83} 84 85void encode_skip(inout tu_encoder encoder, uint bits) 86{ 87 if (encoder.bit_offset + bits >= 32) { 88 DEREF(encoder.node).data[encoder.word_offset] = encoder.cur_value; 89 encoder.word_offset++; 90 encoder.bit_offset = encoder.bit_offset + bits - 32; 91 } else { 92 encoder.bit_offset += bits; 93 } 94} 95 96void encode_finalize(tu_encoder encoder) 97{ 98 DEREF(encoder.node).data[encoder.word_offset] = encoder.cur_value; 99} 100 101void 102encode_leaf_node(uint32_t type, uint64_t src_node, uint64_t dst_node, uint64_t dst_instances, REF(tu_accel_struct_header) dst_header) 103{ 104 float coords[3][3]; 105 uint32_t id; 106 uint32_t geometry_id; 107 uint32_t type_flags = TU_NODE_TYPE_LEAF; 108 109 switch (type) { 110 case vk_ir_node_triangle: { 111 vk_ir_triangle_node src = DEREF(REF(vk_ir_triangle_node)(src_node)); 112 113 coords = src.coords; 114 uint32_t geometry_id_and_flags = src.geometry_id_and_flags; 115 if ((geometry_id_and_flags & VK_GEOMETRY_OPAQUE) != 0) { 116 atomicAnd(DEREF(dst_header).instance_flags, ~TU_INSTANCE_ALL_NONOPAQUE); 117 } else { 118 type_flags |= TU_NODE_TYPE_NONOPAQUE; 119 atomicAnd(DEREF(dst_header).instance_flags, ~TU_INSTANCE_ALL_OPAQUE); 120 } 121 geometry_id = geometry_id_and_flags & 0xffffff; 122 id = src.triangle_id; 123 break; 124 } 125 case vk_ir_node_aabb: { 126 vk_ir_aabb_node src = DEREF(REF(vk_ir_aabb_node)(src_node)); 127 vk_aabb aabb = src.base.aabb; 128 coords[0][0] = aabb.min[0]; 129 coords[0][1] = aabb.min[1]; 130 coords[0][2] = aabb.min[2]; 131 coords[1][0] = aabb.max[0]; 132 coords[1][1] = aabb.max[1]; 133 coords[1][2] = aabb.max[2]; 134 135 type_flags |= TU_NODE_TYPE_AABB; 136 137 if ((src.geometry_id_and_flags & VK_GEOMETRY_OPAQUE) != 0) { 138 atomicAnd(DEREF(dst_header).instance_flags, ~TU_INSTANCE_ALL_NONOPAQUE); 139 } else { 140 type_flags |= TU_NODE_TYPE_NONOPAQUE; 141 atomicAnd(DEREF(dst_header).instance_flags, ~TU_INSTANCE_ALL_OPAQUE); 142 } 143 geometry_id = src.geometry_id_and_flags & 0xffffff; 144 id = src.primitive_id; 145 break; 146 } 147 case vk_ir_node_instance: { 148 vk_ir_instance_node src = DEREF(REF(vk_ir_instance_node)(src_node)); 149 150 id = src.instance_id; 151 geometry_id = 0; 152 REF(tu_instance_descriptor) dst_instance = REF(tu_instance_descriptor)(dst_instances + SIZEOF(tu_instance_descriptor) * id); 153 154 REF(tu_accel_struct_header) blas_header = REF(tu_accel_struct_header)(src.base_ptr); 155 uint64_t bvh_ptr = DEREF(blas_header).bvh_ptr; 156 uint32_t bvh_offset = uint32_t(bvh_ptr - src.base_ptr); 157 158 uint32_t sbt_offset_and_flags = src.sbt_offset_and_flags; 159 uint32_t custom_instance_and_mask = src.custom_instance_and_mask; 160 DEREF(dst_instance).bvh_ptr = bvh_ptr; 161 DEREF(dst_instance).custom_instance_index = custom_instance_and_mask & 0xffffffu; 162 DEREF(dst_instance).sbt_offset_and_flags = sbt_offset_and_flags; 163 DEREF(dst_instance).bvh_offset = bvh_offset; 164 165 mat4 transform = mat4(src.otw_matrix); 166 167 mat4 inv_transform = transpose(inverse(transpose(transform))); 168 DEREF(dst_instance).wto_matrix = mat3x4(inv_transform); 169 DEREF(dst_instance).otw_matrix = mat3x4(transform); 170 171 vk_aabb aabb = src.base.aabb; 172 coords[0][0] = aabb.min[0]; 173 coords[0][1] = aabb.min[1]; 174 coords[0][2] = aabb.min[2]; 175 coords[1][0] = aabb.max[0]; 176 coords[1][1] = aabb.max[1]; 177 coords[1][2] = aabb.max[2]; 178 179 type_flags |= TU_NODE_TYPE_TLAS; 180 181 uint32_t instance_flags = DEREF(blas_header).instance_flags; 182 183 /* Apply VK_GEOMETRY_INSTANCE_FORCE_OPAQUE_BIT_KHR and 184 * VK_GEOMETRY_INSTANCE_FORCE_NO_OPAQUE_BIT_KHR to correct the 185 * ALL_OPAQUE/ALL_NONOPAQUE flags. 186 */ 187 if (((sbt_offset_and_flags >> 24) & (VK_GEOMETRY_INSTANCE_FORCE_OPAQUE_BIT_KHR | 188 VK_GEOMETRY_INSTANCE_FORCE_NO_OPAQUE_BIT_KHR)) != 0) { 189 instance_flags &= ~(VK_GEOMETRY_INSTANCE_FORCE_OPAQUE_BIT_KHR | 190 VK_GEOMETRY_INSTANCE_FORCE_NO_OPAQUE_BIT_KHR); 191 instance_flags |= (sbt_offset_and_flags >> 24) & (VK_GEOMETRY_INSTANCE_FORCE_OPAQUE_BIT_KHR | 192 VK_GEOMETRY_INSTANCE_FORCE_NO_OPAQUE_BIT_KHR); 193 } 194 uint32_t cull_mask_and_flags = ((custom_instance_and_mask >> 16) & 0xff00) | instance_flags; 195 196 coords[2][0] = uintBitsToFloat(cull_mask_and_flags); 197 break; 198 } 199 } 200 201 REF(tu_leaf_node) dst = REF(tu_leaf_node)(dst_node); 202 DEREF(dst).coords = coords; 203 DEREF(dst).id = id; 204 DEREF(dst).geometry_id = geometry_id; 205 DEREF(dst).type_flags = type_flags; 206} 207 208/* Truncate to bfloat16 while rounding down. bfloat16 is used to store the bases. 209 */ 210 211u16vec3 to_bfloat_round_down(vec3 coord) 212{ 213 u32vec3 icoord = floatBitsToUint(coord); 214 return u16vec3(mix(icoord >> 16, (icoord + 0xffff) >> 16, notEqual(icoord & u32vec3(0x80000000), u32vec3(0)))); 215} 216 217/* Approximate subtraction while rounding up. Return a result greater than or 218 * equal to the infinitely-precise result. This just uses the native 219 * subtraction and then shifts one ULP towards infinity. Because the result is 220 * further rounded, it should usually be good enough while being faster than 221 * emulated floating-point math. 222 * 223 * We assume here that the result is always nonnegative, because it's only used 224 * to subtract away the base. 225 */ 226 227vec3 subtract_round_up_approx(vec3 a, vec3 b) 228{ 229 vec3 f = a - b; 230 u32vec3 i = floatBitsToUint(f); 231 232 i++; 233 234 /* Handle infinity/zero special cases */ 235 i = mix(i, floatBitsToUint(f), isinf(f)); 236 i = mix(i, floatBitsToUint(f), equal(f, vec3(0))); 237 238 return uintBitsToFloat(i); 239} 240 241vec3 subtract_round_down_approx(vec3 a, vec3 b) 242{ 243 vec3 f = a - b; 244 u32vec3 i = floatBitsToUint(f); 245 246 i--; 247 248 /* Handle infinity/zero special cases */ 249 i = mix(i, floatBitsToUint(f), isinf(f)); 250 i = mix(i, floatBitsToUint(f), equal(f, vec3(0))); 251 252 return uintBitsToFloat(i); 253} 254 255u32vec3 extract_mantissa(vec3 f) 256{ 257 return mix((floatBitsToUint(f) & 0x7fffff) | 0x800000, u32vec3(0), equal(f, vec3(0))); 258} 259 260void 261encode_internal_node(uint32_t children[8], uint32_t children_offset, uint child_count, 262 vec3 min_offset, vec3 max_offset, uint32_t bvh_offset) 263{ 264 REF(tu_internal_node) dst_node = REF(tu_internal_node)(OFFSET(args.output_bvh, SIZEOF(tu_internal_node) * bvh_offset)); 265 266 DEREF(dst_node).id = children_offset; 267 268 u16vec3 base_bfloat = to_bfloat_round_down(min_offset); 269 vec3 base_float = uintBitsToFloat(u32vec3(base_bfloat) << 16); 270 DEREF(dst_node).bases[0] = base_bfloat.x; 271 DEREF(dst_node).bases[1] = base_bfloat.y; 272 DEREF(dst_node).bases[2] = base_bfloat.z; 273 274 vec3 children_max = subtract_round_up_approx(max_offset, base_float); 275 276 /* The largest child offset will be encoded in 8 bits, including the 277 * explicit leading 1. We need to downcast to this precision while rounding 278 * up to catch cases where the exponent is increased by rounding up, then 279 * extract the exponent. Because children_max is always nonnegative, we can 280 * do the downcast with "(floatBitsToUint(children_max) + 0xffff) >> 16", 281 * and then we further shift to get the rounded exponent. 282 */ 283 u16vec3 exponents = u16vec3((floatBitsToUint(children_max) + 0xffff) >> 23); 284 u8vec3 exponents_u8 = u8vec3(exponents); 285 DEREF(dst_node).exponents[0] = exponents_u8.x; 286 DEREF(dst_node).exponents[1] = exponents_u8.y; 287 DEREF(dst_node).exponents[2] = exponents_u8.z; 288 289 for (uint32_t i = 0; i < child_count; i++) { 290 uint32_t offset = ir_id_to_offset(children[i]); 291 292 vk_aabb child_aabb = 293 DEREF(REF(vk_ir_node)OFFSET(args.intermediate_bvh, offset)).aabb; 294 295 /* Note: because we subtract from the minimum, we should never have a 296 * negative value here. 297 */ 298 vec3 child_min = subtract_round_down_approx(child_aabb.min, base_float); 299 vec3 child_max = subtract_round_up_approx(child_aabb.max, base_float); 300 301 u16vec3 child_min_exponents = u16vec3(floatBitsToUint(child_min) >> 23); 302 u16vec3 child_max_exponents = u16vec3(floatBitsToUint(child_max) >> 23); 303 304 u16vec3 child_min_shift = u16vec3(16) + exponents - child_min_exponents; 305 /* Divide the mantissa by 2**child_min_shift, rounding down */ 306 u8vec3 child_min_mantissas = 307 mix(u8vec3(extract_mantissa(child_min) >> child_min_shift), u8vec3(0), 308 greaterThanEqual(child_min_shift, u16vec3(32))); 309 u16vec3 child_max_shift = u16vec3(16) + exponents - child_max_exponents; 310 /* Divide the mantissa by 2**child_max_shift, rounding up */ 311 u8vec3 child_max_mantissas = 312 mix(u8vec3((extract_mantissa(child_max) + ((u32vec3(1u) << u32vec3(child_max_shift)) - 1)) >> child_max_shift), 313 u8vec3(notEqual(extract_mantissa(child_max), u32vec3(0))), 314 greaterThanEqual(child_max_shift, u16vec3(32))); 315 316 DEREF(dst_node).mantissas[i][0][0] = child_min_mantissas.x; 317 DEREF(dst_node).mantissas[i][0][1] = child_min_mantissas.y; 318 DEREF(dst_node).mantissas[i][0][2] = child_min_mantissas.z; 319 DEREF(dst_node).mantissas[i][1][0] = child_max_mantissas.x; 320 DEREF(dst_node).mantissas[i][1][1] = child_max_mantissas.y; 321 DEREF(dst_node).mantissas[i][1][2] = child_max_mantissas.z; 322 } 323 324 for (uint32_t i = child_count; i < 8; i++) { 325 DEREF(dst_node).mantissas[i][0][0] = uint8_t(0xff); 326 DEREF(dst_node).mantissas[i][0][1] = uint8_t(0xff); 327 DEREF(dst_node).mantissas[i][0][2] = uint8_t(0xff); 328 DEREF(dst_node).mantissas[i][1][0] = uint8_t(0); 329 DEREF(dst_node).mantissas[i][1][1] = uint8_t(0); 330 DEREF(dst_node).mantissas[i][1][2] = uint8_t(0); 331 } 332 333 DEREF(dst_node).child_count = uint8_t(child_count); 334 DEREF(dst_node).type_flags = uint16_t(args.geometry_type == VK_GEOMETRY_TYPE_INSTANCES_KHR ? (TU_NODE_TYPE_TLAS >> 16) : 0); 335} 336 337void 338main() 339{ 340 /* Revert the order so we start at the root */ 341 uint32_t global_id = DEREF(args.header).ir_internal_node_count - 1 - gl_GlobalInvocationID.x; 342 343 uint32_t intermediate_leaf_node_size; 344 switch (args.geometry_type) { 345 case VK_GEOMETRY_TYPE_TRIANGLES_KHR: 346 intermediate_leaf_node_size = SIZEOF(vk_ir_triangle_node); 347 break; 348 case VK_GEOMETRY_TYPE_AABBS_KHR: 349 intermediate_leaf_node_size = SIZEOF(vk_ir_aabb_node); 350 break; 351 default: /* instances */ 352 intermediate_leaf_node_size = SIZEOF(vk_ir_instance_node); 353 break; 354 } 355 356 uint32_t intermediate_leaf_nodes_size = args.leaf_node_count * intermediate_leaf_node_size; 357 358 REF(vk_ir_box_node) intermediate_internal_nodes = 359 REF(vk_ir_box_node)OFFSET(args.intermediate_bvh, intermediate_leaf_nodes_size); 360 REF(vk_ir_box_node) src_node = INDEX(vk_ir_box_node, intermediate_internal_nodes, global_id); 361 vk_ir_box_node src = DEREF(src_node); 362 363 uint64_t dst_instances = args.output_bvh - args.output_bvh_offset + SIZEOF(tu_accel_struct_header); 364 365 bool is_root_node = global_id == DEREF(args.header).ir_internal_node_count - 1; 366 367 REF(tu_accel_struct_header) header = REF(tu_accel_struct_header)(args.output_bvh - args.output_bvh_offset); 368 369 if (is_root_node) { 370 DEREF(header).instance_flags = 371 (args.geometry_type == VK_GEOMETRY_TYPE_AABBS_KHR ? TU_INSTANCE_ALL_AABB : 0) | 372 /* These will be removed when processing leaf nodes */ 373 TU_INSTANCE_ALL_NONOPAQUE | TU_INSTANCE_ALL_OPAQUE; 374 DEREF(args.header).dst_node_offset = 1; 375 } 376 377 for (;;) { 378 /* Make changes to the current node's BVH offset value visible. */ 379 memoryBarrier(gl_ScopeDevice, gl_StorageSemanticsBuffer, 380 gl_SemanticsAcquireRelease | gl_SemanticsMakeAvailable | gl_SemanticsMakeVisible); 381 382 uint32_t bvh_offset = is_root_node ? 0 : DEREF(src_node).bvh_offset; 383 if (bvh_offset == VK_UNKNOWN_BVH_OFFSET) 384 continue; 385 386 if (bvh_offset == VK_NULL_BVH_OFFSET) 387 break; 388 389 uint32_t found_child_count = 0; 390 uint32_t children[8] = {VK_BVH_INVALID_NODE, VK_BVH_INVALID_NODE, 391 VK_BVH_INVALID_NODE, VK_BVH_INVALID_NODE, 392 VK_BVH_INVALID_NODE, VK_BVH_INVALID_NODE, 393 VK_BVH_INVALID_NODE, VK_BVH_INVALID_NODE}; 394 395 for (uint32_t i = 0; i < 2; ++i) 396 if (src.children[i] != VK_BVH_INVALID_NODE) 397 children[found_child_count++] = src.children[i]; 398 399 while (found_child_count < 8) { 400 int32_t collapsed_child_index = -1; 401 float largest_surface_area = -INFINITY; 402 403 for (int32_t i = 0; i < found_child_count; ++i) { 404 if (ir_id_to_type(children[i]) != vk_ir_node_internal) 405 continue; 406 407 vk_aabb bounds = 408 DEREF(REF(vk_ir_node)OFFSET(args.intermediate_bvh, 409 ir_id_to_offset(children[i]))).aabb; 410 411 float surface_area = aabb_surface_area(bounds); 412 if (surface_area > largest_surface_area) { 413 largest_surface_area = surface_area; 414 collapsed_child_index = i; 415 } 416 } 417 418 if (collapsed_child_index != -1) { 419 REF(vk_ir_box_node) child_node = 420 REF(vk_ir_box_node)OFFSET(args.intermediate_bvh, 421 ir_id_to_offset(children[collapsed_child_index])); 422 uint32_t grandchildren[2] = DEREF(child_node).children; 423 uint32_t valid_grandchild_count = 0; 424 425 if (grandchildren[1] != VK_BVH_INVALID_NODE) 426 ++valid_grandchild_count; 427 428 if (grandchildren[0] != VK_BVH_INVALID_NODE) 429 ++valid_grandchild_count; 430 else 431 grandchildren[0] = grandchildren[1]; 432 433 if (valid_grandchild_count > 1) 434 children[found_child_count++] = grandchildren[1]; 435 436 if (valid_grandchild_count > 0) 437 children[collapsed_child_index] = grandchildren[0]; 438 else { 439 found_child_count--; 440 children[collapsed_child_index] = children[found_child_count]; 441 } 442 443 DEREF(child_node).bvh_offset = VK_NULL_BVH_OFFSET; 444 } else 445 break; 446 } 447 448 /* If there is only one child, collapse the current node by setting the 449 * child's offset to this node's offset. Otherwise, use an atomic to 450 * allocate contiguous space for all of the children. 451 */ 452 uint32_t children_offset = bvh_offset; 453 if (found_child_count > 1) { 454 children_offset = atomicAdd(DEREF(args.header).dst_node_offset, found_child_count); 455 } 456 457 vec3 min_offset = vec3(INFINITY); 458 vec3 max_offset = vec3(-INFINITY); 459 for (uint32_t i = 0; i < found_child_count; ++i) { 460 uint32_t type = ir_id_to_type(children[i]); 461 uint32_t offset = ir_id_to_offset(children[i]); 462 uint32_t dst_offset; 463 464 dst_offset = children_offset + i; 465 466 if (type == vk_ir_node_internal) { 467 REF(vk_ir_box_node) child_node = REF(vk_ir_box_node)OFFSET(args.intermediate_bvh, offset); 468 DEREF(child_node).bvh_offset = dst_offset; 469 } else { 470 encode_leaf_node(type, args.intermediate_bvh + offset, 471 args.output_bvh + SIZEOF(tu_internal_node) * dst_offset, dst_instances, 472 header); 473 } 474 475 vk_aabb child_aabb = 476 DEREF(REF(vk_ir_node)OFFSET(args.intermediate_bvh, offset)).aabb; 477 478 min_offset = min(min_offset, child_aabb.min); 479 max_offset = max(max_offset, child_aabb.max); 480 481 if (found_child_count > 1) { 482 set_parent(dst_offset, bvh_offset); 483 } 484 } 485 486 /* Make changes to the children's BVH offset value available to the other invocations. */ 487 memoryBarrier(gl_ScopeDevice, gl_StorageSemanticsBuffer, 488 gl_SemanticsAcquireRelease | gl_SemanticsMakeAvailable | gl_SemanticsMakeVisible); 489 490 if (found_child_count > 1 || found_child_count == 0) 491 encode_internal_node(children, children_offset, found_child_count, min_offset, max_offset, bvh_offset); 492 493 break; 494 } 495 496 if (is_root_node) { 497 DEREF(header).aabb = src.base.aabb; 498 DEREF(header).bvh_ptr = args.output_bvh; 499 500 set_parent(0, VK_BVH_INVALID_NODE); 501 } 502} 503