• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2018 The Amber Authors.
2 // Copyright (C) 2024 Advanced Micro Devices, Inc. All rights reserved.
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 
16 #ifndef SRC_COMMAND_H_
17 #define SRC_COMMAND_H_
18 
19 #include <cstdint>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 
25 #include "amber/shader_info.h"
26 #include "amber/value.h"
27 #include "src/acceleration_structure.h"
28 #include "src/buffer.h"
29 #include "src/command_data.h"
30 #include "src/pipeline_data.h"
31 #include "src/sampler.h"
32 
33 namespace amber {
34 
35 class BufferCommand;
36 class ClearColorCommand;
37 class ClearCommand;
38 class ClearDepthCommand;
39 class ClearStencilCommand;
40 class CompareBufferCommand;
41 class ComputeCommand;
42 class CopyCommand;
43 class DrawArraysCommand;
44 class DrawRectCommand;
45 class DrawGridCommand;
46 class EntryPointCommand;
47 class PatchParameterVerticesCommand;
48 class Pipeline;
49 class ProbeCommand;
50 class ProbeSSBOCommand;
51 class RayTracingCommand;
52 class RepeatCommand;
53 class TLASCommand;
54 
55 /// Base class for all commands.
56 class Command {
57  public:
58   enum class Type : uint8_t {
59     kClear = 0,
60     kClearColor,
61     kClearDepth,
62     kClearStencil,
63     kCompute,
64     kCompareBuffer,
65     kCopy,
66     kDrawArrays,
67     kDrawRect,
68     kDrawGrid,
69     kEntryPoint,
70     kPatchParameterVertices,
71     kPipelineProperties,
72     kProbe,
73     kProbeSSBO,
74     kBuffer,
75     kRepeat,
76     kSampler,
77     kTLAS,
78     kRayTracing
79   };
80 
81   virtual ~Command();
82 
GetType()83   Command::Type GetType() const { return command_type_; }
84 
IsDrawRect()85   bool IsDrawRect() const { return command_type_ == Type::kDrawRect; }
IsDrawGrid()86   bool IsDrawGrid() const { return command_type_ == Type::kDrawGrid; }
IsDrawArrays()87   bool IsDrawArrays() const { return command_type_ == Type::kDrawArrays; }
IsCompareBuffer()88   bool IsCompareBuffer() const { return command_type_ == Type::kCompareBuffer; }
IsCompute()89   bool IsCompute() const { return command_type_ == Type::kCompute; }
IsRayTracing()90   bool IsRayTracing() const { return command_type_ == Type::kRayTracing; }
IsTLAS()91   bool IsTLAS() const { return command_type_ == Type::kTLAS; }
IsCopy()92   bool IsCopy() const { return command_type_ == Type::kCopy; }
IsProbe()93   bool IsProbe() const { return command_type_ == Type::kProbe; }
IsProbeSSBO()94   bool IsProbeSSBO() const { return command_type_ == Type::kProbeSSBO; }
IsBuffer()95   bool IsBuffer() const { return command_type_ == Type::kBuffer; }
IsClear()96   bool IsClear() const { return command_type_ == Type::kClear; }
IsClearColor()97   bool IsClearColor() const { return command_type_ == Type::kClearColor; }
IsClearDepth()98   bool IsClearDepth() const { return command_type_ == Type::kClearDepth; }
IsClearStencil()99   bool IsClearStencil() const { return command_type_ == Type::kClearStencil; }
IsPatchParameterVertices()100   bool IsPatchParameterVertices() const {
101     return command_type_ == Type::kPatchParameterVertices;
102   }
IsEntryPoint()103   bool IsEntryPoint() const { return command_type_ == Type::kEntryPoint; }
IsRepeat()104   bool IsRepeat() { return command_type_ == Type::kRepeat; }
105 
106   ClearCommand* AsClear();
107   ClearColorCommand* AsClearColor();
108   ClearDepthCommand* AsClearDepth();
109   ClearStencilCommand* AsClearStencil();
110   CompareBufferCommand* AsCompareBuffer();
111   ComputeCommand* AsCompute();
112   RayTracingCommand* AsRayTracing();
113   CopyCommand* AsCopy();
114   DrawArraysCommand* AsDrawArrays();
115   DrawRectCommand* AsDrawRect();
116   DrawGridCommand* AsDrawGrid();
117   EntryPointCommand* AsEntryPoint();
118   PatchParameterVerticesCommand* AsPatchParameterVertices();
119   ProbeCommand* AsProbe();
120   ProbeSSBOCommand* AsProbeSSBO();
121   BufferCommand* AsBuffer();
122   RepeatCommand* AsRepeat();
123 
124   virtual std::string ToString() const = 0;
125 
126   /// Sets the input file line number this command is declared on.
SetLine(size_t line)127   void SetLine(size_t line) { line_ = line; }
128   /// Returns the input file line this command was declared on.
GetLine()129   size_t GetLine() const { return line_; }
130 
131  protected:
132   explicit Command(Type type);
133 
134   Type command_type_;
135   size_t line_ = 1;
136 };
137 
138 /// Base class for commands which contain a pipeline.
139 class PipelineCommand : public Command {
140  public:
141   ~PipelineCommand() override;
142 
GetPipeline()143   Pipeline* GetPipeline() const { return pipeline_; }
144 
SetTimedExecution()145   void SetTimedExecution() { timed_execution_ = true; }
IsTimedExecution()146   bool IsTimedExecution() const { return timed_execution_; }
147 
148  protected:
149   PipelineCommand(Type type, Pipeline* pipeline);
150 
151   Pipeline* pipeline_ = nullptr;
152   bool timed_execution_ = false;
153 };
154 
155 /// Command to draw a rectangle on screen.
156 class DrawRectCommand : public PipelineCommand {
157  public:
158   DrawRectCommand(Pipeline* pipeline, PipelineData data);
159   ~DrawRectCommand() override;
160 
GetPipelineData()161   const PipelineData* GetPipelineData() const { return &data_; }
162 
EnableOrtho()163   void EnableOrtho() { is_ortho_ = true; }
IsOrtho()164   bool IsOrtho() const { return is_ortho_; }
165 
EnablePatch()166   void EnablePatch() { is_patch_ = true; }
IsPatch()167   bool IsPatch() const { return is_patch_; }
168 
SetX(float x)169   void SetX(float x) { x_ = x; }
GetX()170   float GetX() const { return x_; }
171 
SetY(float y)172   void SetY(float y) { y_ = y; }
GetY()173   float GetY() const { return y_; }
174 
SetWidth(float w)175   void SetWidth(float w) { width_ = w; }
GetWidth()176   float GetWidth() const { return width_; }
177 
SetHeight(float h)178   void SetHeight(float h) { height_ = h; }
GetHeight()179   float GetHeight() const { return height_; }
180 
ToString()181   std::string ToString() const override { return "DrawRectCommand"; }
182 
183  private:
184   PipelineData data_;
185   bool is_ortho_ = false;
186   bool is_patch_ = false;
187   float x_ = 0.0;
188   float y_ = 0.0;
189   float width_ = 0.0;
190   float height_ = 0.0;
191 };
192 
193 /// Command to draw a grid of recrangles on screen.
194 class DrawGridCommand : public PipelineCommand {
195  public:
196   DrawGridCommand(Pipeline* pipeline, PipelineData data);
197   ~DrawGridCommand() override;
198 
GetPipelineData()199   const PipelineData* GetPipelineData() const { return &data_; }
200 
SetX(float x)201   void SetX(float x) { x_ = x; }
GetX()202   float GetX() const { return x_; }
203 
SetY(float y)204   void SetY(float y) { y_ = y; }
GetY()205   float GetY() const { return y_; }
206 
SetWidth(float w)207   void SetWidth(float w) { width_ = w; }
GetWidth()208   float GetWidth() const { return width_; }
209 
SetHeight(float h)210   void SetHeight(float h) { height_ = h; }
GetHeight()211   float GetHeight() const { return height_; }
212 
SetColumns(uint32_t c)213   void SetColumns(uint32_t c) { columns_ = c; }
GetColumns()214   uint32_t GetColumns() const { return columns_; }
215 
SetRows(uint32_t r)216   void SetRows(uint32_t r) { rows_ = r; }
GetRows()217   uint32_t GetRows() const { return rows_; }
218 
ToString()219   std::string ToString() const override { return "DrawGridCommand"; }
220 
221  private:
222   PipelineData data_;
223   float x_ = 0.0;
224   float y_ = 0.0;
225   float width_ = 0.0;
226   float height_ = 0.0;
227   uint32_t columns_ = 0;
228   uint32_t rows_ = 0;
229 };
230 
231 /// Command to draw from a vertex and index buffer.
232 class DrawArraysCommand : public PipelineCommand {
233  public:
234   DrawArraysCommand(Pipeline* pipeline, PipelineData data);
235   ~DrawArraysCommand() override;
236 
GetPipelineData()237   const PipelineData* GetPipelineData() const { return &data_; }
238 
EnableIndexed()239   void EnableIndexed() { is_indexed_ = true; }
IsIndexed()240   bool IsIndexed() const { return is_indexed_; }
241 
SetTopology(Topology topo)242   void SetTopology(Topology topo) { topology_ = topo; }
GetTopology()243   Topology GetTopology() const { return topology_; }
244 
SetFirstVertexIndex(uint32_t idx)245   void SetFirstVertexIndex(uint32_t idx) { first_vertex_index_ = idx; }
GetFirstVertexIndex()246   uint32_t GetFirstVertexIndex() const { return first_vertex_index_; }
247 
SetVertexCount(uint32_t count)248   void SetVertexCount(uint32_t count) { vertex_count_ = count; }
GetVertexCount()249   uint32_t GetVertexCount() const { return vertex_count_; }
250 
SetFirstInstance(uint32_t idx)251   void SetFirstInstance(uint32_t idx) { first_instance_ = idx; }
GetFirstInstance()252   uint32_t GetFirstInstance() const { return first_instance_; }
253 
SetInstanceCount(uint32_t count)254   void SetInstanceCount(uint32_t count) { instance_count_ = count; }
GetInstanceCount()255   uint32_t GetInstanceCount() const { return instance_count_; }
256 
ToString()257   std::string ToString() const override { return "DrawArraysCommand"; }
258 
259  private:
260   PipelineData data_;
261   bool is_indexed_ = false;
262   Topology topology_ = Topology::kUnknown;
263   uint32_t first_vertex_index_ = 0;
264   uint32_t vertex_count_ = 0;
265   uint32_t first_instance_ = 0;
266   uint32_t instance_count_ = 1;
267 };
268 
269 /// A command to compare two buffers.
270 class CompareBufferCommand : public Command {
271  public:
272   enum class Comparator { kEq, kRmse, kHistogramEmd };
273 
274   CompareBufferCommand(Buffer* buffer_1, Buffer* buffer_2);
275   ~CompareBufferCommand() override;
276 
GetBuffer1()277   Buffer* GetBuffer1() const { return buffer_1_; }
GetBuffer2()278   Buffer* GetBuffer2() const { return buffer_2_; }
279 
SetComparator(Comparator type)280   void SetComparator(Comparator type) { comparator_ = type; }
GetComparator()281   Comparator GetComparator() const { return comparator_; }
282 
SetTolerance(float tolerance)283   void SetTolerance(float tolerance) { tolerance_ = tolerance; }
GetTolerance()284   float GetTolerance() const { return tolerance_; }
285 
ToString()286   std::string ToString() const override { return "CompareBufferCommand"; }
287 
288  private:
289   Buffer* buffer_1_;
290   Buffer* buffer_2_;
291   float tolerance_ = 0.0;
292   Comparator comparator_ = Comparator::kEq;
293 };
294 
295 /// Command to execute a compute command.
296 class ComputeCommand : public PipelineCommand {
297  public:
298   explicit ComputeCommand(Pipeline* pipeline);
299   ~ComputeCommand() override;
300 
SetX(uint32_t x)301   void SetX(uint32_t x) { x_ = x; }
GetX()302   uint32_t GetX() const { return x_; }
303 
SetY(uint32_t y)304   void SetY(uint32_t y) { y_ = y; }
GetY()305   uint32_t GetY() const { return y_; }
306 
SetZ(uint32_t z)307   void SetZ(uint32_t z) { z_ = z; }
GetZ()308   uint32_t GetZ() const { return z_; }
309 
ToString()310   std::string ToString() const override { return "ComputeCommand"; }
311 
312  private:
313   uint32_t x_ = 0;
314   uint32_t y_ = 0;
315   uint32_t z_ = 0;
316 };
317 
318 /// Command to copy data from one buffer to another.
319 class CopyCommand : public Command {
320  public:
321   CopyCommand(Buffer* buffer_from, Buffer* buffer_to);
322   ~CopyCommand() override;
323 
GetBufferFrom()324   Buffer* GetBufferFrom() const { return buffer_from_; }
GetBufferTo()325   Buffer* GetBufferTo() const { return buffer_to_; }
326 
ToString()327   std::string ToString() const override { return "CopyCommand"; }
328 
329  private:
330   Buffer* buffer_from_;
331   Buffer* buffer_to_;
332 };
333 
334 /// Base class for probe commands.
335 class Probe : public Command {
336  public:
337   /// Wrapper around tolerance information for the probe.
338   struct Tolerance {
ToleranceTolerance339     Tolerance(bool percent, double val) : is_percent(percent), value(val) {}
340 
341     bool is_percent = false;
342     double value = 0.0;
343   };
344 
345   ~Probe() override;
346 
GetBuffer()347   Buffer* GetBuffer() const { return buffer_; }
348 
HasTolerances()349   bool HasTolerances() const { return !tolerances_.empty(); }
SetTolerances(const std::vector<Tolerance> & t)350   void SetTolerances(const std::vector<Tolerance>& t) { tolerances_ = t; }
GetTolerances()351   const std::vector<Tolerance>& GetTolerances() const { return tolerances_; }
352 
353  protected:
354   Probe(Type type, Buffer* buffer);
355 
356  private:
357   Buffer* buffer_;
358   std::vector<Tolerance> tolerances_;
359 };
360 
361 /// Command to probe an image buffer.
362 class ProbeCommand : public Probe {
363  public:
364   explicit ProbeCommand(Buffer* buffer);
365   ~ProbeCommand() override;
366 
SetWholeWindow()367   void SetWholeWindow() { is_whole_window_ = true; }
IsWholeWindow()368   bool IsWholeWindow() const { return is_whole_window_; }
369 
SetProbeRect()370   void SetProbeRect() { is_probe_rect_ = true; }
IsProbeRect()371   bool IsProbeRect() const { return is_probe_rect_; }
372 
SetRelative()373   void SetRelative() { is_relative_ = true; }
IsRelative()374   bool IsRelative() const { return is_relative_; }
375 
SetIsRGBA()376   void SetIsRGBA() { color_format_ = ColorFormat::kRGBA; }
IsRGBA()377   bool IsRGBA() const { return color_format_ == ColorFormat::kRGBA; }
378 
SetX(float x)379   void SetX(float x) { x_ = x; }
GetX()380   float GetX() const { return x_; }
381 
SetY(float y)382   void SetY(float y) { y_ = y; }
GetY()383   float GetY() const { return y_; }
384 
SetWidth(float w)385   void SetWidth(float w) { width_ = w; }
GetWidth()386   float GetWidth() const { return width_; }
387 
SetHeight(float h)388   void SetHeight(float h) { height_ = h; }
GetHeight()389   float GetHeight() const { return height_; }
390 
391   // Colours are stored in the range 0.0 - 1.0
SetR(float r)392   void SetR(float r) { r_ = r; }
GetR()393   float GetR() const { return r_; }
394 
SetG(float g)395   void SetG(float g) { g_ = g; }
GetG()396   float GetG() const { return g_; }
397 
SetB(float b)398   void SetB(float b) { b_ = b; }
GetB()399   float GetB() const { return b_; }
400 
SetA(float a)401   void SetA(float a) { a_ = a; }
GetA()402   float GetA() const { return a_; }
403 
ToString()404   std::string ToString() const override { return "ProbeCommand"; }
405 
406  private:
407   enum class ColorFormat {
408     kRGB = 0,
409     kRGBA,
410   };
411 
412   bool is_whole_window_ = false;
413   bool is_probe_rect_ = false;
414   bool is_relative_ = false;
415   ColorFormat color_format_ = ColorFormat::kRGB;
416 
417   float x_ = 0.0;
418   float y_ = 0.0;
419   float width_ = 1.0;
420   float height_ = 1.0;
421 
422   float r_ = 0.0;
423   float g_ = 0.0;
424   float b_ = 0.0;
425   float a_ = 0.0;
426 };
427 
428 /// Command to probe a data buffer.
429 class ProbeSSBOCommand : public Probe {
430  public:
431   enum class Comparator {
432     kEqual,
433     kNotEqual,
434     kFuzzyEqual,
435     kLess,
436     kLessOrEqual,
437     kGreater,
438     kGreaterOrEqual
439   };
440 
441   explicit ProbeSSBOCommand(Buffer* buffer);
442   ~ProbeSSBOCommand() override;
443 
SetComparator(Comparator comp)444   void SetComparator(Comparator comp) { comparator_ = comp; }
GetComparator()445   Comparator GetComparator() const { return comparator_; }
446 
SetDescriptorSet(uint32_t id)447   void SetDescriptorSet(uint32_t id) { descriptor_set_id_ = id; }
GetDescriptorSet()448   uint32_t GetDescriptorSet() const { return descriptor_set_id_; }
449 
SetBinding(uint32_t id)450   void SetBinding(uint32_t id) { binding_num_ = id; }
GetBinding()451   uint32_t GetBinding() const { return binding_num_; }
452 
SetOffset(uint32_t offset)453   void SetOffset(uint32_t offset) { offset_ = offset; }
GetOffset()454   uint32_t GetOffset() const { return offset_; }
455 
SetFormat(Format * fmt)456   void SetFormat(Format* fmt) { format_ = fmt; }
GetFormat()457   Format* GetFormat() const { return format_; }
458 
SetValues(std::vector<Value> && values)459   void SetValues(std::vector<Value>&& values) { values_ = std::move(values); }
GetValues()460   const std::vector<Value>& GetValues() const { return values_; }
461 
ToString()462   std::string ToString() const override { return "ProbeSSBOCommand"; }
463 
464  private:
465   Comparator comparator_ = Comparator::kEqual;
466   uint32_t descriptor_set_id_ = 0;
467   uint32_t binding_num_ = 0;
468   uint32_t offset_ = 0;
469   Format* format_;
470   std::vector<Value> values_;
471 };
472 
473 /// Base class for BufferCommand and SamplerCommand to handle binding.
474 class BindableResourceCommand : public PipelineCommand {
475  public:
476   BindableResourceCommand(Type type, Pipeline* pipeline);
477   ~BindableResourceCommand() override;
478 
SetDescriptorSet(uint32_t set)479   void SetDescriptorSet(uint32_t set) { descriptor_set_ = set; }
GetDescriptorSet()480   uint32_t GetDescriptorSet() const { return descriptor_set_; }
481 
SetBinding(uint32_t num)482   void SetBinding(uint32_t num) { binding_num_ = num; }
GetBinding()483   uint32_t GetBinding() const { return binding_num_; }
484 
485  private:
486   uint32_t descriptor_set_ = 0;
487   uint32_t binding_num_ = 0;
488 };
489 
490 /// Command to set the size of a buffer, or update a buffers contents.
491 class BufferCommand : public BindableResourceCommand {
492  public:
493   enum class BufferType {
494     kSSBO,
495     kSSBODynamic,
496     kUniform,
497     kUniformDynamic,
498     kPushConstant,
499     kStorageImage,
500     kSampledImage,
501     kCombinedImageSampler,
502     kUniformTexelBuffer,
503     kStorageTexelBuffer
504   };
505 
506   BufferCommand(BufferType type, Pipeline* pipeline);
507   ~BufferCommand() override;
508 
IsSSBO()509   bool IsSSBO() const { return buffer_type_ == BufferType::kSSBO; }
IsSSBODynamic()510   bool IsSSBODynamic() const {
511     return buffer_type_ == BufferType::kSSBODynamic;
512   }
IsUniform()513   bool IsUniform() const { return buffer_type_ == BufferType::kUniform; }
IsUniformDynamic()514   bool IsUniformDynamic() const {
515     return buffer_type_ == BufferType::kUniformDynamic;
516   }
IsStorageImage()517   bool IsStorageImage() const {
518     return buffer_type_ == BufferType::kStorageImage;
519   }
IsSampledImage()520   bool IsSampledImage() const {
521     return buffer_type_ == BufferType::kSampledImage;
522   }
IsCombinedImageSampler()523   bool IsCombinedImageSampler() const {
524     return buffer_type_ == BufferType::kCombinedImageSampler;
525   }
IsUniformTexelBuffer()526   bool IsUniformTexelBuffer() const {
527     return buffer_type_ == BufferType::kUniformTexelBuffer;
528   }
IsStorageTexelBuffer()529   bool IsStorageTexelBuffer() const {
530     return buffer_type_ == BufferType::kStorageTexelBuffer;
531   }
IsPushConstant()532   bool IsPushConstant() const {
533     return buffer_type_ == BufferType::kPushConstant;
534   }
535 
SetIsSubdata()536   void SetIsSubdata() { is_subdata_ = true; }
IsSubdata()537   bool IsSubdata() const { return is_subdata_; }
538 
SetOffset(uint32_t offset)539   void SetOffset(uint32_t offset) { offset_ = offset; }
GetOffset()540   uint32_t GetOffset() const { return offset_; }
541 
SetBaseMipLevel(uint32_t base_mip_level)542   void SetBaseMipLevel(uint32_t base_mip_level) {
543     base_mip_level_ = base_mip_level;
544   }
GetBaseMipLevel()545   uint32_t GetBaseMipLevel() const { return base_mip_level_; }
546 
SetDynamicOffset(uint32_t dynamic_offset)547   void SetDynamicOffset(uint32_t dynamic_offset) {
548     dynamic_offset_ = dynamic_offset;
549   }
GetDynamicOffset()550   uint32_t GetDynamicOffset() const { return dynamic_offset_; }
551 
SetDescriptorOffset(uint64_t descriptor_offset)552   void SetDescriptorOffset(uint64_t descriptor_offset) {
553     descriptor_offset_ = descriptor_offset;
554   }
GetDescriptorOffset()555   uint64_t GetDescriptorOffset() const { return descriptor_offset_; }
556 
SetDescriptorRange(uint64_t descriptor_range)557   void SetDescriptorRange(uint64_t descriptor_range) {
558     descriptor_range_ = descriptor_range;
559   }
GetDescriptorRange()560   uint64_t GetDescriptorRange() const { return descriptor_range_; }
561 
SetValues(std::vector<Value> && values)562   void SetValues(std::vector<Value>&& values) { values_ = std::move(values); }
GetValues()563   const std::vector<Value>& GetValues() const { return values_; }
564 
SetBuffer(Buffer * buffer)565   void SetBuffer(Buffer* buffer) { buffer_ = buffer; }
GetBuffer()566   Buffer* GetBuffer() const { return buffer_; }
567 
SetSampler(Sampler * sampler)568   void SetSampler(Sampler* sampler) { sampler_ = sampler; }
GetSampler()569   Sampler* GetSampler() const { return sampler_; }
570 
ToString()571   std::string ToString() const override { return "BufferCommand"; }
572 
573  private:
574   Buffer* buffer_ = nullptr;
575   Sampler* sampler_ = nullptr;
576   BufferType buffer_type_;
577   bool is_subdata_ = false;
578   uint32_t offset_ = 0;
579   uint32_t base_mip_level_ = 0;
580   uint32_t dynamic_offset_ = 0;
581   uint64_t descriptor_offset_ = 0;
582   uint64_t descriptor_range_ = ~0ULL;
583   std::vector<Value> values_;
584 };
585 
586 /// Command for setting sampler parameters and binding.
587 class SamplerCommand : public BindableResourceCommand {
588  public:
589   explicit SamplerCommand(Pipeline* pipeline);
590   ~SamplerCommand() override;
591 
SetSampler(Sampler * sampler)592   void SetSampler(Sampler* sampler) { sampler_ = sampler; }
GetSampler()593   Sampler* GetSampler() const { return sampler_; }
594 
ToString()595   std::string ToString() const override { return "SamplerCommand"; }
596 
597  private:
598   Sampler* sampler_ = nullptr;
599 };
600 
601 /// Command to clear the colour attachments.
602 class ClearCommand : public PipelineCommand {
603  public:
604   explicit ClearCommand(Pipeline* pipeline);
605   ~ClearCommand() override;
606 
ToString()607   std::string ToString() const override { return "ClearCommand"; }
608 };
609 
610 /// Command to set the colour for the clear command.
611 class ClearColorCommand : public PipelineCommand {
612  public:
613   explicit ClearColorCommand(Pipeline* pipeline);
614   ~ClearColorCommand() override;
615 
616   // Colours are stored in the range 0.0 - 1.0
SetR(float r)617   void SetR(float r) { r_ = r; }
GetR()618   float GetR() const { return r_; }
619 
SetG(float g)620   void SetG(float g) { g_ = g; }
GetG()621   float GetG() const { return g_; }
622 
SetB(float b)623   void SetB(float b) { b_ = b; }
GetB()624   float GetB() const { return b_; }
625 
SetA(float a)626   void SetA(float a) { a_ = a; }
GetA()627   float GetA() const { return a_; }
628 
ToString()629   std::string ToString() const override { return "ClearColorCommand"; }
630 
631  private:
632   float r_ = 0.0;
633   float g_ = 0.0;
634   float b_ = 0.0;
635   float a_ = 0.0;
636 };
637 
638 /// Command to set the depth value for the clear command.
639 class ClearDepthCommand : public PipelineCommand {
640  public:
641   explicit ClearDepthCommand(Pipeline* pipeline);
642   ~ClearDepthCommand() override;
643 
SetValue(float val)644   void SetValue(float val) { value_ = val; }
GetValue()645   float GetValue() const { return value_; }
646 
ToString()647   std::string ToString() const override { return "ClearDepthCommand"; }
648 
649  private:
650   float value_ = 0.0;
651 };
652 
653 /// Command to set the stencil value for the clear command.
654 class ClearStencilCommand : public PipelineCommand {
655  public:
656   explicit ClearStencilCommand(Pipeline* pipeline);
657   ~ClearStencilCommand() override;
658 
SetValue(uint32_t val)659   void SetValue(uint32_t val) { value_ = val; }
GetValue()660   uint32_t GetValue() const { return value_; }
661 
ToString()662   std::string ToString() const override { return "ClearStencilCommand"; }
663 
664  private:
665   uint32_t value_ = 0;
666 };
667 
668 /// Command to set the patch parameter vertices.
669 class PatchParameterVerticesCommand : public PipelineCommand {
670  public:
671   explicit PatchParameterVerticesCommand(Pipeline* pipeline);
672   ~PatchParameterVerticesCommand() override;
673 
SetControlPointCount(uint32_t count)674   void SetControlPointCount(uint32_t count) { control_point_count_ = count; }
GetControlPointCount()675   uint32_t GetControlPointCount() const { return control_point_count_; }
676 
ToString()677   std::string ToString() const override {
678     return "PatchParameterVerticesCommand";
679   }
680 
681  private:
682   uint32_t control_point_count_ = 0;
683 };
684 
685 /// Command to set the entry point to use for a given shader type.
686 class EntryPointCommand : public PipelineCommand {
687  public:
688   explicit EntryPointCommand(Pipeline* pipeline);
689   ~EntryPointCommand() override;
690 
SetShaderType(ShaderType type)691   void SetShaderType(ShaderType type) { shader_type_ = type; }
GetShaderType()692   ShaderType GetShaderType() const { return shader_type_; }
693 
SetEntryPointName(const std::string & name)694   void SetEntryPointName(const std::string& name) { entry_point_name_ = name; }
GetEntryPointName()695   std::string GetEntryPointName() const { return entry_point_name_; }
696 
ToString()697   std::string ToString() const override { return "EntryPointCommand"; }
698 
699  private:
700   ShaderType shader_type_ = kShaderTypeVertex;
701   std::string entry_point_name_;
702 };
703 
704 /// Command to repeat the given set of commands a number of times.
705 class RepeatCommand : public Command {
706  public:
707   explicit RepeatCommand(uint32_t count);
708   ~RepeatCommand() override;
709 
GetCount()710   uint32_t GetCount() const { return count_; }
711 
SetCommands(std::vector<std::unique_ptr<Command>> cmds)712   void SetCommands(std::vector<std::unique_ptr<Command>> cmds) {
713     commands_ = std::move(cmds);
714   }
715 
GetCommands()716   const std::vector<std::unique_ptr<Command>>& GetCommands() const {
717     return commands_;
718   }
719 
ToString()720   std::string ToString() const override { return "RepeatCommand"; }
721 
722  private:
723   uint32_t count_ = 0;
724   std::vector<std::unique_ptr<Command>> commands_;
725 };
726 
727 /// Command for setting TLAS parameters and binding.
728 class TLASCommand : public BindableResourceCommand {
729  public:
730   explicit TLASCommand(Pipeline* pipeline);
731   ~TLASCommand() override;
732 
SetTLAS(TLAS * tlas)733   void SetTLAS(TLAS* tlas) { tlas_ = tlas; }
GetTLAS()734   TLAS* GetTLAS() const { return tlas_; }
735 
ToString()736   std::string ToString() const override { return "TLASCommand"; }
737 
738  private:
739   TLAS* tlas_ = nullptr;
740 };
741 
742 /// Command to execute a ray tracing command.
743 class RayTracingCommand : public PipelineCommand {
744  public:
745   explicit RayTracingCommand(Pipeline* pipeline);
746   ~RayTracingCommand() override;
747 
SetX(uint32_t x)748   void SetX(uint32_t x) { x_ = x; }
GetX()749   uint32_t GetX() const { return x_; }
750 
SetY(uint32_t y)751   void SetY(uint32_t y) { y_ = y; }
GetY()752   uint32_t GetY() const { return y_; }
753 
SetZ(uint32_t z)754   void SetZ(uint32_t z) { z_ = z; }
GetZ()755   uint32_t GetZ() const { return z_; }
756 
SetRGenSBTName(const std::string & name)757   void SetRGenSBTName(const std::string& name) { rgen_sbt_name_ = name; }
GetRayGenSBTName()758   std::string GetRayGenSBTName() const { return rgen_sbt_name_; }
759 
SetMissSBTName(const std::string & name)760   void SetMissSBTName(const std::string& name) { miss_sbt_name_ = name; }
GetMissSBTName()761   std::string GetMissSBTName() const { return miss_sbt_name_; }
762 
SetHitsSBTName(const std::string & name)763   void SetHitsSBTName(const std::string& name) { hits_sbt_name_ = name; }
GetHitsSBTName()764   std::string GetHitsSBTName() const { return hits_sbt_name_; }
765 
SetCallSBTName(const std::string & name)766   void SetCallSBTName(const std::string& name) { call_sbt_name_ = name; }
GetCallSBTName()767   std::string GetCallSBTName() const { return call_sbt_name_; }
768 
ToString()769   std::string ToString() const override { return "RayTracingCommand"; }
770 
771  private:
772   uint32_t x_ = 0;
773   uint32_t y_ = 0;
774   uint32_t z_ = 0;
775   std::string rgen_sbt_name_;
776   std::string miss_sbt_name_;
777   std::string hits_sbt_name_;
778   std::string call_sbt_name_;
779 };
780 
781 }  // namespace amber
782 
783 #endif  // SRC_COMMAND_H_
784