• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2018 The Amber Authors.
2 // Copyright (C) 2024 Advanced Micro Devices, Inc. All rights reserved.
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 
16 #include "src/executor.h"
17 
18 #include <cassert>
19 #include <string>
20 #include <utility>
21 #include <vector>
22 
23 #include "src/engine.h"
24 #include "src/make_unique.h"
25 #include "src/script.h"
26 #include "src/shader_compiler.h"
27 
28 namespace amber {
29 
30 Executor::Executor() = default;
31 
32 Executor::~Executor() = default;
33 
CompileShaders(const amber::Script * script,const ShaderMap & shader_map,Options * options)34 Result Executor::CompileShaders(const amber::Script* script,
35                                 const ShaderMap& shader_map,
36                                 Options* options) {
37   for (auto& pipeline : script->GetPipelines()) {
38     for (auto& shader_info : pipeline->GetShaders()) {
39       std::string target_env = shader_info.GetShader()->GetTargetEnv();
40       if (target_env.empty())
41         target_env = script->GetSpvTargetEnv();
42 
43       ShaderCompiler sc(target_env, options->disable_spirv_validation,
44                         script->GetVirtualFiles());
45 
46       Result r;
47       std::vector<uint32_t> data;
48       std::tie(r, data) = sc.Compile(pipeline.get(), &shader_info, shader_map);
49       if (!r.IsSuccess())
50         return r;
51 
52       shader_info.SetData(std::move(data));
53     }
54   }
55   return {};
56 }
57 
Execute(Engine * engine,const amber::Script * script,const ShaderMap & shader_map,Options * options,Delegate * delegate)58 Result Executor::Execute(Engine* engine,
59                          const amber::Script* script,
60                          const ShaderMap& shader_map,
61                          Options* options,
62                          Delegate* delegate) {
63   engine->SetEngineData(script->GetEngineData());
64 
65   if (!script->GetPipelines().empty()) {
66     Result r = CompileShaders(script, shader_map, options);
67     if (!r.IsSuccess())
68       return r;
69 
70     // OpenCL specific pipeline updates.
71     for (auto& pipeline : script->GetPipelines()) {
72       r = pipeline->UpdateOpenCLBufferBindings();
73       if (!r.IsSuccess())
74         return r;
75       r = pipeline->GenerateOpenCLPodBuffers();
76       if (!r.IsSuccess())
77         return r;
78       r = pipeline->GenerateOpenCLLiteralSamplers();
79       if (!r.IsSuccess())
80         return r;
81       r = pipeline->GenerateOpenCLPushConstants();
82       if (!r.IsSuccess())
83         return r;
84     }
85 
86     for (auto& pipeline : script->GetPipelines()) {
87       r = engine->CreatePipeline(pipeline.get());
88       if (!r.IsSuccess())
89         return r;
90     }
91   }
92 
93   if (options->execution_type == ExecutionType::kPipelineCreateOnly)
94     return {};
95 
96   // Process Commands
97   for (const auto& cmd : script->GetCommands()) {
98     if (delegate && delegate->LogExecuteCalls()) {
99       delegate->Log(std::to_string(cmd->GetLine()) + ": " + cmd->ToString());
100     }
101 
102     Result r = ExecuteCommand(engine, cmd.get());
103     if (!r.IsSuccess())
104       return r;
105   }
106   return {};
107 }
108 
ExecuteCommand(Engine * engine,Command * cmd)109 Result Executor::ExecuteCommand(Engine* engine, Command* cmd) {
110   if (cmd->IsProbe()) {
111     auto* buffer = cmd->AsProbe()->GetBuffer();
112     assert(buffer);
113 
114     Format* fmt = buffer->GetFormat();
115     return verifier_.Probe(cmd->AsProbe(), fmt, buffer->GetElementStride(),
116                            buffer->GetRowStride(), buffer->GetWidth(),
117                            buffer->GetHeight(), buffer->ValuePtr()->data());
118   }
119   if (cmd->IsProbeSSBO()) {
120     auto probe_ssbo = cmd->AsProbeSSBO();
121 
122     const auto* buffer = cmd->AsProbe()->GetBuffer();
123     assert(buffer);
124 
125     return verifier_.ProbeSSBO(probe_ssbo, buffer->ElementCount(),
126                                buffer->ValuePtr()->data());
127   }
128   if (cmd->IsClear())
129     return engine->DoClear(cmd->AsClear());
130   if (cmd->IsClearColor())
131     return engine->DoClearColor(cmd->AsClearColor());
132   if (cmd->IsClearDepth())
133     return engine->DoClearDepth(cmd->AsClearDepth());
134   if (cmd->IsClearStencil())
135     return engine->DoClearStencil(cmd->AsClearStencil());
136   if (cmd->IsCompareBuffer()) {
137     auto compare = cmd->AsCompareBuffer();
138     auto buffer_1 = compare->GetBuffer1();
139     auto buffer_2 = compare->GetBuffer2();
140     switch (compare->GetComparator()) {
141       case CompareBufferCommand::Comparator::kRmse:
142         return buffer_1->CompareRMSE(buffer_2, compare->GetTolerance());
143       case CompareBufferCommand::Comparator::kHistogramEmd:
144         return buffer_1->CompareHistogramEMD(buffer_2, compare->GetTolerance());
145       case CompareBufferCommand::Comparator::kEq:
146         return buffer_1->IsEqual(buffer_2);
147     }
148   }
149   if (cmd->IsCopy()) {
150     auto copy = cmd->AsCopy();
151     auto buffer_from = copy->GetBufferFrom();
152     auto buffer_to = copy->GetBufferTo();
153     return buffer_from->CopyTo(buffer_to);
154   }
155   if (cmd->IsDrawRect())
156     return engine->DoDrawRect(cmd->AsDrawRect());
157   if (cmd->IsDrawGrid())
158     return engine->DoDrawGrid(cmd->AsDrawGrid());
159   if (cmd->IsDrawArrays())
160     return engine->DoDrawArrays(cmd->AsDrawArrays());
161   if (cmd->IsCompute())
162     return engine->DoCompute(cmd->AsCompute());
163   if (cmd->IsRayTracing())
164     return engine->DoTraceRays(cmd->AsRayTracing());
165   if (cmd->IsEntryPoint())
166     return engine->DoEntryPoint(cmd->AsEntryPoint());
167   if (cmd->IsPatchParameterVertices())
168     return engine->DoPatchParameterVertices(cmd->AsPatchParameterVertices());
169   if (cmd->IsBuffer())
170     return engine->DoBuffer(cmd->AsBuffer());
171   if (cmd->IsRepeat()) {
172     for (uint32_t i = 0; i < cmd->AsRepeat()->GetCount(); ++i) {
173       for (const auto& sub_cmd : cmd->AsRepeat()->GetCommands()) {
174         Result r = ExecuteCommand(engine, sub_cmd.get());
175         if (!r.IsSuccess())
176           return r;
177       }
178     }
179     return {};
180   }
181   return Result("Unknown command type: " +
182                 std::to_string(static_cast<uint32_t>(cmd->GetType())));
183 }
184 
185 }  // namespace amber
186