1 //
2 // Copyright 2002 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // Scalarize vector and matrix constructor args, so that vectors built from components don't have
7 // matrix arguments, and matrices built from components don't have vector arguments. This avoids
8 // driver bugs around vector and matrix constructors.
9 //
10
11 #include "compiler/translator/tree_ops/ScalarizeVecAndMatConstructorArgs.h"
12 #include "common/debug.h"
13
14 #include <algorithm>
15
16 #include "angle_gl.h"
17 #include "common/angleutils.h"
18 #include "compiler/translator/Compiler.h"
19 #include "compiler/translator/tree_util/IntermNodePatternMatcher.h"
20 #include "compiler/translator/tree_util/IntermNode_util.h"
21 #include "compiler/translator/tree_util/IntermTraverse.h"
22
23 namespace sh
24 {
25
26 namespace
27 {
28
ConstructVectorIndexBinaryNode(TIntermSymbol * symbolNode,int index)29 TIntermBinary *ConstructVectorIndexBinaryNode(TIntermSymbol *symbolNode, int index)
30 {
31 return new TIntermBinary(EOpIndexDirect, symbolNode, CreateIndexNode(index));
32 }
33
ConstructMatrixIndexBinaryNode(TIntermSymbol * symbolNode,int colIndex,int rowIndex)34 TIntermBinary *ConstructMatrixIndexBinaryNode(TIntermSymbol *symbolNode, int colIndex, int rowIndex)
35 {
36 TIntermBinary *colVectorNode = ConstructVectorIndexBinaryNode(symbolNode, colIndex);
37
38 return new TIntermBinary(EOpIndexDirect, colVectorNode, CreateIndexNode(rowIndex));
39 }
40
41 class ScalarizeArgsTraverser : public TIntermTraverser
42 {
43 public:
ScalarizeArgsTraverser(sh::GLenum shaderType,bool fragmentPrecisionHigh,TSymbolTable * symbolTable)44 ScalarizeArgsTraverser(sh::GLenum shaderType,
45 bool fragmentPrecisionHigh,
46 TSymbolTable *symbolTable)
47 : TIntermTraverser(true, false, false, symbolTable),
48 mShaderType(shaderType),
49 mFragmentPrecisionHigh(fragmentPrecisionHigh),
50 mNodesToScalarize(IntermNodePatternMatcher::kScalarizedVecOrMatConstructor)
51 {}
52
53 protected:
54 bool visitAggregate(Visit visit, TIntermAggregate *node) override;
55 bool visitBlock(Visit visit, TIntermBlock *node) override;
56
57 private:
58 void scalarizeArgs(TIntermAggregate *aggregate, bool scalarizeVector, bool scalarizeMatrix);
59
60 // If we have the following code:
61 // mat4 m(0);
62 // vec4 v(1, m);
63 // We will rewrite to:
64 // mat4 m(0);
65 // mat4 s0 = m;
66 // vec4 v(1, s0[0][0], s0[0][1], s0[0][2]);
67 // This function is to create nodes for "mat4 s0 = m;" and insert it to the code sequence. This
68 // way the possible side effects of the constructor argument will only be evaluated once.
69 TVariable *createTempVariable(TIntermTyped *original);
70
71 std::vector<TIntermSequence> mBlockStack;
72
73 sh::GLenum mShaderType;
74 bool mFragmentPrecisionHigh;
75
76 IntermNodePatternMatcher mNodesToScalarize;
77 };
78
visitAggregate(Visit visit,TIntermAggregate * node)79 bool ScalarizeArgsTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
80 {
81 ASSERT(visit == PreVisit);
82 if (mNodesToScalarize.match(node, getParentNode()))
83 {
84 if (node->getType().isVector())
85 {
86 scalarizeArgs(node, false, true);
87 }
88 else
89 {
90 ASSERT(node->getType().isMatrix());
91 scalarizeArgs(node, true, false);
92 }
93 }
94 return true;
95 }
96
visitBlock(Visit visit,TIntermBlock * node)97 bool ScalarizeArgsTraverser::visitBlock(Visit visit, TIntermBlock *node)
98 {
99 mBlockStack.push_back(TIntermSequence());
100 {
101 for (TIntermNode *child : *node->getSequence())
102 {
103 ASSERT(child != nullptr);
104 child->traverse(this);
105 mBlockStack.back().push_back(child);
106 }
107 }
108 if (mBlockStack.back().size() > node->getSequence()->size())
109 {
110 node->getSequence()->clear();
111 *(node->getSequence()) = mBlockStack.back();
112 }
113 mBlockStack.pop_back();
114 return false;
115 }
116
scalarizeArgs(TIntermAggregate * aggregate,bool scalarizeVector,bool scalarizeMatrix)117 void ScalarizeArgsTraverser::scalarizeArgs(TIntermAggregate *aggregate,
118 bool scalarizeVector,
119 bool scalarizeMatrix)
120 {
121 ASSERT(aggregate);
122 ASSERT(!aggregate->isArray());
123 int size = static_cast<int>(aggregate->getType().getObjectSize());
124 TIntermSequence *sequence = aggregate->getSequence();
125 TIntermSequence originalArgs(*sequence);
126 sequence->clear();
127 for (TIntermNode *originalArgNode : originalArgs)
128 {
129 ASSERT(size > 0);
130 TIntermTyped *originalArg = originalArgNode->getAsTyped();
131 ASSERT(originalArg);
132 TVariable *argVariable = createTempVariable(originalArg);
133 if (originalArg->isScalar())
134 {
135 sequence->push_back(CreateTempSymbolNode(argVariable));
136 size--;
137 }
138 else if (originalArg->isVector())
139 {
140 if (scalarizeVector)
141 {
142 int repeat = std::min(size, originalArg->getNominalSize());
143 size -= repeat;
144 for (int index = 0; index < repeat; ++index)
145 {
146 TIntermSymbol *symbolNode = CreateTempSymbolNode(argVariable);
147 TIntermBinary *newNode = ConstructVectorIndexBinaryNode(symbolNode, index);
148 sequence->push_back(newNode);
149 }
150 }
151 else
152 {
153 TIntermSymbol *symbolNode = CreateTempSymbolNode(argVariable);
154 sequence->push_back(symbolNode);
155 size -= originalArg->getNominalSize();
156 }
157 }
158 else
159 {
160 ASSERT(originalArg->isMatrix());
161 if (scalarizeMatrix)
162 {
163 int colIndex = 0, rowIndex = 0;
164 int repeat = std::min(size, originalArg->getCols() * originalArg->getRows());
165 size -= repeat;
166 while (repeat > 0)
167 {
168 TIntermSymbol *symbolNode = CreateTempSymbolNode(argVariable);
169 TIntermBinary *newNode =
170 ConstructMatrixIndexBinaryNode(symbolNode, colIndex, rowIndex);
171 sequence->push_back(newNode);
172 rowIndex++;
173 if (rowIndex >= originalArg->getRows())
174 {
175 rowIndex = 0;
176 colIndex++;
177 }
178 repeat--;
179 }
180 }
181 else
182 {
183 TIntermSymbol *symbolNode = CreateTempSymbolNode(argVariable);
184 sequence->push_back(symbolNode);
185 size -= originalArg->getCols() * originalArg->getRows();
186 }
187 }
188 }
189 }
190
createTempVariable(TIntermTyped * original)191 TVariable *ScalarizeArgsTraverser::createTempVariable(TIntermTyped *original)
192 {
193 ASSERT(original);
194
195 TType *type = new TType(original->getType());
196 type->setQualifier(EvqTemporary);
197 if (mShaderType == GL_FRAGMENT_SHADER && type->getBasicType() == EbtFloat &&
198 type->getPrecision() == EbpUndefined)
199 {
200 // We use the highest available precision for the temporary variable
201 // to avoid computing the actual precision using the rules defined
202 // in GLSL ES 1.0 Section 4.5.2.
203 type->setPrecision(mFragmentPrecisionHigh ? EbpHigh : EbpMedium);
204 }
205
206 TVariable *variable = CreateTempVariable(mSymbolTable, type);
207
208 ASSERT(mBlockStack.size() > 0);
209 TIntermSequence &sequence = mBlockStack.back();
210 TIntermDeclaration *declaration = CreateTempInitDeclarationNode(variable, original);
211 sequence.push_back(declaration);
212
213 return variable;
214 }
215
216 } // namespace
217
ScalarizeVecAndMatConstructorArgs(TCompiler * compiler,TIntermBlock * root,sh::GLenum shaderType,bool fragmentPrecisionHigh,TSymbolTable * symbolTable)218 bool ScalarizeVecAndMatConstructorArgs(TCompiler *compiler,
219 TIntermBlock *root,
220 sh::GLenum shaderType,
221 bool fragmentPrecisionHigh,
222 TSymbolTable *symbolTable)
223 {
224 ScalarizeArgsTraverser scalarizer(shaderType, fragmentPrecisionHigh, symbolTable);
225 root->traverse(&scalarizer);
226
227 return compiler->validateAST(root);
228 }
229
230 } // namespace sh
231