1 /*
2 * Copyright (c) 2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "spirv_cross_helpers_gles.h"
17
18 #include <cmath>
19 #include <glcorearb.h>
20 namespace Gles {
21 namespace {
22 constexpr const GLenum UINT_TYPES[5][5] = { { 0, 0, 0, 0, 0 }, { 0, GL_UNSIGNED_INT, 0, 0, 0 },
23 { 0, GL_UNSIGNED_INT_VEC2, 0, 0, 0 }, { 0, GL_UNSIGNED_INT_VEC3, 0, 0, 0 }, { 0, GL_UNSIGNED_INT_VEC4, 0, 0, 0 } };
24 constexpr const GLenum FLOAT_TYPES[5][5] = {
25 { 0, 0, 0, 0, 0 },
26 { 0, GL_FLOAT, 0, 0, 0 },
27 { 0, GL_FLOAT_VEC2, GL_FLOAT_MAT2, GL_FLOAT_MAT3x2, GL_FLOAT_MAT4x2 },
28 { 0, GL_FLOAT_VEC3, GL_FLOAT_MAT2x3, GL_FLOAT_MAT3, GL_FLOAT_MAT4x3 },
29 { 0, GL_FLOAT_VEC4, GL_FLOAT_MAT2x4, GL_FLOAT_MAT3x4, GL_FLOAT_MAT4 },
30 };
31
32 template<typename T>
Max(T && lhs,T && rhs)33 constexpr T Max(T&& lhs, T&& rhs)
34 {
35 return lhs > rhs ? std::forward<T>(lhs) : std::forward<T>(rhs);
36 }
37
38 static const spirv_cross::SPIRConstant invalid {};
39
40 constexpr const int32_t NOT_FOUND = -1;
41 constexpr const int32_t INVALID_MATCH = -2;
FindConstant(const std::vector<PushConstantReflection> & reflections,const PushConstantReflection & reflection)42 int32_t FindConstant(const std::vector<PushConstantReflection>& reflections, const PushConstantReflection& reflection)
43 {
44 for (size_t i = 0; i < reflections.size(); i++) {
45 if (reflection.name == reflections[i].name) {
46 // Check that it's actually same and not a conflict!.
47 if (reflection.type != reflections[i].type) {
48 return INVALID_MATCH;
49 }
50 if (reflection.offset != reflections[i].offset) {
51 return INVALID_MATCH;
52 }
53 if (reflection.size != reflections[i].size) {
54 return INVALID_MATCH;
55 }
56 if (reflection.arraySize != reflections[i].arraySize) {
57 return INVALID_MATCH;
58 }
59 if (reflection.arrayStride != reflections[i].arrayStride) {
60 return INVALID_MATCH;
61 }
62 if (reflection.matrixStride != reflections[i].matrixStride) {
63 return INVALID_MATCH;
64 }
65 return (int32_t)i;
66 }
67 }
68 return NOT_FOUND;
69 }
70
ConstByName(const spirv_cross::CompilerGLSL & compiler,const char * name)71 const spirv_cross::SPIRConstant& ConstByName(const spirv_cross::CompilerGLSL& compiler, const char* name)
72 {
73 const auto& specInfo = ((CoreCompiler&)compiler).GetConstants();
74 for (auto& c : specInfo) {
75 const auto& opName = compiler.get_name(c.self);
76 if (opName == name) {
77 return compiler.get_constant(c.self);
78 }
79 }
80 // is default invalid?
81 return invalid;
82 }
83
SpecConstByName(const spirv_cross::CompilerGLSL & compiler,const char * name)84 const spirv_cross::SPIRConstant& SpecConstByName(const spirv_cross::CompilerGLSL& compiler, const char* name)
85 {
86 const auto& specInfo = compiler.get_specialization_constants();
87 for (const auto& c : specInfo) {
88 const auto& opName = compiler.get_name(c.id);
89 if (opName == name) {
90 return compiler.get_constant(c.id);
91 }
92 }
93 // is default invalid?
94 return invalid;
95 }
96 } // namespace
97
98 // inherit from CompilerGLSL to have better access
CoreCompiler(const uint32_t * ir,size_t wordCount)99 CoreCompiler::CoreCompiler(const uint32_t* ir, size_t wordCount) : CompilerGLSL(ir, wordCount) {}
100
GetConstants() const101 const std::vector<spirv_cross::SPIRConstant> CoreCompiler::GetConstants() const
102 {
103 std::vector<spirv_cross::SPIRConstant> consts;
104 ir.for_each_typed_id<spirv_cross::SPIRConstant>(
105 [&consts](uint32_t, const spirv_cross::SPIRConstant& c) { consts.push_back(c); });
106 return consts;
107 }
108
GetIr() const109 const spirv_cross::ParsedIR& CoreCompiler::GetIr() const
110 {
111 return ir;
112 }
113
ReflectPushConstants(spirv_cross::Compiler & compiler,const spirv_cross::ShaderResources & resources,std::vector<PushConstantReflection> & reflections,ShaderStageFlags stage)114 void ReflectPushConstants(spirv_cross::Compiler& compiler, const spirv_cross::ShaderResources& resources,
115 std::vector<PushConstantReflection>& reflections, ShaderStageFlags stage)
116 {
117 static constexpr std::string_view pcName = "CORE_PC_00";
118 static constexpr auto nameBaseSize = pcName.size() - 2U; // length without the 2 digits
119 char ids[64];
120 int id = 0;
121 // There can be only one push_constant_buffer, but since spirv-cross has prepared for this to be relaxed, we will
122 // too.
123 Gles::PushConstantReflection base {};
124 base.stage = stage;
125 base.name = pcName;
126 for (auto& remap : resources.push_constant_buffers) {
127 const auto& blockType = compiler.get_type(remap.base_type_id);
128 (void)(blockType);
129 auto ret = snprintf(ids, sizeof(ids), "%d", id);
130 if (ret < 0) {
131 return;
132 }
133 base.name.resize(nameBaseSize);
134 base.name.append(ids);
135 compiler.set_name(remap.id, base.name);
136 assert((blockType.basetype == spirv_cross::SPIRType::Struct) && "Push constant is not a struct!");
137 ProcessStruct(compiler, base, remap.base_type_id, reflections);
138 id++;
139 }
140 }
141
142 // Converts specialization constant to normal constant, (to reduce unnecessary clutter in glsl)
ConvertSpecConstToConstant(spirv_cross::CompilerGLSL & compiler,const char * name)143 void ConvertSpecConstToConstant(spirv_cross::CompilerGLSL& compiler, const char* name)
144 {
145 const auto& c = SpecConstByName(compiler, name);
146 if (c.self == invalid.self) {
147 return;
148 }
149 compiler.unset_decoration(c.self, spv::Decoration::DecorationSpecId);
150 }
151
152 // Converts constant declaration to uniform. (actually only works on spec constants)
ConvertConstantToUniform(const spirv_cross::CompilerGLSL & compiler,std::string & source,const char * name)153 void ConvertConstantToUniform(const spirv_cross::CompilerGLSL& compiler, std::string& source, const char* name)
154 {
155 static constexpr std::string_view constBool = "const bool ";
156 static constexpr std::string_view constUint = "const uint ";
157 static constexpr std::string_view constInt = "const int ";
158 static constexpr std::string_view constFloat = "const float ";
159 static constexpr std::string_view equals = " =";
160 static constexpr auto extraSpace =
161 Max(Max(constBool.size(), constUint.size()), Max(constInt.size(), constFloat.size())) + equals.size();
162 std::string tmp;
163 tmp.reserve(strlen(name) + extraSpace);
164 const auto& constant = ConstByName(compiler, name);
165 if (constant.self == invalid.self) {
166 return;
167 }
168 const auto& type = compiler.get_type(constant.constant_type);
169 if (type.basetype == spirv_cross::SPIRType::Boolean) {
170 tmp += constBool;
171 } else if (type.basetype == spirv_cross::SPIRType::UInt) {
172 tmp += constUint;
173 } else if (type.basetype == spirv_cross::SPIRType::Int) {
174 tmp += constInt;
175 } else if (type.basetype == spirv_cross::SPIRType::Float) {
176 tmp += constFloat;
177 } else {
178 assert(false && "Unhandled specialization constant type");
179 }
180 // We expect spirv_cross to generate them with certain pattern..
181 tmp += name;
182 tmp += equals;
183 const auto p = source.find(tmp);
184 if (p != std::string::npos) {
185 // found it, change it. (changes const to uniform)
186 auto bi = source.begin() + (int64_t)p;
187 auto ei = bi + 6;
188 source.replace(bi, ei, "uniform ");
189
190 // remove the initializer..
191 const auto p2 = source.find('=', p);
192 const auto p3 = source.find(';', p);
193 if ((p2 != std::string::npos) && (p3 != std::string::npos)) {
194 if (p2 < p3) {
195 // should be correct (tm)
196 bi = source.begin() + (int64_t)p2;
197 ei = source.begin() + (int64_t)p3;
198 source.erase(bi, ei);
199 }
200 }
201 }
202 }
203
SetSpecMacro(spirv_cross::CompilerGLSL & compiler,const char * name,uint32_t value)204 void SetSpecMacro(spirv_cross::CompilerGLSL& compiler, const char* name, uint32_t value)
205 {
206 const auto& vc = SpecConstByName(compiler, name);
207 if (vc.self != invalid.self) {
208 const uint32_t constantId = compiler.get_decoration(vc.self, spv::Decoration::DecorationSpecId);
209 char buf[1024];
210 auto ret = snprintf(buf, sizeof(buf), "#define SPIRV_CROSS_CONSTANT_ID_%u %uu", constantId, value);
211 if (ret < 0) {
212 return;
213 }
214 compiler.add_header_line(buf);
215 }
216 }
217
ProcessStruct(const spirv_cross::Compiler & compiler,const PushConstantReflection & base,uint32_t structTypeId,std::vector<PushConstantReflection> & reflections)218 void ProcessStruct(const spirv_cross::Compiler& compiler, const PushConstantReflection& base, uint32_t structTypeId,
219 std::vector<PushConstantReflection>& reflections)
220 {
221 const auto& structType = compiler.get_type(structTypeId);
222 reflections.reserve(reflections.size() + structType.member_types.size());
223 for (uint32_t bi = 0; bi < structType.member_types.size(); bi++) {
224 const uint32_t memberTypeId = structType.member_types[bi];
225 const auto& memberType = compiler.get_type(memberTypeId);
226 const auto& name = compiler.get_member_name(structTypeId, bi);
227
228 PushConstantReflection t { base.stage, INVALID_LOCATION, 0u, {},
229 base.offset + compiler.type_struct_member_offset(structType, bi),
230 compiler.get_declared_struct_member_size(structType, bi), 0u, 0u, 0u };
231 t.name.reserve(base.name.size() + 1U + name.size());
232 t.name = base.name;
233 t.name += '.';
234 t.name += name;
235 if (!memberType.array.empty()) {
236 // Get array stride, e.g. float4 foo[]; Will have array stride of 16 bytes.
237 t.arrayStride = compiler.type_struct_member_array_stride(structType, bi);
238 t.arraySize = memberType.array[0]; // We don't support arrays of arrays. just use the size of first.
239 }
240
241 if (memberType.columns > 1) {
242 // Get bytes stride between columns (if column major), for float4x4 -> 16 bytes.
243 t.matrixStride = compiler.type_struct_member_matrix_stride(structType, bi);
244 }
245
246 switch (memberType.basetype) {
247 case spirv_cross::SPIRType::Struct:
248 ProcessStruct(compiler, t, memberTypeId, reflections);
249 continue;
250 break;
251 case spirv_cross::SPIRType::UInt:
252 t.type = UINT_TYPES[memberType.vecsize][memberType.columns];
253 break;
254 case spirv_cross::SPIRType::Float:
255 t.type = FLOAT_TYPES[memberType.vecsize][memberType.columns];
256 break;
257
258 case spirv_cross::SPIRType::Unknown:
259 case spirv_cross::SPIRType::Void:
260 case spirv_cross::SPIRType::Boolean:
261 case spirv_cross::SPIRType::SByte:
262 case spirv_cross::SPIRType::UByte:
263 case spirv_cross::SPIRType::Short:
264 case spirv_cross::SPIRType::UShort:
265 case spirv_cross::SPIRType::Int:
266 case spirv_cross::SPIRType::Int64:
267 case spirv_cross::SPIRType::UInt64:
268 case spirv_cross::SPIRType::AtomicCounter:
269 case spirv_cross::SPIRType::Half:
270 case spirv_cross::SPIRType::Double:
271 case spirv_cross::SPIRType::Image:
272 case spirv_cross::SPIRType::SampledImage:
273 case spirv_cross::SPIRType::Sampler:
274 case spirv_cross::SPIRType::AccelerationStructure:
275 case spirv_cross::SPIRType::RayQuery:
276 case spirv_cross::SPIRType::ControlPointArray:
277 case spirv_cross::SPIRType::Interpolant:
278 case spirv_cross::SPIRType::Char:
279 break;
280 }
281 assert((t.type != 0) && "Unhandled Type!");
282 const int32_t res = FindConstant(reflections, t);
283 assert((res >= NOT_FOUND) && "Push constant conflict.");
284 if (res == NOT_FOUND) {
285 reflections.push_back(std::move(t));
286 }
287 }
288 }
289 } // namespace Gles
290