1 //
2 // Copyright 2017 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // InitOutputVariables_test.cpp: Tests correctness of the AST pass enabled through
7 // SH_INIT_OUTPUT_VARIABLES.
8 //
9
10 #include "common/angleutils.h"
11
12 #include "compiler/translator/SymbolTable.h"
13 #include "compiler/translator/tree_util/FindMain.h"
14 #include "compiler/translator/tree_util/IntermNode_util.h"
15 #include "compiler/translator/tree_util/IntermTraverse.h"
16 #include "tests/test_utils/ShaderCompileTreeTest.h"
17
18 #include <algorithm>
19
20 namespace sh
21 {
22
23 namespace
24 {
25
26 typedef std::vector<TIntermTyped *> ExpectedLValues;
27
AreSymbolsTheSame(const TIntermSymbol * expected,const TIntermSymbol * candidate)28 bool AreSymbolsTheSame(const TIntermSymbol *expected, const TIntermSymbol *candidate)
29 {
30 if (expected == nullptr || candidate == nullptr)
31 {
32 return false;
33 }
34 const TType &expectedType = expected->getType();
35 const TType &candidateType = candidate->getType();
36 const bool sameTypes = expectedType == candidateType &&
37 expectedType.getPrecision() == candidateType.getPrecision() &&
38 expectedType.getQualifier() == candidateType.getQualifier();
39 const bool sameSymbols = (expected->variable().symbolType() == SymbolType::Empty &&
40 candidate->variable().symbolType() == SymbolType::Empty) ||
41 expected->getName() == candidate->getName();
42 return sameSymbols && sameTypes;
43 }
44
AreLValuesTheSame(TIntermTyped * expected,TIntermTyped * candidate)45 bool AreLValuesTheSame(TIntermTyped *expected, TIntermTyped *candidate)
46 {
47 const TIntermBinary *expectedBinary = expected->getAsBinaryNode();
48 if (expectedBinary)
49 {
50 ASSERT(expectedBinary->getOp() == EOpIndexDirect);
51 const TIntermBinary *candidateBinary = candidate->getAsBinaryNode();
52 if (candidateBinary == nullptr || candidateBinary->getOp() != EOpIndexDirect)
53 {
54 return false;
55 }
56 if (expectedBinary->getRight()->getAsConstantUnion()->getIConst(0) !=
57 candidateBinary->getRight()->getAsConstantUnion()->getIConst(0))
58 {
59 return false;
60 }
61 return AreSymbolsTheSame(expectedBinary->getLeft()->getAsSymbolNode(),
62 candidateBinary->getLeft()->getAsSymbolNode());
63 }
64 return AreSymbolsTheSame(expected->getAsSymbolNode(), candidate->getAsSymbolNode());
65 }
66
CreateLValueNode(const ImmutableString & lValueName,const TType & type)67 TIntermTyped *CreateLValueNode(const ImmutableString &lValueName, const TType &type)
68 {
69 // We're using a mock symbol table here, don't need to assign proper symbol ids to these nodes.
70 TSymbolTable symbolTable;
71 TVariable *variable =
72 new TVariable(&symbolTable, lValueName, new TType(type), SymbolType::UserDefined);
73 return new TIntermSymbol(variable);
74 }
75
CreateIndexedLValueNodeList(const ImmutableString & lValueName,const TType & elementType,unsigned arraySize)76 ExpectedLValues CreateIndexedLValueNodeList(const ImmutableString &lValueName,
77 const TType &elementType,
78 unsigned arraySize)
79 {
80 ASSERT(elementType.isArray() == false);
81 TType *arrayType = new TType(elementType);
82 arrayType->makeArray(arraySize);
83
84 // We're using a mock symbol table here, don't need to assign proper symbol ids to these nodes.
85 TSymbolTable symbolTable;
86 TVariable *variable =
87 new TVariable(&symbolTable, lValueName, arrayType, SymbolType::UserDefined);
88 TIntermSymbol *arraySymbol = new TIntermSymbol(variable);
89
90 ExpectedLValues expected(arraySize);
91 for (unsigned index = 0u; index < arraySize; ++index)
92 {
93 expected[index] = new TIntermBinary(EOpIndexDirect, arraySymbol->deepCopy(),
94 CreateIndexNode(static_cast<int>(index)));
95 }
96 return expected;
97 }
98
99 // VerifyOutputVariableInitializers traverses the subtree covering main and collects the lvalues in
100 // assignments for which the rvalue is an expression containing only zero constants.
101 class VerifyOutputVariableInitializers final : public TIntermTraverser
102 {
103 public:
VerifyOutputVariableInitializers(TIntermBlock * root)104 VerifyOutputVariableInitializers(TIntermBlock *root) : TIntermTraverser(true, false, false)
105 {
106 ASSERT(root != nullptr);
107
108 // The traversal starts in the body of main because this is where the varyings and output
109 // variables are initialized.
110 sh::TIntermFunctionDefinition *main = FindMain(root);
111 ASSERT(main != nullptr);
112 main->traverse(this);
113 }
114
visitBinary(Visit visit,TIntermBinary * node)115 bool visitBinary(Visit visit, TIntermBinary *node) override
116 {
117 if (node->getOp() == EOpAssign && IsZero(node->getRight()))
118 {
119 mCandidateLValues.push_back(node->getLeft());
120 return false;
121 }
122 return true;
123 }
124
125 // The collected lvalues are considered valid if every expected lvalue in expectedLValues is
126 // matched by name and type with any lvalue in mCandidateLValues.
areAllExpectedLValuesFound(const ExpectedLValues & expectedLValues) const127 bool areAllExpectedLValuesFound(const ExpectedLValues &expectedLValues) const
128 {
129 for (size_t i = 0u; i < expectedLValues.size(); ++i)
130 {
131 if (!isExpectedLValueFound(expectedLValues[i]))
132 {
133 return false;
134 }
135 }
136 return true;
137 }
138
isExpectedLValueFound(TIntermTyped * expectedLValue) const139 bool isExpectedLValueFound(TIntermTyped *expectedLValue) const
140 {
141 bool isFound = false;
142 for (size_t j = 0; j < mCandidateLValues.size() && !isFound; ++j)
143 {
144 isFound = AreLValuesTheSame(expectedLValue, mCandidateLValues[j]);
145 }
146 return isFound;
147 }
148
getCandidates() const149 const ExpectedLValues &getCandidates() const { return mCandidateLValues; }
150
151 private:
152 ExpectedLValues mCandidateLValues;
153 };
154
155 // Traverses the AST and records a pointer to a structure with a given name.
156 class FindStructByName final : public TIntermTraverser
157 {
158 public:
FindStructByName(const ImmutableString & structName)159 FindStructByName(const ImmutableString &structName)
160 : TIntermTraverser(true, false, false), mStructName(structName), mStructure(nullptr)
161 {}
162
visitSymbol(TIntermSymbol * symbol)163 void visitSymbol(TIntermSymbol *symbol) override
164 {
165 if (isStructureFound())
166 {
167 return;
168 }
169
170 const TStructure *structure = symbol->getType().getStruct();
171
172 if (structure != nullptr && structure->symbolType() != SymbolType::Empty &&
173 structure->name() == mStructName)
174 {
175 mStructure = structure;
176 }
177 }
178
isStructureFound() const179 bool isStructureFound() const { return mStructure != nullptr; }
getStructure() const180 const TStructure *getStructure() const { return mStructure; }
181
182 private:
183 ImmutableString mStructName;
184 const TStructure *mStructure;
185 };
186
187 } // namespace
188
189 class InitOutputVariablesWebGL2Test : public ShaderCompileTreeTest
190 {
191 public:
SetUp()192 void SetUp() override
193 {
194 mExtraCompileOptions |= SH_VARIABLES;
195 mExtraCompileOptions |= SH_INIT_OUTPUT_VARIABLES;
196 if (getShaderType() == GL_VERTEX_SHADER)
197 {
198 mExtraCompileOptions |= SH_INIT_GL_POSITION;
199 }
200 ShaderCompileTreeTest::SetUp();
201 }
202
203 protected:
getShaderSpec() const204 ShShaderSpec getShaderSpec() const override { return SH_WEBGL2_SPEC; }
205 };
206
207 class InitOutputVariablesWebGL2VertexShaderTest : public InitOutputVariablesWebGL2Test
208 {
209 protected:
getShaderType() const210 ::GLenum getShaderType() const override { return GL_VERTEX_SHADER; }
211 };
212
213 class InitOutputVariablesWebGL2FragmentShaderTest : public InitOutputVariablesWebGL2Test
214 {
215 protected:
getShaderType() const216 ::GLenum getShaderType() const override { return GL_FRAGMENT_SHADER; }
initResources(ShBuiltInResources * resources)217 void initResources(ShBuiltInResources *resources) override
218 {
219 resources->EXT_draw_buffers = 1;
220 resources->MaxDrawBuffers = 2;
221 }
222 };
223
224 class InitOutputVariablesWebGL1FragmentShaderTest : public ShaderCompileTreeTest
225 {
226 public:
InitOutputVariablesWebGL1FragmentShaderTest()227 InitOutputVariablesWebGL1FragmentShaderTest()
228 {
229 mExtraCompileOptions |= SH_VARIABLES;
230 mExtraCompileOptions |= SH_INIT_OUTPUT_VARIABLES;
231 }
232
233 protected:
getShaderType() const234 ::GLenum getShaderType() const override { return GL_FRAGMENT_SHADER; }
getShaderSpec() const235 ShShaderSpec getShaderSpec() const override { return SH_WEBGL_SPEC; }
initResources(ShBuiltInResources * resources)236 void initResources(ShBuiltInResources *resources) override
237 {
238 resources->EXT_draw_buffers = 1;
239 resources->MaxDrawBuffers = 2;
240 }
241 };
242
243 class InitOutputVariablesVertexShaderClipDistanceTest : public ShaderCompileTreeTest
244 {
245 public:
InitOutputVariablesVertexShaderClipDistanceTest()246 InitOutputVariablesVertexShaderClipDistanceTest()
247 {
248 mExtraCompileOptions |= SH_VARIABLES;
249 mExtraCompileOptions |= SH_INIT_OUTPUT_VARIABLES;
250 mExtraCompileOptions |= SH_VALIDATE_AST;
251 }
252
253 protected:
getShaderType() const254 ::GLenum getShaderType() const override { return GL_VERTEX_SHADER; }
getShaderSpec() const255 ShShaderSpec getShaderSpec() const override { return SH_GLES2_SPEC; }
initResources(ShBuiltInResources * resources)256 void initResources(ShBuiltInResources *resources) override
257 {
258 resources->APPLE_clip_distance = 1;
259 resources->MaxClipDistances = 8;
260 }
261 };
262
263 // Test the initialization of output variables with various qualifiers in a vertex shader.
TEST_F(InitOutputVariablesWebGL2VertexShaderTest,OutputAllQualifiers)264 TEST_F(InitOutputVariablesWebGL2VertexShaderTest, OutputAllQualifiers)
265 {
266 const std::string &shaderString =
267 "#version 300 es\n"
268 "precision mediump float;\n"
269 "precision lowp int;\n"
270 "out vec4 out1;\n"
271 "flat out int out2;\n"
272 "centroid out float out3;\n"
273 "smooth out float out4;\n"
274 "void main() {\n"
275 "}\n";
276 compileAssumeSuccess(shaderString);
277 VerifyOutputVariableInitializers verifier(mASTRoot);
278
279 ExpectedLValues expectedLValues = {
280 CreateLValueNode(ImmutableString("out1"), TType(EbtFloat, EbpMedium, EvqVertexOut, 4)),
281 CreateLValueNode(ImmutableString("out2"), TType(EbtInt, EbpLow, EvqFlatOut)),
282 CreateLValueNode(ImmutableString("out3"), TType(EbtFloat, EbpMedium, EvqCentroidOut)),
283 CreateLValueNode(ImmutableString("out4"), TType(EbtFloat, EbpMedium, EvqSmoothOut))};
284 EXPECT_TRUE(verifier.areAllExpectedLValuesFound(expectedLValues));
285 }
286
287 // Test the initialization of an output array in a vertex shader.
TEST_F(InitOutputVariablesWebGL2VertexShaderTest,OutputArray)288 TEST_F(InitOutputVariablesWebGL2VertexShaderTest, OutputArray)
289 {
290 const std::string &shaderString =
291 "#version 300 es\n"
292 "precision mediump float;\n"
293 "out float out1[2];\n"
294 "void main() {\n"
295 "}\n";
296 compileAssumeSuccess(shaderString);
297 VerifyOutputVariableInitializers verifier(mASTRoot);
298
299 ExpectedLValues expectedLValues = CreateIndexedLValueNodeList(
300 ImmutableString("out1"), TType(EbtFloat, EbpMedium, EvqVertexOut), 2);
301 EXPECT_TRUE(verifier.areAllExpectedLValuesFound(expectedLValues));
302 }
303
304 // Test the initialization of a struct output variable in a vertex shader.
TEST_F(InitOutputVariablesWebGL2VertexShaderTest,OutputStruct)305 TEST_F(InitOutputVariablesWebGL2VertexShaderTest, OutputStruct)
306 {
307 const std::string &shaderString =
308 "#version 300 es\n"
309 "precision mediump float;\n"
310 "struct MyS{\n"
311 " float a;\n"
312 " float b;\n"
313 "};\n"
314 "out MyS out1;\n"
315 "void main() {\n"
316 "}\n";
317 compileAssumeSuccess(shaderString);
318 VerifyOutputVariableInitializers verifier(mASTRoot);
319
320 FindStructByName findStruct(ImmutableString("MyS"));
321 mASTRoot->traverse(&findStruct);
322 ASSERT(findStruct.isStructureFound());
323
324 TType type(findStruct.getStructure(), false);
325 type.setQualifier(EvqVertexOut);
326
327 TIntermTyped *expectedLValue = CreateLValueNode(ImmutableString("out1"), type);
328 EXPECT_TRUE(verifier.isExpectedLValueFound(expectedLValue));
329 delete expectedLValue;
330 }
331
332 // Test the initialization of a varying variable in an ESSL1 vertex shader.
TEST_F(InitOutputVariablesWebGL2VertexShaderTest,OutputFromESSL1Shader)333 TEST_F(InitOutputVariablesWebGL2VertexShaderTest, OutputFromESSL1Shader)
334 {
335 const std::string &shaderString =
336 "precision mediump float;\n"
337 "varying vec4 out1;\n"
338 "void main() {\n"
339 "}\n";
340 compileAssumeSuccess(shaderString);
341 VerifyOutputVariableInitializers verifier(mASTRoot);
342
343 TIntermTyped *expectedLValue =
344 CreateLValueNode(ImmutableString("out1"), TType(EbtFloat, EbpMedium, EvqVaryingOut, 4));
345 EXPECT_TRUE(verifier.isExpectedLValueFound(expectedLValue));
346 delete expectedLValue;
347 }
348
349 // Test the initialization of output variables in a fragment shader.
TEST_F(InitOutputVariablesWebGL2FragmentShaderTest,Output)350 TEST_F(InitOutputVariablesWebGL2FragmentShaderTest, Output)
351 {
352 const std::string &shaderString =
353 "#version 300 es\n"
354 "precision mediump float;\n"
355 "out vec4 out1;\n"
356 "void main() {\n"
357 "}\n";
358 compileAssumeSuccess(shaderString);
359 VerifyOutputVariableInitializers verifier(mASTRoot);
360
361 TIntermTyped *expectedLValue =
362 CreateLValueNode(ImmutableString("out1"), TType(EbtFloat, EbpMedium, EvqFragmentOut, 4));
363 EXPECT_TRUE(verifier.isExpectedLValueFound(expectedLValue));
364 delete expectedLValue;
365 }
366
367 // Test the initialization of gl_FragData in a WebGL2 ESSL1 fragment shader. Only writes to
368 // gl_FragData[0] should be found.
TEST_F(InitOutputVariablesWebGL2FragmentShaderTest,FragData)369 TEST_F(InitOutputVariablesWebGL2FragmentShaderTest, FragData)
370 {
371 const std::string &shaderString =
372 "precision mediump float;\n"
373 "void main() {\n"
374 " gl_FragData[0] = vec4(1.);\n"
375 "}\n";
376 compileAssumeSuccess(shaderString);
377 VerifyOutputVariableInitializers verifier(mASTRoot);
378
379 ExpectedLValues expectedLValues = CreateIndexedLValueNodeList(
380 ImmutableString("gl_FragData"), TType(EbtFloat, EbpMedium, EvqFragData, 4), 1);
381 EXPECT_TRUE(verifier.isExpectedLValueFound(expectedLValues[0]));
382 EXPECT_EQ(1u, verifier.getCandidates().size());
383 }
384
385 // Test the initialization of gl_FragData in a WebGL1 ESSL1 fragment shader. Only writes to
386 // gl_FragData[0] should be found.
TEST_F(InitOutputVariablesWebGL1FragmentShaderTest,FragData)387 TEST_F(InitOutputVariablesWebGL1FragmentShaderTest, FragData)
388 {
389 const std::string &shaderString =
390 "precision mediump float;\n"
391 "void main() {\n"
392 " gl_FragData[0] = vec4(1.);\n"
393 "}\n";
394 compileAssumeSuccess(shaderString);
395 VerifyOutputVariableInitializers verifier(mASTRoot);
396
397 // In the symbol table, gl_FragData array has 2 elements. However, only the 1st one should be
398 // initialized.
399 ExpectedLValues expectedLValues = CreateIndexedLValueNodeList(
400 ImmutableString("gl_FragData"), TType(EbtFloat, EbpMedium, EvqFragData, 4), 2);
401 EXPECT_TRUE(verifier.isExpectedLValueFound(expectedLValues[0]));
402 EXPECT_EQ(1u, verifier.getCandidates().size());
403 }
404
405 // Test the initialization of gl_FragData in a WebGL1 ESSL1 fragment shader with GL_EXT_draw_buffers
406 // enabled. All attachment slots should be initialized.
TEST_F(InitOutputVariablesWebGL1FragmentShaderTest,FragDataWithDrawBuffersExtEnabled)407 TEST_F(InitOutputVariablesWebGL1FragmentShaderTest, FragDataWithDrawBuffersExtEnabled)
408 {
409 const std::string &shaderString =
410 "#extension GL_EXT_draw_buffers : enable\n"
411 "precision mediump float;\n"
412 "void main() {\n"
413 " gl_FragData[0] = vec4(1.);\n"
414 "}\n";
415 compileAssumeSuccess(shaderString);
416 VerifyOutputVariableInitializers verifier(mASTRoot);
417
418 ExpectedLValues expectedLValues = CreateIndexedLValueNodeList(
419 ImmutableString("gl_FragData"), TType(EbtFloat, EbpMedium, EvqFragData, 4), 2);
420 EXPECT_TRUE(verifier.isExpectedLValueFound(expectedLValues[0]));
421 EXPECT_TRUE(verifier.isExpectedLValueFound(expectedLValues[1]));
422 EXPECT_EQ(2u, verifier.getCandidates().size());
423 }
424
425 // Test that gl_Position is initialized once in case it is not statically used and both
426 // SH_INIT_OUTPUT_VARIABLES and SH_INIT_GL_POSITION flags are set.
TEST_F(InitOutputVariablesWebGL2VertexShaderTest,InitGLPositionWhenNotStaticallyUsed)427 TEST_F(InitOutputVariablesWebGL2VertexShaderTest, InitGLPositionWhenNotStaticallyUsed)
428 {
429 const std::string &shaderString =
430 "#version 300 es\n"
431 "precision highp float;\n"
432 "void main() {\n"
433 "}\n";
434 compileAssumeSuccess(shaderString);
435 VerifyOutputVariableInitializers verifier(mASTRoot);
436
437 TIntermTyped *glPosition =
438 CreateLValueNode(ImmutableString("gl_Position"), TType(EbtFloat, EbpHigh, EvqPosition, 4));
439 EXPECT_TRUE(verifier.isExpectedLValueFound(glPosition));
440 EXPECT_EQ(1u, verifier.getCandidates().size());
441 }
442
443 // Test that gl_Position is initialized once in case it is statically used and both
444 // SH_INIT_OUTPUT_VARIABLES and SH_INIT_GL_POSITION flags are set.
TEST_F(InitOutputVariablesWebGL2VertexShaderTest,InitGLPositionOnceWhenStaticallyUsed)445 TEST_F(InitOutputVariablesWebGL2VertexShaderTest, InitGLPositionOnceWhenStaticallyUsed)
446 {
447 const std::string &shaderString =
448 "#version 300 es\n"
449 "precision highp float;\n"
450 "void main() {\n"
451 " gl_Position = vec4(1.0);\n"
452 "}\n";
453 compileAssumeSuccess(shaderString);
454 VerifyOutputVariableInitializers verifier(mASTRoot);
455
456 TIntermTyped *glPosition =
457 CreateLValueNode(ImmutableString("gl_Position"), TType(EbtFloat, EbpHigh, EvqPosition, 4));
458 EXPECT_TRUE(verifier.isExpectedLValueFound(glPosition));
459 EXPECT_EQ(1u, verifier.getCandidates().size());
460 }
461
462 // Mirrors ClipDistanceTest.ThreeClipDistancesRedeclared
TEST_F(InitOutputVariablesVertexShaderClipDistanceTest,RedeclareClipDistance)463 TEST_F(InitOutputVariablesVertexShaderClipDistanceTest, RedeclareClipDistance)
464 {
465 constexpr char shaderString[] = R"(
466 #extension GL_APPLE_clip_distance : require
467
468 varying highp float gl_ClipDistance[3];
469
470 void computeClipDistances(in vec4 position, in vec4 plane[3])
471 {
472 gl_ClipDistance[0] = dot(position, plane[0]);
473 gl_ClipDistance[1] = dot(position, plane[1]);
474 gl_ClipDistance[2] = dot(position, plane[2]);
475 }
476
477 uniform vec4 u_plane[3];
478
479 attribute vec2 a_position;
480
481 void main()
482 {
483 gl_Position = vec4(a_position, 0.0, 1.0);
484
485 computeClipDistances(gl_Position, u_plane);
486 })";
487
488 compileAssumeSuccess(shaderString);
489 }
490 } // namespace sh
491