1 // Copyright 2024 The Amber Authors.
2 // Copyright (C) 2024 Advanced Micro Devices, Inc. All rights reserved.
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15
16 #include <utility>
17
18 #include "src/vulkan/raytracing_pipeline.h"
19
20 #include "src/vulkan/blas.h"
21 #include "src/vulkan/command_pool.h"
22 #include "src/vulkan/device.h"
23 #include "src/vulkan/sbt.h"
24 #include "src/vulkan/tlas.h"
25
26 namespace amber {
27 namespace vulkan {
28
makeStridedDeviceAddressRegionKHR(VkDeviceAddress deviceAddress,VkDeviceSize stride,VkDeviceSize size)29 inline VkStridedDeviceAddressRegionKHR makeStridedDeviceAddressRegionKHR(
30 VkDeviceAddress deviceAddress,
31 VkDeviceSize stride,
32 VkDeviceSize size) {
33 VkStridedDeviceAddressRegionKHR res;
34 res.deviceAddress = deviceAddress;
35 res.stride = stride;
36 res.size = size;
37 return res;
38 }
39
getBufferDeviceAddress(Device * device,VkBuffer buffer)40 inline VkDeviceAddress getBufferDeviceAddress(Device* device, VkBuffer buffer) {
41 const VkBufferDeviceAddressInfo bufferDeviceAddressInfo = {
42 VK_STRUCTURE_TYPE_BUFFER_DEVICE_ADDRESS_INFO_KHR,
43 nullptr,
44 buffer,
45 };
46
47 return device->GetPtrs()->vkGetBufferDeviceAddress(device->GetVkDevice(),
48 &bufferDeviceAddressInfo);
49 }
50
RayTracingPipeline(Device * device,BlasesMap * blases,TlasesMap * tlases,uint32_t fence_timeout_ms,bool pipeline_runtime_layer_enabled,const std::vector<VkPipelineShaderStageCreateInfo> & shader_stage_info,VkPipelineCreateFlags create_flags)51 RayTracingPipeline::RayTracingPipeline(
52 Device* device,
53 BlasesMap* blases,
54 TlasesMap* tlases,
55 uint32_t fence_timeout_ms,
56 bool pipeline_runtime_layer_enabled,
57 const std::vector<VkPipelineShaderStageCreateInfo>& shader_stage_info,
58 VkPipelineCreateFlags create_flags)
59 : Pipeline(PipelineType::kRayTracing,
60 device,
61 fence_timeout_ms,
62 pipeline_runtime_layer_enabled,
63 shader_stage_info,
64 create_flags),
65 shader_group_create_info_(),
66 blases_(blases),
67 tlases_(tlases) {}
68
69 RayTracingPipeline::~RayTracingPipeline() = default;
70
Initialize(CommandPool * pool,std::vector<VkRayTracingShaderGroupCreateInfoKHR> & shader_group_create_info)71 Result RayTracingPipeline::Initialize(
72 CommandPool* pool,
73 std::vector<VkRayTracingShaderGroupCreateInfoKHR>&
74 shader_group_create_info) {
75 shader_group_create_info_.swap(shader_group_create_info);
76
77 return Pipeline::Initialize(pool);
78 }
79
CreateVkRayTracingPipeline(const VkPipelineLayout & pipeline_layout,VkPipeline * pipeline,const std::vector<VkPipeline> & libs,uint32_t maxPipelineRayPayloadSize,uint32_t maxPipelineRayHitAttributeSize,uint32_t maxPipelineRayRecursionDepth)80 Result RayTracingPipeline::CreateVkRayTracingPipeline(
81 const VkPipelineLayout& pipeline_layout,
82 VkPipeline* pipeline,
83 const std::vector<VkPipeline>& libs,
84 uint32_t maxPipelineRayPayloadSize,
85 uint32_t maxPipelineRayHitAttributeSize,
86 uint32_t maxPipelineRayRecursionDepth) {
87 std::vector<VkPipelineShaderStageCreateInfo> shader_stage_info =
88 GetVkShaderStageInfo();
89
90 for (auto& info : shader_stage_info)
91 info.pName = GetEntryPointName(info.stage);
92
93 const bool lib = (create_flags_ & VK_PIPELINE_CREATE_LIBRARY_BIT_KHR) != 0;
94 const VkPipelineLibraryCreateInfoKHR libraryInfo = {
95 VK_STRUCTURE_TYPE_PIPELINE_LIBRARY_CREATE_INFO_KHR, nullptr,
96 static_cast<uint32_t>(libs.size()), libs.size() ? &libs[0] : nullptr};
97 const VkRayTracingPipelineInterfaceCreateInfoKHR libraryInterface = {
98 VK_STRUCTURE_TYPE_RAY_TRACING_PIPELINE_INTERFACE_CREATE_INFO_KHR, nullptr,
99 maxPipelineRayPayloadSize, maxPipelineRayHitAttributeSize};
100
101 VkRayTracingPipelineCreateInfoKHR pipelineCreateInfo{
102 VK_STRUCTURE_TYPE_RAY_TRACING_PIPELINE_CREATE_INFO_KHR,
103 nullptr,
104 create_flags_,
105 static_cast<uint32_t>(shader_stage_info.size()),
106 shader_stage_info.data(),
107 static_cast<uint32_t>(shader_group_create_info_.size()),
108 shader_group_create_info_.data(),
109 maxPipelineRayRecursionDepth,
110 libs.empty() ? nullptr : &libraryInfo,
111 lib || !libs.empty() ? &libraryInterface : nullptr,
112 nullptr,
113 pipeline_layout,
114 VK_NULL_HANDLE,
115 0,
116 };
117
118 VkResult r = device_->GetPtrs()->vkCreateRayTracingPipelinesKHR(
119 device_->GetVkDevice(), VK_NULL_HANDLE, VK_NULL_HANDLE, 1u,
120 &pipelineCreateInfo, nullptr, pipeline);
121 if (r != VK_SUCCESS)
122 return Result("Vulkan::Calling vkCreateRayTracingPipelinesKHR Fail");
123
124 return {};
125 }
126
getVulkanSBTRegion(VkPipeline pipeline,amber::SBT * aSBT,VkStridedDeviceAddressRegionKHR * region)127 Result RayTracingPipeline::getVulkanSBTRegion(
128 VkPipeline pipeline,
129 amber::SBT* aSBT,
130 VkStridedDeviceAddressRegionKHR* region) {
131 const uint32_t handle_size = device_->GetRayTracingShaderGroupHandleSize();
132 if (aSBT != nullptr) {
133 SBT* vSBT = nullptr;
134 auto x = sbtses_.find(aSBT);
135
136 if (x == sbtses_.end()) {
137 auto p = MakeUnique<amber::vulkan::SBT>(device_);
138 sbts_.push_back(std::move(p));
139 auto sbt_vulkan = sbtses_.emplace(aSBT, sbts_.back().get());
140
141 vSBT = sbt_vulkan.first->second;
142
143 Result r = vSBT->Create(aSBT, pipeline);
144 if (!r.IsSuccess())
145 return r;
146 } else {
147 vSBT = x->second;
148 }
149
150 *region = makeStridedDeviceAddressRegionKHR(
151 getBufferDeviceAddress(device_, vSBT->getBuffer()->GetVkBuffer()),
152 handle_size, handle_size * aSBT->GetSBTSize());
153 } else {
154 *region = makeStridedDeviceAddressRegionKHR(0, 0, 0);
155 }
156
157 return {};
158 }
159
InitLibrary(const std::vector<VkPipeline> & libs,uint32_t maxPipelineRayPayloadSize,uint32_t maxPipelineRayHitAttributeSize,uint32_t maxPipelineRayRecursionDepth)160 Result RayTracingPipeline::InitLibrary(const std::vector<VkPipeline>& libs,
161 uint32_t maxPipelineRayPayloadSize,
162 uint32_t maxPipelineRayHitAttributeSize,
163 uint32_t maxPipelineRayRecursionDepth) {
164 assert(pipeline_layout_ == VK_NULL_HANDLE);
165 Result r = CreateVkPipelineLayout(&pipeline_layout_);
166 if (!r.IsSuccess())
167 return r;
168
169 assert(pipeline_ == VK_NULL_HANDLE);
170 r = CreateVkRayTracingPipeline(
171 pipeline_layout_, &pipeline_, libs, maxPipelineRayPayloadSize,
172 maxPipelineRayHitAttributeSize, maxPipelineRayRecursionDepth);
173 if (!r.IsSuccess())
174 return r;
175
176 return {};
177 }
178
TraceRays(amber::SBT * rSBT,amber::SBT * mSBT,amber::SBT * hSBT,amber::SBT * cSBT,uint32_t x,uint32_t y,uint32_t z,uint32_t maxPipelineRayPayloadSize,uint32_t maxPipelineRayHitAttributeSize,uint32_t maxPipelineRayRecursionDepth,const std::vector<VkPipeline> & libs,bool is_timed_execution)179 Result RayTracingPipeline::TraceRays(amber::SBT* rSBT,
180 amber::SBT* mSBT,
181 amber::SBT* hSBT,
182 amber::SBT* cSBT,
183 uint32_t x,
184 uint32_t y,
185 uint32_t z,
186 uint32_t maxPipelineRayPayloadSize,
187 uint32_t maxPipelineRayHitAttributeSize,
188 uint32_t maxPipelineRayRecursionDepth,
189 const std::vector<VkPipeline>& libs,
190 bool is_timed_execution) {
191 Result r = SendDescriptorDataToDeviceIfNeeded();
192 if (!r.IsSuccess())
193 return r;
194
195 r = InitLibrary(libs, maxPipelineRayPayloadSize,
196 maxPipelineRayHitAttributeSize, maxPipelineRayRecursionDepth);
197 if (!r.IsSuccess())
198 return r;
199
200 // Note that a command updating a descriptor set and a command using
201 // it must be submitted separately, because using a descriptor set
202 // while updating it is not safe.
203 UpdateDescriptorSetsIfNeeded();
204 CreateTimingQueryObjectIfNeeded(is_timed_execution);
205 {
206 CommandBufferGuard guard(GetCommandBuffer());
207 if (!guard.IsRecording())
208 return guard.GetResult();
209
210 for (auto& i : *blases_) {
211 i.second->BuildBLAS(GetCommandBuffer());
212 }
213 for (auto& i : *tlases_) {
214 i.second->BuildTLAS(GetCommandBuffer()->GetVkCommandBuffer());
215 }
216
217 BindVkDescriptorSets(pipeline_layout_);
218
219 r = RecordPushConstant(pipeline_layout_);
220 if (!r.IsSuccess())
221 return r;
222
223 device_->GetPtrs()->vkCmdBindPipeline(
224 command_->GetVkCommandBuffer(), VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR,
225 pipeline_);
226
227 VkStridedDeviceAddressRegionKHR rSBTRegion = {};
228 VkStridedDeviceAddressRegionKHR mSBTRegion = {};
229 VkStridedDeviceAddressRegionKHR hSBTRegion = {};
230 VkStridedDeviceAddressRegionKHR cSBTRegion = {};
231
232 r = getVulkanSBTRegion(pipeline_, rSBT, &rSBTRegion);
233 if (!r.IsSuccess())
234 return r;
235
236 r = getVulkanSBTRegion(pipeline_, mSBT, &mSBTRegion);
237 if (!r.IsSuccess())
238 return r;
239
240 r = getVulkanSBTRegion(pipeline_, hSBT, &hSBTRegion);
241 if (!r.IsSuccess())
242 return r;
243
244 r = getVulkanSBTRegion(pipeline_, cSBT, &cSBTRegion);
245 if (!r.IsSuccess())
246 return r;
247
248 device_->GetPtrs()->vkCmdTraceRaysKHR(command_->GetVkCommandBuffer(),
249 &rSBTRegion, &mSBTRegion, &hSBTRegion,
250 &cSBTRegion, x, y, z);
251 BeginTimerQuery();
252 r = guard.Submit(GetFenceTimeout(), GetPipelineRuntimeLayerEnabled());
253 EndTimerQuery();
254 if (!r.IsSuccess())
255 return r;
256 }
257 DestroyTimingQueryObjectIfNeeded();
258 r = ReadbackDescriptorsToHostDataQueue();
259 if (!r.IsSuccess())
260 return r;
261
262 return {};
263 }
264
265 } // namespace vulkan
266 } // namespace amber
267