• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2024 The Amber Authors.
2 // Copyright (C) 2024 Advanced Micro Devices, Inc. All rights reserved.
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 
16 #include <cstring>
17 
18 #include "src/vulkan/sbt.h"
19 #include "src/vulkan/pipeline.h"
20 
21 namespace amber {
22 namespace vulkan {
23 
SBT(Device * device)24 SBT::SBT(Device* device) : device_(device) {}
25 
Create(amber::SBT * sbt,VkPipeline pipeline)26 Result SBT::Create(amber::SBT* sbt, VkPipeline pipeline) {
27   uint32_t handles_count = 0;
28   for (auto& x : sbt->GetSBTRecords())
29     handles_count += x->GetCount();
30 
31   if (handles_count == 0)
32     return Result("SBT must contain at least one record");
33 
34   const uint32_t handle_size = device_->GetRayTracingShaderGroupHandleSize();
35   const uint32_t buffer_size = handle_size * handles_count;
36   std::vector<uint8_t> handles(buffer_size);
37 
38   buffer_ = MakeUnique<TransferBuffer>(device_, buffer_size, nullptr);
39   buffer_->AddUsageFlags(VK_BUFFER_USAGE_TRANSFER_DST_BIT |
40                          VK_BUFFER_USAGE_SHADER_BINDING_TABLE_BIT_KHR |
41                          VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT);
42   buffer_->AddAllocateFlags(VK_MEMORY_ALLOCATE_DEVICE_ADDRESS_BIT);
43   Result r = buffer_->Initialize();
44   if (!r.IsSuccess())
45     return r;
46 
47   size_t start = 0;
48   for (auto& x : sbt->GetSBTRecords()) {
49     const uint32_t index = x->GetIndex();
50     const uint32_t count = x->GetCount();
51     if (index != static_cast<uint32_t>(-1)) {
52       VkResult vr = device_->GetPtrs()->vkGetRayTracingShaderGroupHandlesKHR(
53           device_->GetVkDevice(), pipeline, index, count, count * handle_size,
54           &handles[start * handle_size]);
55 
56       if (vr != VK_SUCCESS)
57         return Result("vkGetRayTracingShaderGroupHandlesKHR has failed");
58     }
59 
60     start += count;
61   }
62 
63   memcpy(buffer_->HostAccessibleMemoryPtr(), handles.data(), handles.size());
64 
65   // Skip flush as memory allocated for buffer is coherent
66 
67   return r;
68 }
69 
70 SBT::~SBT() = default;
71 
72 }  // namespace vulkan
73 }  // namespace amber
74