• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright 2002 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // Scalarize vector and matrix constructor args, so that vectors built from components don't have
7 // matrix arguments, and matrices built from components don't have vector arguments. This avoids
8 // driver bugs around vector and matrix constructors.
9 //
10 
11 #include "compiler/translator/tree_ops/ScalarizeVecAndMatConstructorArgs.h"
12 #include "common/debug.h"
13 
14 #include <algorithm>
15 
16 #include "angle_gl.h"
17 #include "common/angleutils.h"
18 #include "compiler/translator/Compiler.h"
19 #include "compiler/translator/tree_util/IntermNodePatternMatcher.h"
20 #include "compiler/translator/tree_util/IntermNode_util.h"
21 #include "compiler/translator/tree_util/IntermTraverse.h"
22 #include "compiler/translator/util.h"
23 
24 namespace sh
25 {
26 
27 namespace
28 {
29 
ConstructVectorIndexBinaryNode(TIntermTyped * symbolNode,int index)30 TIntermBinary *ConstructVectorIndexBinaryNode(TIntermTyped *symbolNode, int index)
31 {
32     return new TIntermBinary(EOpIndexDirect, symbolNode, CreateIndexNode(index));
33 }
34 
ConstructMatrixIndexBinaryNode(TIntermTyped * symbolNode,int colIndex,int rowIndex)35 TIntermBinary *ConstructMatrixIndexBinaryNode(TIntermTyped *symbolNode, int colIndex, int rowIndex)
36 {
37     TIntermBinary *colVectorNode = ConstructVectorIndexBinaryNode(symbolNode, colIndex);
38 
39     return new TIntermBinary(EOpIndexDirect, colVectorNode, CreateIndexNode(rowIndex));
40 }
41 
42 class ScalarizeArgsTraverser : public TIntermTraverser
43 {
44   public:
ScalarizeArgsTraverser(TSymbolTable * symbolTable)45     ScalarizeArgsTraverser(TSymbolTable *symbolTable)
46         : TIntermTraverser(true, false, false, symbolTable),
47           mNodesToScalarize(IntermNodePatternMatcher::kScalarizedVecOrMatConstructor)
48     {}
49 
50   protected:
51     bool visitAggregate(Visit visit, TIntermAggregate *node) override;
52     bool visitBlock(Visit visit, TIntermBlock *node) override;
53 
54   private:
55     void scalarizeArgs(TIntermAggregate *aggregate, bool scalarizeVector, bool scalarizeMatrix);
56 
57     // If we have the following code:
58     //   mat4 m(0);
59     //   vec4 v(1, m);
60     // We will rewrite to:
61     //   mat4 m(0);
62     //   mat4 s0 = m;
63     //   vec4 v(1, s0[0][0], s0[0][1], s0[0][2]);
64     // This function is to create nodes for "mat4 s0 = m;" and insert it to the code sequence. This
65     // way the possible side effects of the constructor argument will only be evaluated once.
66     TIntermTyped *createTempVariable(TIntermTyped *original);
67 
68     std::vector<TIntermSequence> mBlockStack;
69 
70     IntermNodePatternMatcher mNodesToScalarize;
71 };
72 
visitAggregate(Visit visit,TIntermAggregate * node)73 bool ScalarizeArgsTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
74 {
75     ASSERT(visit == PreVisit);
76     if (mNodesToScalarize.match(node, getParentNode()))
77     {
78         if (node->getType().isVector())
79         {
80             scalarizeArgs(node, false, true);
81         }
82         else
83         {
84             ASSERT(node->getType().isMatrix());
85             scalarizeArgs(node, true, false);
86         }
87     }
88     return true;
89 }
90 
visitBlock(Visit visit,TIntermBlock * node)91 bool ScalarizeArgsTraverser::visitBlock(Visit visit, TIntermBlock *node)
92 {
93     mBlockStack.push_back(TIntermSequence());
94     {
95         for (TIntermNode *child : *node->getSequence())
96         {
97             ASSERT(child != nullptr);
98             child->traverse(this);
99             mBlockStack.back().push_back(child);
100         }
101     }
102     if (mBlockStack.back().size() > node->getSequence()->size())
103     {
104         node->getSequence()->clear();
105         *(node->getSequence()) = mBlockStack.back();
106     }
107     mBlockStack.pop_back();
108     return false;
109 }
110 
scalarizeArgs(TIntermAggregate * aggregate,bool scalarizeVector,bool scalarizeMatrix)111 void ScalarizeArgsTraverser::scalarizeArgs(TIntermAggregate *aggregate,
112                                            bool scalarizeVector,
113                                            bool scalarizeMatrix)
114 {
115     ASSERT(aggregate);
116     ASSERT(!aggregate->isArray());
117     int size                  = static_cast<int>(aggregate->getType().getObjectSize());
118     TIntermSequence *sequence = aggregate->getSequence();
119     TIntermSequence originalArgs(*sequence);
120     sequence->clear();
121     for (TIntermNode *originalArgNode : originalArgs)
122     {
123         ASSERT(size > 0);
124         TIntermTyped *originalArg = originalArgNode->getAsTyped();
125         ASSERT(originalArg);
126         TIntermTyped *argVariable = createTempVariable(originalArg);
127         if (originalArg->isScalar())
128         {
129             sequence->push_back(argVariable);
130             size--;
131         }
132         else if (originalArg->isVector())
133         {
134             if (scalarizeVector)
135             {
136                 int repeat = std::min<int>(size, originalArg->getNominalSize());
137                 size -= repeat;
138                 for (int index = 0; index < repeat; ++index)
139                 {
140                     TIntermBinary *newNode =
141                         ConstructVectorIndexBinaryNode(argVariable->deepCopy(), index);
142                     sequence->push_back(newNode);
143                 }
144             }
145             else
146             {
147                 sequence->push_back(argVariable);
148                 size -= originalArg->getNominalSize();
149             }
150         }
151         else
152         {
153             ASSERT(originalArg->isMatrix());
154             if (scalarizeMatrix)
155             {
156                 int colIndex = 0, rowIndex = 0;
157                 int repeat = std::min<int>(size, originalArg->getCols() * originalArg->getRows());
158                 size -= repeat;
159                 while (repeat > 0)
160                 {
161                     TIntermBinary *newNode =
162                         ConstructMatrixIndexBinaryNode(argVariable->deepCopy(), colIndex, rowIndex);
163                     sequence->push_back(newNode);
164                     rowIndex++;
165                     if (rowIndex >= originalArg->getRows())
166                     {
167                         rowIndex = 0;
168                         colIndex++;
169                     }
170                     repeat--;
171                 }
172             }
173             else
174             {
175                 sequence->push_back(argVariable);
176                 size -= originalArg->getCols() * originalArg->getRows();
177             }
178         }
179     }
180 }
181 
createTempVariable(TIntermTyped * original)182 TIntermTyped *ScalarizeArgsTraverser::createTempVariable(TIntermTyped *original)
183 {
184     ASSERT(original);
185 
186     TType *type = new TType(original->getType());
187     type->setQualifier(EvqTemporary);
188 
189     // The precision of the constant must have been retained (or derived), which will now apply to
190     // the temp variable.  In some cases, the precision cannot be derived, so use the constant as
191     // is.  For example, in the following standalone statement, the precision of the constant 0
192     // cannot be determined:
193     //
194     //      mat2(0, bvec3(m));
195     //
196     if (IsPrecisionApplicableToType(type->getBasicType()) && type->getPrecision() == EbpUndefined)
197     {
198         return original;
199     }
200 
201     TVariable *variable = CreateTempVariable(mSymbolTable, type);
202 
203     ASSERT(mBlockStack.size() > 0);
204     TIntermSequence &sequence       = mBlockStack.back();
205     TIntermDeclaration *declaration = CreateTempInitDeclarationNode(variable, original);
206     sequence.push_back(declaration);
207 
208     return CreateTempSymbolNode(variable);
209 }
210 
211 }  // namespace
212 
ScalarizeVecAndMatConstructorArgs(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable)213 bool ScalarizeVecAndMatConstructorArgs(TCompiler *compiler,
214                                        TIntermBlock *root,
215                                        TSymbolTable *symbolTable)
216 {
217     ScalarizeArgsTraverser scalarizer(symbolTable);
218     root->traverse(&scalarizer);
219 
220     return compiler->validateAST(root);
221 }
222 
223 }  // namespace sh
224