• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "spirv_cross_helpers_gles.h"
17 
18 #include <cmath>
19 #include <glcorearb.h>
20 namespace Gles {
21 namespace {
22 constexpr const GLenum UINT_TYPES[5][5] = { { 0, 0, 0, 0, 0 }, { 0, GL_UNSIGNED_INT, 0, 0, 0 },
23     { 0, GL_UNSIGNED_INT_VEC2, 0, 0, 0 }, { 0, GL_UNSIGNED_INT_VEC3, 0, 0, 0 }, { 0, GL_UNSIGNED_INT_VEC4, 0, 0, 0 } };
24 constexpr const GLenum FLOAT_TYPES[5][5] = {
25     { 0, 0, 0, 0, 0 },
26     { 0, GL_FLOAT, 0, 0, 0 },
27     { 0, GL_FLOAT_VEC2, GL_FLOAT_MAT2, GL_FLOAT_MAT3x2, GL_FLOAT_MAT4x2 },
28     { 0, GL_FLOAT_VEC3, GL_FLOAT_MAT2x3, GL_FLOAT_MAT3, GL_FLOAT_MAT4x3 },
29     { 0, GL_FLOAT_VEC4, GL_FLOAT_MAT2x4, GL_FLOAT_MAT3x4, GL_FLOAT_MAT4 },
30 };
31 
32 template<typename T>
Max(T && lhs,T && rhs)33 constexpr T Max(T&& lhs, T&& rhs)
34 {
35     return lhs > rhs ? std::forward<T>(lhs) : std::forward<T>(rhs);
36 }
37 
38 static const spirv_cross::SPIRConstant invalid {};
39 
40 constexpr const int32_t NOT_FOUND = -1;
41 constexpr const int32_t INVALID_MATCH = -2;
FindConstant(const std::vector<PushConstantReflection> & reflections,const PushConstantReflection & reflection)42 int32_t FindConstant(const std::vector<PushConstantReflection>& reflections, const PushConstantReflection& reflection)
43 {
44     for (size_t i = 0; i < reflections.size(); i++) {
45         if (reflection.name == reflections[i].name) {
46             // Check that it's actually same and not a conflict!.
47             if (reflection.type != reflections[i].type) {
48                 return INVALID_MATCH;
49             }
50             if (reflection.offset != reflections[i].offset) {
51                 return INVALID_MATCH;
52             }
53             if (reflection.size != reflections[i].size) {
54                 return INVALID_MATCH;
55             }
56             if (reflection.arraySize != reflections[i].arraySize) {
57                 return INVALID_MATCH;
58             }
59             if (reflection.arrayStride != reflections[i].arrayStride) {
60                 return INVALID_MATCH;
61             }
62             if (reflection.matrixStride != reflections[i].matrixStride) {
63                 return INVALID_MATCH;
64             }
65             return (int32_t)i;
66         }
67     }
68     return NOT_FOUND;
69 }
70 
ConstByName(const spirv_cross::CompilerGLSL & compiler,const char * name)71 const spirv_cross::SPIRConstant& ConstByName(const spirv_cross::CompilerGLSL& compiler, const char* name)
72 {
73     const auto& specInfo = ((CoreCompiler&)compiler).GetConstants();
74     for (auto& c : specInfo) {
75         const auto& opName = compiler.get_name(c.self);
76         if (opName == name) {
77             return compiler.get_constant(c.self);
78         }
79     }
80     // is default invalid?
81     return invalid;
82 }
83 
SpecConstByName(const spirv_cross::CompilerGLSL & compiler,const char * name)84 const spirv_cross::SPIRConstant& SpecConstByName(const spirv_cross::CompilerGLSL& compiler, const char* name)
85 {
86     const auto& specInfo = compiler.get_specialization_constants();
87     for (const auto& c : specInfo) {
88         const auto& opName = compiler.get_name(c.id);
89         if (opName == name) {
90             return compiler.get_constant(c.id);
91         }
92     }
93     // is default invalid?
94     return invalid;
95 }
96 } // namespace
97 
98 // inherit from CompilerGLSL to have better access
CoreCompiler(const uint32_t * ir,size_t wordCount)99 CoreCompiler::CoreCompiler(const uint32_t* ir, size_t wordCount) : CompilerGLSL(ir, wordCount) {}
100 
GetConstants() const101 const std::vector<spirv_cross::SPIRConstant> CoreCompiler::GetConstants() const
102 {
103     std::vector<spirv_cross::SPIRConstant> consts;
104     ir.for_each_typed_id<spirv_cross::SPIRConstant>(
105         [&consts](uint32_t, const spirv_cross::SPIRConstant& c) { consts.push_back(c); });
106     return consts;
107 }
108 
GetIr() const109 const spirv_cross::ParsedIR& CoreCompiler::GetIr() const
110 {
111     return ir;
112 }
113 
ReflectPushConstants(spirv_cross::Compiler & compiler,const spirv_cross::ShaderResources & resources,std::vector<PushConstantReflection> & reflections,ShaderStageFlags stage)114 void ReflectPushConstants(spirv_cross::Compiler& compiler, const spirv_cross::ShaderResources& resources,
115     std::vector<PushConstantReflection>& reflections, ShaderStageFlags stage)
116 {
117     static constexpr std::string_view pcName = "CORE_PC_00";
118     static constexpr auto nameBaseSize = pcName.size() - 2U; // length without the 2 digits
119     char ids[64];
120     int id = 0;
121     // There can be only one push_constant_buffer, but since spirv-cross has prepared for this to be relaxed, we will
122     // too.
123     Gles::PushConstantReflection base {};
124     base.stage = stage;
125     base.name = pcName;
126     for (auto& remap : resources.push_constant_buffers) {
127         const auto& blockType = compiler.get_type(remap.base_type_id);
128         (void)(blockType);
129         auto ret = snprintf(ids, sizeof(ids), "%d", id);
130         if (ret < 0) {
131             return;
132         }
133         base.name.resize(nameBaseSize);
134         base.name.append(ids);
135         compiler.set_name(remap.id, base.name);
136         assert((blockType.basetype == spirv_cross::SPIRType::Struct) && "Push constant is not a struct!");
137         ProcessStruct(compiler, base, remap.base_type_id, reflections);
138         id++;
139     }
140 }
141 
142 // Converts specialization constant to normal constant, (to reduce unnecessary clutter in glsl)
ConvertSpecConstToConstant(spirv_cross::CompilerGLSL & compiler,const char * name)143 void ConvertSpecConstToConstant(spirv_cross::CompilerGLSL& compiler, const char* name)
144 {
145     const auto& c = SpecConstByName(compiler, name);
146     if (c.self == invalid.self) {
147         return;
148     }
149     compiler.unset_decoration(c.self, spv::Decoration::DecorationSpecId);
150 }
151 
152 // Converts constant declaration to uniform. (actually only works on spec constants)
ConvertConstantToUniform(const spirv_cross::CompilerGLSL & compiler,std::string & source,const char * name)153 void ConvertConstantToUniform(const spirv_cross::CompilerGLSL& compiler, std::string& source, const char* name)
154 {
155     static constexpr std::string_view constBool = "const bool ";
156     static constexpr std::string_view constUint = "const uint ";
157     static constexpr std::string_view constInt = "const int ";
158     static constexpr std::string_view constFloat = "const float ";
159     static constexpr std::string_view equals = " =";
160     static constexpr auto extraSpace =
161         Max(Max(constBool.size(), constUint.size()), Max(constInt.size(), constFloat.size())) + equals.size();
162     std::string tmp;
163     tmp.reserve(strlen(name) + extraSpace);
164     const auto& constant = ConstByName(compiler, name);
165     if (constant.self == invalid.self) {
166         return;
167     }
168     const auto& type = compiler.get_type(constant.constant_type);
169     if (type.basetype == spirv_cross::SPIRType::Boolean) {
170         tmp += constBool;
171     } else if (type.basetype == spirv_cross::SPIRType::UInt) {
172         tmp += constUint;
173     } else if (type.basetype == spirv_cross::SPIRType::Int) {
174         tmp += constInt;
175     } else if (type.basetype == spirv_cross::SPIRType::Float) {
176         tmp += constFloat;
177     } else {
178         assert(false && "Unhandled specialization constant type");
179     }
180     // We expect spirv_cross to generate them with certain pattern..
181     tmp += name;
182     tmp += equals;
183     const auto p = source.find(tmp);
184     if (p != std::string::npos) {
185         // found it, change it. (changes const to uniform)
186         auto bi = source.begin() + (int64_t)p;
187         auto ei = bi + 6;
188         source.replace(bi, ei, "uniform ");
189 
190         // remove the initializer..
191         const auto p2 = source.find('=', p);
192         const auto p3 = source.find(';', p);
193         if ((p2 != std::string::npos) && (p3 != std::string::npos)) {
194             if (p2 < p3) {
195                 // should be correct (tm)
196                 bi = source.begin() + (int64_t)p2;
197                 ei = source.begin() + (int64_t)p3;
198                 source.erase(bi, ei);
199             }
200         }
201     }
202 }
203 
SetSpecMacro(spirv_cross::CompilerGLSL & compiler,const char * name,uint32_t value)204 void SetSpecMacro(spirv_cross::CompilerGLSL& compiler, const char* name, uint32_t value)
205 {
206     const auto& vc = SpecConstByName(compiler, name);
207     if (vc.self != invalid.self) {
208         const uint32_t constantId = compiler.get_decoration(vc.self, spv::Decoration::DecorationSpecId);
209         char buf[1024];
210         auto ret = snprintf(buf, sizeof(buf), "#define SPIRV_CROSS_CONSTANT_ID_%u %uu", constantId, value);
211         if (ret < 0) {
212             return;
213         }
214         compiler.add_header_line(buf);
215     }
216 }
217 
ProcessStruct(const spirv_cross::Compiler & compiler,const PushConstantReflection & base,uint32_t structTypeId,std::vector<PushConstantReflection> & reflections)218 void ProcessStruct(const spirv_cross::Compiler& compiler, const PushConstantReflection& base, uint32_t structTypeId,
219     std::vector<PushConstantReflection>& reflections)
220 {
221     const auto& structType = compiler.get_type(structTypeId);
222     reflections.reserve(reflections.size() + structType.member_types.size());
223     for (uint32_t bi = 0; bi < structType.member_types.size(); bi++) {
224         const uint32_t memberTypeId = structType.member_types[bi];
225         const auto& memberType = compiler.get_type(memberTypeId);
226         const auto& name = compiler.get_member_name(structTypeId, bi);
227 
228         PushConstantReflection t { base.stage, INVALID_LOCATION, 0u, {},
229             base.offset + compiler.type_struct_member_offset(structType, bi),
230             compiler.get_declared_struct_member_size(structType, bi), 0u, 0u, 0u };
231         t.name.reserve(base.name.size() + 1U + name.size());
232         t.name = base.name;
233         t.name += '.';
234         t.name += name;
235         if (!memberType.array.empty()) {
236             // Get array stride, e.g. float4 foo[]; Will have array stride of 16 bytes.
237             t.arrayStride = compiler.type_struct_member_array_stride(structType, bi);
238             t.arraySize = memberType.array[0]; // We don't support arrays of arrays. just use the size of first.
239         }
240 
241         if (memberType.columns > 1) {
242             // Get bytes stride between columns (if column major), for float4x4 -> 16 bytes.
243             t.matrixStride = compiler.type_struct_member_matrix_stride(structType, bi);
244         }
245 
246         switch (memberType.basetype) {
247             case spirv_cross::SPIRType::Struct:
248                 ProcessStruct(compiler, t, memberTypeId, reflections);
249                 continue;
250                 break;
251             case spirv_cross::SPIRType::UInt:
252                 t.type = UINT_TYPES[memberType.vecsize][memberType.columns];
253                 break;
254             case spirv_cross::SPIRType::Float:
255                 t.type = FLOAT_TYPES[memberType.vecsize][memberType.columns];
256                 break;
257 
258             case spirv_cross::SPIRType::Unknown:
259             case spirv_cross::SPIRType::Void:
260             case spirv_cross::SPIRType::Boolean:
261             case spirv_cross::SPIRType::SByte:
262             case spirv_cross::SPIRType::UByte:
263             case spirv_cross::SPIRType::Short:
264             case spirv_cross::SPIRType::UShort:
265             case spirv_cross::SPIRType::Int:
266             case spirv_cross::SPIRType::Int64:
267             case spirv_cross::SPIRType::UInt64:
268             case spirv_cross::SPIRType::AtomicCounter:
269             case spirv_cross::SPIRType::Half:
270             case spirv_cross::SPIRType::Double:
271             case spirv_cross::SPIRType::Image:
272             case spirv_cross::SPIRType::SampledImage:
273             case spirv_cross::SPIRType::Sampler:
274             case spirv_cross::SPIRType::AccelerationStructure:
275             case spirv_cross::SPIRType::RayQuery:
276             case spirv_cross::SPIRType::ControlPointArray:
277             case spirv_cross::SPIRType::Interpolant:
278             case spirv_cross::SPIRType::Char:
279                 break;
280         }
281         assert((t.type != 0) && "Unhandled Type!");
282         const int32_t res = FindConstant(reflections, t);
283         assert((res >= NOT_FOUND) && "Push constant conflict.");
284         if (res == NOT_FOUND) {
285             reflections.push_back(std::move(t));
286         }
287     }
288 }
289 } // namespace Gles
290