1 // Copyright 2024 The Amber Authors. 2 // Copyright (C) 2024 Advanced Micro Devices, Inc. All rights reserved. 3 // 4 // Licensed under the Apache License, Version 2.0 (the "License"); 5 // you may not use this file except in compliance with the License. 6 // You may obtain a copy of the License at 7 // 8 // http://www.apache.org/licenses/LICENSE-2.0 9 // 10 // Unless required by applicable law or agreed to in writing, software 11 // distributed under the License is distributed on an "AS IS" BASIS, 12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 // See the License for the specific language governing permissions and 14 // limitations under the License. 15 16 #ifndef SRC_ACCELERATION_STRUCTURE_H_ 17 #define SRC_ACCELERATION_STRUCTURE_H_ 18 19 #include <cstdint> 20 #include <memory> 21 #include <string> 22 #include <utility> 23 #include <vector> 24 25 #include "amber/amber.h" 26 #include "amber/result.h" 27 #include "amber/value.h" 28 #include "src/format.h" 29 #include "src/image.h" 30 31 namespace amber { 32 33 enum class GeometryType : int8_t { 34 kUnknown = 0, 35 kTriangle, 36 kAABB, 37 }; 38 39 class Shader; 40 41 class Geometry { 42 public: 43 Geometry(); 44 ~Geometry(); 45 SetType(GeometryType type)46 void SetType(GeometryType type) { type_ = type; } GetType()47 GeometryType GetType() { return type_; } 48 SetData(std::vector<float> & data)49 void SetData(std::vector<float>& data) { data_.swap(data); } GetData()50 std::vector<float>& GetData() { return data_; } 51 SetFlags(uint32_t flags)52 void SetFlags(uint32_t flags) { flags_ = flags; } GetFlags()53 uint32_t GetFlags() { return flags_; } 54 getVertexCount()55 size_t getVertexCount() const { 56 return data_.size() / 3; // Three floats to define vertex 57 } 58 getPrimitiveCount()59 size_t getPrimitiveCount() const { 60 return IsTriangle() ? (getVertexCount() / 3) // 3 vertices per triangle 61 : IsAABB() ? (getVertexCount() / 2) // 2 vertices per AABB 62 : 0; 63 } 64 IsTriangle()65 bool IsTriangle() const { return type_ == GeometryType::kTriangle; } IsAABB()66 bool IsAABB() const { return type_ == GeometryType::kAABB; } 67 68 private: 69 GeometryType type_ = GeometryType::kUnknown; 70 std::vector<float> data_; 71 uint32_t flags_ = 0u; 72 }; 73 74 class BLAS { 75 public: 76 BLAS(); 77 ~BLAS(); 78 SetName(const std::string & name)79 void SetName(const std::string& name) { name_ = name; } GetName()80 std::string GetName() const { return name_; } 81 AddGeometry(std::unique_ptr<Geometry> * geometry)82 void AddGeometry(std::unique_ptr<Geometry>* geometry) { 83 geometry_.push_back(std::move(*geometry)); 84 } GetGeometrySize()85 size_t GetGeometrySize() { return geometry_.size(); } GetGeometries()86 std::vector<std::unique_ptr<Geometry>>& GetGeometries() { return geometry_; } 87 88 private: 89 std::string name_; 90 std::vector<std::unique_ptr<Geometry>> geometry_; 91 }; 92 93 class BLASInstance { 94 public: BLASInstance()95 BLASInstance() 96 : used_blas_name_(), 97 used_blas_(nullptr), 98 transform_(0), 99 instance_custom_index_(0), 100 mask_(0xFF), 101 instanceShaderBindingTableRecordOffset_(0), 102 flags_(0) {} 103 ~BLASInstance(); 104 SetUsedBLAS(const std::string & name,BLAS * blas)105 void SetUsedBLAS(const std::string& name, BLAS* blas) { 106 used_blas_name_ = name; 107 used_blas_ = blas; 108 } GetUsedBLASName()109 std::string GetUsedBLASName() const { return used_blas_name_; } GetUsedBLAS()110 BLAS* GetUsedBLAS() const { return used_blas_; } 111 SetTransform(const std::vector<float> & transform)112 void SetTransform(const std::vector<float>& transform) { 113 transform_ = transform; 114 } GetTransform()115 const float* GetTransform() const { return transform_.data(); } 116 SetInstanceIndex(uint32_t instance_custom_index)117 void SetInstanceIndex(uint32_t instance_custom_index) { 118 instance_custom_index_ = instance_custom_index; 119 // Make sure argument was not cut off 120 assert(instance_custom_index_ == instance_custom_index); 121 } GetInstanceIndex()122 uint32_t GetInstanceIndex() const { return instance_custom_index_; } 123 SetMask(uint32_t mask)124 void SetMask(uint32_t mask) { 125 mask_ = mask; 126 // Make sure argument was not cut off 127 assert(mask_ == mask); 128 } GetMask()129 uint32_t GetMask() const { return mask_; } 130 SetOffset(uint32_t offset)131 void SetOffset(uint32_t offset) { 132 instanceShaderBindingTableRecordOffset_ = offset; 133 // Make sure argument was not cut off 134 assert(instanceShaderBindingTableRecordOffset_ == offset); 135 } GetOffset()136 uint32_t GetOffset() const { return instanceShaderBindingTableRecordOffset_; } 137 SetFlags(uint32_t flags)138 void SetFlags(uint32_t flags) { 139 flags_ = flags; 140 // Make sure argument was not cut off 141 assert(flags_ == flags); 142 } GetFlags()143 uint32_t GetFlags() const { return flags_; } 144 145 private: 146 std::string used_blas_name_; 147 BLAS* used_blas_; 148 std::vector<float> transform_; 149 uint32_t instance_custom_index_ : 24; 150 uint32_t mask_ : 8; 151 uint32_t instanceShaderBindingTableRecordOffset_ : 24; 152 uint32_t flags_ : 8; 153 }; 154 155 class TLAS { 156 public: 157 TLAS(); 158 ~TLAS(); 159 SetName(const std::string & name)160 void SetName(const std::string& name) { name_ = name; } GetName()161 std::string GetName() const { return name_; } 162 AddInstance(std::unique_ptr<BLASInstance> instance)163 void AddInstance(std::unique_ptr<BLASInstance> instance) { 164 blas_instances_.push_back( 165 std::unique_ptr<BLASInstance>(instance.release())); 166 } GetInstanceSize()167 size_t GetInstanceSize() { return blas_instances_.size(); } GetInstances()168 std::vector<std::unique_ptr<BLASInstance>>& GetInstances() { 169 return blas_instances_; 170 } 171 172 private: 173 std::string name_; 174 std::vector<std::unique_ptr<BLASInstance>> blas_instances_; 175 }; 176 177 class ShaderGroup { 178 public: 179 ShaderGroup(); 180 ~ShaderGroup(); 181 SetName(const std::string & name)182 void SetName(const std::string& name) { name_ = name; } GetName()183 std::string GetName() const { return name_; } 184 SetGeneralShader(Shader * shader)185 void SetGeneralShader(Shader* shader) { generalShader_ = shader; } GetGeneralShader()186 Shader* GetGeneralShader() const { return generalShader_; } 187 SetClosestHitShader(Shader * shader)188 void SetClosestHitShader(Shader* shader) { closestHitShader_ = shader; } GetClosestHitShader()189 Shader* GetClosestHitShader() const { return closestHitShader_; } 190 SetAnyHitShader(Shader * shader)191 void SetAnyHitShader(Shader* shader) { anyHitShader_ = shader; } GetAnyHitShader()192 Shader* GetAnyHitShader() const { return anyHitShader_; } 193 SetIntersectionShader(Shader * shader)194 void SetIntersectionShader(Shader* shader) { intersectionShader_ = shader; } GetIntersectionShader()195 Shader* GetIntersectionShader() const { return intersectionShader_; } 196 IsGeneralGroup()197 bool IsGeneralGroup() const { return generalShader_ != nullptr; } IsHitGroup()198 bool IsHitGroup() const { 199 return closestHitShader_ != nullptr || anyHitShader_ != nullptr || 200 intersectionShader_ != nullptr; 201 } GetShaderByType(ShaderType type)202 Shader* GetShaderByType(ShaderType type) const { 203 switch (type) { 204 case kShaderTypeRayGeneration: 205 case kShaderTypeMiss: 206 case kShaderTypeCall: 207 return generalShader_; 208 case kShaderTypeAnyHit: 209 return anyHitShader_; 210 case kShaderTypeClosestHit: 211 return closestHitShader_; 212 case kShaderTypeIntersection: 213 return intersectionShader_; 214 default: 215 assert(0 && "Unsupported shader type"); 216 return nullptr; 217 } 218 } 219 220 private: 221 std::string name_; 222 Shader* generalShader_; 223 Shader* closestHitShader_; 224 Shader* anyHitShader_; 225 Shader* intersectionShader_; 226 }; 227 228 class SBTRecord { 229 public: 230 SBTRecord(); 231 ~SBTRecord(); 232 SetUsedShaderGroupName(const std::string & shader_group_name)233 void SetUsedShaderGroupName(const std::string& shader_group_name) { 234 used_shader_group_name_ = shader_group_name; 235 } GetUsedShaderGroupName()236 std::string GetUsedShaderGroupName() const { return used_shader_group_name_; } 237 SetCount(const uint32_t count)238 void SetCount(const uint32_t count) { count_ = count; } GetCount()239 uint32_t GetCount() const { return count_; } 240 SetIndex(const uint32_t index)241 void SetIndex(const uint32_t index) { index_ = index; } GetIndex()242 uint32_t GetIndex() const { return index_; } 243 244 private: 245 std::string used_shader_group_name_; 246 uint32_t count_ = 1; 247 uint32_t index_ = static_cast<uint32_t>(-1); 248 }; 249 250 class SBT { 251 public: 252 SBT(); 253 ~SBT(); 254 SetName(const std::string & name)255 void SetName(const std::string& name) { name_ = name; } GetName()256 std::string GetName() const { return name_; } 257 AddSBTRecord(std::unique_ptr<SBTRecord> record)258 void AddSBTRecord(std::unique_ptr<SBTRecord> record) { 259 records_.push_back(std::move(record)); 260 } GetSBTRecordCount()261 size_t GetSBTRecordCount() { return records_.size(); } GetSBTRecords()262 std::vector<std::unique_ptr<SBTRecord>>& GetSBTRecords() { return records_; } GetSBTSize()263 uint32_t GetSBTSize() { 264 uint32_t size = 0; 265 for (auto& x : records_) 266 size += x->GetCount(); 267 268 return size; 269 } 270 271 private: 272 std::string name_; 273 std::vector<std::unique_ptr<SBTRecord>> records_; 274 }; 275 276 } // namespace amber 277 278 #endif // SRC_ACCELERATION_STRUCTURE_H_ 279