• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright 2021 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // DeclarePerVertexBlocks: Declare gl_PerVertex blocks if not already.
7 //
8 
9 #include "compiler/translator/tree_ops/vulkan/DeclarePerVertexBlocks.h"
10 
11 #include "compiler/translator/Compiler.h"
12 #include "compiler/translator/ImmutableStringBuilder.h"
13 #include "compiler/translator/StaticType.h"
14 #include "compiler/translator/SymbolTable.h"
15 #include "compiler/translator/tree_util/IntermNode_util.h"
16 #include "compiler/translator/tree_util/IntermTraverse.h"
17 #include "compiler/translator/tree_util/ReplaceVariable.h"
18 
19 namespace sh
20 {
21 namespace
22 {
23 using PerVertexMemberFlags = std::array<bool, 4>;
24 
GetPerVertexFieldIndex(const TQualifier qualifier,const ImmutableString & name)25 int GetPerVertexFieldIndex(const TQualifier qualifier, const ImmutableString &name)
26 {
27     switch (qualifier)
28     {
29         case EvqPosition:
30             ASSERT(name == "gl_Position");
31             return 0;
32         case EvqPointSize:
33             ASSERT(name == "gl_PointSize");
34             return 1;
35         case EvqClipDistance:
36             ASSERT(name == "gl_ClipDistance");
37             return 2;
38         case EvqCullDistance:
39             ASSERT(name == "gl_CullDistance");
40             return 3;
41         default:
42             return -1;
43     }
44 }
45 
46 // Traverser that:
47 //
48 // 1. Inspects global qualifier declarations and extracts whether any of the gl_PerVertex built-ins
49 //    are invariant or precise.  These declarations are then dropped.
50 // 2. Finds the array size of gl_ClipDistance and gl_CullDistance built-in, if any.
51 class InspectPerVertexBuiltInsTraverser : public TIntermTraverser
52 {
53   public:
InspectPerVertexBuiltInsTraverser(TCompiler * compiler,TSymbolTable * symbolTable,PerVertexMemberFlags * invariantFlagsOut,PerVertexMemberFlags * preciseFlagsOut,uint32_t * clipDistanceArraySizeOut,uint32_t * cullDistanceArraySizeOut)54     InspectPerVertexBuiltInsTraverser(TCompiler *compiler,
55                                       TSymbolTable *symbolTable,
56                                       PerVertexMemberFlags *invariantFlagsOut,
57                                       PerVertexMemberFlags *preciseFlagsOut,
58                                       uint32_t *clipDistanceArraySizeOut,
59                                       uint32_t *cullDistanceArraySizeOut)
60         : TIntermTraverser(true, false, false, symbolTable),
61           mInvariantFlagsOut(invariantFlagsOut),
62           mPreciseFlagsOut(preciseFlagsOut),
63           mClipDistanceArraySizeOut(clipDistanceArraySizeOut),
64           mCullDistanceArraySizeOut(cullDistanceArraySizeOut)
65     {}
66 
visitGlobalQualifierDeclaration(Visit visit,TIntermGlobalQualifierDeclaration * node)67     bool visitGlobalQualifierDeclaration(Visit visit,
68                                          TIntermGlobalQualifierDeclaration *node) override
69     {
70         TIntermSymbol *symbol = node->getSymbol();
71 
72         const int fieldIndex =
73             GetPerVertexFieldIndex(symbol->getType().getQualifier(), symbol->getName());
74         if (fieldIndex < 0)
75         {
76             return false;
77         }
78 
79         if (node->isInvariant())
80         {
81             (*mInvariantFlagsOut)[fieldIndex] = true;
82         }
83         else if (node->isPrecise())
84         {
85             (*mPreciseFlagsOut)[fieldIndex] = true;
86         }
87 
88         mMultiReplacements.emplace_back(getParentNode()->getAsBlock(), node, TIntermSequence());
89 
90         return false;
91     }
92 
visitSymbol(TIntermSymbol * symbol)93     void visitSymbol(TIntermSymbol *symbol) override
94     {
95         const TType &type = symbol->getType();
96         switch (type.getQualifier())
97         {
98             case EvqClipDistance:
99                 *mClipDistanceArraySizeOut = type.getOutermostArraySize();
100                 break;
101             case EvqCullDistance:
102                 *mCullDistanceArraySizeOut = type.getOutermostArraySize();
103                 break;
104             default:
105                 break;
106         }
107     }
108 
109   private:
110     PerVertexMemberFlags *mInvariantFlagsOut;
111     PerVertexMemberFlags *mPreciseFlagsOut;
112     uint32_t *mClipDistanceArraySizeOut;
113     uint32_t *mCullDistanceArraySizeOut;
114 };
115 
116 // Traverser that:
117 //
118 // 1. Declares the input and output gl_PerVertex types and variables if not already (based on shader
119 //    type).
120 // 2. Turns built-in references into indexes into these variables.
121 class DeclarePerVertexBlocksTraverser : public TIntermTraverser
122 {
123   public:
DeclarePerVertexBlocksTraverser(TCompiler * compiler,TSymbolTable * symbolTable,const PerVertexMemberFlags & invariantFlags,const PerVertexMemberFlags & preciseFlags,uint32_t clipDistanceArraySize,uint32_t cullDistanceArraySize)124     DeclarePerVertexBlocksTraverser(TCompiler *compiler,
125                                     TSymbolTable *symbolTable,
126                                     const PerVertexMemberFlags &invariantFlags,
127                                     const PerVertexMemberFlags &preciseFlags,
128                                     uint32_t clipDistanceArraySize,
129                                     uint32_t cullDistanceArraySize)
130         : TIntermTraverser(true, false, false, symbolTable),
131           mShaderType(compiler->getShaderType()),
132           mShaderVersion(compiler->getShaderVersion()),
133           mResources(compiler->getResources()),
134           mClipDistanceArraySize(clipDistanceArraySize),
135           mCullDistanceArraySize(cullDistanceArraySize),
136           mPerVertexInVar(nullptr),
137           mPerVertexOutVar(nullptr),
138           mPerVertexInVarRedeclared(false),
139           mPerVertexOutVarRedeclared(false),
140           mPerVertexOutInvariantFlags(invariantFlags),
141           mPerVertexOutPreciseFlags(preciseFlags)
142     {}
143 
visitSymbol(TIntermSymbol * symbol)144     void visitSymbol(TIntermSymbol *symbol) override
145     {
146         const TVariable *variable = &symbol->variable();
147         const TType *type         = &variable->getType();
148 
149         // Replace gl_out if necessary.
150         if (mShaderType == GL_TESS_CONTROL_SHADER && type->getQualifier() == EvqPerVertexOut)
151         {
152             ASSERT(variable->name() == "gl_out");
153 
154             // Declare gl_out if not already.
155             if (mPerVertexOutVar == nullptr)
156             {
157                 // Record invariant and precise qualifiers used on the fields so they would be
158                 // applied to the replacement gl_out.
159                 for (const TField *field : type->getInterfaceBlock()->fields())
160                 {
161                     const TType &fieldType = *field->type();
162                     const int fieldIndex =
163                         GetPerVertexFieldIndex(fieldType.getQualifier(), field->name());
164                     ASSERT(fieldIndex >= 0);
165 
166                     if (fieldType.isInvariant())
167                     {
168                         mPerVertexOutInvariantFlags[fieldIndex] = true;
169                     }
170                     if (fieldType.isPrecise())
171                     {
172                         mPerVertexOutPreciseFlags[fieldIndex] = true;
173                     }
174                 }
175 
176                 declareDefaultGlOut();
177             }
178 
179             if (mPerVertexOutVarRedeclared)
180             {
181                 // Traverse the parents and promote the new type.  Replace the root of
182                 // EOpIndex[In]Direct chain.
183                 queueAccessChainReplacement(new TIntermSymbol(mPerVertexOutVar));
184             }
185 
186             return;
187         }
188 
189         // Replace gl_in if necessary.
190         if ((mShaderType == GL_TESS_CONTROL_SHADER || mShaderType == GL_TESS_EVALUATION_SHADER ||
191              mShaderType == GL_GEOMETRY_SHADER) &&
192             type->getQualifier() == EvqPerVertexIn)
193         {
194             ASSERT(variable->name() == "gl_in");
195 
196             // Declare gl_in if not already.
197             if (mPerVertexInVar == nullptr)
198             {
199                 declareDefaultGlIn();
200             }
201 
202             if (mPerVertexInVarRedeclared)
203             {
204                 // Traverse the parents and promote the new type.  Replace the root of
205                 // EOpIndex[In]Direct chain.
206                 queueAccessChainReplacement(new TIntermSymbol(mPerVertexInVar));
207             }
208 
209             return;
210         }
211 
212         // Turn gl_Position, gl_PointSize, gl_ClipDistance and gl_CullDistance into references to
213         // the output gl_PerVertex.  Note that the default gl_PerVertex is declared as follows:
214         //
215         //     out gl_PerVertex
216         //     {
217         //         vec4 gl_Position;
218         //         float gl_PointSize;
219         //         float gl_ClipDistance[];
220         //         float gl_CullDistance[];
221         //     };
222         //
223 
224         if (variable->symbolType() != SymbolType::BuiltIn)
225         {
226             ASSERT(variable->name() != "gl_Position" && variable->name() != "gl_PointSize" &&
227                    variable->name() != "gl_ClipDistance" && variable->name() != "gl_CullDistance" &&
228                    variable->name() != "gl_in" && variable->name() != "gl_out");
229 
230             return;
231         }
232 
233         // If this built-in was already visited, reuse the variable defined for it.
234         auto replacement = mVariableMap.find(variable);
235         if (replacement != mVariableMap.end())
236         {
237             queueReplacement(replacement->second->deepCopy(), OriginalNode::IS_DROPPED);
238             return;
239         }
240 
241         const int fieldIndex = GetPerVertexFieldIndex(type->getQualifier(), variable->name());
242 
243         // Not the built-in we are looking for.
244         if (fieldIndex < 0)
245         {
246             return;
247         }
248 
249         // Declare the output gl_PerVertex if not already.
250         if (mPerVertexOutVar == nullptr)
251         {
252             declareDefaultGlOut();
253         }
254 
255         TType *newType = new TType(*type);
256         newType->setInterfaceBlockField(mPerVertexOutVar->getType().getInterfaceBlock(),
257                                         fieldIndex);
258 
259         TVariable *newVariable = new TVariable(mSymbolTable, variable->name(), newType,
260                                                variable->symbolType(), variable->extensions());
261 
262         TIntermSymbol *newSymbol = new TIntermSymbol(newVariable);
263         mVariableMap[variable]   = newSymbol;
264 
265         queueReplacement(newSymbol, OriginalNode::IS_DROPPED);
266     }
267 
getRedeclaredPerVertexOutVar()268     const TVariable *getRedeclaredPerVertexOutVar()
269     {
270         return mPerVertexOutVarRedeclared ? mPerVertexOutVar : nullptr;
271     }
272 
getRedeclaredPerVertexInVar()273     const TVariable *getRedeclaredPerVertexInVar()
274     {
275         return mPerVertexInVarRedeclared ? mPerVertexInVar : nullptr;
276     }
277 
278   private:
declarePerVertex(TQualifier qualifier,uint32_t arraySize,ImmutableString & variableName)279     const TVariable *declarePerVertex(TQualifier qualifier,
280                                       uint32_t arraySize,
281                                       ImmutableString &variableName)
282     {
283         TFieldList *fields = new TFieldList;
284 
285         const TType *vec4Type  = StaticType::GetBasic<EbtFloat, EbpHigh, 4>();
286         const TType *floatType = StaticType::GetBasic<EbtFloat, EbpHigh, 1>();
287 
288         TType *positionType     = new TType(*vec4Type);
289         TType *pointSizeType    = new TType(*floatType);
290         TType *clipDistanceType = new TType(*floatType);
291         TType *cullDistanceType = new TType(*floatType);
292 
293         positionType->setQualifier(EvqPosition);
294         pointSizeType->setQualifier(EvqPointSize);
295         clipDistanceType->setQualifier(EvqClipDistance);
296         cullDistanceType->setQualifier(EvqCullDistance);
297 
298         TPrecision pointSizePrecision = EbpHigh;
299         if (mShaderType == GL_VERTEX_SHADER)
300         {
301             // gl_PointSize is mediump in ES100 and highp in ES300+.
302             const TVariable *glPointSize = static_cast<const TVariable *>(
303                 mSymbolTable->findBuiltIn(ImmutableString("gl_PointSize"), mShaderVersion));
304             ASSERT(glPointSize);
305 
306             pointSizePrecision = glPointSize->getType().getPrecision();
307         }
308         pointSizeType->setPrecision(pointSizePrecision);
309 
310         // TODO: handle interaction with GS and T*S where the two can have different sizes.  These
311         // values are valid for EvqPerVertexOut only.  For EvqPerVertexIn, the size should come from
312         // the declaration of gl_in.  http://anglebug.com/5466.
313         clipDistanceType->makeArray(std::max(mClipDistanceArraySize, 1u));
314         cullDistanceType->makeArray(std::max(mCullDistanceArraySize, 1u));
315 
316         if (qualifier == EvqPerVertexOut)
317         {
318             positionType->setInvariant(mPerVertexOutInvariantFlags[0]);
319             pointSizeType->setInvariant(mPerVertexOutInvariantFlags[1]);
320             clipDistanceType->setInvariant(mPerVertexOutInvariantFlags[2]);
321             cullDistanceType->setInvariant(mPerVertexOutInvariantFlags[3]);
322 
323             positionType->setPrecise(mPerVertexOutPreciseFlags[0]);
324             pointSizeType->setPrecise(mPerVertexOutPreciseFlags[1]);
325             clipDistanceType->setPrecise(mPerVertexOutPreciseFlags[2]);
326             cullDistanceType->setPrecise(mPerVertexOutPreciseFlags[3]);
327         }
328 
329         fields->push_back(new TField(positionType, ImmutableString("gl_Position"), TSourceLoc(),
330                                      SymbolType::AngleInternal));
331         fields->push_back(new TField(pointSizeType, ImmutableString("gl_PointSize"), TSourceLoc(),
332                                      SymbolType::AngleInternal));
333         fields->push_back(new TField(clipDistanceType, ImmutableString("gl_ClipDistance"),
334                                      TSourceLoc(), SymbolType::AngleInternal));
335         fields->push_back(new TField(cullDistanceType, ImmutableString("gl_CullDistance"),
336                                      TSourceLoc(), SymbolType::AngleInternal));
337 
338         TInterfaceBlock *interfaceBlock =
339             new TInterfaceBlock(mSymbolTable, ImmutableString("gl_PerVertex"), fields,
340                                 TLayoutQualifier::Create(), SymbolType::AngleInternal);
341 
342         TType *interfaceBlockType =
343             new TType(interfaceBlock, qualifier, TLayoutQualifier::Create());
344         if (arraySize > 0)
345         {
346             interfaceBlockType->makeArray(arraySize);
347         }
348 
349         TVariable *interfaceBlockVar =
350             new TVariable(mSymbolTable, variableName, interfaceBlockType,
351                           variableName.empty() ? SymbolType::Empty : SymbolType::AngleInternal);
352 
353         return interfaceBlockVar;
354     }
355 
declareDefaultGlOut()356     void declareDefaultGlOut()
357     {
358         ASSERT(!mPerVertexOutVarRedeclared);
359 
360         // For tessellation control shaders, gl_out is an array of MaxPatchVertices
361         // For other shaders, there's no explicit name or array size
362 
363         ImmutableString varName("");
364         uint32_t arraySize = 0;
365         if (mShaderType == GL_TESS_CONTROL_SHADER)
366         {
367             varName   = ImmutableString("gl_out");
368             arraySize = mResources.MaxPatchVertices;
369         }
370 
371         mPerVertexOutVar           = declarePerVertex(EvqPerVertexOut, arraySize, varName);
372         mPerVertexOutVarRedeclared = true;
373     }
374 
declareDefaultGlIn()375     void declareDefaultGlIn()
376     {
377         ASSERT(!mPerVertexInVarRedeclared);
378 
379         // For tessellation shaders, gl_in is an array of MaxPatchVertices.
380         // For geometry shaders, gl_in is sized based on the primitive type.
381 
382         ImmutableString varName("gl_in");
383         uint32_t arraySize = mResources.MaxPatchVertices;
384         if (mShaderType == GL_GEOMETRY_SHADER)
385         {
386             arraySize =
387                 mSymbolTable->getGlInVariableWithArraySize()->getType().getOutermostArraySize();
388         }
389 
390         mPerVertexInVar           = declarePerVertex(EvqPerVertexIn, arraySize, varName);
391         mPerVertexInVarRedeclared = true;
392     }
393 
394     GLenum mShaderType;
395     int mShaderVersion;
396     const ShBuiltInResources &mResources;
397     uint32_t mClipDistanceArraySize;
398     uint32_t mCullDistanceArraySize;
399 
400     const TVariable *mPerVertexInVar;
401     const TVariable *mPerVertexOutVar;
402 
403     bool mPerVertexInVarRedeclared;
404     bool mPerVertexOutVarRedeclared;
405 
406     // A map of already replaced built-in variables.
407     VariableReplacementMap mVariableMap;
408 
409     // Whether each field is invariant or precise.
410     PerVertexMemberFlags mPerVertexOutInvariantFlags;
411     PerVertexMemberFlags mPerVertexOutPreciseFlags;
412 };
413 
AddPerVertexDecl(TIntermBlock * root,const TVariable * variable)414 void AddPerVertexDecl(TIntermBlock *root, const TVariable *variable)
415 {
416     if (variable == nullptr)
417     {
418         return;
419     }
420 
421     TIntermDeclaration *decl = new TIntermDeclaration;
422     TIntermSymbol *symbol    = new TIntermSymbol(variable);
423     decl->appendDeclarator(symbol);
424 
425     // Insert the declaration before the first function.
426     size_t firstFunctionIndex = FindFirstFunctionDefinitionIndex(root);
427     root->insertChildNodes(firstFunctionIndex, {decl});
428 }
429 }  // anonymous namespace
430 
DeclarePerVertexBlocks(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable)431 bool DeclarePerVertexBlocks(TCompiler *compiler, TIntermBlock *root, TSymbolTable *symbolTable)
432 {
433     if (compiler->getShaderType() == GL_COMPUTE_SHADER ||
434         compiler->getShaderType() == GL_FRAGMENT_SHADER)
435     {
436         return true;
437     }
438 
439     // First, visit all global qualifier declarations and find which built-ins are invariant or
440     // precise.  At the same time, find out the size of gl_ClipDistance and gl_CullDistance arrays.
441     PerVertexMemberFlags invariantFlags = {};
442     PerVertexMemberFlags preciseFlags   = {};
443     uint32_t clipDistanceArraySize = 0, cullDistanceArraySize = 0;
444 
445     InspectPerVertexBuiltInsTraverser infoTraverser(compiler, symbolTable, &invariantFlags,
446                                                     &preciseFlags, &clipDistanceArraySize,
447                                                     &cullDistanceArraySize);
448     root->traverse(&infoTraverser);
449     if (!infoTraverser.updateTree(compiler, root))
450     {
451         return false;
452     }
453 
454     // If not specified, take the clip/cull distance size from the resources.
455     if (clipDistanceArraySize == 0)
456     {
457         clipDistanceArraySize = compiler->getResources().MaxClipDistances;
458     }
459     if (cullDistanceArraySize == 0)
460     {
461         cullDistanceArraySize = compiler->getResources().MaxCullDistances;
462     }
463 
464     // If #pragma STDGL invariant(all) is specified, make all outputs invariant.
465     if (compiler->getPragma().stdgl.invariantAll)
466     {
467         std::fill(invariantFlags.begin(), invariantFlags.end(), true);
468     }
469 
470     // Then declare the in and out gl_PerVertex I/O blocks.
471     DeclarePerVertexBlocksTraverser traverser(compiler, symbolTable, invariantFlags, preciseFlags,
472                                               clipDistanceArraySize, cullDistanceArraySize);
473     root->traverse(&traverser);
474     if (!traverser.updateTree(compiler, root))
475     {
476         return false;
477     }
478 
479     AddPerVertexDecl(root, traverser.getRedeclaredPerVertexOutVar());
480     AddPerVertexDecl(root, traverser.getRedeclaredPerVertexInVar());
481 
482     return compiler->validateAST(root);
483 }
484 }  // namespace sh
485