1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // RewriteSampleMaskVariable.cpp: Find any references to gl_SampleMask and gl_SampleMaskIn, and
7 // rewrite it with ANGLESampleMask or ANGLESampleMaskIn.
8 //
9
10 #include "compiler/translator/tree_util/RewriteSampleMaskVariable.h"
11
12 #include "common/bitset_utils.h"
13 #include "common/debug.h"
14 #include "common/utilities.h"
15 #include "compiler/translator/Compiler.h"
16 #include "compiler/translator/SymbolTable.h"
17 #include "compiler/translator/tree_util/BuiltIn.h"
18 #include "compiler/translator/tree_util/IntermNode_util.h"
19 #include "compiler/translator/tree_util/IntermTraverse.h"
20 #include "compiler/translator/tree_util/RunAtTheBeginningOfShader.h"
21 #include "compiler/translator/tree_util/RunAtTheEndOfShader.h"
22
23 namespace sh
24 {
25 namespace
26 {
27 constexpr int kMaxIndexForSampleMaskVar = 0;
28 constexpr int kFullSampleMask = 0xFFFFFFFF;
29
30 // Traverse the tree and collect the redeclaration and replace all non constant index references of
31 // gl_SampleMask or gl_SampleMaskIn with constant index references
32 class GLSampleMaskRelatedReferenceTraverser : public TIntermTraverser
33 {
34 public:
GLSampleMaskRelatedReferenceTraverser(const TIntermSymbol ** redeclaredSymOut,const ImmutableString & targetStr)35 GLSampleMaskRelatedReferenceTraverser(const TIntermSymbol **redeclaredSymOut,
36 const ImmutableString &targetStr)
37 : TIntermTraverser(true, false, false),
38 mRedeclaredSym(redeclaredSymOut),
39 mTargetStr(targetStr)
40 {
41 *mRedeclaredSym = nullptr;
42 }
43
visitDeclaration(Visit visit,TIntermDeclaration * node)44 bool visitDeclaration(Visit visit, TIntermDeclaration *node) override
45 {
46 // If gl_SampleMask is redeclared, we need to collect its information
47 const TIntermSequence &sequence = *(node->getSequence());
48
49 if (sequence.size() != 1)
50 {
51 return true;
52 }
53
54 TIntermTyped *variable = sequence.front()->getAsTyped();
55 TIntermSymbol *symbol = variable->getAsSymbolNode();
56 if (symbol == nullptr || symbol->getName() != mTargetStr)
57 {
58 return true;
59 }
60
61 *mRedeclaredSym = symbol;
62
63 return true;
64 }
65
visitBinary(Visit visit,TIntermBinary * node)66 bool visitBinary(Visit visit, TIntermBinary *node) override
67 {
68 TOperator op = node->getOp();
69 if (op != EOpIndexDirect && op != EOpIndexIndirect)
70 {
71 return true;
72 }
73 TIntermSymbol *left = node->getLeft()->getAsSymbolNode();
74 if (!left)
75 {
76 return true;
77 }
78 if (left->getName() != mTargetStr)
79 {
80 return true;
81 }
82 const TConstantUnion *constIdx = node->getRight()->getConstantValue();
83 if (!constIdx)
84 {
85 if (node->getRight()->hasSideEffects())
86 {
87 insertStatementInParentBlock(node->getRight());
88 }
89
90 queueReplacementWithParent(node, node->getRight(),
91 CreateIndexNode(kMaxIndexForSampleMaskVar),
92 OriginalNode::IS_DROPPED);
93 }
94
95 return true;
96 }
97
98 private:
99 const TIntermSymbol **mRedeclaredSym;
100 const ImmutableString mTargetStr;
101 };
102
103 } // anonymous namespace
104
RewriteSampleMask(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable,const TIntermTyped * numSamplesUniform)105 ANGLE_NO_DISCARD bool RewriteSampleMask(TCompiler *compiler,
106 TIntermBlock *root,
107 TSymbolTable *symbolTable,
108 const TIntermTyped *numSamplesUniform)
109 {
110 const TIntermSymbol *redeclaredGLSampleMask = nullptr;
111 GLSampleMaskRelatedReferenceTraverser indexTraverser(&redeclaredGLSampleMask,
112 ImmutableString("gl_SampleMask"));
113
114 root->traverse(&indexTraverser);
115 if (!indexTraverser.updateTree(compiler, root))
116 {
117 return false;
118 }
119
120 // Retrieve gl_SampleMask variable reference
121 // Search user redeclared it first
122 const TVariable *glSampleMaskVar = nullptr;
123 if (redeclaredGLSampleMask)
124 {
125 glSampleMaskVar = &redeclaredGLSampleMask->variable();
126 }
127 else
128 {
129 // User defined not found, find in built-in table
130 glSampleMaskVar = static_cast<const TVariable *>(symbolTable->findBuiltIn(
131 ImmutableString("gl_SampleMask"), compiler->getShaderVersion()));
132 }
133 if (!glSampleMaskVar)
134 {
135 return false;
136 }
137
138 // Current ANGLE assumes that the maximum number of samples is less than or equal to
139 // VK_SAMPLE_COUNT_32_BIT. So, the size of gl_SampleMask array is always one.
140 const unsigned int arraySizeOfSampleMask = glSampleMaskVar->getType().getOutermostArraySize();
141 ASSERT(arraySizeOfSampleMask == 1);
142
143 TIntermSymbol *glSampleMaskSymbol = new TIntermSymbol(glSampleMaskVar);
144
145 // if (ANGLEUniforms.numSamples == 1)
146 // {
147 // gl_SampleMask[0] = int(0xFFFFFFFF);
148 // }
149 TIntermConstantUnion *singleSampleCount = CreateUIntNode(1);
150 TIntermBinary *equalTo =
151 new TIntermBinary(EOpEqual, numSamplesUniform->deepCopy(), singleSampleCount);
152
153 TIntermBlock *trueBlock = new TIntermBlock();
154
155 TIntermBinary *sampleMaskVar = new TIntermBinary(EOpIndexDirect, glSampleMaskSymbol->deepCopy(),
156 CreateIndexNode(kMaxIndexForSampleMaskVar));
157 TIntermConstantUnion *fullSampleMask = CreateIndexNode(kFullSampleMask);
158 TIntermBinary *assignment = new TIntermBinary(EOpAssign, sampleMaskVar, fullSampleMask);
159
160 trueBlock->appendStatement(assignment);
161
162 TIntermIfElse *multiSampleOrNot = new TIntermIfElse(equalTo, trueBlock, nullptr);
163
164 return RunAtTheEndOfShader(compiler, root, multiSampleOrNot, symbolTable);
165 }
166
RewriteSampleMaskIn(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable)167 ANGLE_NO_DISCARD bool RewriteSampleMaskIn(TCompiler *compiler,
168 TIntermBlock *root,
169 TSymbolTable *symbolTable)
170 {
171 const TIntermSymbol *redeclaredGLSampleMaskIn = nullptr;
172 GLSampleMaskRelatedReferenceTraverser indexTraverser(&redeclaredGLSampleMaskIn,
173 ImmutableString("gl_SampleMaskIn"));
174
175 root->traverse(&indexTraverser);
176 if (!indexTraverser.updateTree(compiler, root))
177 {
178 return false;
179 }
180
181 // Retrieve gl_SampleMaskIn variable reference
182 const TVariable *glSampleMaskInVar = nullptr;
183 glSampleMaskInVar = static_cast<const TVariable *>(
184 symbolTable->findBuiltIn(ImmutableString("gl_SampleMaskIn"), compiler->getShaderVersion()));
185 if (!glSampleMaskInVar)
186 {
187 return false;
188 }
189
190 // Current ANGLE assumes that the maximum number of samples is less than or equal to
191 // VK_SAMPLE_COUNT_32_BIT. So, the size of gl_SampleMask array is always one.
192 const unsigned int arraySizeOfSampleMaskIn =
193 glSampleMaskInVar->getType().getOutermostArraySize();
194 ASSERT(arraySizeOfSampleMaskIn == 1);
195
196 return true;
197 }
198
199 } // namespace sh
200