• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2024 The Amber Authors.
2 // Copyright (C) 2024 Advanced Micro Devices, Inc. All rights reserved.
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 
16 #ifndef SRC_VULKAN_RAYTRACING_PIPELINE_H_
17 #define SRC_VULKAN_RAYTRACING_PIPELINE_H_
18 
19 #include <memory>
20 #include <vector>
21 
22 #include "amber/result.h"
23 #include "amber/vulkan_header.h"
24 #include "src/vulkan/pipeline.h"
25 
26 namespace amber {
27 namespace vulkan {
28 
29 /// Pipepline to handle compute commands.
30 class RayTracingPipeline : public Pipeline {
31  public:
32   RayTracingPipeline(
33       Device* device,
34       BlasesMap* blases,
35       TlasesMap* tlases,
36       uint32_t fence_timeout_ms,
37       bool pipeline_runtime_layer_enabled,
38       const std::vector<VkPipelineShaderStageCreateInfo>& shader_stage_info,
39       VkPipelineCreateFlags create_flags);
40   ~RayTracingPipeline() override;
41 
42   Result AddTLASDescriptor(const TLASCommand* cmd);
43 
44   Result Initialize(CommandPool* pool,
45                     std::vector<VkRayTracingShaderGroupCreateInfoKHR>&
46                         shader_group_create_info);
47 
48   Result getVulkanSBTRegion(VkPipeline pipeline,
49                             amber::SBT* aSBT,
50                             VkStridedDeviceAddressRegionKHR* region);
51 
52   Result InitLibrary(const std::vector<VkPipeline>& lib,
53                      uint32_t maxPipelineRayPayloadSize,
54                      uint32_t maxPipelineRayHitAttributeSize,
55                      uint32_t maxPipelineRayRecursionDepth);
56 
57   Result TraceRays(amber::SBT* rSBT,
58                    amber::SBT* mSBT,
59                    amber::SBT* hSBT,
60                    amber::SBT* cSBT,
61                    uint32_t x,
62                    uint32_t y,
63                    uint32_t z,
64                    uint32_t maxPipelineRayPayloadSize,
65                    uint32_t maxPipelineRayHitAttributeSize,
66                    uint32_t maxPipelineRayRecursionDepth,
67                    const std::vector<VkPipeline>& lib,
68                    bool is_timed_execution);
69 
GetBlases()70   BlasesMap* GetBlases() override { return blases_; }
GetTlases()71   TlasesMap* GetTlases() override { return tlases_; }
72 
73  private:
74   Result CreateVkRayTracingPipeline(const VkPipelineLayout& pipeline_layout,
75                                     VkPipeline* pipeline,
76                                     const std::vector<VkPipeline>& libs,
77                                     uint32_t maxPipelineRayPayloadSize,
78                                     uint32_t maxPipelineRayHitAttributeSize,
79                                     uint32_t maxPipelineRayRecursionDepth);
80 
81   std::vector<VkRayTracingShaderGroupCreateInfoKHR> shader_group_create_info_;
82   BlasesMap* blases_;
83   TlasesMap* tlases_;
84   SbtsMap sbtses_;
85   std::vector<std::unique_ptr<amber::vulkan::SBT>> sbts_;
86 };
87 
88 }  // namespace vulkan
89 }  // namespace amber
90 
91 #endif  // SRC_VULKAN_RAYTRACING_PIPELINE_H_
92