1 //
2 // Copyright 2024 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6
7 #include "compiler/translator/wgsl/OutputUniformBlocks.h"
8
9 #include "angle_gl.h"
10 #include "common/mathutil.h"
11 #include "common/utilities.h"
12 #include "compiler/translator/BaseTypes.h"
13 #include "compiler/translator/Common.h"
14 #include "compiler/translator/Compiler.h"
15 #include "compiler/translator/ImmutableString.h"
16 #include "compiler/translator/ImmutableStringBuilder.h"
17 #include "compiler/translator/InfoSink.h"
18 #include "compiler/translator/IntermNode.h"
19 #include "compiler/translator/SymbolUniqueId.h"
20 #include "compiler/translator/tree_util/IntermTraverse.h"
21 #include "compiler/translator/util.h"
22 #include "compiler/translator/wgsl/Utils.h"
23
24 namespace sh
25 {
26
27 namespace
28 {
29
30 // Traverses the AST and finds all structs that are used in the uniform address space (see the
31 // UniformBlockMetadata struct).
32 class FindUniformAddressSpaceStructs : public TIntermTraverser
33 {
34 public:
FindUniformAddressSpaceStructs(UniformBlockMetadata * uniformBlockMetadata)35 FindUniformAddressSpaceStructs(UniformBlockMetadata *uniformBlockMetadata)
36 : TIntermTraverser(true, false, false), mUniformBlockMetadata(uniformBlockMetadata)
37 {}
38
39 ~FindUniformAddressSpaceStructs() override = default;
40
visitDeclaration(Visit visit,TIntermDeclaration * node)41 bool visitDeclaration(Visit visit, TIntermDeclaration *node) override
42 {
43 const TIntermSequence &sequence = *(node->getSequence());
44
45 TIntermTyped *variable = sequence.front()->getAsTyped();
46 const TType &type = variable->getType();
47
48 // TODO(anglebug.com/376553328): should eventually ASSERT that there are no default uniforms
49 // here.
50 if (type.getQualifier() == EvqUniform)
51 {
52 recordTypesUsedInUniformAddressSpace(&type);
53 }
54
55 return true;
56 }
57
58 private:
59 // Recurses through the tree of types referred to be `type` (which is used in the uniform
60 // address space) and fills in the `mUniformBlockMetadata` struct appropriately.
recordTypesUsedInUniformAddressSpace(const TType * type)61 void recordTypesUsedInUniformAddressSpace(const TType *type)
62 {
63 if (type->isArray())
64 {
65 TType innerType = *type;
66 innerType.toArrayBaseType();
67 recordTypesUsedInUniformAddressSpace(&innerType);
68 }
69 else if (type->getStruct() != nullptr)
70 {
71 mUniformBlockMetadata->structsInUniformAddressSpace.insert(
72 type->getStruct()->uniqueId().get());
73 // Recurse into the types of the fields of this struct type.
74 for (TField *const field : type->getStruct()->fields())
75 {
76 recordTypesUsedInUniformAddressSpace(field->type());
77 }
78 }
79 }
80
81 UniformBlockMetadata *const mUniformBlockMetadata;
82 };
83
84 } // namespace
85
RecordUniformBlockMetadata(TIntermBlock * root,UniformBlockMetadata & outMetadata)86 bool RecordUniformBlockMetadata(TIntermBlock *root, UniformBlockMetadata &outMetadata)
87 {
88 FindUniformAddressSpaceStructs traverser(&outMetadata);
89 root->traverse(&traverser);
90 return true;
91 }
92
OutputUniformWrapperStructsAndConversions(TInfoSinkBase & output,const WGSLGenerationMetadataForUniforms & wgslGenerationMetadataForUniforms)93 bool OutputUniformWrapperStructsAndConversions(
94 TInfoSinkBase &output,
95 const WGSLGenerationMetadataForUniforms &wgslGenerationMetadataForUniforms)
96 {
97
98 auto generate16AlignedWrapperStruct = [&output](const TType &type) {
99 output << "struct " << MakeUniformWrapperStructName(&type) << "\n{\n";
100 output << " @align(16) " << kWrappedStructFieldName << " : ";
101 WriteWgslType(output, type, {});
102 output << "\n};\n";
103 };
104
105 bool generatedVec2WrapperStruct = false;
106
107 for (const TType &type : wgslGenerationMetadataForUniforms.arrayElementTypesInUniforms)
108 {
109 // Structs don't need wrapper structs.
110 ASSERT(type.getStruct() == nullptr);
111 // Multidimensional arrays not currently supported in uniforms
112 ASSERT(!type.isArray());
113
114 if (type.isVector() && type.getNominalSize() == 2)
115 {
116 generatedVec2WrapperStruct = true;
117 }
118 generate16AlignedWrapperStruct(type);
119 }
120
121 // matCx2 is represented as array<ANGLE_wrapped_vec2, C> so if there are matCx2s we need to
122 // generate an ANGLE_wrapped_vec2 struct.
123 if (!wgslGenerationMetadataForUniforms.outputMatCx2Conversion.empty() &&
124 !generatedVec2WrapperStruct)
125 {
126 generate16AlignedWrapperStruct(*new TType(TBasicType::EbtFloat, 2));
127 }
128
129 for (const TType &type :
130 wgslGenerationMetadataForUniforms.arrayElementTypesThatNeedUnwrappingConversions)
131 {
132 // Should be a subset of the types that have had wrapper structs generated above, otherwise
133 // it's impossible to unwrap them!
134 TType innerType = type;
135 innerType.toArrayElementType();
136 ASSERT(wgslGenerationMetadataForUniforms.arrayElementTypesInUniforms.count(innerType) != 0);
137
138 // This could take ptr<uniform, typeName>, with the unrestricted_pointer_parameters
139 // extension. This is probably fine.
140 output << "fn " << MakeUnwrappingArrayConversionFunctionName(&type) << "(wrappedArr : ";
141 WriteWgslType(output, type, {WgslAddressSpace::Uniform});
142 output << ") -> ";
143 WriteWgslType(output, type, {WgslAddressSpace::NonUniform});
144 output << "\n{\n";
145 output << " var retVal : ";
146 WriteWgslType(output, type, {WgslAddressSpace::NonUniform});
147 output << ";\n";
148 output << " for (var i : u32 = 0; i < " << type.getOutermostArraySize() << "; i++) {;\n";
149 output << " retVal[i] = wrappedArr[i]." << kWrappedStructFieldName << ";\n";
150 output << " }\n";
151 output << " return retVal;\n";
152 output << "}\n";
153 }
154
155 for (const TType &type : wgslGenerationMetadataForUniforms.outputMatCx2Conversion)
156 {
157 ASSERT(type.isMatrix() && type.getRows() == 2);
158 output << "fn " << MakeMatCx2ConversionFunctionName(&type) << "(mangledMatrix : ";
159
160 WriteWgslType(output, type, {WgslAddressSpace::Uniform});
161 output << ") -> ";
162 WriteWgslType(output, type, {WgslAddressSpace::NonUniform});
163 output << "\n{\n";
164 output << " var retVal : ";
165 WriteWgslType(output, type, {WgslAddressSpace::NonUniform});
166 output << ";\n";
167
168 if (type.isArray())
169 {
170 output << " for (var i : u32 = 0; i < " << type.getOutermostArraySize()
171 << "; i++) {;\n";
172 output << " retVal[i] = ";
173 }
174 else
175 {
176 output << " retVal = ";
177 }
178
179 TType baseType = type;
180 baseType.toArrayBaseType();
181 WriteWgslType(output, baseType, {WgslAddressSpace::NonUniform});
182 output << "(";
183 for (uint8_t i = 0; i < type.getCols(); i++)
184 {
185 if (i != 0)
186 {
187 output << ", ";
188 }
189 // The mangled matrix is an array and the elements are wrapped vec2s, which can be
190 // passed directly to the matCx2 constructor.
191 output << "mangledMatrix" << (type.isArray() ? "[i]" : "") << "[" << static_cast<int>(i)
192 << "]." << kWrappedStructFieldName;
193 }
194 output << ");\n";
195
196 if (type.isArray())
197 {
198 // Close the for loop.
199 output << " }\n";
200 }
201 output << " return retVal;\n";
202 output << "}\n";
203 }
204
205 return true;
206 }
207
MakeUnwrappingArrayConversionFunctionName(const TType * type)208 ImmutableString MakeUnwrappingArrayConversionFunctionName(const TType *type)
209 {
210 ASSERT(type->getNumArraySizes() <= 1);
211 ImmutableString arrStr = type->isArray() ? BuildConcatenatedImmutableString(
212 "Array", type->getOutermostArraySize(), "_")
213 : kEmptyImmutableString;
214 return BuildConcatenatedImmutableString("ANGLE_Convert_", arrStr,
215 MakeUniformWrapperStructName(type), "_ElementsTo_",
216 type->getBuiltInTypeNameString(), "_Elements");
217 }
218
IsMatCx2(const TType * type)219 bool IsMatCx2(const TType *type)
220 {
221 return type->isMatrix() && type->getRows() == 2;
222 }
223
MakeMatCx2ConversionFunctionName(const TType * type)224 ImmutableString MakeMatCx2ConversionFunctionName(const TType *type)
225 {
226 ASSERT(type->getNumArraySizes() <= 1);
227 ImmutableString arrStr = type->isArray() ? BuildConcatenatedImmutableString(
228 "Array", type->getOutermostArraySize(), "_")
229 : kEmptyImmutableString;
230 return BuildConcatenatedImmutableString("ANGLE_Convert_", arrStr, "Mat", type->getCols(), "x2");
231 }
232
OutputUniformBlocksAndSamplers(TCompiler * compiler,TIntermBlock * root)233 bool OutputUniformBlocksAndSamplers(TCompiler *compiler, TIntermBlock *root)
234 {
235 // TODO(anglebug.com/42267100): This should eventually just be handled the same way as a regular
236 // UBO, like in Vulkan which create a block out of the default uniforms with a traverser:
237 // https://source.chromium.org/chromium/chromium/src/+/main:third_party/angle/src/compiler/translator/spirv/TranslatorSPIRV.cpp;l=70;drc=451093bbaf7fe812bf67d27d760f3bb64c92830b
238 const std::vector<ShaderVariable> &basicUniforms = compiler->getUniforms();
239 TInfoSinkBase &output = compiler->getInfoSink().obj;
240 GlobalVars globalVars = FindGlobalVars(root);
241
242 // Only output a struct at all if there are going to be members.
243 bool outputStructHeader = false;
244 for (const ShaderVariable &shaderVar : basicUniforms)
245 {
246 if (gl::IsOpaqueType(shaderVar.type) || !shaderVar.active)
247 {
248 continue;
249 }
250 if (shaderVar.isBuiltIn())
251 {
252 // gl_DepthRange and also the GLSL 4.2 gl_NumSamples are uniforms.
253 // TODO(anglebug.com/42267100): put gl_DepthRange into default uniform block.
254 continue;
255 }
256
257 // TODO(anglebug.com/42267100): some types will NOT match std140 layout here, namely matCx2,
258 // bool, and arrays with stride less than 16.
259 // (this check does not cover the unsupported case where there is an array of structs of
260 // size < 16).
261 if (shaderVar.type == GL_BOOL)
262 {
263 return false;
264 }
265
266 // Some uniform variables might have been deleted, for example if they were structs that
267 // only contained samplers (which are pulled into separate default uniforms).
268 auto globalVarIter = globalVars.find(shaderVar.name);
269 if (globalVarIter == globalVars.end())
270 {
271 continue;
272 }
273
274 if (!outputStructHeader)
275 {
276 output << "struct ANGLE_DefaultUniformBlock {\n";
277 outputStructHeader = true;
278 }
279 output << " ";
280 output << shaderVar.name << " : ";
281
282 TIntermDeclaration *declNode = globalVarIter->second;
283 const TVariable *astVar = &ViewDeclaration(*declNode).symbol.variable();
284 WriteWgslType(output, astVar->getType(), {WgslAddressSpace::Uniform});
285
286 output << ",\n";
287 }
288 // TODO(anglebug.com/42267100): might need string replacement for @group(0) and @binding(0)
289 // annotations. All WGSL resources available to shaders share the same (group, binding) ID
290 // space.
291 if (outputStructHeader)
292 {
293 ASSERT(compiler->getShaderType() == GL_VERTEX_SHADER ||
294 compiler->getShaderType() == GL_FRAGMENT_SHADER);
295 const uint32_t bindingIndex = compiler->getShaderType() == GL_VERTEX_SHADER
296 ? kDefaultVertexUniformBlockBinding
297 : kDefaultFragmentUniformBlockBinding;
298 output << "};\n\n"
299 << "@group(" << kDefaultUniformBlockBindGroup << ") @binding(" << bindingIndex
300 << ") var<uniform> " << kDefaultUniformBlockVarName << " : "
301 << kDefaultUniformBlockVarType << ";\n";
302 }
303
304 for (const auto &globalVarIter : globalVars)
305 {
306 TIntermDeclaration *declNode = globalVarIter.second;
307 ASSERT(declNode);
308
309 const TIntermSymbol *declSymbol = &ViewDeclaration(*declNode).symbol;
310 const TType &declType = declSymbol->getType();
311 if (!declType.isSampler())
312 {
313 continue;
314 }
315
316 // Note that this may output ignored symbols.
317 output << kTextureSamplerBindingMarker << kAngleSamplerPrefix << declSymbol->getName()
318 << " : ";
319 WriteWgslSamplerType(output, declType, WgslSamplerTypeConfig::Sampler);
320 output << ";\n";
321
322 output << kTextureSamplerBindingMarker << kAngleTexturePrefix << declSymbol->getName()
323 << " : ";
324 WriteWgslSamplerType(output, declType, WgslSamplerTypeConfig::Texture);
325 output << ";\n";
326 }
327
328 return true;
329 }
330
WGSLGetMappedSamplerName(const std::string & originalName)331 std::string WGSLGetMappedSamplerName(const std::string &originalName)
332 {
333 std::string samplerName = originalName;
334
335 // Samplers in structs are extracted.
336 std::replace(samplerName.begin(), samplerName.end(), '.', '_');
337
338 // Remove array elements
339 auto out = samplerName.begin();
340 for (auto in = samplerName.begin(); in != samplerName.end(); in++)
341 {
342 if (*in == '[')
343 {
344 while (*in != ']')
345 {
346 in++;
347 ASSERT(in != samplerName.end());
348 }
349 }
350 else
351 {
352 *out++ = *in;
353 }
354 }
355
356 samplerName.erase(out, samplerName.end());
357
358 return samplerName;
359 }
360
361 } // namespace sh
362