• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright 2018 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // Implementation of the function RewriteAtomicFunctionExpressions.
7 // See the header for more details.
8 
9 #include "RewriteAtomicFunctionExpressions.h"
10 
11 #include "compiler/translator/tree_util/IntermNodePatternMatcher.h"
12 #include "compiler/translator/tree_util/IntermNode_util.h"
13 #include "compiler/translator/tree_util/IntermTraverse.h"
14 #include "compiler/translator/util.h"
15 
16 namespace sh
17 {
18 namespace
19 {
20 // Traverser that simplifies all the atomic function expressions into the ones that can be directly
21 // translated into HLSL.
22 //
23 // case 1 (only for atomicExchange and atomicCompSwap):
24 //  original:
25 //      atomicExchange(counter, newValue);
26 //  new:
27 //      tempValue = atomicExchange(counter, newValue);
28 //
29 // case 2 (atomic function, temporary variable required):
30 //  original:
31 //      value = atomicAdd(counter, 1) * otherValue;
32 //      someArray[atomicAdd(counter, 1)] = someOtherValue;
33 //  new:
34 //      value = ((tempValue = atomicAdd(counter, 1)), tempValue) * otherValue;
35 //      someArray[((tempValue = atomicAdd(counter, 1)), tempValue)] = someOtherValue;
36 //
37 // case 3 (atomic function used directly initialize a variable):
38 //  original:
39 //      int value = atomicAdd(counter, 1);
40 //  new:
41 //      tempValue = atomicAdd(counter, 1);
42 //      int value = tempValue;
43 //
44 class RewriteAtomicFunctionExpressionsTraverser : public TIntermTraverser
45 {
46   public:
47     RewriteAtomicFunctionExpressionsTraverser(TSymbolTable *symbolTable, int shaderVersion);
48 
49     bool visitAggregate(Visit visit, TIntermAggregate *node) override;
50     bool visitBlock(Visit visit, TIntermBlock *node) override;
51 
52   private:
53     static bool IsAtomicExchangeOrCompSwapNoReturnValue(TIntermAggregate *node,
54                                                         TIntermNode *parentNode);
55     static bool IsAtomicFunctionInsideExpression(TIntermAggregate *node, TIntermNode *parentNode);
56 
57     void rewriteAtomicFunctionCallNode(TIntermAggregate *oldAtomicFunctionNode);
58 
59     const TVariable *getTempVariable(const TType *type);
60 
61     int mShaderVersion;
62     TIntermSequence mTempVariables;
63 };
64 
RewriteAtomicFunctionExpressionsTraverser(TSymbolTable * symbolTable,int shaderVersion)65 RewriteAtomicFunctionExpressionsTraverser::RewriteAtomicFunctionExpressionsTraverser(
66     TSymbolTable *symbolTable,
67     int shaderVersion)
68     : TIntermTraverser(false, false, true, symbolTable), mShaderVersion(shaderVersion)
69 {}
70 
rewriteAtomicFunctionCallNode(TIntermAggregate * oldAtomicFunctionNode)71 void RewriteAtomicFunctionExpressionsTraverser::rewriteAtomicFunctionCallNode(
72     TIntermAggregate *oldAtomicFunctionNode)
73 {
74     ASSERT(oldAtomicFunctionNode);
75 
76     const TVariable *returnVariable = getTempVariable(&oldAtomicFunctionNode->getType());
77 
78     TIntermBinary *rewrittenNode = new TIntermBinary(
79         TOperator::EOpAssign, CreateTempSymbolNode(returnVariable), oldAtomicFunctionNode);
80 
81     auto *parentNode = getParentNode();
82 
83     auto *parentBinary = parentNode->getAsBinaryNode();
84     if (parentBinary && parentBinary->getOp() == EOpInitialize)
85     {
86         insertStatementInParentBlock(rewrittenNode);
87         queueReplacement(CreateTempSymbolNode(returnVariable), OriginalNode::IS_DROPPED);
88     }
89     else
90     {
91         // As all atomic function assignment will be converted to the last argument of an
92         // interlocked function, if we need the return value, assignment needs to be wrapped with
93         // the comma operator and the temporary variables.
94         if (!parentNode->getAsBlock())
95         {
96             rewrittenNode = TIntermBinary::CreateComma(
97                 rewrittenNode, new TIntermSymbol(returnVariable), mShaderVersion);
98         }
99 
100         queueReplacement(rewrittenNode, OriginalNode::IS_DROPPED);
101     }
102 }
103 
getTempVariable(const TType * type)104 const TVariable *RewriteAtomicFunctionExpressionsTraverser::getTempVariable(const TType *type)
105 {
106     TIntermDeclaration *variableDeclaration;
107     TVariable *returnVariable =
108         DeclareTempVariable(mSymbolTable, type, EvqTemporary, &variableDeclaration);
109     mTempVariables.push_back(variableDeclaration);
110     return returnVariable;
111 }
112 
IsAtomicExchangeOrCompSwapNoReturnValue(TIntermAggregate * node,TIntermNode * parentNode)113 bool RewriteAtomicFunctionExpressionsTraverser::IsAtomicExchangeOrCompSwapNoReturnValue(
114     TIntermAggregate *node,
115     TIntermNode *parentNode)
116 {
117     ASSERT(node);
118     return (node->getOp() == EOpAtomicExchange || node->getOp() == EOpAtomicCompSwap) &&
119            parentNode && parentNode->getAsBlock();
120 }
121 
IsAtomicFunctionInsideExpression(TIntermAggregate * node,TIntermNode * parentNode)122 bool RewriteAtomicFunctionExpressionsTraverser::IsAtomicFunctionInsideExpression(
123     TIntermAggregate *node,
124     TIntermNode *parentNode)
125 {
126     ASSERT(node);
127     // We only need to handle atomic functions with a parent that it is not block nodes. If the
128     // parent node is block, it means that the atomic function is not inside an expression.
129     if (!IsAtomicFunction(node->getOp()) || parentNode->getAsBlock())
130     {
131         return false;
132     }
133 
134     auto *parentAsBinary = parentNode->getAsBinaryNode();
135     // Assignments are handled in OutputHLSL
136     return !parentAsBinary || parentAsBinary->getOp() != EOpAssign;
137 }
138 
visitAggregate(Visit visit,TIntermAggregate * node)139 bool RewriteAtomicFunctionExpressionsTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
140 {
141     ASSERT(visit == PostVisit);
142     // Skip atomic memory functions for SSBO. They will be processed in the OutputHLSL traverser.
143     if (IsAtomicFunction(node->getOp()) &&
144         IsInShaderStorageBlock((*node->getSequence())[0]->getAsTyped()))
145     {
146         return false;
147     }
148 
149     TIntermNode *parentNode = getParentNode();
150     if (IsAtomicExchangeOrCompSwapNoReturnValue(node, parentNode) ||
151         IsAtomicFunctionInsideExpression(node, parentNode))
152     {
153         rewriteAtomicFunctionCallNode(node);
154     }
155 
156     return true;
157 }
158 
visitBlock(Visit visit,TIntermBlock * node)159 bool RewriteAtomicFunctionExpressionsTraverser::visitBlock(Visit visit, TIntermBlock *node)
160 {
161     ASSERT(visit == PostVisit);
162 
163     if (!mTempVariables.empty() && getParentNode()->getAsFunctionDefinition())
164     {
165         insertStatementsInBlockAtPosition(node, 0, mTempVariables, TIntermSequence());
166         mTempVariables.clear();
167     }
168 
169     return true;
170 }
171 
172 }  // anonymous namespace
173 
RewriteAtomicFunctionExpressions(TCompiler * compiler,TIntermNode * root,TSymbolTable * symbolTable,int shaderVersion)174 bool RewriteAtomicFunctionExpressions(TCompiler *compiler,
175                                       TIntermNode *root,
176                                       TSymbolTable *symbolTable,
177                                       int shaderVersion)
178 {
179     RewriteAtomicFunctionExpressionsTraverser traverser(symbolTable, shaderVersion);
180     traverser.traverse(root);
181     return traverser.updateTree(compiler, root);
182 }
183 }  // namespace sh
184