• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright 2024 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 #include "compiler/translator/wgsl/OutputUniformBlocks.h"
8 
9 #include "angle_gl.h"
10 #include "common/mathutil.h"
11 #include "common/utilities.h"
12 #include "compiler/translator/BaseTypes.h"
13 #include "compiler/translator/Common.h"
14 #include "compiler/translator/Compiler.h"
15 #include "compiler/translator/ImmutableString.h"
16 #include "compiler/translator/ImmutableStringBuilder.h"
17 #include "compiler/translator/InfoSink.h"
18 #include "compiler/translator/IntermNode.h"
19 #include "compiler/translator/SymbolUniqueId.h"
20 #include "compiler/translator/tree_util/IntermTraverse.h"
21 #include "compiler/translator/util.h"
22 #include "compiler/translator/wgsl/Utils.h"
23 
24 namespace sh
25 {
26 
27 namespace
28 {
29 
30 // Traverses the AST and finds all structs that are used in the uniform address space (see the
31 // UniformBlockMetadata struct).
32 class FindUniformAddressSpaceStructs : public TIntermTraverser
33 {
34   public:
FindUniformAddressSpaceStructs(UniformBlockMetadata * uniformBlockMetadata)35     FindUniformAddressSpaceStructs(UniformBlockMetadata *uniformBlockMetadata)
36         : TIntermTraverser(true, false, false), mUniformBlockMetadata(uniformBlockMetadata)
37     {}
38 
39     ~FindUniformAddressSpaceStructs() override = default;
40 
visitDeclaration(Visit visit,TIntermDeclaration * node)41     bool visitDeclaration(Visit visit, TIntermDeclaration *node) override
42     {
43         const TIntermSequence &sequence = *(node->getSequence());
44 
45         TIntermTyped *variable = sequence.front()->getAsTyped();
46         const TType &type      = variable->getType();
47 
48         // TODO(anglebug.com/376553328): should eventually ASSERT that there are no default uniforms
49         // here.
50         if (type.getQualifier() == EvqUniform)
51         {
52             recordTypesUsedInUniformAddressSpace(&type);
53         }
54 
55         return true;
56     }
57 
58   private:
59     // Recurses through the tree of types referred to be `type` (which is used in the uniform
60     // address space) and fills in the `mUniformBlockMetadata` struct appropriately.
recordTypesUsedInUniformAddressSpace(const TType * type)61     void recordTypesUsedInUniformAddressSpace(const TType *type)
62     {
63         if (type->isArray())
64         {
65             TType innerType = *type;
66             innerType.toArrayBaseType();
67             recordTypesUsedInUniformAddressSpace(&innerType);
68         }
69         else if (type->getStruct() != nullptr)
70         {
71             mUniformBlockMetadata->structsInUniformAddressSpace.insert(
72                 type->getStruct()->uniqueId().get());
73             // Recurse into the types of the fields of this struct type.
74             for (TField *const field : type->getStruct()->fields())
75             {
76                 recordTypesUsedInUniformAddressSpace(field->type());
77             }
78         }
79     }
80 
81     UniformBlockMetadata *const mUniformBlockMetadata;
82 };
83 
84 }  // namespace
85 
RecordUniformBlockMetadata(TIntermBlock * root,UniformBlockMetadata & outMetadata)86 bool RecordUniformBlockMetadata(TIntermBlock *root, UniformBlockMetadata &outMetadata)
87 {
88     FindUniformAddressSpaceStructs traverser(&outMetadata);
89     root->traverse(&traverser);
90     return true;
91 }
92 
OutputUniformWrapperStructsAndConversions(TInfoSinkBase & output,const WGSLGenerationMetadataForUniforms & wgslGenerationMetadataForUniforms)93 bool OutputUniformWrapperStructsAndConversions(
94     TInfoSinkBase &output,
95     const WGSLGenerationMetadataForUniforms &wgslGenerationMetadataForUniforms)
96 {
97 
98     auto generate16AlignedWrapperStruct = [&output](const TType &type) {
99         output << "struct " << MakeUniformWrapperStructName(&type) << "\n{\n";
100         output << "  @align(16) " << kWrappedStructFieldName << " : ";
101         WriteWgslType(output, type, {});
102         output << "\n};\n";
103     };
104 
105     bool generatedVec2WrapperStruct = false;
106 
107     for (const TType &type : wgslGenerationMetadataForUniforms.arrayElementTypesInUniforms)
108     {
109         // Structs don't need wrapper structs.
110         ASSERT(type.getStruct() == nullptr);
111         // Multidimensional arrays not currently supported in uniforms
112         ASSERT(!type.isArray());
113 
114         if (type.isVector() && type.getNominalSize() == 2)
115         {
116             generatedVec2WrapperStruct = true;
117         }
118         generate16AlignedWrapperStruct(type);
119     }
120 
121     // matCx2 is represented as array<ANGLE_wrapped_vec2, C> so if there are matCx2s we need to
122     // generate an ANGLE_wrapped_vec2 struct.
123     if (!wgslGenerationMetadataForUniforms.outputMatCx2Conversion.empty() &&
124         !generatedVec2WrapperStruct)
125     {
126         generate16AlignedWrapperStruct(*new TType(TBasicType::EbtFloat, 2));
127     }
128 
129     for (const TType &type :
130          wgslGenerationMetadataForUniforms.arrayElementTypesThatNeedUnwrappingConversions)
131     {
132         // Should be a subset of the types that have had wrapper structs generated above, otherwise
133         // it's impossible to unwrap them!
134         TType innerType = type;
135         innerType.toArrayElementType();
136         ASSERT(wgslGenerationMetadataForUniforms.arrayElementTypesInUniforms.count(innerType) != 0);
137 
138         // This could take ptr<uniform, typeName>, with the unrestricted_pointer_parameters
139         // extension. This is probably fine.
140         output << "fn " << MakeUnwrappingArrayConversionFunctionName(&type) << "(wrappedArr : ";
141         WriteWgslType(output, type, {WgslAddressSpace::Uniform});
142         output << ") -> ";
143         WriteWgslType(output, type, {WgslAddressSpace::NonUniform});
144         output << "\n{\n";
145         output << "  var retVal : ";
146         WriteWgslType(output, type, {WgslAddressSpace::NonUniform});
147         output << ";\n";
148         output << "  for (var i : u32 = 0; i < " << type.getOutermostArraySize() << "; i++) {;\n";
149         output << "    retVal[i] = wrappedArr[i]." << kWrappedStructFieldName << ";\n";
150         output << "  }\n";
151         output << "  return retVal;\n";
152         output << "}\n";
153     }
154 
155     for (const TType &type : wgslGenerationMetadataForUniforms.outputMatCx2Conversion)
156     {
157         ASSERT(type.isMatrix() && type.getRows() == 2);
158         output << "fn " << MakeMatCx2ConversionFunctionName(&type) << "(mangledMatrix : ";
159 
160         WriteWgslType(output, type, {WgslAddressSpace::Uniform});
161         output << ") -> ";
162         WriteWgslType(output, type, {WgslAddressSpace::NonUniform});
163         output << "\n{\n";
164         output << "  var retVal : ";
165         WriteWgslType(output, type, {WgslAddressSpace::NonUniform});
166         output << ";\n";
167 
168         if (type.isArray())
169         {
170             output << "  for (var i : u32 = 0; i < " << type.getOutermostArraySize()
171                    << "; i++) {;\n";
172             output << "    retVal[i] = ";
173         }
174         else
175         {
176             output << "  retVal = ";
177         }
178 
179         TType baseType = type;
180         baseType.toArrayBaseType();
181         WriteWgslType(output, baseType, {WgslAddressSpace::NonUniform});
182         output << "(";
183         for (uint8_t i = 0; i < type.getCols(); i++)
184         {
185             if (i != 0)
186             {
187                 output << ", ";
188             }
189             // The mangled matrix is an array and the elements are wrapped vec2s, which can be
190             // passed directly to the matCx2 constructor.
191             output << "mangledMatrix" << (type.isArray() ? "[i]" : "") << "[" << static_cast<int>(i)
192                    << "]." << kWrappedStructFieldName;
193         }
194         output << ");\n";
195 
196         if (type.isArray())
197         {
198             // Close the for loop.
199             output << "  }\n";
200         }
201         output << "  return retVal;\n";
202         output << "}\n";
203     }
204 
205     return true;
206 }
207 
MakeUnwrappingArrayConversionFunctionName(const TType * type)208 ImmutableString MakeUnwrappingArrayConversionFunctionName(const TType *type)
209 {
210     ASSERT(type->getNumArraySizes() <= 1);
211     ImmutableString arrStr = type->isArray() ? BuildConcatenatedImmutableString(
212                                                    "Array", type->getOutermostArraySize(), "_")
213                                              : kEmptyImmutableString;
214     return BuildConcatenatedImmutableString("ANGLE_Convert_", arrStr,
215                                             MakeUniformWrapperStructName(type), "_ElementsTo_",
216                                             type->getBuiltInTypeNameString(), "_Elements");
217 }
218 
IsMatCx2(const TType * type)219 bool IsMatCx2(const TType *type)
220 {
221     return type->isMatrix() && type->getRows() == 2;
222 }
223 
MakeMatCx2ConversionFunctionName(const TType * type)224 ImmutableString MakeMatCx2ConversionFunctionName(const TType *type)
225 {
226     ASSERT(type->getNumArraySizes() <= 1);
227     ImmutableString arrStr = type->isArray() ? BuildConcatenatedImmutableString(
228                                                    "Array", type->getOutermostArraySize(), "_")
229                                              : kEmptyImmutableString;
230     return BuildConcatenatedImmutableString("ANGLE_Convert_", arrStr, "Mat", type->getCols(), "x2");
231 }
232 
OutputUniformBlocksAndSamplers(TCompiler * compiler,TIntermBlock * root)233 bool OutputUniformBlocksAndSamplers(TCompiler *compiler, TIntermBlock *root)
234 {
235     // TODO(anglebug.com/42267100): This should eventually just be handled the same way as a regular
236     // UBO, like in Vulkan which create a block out of the default uniforms with a traverser:
237     // https://source.chromium.org/chromium/chromium/src/+/main:third_party/angle/src/compiler/translator/spirv/TranslatorSPIRV.cpp;l=70;drc=451093bbaf7fe812bf67d27d760f3bb64c92830b
238     const std::vector<ShaderVariable> &basicUniforms = compiler->getUniforms();
239     TInfoSinkBase &output                            = compiler->getInfoSink().obj;
240     GlobalVars globalVars                            = FindGlobalVars(root);
241 
242     // Only output a struct at all if there are going to be members.
243     bool outputStructHeader = false;
244     for (const ShaderVariable &shaderVar : basicUniforms)
245     {
246         if (gl::IsOpaqueType(shaderVar.type) || !shaderVar.active)
247         {
248             continue;
249         }
250         if (shaderVar.isBuiltIn())
251         {
252             // gl_DepthRange and also the GLSL 4.2 gl_NumSamples are uniforms.
253             // TODO(anglebug.com/42267100): put gl_DepthRange into default uniform block.
254             continue;
255         }
256 
257         // TODO(anglebug.com/42267100): some types will NOT match std140 layout here, namely matCx2,
258         // bool, and arrays with stride less than 16.
259         // (this check does not cover the unsupported case where there is an array of structs of
260         // size < 16).
261         if (shaderVar.type == GL_BOOL)
262         {
263             return false;
264         }
265 
266         // Some uniform variables might have been deleted, for example if they were structs that
267         // only contained samplers (which are pulled into separate default uniforms).
268         auto globalVarIter = globalVars.find(shaderVar.name);
269         if (globalVarIter == globalVars.end())
270         {
271             continue;
272         }
273 
274         if (!outputStructHeader)
275         {
276             output << "struct ANGLE_DefaultUniformBlock {\n";
277             outputStructHeader = true;
278         }
279         output << "  ";
280         output << shaderVar.name << " : ";
281 
282         TIntermDeclaration *declNode = globalVarIter->second;
283         const TVariable *astVar      = &ViewDeclaration(*declNode).symbol.variable();
284         WriteWgslType(output, astVar->getType(), {WgslAddressSpace::Uniform});
285 
286         output << ",\n";
287     }
288     // TODO(anglebug.com/42267100): might need string replacement for @group(0) and @binding(0)
289     // annotations. All WGSL resources available to shaders share the same (group, binding) ID
290     // space.
291     if (outputStructHeader)
292     {
293         ASSERT(compiler->getShaderType() == GL_VERTEX_SHADER ||
294                compiler->getShaderType() == GL_FRAGMENT_SHADER);
295         const uint32_t bindingIndex = compiler->getShaderType() == GL_VERTEX_SHADER
296                                           ? kDefaultVertexUniformBlockBinding
297                                           : kDefaultFragmentUniformBlockBinding;
298         output << "};\n\n"
299                << "@group(" << kDefaultUniformBlockBindGroup << ") @binding(" << bindingIndex
300                << ") var<uniform> " << kDefaultUniformBlockVarName << " : "
301                << kDefaultUniformBlockVarType << ";\n";
302     }
303 
304     for (const auto &globalVarIter : globalVars)
305     {
306         TIntermDeclaration *declNode = globalVarIter.second;
307         ASSERT(declNode);
308 
309         const TIntermSymbol *declSymbol = &ViewDeclaration(*declNode).symbol;
310         const TType &declType           = declSymbol->getType();
311         if (!declType.isSampler())
312         {
313             continue;
314         }
315 
316         // Note that this may output ignored symbols.
317         output << kTextureSamplerBindingMarker << kAngleSamplerPrefix << declSymbol->getName()
318                << " : ";
319         WriteWgslSamplerType(output, declType, WgslSamplerTypeConfig::Sampler);
320         output << ";\n";
321 
322         output << kTextureSamplerBindingMarker << kAngleTexturePrefix << declSymbol->getName()
323                << " : ";
324         WriteWgslSamplerType(output, declType, WgslSamplerTypeConfig::Texture);
325         output << ";\n";
326     }
327 
328     return true;
329 }
330 
WGSLGetMappedSamplerName(const std::string & originalName)331 std::string WGSLGetMappedSamplerName(const std::string &originalName)
332 {
333     std::string samplerName = originalName;
334 
335     // Samplers in structs are extracted.
336     std::replace(samplerName.begin(), samplerName.end(), '.', '_');
337 
338     // Remove array elements
339     auto out = samplerName.begin();
340     for (auto in = samplerName.begin(); in != samplerName.end(); in++)
341     {
342         if (*in == '[')
343         {
344             while (*in != ']')
345             {
346                 in++;
347                 ASSERT(in != samplerName.end());
348             }
349         }
350         else
351         {
352             *out++ = *in;
353         }
354     }
355 
356     samplerName.erase(out, samplerName.end());
357 
358     return samplerName;
359 }
360 
361 }  // namespace sh
362