1 //
2 // Copyright 2018 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // Implementation of the function RewriteAtomicFunctionExpressions.
7 // See the header for more details.
8
9 #include "RewriteAtomicFunctionExpressions.h"
10
11 #include "compiler/translator/tree_util/IntermNodePatternMatcher.h"
12 #include "compiler/translator/tree_util/IntermNode_util.h"
13 #include "compiler/translator/tree_util/IntermTraverse.h"
14 #include "compiler/translator/util.h"
15
16 namespace sh
17 {
18 namespace
19 {
20 // Traverser that simplifies all the atomic function expressions into the ones that can be directly
21 // translated into HLSL.
22 //
23 // case 1 (only for atomicExchange and atomicCompSwap):
24 // original:
25 // atomicExchange(counter, newValue);
26 // new:
27 // tempValue = atomicExchange(counter, newValue);
28 //
29 // case 2 (atomic function, temporary variable required):
30 // original:
31 // value = atomicAdd(counter, 1) * otherValue;
32 // someArray[atomicAdd(counter, 1)] = someOtherValue;
33 // new:
34 // value = ((tempValue = atomicAdd(counter, 1)), tempValue) * otherValue;
35 // someArray[((tempValue = atomicAdd(counter, 1)), tempValue)] = someOtherValue;
36 //
37 // case 3 (atomic function used directly initialize a variable):
38 // original:
39 // int value = atomicAdd(counter, 1);
40 // new:
41 // tempValue = atomicAdd(counter, 1);
42 // int value = tempValue;
43 //
44 class RewriteAtomicFunctionExpressionsTraverser : public TIntermTraverser
45 {
46 public:
47 RewriteAtomicFunctionExpressionsTraverser(TSymbolTable *symbolTable, int shaderVersion);
48
49 bool visitAggregate(Visit visit, TIntermAggregate *node) override;
50 bool visitBlock(Visit visit, TIntermBlock *node) override;
51
52 private:
53 static bool IsAtomicExchangeOrCompSwapNoReturnValue(TIntermAggregate *node,
54 TIntermNode *parentNode);
55 static bool IsAtomicFunctionInsideExpression(TIntermAggregate *node, TIntermNode *parentNode);
56
57 void rewriteAtomicFunctionCallNode(TIntermAggregate *oldAtomicFunctionNode);
58
59 const TVariable *getTempVariable(const TType *type);
60
61 int mShaderVersion;
62 TIntermSequence mTempVariables;
63 };
64
RewriteAtomicFunctionExpressionsTraverser(TSymbolTable * symbolTable,int shaderVersion)65 RewriteAtomicFunctionExpressionsTraverser::RewriteAtomicFunctionExpressionsTraverser(
66 TSymbolTable *symbolTable,
67 int shaderVersion)
68 : TIntermTraverser(false, false, true, symbolTable), mShaderVersion(shaderVersion)
69 {}
70
rewriteAtomicFunctionCallNode(TIntermAggregate * oldAtomicFunctionNode)71 void RewriteAtomicFunctionExpressionsTraverser::rewriteAtomicFunctionCallNode(
72 TIntermAggregate *oldAtomicFunctionNode)
73 {
74 ASSERT(oldAtomicFunctionNode);
75
76 const TVariable *returnVariable = getTempVariable(&oldAtomicFunctionNode->getType());
77
78 TIntermBinary *rewrittenNode = new TIntermBinary(
79 TOperator::EOpAssign, CreateTempSymbolNode(returnVariable), oldAtomicFunctionNode);
80
81 auto *parentNode = getParentNode();
82
83 auto *parentBinary = parentNode->getAsBinaryNode();
84 if (parentBinary && parentBinary->getOp() == EOpInitialize)
85 {
86 insertStatementInParentBlock(rewrittenNode);
87 queueReplacement(CreateTempSymbolNode(returnVariable), OriginalNode::IS_DROPPED);
88 }
89 else
90 {
91 // As all atomic function assignment will be converted to the last argument of an
92 // interlocked function, if we need the return value, assignment needs to be wrapped with
93 // the comma operator and the temporary variables.
94 if (!parentNode->getAsBlock())
95 {
96 rewrittenNode = TIntermBinary::CreateComma(
97 rewrittenNode, new TIntermSymbol(returnVariable), mShaderVersion);
98 }
99
100 queueReplacement(rewrittenNode, OriginalNode::IS_DROPPED);
101 }
102 }
103
getTempVariable(const TType * type)104 const TVariable *RewriteAtomicFunctionExpressionsTraverser::getTempVariable(const TType *type)
105 {
106 TIntermDeclaration *variableDeclaration;
107 TVariable *returnVariable =
108 DeclareTempVariable(mSymbolTable, type, EvqTemporary, &variableDeclaration);
109 mTempVariables.push_back(variableDeclaration);
110 return returnVariable;
111 }
112
IsAtomicExchangeOrCompSwapNoReturnValue(TIntermAggregate * node,TIntermNode * parentNode)113 bool RewriteAtomicFunctionExpressionsTraverser::IsAtomicExchangeOrCompSwapNoReturnValue(
114 TIntermAggregate *node,
115 TIntermNode *parentNode)
116 {
117 ASSERT(node);
118 return (node->getOp() == EOpAtomicExchange || node->getOp() == EOpAtomicCompSwap) &&
119 parentNode && parentNode->getAsBlock();
120 }
121
IsAtomicFunctionInsideExpression(TIntermAggregate * node,TIntermNode * parentNode)122 bool RewriteAtomicFunctionExpressionsTraverser::IsAtomicFunctionInsideExpression(
123 TIntermAggregate *node,
124 TIntermNode *parentNode)
125 {
126 ASSERT(node);
127 // We only need to handle atomic functions with a parent that it is not block nodes. If the
128 // parent node is block, it means that the atomic function is not inside an expression.
129 if (!IsAtomicFunction(node->getOp()) || parentNode->getAsBlock())
130 {
131 return false;
132 }
133
134 auto *parentAsBinary = parentNode->getAsBinaryNode();
135 // Assignments are handled in OutputHLSL
136 return !parentAsBinary || parentAsBinary->getOp() != EOpAssign;
137 }
138
visitAggregate(Visit visit,TIntermAggregate * node)139 bool RewriteAtomicFunctionExpressionsTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
140 {
141 ASSERT(visit == PostVisit);
142 // Skip atomic memory functions for SSBO. They will be processed in the OutputHLSL traverser.
143 if (IsAtomicFunction(node->getOp()) &&
144 IsInShaderStorageBlock((*node->getSequence())[0]->getAsTyped()))
145 {
146 return false;
147 }
148
149 TIntermNode *parentNode = getParentNode();
150 if (IsAtomicExchangeOrCompSwapNoReturnValue(node, parentNode) ||
151 IsAtomicFunctionInsideExpression(node, parentNode))
152 {
153 rewriteAtomicFunctionCallNode(node);
154 }
155
156 return true;
157 }
158
visitBlock(Visit visit,TIntermBlock * node)159 bool RewriteAtomicFunctionExpressionsTraverser::visitBlock(Visit visit, TIntermBlock *node)
160 {
161 ASSERT(visit == PostVisit);
162
163 if (!mTempVariables.empty() && getParentNode()->getAsFunctionDefinition())
164 {
165 insertStatementsInBlockAtPosition(node, 0, mTempVariables, TIntermSequence());
166 mTempVariables.clear();
167 }
168
169 return true;
170 }
171
172 } // anonymous namespace
173
RewriteAtomicFunctionExpressions(TCompiler * compiler,TIntermNode * root,TSymbolTable * symbolTable,int shaderVersion)174 bool RewriteAtomicFunctionExpressions(TCompiler *compiler,
175 TIntermNode *root,
176 TSymbolTable *symbolTable,
177 int shaderVersion)
178 {
179 RewriteAtomicFunctionExpressionsTraverser traverser(symbolTable, shaderVersion);
180 traverser.traverse(root);
181 return traverser.updateTree(compiler, root);
182 }
183 } // namespace sh
184