1 //
2 // Copyright 2019 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // EmulateGLDrawID is an AST traverser to convert the gl_DrawID builtin
7 // to a uniform int
8 //
9 // EmulateGLBaseVertex is an AST traverser to convert the gl_BaseVertex builtin
10 // to a uniform int
11 //
12 // EmulateGLBaseInstance is an AST traverser to convert the gl_BaseInstance builtin
13 // to a uniform int
14 //
15
16 #include "compiler/translator/tree_ops/EmulateMultiDrawShaderBuiltins.h"
17
18 #include "angle_gl.h"
19 #include "compiler/translator/StaticType.h"
20 #include "compiler/translator/Symbol.h"
21 #include "compiler/translator/SymbolTable.h"
22 #include "compiler/translator/tree_util/BuiltIn.h"
23 #include "compiler/translator/tree_util/IntermTraverse.h"
24 #include "compiler/translator/tree_util/ReplaceVariable.h"
25 #include "compiler/translator/util.h"
26
27 namespace sh
28 {
29
30 namespace
31 {
32
33 constexpr const ImmutableString kEmulatedGLDrawIDName("angle_DrawID");
34
35 class FindGLDrawIDTraverser : public TIntermTraverser
36 {
37 public:
FindGLDrawIDTraverser()38 FindGLDrawIDTraverser() : TIntermTraverser(true, false, false), mVariable(nullptr) {}
39
getGLDrawIDBuiltinVariable()40 const TVariable *getGLDrawIDBuiltinVariable() { return mVariable; }
41
42 protected:
visitSymbol(TIntermSymbol * node)43 void visitSymbol(TIntermSymbol *node) override
44 {
45 if (&node->variable() == BuiltInVariable::gl_DrawID())
46 {
47 mVariable = &node->variable();
48 }
49 }
50
51 private:
52 const TVariable *mVariable;
53 };
54
55 class AddBaseVertexToGLVertexIDTraverser : public TIntermTraverser
56 {
57 public:
AddBaseVertexToGLVertexIDTraverser()58 AddBaseVertexToGLVertexIDTraverser() : TIntermTraverser(true, false, false) {}
59
60 protected:
visitSymbol(TIntermSymbol * node)61 void visitSymbol(TIntermSymbol *node) override
62 {
63 if (&node->variable() == BuiltInVariable::gl_VertexID())
64 {
65
66 TIntermSymbol *baseVertexRef = new TIntermSymbol(BuiltInVariable::gl_BaseVertex());
67
68 TIntermBinary *addBaseVertex = new TIntermBinary(EOpAdd, node, baseVertexRef);
69 queueReplacement(addBaseVertex, OriginalNode::BECOMES_CHILD);
70 }
71 }
72 };
73
74 constexpr const ImmutableString kEmulatedGLBaseVertexName("angle_BaseVertex");
75
76 class FindGLBaseVertexTraverser : public TIntermTraverser
77 {
78 public:
FindGLBaseVertexTraverser()79 FindGLBaseVertexTraverser() : TIntermTraverser(true, false, false), mVariable(nullptr) {}
80
getGLBaseVertexBuiltinVariable()81 const TVariable *getGLBaseVertexBuiltinVariable() { return mVariable; }
82
83 protected:
visitSymbol(TIntermSymbol * node)84 void visitSymbol(TIntermSymbol *node) override
85 {
86 if (&node->variable() == BuiltInVariable::gl_BaseVertex())
87 {
88 mVariable = &node->variable();
89 }
90 }
91
92 private:
93 const TVariable *mVariable;
94 };
95
96 constexpr const ImmutableString kEmulatedGLBaseInstanceName("angle_BaseInstance");
97
98 class FindGLBaseInstanceTraverser : public TIntermTraverser
99 {
100 public:
FindGLBaseInstanceTraverser()101 FindGLBaseInstanceTraverser() : TIntermTraverser(true, false, false), mVariable(nullptr) {}
102
getGLBaseInstanceBuiltinVariable()103 const TVariable *getGLBaseInstanceBuiltinVariable() { return mVariable; }
104
105 protected:
visitSymbol(TIntermSymbol * node)106 void visitSymbol(TIntermSymbol *node) override
107 {
108 if (&node->variable() == BuiltInVariable::gl_BaseInstance())
109 {
110 mVariable = &node->variable();
111 }
112 }
113
114 private:
115 const TVariable *mVariable;
116 };
117
118 } // namespace
119
EmulateGLDrawID(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable,std::vector<sh::ShaderVariable> * uniforms,bool shouldCollect)120 bool EmulateGLDrawID(TCompiler *compiler,
121 TIntermBlock *root,
122 TSymbolTable *symbolTable,
123 std::vector<sh::ShaderVariable> *uniforms,
124 bool shouldCollect)
125 {
126 FindGLDrawIDTraverser traverser;
127 root->traverse(&traverser);
128 const TVariable *builtInVariable = traverser.getGLDrawIDBuiltinVariable();
129 if (builtInVariable)
130 {
131 const TType *type = StaticType::Get<EbtInt, EbpHigh, EvqUniform, 1, 1>();
132 const TVariable *drawID =
133 new TVariable(symbolTable, kEmulatedGLDrawIDName, type, SymbolType::AngleInternal);
134 const TIntermSymbol *drawIDSymbol = new TIntermSymbol(drawID);
135
136 // AngleInternal variables don't get collected
137 if (shouldCollect)
138 {
139 ShaderVariable uniform;
140 uniform.name = kEmulatedGLDrawIDName.data();
141 uniform.mappedName = kEmulatedGLDrawIDName.data();
142 uniform.type = GLVariableType(*type);
143 uniform.precision = GLVariablePrecision(*type);
144 uniform.staticUse = symbolTable->isStaticallyUsed(*builtInVariable);
145 uniform.active = true;
146 uniform.binding = type->getLayoutQualifier().binding;
147 uniform.location = type->getLayoutQualifier().location;
148 uniform.offset = type->getLayoutQualifier().offset;
149 uniform.readonly = type->getMemoryQualifier().readonly;
150 uniform.writeonly = type->getMemoryQualifier().writeonly;
151 uniforms->push_back(uniform);
152 }
153
154 DeclareGlobalVariable(root, drawID);
155 if (!ReplaceVariableWithTyped(compiler, root, builtInVariable, drawIDSymbol))
156 {
157 return false;
158 }
159 }
160
161 return true;
162 }
163
EmulateGLBaseVertexBaseInstance(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable,std::vector<sh::ShaderVariable> * uniforms,bool shouldCollect,bool addBaseVertexToVertexID)164 bool EmulateGLBaseVertexBaseInstance(TCompiler *compiler,
165 TIntermBlock *root,
166 TSymbolTable *symbolTable,
167 std::vector<sh::ShaderVariable> *uniforms,
168 bool shouldCollect,
169 bool addBaseVertexToVertexID)
170 {
171 bool addBaseVertex = false, addBaseInstance = false;
172 ShaderVariable uniformBaseVertex, uniformBaseInstance;
173
174 if (addBaseVertexToVertexID)
175 {
176 // This is a workaround for Mac AMD GPU
177 // Replace gl_VertexID with (gl_VertexID + gl_BaseVertex)
178 AddBaseVertexToGLVertexIDTraverser traverserVertexID;
179 root->traverse(&traverserVertexID);
180 if (!traverserVertexID.updateTree(compiler, root))
181 {
182 return false;
183 }
184 }
185
186 FindGLBaseVertexTraverser traverserBaseVertex;
187 root->traverse(&traverserBaseVertex);
188 const TVariable *builtInVariableBaseVertex =
189 traverserBaseVertex.getGLBaseVertexBuiltinVariable();
190
191 if (builtInVariableBaseVertex)
192 {
193 const TVariable *baseVertex = BuiltInVariable::angle_BaseVertex();
194 const TType &type = baseVertex->getType();
195 const TIntermSymbol *baseVertexSymbol = new TIntermSymbol(baseVertex);
196
197 // AngleInternal variables don't get collected
198 if (shouldCollect)
199 {
200 uniformBaseVertex.name = kEmulatedGLBaseVertexName.data();
201 uniformBaseVertex.mappedName = kEmulatedGLBaseVertexName.data();
202 uniformBaseVertex.type = GLVariableType(type);
203 uniformBaseVertex.precision = GLVariablePrecision(type);
204 uniformBaseVertex.staticUse = symbolTable->isStaticallyUsed(*builtInVariableBaseVertex);
205 uniformBaseVertex.active = true;
206 uniformBaseVertex.binding = type.getLayoutQualifier().binding;
207 uniformBaseVertex.location = type.getLayoutQualifier().location;
208 uniformBaseVertex.offset = type.getLayoutQualifier().offset;
209 uniformBaseVertex.readonly = type.getMemoryQualifier().readonly;
210 uniformBaseVertex.writeonly = type.getMemoryQualifier().writeonly;
211 addBaseVertex = true;
212 }
213
214 DeclareGlobalVariable(root, baseVertex);
215 if (!ReplaceVariableWithTyped(compiler, root, builtInVariableBaseVertex, baseVertexSymbol))
216 {
217 return false;
218 }
219 }
220
221 FindGLBaseInstanceTraverser traverserInstance;
222 root->traverse(&traverserInstance);
223 const TVariable *builtInVariableBaseInstance =
224 traverserInstance.getGLBaseInstanceBuiltinVariable();
225
226 if (builtInVariableBaseInstance)
227 {
228 const TVariable *baseInstance = BuiltInVariable::angle_BaseInstance();
229 const TType &type = baseInstance->getType();
230 const TIntermSymbol *baseInstanceSymbol = new TIntermSymbol(baseInstance);
231
232 // AngleInternal variables don't get collected
233 if (shouldCollect)
234 {
235 uniformBaseInstance.name = kEmulatedGLBaseInstanceName.data();
236 uniformBaseInstance.mappedName = kEmulatedGLBaseInstanceName.data();
237 uniformBaseInstance.type = GLVariableType(type);
238 uniformBaseInstance.precision = GLVariablePrecision(type);
239 uniformBaseInstance.staticUse =
240 symbolTable->isStaticallyUsed(*builtInVariableBaseInstance);
241 uniformBaseInstance.active = true;
242 uniformBaseInstance.binding = type.getLayoutQualifier().binding;
243 uniformBaseInstance.location = type.getLayoutQualifier().location;
244 uniformBaseInstance.offset = type.getLayoutQualifier().offset;
245 uniformBaseInstance.readonly = type.getMemoryQualifier().readonly;
246 uniformBaseInstance.writeonly = type.getMemoryQualifier().writeonly;
247 addBaseInstance = true;
248 }
249
250 DeclareGlobalVariable(root, baseInstance);
251 if (!ReplaceVariableWithTyped(compiler, root, builtInVariableBaseInstance,
252 baseInstanceSymbol))
253 {
254 return false;
255 }
256 }
257
258 // Make sure the order in uniforms is the same as the traverse order
259 if (addBaseInstance)
260 {
261 uniforms->push_back(uniformBaseInstance);
262 }
263 if (addBaseVertex)
264 {
265 uniforms->push_back(uniformBaseVertex);
266 }
267
268 return true;
269 }
270
271 } // namespace sh
272