1 // Copyright 2024 The Amber Authors.
2 // Copyright (C) 2024 Advanced Micro Devices, Inc. All rights reserved.
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15
16 #include <cstring>
17
18 #include "src/vulkan/sbt.h"
19 #include "src/vulkan/pipeline.h"
20
21 namespace amber {
22 namespace vulkan {
23
SBT(Device * device)24 SBT::SBT(Device* device) : device_(device) {}
25
Create(amber::SBT * sbt,VkPipeline pipeline)26 Result SBT::Create(amber::SBT* sbt, VkPipeline pipeline) {
27 uint32_t handles_count = 0;
28 for (auto& x : sbt->GetSBTRecords())
29 handles_count += x->GetCount();
30
31 if (handles_count == 0)
32 return Result("SBT must contain at least one record");
33
34 const uint32_t handle_size = device_->GetRayTracingShaderGroupHandleSize();
35 const uint32_t buffer_size = handle_size * handles_count;
36 std::vector<uint8_t> handles(buffer_size);
37
38 buffer_ = MakeUnique<TransferBuffer>(device_, buffer_size, nullptr);
39 buffer_->AddUsageFlags(VK_BUFFER_USAGE_TRANSFER_DST_BIT |
40 VK_BUFFER_USAGE_SHADER_BINDING_TABLE_BIT_KHR |
41 VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT);
42 buffer_->AddAllocateFlags(VK_MEMORY_ALLOCATE_DEVICE_ADDRESS_BIT);
43 Result r = buffer_->Initialize();
44 if (!r.IsSuccess())
45 return r;
46
47 size_t start = 0;
48 for (auto& x : sbt->GetSBTRecords()) {
49 const uint32_t index = x->GetIndex();
50 const uint32_t count = x->GetCount();
51 if (index != static_cast<uint32_t>(-1)) {
52 VkResult vr = device_->GetPtrs()->vkGetRayTracingShaderGroupHandlesKHR(
53 device_->GetVkDevice(), pipeline, index, count, count * handle_size,
54 &handles[start * handle_size]);
55
56 if (vr != VK_SUCCESS)
57 return Result("vkGetRayTracingShaderGroupHandlesKHR has failed");
58 }
59
60 start += count;
61 }
62
63 memcpy(buffer_->HostAccessibleMemoryPtr(), handles.data(), handles.size());
64
65 // Skip flush as memory allocated for buffer is coherent
66
67 return r;
68 }
69
70 SBT::~SBT() = default;
71
72 } // namespace vulkan
73 } // namespace amber
74