1 // Copyright 2018 The Amber Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/shader_compiler.h"
16 
17 #include <algorithm>
18 #include <cstdlib>
19 #include <cstring>
20 #include <iterator>
21 #include <string>
22 #include <utility>
23 
24 #if AMBER_ENABLE_SPIRV_TOOLS
25 #include "spirv-tools/libspirv.hpp"
26 #include "spirv-tools/linker.hpp"
27 #include "spirv-tools/optimizer.hpp"
28 #endif  // AMBER_ENABLE_SPIRV_TOOLS
29 
30 #if AMBER_ENABLE_SHADERC
31 #pragma clang diagnostic push
32 #pragma clang diagnostic ignored "-Wold-style-cast"
33 #pragma clang diagnostic ignored "-Wshadow-uncaptured-local"
34 #pragma clang diagnostic ignored "-Wweak-vtables"
35 #include "shaderc/shaderc.hpp"
36 #pragma clang diagnostic pop
37 #endif  // AMBER_ENABLE_SHADERC
38 
39 #if AMBER_ENABLE_DXC
40 #include "src/dxc_helper.h"
41 #endif  // AMBER_ENABLE_DXC
42 
43 #if AMBER_ENABLE_CLSPV
44 #include "src/clspv_helper.h"
45 #endif  // AMBER_ENABLE_CLSPV
46 
47 namespace amber {
48 
49 ShaderCompiler::ShaderCompiler() = default;
50 
ShaderCompiler(const std::string & env,bool disable_spirv_validation,VirtualFileStore * virtual_files)51 ShaderCompiler::ShaderCompiler(const std::string& env,
52                                bool disable_spirv_validation,
53                                VirtualFileStore* virtual_files)
54     : spv_env_(env),
55       disable_spirv_validation_(disable_spirv_validation),
56       virtual_files_(virtual_files) {
57   // Do not warn about virtual_files_ not being used.
58   // This is conditionally used based on preprocessor defines.
59   (void)virtual_files_;
60 }
61 
62 ShaderCompiler::~ShaderCompiler() = default;
63 
Compile(Pipeline * pipeline,Pipeline::ShaderInfo * shader_info,const ShaderMap & shader_map) const64 std::pair<Result, std::vector<uint32_t>> ShaderCompiler::Compile(
65     Pipeline* pipeline,
66     Pipeline::ShaderInfo* shader_info,
67     const ShaderMap& shader_map) const {
68   const auto shader = shader_info->GetShader();
69   std::string key = shader->GetName();
70   const std::string pipeline_name = pipeline->GetName();
71   if (pipeline_name != "") {
72     key = pipeline_name + "-" + key;
73   }
74   auto it = shader_map.find(key);
75   if (it != shader_map.end()) {
76 #if AMBER_ENABLE_CLSPV
77     if (shader->GetFormat() == kShaderFormatOpenCLC) {
78       return {Result("OPENCL-C shaders do not support pre-compiled shaders"),
79               {}};
80     }
81 #endif  // AMBER_ENABLE_CLSPV
82     return {{}, it->second};
83   }
84 
85 #if AMBER_ENABLE_SPIRV_TOOLS
86   std::string spv_errors;
87 
88   spv_target_env target_env = SPV_ENV_UNIVERSAL_1_0;
89   if (!spv_env_.empty()) {
90     if (!spvParseTargetEnv(spv_env_.c_str(), &target_env))
91       return {Result("Unable to parse SPIR-V target environment"), {}};
92   }
93 
94   auto msg_consumer = [&spv_errors](spv_message_level_t level, const char*,
95                                     const spv_position_t& position,
96                                     const char* message) {
97     switch (level) {
98       case SPV_MSG_FATAL:
99       case SPV_MSG_INTERNAL_ERROR:
100       case SPV_MSG_ERROR:
101         spv_errors += "error: line " + std::to_string(position.index) + ": " +
102                       message + "\n";
103         break;
104       case SPV_MSG_WARNING:
105         spv_errors += "warning: line " + std::to_string(position.index) + ": " +
106                       message + "\n";
107         break;
108       case SPV_MSG_INFO:
109         spv_errors += "info: line " + std::to_string(position.index) + ": " +
110                       message + "\n";
111         break;
112       case SPV_MSG_DEBUG:
113         break;
114     }
115   };
116 
117   spvtools::SpirvTools tools(target_env);
118   tools.SetMessageConsumer(msg_consumer);
119 #endif  // AMBER_ENABLE_SPIRV_TOOLS
120 
121   std::vector<uint32_t> results;
122 
123   if (shader->GetFormat() == kShaderFormatSpirvHex) {
124     Result r = ParseHex(shader->GetData(), &results);
125     if (!r.IsSuccess())
126       return {Result("Unable to parse shader hex."), {}};
127   } else if (shader->GetFormat() == kShaderFormatSpirvBin) {
128     results.resize(shader->GetData().size() / 4);
129     memcpy(results.data(), shader->GetData().data(), shader->GetData().size());
130 
131 #if AMBER_ENABLE_SHADERC
132   } else if (shader->GetFormat() == kShaderFormatGlsl) {
133     Result r = CompileGlsl(shader, &results);
134     if (!r.IsSuccess())
135       return {r, {}};
136 #endif  // AMBER_ENABLE_SHADERC
137 
138 #if AMBER_ENABLE_DXC
139   } else if (shader->GetFormat() == kShaderFormatHlsl) {
140     Result r = CompileHlsl(shader, &results);
141     if (!r.IsSuccess())
142       return {r, {}};
143 #endif  // AMBER_ENABLE_DXC
144 
145 #if AMBER_ENABLE_SPIRV_TOOLS
146   } else if (shader->GetFormat() == kShaderFormatSpirvAsm) {
147     if (!tools.Assemble(shader->GetData(), &results,
148                         spvtools::SpirvTools::kDefaultAssembleOption)) {
149       return {Result("Shader assembly failed: " + spv_errors), {}};
150     }
151 #endif  // AMBER_ENABLE_SPIRV_TOOLS
152 
153 #if AMBER_ENABLE_CLSPV
154   } else if (shader->GetFormat() == kShaderFormatOpenCLC) {
155     Result r = CompileOpenCLC(shader_info, pipeline, target_env, &results);
156     if (!r.IsSuccess())
157       return {r, {}};
158 #endif  // AMBER_ENABLE_CLSPV
159 
160   } else {
161     return {Result("Invalid shader format"), results};
162   }
163 
164   // Validate the shader, but have an option to disable that.
165   // Always use the data member, to avoid an unused-variable warning
166   // when not using SPIRV-Tools support.
167   if (!disable_spirv_validation_) {
168 #if AMBER_ENABLE_SPIRV_TOOLS
169     spvtools::ValidatorOptions options;
170     if (!tools.Validate(results.data(), results.size(), options))
171       return {Result("Invalid shader: " + spv_errors), {}};
172 #endif  // AMBER_ENABLE_SPIRV_TOOLS
173   }
174 
175 #if AMBER_ENABLE_SPIRV_TOOLS
176   // Optimize the shader if any optimizations were specified.
177   if (!shader_info->GetShaderOptimizations().empty()) {
178     spvtools::Optimizer optimizer(target_env);
179     optimizer.SetMessageConsumer(msg_consumer);
180     if (!optimizer.RegisterPassesFromFlags(
181             shader_info->GetShaderOptimizations())) {
182       return {Result("Invalid optimizations: " + spv_errors), {}};
183     }
184     if (!optimizer.Run(results.data(), results.size(), &results))
185       return {Result("Optimizations failed: " + spv_errors), {}};
186   }
187 #endif  // AMBER_ENABLE_SPIRV_TOOLS
188 
189   return {{}, results};
190 }
191 
ParseHex(const std::string & data,std::vector<uint32_t> * result) const192 Result ShaderCompiler::ParseHex(const std::string& data,
193                                 std::vector<uint32_t>* result) const {
194   size_t used = 0;
195   const char* str = data.c_str();
196   uint8_t converted = 0;
197   uint32_t tmp = 0;
198   while (used < data.length()) {
199     char* new_pos = nullptr;
200     uint64_t v = static_cast<uint64_t>(std::strtol(str, &new_pos, 16));
201 
202     ++converted;
203 
204     // TODO(dsinclair): Is this actually right?
205     tmp = tmp | (static_cast<uint32_t>(v) << (8 * (converted - 1)));
206     if (converted == 4) {
207       result->push_back(tmp);
208       tmp = 0;
209       converted = 0;
210     }
211 
212     used += static_cast<size_t>(new_pos - str);
213     str = new_pos;
214   }
215   return {};
216 }
217 
218 #if AMBER_ENABLE_SHADERC
CompileGlsl(const Shader * shader,std::vector<uint32_t> * result) const219 Result ShaderCompiler::CompileGlsl(const Shader* shader,
220                                    std::vector<uint32_t>* result) const {
221   shaderc::Compiler compiler;
222   shaderc::CompileOptions options;
223 
224   uint32_t env = 0u;
225   uint32_t env_version = 0u;
226   uint32_t spirv_version = 0u;
227   auto r = ParseSpvEnv(spv_env_, &env, &env_version, &spirv_version);
228   if (!r.IsSuccess())
229     return r;
230 
231   options.SetTargetEnvironment(static_cast<shaderc_target_env>(env),
232                                env_version);
233   options.SetTargetSpirv(static_cast<shaderc_spirv_version>(spirv_version));
234 
235   shaderc_shader_kind kind;
236   if (shader->GetType() == kShaderTypeCompute)
237     kind = shaderc_compute_shader;
238   else if (shader->GetType() == kShaderTypeFragment)
239     kind = shaderc_fragment_shader;
240   else if (shader->GetType() == kShaderTypeGeometry)
241     kind = shaderc_geometry_shader;
242   else if (shader->GetType() == kShaderTypeVertex)
243     kind = shaderc_vertex_shader;
244   else if (shader->GetType() == kShaderTypeTessellationControl)
245     kind = shaderc_tess_control_shader;
246   else if (shader->GetType() == kShaderTypeTessellationEvaluation)
247     kind = shaderc_tess_evaluation_shader;
248   else if (shader->GetType() == kShaderTypeRayGeneration)
249     kind = shaderc_raygen_shader;
250   else if (shader->GetType() == kShaderTypeAnyHit)
251     kind = shaderc_anyhit_shader;
252   else if (shader->GetType() == kShaderTypeClosestHit)
253     kind = shaderc_closesthit_shader;
254   else if (shader->GetType() == kShaderTypeMiss)
255     kind = shaderc_miss_shader;
256   else if (shader->GetType() == kShaderTypeIntersection)
257     kind = shaderc_intersection_shader;
258   else if (shader->GetType() == kShaderTypeCall)
259     kind = shaderc_callable_shader;
260   else
261     return Result("Unknown shader type");
262 
263   shaderc::SpvCompilationResult module =
264       compiler.CompileGlslToSpv(shader->GetData(), kind, "-", options);
265 
266   if (module.GetCompilationStatus() != shaderc_compilation_status_success)
267     return Result(module.GetErrorMessage());
268 
269   std::copy(module.cbegin(), module.cend(), std::back_inserter(*result));
270   return {};
271 }
272 #else
CompileGlsl(const Shader *,std::vector<uint32_t> *) const273 Result ShaderCompiler::CompileGlsl(const Shader*,
274                                    std::vector<uint32_t>*) const {
275   return {};
276 }
277 #endif  // AMBER_ENABLE_SHADERC
278 
279 #if AMBER_ENABLE_DXC
CompileHlsl(const Shader * shader,std::vector<uint32_t> * result) const280 Result ShaderCompiler::CompileHlsl(const Shader* shader,
281                                    std::vector<uint32_t>* result) const {
282   std::string target;
283   if (shader->GetType() == kShaderTypeCompute)
284     target = "cs_6_2";
285   else if (shader->GetType() == kShaderTypeFragment)
286     target = "ps_6_2";
287   else if (shader->GetType() == kShaderTypeGeometry)
288     target = "gs_6_2";
289   else if (shader->GetType() == kShaderTypeVertex)
290     target = "vs_6_2";
291   else
292     return Result("Unknown shader type");
293 
294   return dxchelper::Compile(shader->GetData(), "main", target, spv_env_,
295                             shader->GetFilePath(), virtual_files_, result);
296 }
297 #else
CompileHlsl(const Shader *,std::vector<uint32_t> *) const298 Result ShaderCompiler::CompileHlsl(const Shader*,
299                                    std::vector<uint32_t>*) const {
300   return {};
301 }
302 #endif  // AMBER_ENABLE_DXC
303 
304 #if AMBER_ENABLE_CLSPV
CompileOpenCLC(Pipeline::ShaderInfo * shader_info,Pipeline * pipeline,spv_target_env env,std::vector<uint32_t> * result) const305 Result ShaderCompiler::CompileOpenCLC(Pipeline::ShaderInfo* shader_info,
306                                       Pipeline* pipeline,
307                                       spv_target_env env,
308                                       std::vector<uint32_t>* result) const {
309   return clspvhelper::Compile(shader_info, pipeline, env, result);
310 }
311 #endif  // AMBER_ENABLE_CLSPV
312 
313 namespace {
314 
315 // Value for the Vulkan API, used in the Shaderc API
316 const uint32_t kVulkan = 0;
317 // Values for versions of the Vulkan API, used in the Shaderc API
318 const uint32_t kVulkan_1_0 = (uint32_t(1) << 22);
319 const uint32_t kVulkan_1_1 = (uint32_t(1) << 22) | (1 << 12);
320 const uint32_t kVulkan_1_2 = (uint32_t(1) << 22) | (2 << 12);
321 // Values for SPIR-V versions, used in the Shaderc API
322 const uint32_t kSpv_1_0 = uint32_t(0x10000);
323 const uint32_t kSpv_1_1 = uint32_t(0x10100);
324 const uint32_t kSpv_1_2 = uint32_t(0x10200);
325 const uint32_t kSpv_1_3 = uint32_t(0x10300);
326 const uint32_t kSpv_1_4 = uint32_t(0x10400);
327 const uint32_t kSpv_1_5 = uint32_t(0x10500);
328 
329 #if AMBER_ENABLE_SHADERC
330 // Check that we have the right values, from the original definitions
331 // in the Shaderc API.
332 static_assert(kVulkan == shaderc_target_env_vulkan,
333               "enum vulkan* value mismatch");
334 static_assert(kVulkan_1_0 == shaderc_env_version_vulkan_1_0,
335               "enum vulkan1.0 value mismatch");
336 static_assert(kVulkan_1_1 == shaderc_env_version_vulkan_1_1,
337               "enum vulkan1.1 value mismatch");
338 static_assert(kVulkan_1_2 == shaderc_env_version_vulkan_1_2,
339               "enum vulkan1.2 value mismatch");
340 static_assert(kSpv_1_0 == shaderc_spirv_version_1_0,
341               "enum spv1.0 value mismatch");
342 static_assert(kSpv_1_1 == shaderc_spirv_version_1_1,
343               "enum spv1.1 value mismatch");
344 static_assert(kSpv_1_2 == shaderc_spirv_version_1_2,
345               "enum spv1.2 value mismatch");
346 static_assert(kSpv_1_3 == shaderc_spirv_version_1_3,
347               "enum spv1.3 value mismatch");
348 static_assert(kSpv_1_4 == shaderc_spirv_version_1_4,
349               "enum spv1.4 value mismatch");
350 static_assert(kSpv_1_5 == shaderc_spirv_version_1_5,
351               "enum spv1.5 value mismatch");
352 #endif
353 
354 }  // namespace
355 
ParseSpvEnv(const std::string & spv_env,uint32_t * target_env,uint32_t * target_env_version,uint32_t * spirv_version)356 Result ParseSpvEnv(const std::string& spv_env,
357                    uint32_t* target_env,
358                    uint32_t* target_env_version,
359                    uint32_t* spirv_version) {
360   if (!target_env || !target_env_version || !spirv_version)
361     return Result("ParseSpvEnv: null pointer parameter");
362 
363   // Use the same values as in Shaderc's shaderc/env.h
364   struct Values {
365     uint32_t env;
366     uint32_t env_version;
367     uint32_t spirv_version;
368   };
369   Values values{kVulkan, kVulkan_1_0, kSpv_1_0};
370 
371   if (spv_env == "" || spv_env == "spv1.0") {
372     values = {kVulkan, kVulkan_1_0, kSpv_1_0};
373   } else if (spv_env == "spv1.1") {
374     values = {kVulkan, kVulkan_1_1, kSpv_1_1};
375   } else if (spv_env == "spv1.2") {
376     values = {kVulkan, kVulkan_1_1, kSpv_1_2};
377   } else if (spv_env == "spv1.3") {
378     values = {kVulkan, kVulkan_1_1, kSpv_1_3};
379   } else if (spv_env == "spv1.4") {
380     // Vulkan 1.2 requires support for SPIR-V 1.4,
381     // but Vulkan 1.1 permits it with an extension.
382     // So Vulkan 1.2 is the right answer here.
383     values = {kVulkan, kVulkan_1_2, kSpv_1_4};
384   } else if (spv_env == "spv1.5") {
385     values = {kVulkan, kVulkan_1_2, kSpv_1_5};
386   } else if (spv_env == "vulkan1.0") {
387     values = {kVulkan, kVulkan_1_0, kSpv_1_0};
388   } else if (spv_env == "vulkan1.1") {
389     // Vulkan 1.1 requires support for SPIR-V 1.3.
390     values = {kVulkan, kVulkan_1_1, kSpv_1_3};
391   } else if (spv_env == "vulkan1.1spv1.4") {
392     values = {kVulkan, kVulkan_1_1, kSpv_1_4};
393   } else if (spv_env == "vulkan1.2") {
394     values = {kVulkan, kVulkan_1_2, kSpv_1_5};
395   } else {
396     return Result(std::string("Unrecognized environment ") + spv_env);
397   }
398 
399   *target_env = values.env;
400   *target_env_version = values.env_version;
401   *spirv_version = values.spirv_version;
402   return {};
403 }
404 
405 }  // namespace amber
406