1 // Copyright 2018 The Amber Authors. 2 // Copyright (C) 2024 Advanced Micro Devices, Inc. All rights reserved. 3 // 4 // Licensed under the Apache License, Version 2.0 (the "License"); 5 // you may not use this file except in compliance with the License. 6 // You may obtain a copy of the License at 7 // 8 // http://www.apache.org/licenses/LICENSE-2.0 9 // 10 // Unless required by applicable law or agreed to in writing, software 11 // distributed under the License is distributed on an "AS IS" BASIS, 12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 // See the License for the specific language governing permissions and 14 // limitations under the License. 15 16 #ifndef SRC_SCRIPT_H_ 17 #define SRC_SCRIPT_H_ 18 19 #include <algorithm> 20 #include <cstdint> 21 #include <map> 22 #include <memory> 23 #include <string> 24 #include <utility> 25 #include <vector> 26 27 #include "amber/recipe.h" 28 #include "amber/result.h" 29 #include "src/acceleration_structure.h" 30 #include "src/buffer.h" 31 #include "src/command.h" 32 #include "src/engine.h" 33 #include "src/format.h" 34 #include "src/pipeline.h" 35 #include "src/sampler.h" 36 #include "src/shader.h" 37 #include "src/virtual_file_store.h" 38 39 namespace amber { 40 41 /// Class representing the script to be run against an engine. 42 class Script : public RecipeImpl { 43 public: 44 Script(); 45 ~Script() override; 46 47 bool IsKnownFeature(const std::string& name) const; 48 bool IsKnownProperty(const std::string& name) const; 49 50 /// Retrieves information on the shaders in the given script. 51 std::vector<ShaderInfo> GetShaderInfo() const override; 52 53 /// Returns required features in the given recipe. GetRequiredFeatures()54 std::vector<std::string> GetRequiredFeatures() const override { 55 return engine_info_.required_features; 56 } 57 GetRequiredProperties()58 std::vector<std::string> GetRequiredProperties() const override { 59 return engine_info_.required_properties; 60 } 61 62 /// Returns required device extensions in the given recipe. GetRequiredDeviceExtensions()63 std::vector<std::string> GetRequiredDeviceExtensions() const override { 64 return engine_info_.required_device_extensions; 65 } 66 67 /// Returns required instance extensions in the given recipe. GetRequiredInstanceExtensions()68 std::vector<std::string> GetRequiredInstanceExtensions() const override { 69 return engine_info_.required_instance_extensions; 70 } 71 72 /// Sets the fence timeout to |timeout_ms|. SetFenceTimeout(uint32_t timeout_ms)73 void SetFenceTimeout(uint32_t timeout_ms) override { 74 engine_data_.fence_timeout_ms = timeout_ms; 75 } 76 77 /// Sets or clears runtime layer bit to |enabled|. SetPipelineRuntimeLayerEnabled(bool enabled)78 void SetPipelineRuntimeLayerEnabled(bool enabled) override { 79 engine_data_.pipeline_runtime_layer_enabled = enabled; 80 } 81 82 /// Adds |pipeline| to the list of known pipelines. The |pipeline| must have 83 /// a unique name over all pipelines in the script. AddPipeline(std::unique_ptr<Pipeline> pipeline)84 Result AddPipeline(std::unique_ptr<Pipeline> pipeline) { 85 if (name_to_pipeline_.count(pipeline->GetName()) > 0) 86 return Result("duplicate pipeline name provided"); 87 88 pipelines_.push_back(std::move(pipeline)); 89 name_to_pipeline_[pipelines_.back()->GetName()] = pipelines_.back().get(); 90 return {}; 91 } 92 93 /// Retrieves the pipeline with |name|, |nullptr| if not found. GetPipeline(const std::string & name)94 Pipeline* GetPipeline(const std::string& name) const { 95 auto it = name_to_pipeline_.find(name); 96 return it == name_to_pipeline_.end() ? nullptr : it->second; 97 } 98 99 /// Retrieves a list of all pipelines. GetPipelines()100 const std::vector<std::unique_ptr<Pipeline>>& GetPipelines() const { 101 return pipelines_; 102 } 103 104 /// Adds |shader| to the list of known shaders. The |shader| must have a 105 /// unique name over all shaders in the script. AddShader(std::unique_ptr<Shader> shader)106 Result AddShader(std::unique_ptr<Shader> shader) { 107 if (name_to_shader_.count(shader->GetName()) > 0) 108 return Result("duplicate shader name provided"); 109 110 shaders_.push_back(std::move(shader)); 111 name_to_shader_[shaders_.back()->GetName()] = shaders_.back().get(); 112 return {}; 113 } 114 115 /// Retrieves the shader with |name|, |nullptr| if not found. GetShader(const std::string & name)116 Shader* GetShader(const std::string& name) const { 117 auto it = name_to_shader_.find(name); 118 return it == name_to_shader_.end() ? nullptr : it->second; 119 } 120 121 /// Retrieves a list of all shaders. GetShaders()122 const std::vector<std::unique_ptr<Shader>>& GetShaders() const { 123 return shaders_; 124 } 125 126 /// Search |pipeline| and all included into pipeline libraries whether shader 127 /// with |name| is present in pipeline groups. Returns shader if found, 128 /// |nullptr| if not found. FindShader(const Pipeline * pipeline,Shader * shader)129 Shader* FindShader(const Pipeline* pipeline, Shader* shader) const { 130 if (shader) { 131 for (auto group : pipeline->GetShaderGroups()) { 132 Shader* test_shader = group->GetShaderByType(shader->GetType()); 133 if (test_shader == shader) 134 return shader; 135 } 136 137 for (auto lib : pipeline->GetPipelineLibraries()) { 138 shader = FindShader(lib, shader); 139 if (shader) 140 return shader; 141 } 142 } 143 144 return nullptr; 145 } 146 147 /// Search |pipeline| and all included into pipeline libraries whether shader 148 /// group with |name| is present. Returns shader group if found, |nullptr| 149 /// if not found. |index| is an shader group index in pipeline or library. FindShaderGroup(const Pipeline * pipeline,const std::string & name,uint32_t * index)150 ShaderGroup* FindShaderGroup(const Pipeline* pipeline, 151 const std::string& name, 152 uint32_t* index) const { 153 ShaderGroup* result = nullptr; 154 uint32_t shader_group_index = pipeline->GetShaderGroupIndex(name); 155 if (shader_group_index != static_cast<uint32_t>(-1)) { 156 (*index) += shader_group_index; 157 result = pipeline->GetShaderGroupByIndex(shader_group_index); 158 return result; 159 } else { 160 (*index) += static_cast<uint32_t>(pipeline->GetShaderGroups().size()); 161 } 162 163 for (auto lib : pipeline->GetPipelineLibraries()) { 164 result = FindShaderGroup(lib, name, index); 165 if (result) 166 return result; 167 } 168 169 *index = static_cast<uint32_t>(-1); 170 171 return nullptr; 172 } 173 174 /// Adds |buffer| to the list of known buffers. The |buffer| must have a 175 /// unique name over all buffers in the script. AddBuffer(std::unique_ptr<Buffer> buffer)176 Result AddBuffer(std::unique_ptr<Buffer> buffer) { 177 if (name_to_buffer_.count(buffer->GetName()) > 0) 178 return Result("duplicate buffer name provided"); 179 180 buffers_.push_back(std::move(buffer)); 181 name_to_buffer_[buffers_.back()->GetName()] = buffers_.back().get(); 182 return {}; 183 } 184 185 /// Retrieves the buffer with |name|, |nullptr| if not found. GetBuffer(const std::string & name)186 Buffer* GetBuffer(const std::string& name) const { 187 auto it = name_to_buffer_.find(name); 188 return it == name_to_buffer_.end() ? nullptr : it->second; 189 } 190 191 /// Retrieves a list of all buffers. GetBuffers()192 const std::vector<std::unique_ptr<Buffer>>& GetBuffers() const { 193 return buffers_; 194 } 195 196 /// Adds |sampler| to the list of known sampler. The |sampler| must have a 197 /// unique name over all samplers in the script. AddSampler(std::unique_ptr<Sampler> sampler)198 Result AddSampler(std::unique_ptr<Sampler> sampler) { 199 if (name_to_sampler_.count(sampler->GetName()) > 0) 200 return Result("duplicate sampler name provided"); 201 202 samplers_.push_back(std::move(sampler)); 203 name_to_sampler_[samplers_.back()->GetName()] = samplers_.back().get(); 204 return {}; 205 } 206 207 /// Retrieves the sampler with |name|, |nullptr| if not found. GetSampler(const std::string & name)208 Sampler* GetSampler(const std::string& name) const { 209 auto it = name_to_sampler_.find(name); 210 return it == name_to_sampler_.end() ? nullptr : it->second; 211 } 212 213 /// Retrieves a list of all samplers. GetSamplers()214 const std::vector<std::unique_ptr<Sampler>>& GetSamplers() const { 215 return samplers_; 216 } 217 218 /// Adds |blas| to the list of known bottom level acceleration structures. 219 /// The |blas| must have a unique name over all BLASes in the script. AddBLAS(std::unique_ptr<BLAS> blas)220 Result AddBLAS(std::unique_ptr<BLAS> blas) { 221 if (name_to_blas_.count(blas->GetName()) > 0) 222 return Result("duplicate BLAS name provided"); 223 224 blases_.push_back(std::move(blas)); 225 name_to_blas_[blases_.back()->GetName()] = blases_.back().get(); 226 227 return {}; 228 } 229 230 /// Retrieves the BLAS with |name|, |nullptr| if not found. GetBLAS(const std::string & name)231 BLAS* GetBLAS(const std::string& name) const { 232 auto it = name_to_blas_.find(name); 233 return it == name_to_blas_.end() ? nullptr : it->second; 234 } 235 236 /// Retrieves a list of all BLASes. GetBLASes()237 const std::vector<std::unique_ptr<BLAS>>& GetBLASes() const { 238 return blases_; 239 } 240 241 /// Adds |tlas| to the list of known top level acceleration structures. 242 /// The |tlas| must have a unique name over all TLASes in the script. AddTLAS(std::unique_ptr<TLAS> tlas)243 Result AddTLAS(std::unique_ptr<TLAS> tlas) { 244 if (name_to_tlas_.count(tlas->GetName()) > 0) 245 return Result("duplicate TLAS name provided"); 246 247 tlases_.push_back(std::move(tlas)); 248 name_to_tlas_[tlases_.back()->GetName()] = tlases_.back().get(); 249 250 return {}; 251 } 252 253 /// Retrieves the TLAS with |name|, |nullptr| if not found. GetTLAS(const std::string & name)254 TLAS* GetTLAS(const std::string& name) const { 255 auto it = name_to_tlas_.find(name); 256 return it == name_to_tlas_.end() ? nullptr : it->second; 257 } 258 259 /// Retrieves a list of all TLASes. GetTLASes()260 const std::vector<std::unique_ptr<TLAS>>& GetTLASes() const { 261 return tlases_; 262 } 263 264 /// Adds |feature| to the list of features that must be supported by the 265 /// engine. AddRequiredFeature(const std::string & feature)266 void AddRequiredFeature(const std::string& feature) { 267 engine_info_.required_features.push_back(feature); 268 } 269 270 /// Adds |prop| to the list of properties that must be supported by the 271 /// engine. AddRequiredProperty(const std::string & prop)272 void AddRequiredProperty(const std::string& prop) { 273 engine_info_.required_properties.push_back(prop); 274 } 275 276 /// Checks if |feature| is in required features IsRequiredFeature(const std::string & feature)277 bool IsRequiredFeature(const std::string& feature) const { 278 return std::find(engine_info_.required_features.begin(), 279 engine_info_.required_features.end(), 280 feature) != engine_info_.required_features.end(); 281 } 282 283 /// Checks if |prop| is in required features IsRequiredProperty(const std::string & prop)284 bool IsRequiredProperty(const std::string& prop) const { 285 return std::find(engine_info_.required_properties.begin(), 286 engine_info_.required_properties.end(), 287 prop) != engine_info_.required_properties.end(); 288 } 289 290 /// Adds |ext| to the list of device extensions that must be supported. AddRequiredDeviceExtension(const std::string & ext)291 void AddRequiredDeviceExtension(const std::string& ext) { 292 engine_info_.required_device_extensions.push_back(ext); 293 } 294 295 /// Adds |ext| to the list of instance extensions that must be supported. AddRequiredInstanceExtension(const std::string & ext)296 void AddRequiredInstanceExtension(const std::string& ext) { 297 engine_info_.required_instance_extensions.push_back(ext); 298 } 299 300 /// Adds |ext| to the list of extensions that must be supported by the engine. 301 /// Note, this should only be used by the VkScript engine where there is no 302 /// differentiation between the types of extensions. 303 void AddRequiredExtension(const std::string& ext); 304 305 /// Retrieves the engine configuration data for this script. GetEngineData()306 EngineData& GetEngineData() { return engine_data_; } 307 /// Retrieves the engine configuration data for this script. GetEngineData()308 const EngineData& GetEngineData() const { return engine_data_; } 309 310 /// Sets |cmds| to the list of commands to execute against the engine. SetCommands(std::vector<std::unique_ptr<Command>> cmds)311 void SetCommands(std::vector<std::unique_ptr<Command>> cmds) { 312 commands_ = std::move(cmds); 313 } 314 315 /// Retrieves the list of commands to execute against the engine. GetCommands()316 const std::vector<std::unique_ptr<Command>>& GetCommands() const { 317 return commands_; 318 } 319 320 /// Sets the SPIR-V target environment. SetSpvTargetEnv(const std::string & env)321 void SetSpvTargetEnv(const std::string& env) { spv_env_ = env; } 322 /// Retrieves the SPIR-V target environment. GetSpvTargetEnv()323 const std::string& GetSpvTargetEnv() const { return spv_env_; } 324 325 /// Assign ownership of the format to the script. RegisterFormat(std::unique_ptr<Format> fmt)326 Format* RegisterFormat(std::unique_ptr<Format> fmt) { 327 formats_.push_back(std::move(fmt)); 328 return formats_.back().get(); 329 } 330 331 /// Assigns ownership of the type to the script. RegisterType(std::unique_ptr<type::Type> type)332 type::Type* RegisterType(std::unique_ptr<type::Type> type) { 333 types_.push_back(std::move(type)); 334 return types_.back().get(); 335 } 336 337 /// Adds |type| to the list of known types. The |type| must have 338 /// a unique name over all types in the script. AddType(const std::string & name,std::unique_ptr<type::Type> type)339 Result AddType(const std::string& name, std::unique_ptr<type::Type> type) { 340 if (name_to_type_.count(name) > 0) 341 return Result("duplicate type name provided"); 342 343 name_to_type_[name] = std::move(type); 344 return {}; 345 } 346 347 /// Retrieves the type with |name|, |nullptr| if not found. GetType(const std::string & name)348 type::Type* GetType(const std::string& name) const { 349 auto it = name_to_type_.find(name); 350 return it == name_to_type_.end() ? nullptr : it->second.get(); 351 } 352 353 // Returns the virtual file store. GetVirtualFiles()354 VirtualFileStore* GetVirtualFiles() const { return virtual_files_.get(); } 355 356 /// Adds the virtual file with content |content| to the virtual file path 357 /// |path|. If there's already a virtual file with the given path, an error is 358 /// returned. AddVirtualFile(const std::string & path,const std::string & content)359 Result AddVirtualFile(const std::string& path, const std::string& content) { 360 return virtual_files_->Add(path, content); 361 } 362 363 /// Look up the virtual file by path. If the file was found, the content is 364 /// assigned to content. GetVirtualFile(const std::string & path,std::string * content)365 Result GetVirtualFile(const std::string& path, std::string* content) const { 366 return virtual_files_->Get(path, content); 367 } 368 369 type::Type* ParseType(const std::string& str); 370 371 private: 372 struct { 373 std::vector<std::string> required_features; 374 std::vector<std::string> required_properties; 375 std::vector<std::string> required_device_extensions; 376 std::vector<std::string> required_instance_extensions; 377 } engine_info_; 378 379 EngineData engine_data_; 380 std::string spv_env_; 381 std::map<std::string, Shader*> name_to_shader_; 382 std::map<std::string, Buffer*> name_to_buffer_; 383 std::map<std::string, Sampler*> name_to_sampler_; 384 std::map<std::string, Pipeline*> name_to_pipeline_; 385 std::map<std::string, BLAS*> name_to_blas_; 386 std::map<std::string, TLAS*> name_to_tlas_; 387 std::map<std::string, std::unique_ptr<type::Type>> name_to_type_; 388 std::vector<std::unique_ptr<Shader>> shaders_; 389 std::vector<std::unique_ptr<Command>> commands_; 390 std::vector<std::unique_ptr<Buffer>> buffers_; 391 std::vector<std::unique_ptr<Sampler>> samplers_; 392 std::vector<std::unique_ptr<Pipeline>> pipelines_; 393 std::vector<std::unique_ptr<BLAS>> blases_; 394 std::vector<std::unique_ptr<TLAS>> tlases_; 395 std::vector<std::unique_ptr<type::Type>> types_; 396 std::vector<std::unique_ptr<Format>> formats_; 397 std::unique_ptr<VirtualFileStore> virtual_files_; 398 }; 399 400 } // namespace amber 401 402 #endif // SRC_SCRIPT_H_ 403