• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2018 The Amber Authors.
2 // Copyright (C) 2024 Advanced Micro Devices, Inc. All rights reserved.
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 
16 #ifndef SRC_PIPELINE_H_
17 #define SRC_PIPELINE_H_
18 
19 #include <map>
20 #include <memory>
21 #include <string>
22 #include <unordered_map>
23 #include <utility>
24 #include <vector>
25 
26 #include "amber/result.h"
27 #include "src/acceleration_structure.h"
28 #include "src/buffer.h"
29 #include "src/command_data.h"
30 #include "src/pipeline_data.h"
31 #include "src/sampler.h"
32 #include "src/shader.h"
33 
34 namespace amber {
35 
36 enum class PipelineType { kCompute = 0, kGraphics, kRayTracing };
37 
38 /// Stores all information related to a pipeline.
39 class Pipeline {
40  public:
41   /// Information on a shader attached to this pipeline.
42   class ShaderInfo {
43    public:
44     ShaderInfo(Shader*, ShaderType type);
45     ShaderInfo(const ShaderInfo&);
46     ~ShaderInfo();
47 
48     ShaderInfo& operator=(const ShaderInfo&) = default;
49 
50     // Set the optimization options for this shader. Optimizations are
51     // specified like command-line arguments to spirv-opt (see its --help).
52     // Parsing is done by spvtools::Optimizer::RegisterPassesFromFlags (see
53     // SPIRV-Tools include/spirv-tools/optimizer.hpp).
SetShaderOptimizations(const std::vector<std::string> & opts)54     void SetShaderOptimizations(const std::vector<std::string>& opts) {
55       shader_optimizations_ = opts;
56     }
GetShaderOptimizations()57     const std::vector<std::string>& GetShaderOptimizations() const {
58       return shader_optimizations_;
59     }
60 
SetCompileOptions(const std::vector<std::string> & options)61     void SetCompileOptions(const std::vector<std::string>& options) {
62       compile_options_ = options;
63     }
GetCompileOptions()64     const std::vector<std::string>& GetCompileOptions() const {
65       return compile_options_;
66     }
67 
68     enum class RequiredSubgroupSizeSetting : uint32_t {
69       kNotSet = 0,
70       kSetToSpecificSize,
71       kSetToMinimumSize,
72       kSetToMaximumSize
73     };
74 
SetRequiredSubgroupSizeSetting(RequiredSubgroupSizeSetting setting,uint32_t size)75     void SetRequiredSubgroupSizeSetting(RequiredSubgroupSizeSetting setting,
76                                         uint32_t size) {
77       required_subgroup_size_setting_ = setting;
78       required_subgroup_size_ = size;
79     }
GetRequiredSubgroupSizeSetting()80     RequiredSubgroupSizeSetting GetRequiredSubgroupSizeSetting() const {
81       return required_subgroup_size_setting_;
82     }
GetRequiredSubgroupSize()83     uint32_t GetRequiredSubgroupSize() const { return required_subgroup_size_; }
84 
SetVaryingSubgroupSize(const bool isSet)85     void SetVaryingSubgroupSize(const bool isSet) {
86       varying_subgroup_size_ = isSet;
87     }
GetVaryingSubgroupSize()88     bool GetVaryingSubgroupSize() const { return varying_subgroup_size_; }
89 
SetRequireFullSubgroups(const bool isSet)90     void SetRequireFullSubgroups(const bool isSet) {
91       require_full_subgroups_ = isSet;
92     }
GetRequireFullSubgroups()93     bool GetRequireFullSubgroups() const { return require_full_subgroups_; }
94 
SetShader(Shader * shader)95     void SetShader(Shader* shader) { shader_ = shader; }
GetShader()96     const Shader* GetShader() const { return shader_; }
97 
SetEntryPoint(const std::string & ep)98     void SetEntryPoint(const std::string& ep) { entry_point_ = ep; }
GetEntryPoint()99     const std::string& GetEntryPoint() const { return entry_point_; }
100 
SetShaderType(ShaderType type)101     void SetShaderType(ShaderType type) { shader_type_ = type; }
GetShaderType()102     ShaderType GetShaderType() const { return shader_type_; }
103 
GetData()104     const std::vector<uint32_t> GetData() const { return data_; }
SetData(std::vector<uint32_t> && data)105     void SetData(std::vector<uint32_t>&& data) { data_ = std::move(data); }
106 
GetSpecialization()107     const std::map<uint32_t, uint32_t>& GetSpecialization() const {
108       return specialization_;
109     }
AddSpecialization(uint32_t spec_id,uint32_t value)110     void AddSpecialization(uint32_t spec_id, uint32_t value) {
111       specialization_[spec_id] = value;
112     }
113 
114     /// Descriptor information for an OpenCL-C shader.
115     struct DescriptorMapEntry {
116       std::string arg_name = "";
117 
118       enum class Kind : int {
119         UNKNOWN,
120         SSBO,
121         UBO,
122         POD,
123         POD_UBO,
124         POD_PUSHCONSTANT,
125         RO_IMAGE,
126         WO_IMAGE,
127         SAMPLER,
128       } kind;
129 
130       uint32_t descriptor_set = 0;
131       uint32_t binding = 0;
132       uint32_t arg_ordinal = 0;
133       uint32_t pod_offset = 0;
134       uint32_t pod_arg_size = 0;
135     };
136 
AddDescriptorEntry(const std::string & kernel,DescriptorMapEntry && entry)137     void AddDescriptorEntry(const std::string& kernel,
138                             DescriptorMapEntry&& entry) {
139       descriptor_map_[kernel].emplace_back(std::move(entry));
140     }
141     const std::unordered_map<std::string, std::vector<DescriptorMapEntry>>&
GetDescriptorMap()142     GetDescriptorMap() const {
143       return descriptor_map_;
144     }
145 
146     /// Push constant information for an OpenCL-C shader.
147     struct PushConstant {
148       enum class PushConstantType {
149         kDimensions = 0,
150         kGlobalOffset,
151         kRegionOffset,
152       };
153       PushConstantType type;
154       uint32_t offset = 0;
155       uint32_t size = 0;
156     };
157 
AddPushConstant(PushConstant && pc)158     void AddPushConstant(PushConstant&& pc) {
159       push_constants_.emplace_back(std::move(pc));
160     }
GetPushConstants()161     const std::vector<PushConstant>& GetPushConstants() const {
162       return push_constants_;
163     }
164 
165    private:
166     Shader* shader_ = nullptr;
167     ShaderType shader_type_;
168     std::vector<std::string> shader_optimizations_;
169     std::string entry_point_;
170     std::vector<uint32_t> data_;
171     std::map<uint32_t, uint32_t> specialization_;
172     std::unordered_map<std::string, std::vector<DescriptorMapEntry>>
173         descriptor_map_;
174     std::vector<PushConstant> push_constants_;
175     std::vector<std::string> compile_options_;
176     RequiredSubgroupSizeSetting required_subgroup_size_setting_;
177     uint32_t required_subgroup_size_;
178     bool varying_subgroup_size_;
179     bool require_full_subgroups_;
180   };
181 
182   /// Information on a buffer attached to the pipeline.
183   ///
184   /// The BufferInfo will have either (descriptor_set, binding) or location
185   /// attached.
186   struct BufferInfo {
187     BufferInfo() = default;
BufferInfoBufferInfo188     explicit BufferInfo(Buffer* buf) : buffer(buf) {}
189 
190     Buffer* buffer = nullptr;
191     uint32_t descriptor_set = 0;
192     uint32_t binding = 0;
193     uint32_t location = 0;
194     uint32_t base_mip_level = 0;
195     uint32_t dynamic_offset = 0;
196     std::string arg_name = "";
197     uint32_t arg_no = 0;
198     BufferType type = BufferType::kUnknown;
199     InputRate input_rate = InputRate::kVertex;
200     Format* format;
201     uint32_t offset = 0;
202     uint32_t stride = 0;
203     Sampler* sampler = nullptr;
204     uint64_t descriptor_offset = 0;
205     uint64_t descriptor_range = ~0ULL;  // ~0ULL == VK_WHOLE_SIZE
206   };
207 
208   /// Information on a sampler attached to the pipeline.
209   struct SamplerInfo {
210     SamplerInfo() = default;
SamplerInfoSamplerInfo211     explicit SamplerInfo(Sampler* samp) : sampler(samp) {}
212 
213     Sampler* sampler = nullptr;
214     uint32_t descriptor_set = 0;
215     uint32_t binding = 0;
216     std::string arg_name = "";
217     uint32_t arg_no = 0;
218     uint32_t mask = 0;
219   };
220 
221   /// Information on a top level acceleration structure at the pipeline.
222   struct TLASInfo {
223     TLASInfo() = default;
TLASInfoTLASInfo224     explicit TLASInfo(TLAS* as) : tlas(as) {}
225 
226     TLAS* tlas = nullptr;
227     uint32_t descriptor_set = 0;
228     uint32_t binding = 0;
229   };
230   static const char* kGeneratedColorBuffer;
231   static const char* kGeneratedDepthBuffer;
232   static const char* kGeneratedPushConstantBuffer;
233 
234   explicit Pipeline(PipelineType type);
235   ~Pipeline();
236 
237   std::unique_ptr<Pipeline> Clone() const;
238 
IsGraphics()239   bool IsGraphics() const { return pipeline_type_ == PipelineType::kGraphics; }
IsCompute()240   bool IsCompute() const { return pipeline_type_ == PipelineType::kCompute; }
IsRayTracing()241   bool IsRayTracing() const {
242     return pipeline_type_ == PipelineType::kRayTracing;
243   }
244 
GetType()245   PipelineType GetType() const { return pipeline_type_; }
246 
SetName(const std::string & name)247   void SetName(const std::string& name) { name_ = name; }
GetName()248   const std::string& GetName() const { return name_; }
249 
SetFramebufferWidth(uint32_t fb_width)250   void SetFramebufferWidth(uint32_t fb_width) {
251     fb_width_ = fb_width;
252     UpdateFramebufferSizes();
253   }
GetFramebufferWidth()254   uint32_t GetFramebufferWidth() const { return fb_width_; }
255 
SetFramebufferHeight(uint32_t fb_height)256   void SetFramebufferHeight(uint32_t fb_height) {
257     fb_height_ = fb_height;
258     UpdateFramebufferSizes();
259   }
GetFramebufferHeight()260   uint32_t GetFramebufferHeight() const { return fb_height_; }
261 
262   /// Adds |shader| of |type| to the pipeline.
263   Result AddShader(Shader* shader, ShaderType type);
264   /// Returns information on all bound shaders in this pipeline.
GetShaders()265   std::vector<ShaderInfo>& GetShaders() { return shaders_; }
266   /// Returns information on all bound shaders in this pipeline.
GetShaders()267   const std::vector<ShaderInfo>& GetShaders() const { return shaders_; }
268 
269   /// Returns the ShaderInfo for |shader| or nullptr.
GetShader(Shader * shader)270   const ShaderInfo* GetShader(Shader* shader) const {
271     for (const auto& info : shaders_) {
272       if (info.GetShader() == shader)
273         return &info;
274     }
275     return nullptr;
276   }
277 
278   /// Adds |shaders| to the pipeline.
279   /// Designed to support libraries
AddShaders(const std::vector<ShaderInfo> & lib_shaders)280   Result AddShaders(const std::vector<ShaderInfo>& lib_shaders) {
281     shaders_.reserve(shaders_.size() + lib_shaders.size());
282     shaders_.insert(std::end(shaders_), std::begin(lib_shaders),
283                     std::end(lib_shaders));
284 
285     return {};
286   }
287 
288   /// Returns a success result if |shader| found and the shader index is
289   /// returned in |out|. Returns failure otherwise.
GetShaderIndex(Shader * shader,uint32_t * out)290   Result GetShaderIndex(Shader* shader, uint32_t* out) const {
291     for (size_t index = 0; index < shaders_.size(); index++) {
292       if (shaders_[index].GetShader() == shader) {
293         *out = static_cast<uint32_t>(index);
294         return {};
295       }
296     }
297     return Result("Referred shader not found in group");
298   }
299 
300   /// Sets the |type| of |shader| in the pipeline.
301   Result SetShaderType(const Shader* shader, ShaderType type);
302   /// Sets the entry point |name| for |shader| in this pipeline.
303   Result SetShaderEntryPoint(const Shader* shader, const std::string& name);
304   /// Sets the optimizations (|opts|) for |shader| in this pipeline.
305   Result SetShaderOptimizations(const Shader* shader,
306                                 const std::vector<std::string>& opts);
307   /// Sets the compile options for |shader| in this pipeline.
308   Result SetShaderCompileOptions(const Shader* shader,
309                                  const std::vector<std::string>& options);
310   /// Sets required subgroup size.
311   Result SetShaderRequiredSubgroupSize(const Shader* shader,
312                                        const uint32_t subgroupSize);
313   /// Sets required subgroup size to the device minimum supported subgroup size.
314   Result SetShaderRequiredSubgroupSizeToMinimum(const Shader* shader);
315 
316   /// Sets required subgroup size to the device maximum supported subgroup size.
317   Result SetShaderRequiredSubgroupSizeToMaximum(const Shader* shader);
318 
319   /// Sets varying subgroup size property.
320   Result SetShaderVaryingSubgroupSize(const Shader* shader, const bool isSet);
321 
322   /// Sets require full subgroups property.
323   Result SetShaderRequireFullSubgroups(const Shader* shader, const bool isSet);
324   /// Returns a list of all colour attachments in this pipeline.
GetColorAttachments()325   const std::vector<BufferInfo>& GetColorAttachments() const {
326     return color_attachments_;
327   }
328   /// Adds |buf| as a colour attachment at |location| in the pipeline.
329   /// Uses |base_mip_level| as the mip level for output.
330   Result AddColorAttachment(Buffer* buf,
331                             uint32_t location,
332                             uint32_t base_mip_level);
333   /// Retrieves the location that |buf| is bound to in the pipeline. The
334   /// location will be written to |loc|. An error result will be return if
335   /// something goes wrong.
336   Result GetLocationForColorAttachment(Buffer* buf, uint32_t* loc) const;
337 
338   /// Returns a list of all resolve targets in this pipeline.
GetResolveTargets()339   const std::vector<BufferInfo>& GetResolveTargets() const {
340     return resolve_targets_;
341   }
342 
343   /// Adds |buf| as a multisample resolve target in the pipeline.
344   Result AddResolveTarget(Buffer* buf);
345 
346   /// Sets |buf| as the depth/stencil buffer for this pipeline.
347   Result SetDepthStencilBuffer(Buffer* buf);
348   /// Returns information on the depth/stencil buffer bound to the pipeline. If
349   /// no depth buffer is bound the |BufferInfo::buffer| parameter will be
350   /// nullptr.
GetDepthStencilBuffer()351   const BufferInfo& GetDepthStencilBuffer() const {
352     return depth_stencil_buffer_;
353   }
354 
355   /// Returns pipeline data.
GetPipelineData()356   PipelineData* GetPipelineData() { return &pipeline_data_; }
357 
358   /// Returns information on all vertex buffers bound to the pipeline.
GetVertexBuffers()359   const std::vector<BufferInfo>& GetVertexBuffers() const {
360     return vertex_buffers_;
361   }
362   /// Adds |buf| as a vertex buffer at |location| in the pipeline using |rate|
363   /// as the input rate, |format| as vertex data format, |offset| as a starting
364   /// offset for the vertex buffer data, and |stride| for the data stride in
365   /// bytes.
366   Result AddVertexBuffer(Buffer* buf,
367                          uint32_t location,
368                          InputRate rate,
369                          Format* format,
370                          uint32_t offset,
371                          uint32_t stride);
372 
373   /// Binds |buf| as the index buffer for this pipeline.
374   Result SetIndexBuffer(Buffer* buf);
375   /// Returns the index buffer bound to this pipeline or nullptr if no index
376   /// buffer bound.
GetIndexBuffer()377   Buffer* GetIndexBuffer() const { return index_buffer_; }
378 
379   /// Adds |buf| of |type| to the pipeline at the given |descriptor_set|,
380   /// |binding|, |base_mip_level|, |descriptor_offset|, |descriptor_range| and
381   /// |dynamic_offset|.
382   void AddBuffer(Buffer* buf,
383                  BufferType type,
384                  uint32_t descriptor_set,
385                  uint32_t binding,
386                  uint32_t base_mip_level,
387                  uint32_t dynamic_offset,
388                  uint64_t descriptor_offset,
389                  uint64_t descriptor_range);
390   /// Adds |buf| to the pipeline at the given |arg_name|.
391   void AddBuffer(Buffer* buf, BufferType type, const std::string& arg_name);
392   /// Adds |buf| to the pipeline at the given |arg_no|.
393   void AddBuffer(Buffer* buf, BufferType type, uint32_t arg_no);
394   /// Returns information on all buffers in this pipeline.
GetBuffers()395   const std::vector<BufferInfo>& GetBuffers() const { return buffers_; }
396   /// Clears all buffer bindings for given |descriptor_set| and |binding|.
397   void ClearBuffers(uint32_t descriptor_set, uint32_t binding);
398 
399   /// Adds |sampler| to the pipeline at the given |descriptor_set| and
400   /// |binding|.
401   void AddSampler(Sampler* sampler, uint32_t descriptor_set, uint32_t binding);
402   /// Adds |sampler| to the pipeline at the given |arg_name|.
403   void AddSampler(Sampler* sampler, const std::string& arg_name);
404   /// Adds |sampler| to the pieline at the given |arg_no|.
405   void AddSampler(Sampler* sampler, uint32_t arg_no);
406   /// Adds an entry for an OpenCL literal sampler.
407   void AddSampler(uint32_t sampler_mask,
408                   uint32_t descriptor_set,
409                   uint32_t binding);
410   /// Clears all sampler bindings for given |descriptor_set| and |binding|.
411   void ClearSamplers(uint32_t descriptor_set, uint32_t binding);
412 
413   /// Returns information on all samplers in this pipeline.
GetSamplers()414   const std::vector<SamplerInfo>& GetSamplers() const { return samplers_; }
415 
416   /// Adds |tlas| to the pipeline at the given |descriptor_set| and
417   /// |binding|.
418   void AddTLAS(TLAS* tlas, uint32_t descriptor_set, uint32_t binding);
419 
420   /// Returns information on all bound TLAS in the pipeline.
GetTLASes()421   std::vector<TLASInfo>& GetTLASes() { return tlases_; }
422 
423   /// Adds |sbt| to the list of known shader binding tables.
424   /// The |sbt| must have a unique name within pipeline.
AddSBT(std::unique_ptr<SBT> sbt)425   Result AddSBT(std::unique_ptr<SBT> sbt) {
426     if (name_to_sbt_.count(sbt->GetName()) > 0)
427       return Result("duplicate SBT name provided");
428 
429     sbts_.push_back(std::move(sbt));
430     name_to_sbt_[sbts_.back()->GetName()] = sbts_.back().get();
431 
432     return {};
433   }
434 
435   /// Retrieves the SBT with |name|, |nullptr| if not found.
GetSBT(const std::string & name)436   SBT* GetSBT(const std::string& name) const {
437     auto it = name_to_sbt_.find(name);
438     return it == name_to_sbt_.end() ? nullptr : it->second;
439   }
440 
441   /// Retrieves a list of all SBTs.
GetSBTs()442   const std::vector<std::unique_ptr<SBT>>& GetSBTs() const { return sbts_; }
443 
444   /// Adds |group| to the list of known shader groups.
445   /// The |group| must have a unique name within pipeline.
AddShaderGroup(std::shared_ptr<ShaderGroup> group)446   Result AddShaderGroup(std::shared_ptr<ShaderGroup> group) {
447     if (name_to_shader_group_.count(group->GetName()) > 0)
448       return Result("shader group name already exists");
449 
450     shader_groups_.push_back(std::move(group));
451     name_to_shader_group_[shader_groups_.back()->GetName()] =
452         shader_groups_.back().get();
453 
454     return {};
455   }
456 
457   /// Retrieves the Shader Group with |name|, |nullptr| if not found.
GetShaderGroup(const std::string & name)458   ShaderGroup* GetShaderGroup(const std::string& name) const {
459     auto it = name_to_shader_group_.find(name);
460     return it == name_to_shader_group_.end() ? nullptr : it->second;
461   }
462   /// Retrieves a Shader Group at given |index|.
GetShaderGroupByIndex(uint32_t index)463   ShaderGroup* GetShaderGroupByIndex(uint32_t index) const {
464     return shader_groups_[index].get();
465   }
466   /// Retreives index of shader group specified by |name|
GetShaderGroupIndex(const std::string & name)467   uint32_t GetShaderGroupIndex(const std::string& name) const {
468     ShaderGroup* shader_group = GetShaderGroup(name);
469 
470     for (size_t i = 0; i < shader_groups_.size(); i++) {
471       if (shader_groups_[i].get() == shader_group) {
472         return static_cast<uint32_t>(i);
473       }
474     }
475 
476     return static_cast<uint32_t>(-1);
477   }
478   /// Retrieves a list of all Shader Groups.
GetShaderGroups()479   const std::vector<std::shared_ptr<ShaderGroup>>& GetShaderGroups() const {
480     return shader_groups_;
481   }
482 
483   /// Updates the descriptor set and binding info for the OpenCL-C kernel bound
484   /// to the pipeline. No effect for other shader formats.
485   Result UpdateOpenCLBufferBindings();
486 
487   /// Returns the buffer which is currently bound to this pipeline at
488   /// |descriptor_set| and |binding|.
489   Buffer* GetBufferForBinding(uint32_t descriptor_set, uint32_t binding) const;
490 
491   Result SetPushConstantBuffer(Buffer* buf);
GetPushConstantBuffer()492   const BufferInfo& GetPushConstantBuffer() const {
493     return push_constant_buffer_;
494   }
495 
496   /// Validates that the pipeline has been created correctly.
497   Result Validate() const;
498 
499   /// Generates a default color attachment in B8G8R8A8_UNORM.
500   std::unique_ptr<Buffer> GenerateDefaultColorAttachmentBuffer();
501   /// Generates a default depth/stencil attachment in D32_SFLOAT_S8_UINT format.
502   std::unique_ptr<Buffer> GenerateDefaultDepthStencilAttachmentBuffer();
503 
504   /// Information on values set for OpenCL-C plain-old-data args.
505   struct ArgSetInfo {
506     std::string name;
507     uint32_t ordinal = 0;
508     Format* fmt = nullptr;
509     Value value;
510   };
511 
512   /// Adds value from SET command.
SetArg(ArgSetInfo && info)513   void SetArg(ArgSetInfo&& info) { set_arg_values_.push_back(std::move(info)); }
SetArgValues()514   const std::vector<ArgSetInfo>& SetArgValues() const {
515     return set_arg_values_;
516   }
517 
518   /// Generate the buffers necessary for OpenCL PoD arguments populated via SET
519   /// command. This should be called after all other buffers are bound.
520   Result GenerateOpenCLPodBuffers();
521 
522   /// Generate the samplers necessary for OpenCL literal samplers from the
523   /// descriptor map. This should be called after all other samplers are bound.
524   Result GenerateOpenCLLiteralSamplers();
525 
526   /// Generate the push constant buffers necessary for OpenCL kernels.
527   Result GenerateOpenCLPushConstants();
528 
SetMaxPipelineRayPayloadSize(uint32_t size)529   void SetMaxPipelineRayPayloadSize(uint32_t size) {
530     max_pipeline_ray_payload_size_ = size;
531   }
GetMaxPipelineRayPayloadSize()532   uint32_t GetMaxPipelineRayPayloadSize() {
533     return max_pipeline_ray_payload_size_;
534   }
SetMaxPipelineRayHitAttributeSize(uint32_t size)535   void SetMaxPipelineRayHitAttributeSize(uint32_t size) {
536     max_pipeline_ray_hit_attribute_size_ = size;
537   }
GetMaxPipelineRayHitAttributeSize()538   uint32_t GetMaxPipelineRayHitAttributeSize() {
539     return max_pipeline_ray_hit_attribute_size_;
540   }
SetMaxPipelineRayRecursionDepth(uint32_t depth)541   void SetMaxPipelineRayRecursionDepth(uint32_t depth) {
542     max_pipeline_ray_recursion_depth_ = depth;
543   }
GetMaxPipelineRayRecursionDepth()544   uint32_t GetMaxPipelineRayRecursionDepth() {
545     return max_pipeline_ray_recursion_depth_;
546   }
SetCreateFlags(uint32_t flags)547   void SetCreateFlags(uint32_t flags) {
548     create_flags_ = flags;
549   }
GetCreateFlags()550   uint32_t GetCreateFlags() const {
551     return create_flags_;
552   }
553 
AddPipelineLibrary(Pipeline * pipeline)554   void AddPipelineLibrary(Pipeline* pipeline) { libs_.push_back(pipeline); }
GetPipelineLibraries()555   const std::vector<Pipeline*>& GetPipelineLibraries() const { return libs_; }
556 
557  private:
558   void UpdateFramebufferSizes();
559 
560   Result SetShaderRequiredSubgroupSize(
561       const Shader* shader,
562       const ShaderInfo::RequiredSubgroupSizeSetting setting,
563       const uint32_t subgroupSize);
564 
565   Result CreatePushConstantBuffer();
566 
567   Result ValidateGraphics() const;
568   Result ValidateCompute() const;
569   Result ValidateRayTracing() const;
570 
571   PipelineType pipeline_type_ = PipelineType::kCompute;
572   std::string name_;
573   std::vector<ShaderInfo> shaders_;
574   std::vector<TLASInfo> tlases_;
575   std::vector<BufferInfo> color_attachments_;
576   std::vector<BufferInfo> resolve_targets_;
577   std::vector<BufferInfo> vertex_buffers_;
578   std::vector<BufferInfo> buffers_;
579   std::vector<std::unique_ptr<type::Type>> types_;
580   std::vector<SamplerInfo> samplers_;
581   std::vector<std::unique_ptr<Format>> formats_;
582   BufferInfo depth_stencil_buffer_;
583   BufferInfo push_constant_buffer_;
584   Buffer* index_buffer_ = nullptr;
585   PipelineData pipeline_data_;
586   uint32_t fb_width_ = 250;
587   uint32_t fb_height_ = 250;
588 
589   std::vector<ArgSetInfo> set_arg_values_;
590   std::vector<std::unique_ptr<Buffer>> opencl_pod_buffers_;
591   /// Maps (descriptor set, binding) to the buffer for that binding pair.
592   std::map<std::pair<uint32_t, uint32_t>, Buffer*> opencl_pod_buffer_map_;
593   std::vector<std::unique_ptr<Sampler>> opencl_literal_samplers_;
594   std::unique_ptr<Buffer> opencl_push_constants_;
595 
596   std::map<std::string, ShaderGroup*> name_to_shader_group_;
597   std::vector<std::shared_ptr<ShaderGroup>> shader_groups_;
598   std::map<std::string, SBT*> name_to_sbt_;
599   std::vector<std::unique_ptr<SBT>> sbts_;
600   uint32_t max_pipeline_ray_payload_size_ = 0;
601   uint32_t max_pipeline_ray_hit_attribute_size_ = 0;
602   uint32_t max_pipeline_ray_recursion_depth_ = 1;
603   uint32_t create_flags_ = 0;
604   std::vector<Pipeline*> libs_;
605 };
606 
607 }  // namespace amber
608 
609 #endif  // SRC_PIPELINE_H_
610