• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2018 The Amber Authors.
2 // Copyright (C) 2024 Advanced Micro Devices, Inc. All rights reserved.
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 
16 #ifndef SRC_SCRIPT_H_
17 #define SRC_SCRIPT_H_
18 
19 #include <algorithm>
20 #include <cstdint>
21 #include <map>
22 #include <memory>
23 #include <string>
24 #include <utility>
25 #include <vector>
26 
27 #include "amber/recipe.h"
28 #include "amber/result.h"
29 #include "src/acceleration_structure.h"
30 #include "src/buffer.h"
31 #include "src/command.h"
32 #include "src/engine.h"
33 #include "src/format.h"
34 #include "src/pipeline.h"
35 #include "src/sampler.h"
36 #include "src/shader.h"
37 #include "src/virtual_file_store.h"
38 
39 namespace amber {
40 
41 /// Class representing the script to be run against an engine.
42 class Script : public RecipeImpl {
43  public:
44   Script();
45   ~Script() override;
46 
47   bool IsKnownFeature(const std::string& name) const;
48   bool IsKnownProperty(const std::string& name) const;
49 
50   /// Retrieves information on the shaders in the given script.
51   std::vector<ShaderInfo> GetShaderInfo() const override;
52 
53   /// Returns required features in the given recipe.
GetRequiredFeatures()54   std::vector<std::string> GetRequiredFeatures() const override {
55     return engine_info_.required_features;
56   }
57 
GetRequiredProperties()58   std::vector<std::string> GetRequiredProperties() const override {
59     return engine_info_.required_properties;
60   }
61 
62   /// Returns required device extensions in the given recipe.
GetRequiredDeviceExtensions()63   std::vector<std::string> GetRequiredDeviceExtensions() const override {
64     return engine_info_.required_device_extensions;
65   }
66 
67   /// Returns required instance extensions in the given recipe.
GetRequiredInstanceExtensions()68   std::vector<std::string> GetRequiredInstanceExtensions() const override {
69     return engine_info_.required_instance_extensions;
70   }
71 
72   /// Sets the fence timeout to |timeout_ms|.
SetFenceTimeout(uint32_t timeout_ms)73   void SetFenceTimeout(uint32_t timeout_ms) override {
74     engine_data_.fence_timeout_ms = timeout_ms;
75   }
76 
77   /// Sets or clears runtime layer bit to |enabled|.
SetPipelineRuntimeLayerEnabled(bool enabled)78   void SetPipelineRuntimeLayerEnabled(bool enabled) override {
79     engine_data_.pipeline_runtime_layer_enabled = enabled;
80   }
81 
82   /// Adds |pipeline| to the list of known pipelines. The |pipeline| must have
83   /// a unique name over all pipelines in the script.
AddPipeline(std::unique_ptr<Pipeline> pipeline)84   Result AddPipeline(std::unique_ptr<Pipeline> pipeline) {
85     if (name_to_pipeline_.count(pipeline->GetName()) > 0)
86       return Result("duplicate pipeline name provided");
87 
88     pipelines_.push_back(std::move(pipeline));
89     name_to_pipeline_[pipelines_.back()->GetName()] = pipelines_.back().get();
90     return {};
91   }
92 
93   /// Retrieves the pipeline with |name|, |nullptr| if not found.
GetPipeline(const std::string & name)94   Pipeline* GetPipeline(const std::string& name) const {
95     auto it = name_to_pipeline_.find(name);
96     return it == name_to_pipeline_.end() ? nullptr : it->second;
97   }
98 
99   /// Retrieves a list of all pipelines.
GetPipelines()100   const std::vector<std::unique_ptr<Pipeline>>& GetPipelines() const {
101     return pipelines_;
102   }
103 
104   /// Adds |shader| to the list of known shaders. The |shader| must have a
105   /// unique name over all shaders in the script.
AddShader(std::unique_ptr<Shader> shader)106   Result AddShader(std::unique_ptr<Shader> shader) {
107     if (name_to_shader_.count(shader->GetName()) > 0)
108       return Result("duplicate shader name provided");
109 
110     shaders_.push_back(std::move(shader));
111     name_to_shader_[shaders_.back()->GetName()] = shaders_.back().get();
112     return {};
113   }
114 
115   /// Retrieves the shader with |name|, |nullptr| if not found.
GetShader(const std::string & name)116   Shader* GetShader(const std::string& name) const {
117     auto it = name_to_shader_.find(name);
118     return it == name_to_shader_.end() ? nullptr : it->second;
119   }
120 
121   /// Retrieves a list of all shaders.
GetShaders()122   const std::vector<std::unique_ptr<Shader>>& GetShaders() const {
123     return shaders_;
124   }
125 
126   /// Search |pipeline| and all included into pipeline libraries whether shader
127   /// with |name| is present in pipeline groups. Returns shader if found,
128   /// |nullptr| if not found.
FindShader(const Pipeline * pipeline,Shader * shader)129   Shader* FindShader(const Pipeline* pipeline, Shader* shader) const {
130     if (shader) {
131       for (auto group : pipeline->GetShaderGroups()) {
132         Shader* test_shader = group->GetShaderByType(shader->GetType());
133         if (test_shader == shader)
134           return shader;
135       }
136 
137       for (auto lib : pipeline->GetPipelineLibraries()) {
138         shader = FindShader(lib, shader);
139         if (shader)
140           return shader;
141       }
142     }
143 
144     return nullptr;
145   }
146 
147   /// Search |pipeline| and all included into pipeline libraries whether shader
148   /// group with |name| is present. Returns shader group if found, |nullptr|
149   /// if not found. |index| is an shader group index in pipeline or library.
FindShaderGroup(const Pipeline * pipeline,const std::string & name,uint32_t * index)150   ShaderGroup* FindShaderGroup(const Pipeline* pipeline,
151                                const std::string& name,
152                                uint32_t* index) const {
153     ShaderGroup* result = nullptr;
154     uint32_t shader_group_index = pipeline->GetShaderGroupIndex(name);
155     if (shader_group_index != static_cast<uint32_t>(-1)) {
156       (*index) += shader_group_index;
157       result = pipeline->GetShaderGroupByIndex(shader_group_index);
158       return result;
159     } else {
160       (*index) += static_cast<uint32_t>(pipeline->GetShaderGroups().size());
161     }
162 
163     for (auto lib : pipeline->GetPipelineLibraries()) {
164       result = FindShaderGroup(lib, name, index);
165       if (result)
166         return result;
167     }
168 
169     *index = static_cast<uint32_t>(-1);
170 
171     return nullptr;
172   }
173 
174   /// Adds |buffer| to the list of known buffers. The |buffer| must have a
175   /// unique name over all buffers in the script.
AddBuffer(std::unique_ptr<Buffer> buffer)176   Result AddBuffer(std::unique_ptr<Buffer> buffer) {
177     if (name_to_buffer_.count(buffer->GetName()) > 0)
178       return Result("duplicate buffer name provided");
179 
180     buffers_.push_back(std::move(buffer));
181     name_to_buffer_[buffers_.back()->GetName()] = buffers_.back().get();
182     return {};
183   }
184 
185   /// Retrieves the buffer with |name|, |nullptr| if not found.
GetBuffer(const std::string & name)186   Buffer* GetBuffer(const std::string& name) const {
187     auto it = name_to_buffer_.find(name);
188     return it == name_to_buffer_.end() ? nullptr : it->second;
189   }
190 
191   /// Retrieves a list of all buffers.
GetBuffers()192   const std::vector<std::unique_ptr<Buffer>>& GetBuffers() const {
193     return buffers_;
194   }
195 
196   /// Adds |sampler| to the list of known sampler. The |sampler| must have a
197   /// unique name over all samplers in the script.
AddSampler(std::unique_ptr<Sampler> sampler)198   Result AddSampler(std::unique_ptr<Sampler> sampler) {
199     if (name_to_sampler_.count(sampler->GetName()) > 0)
200       return Result("duplicate sampler name provided");
201 
202     samplers_.push_back(std::move(sampler));
203     name_to_sampler_[samplers_.back()->GetName()] = samplers_.back().get();
204     return {};
205   }
206 
207   /// Retrieves the sampler with |name|, |nullptr| if not found.
GetSampler(const std::string & name)208   Sampler* GetSampler(const std::string& name) const {
209     auto it = name_to_sampler_.find(name);
210     return it == name_to_sampler_.end() ? nullptr : it->second;
211   }
212 
213   /// Retrieves a list of all samplers.
GetSamplers()214   const std::vector<std::unique_ptr<Sampler>>& GetSamplers() const {
215     return samplers_;
216   }
217 
218   /// Adds |blas| to the list of known bottom level acceleration structures.
219   /// The |blas| must have a unique name over all BLASes in the script.
AddBLAS(std::unique_ptr<BLAS> blas)220   Result AddBLAS(std::unique_ptr<BLAS> blas) {
221     if (name_to_blas_.count(blas->GetName()) > 0)
222       return Result("duplicate BLAS name provided");
223 
224     blases_.push_back(std::move(blas));
225     name_to_blas_[blases_.back()->GetName()] = blases_.back().get();
226 
227     return {};
228   }
229 
230   /// Retrieves the BLAS with |name|, |nullptr| if not found.
GetBLAS(const std::string & name)231   BLAS* GetBLAS(const std::string& name) const {
232     auto it = name_to_blas_.find(name);
233     return it == name_to_blas_.end() ? nullptr : it->second;
234   }
235 
236   /// Retrieves a list of all BLASes.
GetBLASes()237   const std::vector<std::unique_ptr<BLAS>>& GetBLASes() const {
238     return blases_;
239   }
240 
241   /// Adds |tlas| to the list of known top level acceleration structures.
242   /// The |tlas| must have a unique name over all TLASes in the script.
AddTLAS(std::unique_ptr<TLAS> tlas)243   Result AddTLAS(std::unique_ptr<TLAS> tlas) {
244     if (name_to_tlas_.count(tlas->GetName()) > 0)
245       return Result("duplicate TLAS name provided");
246 
247     tlases_.push_back(std::move(tlas));
248     name_to_tlas_[tlases_.back()->GetName()] = tlases_.back().get();
249 
250     return {};
251   }
252 
253   /// Retrieves the TLAS with |name|, |nullptr| if not found.
GetTLAS(const std::string & name)254   TLAS* GetTLAS(const std::string& name) const {
255     auto it = name_to_tlas_.find(name);
256     return it == name_to_tlas_.end() ? nullptr : it->second;
257   }
258 
259   /// Retrieves a list of all TLASes.
GetTLASes()260   const std::vector<std::unique_ptr<TLAS>>& GetTLASes() const {
261     return tlases_;
262   }
263 
264   /// Adds |feature| to the list of features that must be supported by the
265   /// engine.
AddRequiredFeature(const std::string & feature)266   void AddRequiredFeature(const std::string& feature) {
267     engine_info_.required_features.push_back(feature);
268   }
269 
270   /// Adds |prop| to the list of properties that must be supported by the
271   /// engine.
AddRequiredProperty(const std::string & prop)272   void AddRequiredProperty(const std::string& prop) {
273     engine_info_.required_properties.push_back(prop);
274   }
275 
276   /// Checks if |feature| is in required features
IsRequiredFeature(const std::string & feature)277   bool IsRequiredFeature(const std::string& feature) const {
278     return std::find(engine_info_.required_features.begin(),
279                      engine_info_.required_features.end(),
280                      feature) != engine_info_.required_features.end();
281   }
282 
283   /// Checks if |prop| is in required features
IsRequiredProperty(const std::string & prop)284   bool IsRequiredProperty(const std::string& prop) const {
285     return std::find(engine_info_.required_properties.begin(),
286                      engine_info_.required_properties.end(),
287                      prop) != engine_info_.required_properties.end();
288   }
289 
290   /// Adds |ext| to the list of device extensions that must be supported.
AddRequiredDeviceExtension(const std::string & ext)291   void AddRequiredDeviceExtension(const std::string& ext) {
292     engine_info_.required_device_extensions.push_back(ext);
293   }
294 
295   /// Adds |ext| to the list of instance extensions that must be supported.
AddRequiredInstanceExtension(const std::string & ext)296   void AddRequiredInstanceExtension(const std::string& ext) {
297     engine_info_.required_instance_extensions.push_back(ext);
298   }
299 
300   /// Adds |ext| to the list of extensions that must be supported by the engine.
301   /// Note, this should only be used by the VkScript engine where there is no
302   /// differentiation between the types of extensions.
303   void AddRequiredExtension(const std::string& ext);
304 
305   /// Retrieves the engine configuration data for this script.
GetEngineData()306   EngineData& GetEngineData() { return engine_data_; }
307   /// Retrieves the engine configuration data for this script.
GetEngineData()308   const EngineData& GetEngineData() const { return engine_data_; }
309 
310   /// Sets |cmds| to the list of commands to execute against the engine.
SetCommands(std::vector<std::unique_ptr<Command>> cmds)311   void SetCommands(std::vector<std::unique_ptr<Command>> cmds) {
312     commands_ = std::move(cmds);
313   }
314 
315   /// Retrieves the list of commands to execute against the engine.
GetCommands()316   const std::vector<std::unique_ptr<Command>>& GetCommands() const {
317     return commands_;
318   }
319 
320   /// Sets the SPIR-V target environment.
SetSpvTargetEnv(const std::string & env)321   void SetSpvTargetEnv(const std::string& env) { spv_env_ = env; }
322   /// Retrieves the SPIR-V target environment.
GetSpvTargetEnv()323   const std::string& GetSpvTargetEnv() const { return spv_env_; }
324 
325   /// Assign ownership of the format to the script.
RegisterFormat(std::unique_ptr<Format> fmt)326   Format* RegisterFormat(std::unique_ptr<Format> fmt) {
327     formats_.push_back(std::move(fmt));
328     return formats_.back().get();
329   }
330 
331   /// Assigns ownership of the type to the script.
RegisterType(std::unique_ptr<type::Type> type)332   type::Type* RegisterType(std::unique_ptr<type::Type> type) {
333     types_.push_back(std::move(type));
334     return types_.back().get();
335   }
336 
337   /// Adds |type| to the list of known types. The |type| must have
338   /// a unique name over all types in the script.
AddType(const std::string & name,std::unique_ptr<type::Type> type)339   Result AddType(const std::string& name, std::unique_ptr<type::Type> type) {
340     if (name_to_type_.count(name) > 0)
341       return Result("duplicate type name provided");
342 
343     name_to_type_[name] = std::move(type);
344     return {};
345   }
346 
347   /// Retrieves the type with |name|, |nullptr| if not found.
GetType(const std::string & name)348   type::Type* GetType(const std::string& name) const {
349     auto it = name_to_type_.find(name);
350     return it == name_to_type_.end() ? nullptr : it->second.get();
351   }
352 
353   // Returns the virtual file store.
GetVirtualFiles()354   VirtualFileStore* GetVirtualFiles() const { return virtual_files_.get(); }
355 
356   /// Adds the virtual file with content |content| to the virtual file path
357   /// |path|. If there's already a virtual file with the given path, an error is
358   /// returned.
AddVirtualFile(const std::string & path,const std::string & content)359   Result AddVirtualFile(const std::string& path, const std::string& content) {
360     return virtual_files_->Add(path, content);
361   }
362 
363   /// Look up the virtual file by path. If the file was found, the content is
364   /// assigned to content.
GetVirtualFile(const std::string & path,std::string * content)365   Result GetVirtualFile(const std::string& path, std::string* content) const {
366     return virtual_files_->Get(path, content);
367   }
368 
369   type::Type* ParseType(const std::string& str);
370 
371  private:
372   struct {
373     std::vector<std::string> required_features;
374     std::vector<std::string> required_properties;
375     std::vector<std::string> required_device_extensions;
376     std::vector<std::string> required_instance_extensions;
377   } engine_info_;
378 
379   EngineData engine_data_;
380   std::string spv_env_;
381   std::map<std::string, Shader*> name_to_shader_;
382   std::map<std::string, Buffer*> name_to_buffer_;
383   std::map<std::string, Sampler*> name_to_sampler_;
384   std::map<std::string, Pipeline*> name_to_pipeline_;
385   std::map<std::string, BLAS*> name_to_blas_;
386   std::map<std::string, TLAS*> name_to_tlas_;
387   std::map<std::string, std::unique_ptr<type::Type>> name_to_type_;
388   std::vector<std::unique_ptr<Shader>> shaders_;
389   std::vector<std::unique_ptr<Command>> commands_;
390   std::vector<std::unique_ptr<Buffer>> buffers_;
391   std::vector<std::unique_ptr<Sampler>> samplers_;
392   std::vector<std::unique_ptr<Pipeline>> pipelines_;
393   std::vector<std::unique_ptr<BLAS>> blases_;
394   std::vector<std::unique_ptr<TLAS>> tlases_;
395   std::vector<std::unique_ptr<type::Type>> types_;
396   std::vector<std::unique_ptr<Format>> formats_;
397   std::unique_ptr<VirtualFileStore> virtual_files_;
398 };
399 
400 }  // namespace amber
401 
402 #endif  // SRC_SCRIPT_H_
403