• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2017 The ANGLE Project Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 //
5 // VectorizeVectorScalarArithmetic.cpp: Turn some arithmetic operations that operate on a float
6 // vector-scalar pair into vector-vector operations. This is done recursively. Some scalar binary
7 // operations inside vector constructors are also turned into vector operations.
8 //
9 // This is targeted to work around a bug in NVIDIA OpenGL drivers that was reproducible on NVIDIA
10 // driver version 387.92. It works around the most common occurrences of the bug.
11 
12 #include "compiler/translator/tree_ops/VectorizeVectorScalarArithmetic.h"
13 
14 #include <set>
15 
16 #include "compiler/translator/IntermNode.h"
17 #include "compiler/translator/tree_util/IntermNode_util.h"
18 #include "compiler/translator/tree_util/IntermTraverse.h"
19 
20 namespace sh
21 {
22 
23 namespace
24 {
25 
26 class VectorizeVectorScalarArithmeticTraverser : public TIntermTraverser
27 {
28   public:
VectorizeVectorScalarArithmeticTraverser(TSymbolTable * symbolTable)29     VectorizeVectorScalarArithmeticTraverser(TSymbolTable *symbolTable)
30         : TIntermTraverser(true, false, false, symbolTable), mReplaced(false)
31     {}
32 
didReplaceScalarsWithVectors()33     bool didReplaceScalarsWithVectors() { return mReplaced; }
nextIteration()34     void nextIteration()
35     {
36         mReplaced = false;
37         mModifiedBlocks.clear();
38     }
39 
40   protected:
41     bool visitBinary(Visit visit, TIntermBinary *node) override;
42     bool visitAggregate(Visit visit, TIntermAggregate *node) override;
43 
44   private:
45     // These helpers should only be called from visitAggregate when visiting a constructor.
46     // argBinary is the only argument of the constructor.
47     void replaceMathInsideConstructor(TIntermAggregate *node, TIntermBinary *argBinary);
48     void replaceAssignInsideConstructor(const TIntermAggregate *node,
49                                         const TIntermBinary *argBinary);
50 
51     static TIntermTyped *Vectorize(TIntermTyped *node,
52                                    TType vectorType,
53                                    TIntermTraverser::OriginalNode *originalNodeFate);
54 
55     bool mReplaced;
56     std::set<const TIntermBlock *> mModifiedBlocks;
57 };
58 
Vectorize(TIntermTyped * node,TType vectorType,TIntermTraverser::OriginalNode * originalNodeFate)59 TIntermTyped *VectorizeVectorScalarArithmeticTraverser::Vectorize(
60     TIntermTyped *node,
61     TType vectorType,
62     TIntermTraverser::OriginalNode *originalNodeFate)
63 {
64     ASSERT(node->isScalar());
65     vectorType.setQualifier(EvqTemporary);
66     TIntermSequence vectorConstructorArgs;
67     vectorConstructorArgs.push_back(node);
68     TIntermAggregate *vectorized =
69         TIntermAggregate::CreateConstructor(vectorType, &vectorConstructorArgs);
70     TIntermTyped *vectorizedFolded = vectorized->fold(nullptr);
71     if (originalNodeFate != nullptr)
72     {
73         if (vectorizedFolded != vectorized)
74         {
75             *originalNodeFate = OriginalNode::IS_DROPPED;
76         }
77         else
78         {
79             *originalNodeFate = OriginalNode::BECOMES_CHILD;
80         }
81     }
82     return vectorizedFolded;
83 }
84 
visitBinary(Visit,TIntermBinary * node)85 bool VectorizeVectorScalarArithmeticTraverser::visitBinary(Visit /*visit*/, TIntermBinary *node)
86 {
87     TIntermTyped *left  = node->getLeft();
88     TIntermTyped *right = node->getRight();
89     ASSERT(left);
90     ASSERT(right);
91     switch (node->getOp())
92     {
93         case EOpAdd:
94         case EOpAddAssign:
95             // Only these specific ops are necessary to turn into vector ops.
96             break;
97         default:
98             return true;
99     }
100     if (node->getBasicType() != EbtFloat)
101     {
102         // Only float ops have reproduced the bug.
103         return true;
104     }
105     if (left->isScalar() && right->isVector())
106     {
107         ASSERT(!node->isAssignment());
108         ASSERT(!right->isArray());
109         OriginalNode originalNodeFate;
110         TIntermTyped *leftVectorized = Vectorize(left, right->getType(), &originalNodeFate);
111         queueReplacementWithParent(node, left, leftVectorized, originalNodeFate);
112         mReplaced = true;
113         // Don't replace more nodes in the same subtree on this traversal. However, nodes elsewhere
114         // in the tree may still be replaced.
115         return false;
116     }
117     else if (left->isVector() && right->isScalar())
118     {
119         OriginalNode originalNodeFate;
120         TIntermTyped *rightVectorized = Vectorize(right, left->getType(), &originalNodeFate);
121         queueReplacementWithParent(node, right, rightVectorized, originalNodeFate);
122         mReplaced = true;
123         // Don't replace more nodes in the same subtree on this traversal. However, nodes elsewhere
124         // in the tree may still be replaced.
125         return false;
126     }
127     return true;
128 }
129 
replaceMathInsideConstructor(TIntermAggregate * node,TIntermBinary * argBinary)130 void VectorizeVectorScalarArithmeticTraverser::replaceMathInsideConstructor(
131     TIntermAggregate *node,
132     TIntermBinary *argBinary)
133 {
134     // Turn:
135     //   a * b
136     // into:
137     //   gvec(a) * gvec(b)
138 
139     TIntermTyped *left  = argBinary->getLeft();
140     TIntermTyped *right = argBinary->getRight();
141     ASSERT(left->isScalar() && right->isScalar());
142 
143     TType leftVectorizedType = left->getType();
144     leftVectorizedType.setPrimarySize(static_cast<unsigned char>(node->getType().getNominalSize()));
145     TIntermTyped *leftVectorized = Vectorize(left, leftVectorizedType, nullptr);
146     TType rightVectorizedType    = right->getType();
147     rightVectorizedType.setPrimarySize(
148         static_cast<unsigned char>(node->getType().getNominalSize()));
149     TIntermTyped *rightVectorized = Vectorize(right, rightVectorizedType, nullptr);
150 
151     TIntermBinary *newArg = new TIntermBinary(argBinary->getOp(), leftVectorized, rightVectorized);
152     queueReplacementWithParent(node, argBinary, newArg, OriginalNode::IS_DROPPED);
153 }
154 
replaceAssignInsideConstructor(const TIntermAggregate * node,const TIntermBinary * argBinary)155 void VectorizeVectorScalarArithmeticTraverser::replaceAssignInsideConstructor(
156     const TIntermAggregate *node,
157     const TIntermBinary *argBinary)
158 {
159     // Turn:
160     //   gvec(a *= b);
161     // into:
162     //   // This is inserted into the parent block:
163     //   gvec s0 = gvec(a);
164     //
165     //   // This goes where the gvec constructor used to be:
166     //   ((s0 *= b, a = s0.x), s0);
167 
168     TIntermTyped *left  = argBinary->getLeft();
169     TIntermTyped *right = argBinary->getRight();
170     ASSERT(left->isScalar() && right->isScalar());
171     ASSERT(!left->hasSideEffects());
172 
173     TType vecType = node->getType();
174     vecType.setQualifier(EvqTemporary);
175 
176     // gvec s0 = gvec(a);
177     // s0 is called "tempAssignmentTarget" below.
178     TIntermTyped *tempAssignmentTargetInitializer = Vectorize(left->deepCopy(), vecType, nullptr);
179     TIntermDeclaration *tempAssignmentTargetDeclaration = nullptr;
180     TVariable *tempAssignmentTarget =
181         DeclareTempVariable(mSymbolTable, tempAssignmentTargetInitializer, EvqTemporary,
182                             &tempAssignmentTargetDeclaration);
183 
184     // s0 *= b
185     TOperator compoundAssignmentOp = argBinary->getOp();
186     if (compoundAssignmentOp == EOpMulAssign)
187     {
188         compoundAssignmentOp = EOpVectorTimesScalarAssign;
189     }
190     TIntermBinary *replacementCompoundAssignment = new TIntermBinary(
191         compoundAssignmentOp, CreateTempSymbolNode(tempAssignmentTarget), right->deepCopy());
192 
193     // s0.x
194     TVector<int> swizzleXOffset;
195     swizzleXOffset.push_back(0);
196     TIntermSwizzle *tempAssignmentTargetX =
197         new TIntermSwizzle(CreateTempSymbolNode(tempAssignmentTarget), swizzleXOffset);
198     // a = s0.x
199     TIntermBinary *replacementAssignBackToTarget =
200         new TIntermBinary(EOpAssign, left->deepCopy(), tempAssignmentTargetX);
201 
202     // s0 *= b, a = s0.x
203     TIntermBinary *replacementSequenceLeft =
204         new TIntermBinary(EOpComma, replacementCompoundAssignment, replacementAssignBackToTarget);
205     // (s0 *= b, a = s0.x), s0
206     // Note that the created comma node is not const qualified in any case, so we can always pass
207     // shader version 300 here.
208     TIntermBinary *replacementSequence = TIntermBinary::CreateComma(
209         replacementSequenceLeft, CreateTempSymbolNode(tempAssignmentTarget), 300);
210 
211     insertStatementInParentBlock(tempAssignmentTargetDeclaration);
212     queueReplacement(replacementSequence, OriginalNode::IS_DROPPED);
213 }
214 
visitAggregate(Visit,TIntermAggregate * node)215 bool VectorizeVectorScalarArithmeticTraverser::visitAggregate(Visit /*visit*/,
216                                                               TIntermAggregate *node)
217 {
218     // Transform scalar binary expressions inside vector constructors.
219     if (!node->isConstructor() || !node->isVector() || node->getSequence()->size() != 1)
220     {
221         return true;
222     }
223     TIntermTyped *argument = node->getSequence()->back()->getAsTyped();
224     ASSERT(argument);
225     if (!argument->isScalar() || argument->getBasicType() != EbtFloat)
226     {
227         return true;
228     }
229     TIntermBinary *argBinary = argument->getAsBinaryNode();
230     if (!argBinary)
231     {
232         return true;
233     }
234 
235     // Only specific ops are necessary to change.
236     switch (argBinary->getOp())
237     {
238         case EOpMul:
239         case EOpDiv:
240         {
241             replaceMathInsideConstructor(node, argBinary);
242             mReplaced = true;
243             // Don't replace more nodes in the same subtree on this traversal. However, nodes
244             // elsewhere in the tree may still be replaced.
245             return false;
246         }
247         case EOpMulAssign:
248         case EOpDivAssign:
249         {
250             // The case where the left side has side effects is too complicated to deal with, so we
251             // leave that be.
252             if (!argBinary->getLeft()->hasSideEffects())
253             {
254                 const TIntermBlock *parentBlock = getParentBlock();
255                 // We can't do more than one insertion to the same block on the same traversal.
256                 if (mModifiedBlocks.find(parentBlock) == mModifiedBlocks.end())
257                 {
258                     replaceAssignInsideConstructor(node, argBinary);
259                     mModifiedBlocks.insert(parentBlock);
260                     mReplaced = true;
261                     // Don't replace more nodes in the same subtree on this traversal.
262                     // However, nodes elsewhere in the tree may still be replaced.
263                     return false;
264                 }
265             }
266             break;
267         }
268         default:
269             return true;
270     }
271     return true;
272 }
273 
274 }  // anonymous namespace
275 
VectorizeVectorScalarArithmetic(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable)276 bool VectorizeVectorScalarArithmetic(TCompiler *compiler,
277                                      TIntermBlock *root,
278                                      TSymbolTable *symbolTable)
279 {
280     VectorizeVectorScalarArithmeticTraverser traverser(symbolTable);
281     do
282     {
283         traverser.nextIteration();
284         root->traverse(&traverser);
285         if (!traverser.updateTree(compiler, root))
286         {
287             return false;
288         }
289     } while (traverser.didReplaceScalarsWithVectors());
290 
291     return true;
292 }
293 
294 }  // namespace sh
295