• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2024 The Amber Authors.
2 // Copyright (C) 2024 Advanced Micro Devices, Inc. All rights reserved.
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 
16 #include <utility>
17 
18 #include "src/vulkan/raytracing_pipeline.h"
19 
20 #include "src/vulkan/blas.h"
21 #include "src/vulkan/command_pool.h"
22 #include "src/vulkan/device.h"
23 #include "src/vulkan/sbt.h"
24 #include "src/vulkan/tlas.h"
25 
26 namespace amber {
27 namespace vulkan {
28 
makeStridedDeviceAddressRegionKHR(VkDeviceAddress deviceAddress,VkDeviceSize stride,VkDeviceSize size)29 inline VkStridedDeviceAddressRegionKHR makeStridedDeviceAddressRegionKHR(
30     VkDeviceAddress deviceAddress,
31     VkDeviceSize stride,
32     VkDeviceSize size) {
33   VkStridedDeviceAddressRegionKHR res;
34   res.deviceAddress = deviceAddress;
35   res.stride = stride;
36   res.size = size;
37   return res;
38 }
39 
getBufferDeviceAddress(Device * device,VkBuffer buffer)40 inline VkDeviceAddress getBufferDeviceAddress(Device* device, VkBuffer buffer) {
41   const VkBufferDeviceAddressInfo bufferDeviceAddressInfo = {
42       VK_STRUCTURE_TYPE_BUFFER_DEVICE_ADDRESS_INFO_KHR,
43       nullptr,
44       buffer,
45   };
46 
47   return device->GetPtrs()->vkGetBufferDeviceAddress(device->GetVkDevice(),
48                                                      &bufferDeviceAddressInfo);
49 }
50 
RayTracingPipeline(Device * device,BlasesMap * blases,TlasesMap * tlases,uint32_t fence_timeout_ms,bool pipeline_runtime_layer_enabled,const std::vector<VkPipelineShaderStageCreateInfo> & shader_stage_info,VkPipelineCreateFlags create_flags)51 RayTracingPipeline::RayTracingPipeline(
52     Device* device,
53     BlasesMap* blases,
54     TlasesMap* tlases,
55     uint32_t fence_timeout_ms,
56     bool pipeline_runtime_layer_enabled,
57     const std::vector<VkPipelineShaderStageCreateInfo>& shader_stage_info,
58     VkPipelineCreateFlags create_flags)
59     : Pipeline(PipelineType::kRayTracing,
60                device,
61                fence_timeout_ms,
62                pipeline_runtime_layer_enabled,
63                shader_stage_info,
64                create_flags),
65       shader_group_create_info_(),
66       blases_(blases),
67       tlases_(tlases) {}
68 
69 RayTracingPipeline::~RayTracingPipeline() = default;
70 
Initialize(CommandPool * pool,std::vector<VkRayTracingShaderGroupCreateInfoKHR> & shader_group_create_info)71 Result RayTracingPipeline::Initialize(
72     CommandPool* pool,
73     std::vector<VkRayTracingShaderGroupCreateInfoKHR>&
74         shader_group_create_info) {
75   shader_group_create_info_.swap(shader_group_create_info);
76 
77   return Pipeline::Initialize(pool);
78 }
79 
CreateVkRayTracingPipeline(const VkPipelineLayout & pipeline_layout,VkPipeline * pipeline,const std::vector<VkPipeline> & libs,uint32_t maxPipelineRayPayloadSize,uint32_t maxPipelineRayHitAttributeSize,uint32_t maxPipelineRayRecursionDepth)80 Result RayTracingPipeline::CreateVkRayTracingPipeline(
81     const VkPipelineLayout& pipeline_layout,
82     VkPipeline* pipeline,
83     const std::vector<VkPipeline>& libs,
84     uint32_t maxPipelineRayPayloadSize,
85     uint32_t maxPipelineRayHitAttributeSize,
86     uint32_t maxPipelineRayRecursionDepth) {
87   std::vector<VkPipelineShaderStageCreateInfo> shader_stage_info =
88       GetVkShaderStageInfo();
89 
90   for (auto& info : shader_stage_info)
91     info.pName = GetEntryPointName(info.stage);
92 
93   const bool lib = (create_flags_ & VK_PIPELINE_CREATE_LIBRARY_BIT_KHR) != 0;
94   const VkPipelineLibraryCreateInfoKHR libraryInfo = {
95       VK_STRUCTURE_TYPE_PIPELINE_LIBRARY_CREATE_INFO_KHR, nullptr,
96       static_cast<uint32_t>(libs.size()), libs.size() ? &libs[0] : nullptr};
97   const VkRayTracingPipelineInterfaceCreateInfoKHR libraryInterface = {
98       VK_STRUCTURE_TYPE_RAY_TRACING_PIPELINE_INTERFACE_CREATE_INFO_KHR, nullptr,
99       maxPipelineRayPayloadSize, maxPipelineRayHitAttributeSize};
100 
101   VkRayTracingPipelineCreateInfoKHR pipelineCreateInfo{
102       VK_STRUCTURE_TYPE_RAY_TRACING_PIPELINE_CREATE_INFO_KHR,
103       nullptr,
104       create_flags_,
105       static_cast<uint32_t>(shader_stage_info.size()),
106       shader_stage_info.data(),
107       static_cast<uint32_t>(shader_group_create_info_.size()),
108       shader_group_create_info_.data(),
109       maxPipelineRayRecursionDepth,
110       libs.empty() ? nullptr : &libraryInfo,
111       lib || !libs.empty() ? &libraryInterface : nullptr,
112       nullptr,
113       pipeline_layout,
114       VK_NULL_HANDLE,
115       0,
116   };
117 
118   VkResult r = device_->GetPtrs()->vkCreateRayTracingPipelinesKHR(
119       device_->GetVkDevice(), VK_NULL_HANDLE, VK_NULL_HANDLE, 1u,
120       &pipelineCreateInfo, nullptr, pipeline);
121   if (r != VK_SUCCESS)
122     return Result("Vulkan::Calling vkCreateRayTracingPipelinesKHR Fail");
123 
124   return {};
125 }
126 
getVulkanSBTRegion(VkPipeline pipeline,amber::SBT * aSBT,VkStridedDeviceAddressRegionKHR * region)127 Result RayTracingPipeline::getVulkanSBTRegion(
128     VkPipeline pipeline,
129     amber::SBT* aSBT,
130     VkStridedDeviceAddressRegionKHR* region) {
131   const uint32_t handle_size = device_->GetRayTracingShaderGroupHandleSize();
132   if (aSBT != nullptr) {
133     SBT* vSBT = nullptr;
134     auto x = sbtses_.find(aSBT);
135 
136     if (x == sbtses_.end()) {
137       auto p = MakeUnique<amber::vulkan::SBT>(device_);
138       sbts_.push_back(std::move(p));
139       auto sbt_vulkan = sbtses_.emplace(aSBT, sbts_.back().get());
140 
141       vSBT = sbt_vulkan.first->second;
142 
143       Result r = vSBT->Create(aSBT, pipeline);
144       if (!r.IsSuccess())
145         return r;
146     } else {
147       vSBT = x->second;
148     }
149 
150     *region = makeStridedDeviceAddressRegionKHR(
151         getBufferDeviceAddress(device_, vSBT->getBuffer()->GetVkBuffer()),
152         handle_size, handle_size * aSBT->GetSBTSize());
153   } else {
154     *region = makeStridedDeviceAddressRegionKHR(0, 0, 0);
155   }
156 
157   return {};
158 }
159 
InitLibrary(const std::vector<VkPipeline> & libs,uint32_t maxPipelineRayPayloadSize,uint32_t maxPipelineRayHitAttributeSize,uint32_t maxPipelineRayRecursionDepth)160 Result RayTracingPipeline::InitLibrary(const std::vector<VkPipeline>& libs,
161                                        uint32_t maxPipelineRayPayloadSize,
162                                        uint32_t maxPipelineRayHitAttributeSize,
163                                        uint32_t maxPipelineRayRecursionDepth) {
164   assert(pipeline_layout_ == VK_NULL_HANDLE);
165   Result r = CreateVkPipelineLayout(&pipeline_layout_);
166   if (!r.IsSuccess())
167     return r;
168 
169   assert(pipeline_ == VK_NULL_HANDLE);
170   r = CreateVkRayTracingPipeline(
171       pipeline_layout_, &pipeline_, libs, maxPipelineRayPayloadSize,
172       maxPipelineRayHitAttributeSize, maxPipelineRayRecursionDepth);
173   if (!r.IsSuccess())
174     return r;
175 
176   return {};
177 }
178 
TraceRays(amber::SBT * rSBT,amber::SBT * mSBT,amber::SBT * hSBT,amber::SBT * cSBT,uint32_t x,uint32_t y,uint32_t z,uint32_t maxPipelineRayPayloadSize,uint32_t maxPipelineRayHitAttributeSize,uint32_t maxPipelineRayRecursionDepth,const std::vector<VkPipeline> & libs,bool is_timed_execution)179 Result RayTracingPipeline::TraceRays(amber::SBT* rSBT,
180                                      amber::SBT* mSBT,
181                                      amber::SBT* hSBT,
182                                      amber::SBT* cSBT,
183                                      uint32_t x,
184                                      uint32_t y,
185                                      uint32_t z,
186                                      uint32_t maxPipelineRayPayloadSize,
187                                      uint32_t maxPipelineRayHitAttributeSize,
188                                      uint32_t maxPipelineRayRecursionDepth,
189                                      const std::vector<VkPipeline>& libs,
190                                      bool is_timed_execution) {
191   Result r = SendDescriptorDataToDeviceIfNeeded();
192   if (!r.IsSuccess())
193     return r;
194 
195   r = InitLibrary(libs, maxPipelineRayPayloadSize,
196                   maxPipelineRayHitAttributeSize, maxPipelineRayRecursionDepth);
197   if (!r.IsSuccess())
198     return r;
199 
200   // Note that a command updating a descriptor set and a command using
201   // it must be submitted separately, because using a descriptor set
202   // while updating it is not safe.
203   UpdateDescriptorSetsIfNeeded();
204   CreateTimingQueryObjectIfNeeded(is_timed_execution);
205   {
206     CommandBufferGuard guard(GetCommandBuffer());
207     if (!guard.IsRecording())
208       return guard.GetResult();
209 
210     for (auto& i : *blases_) {
211       i.second->BuildBLAS(GetCommandBuffer());
212     }
213     for (auto& i : *tlases_) {
214       i.second->BuildTLAS(GetCommandBuffer()->GetVkCommandBuffer());
215     }
216 
217     BindVkDescriptorSets(pipeline_layout_);
218 
219     r = RecordPushConstant(pipeline_layout_);
220     if (!r.IsSuccess())
221       return r;
222 
223     device_->GetPtrs()->vkCmdBindPipeline(
224         command_->GetVkCommandBuffer(), VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR,
225         pipeline_);
226 
227     VkStridedDeviceAddressRegionKHR rSBTRegion = {};
228     VkStridedDeviceAddressRegionKHR mSBTRegion = {};
229     VkStridedDeviceAddressRegionKHR hSBTRegion = {};
230     VkStridedDeviceAddressRegionKHR cSBTRegion = {};
231 
232     r = getVulkanSBTRegion(pipeline_, rSBT, &rSBTRegion);
233     if (!r.IsSuccess())
234       return r;
235 
236     r = getVulkanSBTRegion(pipeline_, mSBT, &mSBTRegion);
237     if (!r.IsSuccess())
238       return r;
239 
240     r = getVulkanSBTRegion(pipeline_, hSBT, &hSBTRegion);
241     if (!r.IsSuccess())
242       return r;
243 
244     r = getVulkanSBTRegion(pipeline_, cSBT, &cSBTRegion);
245     if (!r.IsSuccess())
246       return r;
247 
248     device_->GetPtrs()->vkCmdTraceRaysKHR(command_->GetVkCommandBuffer(),
249                                           &rSBTRegion, &mSBTRegion, &hSBTRegion,
250                                           &cSBTRegion, x, y, z);
251     BeginTimerQuery();
252     r = guard.Submit(GetFenceTimeout(), GetPipelineRuntimeLayerEnabled());
253     EndTimerQuery();
254     if (!r.IsSuccess())
255       return r;
256   }
257   DestroyTimingQueryObjectIfNeeded();
258   r = ReadbackDescriptorsToHostDataQueue();
259   if (!r.IsSuccess())
260     return r;
261 
262   return {};
263 }
264 
265 }  // namespace vulkan
266 }  // namespace amber
267