• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2024 The Amber Authors.
2 // Copyright (C) 2024 Advanced Micro Devices, Inc. All rights reserved.
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 
16 #ifndef SRC_ACCELERATION_STRUCTURE_H_
17 #define SRC_ACCELERATION_STRUCTURE_H_
18 
19 #include <cstdint>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 
25 #include "amber/amber.h"
26 #include "amber/result.h"
27 #include "amber/value.h"
28 #include "src/format.h"
29 #include "src/image.h"
30 
31 namespace amber {
32 
33 enum class GeometryType : int8_t {
34   kUnknown = 0,
35   kTriangle,
36   kAABB,
37 };
38 
39 class Shader;
40 
41 class Geometry {
42  public:
43   Geometry();
44   ~Geometry();
45 
SetType(GeometryType type)46   void SetType(GeometryType type) { type_ = type; }
GetType()47   GeometryType GetType() { return type_; }
48 
SetData(std::vector<float> & data)49   void SetData(std::vector<float>& data) { data_.swap(data); }
GetData()50   std::vector<float>& GetData() { return data_; }
51 
SetFlags(uint32_t flags)52   void SetFlags(uint32_t flags) { flags_ = flags; }
GetFlags()53   uint32_t GetFlags() { return flags_; }
54 
getVertexCount()55   size_t getVertexCount() const {
56     return data_.size() / 3;  // Three floats to define vertex
57   }
58 
getPrimitiveCount()59   size_t getPrimitiveCount() const {
60     return IsTriangle() ? (getVertexCount() / 3)  // 3 vertices per triangle
61            : IsAABB()   ? (getVertexCount() / 2)  // 2 vertices per AABB
62                         : 0;
63   }
64 
IsTriangle()65   bool IsTriangle() const { return type_ == GeometryType::kTriangle; }
IsAABB()66   bool IsAABB() const { return type_ == GeometryType::kAABB; }
67 
68  private:
69   GeometryType type_ = GeometryType::kUnknown;
70   std::vector<float> data_;
71   uint32_t flags_ = 0u;
72 };
73 
74 class BLAS {
75  public:
76   BLAS();
77   ~BLAS();
78 
SetName(const std::string & name)79   void SetName(const std::string& name) { name_ = name; }
GetName()80   std::string GetName() const { return name_; }
81 
AddGeometry(std::unique_ptr<Geometry> * geometry)82   void AddGeometry(std::unique_ptr<Geometry>* geometry) {
83     geometry_.push_back(std::move(*geometry));
84   }
GetGeometrySize()85   size_t GetGeometrySize() { return geometry_.size(); }
GetGeometries()86   std::vector<std::unique_ptr<Geometry>>& GetGeometries() { return geometry_; }
87 
88  private:
89   std::string name_;
90   std::vector<std::unique_ptr<Geometry>> geometry_;
91 };
92 
93 class BLASInstance {
94  public:
BLASInstance()95   BLASInstance()
96       : used_blas_name_(),
97         used_blas_(nullptr),
98         transform_(0),
99         instance_custom_index_(0),
100         mask_(0xFF),
101         instanceShaderBindingTableRecordOffset_(0),
102         flags_(0) {}
103   ~BLASInstance();
104 
SetUsedBLAS(const std::string & name,BLAS * blas)105   void SetUsedBLAS(const std::string& name, BLAS* blas) {
106     used_blas_name_ = name;
107     used_blas_ = blas;
108   }
GetUsedBLASName()109   std::string GetUsedBLASName() const { return used_blas_name_; }
GetUsedBLAS()110   BLAS* GetUsedBLAS() const { return used_blas_; }
111 
SetTransform(const std::vector<float> & transform)112   void SetTransform(const std::vector<float>& transform) {
113     transform_ = transform;
114   }
GetTransform()115   const float* GetTransform() const { return transform_.data(); }
116 
SetInstanceIndex(uint32_t instance_custom_index)117   void SetInstanceIndex(uint32_t instance_custom_index) {
118     instance_custom_index_ = instance_custom_index;
119     // Make sure argument was not cut off
120     assert(instance_custom_index_ == instance_custom_index);
121   }
GetInstanceIndex()122   uint32_t GetInstanceIndex() const { return instance_custom_index_; }
123 
SetMask(uint32_t mask)124   void SetMask(uint32_t mask) {
125     mask_ = mask;
126     // Make sure argument was not cut off
127     assert(mask_ == mask);
128   }
GetMask()129   uint32_t GetMask() const { return mask_; }
130 
SetOffset(uint32_t offset)131   void SetOffset(uint32_t offset) {
132     instanceShaderBindingTableRecordOffset_ = offset;
133     // Make sure argument was not cut off
134     assert(instanceShaderBindingTableRecordOffset_ == offset);
135   }
GetOffset()136   uint32_t GetOffset() const { return instanceShaderBindingTableRecordOffset_; }
137 
SetFlags(uint32_t flags)138   void SetFlags(uint32_t flags) {
139     flags_ = flags;
140     // Make sure argument was not cut off
141     assert(flags_ == flags);
142   }
GetFlags()143   uint32_t GetFlags() const { return flags_; }
144 
145  private:
146   std::string used_blas_name_;
147   BLAS* used_blas_;
148   std::vector<float> transform_;
149   uint32_t instance_custom_index_ : 24;
150   uint32_t mask_ : 8;
151   uint32_t instanceShaderBindingTableRecordOffset_ : 24;
152   uint32_t flags_ : 8;
153 };
154 
155 class TLAS {
156  public:
157   TLAS();
158   ~TLAS();
159 
SetName(const std::string & name)160   void SetName(const std::string& name) { name_ = name; }
GetName()161   std::string GetName() const { return name_; }
162 
AddInstance(std::unique_ptr<BLASInstance> instance)163   void AddInstance(std::unique_ptr<BLASInstance> instance) {
164     blas_instances_.push_back(
165         std::unique_ptr<BLASInstance>(instance.release()));
166   }
GetInstanceSize()167   size_t GetInstanceSize() { return blas_instances_.size(); }
GetInstances()168   std::vector<std::unique_ptr<BLASInstance>>& GetInstances() {
169     return blas_instances_;
170   }
171 
172  private:
173   std::string name_;
174   std::vector<std::unique_ptr<BLASInstance>> blas_instances_;
175 };
176 
177 class ShaderGroup {
178  public:
179   ShaderGroup();
180   ~ShaderGroup();
181 
SetName(const std::string & name)182   void SetName(const std::string& name) { name_ = name; }
GetName()183   std::string GetName() const { return name_; }
184 
SetGeneralShader(Shader * shader)185   void SetGeneralShader(Shader* shader) { generalShader_ = shader; }
GetGeneralShader()186   Shader* GetGeneralShader() const { return generalShader_; }
187 
SetClosestHitShader(Shader * shader)188   void SetClosestHitShader(Shader* shader) { closestHitShader_ = shader; }
GetClosestHitShader()189   Shader* GetClosestHitShader() const { return closestHitShader_; }
190 
SetAnyHitShader(Shader * shader)191   void SetAnyHitShader(Shader* shader) { anyHitShader_ = shader; }
GetAnyHitShader()192   Shader* GetAnyHitShader() const { return anyHitShader_; }
193 
SetIntersectionShader(Shader * shader)194   void SetIntersectionShader(Shader* shader) { intersectionShader_ = shader; }
GetIntersectionShader()195   Shader* GetIntersectionShader() const { return intersectionShader_; }
196 
IsGeneralGroup()197   bool IsGeneralGroup() const { return generalShader_ != nullptr; }
IsHitGroup()198   bool IsHitGroup() const {
199     return closestHitShader_ != nullptr || anyHitShader_ != nullptr ||
200            intersectionShader_ != nullptr;
201   }
GetShaderByType(ShaderType type)202   Shader* GetShaderByType(ShaderType type) const {
203     switch (type) {
204       case kShaderTypeRayGeneration:
205       case kShaderTypeMiss:
206       case kShaderTypeCall:
207         return generalShader_;
208       case kShaderTypeAnyHit:
209         return anyHitShader_;
210       case kShaderTypeClosestHit:
211         return closestHitShader_;
212       case kShaderTypeIntersection:
213         return intersectionShader_;
214       default:
215         assert(0 && "Unsupported shader type");
216         return nullptr;
217     }
218   }
219 
220  private:
221   std::string name_;
222   Shader* generalShader_;
223   Shader* closestHitShader_;
224   Shader* anyHitShader_;
225   Shader* intersectionShader_;
226 };
227 
228 class SBTRecord {
229  public:
230   SBTRecord();
231   ~SBTRecord();
232 
SetUsedShaderGroupName(const std::string & shader_group_name)233   void SetUsedShaderGroupName(const std::string& shader_group_name) {
234     used_shader_group_name_ = shader_group_name;
235   }
GetUsedShaderGroupName()236   std::string GetUsedShaderGroupName() const { return used_shader_group_name_; }
237 
SetCount(const uint32_t count)238   void SetCount(const uint32_t count) { count_ = count; }
GetCount()239   uint32_t GetCount() const { return count_; }
240 
SetIndex(const uint32_t index)241   void SetIndex(const uint32_t index) { index_ = index; }
GetIndex()242   uint32_t GetIndex() const { return index_; }
243 
244  private:
245   std::string used_shader_group_name_;
246   uint32_t count_ = 1;
247   uint32_t index_ = static_cast<uint32_t>(-1);
248 };
249 
250 class SBT {
251  public:
252   SBT();
253   ~SBT();
254 
SetName(const std::string & name)255   void SetName(const std::string& name) { name_ = name; }
GetName()256   std::string GetName() const { return name_; }
257 
AddSBTRecord(std::unique_ptr<SBTRecord> record)258   void AddSBTRecord(std::unique_ptr<SBTRecord> record) {
259     records_.push_back(std::move(record));
260   }
GetSBTRecordCount()261   size_t GetSBTRecordCount() { return records_.size(); }
GetSBTRecords()262   std::vector<std::unique_ptr<SBTRecord>>& GetSBTRecords() { return records_; }
GetSBTSize()263   uint32_t GetSBTSize() {
264     uint32_t size = 0;
265     for (auto& x : records_)
266       size += x->GetCount();
267 
268     return size;
269   }
270 
271  private:
272   std::string name_;
273   std::vector<std::unique_ptr<SBTRecord>> records_;
274 };
275 
276 }  // namespace amber
277 
278 #endif  // SRC_ACCELERATION_STRUCTURE_H_
279