1 // Copyright 2024 The Amber Authors. 2 // Copyright (C) 2024 Advanced Micro Devices, Inc. All rights reserved. 3 // 4 // Licensed under the Apache License, Version 2.0 (the "License"); 5 // you may not use this file except in compliance with the License. 6 // You may obtain a copy of the License at 7 // 8 // http://www.apache.org/licenses/LICENSE-2.0 9 // 10 // Unless required by applicable law or agreed to in writing, software 11 // distributed under the License is distributed on an "AS IS" BASIS, 12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 // See the License for the specific language governing permissions and 14 // limitations under the License. 15 16 #ifndef SRC_VULKAN_RAYTRACING_PIPELINE_H_ 17 #define SRC_VULKAN_RAYTRACING_PIPELINE_H_ 18 19 #include <memory> 20 #include <vector> 21 22 #include "amber/result.h" 23 #include "amber/vulkan_header.h" 24 #include "src/vulkan/pipeline.h" 25 26 namespace amber { 27 namespace vulkan { 28 29 /// Pipepline to handle compute commands. 30 class RayTracingPipeline : public Pipeline { 31 public: 32 RayTracingPipeline( 33 Device* device, 34 BlasesMap* blases, 35 TlasesMap* tlases, 36 uint32_t fence_timeout_ms, 37 bool pipeline_runtime_layer_enabled, 38 const std::vector<VkPipelineShaderStageCreateInfo>& shader_stage_info, 39 VkPipelineCreateFlags create_flags); 40 ~RayTracingPipeline() override; 41 42 Result AddTLASDescriptor(const TLASCommand* cmd); 43 44 Result Initialize(CommandPool* pool, 45 std::vector<VkRayTracingShaderGroupCreateInfoKHR>& 46 shader_group_create_info); 47 48 Result getVulkanSBTRegion(VkPipeline pipeline, 49 amber::SBT* aSBT, 50 VkStridedDeviceAddressRegionKHR* region); 51 52 Result InitLibrary(const std::vector<VkPipeline>& lib, 53 uint32_t maxPipelineRayPayloadSize, 54 uint32_t maxPipelineRayHitAttributeSize, 55 uint32_t maxPipelineRayRecursionDepth); 56 57 Result TraceRays(amber::SBT* rSBT, 58 amber::SBT* mSBT, 59 amber::SBT* hSBT, 60 amber::SBT* cSBT, 61 uint32_t x, 62 uint32_t y, 63 uint32_t z, 64 uint32_t maxPipelineRayPayloadSize, 65 uint32_t maxPipelineRayHitAttributeSize, 66 uint32_t maxPipelineRayRecursionDepth, 67 const std::vector<VkPipeline>& lib, 68 bool is_timed_execution); 69 GetBlases()70 BlasesMap* GetBlases() override { return blases_; } GetTlases()71 TlasesMap* GetTlases() override { return tlases_; } 72 73 private: 74 Result CreateVkRayTracingPipeline(const VkPipelineLayout& pipeline_layout, 75 VkPipeline* pipeline, 76 const std::vector<VkPipeline>& libs, 77 uint32_t maxPipelineRayPayloadSize, 78 uint32_t maxPipelineRayHitAttributeSize, 79 uint32_t maxPipelineRayRecursionDepth); 80 81 std::vector<VkRayTracingShaderGroupCreateInfoKHR> shader_group_create_info_; 82 BlasesMap* blases_; 83 TlasesMap* tlases_; 84 SbtsMap sbtses_; 85 std::vector<std::unique_ptr<amber::vulkan::SBT>> sbts_; 86 }; 87 88 } // namespace vulkan 89 } // namespace amber 90 91 #endif // SRC_VULKAN_RAYTRACING_PIPELINE_H_ 92