• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2024 The Amber Authors.
2 // Copyright (C) 2024 Advanced Micro Devices, Inc. All rights reserved.
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 
16 #include "src/vulkan/tlas.h"
17 #include "src/vulkan/blas.h"
18 
19 namespace amber {
20 namespace vulkan {
21 
makeVkMatrix(const float * m)22 static VkTransformMatrixKHR makeVkMatrix(const float* m) {
23   const VkTransformMatrixKHR identityMatrix3x4 = {{{1.0f, 0.0f, 0.0f, 0.0f},
24                                                    {0.0f, 1.0f, 0.0f, 0.0f},
25                                                    {0.0f, 0.0f, 1.0f, 0.0f}}};
26   VkTransformMatrixKHR v;
27 
28   if (m == nullptr)
29     return identityMatrix3x4;
30 
31   for (size_t i = 0; i < 12; i++) {
32     const size_t r = i / 4;
33     const size_t c = i % 4;
34     v.matrix[r][c] = m[i];
35   }
36 
37   return v;
38 }
39 
TLAS(Device * device)40 TLAS::TLAS(Device* device) : device_(device) {}
41 
CreateTLAS(amber::TLAS * tlas,BlasesMap * blases)42 Result TLAS::CreateTLAS(amber::TLAS* tlas,
43                         BlasesMap* blases) {
44   if (tlas_ != VK_NULL_HANDLE)
45     return {};
46 
47   assert(tlas != nullptr);
48 
49   VkDeviceOrHostAddressConstKHR const_default_ptr;
50   VkDeviceOrHostAddressKHR default_ptr;
51 
52   const_default_ptr.hostAddress = nullptr;
53   default_ptr.hostAddress = nullptr;
54 
55   instances_count_ = static_cast<uint32_t>(tlas->GetInstances().size());
56 
57   const uint32_t ib_size =
58       uint32_t(instances_count_ * sizeof(VkAccelerationStructureInstanceKHR));
59 
60   instance_buffer_ = MakeUnique<TransferBuffer>(device_, ib_size, nullptr);
61   instance_buffer_->AddUsageFlags(
62       VK_BUFFER_USAGE_ACCELERATION_STRUCTURE_BUILD_INPUT_READ_ONLY_BIT_KHR |
63       VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT);
64   instance_buffer_->AddAllocateFlags(VK_MEMORY_ALLOCATE_DEVICE_ADDRESS_BIT);
65   instance_buffer_->Initialize();
66 
67   VkAccelerationStructureInstanceKHR* instances_ptr =
68       reinterpret_cast<VkAccelerationStructureInstanceKHR*>
69           (instance_buffer_->HostAccessibleMemoryPtr());
70 
71   for (auto& instance : tlas->GetInstances()) {
72     auto blas = instance->GetUsedBLAS();
73 
74     assert(blas != nullptr);
75 
76     auto blas_vulkan_it = blases->find(blas);
77     amber::vulkan::BLAS* blas_vulkan_ptr = nullptr;
78 
79     if (blas_vulkan_it == blases->end()) {
80       auto blas_vulkan =
81           blases->emplace(blas, new amber::vulkan::BLAS(device_));
82       blas_vulkan_ptr = blas_vulkan.first->second.get();
83 
84       Result r = blas_vulkan_ptr->CreateBLAS(blas);
85 
86       if (!r.IsSuccess())
87         return r;
88     } else {
89       blas_vulkan_ptr = blas_vulkan_it->second.get();
90     }
91 
92     VkDeviceAddress accelerationStructureAddress =
93         blas_vulkan_ptr->getVkBLASDeviceAddress();
94 
95     *instances_ptr = VkAccelerationStructureInstanceKHR{
96         makeVkMatrix(instance->GetTransform()),
97         instance->GetInstanceIndex(),
98         instance->GetMask(),
99         instance->GetOffset(),
100         instance->GetFlags(),
101         static_cast<uint64_t>(accelerationStructureAddress)};
102 
103     instances_ptr++;
104   }
105 
106   VkAccelerationStructureGeometryInstancesDataKHR
107       accelerationStructureGeometryInstancesDataKHR = {
108           VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_GEOMETRY_INSTANCES_DATA_KHR,
109           nullptr,
110           VK_FALSE,
111           const_default_ptr,
112       };
113   VkAccelerationStructureGeometryDataKHR geometry = {};
114   geometry.instances = accelerationStructureGeometryInstancesDataKHR;
115 
116   accelerationStructureGeometryKHR_ = {
117       VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_GEOMETRY_KHR,
118       nullptr,
119       VK_GEOMETRY_TYPE_INSTANCES_KHR,
120       geometry,
121       0,
122   };
123 
124   accelerationStructureBuildGeometryInfoKHR_ = {
125       VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_BUILD_GEOMETRY_INFO_KHR,
126       nullptr,
127       VK_ACCELERATION_STRUCTURE_TYPE_TOP_LEVEL_KHR,
128       0,
129       VK_BUILD_ACCELERATION_STRUCTURE_MODE_BUILD_KHR,
130       VK_NULL_HANDLE,
131       VK_NULL_HANDLE,
132       1,
133       &accelerationStructureGeometryKHR_,
134       nullptr,
135       default_ptr,
136   };
137 
138   VkAccelerationStructureBuildSizesInfoKHR sizeInfo = {
139       VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_BUILD_SIZES_INFO_KHR,
140       nullptr,
141       0,
142       0,
143       0,
144   };
145 
146   device_->GetPtrs()->vkGetAccelerationStructureBuildSizesKHR(
147       device_->GetVkDevice(), VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR,
148       &accelerationStructureBuildGeometryInfoKHR_, &instances_count_,
149       &sizeInfo);
150 
151   const uint32_t as_size =
152       static_cast<uint32_t>(sizeInfo.accelerationStructureSize);
153 
154   buffer_ = MakeUnique<TransferBuffer>(device_, as_size, nullptr);
155   buffer_->AddUsageFlags(
156       VK_BUFFER_USAGE_ACCELERATION_STRUCTURE_STORAGE_BIT_KHR |
157       VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT);
158   buffer_->AddAllocateFlags(VK_MEMORY_ALLOCATE_DEVICE_ADDRESS_BIT);
159   buffer_->Initialize();
160 
161   const VkAccelerationStructureCreateInfoKHR
162       accelerationStructureCreateInfoKHR = {
163           VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_CREATE_INFO_KHR,
164           nullptr,
165           0,
166           buffer_->GetVkBuffer(),
167           0,
168           as_size,
169           VK_ACCELERATION_STRUCTURE_TYPE_TOP_LEVEL_KHR,
170           0,
171       };
172 
173   if (device_->GetPtrs()->vkCreateAccelerationStructureKHR(
174           device_->GetVkDevice(), &accelerationStructureCreateInfoKHR, nullptr,
175           &tlas_) != VK_SUCCESS) {
176     return Result(
177         "Vulkan::Calling vkCreateAccelerationStructureKHR "
178         "failed");
179   }
180 
181   accelerationStructureBuildGeometryInfoKHR_.dstAccelerationStructure = tlas_;
182 
183   if (sizeInfo.buildScratchSize > 0) {
184     scratch_buffer_ = MakeUnique<TransferBuffer>(
185         device_, static_cast<uint32_t>(sizeInfo.buildScratchSize), nullptr);
186     scratch_buffer_->AddUsageFlags(VK_BUFFER_USAGE_STORAGE_BUFFER_BIT |
187                                    VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT);
188     scratch_buffer_->AddAllocateFlags(VK_MEMORY_ALLOCATE_DEVICE_ADDRESS_BIT);
189     scratch_buffer_->Initialize();
190 
191     accelerationStructureBuildGeometryInfoKHR_.scratchData.deviceAddress =
192         scratch_buffer_->getBufferDeviceAddress();
193   }
194 
195   accelerationStructureGeometryKHR_.geometry.instances.data.deviceAddress =
196       instance_buffer_->getBufferDeviceAddress();
197 
198   return {};
199 }
200 
BuildTLAS(VkCommandBuffer cmdBuffer)201 Result TLAS::BuildTLAS(VkCommandBuffer cmdBuffer) {
202   if (tlas_ == VK_NULL_HANDLE)
203     return Result("Acceleration structure should be created first");
204   if (built_)
205     return {};
206 
207   VkAccelerationStructureBuildRangeInfoKHR
208       accelerationStructureBuildRangeInfoKHR = {instances_count_, 0, 0, 0};
209   VkAccelerationStructureBuildRangeInfoKHR*
210       accelerationStructureBuildRangeInfoKHRPtr =
211           &accelerationStructureBuildRangeInfoKHR;
212 
213   device_->GetPtrs()->vkCmdBuildAccelerationStructuresKHR(
214       cmdBuffer, 1, &accelerationStructureBuildGeometryInfoKHR_,
215       &accelerationStructureBuildRangeInfoKHRPtr);
216 
217   const VkAccessFlags accessMasks =
218       VK_ACCESS_ACCELERATION_STRUCTURE_WRITE_BIT_KHR |
219       VK_ACCESS_ACCELERATION_STRUCTURE_READ_BIT_KHR;
220   const VkMemoryBarrier memBarrier{
221       VK_STRUCTURE_TYPE_MEMORY_BARRIER,
222       nullptr,
223       accessMasks,
224       accessMasks,
225   };
226 
227   device_->GetPtrs()->vkCmdPipelineBarrier(
228       cmdBuffer, VK_PIPELINE_STAGE_ACCELERATION_STRUCTURE_BUILD_BIT_KHR,
229       VK_PIPELINE_STAGE_ALL_COMMANDS_BIT, 0, 1, &memBarrier, 0, nullptr, 0,
230       nullptr);
231 
232   built_ = true;
233 
234   return {};
235 }
236 
~TLAS()237 TLAS::~TLAS() {
238   if (tlas_ != VK_NULL_HANDLE) {
239     device_->GetPtrs()->vkDestroyAccelerationStructureKHR(
240         device_->GetVkDevice(), tlas_, nullptr);
241   }
242 }
243 
244 }  // namespace vulkan
245 }  // namespace amber
246