• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2018 The Amber Authors.
2 // Copyright (C) 2024 Advanced Micro Devices, Inc. All rights reserved.
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 
16 #include "src/vulkan/pipeline.h"
17 
18 #include <algorithm>
19 #include <array>
20 #include <limits>
21 #include <utility>
22 
23 #include "src/command.h"
24 #include "src/engine.h"
25 #include "src/make_unique.h"
26 #include "src/vulkan/buffer_descriptor.h"
27 #include "src/vulkan/compute_pipeline.h"
28 #include "src/vulkan/device.h"
29 #include "src/vulkan/graphics_pipeline.h"
30 #include "src/vulkan/image_descriptor.h"
31 #include "src/vulkan/raytracing_pipeline.h"
32 #include "src/vulkan/sampler_descriptor.h"
33 #include "src/vulkan/tlas_descriptor.h"
34 
35 namespace amber {
36 namespace vulkan {
37 namespace {
38 
39 const char* kDefaultEntryPointName = "main";
40 
41 constexpr VkMemoryBarrier kMemoryBarrierFull = {
42     VK_STRUCTURE_TYPE_MEMORY_BARRIER, nullptr,
43     VK_ACCESS_2_MEMORY_READ_BIT_KHR | VK_ACCESS_2_MEMORY_WRITE_BIT_KHR,
44     VK_ACCESS_2_MEMORY_READ_BIT_KHR | VK_ACCESS_2_MEMORY_WRITE_BIT_KHR};
45 
46 constexpr uint32_t kNumQueryObjects = 2;
47 
48 }  // namespace
49 
Pipeline(PipelineType type,Device * device,uint32_t fence_timeout_ms,bool pipeline_runtime_layer_enabled,const std::vector<VkPipelineShaderStageCreateInfo> & shader_stage_info,VkPipelineCreateFlags create_flags)50 Pipeline::Pipeline(
51     PipelineType type,
52     Device* device,
53     uint32_t fence_timeout_ms,
54     bool pipeline_runtime_layer_enabled,
55     const std::vector<VkPipelineShaderStageCreateInfo>& shader_stage_info,
56     VkPipelineCreateFlags create_flags)
57     : device_(device),
58       create_flags_(create_flags),
59       pipeline_type_(type),
60       shader_stage_info_(shader_stage_info),
61       fence_timeout_ms_(fence_timeout_ms),
62       pipeline_runtime_layer_enabled_(pipeline_runtime_layer_enabled) {}
63 
~Pipeline()64 Pipeline::~Pipeline() {
65   // Command must be reset before we destroy descriptors or we get a validation
66   // error.
67   command_ = nullptr;
68 
69   for (auto& info : descriptor_set_info_) {
70     if (info.layout != VK_NULL_HANDLE) {
71       device_->GetPtrs()->vkDestroyDescriptorSetLayout(device_->GetVkDevice(),
72                                                        info.layout, nullptr);
73     }
74 
75     if (info.empty)
76       continue;
77 
78     if (info.pool != VK_NULL_HANDLE) {
79       device_->GetPtrs()->vkDestroyDescriptorPool(device_->GetVkDevice(),
80                                                   info.pool, nullptr);
81     }
82   }
83 
84   if (pipeline_layout_ != VK_NULL_HANDLE) {
85     device_->GetPtrs()->vkDestroyPipelineLayout(device_->GetVkDevice(),
86                                                 pipeline_layout_, nullptr);
87     pipeline_layout_ = VK_NULL_HANDLE;
88   }
89 
90   if (pipeline_ != VK_NULL_HANDLE) {
91     device_->GetPtrs()->vkDestroyPipeline(device_->GetVkDevice(), pipeline_,
92                                           nullptr);
93     pipeline_ = VK_NULL_HANDLE;
94   }
95 }
96 
AsGraphics()97 GraphicsPipeline* Pipeline::AsGraphics() {
98   return static_cast<GraphicsPipeline*>(this);
99 }
100 
AsCompute()101 ComputePipeline* Pipeline::AsCompute() {
102   return static_cast<ComputePipeline*>(this);
103 }
104 
AsRayTracingPipeline()105 RayTracingPipeline* Pipeline::AsRayTracingPipeline() {
106   return static_cast<RayTracingPipeline*>(this);
107 }
108 
Initialize(CommandPool * pool)109 Result Pipeline::Initialize(CommandPool* pool) {
110   push_constant_ = MakeUnique<PushConstant>(device_);
111 
112   command_ = MakeUnique<CommandBuffer>(device_, pool);
113   return command_->Initialize();
114 }
115 
CreateDescriptorSetLayouts()116 Result Pipeline::CreateDescriptorSetLayouts() {
117   for (auto& info : descriptor_set_info_) {
118     VkDescriptorSetLayoutCreateInfo desc_info =
119         VkDescriptorSetLayoutCreateInfo();
120     desc_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO;
121 
122     // If there are no descriptors for this descriptor set we only
123     // need to create its layout and there will be no bindings.
124     std::vector<VkDescriptorSetLayoutBinding> bindings;
125     for (auto& desc : info.descriptors) {
126       bindings.emplace_back();
127       bindings.back().binding = desc->GetBinding();
128       bindings.back().descriptorType = desc->GetVkDescriptorType();
129       bindings.back().descriptorCount = desc->GetDescriptorCount();
130       bindings.back().stageFlags = VK_SHADER_STAGE_ALL;
131     }
132     desc_info.bindingCount = static_cast<uint32_t>(bindings.size());
133     desc_info.pBindings = bindings.data();
134 
135     if (device_->GetPtrs()->vkCreateDescriptorSetLayout(
136             device_->GetVkDevice(), &desc_info, nullptr, &info.layout) !=
137         VK_SUCCESS) {
138       return Result("Vulkan::Calling vkCreateDescriptorSetLayout Fail");
139     }
140   }
141 
142   return {};
143 }
144 
CreateDescriptorPools()145 Result Pipeline::CreateDescriptorPools() {
146   for (auto& info : descriptor_set_info_) {
147     if (info.empty)
148       continue;
149 
150     std::vector<VkDescriptorPoolSize> pool_sizes;
151     for (auto& desc : info.descriptors) {
152       VkDescriptorType type = desc->GetVkDescriptorType();
153       auto it = find_if(pool_sizes.begin(), pool_sizes.end(),
154                         [&type](const VkDescriptorPoolSize& size) {
155                           return size.type == type;
156                         });
157       if (it != pool_sizes.end()) {
158         it->descriptorCount += desc->GetDescriptorCount();
159         continue;
160       }
161 
162       pool_sizes.emplace_back();
163       pool_sizes.back().type = type;
164       pool_sizes.back().descriptorCount = desc->GetDescriptorCount();
165     }
166 
167     VkDescriptorPoolCreateInfo pool_info = VkDescriptorPoolCreateInfo();
168     pool_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO;
169     pool_info.maxSets = 1;
170     pool_info.poolSizeCount = static_cast<uint32_t>(pool_sizes.size());
171     pool_info.pPoolSizes = pool_sizes.data();
172 
173     if (device_->GetPtrs()->vkCreateDescriptorPool(device_->GetVkDevice(),
174                                                    &pool_info, nullptr,
175                                                    &info.pool) != VK_SUCCESS) {
176       return Result("Vulkan::Calling vkCreateDescriptorPool Fail");
177     }
178   }
179 
180   return {};
181 }
182 
CreateDescriptorSets()183 Result Pipeline::CreateDescriptorSets() {
184   for (size_t i = 0; i < descriptor_set_info_.size(); ++i) {
185     if (descriptor_set_info_[i].empty)
186       continue;
187 
188     VkDescriptorSetAllocateInfo desc_set_info = VkDescriptorSetAllocateInfo();
189     desc_set_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO;
190     desc_set_info.descriptorPool = descriptor_set_info_[i].pool;
191     desc_set_info.descriptorSetCount = 1;
192     desc_set_info.pSetLayouts = &descriptor_set_info_[i].layout;
193 
194     VkDescriptorSet desc_set = VK_NULL_HANDLE;
195     if (device_->GetPtrs()->vkAllocateDescriptorSets(
196             device_->GetVkDevice(), &desc_set_info, &desc_set) != VK_SUCCESS) {
197       return Result("Vulkan::Calling vkAllocateDescriptorSets Fail");
198     }
199     descriptor_set_info_[i].vk_desc_set = desc_set;
200   }
201 
202   return {};
203 }
204 
CreateVkPipelineLayout(VkPipelineLayout * pipeline_layout)205 Result Pipeline::CreateVkPipelineLayout(VkPipelineLayout* pipeline_layout) {
206   Result r = CreateVkDescriptorRelatedObjectsIfNeeded();
207   if (!r.IsSuccess())
208     return r;
209 
210   std::vector<VkDescriptorSetLayout> descriptor_set_layouts;
211   for (const auto& desc_set : descriptor_set_info_)
212     descriptor_set_layouts.push_back(desc_set.layout);
213 
214   VkPipelineLayoutCreateInfo pipeline_layout_info =
215       VkPipelineLayoutCreateInfo();
216   pipeline_layout_info.sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO;
217   pipeline_layout_info.setLayoutCount =
218       static_cast<uint32_t>(descriptor_set_layouts.size());
219   pipeline_layout_info.pSetLayouts = descriptor_set_layouts.data();
220 
221   VkPushConstantRange push_const_range =
222       push_constant_->GetVkPushConstantRange();
223   if (push_const_range.size > 0) {
224     pipeline_layout_info.pushConstantRangeCount = 1U;
225     pipeline_layout_info.pPushConstantRanges = &push_const_range;
226   }
227 
228   if (device_->GetPtrs()->vkCreatePipelineLayout(
229           device_->GetVkDevice(), &pipeline_layout_info, nullptr,
230           pipeline_layout) != VK_SUCCESS) {
231     return Result("Vulkan::Calling vkCreatePipelineLayout Fail");
232   }
233 
234   return {};
235 }
236 
CreateVkDescriptorRelatedObjectsIfNeeded()237 Result Pipeline::CreateVkDescriptorRelatedObjectsIfNeeded() {
238   if (descriptor_related_objects_already_created_)
239     return {};
240 
241   Result r = CreateDescriptorSetLayouts();
242   if (!r.IsSuccess())
243     return r;
244 
245   r = CreateDescriptorPools();
246   if (!r.IsSuccess())
247     return r;
248 
249   r = CreateDescriptorSets();
250   if (!r.IsSuccess())
251     return r;
252 
253   descriptor_related_objects_already_created_ = true;
254   return {};
255 }
256 
UpdateDescriptorSetsIfNeeded()257 void Pipeline::UpdateDescriptorSetsIfNeeded() {
258   for (auto& info : descriptor_set_info_) {
259     for (auto& desc : info.descriptors)
260       desc->UpdateDescriptorSetIfNeeded(info.vk_desc_set);
261   }
262 }
263 
CreateTimingQueryObjectIfNeeded(bool is_timed_execution)264 void Pipeline::CreateTimingQueryObjectIfNeeded(bool is_timed_execution) {
265   if (!is_timed_execution ||
266       !device_->IsTimestampComputeAndGraphicsSupported()) {
267     return;
268   }
269   in_timed_execution_ = true;
270   VkQueryPoolCreateInfo pool_create_info{
271       VK_STRUCTURE_TYPE_QUERY_POOL_CREATE_INFO,
272       nullptr,
273       0,
274       VK_QUERY_TYPE_TIMESTAMP,
275       kNumQueryObjects,
276       0};
277   device_->GetPtrs()->vkCreateQueryPool(
278       device_->GetVkDevice(), &pool_create_info, nullptr, &query_pool_);
279 }
280 
DestroyTimingQueryObjectIfNeeded()281 void Pipeline::DestroyTimingQueryObjectIfNeeded() {
282   if (!in_timed_execution_) {
283     return;
284   }
285 
286   // Flags set so we may/will wait on the CPU for the availiblity of our
287   // queries.
288   const VkQueryResultFlags flags =
289       VK_QUERY_RESULT_WAIT_BIT | VK_QUERY_RESULT_64_BIT;
290   std::array<uint64_t, kNumQueryObjects> time_stamps = {};
291   constexpr VkDeviceSize kStrideBytes = sizeof(uint64_t);
292 
293   device_->GetPtrs()->vkGetQueryPoolResults(
294       device_->GetVkDevice(), query_pool_, 0, kNumQueryObjects,
295       sizeof(time_stamps), time_stamps.data(), kStrideBytes, flags);
296   double time_in_ns = static_cast<double>(time_stamps[1] - time_stamps[0]) *
297                       static_cast<double>(device_->GetTimestampPeriod());
298 
299   constexpr double kNsToMsTime = 1.0 / 1000000.0;
300   device_->ReportExecutionTiming(time_in_ns * kNsToMsTime);
301   device_->GetPtrs()->vkDestroyQueryPool(device_->GetVkDevice(), query_pool_,
302                                          nullptr);
303   in_timed_execution_ = false;
304 }
305 
BeginTimerQuery()306 void Pipeline::BeginTimerQuery() {
307   if (!in_timed_execution_) {
308     return;
309   }
310 
311   device_->GetPtrs()->vkCmdResetQueryPool(command_->GetVkCommandBuffer(),
312                                           query_pool_, 0, kNumQueryObjects);
313   // Full barrier prevents any work from before the point being still in the
314   // pipeline.
315   device_->GetPtrs()->vkCmdPipelineBarrier(
316       command_->GetVkCommandBuffer(), VK_PIPELINE_STAGE_ALL_COMMANDS_BIT,
317       VK_PIPELINE_STAGE_ALL_COMMANDS_BIT, 0, 1, &kMemoryBarrierFull, 0, nullptr,
318       0, nullptr);
319   constexpr uint32_t kBeginQueryIndexOffset = 0;
320   device_->GetPtrs()->vkCmdWriteTimestamp(command_->GetVkCommandBuffer(),
321                                           VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT,
322                                           query_pool_, kBeginQueryIndexOffset);
323 }
324 
EndTimerQuery()325 void Pipeline::EndTimerQuery() {
326   if (!in_timed_execution_) {
327     return;
328   }
329 
330   // Full barrier ensures that work including in our timing is executed before
331   // the timestamp.
332   device_->GetPtrs()->vkCmdPipelineBarrier(
333       command_->GetVkCommandBuffer(), VK_PIPELINE_STAGE_ALL_COMMANDS_BIT,
334       VK_PIPELINE_STAGE_ALL_COMMANDS_BIT, 0, 1, &kMemoryBarrierFull, 0, nullptr,
335       0, nullptr);
336   constexpr uint32_t kEndQueryIndexOffset = 1;
337   device_->GetPtrs()->vkCmdWriteTimestamp(command_->GetVkCommandBuffer(),
338                                           VK_PIPELINE_STAGE_BOTTOM_OF_PIPE_BIT,
339                                           query_pool_, kEndQueryIndexOffset);
340 }
341 
RecordPushConstant(const VkPipelineLayout & pipeline_layout)342 Result Pipeline::RecordPushConstant(const VkPipelineLayout& pipeline_layout) {
343   return push_constant_->RecordPushConstantVkCommand(command_.get(),
344                                                      pipeline_layout);
345 }
346 
AddPushConstantBuffer(const Buffer * buf,uint32_t offset)347 Result Pipeline::AddPushConstantBuffer(const Buffer* buf, uint32_t offset) {
348   if (!buf)
349     return Result("Missing push constant buffer data");
350   return push_constant_->AddBuffer(buf, offset);
351 }
352 
GetDescriptorSlot(uint32_t desc_set,uint32_t binding,Descriptor ** desc)353 Result Pipeline::GetDescriptorSlot(uint32_t desc_set,
354                                    uint32_t binding,
355                                    Descriptor** desc) {
356   *desc = nullptr;
357 
358   if (desc_set >= descriptor_set_info_.size()) {
359     for (size_t i = descriptor_set_info_.size();
360          i <= static_cast<size_t>(desc_set); ++i) {
361       descriptor_set_info_.emplace_back();
362     }
363   }
364 
365   if (descriptor_set_info_[desc_set].empty &&
366       descriptor_related_objects_already_created_) {
367     return Result(
368         "Vulkan: Pipeline descriptor related objects were already created but "
369         "try to put data on empty descriptor set '" +
370         std::to_string(desc_set) +
371         "'. Note that all used descriptor sets must be allocated before the "
372         "first compute or draw.");
373   }
374   descriptor_set_info_[desc_set].empty = false;
375 
376   auto& descriptors = descriptor_set_info_[desc_set].descriptors;
377   for (auto& descriptor : descriptors) {
378     if (descriptor->GetBinding() == binding)
379       *desc = descriptor.get();
380   }
381 
382   return {};
383 }
384 
AddDescriptorBuffer(Buffer * amber_buffer)385 Result Pipeline::AddDescriptorBuffer(Buffer* amber_buffer) {
386   // Don't add the buffer if it's already added.
387   const auto& buffer =
388       std::find_if(descriptor_buffers_.begin(), descriptor_buffers_.end(),
389                    [&](const Buffer* buf) { return buf == amber_buffer; });
390   if (buffer != descriptor_buffers_.end()) {
391     return {};
392   }
393   descriptor_buffers_.push_back(amber_buffer);
394   return {};
395 }
396 
AddBufferDescriptor(const BufferCommand * cmd)397 Result Pipeline::AddBufferDescriptor(const BufferCommand* cmd) {
398   if (cmd == nullptr)
399     return Result("Pipeline::AddBufferDescriptor BufferCommand is nullptr");
400   if (!cmd->IsSSBO() && !cmd->IsUniform() && !cmd->IsStorageImage() &&
401       !cmd->IsSampledImage() && !cmd->IsCombinedImageSampler() &&
402       !cmd->IsUniformTexelBuffer() && !cmd->IsStorageTexelBuffer() &&
403       !cmd->IsUniformDynamic() && !cmd->IsSSBODynamic()) {
404     return Result("Pipeline::AddBufferDescriptor not supported buffer type");
405   }
406 
407   Descriptor* desc;
408   Result r =
409       GetDescriptorSlot(cmd->GetDescriptorSet(), cmd->GetBinding(), &desc);
410   if (!r.IsSuccess())
411     return r;
412 
413   auto& descriptors = descriptor_set_info_[cmd->GetDescriptorSet()].descriptors;
414 
415   bool is_image = false;
416   DescriptorType desc_type = DescriptorType::kUniformBuffer;
417 
418   if (cmd->IsStorageImage()) {
419     desc_type = DescriptorType::kStorageImage;
420     is_image = true;
421   } else if (cmd->IsSampledImage()) {
422     desc_type = DescriptorType::kSampledImage;
423     is_image = true;
424   } else if (cmd->IsCombinedImageSampler()) {
425     desc_type = DescriptorType::kCombinedImageSampler;
426     is_image = true;
427   } else if (cmd->IsUniformTexelBuffer()) {
428     desc_type = DescriptorType::kUniformTexelBuffer;
429   } else if (cmd->IsStorageTexelBuffer()) {
430     desc_type = DescriptorType::kStorageTexelBuffer;
431   } else if (cmd->IsSSBO()) {
432     desc_type = DescriptorType::kStorageBuffer;
433   } else if (cmd->IsUniformDynamic()) {
434     desc_type = DescriptorType::kUniformBufferDynamic;
435   } else if (cmd->IsSSBODynamic()) {
436     desc_type = DescriptorType::kStorageBufferDynamic;
437   }
438 
439   if (desc == nullptr) {
440     if (is_image) {
441       auto image_desc = MakeUnique<ImageDescriptor>(
442           cmd->GetBuffer(), desc_type, device_, cmd->GetBaseMipLevel(),
443           cmd->GetDescriptorSet(), cmd->GetBinding(), this);
444       if (cmd->IsCombinedImageSampler())
445         image_desc->SetAmberSampler(cmd->GetSampler());
446 
447       descriptors.push_back(std::move(image_desc));
448     } else {
449       auto buffer_desc = MakeUnique<BufferDescriptor>(
450           cmd->GetBuffer(), desc_type, device_, cmd->GetDescriptorSet(),
451           cmd->GetBinding(), this);
452       descriptors.push_back(std::move(buffer_desc));
453     }
454     AddDescriptorBuffer(cmd->GetBuffer());
455     desc = descriptors.back().get();
456   } else {
457     if (desc->GetDescriptorType() != desc_type) {
458       return Result(
459           "Descriptors bound to the same binding needs to have matching "
460           "descriptor types");
461     }
462     desc->AsBufferBackedDescriptor()->AddAmberBuffer(cmd->GetBuffer());
463     AddDescriptorBuffer(cmd->GetBuffer());
464   }
465 
466   if (cmd->IsUniformDynamic() || cmd->IsSSBODynamic())
467     desc->AsBufferDescriptor()->AddDynamicOffset(cmd->GetDynamicOffset());
468 
469   if (cmd->IsUniform() || cmd->IsUniformDynamic() || cmd->IsSSBO() ||
470       cmd->IsSSBODynamic()) {
471     desc->AsBufferDescriptor()->AddDescriptorOffset(cmd->GetDescriptorOffset());
472     desc->AsBufferDescriptor()->AddDescriptorRange(cmd->GetDescriptorRange());
473   }
474 
475   if (cmd->IsSSBO() && !desc->IsStorageBuffer()) {
476     return Result(
477         "Vulkan::AddBufferDescriptor BufferCommand for SSBO uses wrong "
478         "descriptor "
479         "set and binding");
480   }
481 
482   if (cmd->IsUniform() && !desc->IsUniformBuffer()) {
483     return Result(
484         "Vulkan::AddBufferDescriptor BufferCommand for UBO uses wrong "
485         "descriptor set "
486         "and binding");
487   }
488 
489   return {};
490 }
491 
AddSamplerDescriptor(const SamplerCommand * cmd)492 Result Pipeline::AddSamplerDescriptor(const SamplerCommand* cmd) {
493   if (cmd == nullptr)
494     return Result("Pipeline::AddSamplerDescriptor SamplerCommand is nullptr");
495 
496   Descriptor* desc;
497   Result r =
498       GetDescriptorSlot(cmd->GetDescriptorSet(), cmd->GetBinding(), &desc);
499   if (!r.IsSuccess())
500     return r;
501 
502   auto& descriptors = descriptor_set_info_[cmd->GetDescriptorSet()].descriptors;
503 
504   if (desc == nullptr) {
505     auto sampler_desc = MakeUnique<SamplerDescriptor>(
506         cmd->GetSampler(), DescriptorType::kSampler, device_,
507         cmd->GetDescriptorSet(), cmd->GetBinding());
508     descriptors.push_back(std::move(sampler_desc));
509   } else {
510     if (desc->GetDescriptorType() != DescriptorType::kSampler) {
511       return Result(
512           "Descriptors bound to the same binding needs to have matching "
513           "descriptor types");
514     }
515     desc->AsSamplerDescriptor()->AddAmberSampler(cmd->GetSampler());
516   }
517 
518   return {};
519 }
520 
AddTLASDescriptor(const TLASCommand * cmd)521 Result Pipeline::AddTLASDescriptor(const TLASCommand* cmd) {
522   if (cmd == nullptr)
523     return Result("Pipeline::AddTLASDescriptor TLASCommand is nullptr");
524 
525   Descriptor* desc;
526   Result r =
527       GetDescriptorSlot(cmd->GetDescriptorSet(), cmd->GetBinding(), &desc);
528   if (!r.IsSuccess())
529     return r;
530 
531   auto& descriptors = descriptor_set_info_[cmd->GetDescriptorSet()].descriptors;
532 
533   if (desc == nullptr) {
534     auto tlas_desc = MakeUnique<TLASDescriptor>(
535         cmd->GetTLAS(), DescriptorType::kTLAS, device_, GetBlases(),
536         GetTlases(), cmd->GetDescriptorSet(), cmd->GetBinding());
537     descriptors.push_back(std::move(tlas_desc));
538   } else {
539     if (desc->GetDescriptorType() != DescriptorType::kTLAS) {
540       return Result(
541           "Descriptors bound to the same binding needs to have matching "
542           "descriptor types");
543     }
544     desc->AsTLASDescriptor()->AddAmberTLAS(cmd->GetTLAS());
545   }
546 
547   return {};
548 }
549 
SendDescriptorDataToDeviceIfNeeded()550 Result Pipeline::SendDescriptorDataToDeviceIfNeeded() {
551   {
552     CommandBufferGuard guard(GetCommandBuffer());
553     if (!guard.IsRecording())
554       return guard.GetResult();
555 
556     for (auto& info : descriptor_set_info_) {
557       for (auto& desc : info.descriptors) {
558         Result r = desc->CreateResourceIfNeeded();
559         if (!r.IsSuccess())
560           return r;
561       }
562     }
563 
564     // Initialize transfer buffers / images.
565     for (auto buffer : descriptor_buffers_) {
566       if (descriptor_transfer_resources_.count(buffer) == 0) {
567         return Result(
568             "Vulkan: Pipeline::SendDescriptorDataToDeviceIfNeeded() "
569             "descriptor's transfer resource is not found");
570       }
571       Result r = descriptor_transfer_resources_[buffer]->Initialize();
572       if (!r.IsSuccess())
573         return r;
574     }
575 
576     // Note that if a buffer for a descriptor is host accessible and
577     // does not need to record a command to copy data to device, it
578     // directly writes data to the buffer. The direct write must be
579     // done after resizing backed buffer i.e., copying data to the new
580     // buffer from the old one. Thus, we must submit commands here to
581     // guarantee this.
582     Result r =
583         guard.Submit(GetFenceTimeout(), GetPipelineRuntimeLayerEnabled());
584     if (!r.IsSuccess())
585       return r;
586   }
587 
588   CommandBufferGuard guard(GetCommandBuffer());
589   if (!guard.IsRecording())
590     return guard.GetResult();
591 
592   // Copy descriptor data to transfer resources.
593   for (auto& buffer : descriptor_buffers_) {
594     if (auto transfer_buffer =
595             descriptor_transfer_resources_[buffer]->AsTransferBuffer()) {
596       BufferBackedDescriptor::RecordCopyBufferDataToTransferResourceIfNeeded(
597           GetCommandBuffer(), buffer, transfer_buffer);
598     } else if (auto transfer_image =
599                    descriptor_transfer_resources_[buffer]->AsTransferImage()) {
600       transfer_image->ImageBarrier(GetCommandBuffer(),
601                                    VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL,
602                                    VK_PIPELINE_STAGE_TRANSFER_BIT);
603 
604       BufferBackedDescriptor::RecordCopyBufferDataToTransferResourceIfNeeded(
605           GetCommandBuffer(), buffer, transfer_image);
606 
607       transfer_image->ImageBarrier(GetCommandBuffer(), VK_IMAGE_LAYOUT_GENERAL,
608                                    VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT);
609     } else {
610       return Result(
611           "Vulkan: Pipeline::SendDescriptorDataToDeviceIfNeeded() "
612           "this should be unreachable");
613     }
614   }
615   return guard.Submit(GetFenceTimeout(), GetPipelineRuntimeLayerEnabled());
616 }
617 
BindVkDescriptorSets(const VkPipelineLayout & pipeline_layout)618 void Pipeline::BindVkDescriptorSets(const VkPipelineLayout& pipeline_layout) {
619   for (size_t i = 0; i < descriptor_set_info_.size(); ++i) {
620     if (descriptor_set_info_[i].empty)
621       continue;
622 
623     // Sort descriptors by binding number to get correct order of dynamic
624     // offsets.
625     typedef std::pair<uint32_t, std::vector<uint32_t>> binding_offsets_pair;
626     std::vector<binding_offsets_pair> binding_offsets;
627     for (const auto& desc : descriptor_set_info_[i].descriptors) {
628       binding_offsets.push_back(
629           {desc->GetBinding(), desc->GetDynamicOffsets()});
630     }
631 
632     std::sort(std::begin(binding_offsets), std::end(binding_offsets),
633               [](const binding_offsets_pair& a, const binding_offsets_pair& b) {
634                 return a.first < b.first;
635               });
636 
637     // Add the sorted dynamic offsets.
638     std::vector<uint32_t> dynamic_offsets;
639     for (const auto& binding_offset : binding_offsets) {
640       for (auto offset : binding_offset.second) {
641         dynamic_offsets.push_back(offset);
642       }
643     }
644 
645     device_->GetPtrs()->vkCmdBindDescriptorSets(
646         command_->GetVkCommandBuffer(),
647         IsGraphics()     ? VK_PIPELINE_BIND_POINT_GRAPHICS
648         : IsCompute()    ? VK_PIPELINE_BIND_POINT_COMPUTE
649         : IsRayTracing() ? VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR
650                          : VK_PIPELINE_BIND_POINT_MAX_ENUM,
651         pipeline_layout, static_cast<uint32_t>(i), 1,
652         &descriptor_set_info_[i].vk_desc_set,
653         static_cast<uint32_t>(dynamic_offsets.size()), dynamic_offsets.data());
654   }
655 }
656 
ReadbackDescriptorsToHostDataQueue()657 Result Pipeline::ReadbackDescriptorsToHostDataQueue() {
658   if (descriptor_buffers_.empty())
659     return Result{};
660 
661   // Record required commands to copy the data to a host visible buffer.
662   {
663     CommandBufferGuard guard(GetCommandBuffer());
664     if (!guard.IsRecording())
665       return guard.GetResult();
666 
667     for (auto& buffer : descriptor_buffers_) {
668       if (descriptor_transfer_resources_.count(buffer) == 0) {
669         return Result(
670             "Vulkan: Pipeline::ReadbackDescriptorsToHostDataQueue() "
671             "descriptor's transfer resource is not found");
672       }
673       if (auto transfer_buffer =
674               descriptor_transfer_resources_[buffer]->AsTransferBuffer()) {
675         Result r = BufferBackedDescriptor::RecordCopyTransferResourceToHost(
676             GetCommandBuffer(), transfer_buffer);
677         if (!r.IsSuccess())
678           return r;
679       } else if (auto transfer_image = descriptor_transfer_resources_[buffer]
680                                            ->AsTransferImage()) {
681         transfer_image->ImageBarrier(GetCommandBuffer(),
682                                      VK_IMAGE_LAYOUT_TRANSFER_SRC_OPTIMAL,
683                                      VK_PIPELINE_STAGE_TRANSFER_BIT);
684         Result r = BufferBackedDescriptor::RecordCopyTransferResourceToHost(
685             GetCommandBuffer(), transfer_image);
686         if (!r.IsSuccess())
687           return r;
688       } else {
689         return Result(
690             "Vulkan: Pipeline::ReadbackDescriptorsToHostDataQueue() "
691             "this should be unreachable");
692       }
693     }
694 
695     Result r =
696         guard.Submit(GetFenceTimeout(), GetPipelineRuntimeLayerEnabled());
697     if (!r.IsSuccess())
698       return r;
699   }
700 
701   // Move data from transfer buffers to output buffers.
702   for (auto& buffer : descriptor_buffers_) {
703     auto& transfer_resource = descriptor_transfer_resources_[buffer];
704     Result r = BufferBackedDescriptor::MoveTransferResourceToBufferOutput(
705         transfer_resource.get(), buffer);
706     if (!r.IsSuccess())
707       return r;
708   }
709   descriptor_transfer_resources_.clear();
710   return {};
711 }
712 
GetEntryPointName(VkShaderStageFlagBits stage) const713 const char* Pipeline::GetEntryPointName(VkShaderStageFlagBits stage) const {
714   auto it = entry_points_.find(stage);
715   if (it != entry_points_.end())
716     return it->second.c_str();
717 
718   return kDefaultEntryPointName;
719 }
720 
721 }  // namespace vulkan
722 }  // namespace amber
723