1 //
2 // Copyright 2002 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // The ArrayReturnValueToOutParameter function changes return values of an array type to out
7 // parameters in function definitions, prototypes, and call sites.
8
9 #include "compiler/translator/tree_ops/ArrayReturnValueToOutParameter.h"
10
11 #include <map>
12
13 #include "compiler/translator/StaticType.h"
14 #include "compiler/translator/SymbolTable.h"
15 #include "compiler/translator/tree_util/IntermNode_util.h"
16 #include "compiler/translator/tree_util/IntermTraverse.h"
17
18 namespace sh
19 {
20
21 namespace
22 {
23
24 constexpr const ImmutableString kReturnValueVariableName("angle_return");
25
26 class ArrayReturnValueToOutParameterTraverser : private TIntermTraverser
27 {
28 public:
29 ANGLE_NO_DISCARD static bool apply(TCompiler *compiler,
30 TIntermNode *root,
31 TSymbolTable *symbolTable);
32
33 private:
34 ArrayReturnValueToOutParameterTraverser(TSymbolTable *symbolTable);
35
36 void visitFunctionPrototype(TIntermFunctionPrototype *node) override;
37 bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override;
38 bool visitAggregate(Visit visit, TIntermAggregate *node) override;
39 bool visitBranch(Visit visit, TIntermBranch *node) override;
40 bool visitBinary(Visit visit, TIntermBinary *node) override;
41
42 TIntermAggregate *createReplacementCall(TIntermAggregate *originalCall,
43 TIntermTyped *returnValueTarget);
44
45 // Set when traversal is inside a function with array return value.
46 TIntermFunctionDefinition *mFunctionWithArrayReturnValue;
47
48 struct ChangedFunction
49 {
50 const TVariable *returnValueVariable;
51 const TFunction *func;
52 };
53
54 // Map from function symbol ids to the changed function.
55 std::map<int, ChangedFunction> mChangedFunctions;
56 };
57
createReplacementCall(TIntermAggregate * originalCall,TIntermTyped * returnValueTarget)58 TIntermAggregate *ArrayReturnValueToOutParameterTraverser::createReplacementCall(
59 TIntermAggregate *originalCall,
60 TIntermTyped *returnValueTarget)
61 {
62 TIntermSequence *replacementArguments = new TIntermSequence();
63 TIntermSequence *originalArguments = originalCall->getSequence();
64 for (auto &arg : *originalArguments)
65 {
66 replacementArguments->push_back(arg);
67 }
68 replacementArguments->push_back(returnValueTarget);
69 ASSERT(originalCall->getFunction());
70 const TSymbolUniqueId &originalId = originalCall->getFunction()->uniqueId();
71 TIntermAggregate *replacementCall = TIntermAggregate::CreateFunctionCall(
72 *mChangedFunctions[originalId.get()].func, replacementArguments);
73 replacementCall->setLine(originalCall->getLine());
74 return replacementCall;
75 }
76
apply(TCompiler * compiler,TIntermNode * root,TSymbolTable * symbolTable)77 bool ArrayReturnValueToOutParameterTraverser::apply(TCompiler *compiler,
78 TIntermNode *root,
79 TSymbolTable *symbolTable)
80 {
81 ArrayReturnValueToOutParameterTraverser arrayReturnValueToOutParam(symbolTable);
82 root->traverse(&arrayReturnValueToOutParam);
83 return arrayReturnValueToOutParam.updateTree(compiler, root);
84 }
85
ArrayReturnValueToOutParameterTraverser(TSymbolTable * symbolTable)86 ArrayReturnValueToOutParameterTraverser::ArrayReturnValueToOutParameterTraverser(
87 TSymbolTable *symbolTable)
88 : TIntermTraverser(true, false, true, symbolTable), mFunctionWithArrayReturnValue(nullptr)
89 {}
90
visitFunctionDefinition(Visit visit,TIntermFunctionDefinition * node)91 bool ArrayReturnValueToOutParameterTraverser::visitFunctionDefinition(
92 Visit visit,
93 TIntermFunctionDefinition *node)
94 {
95 if (node->getFunctionPrototype()->isArray() && visit == PreVisit)
96 {
97 // Replacing the function header is done on visitFunctionPrototype().
98 mFunctionWithArrayReturnValue = node;
99 }
100 if (visit == PostVisit)
101 {
102 mFunctionWithArrayReturnValue = nullptr;
103 }
104 return true;
105 }
106
visitFunctionPrototype(TIntermFunctionPrototype * node)107 void ArrayReturnValueToOutParameterTraverser::visitFunctionPrototype(TIntermFunctionPrototype *node)
108 {
109 if (node->isArray())
110 {
111 // Replace the whole prototype node with another node that has the out parameter
112 // added. Also set the function to return void.
113 const TSymbolUniqueId &functionId = node->getFunction()->uniqueId();
114 if (mChangedFunctions.find(functionId.get()) == mChangedFunctions.end())
115 {
116 TType *returnValueVariableType = new TType(node->getType());
117 returnValueVariableType->setQualifier(EvqOut);
118 ChangedFunction changedFunction;
119 changedFunction.returnValueVariable =
120 new TVariable(mSymbolTable, kReturnValueVariableName, returnValueVariableType,
121 SymbolType::AngleInternal);
122 TFunction *func = new TFunction(mSymbolTable, node->getFunction()->name(),
123 node->getFunction()->symbolType(),
124 StaticType::GetBasic<EbtVoid>(), false);
125 for (size_t i = 0; i < node->getFunction()->getParamCount(); ++i)
126 {
127 func->addParameter(node->getFunction()->getParam(i));
128 }
129 func->addParameter(changedFunction.returnValueVariable);
130 changedFunction.func = func;
131 mChangedFunctions[functionId.get()] = changedFunction;
132 }
133 TIntermFunctionPrototype *replacement =
134 new TIntermFunctionPrototype(mChangedFunctions[functionId.get()].func);
135 replacement->setLine(node->getLine());
136
137 queueReplacement(replacement, OriginalNode::IS_DROPPED);
138 }
139 }
140
visitAggregate(Visit visit,TIntermAggregate * node)141 bool ArrayReturnValueToOutParameterTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
142 {
143 ASSERT(!node->isArray() || node->getOp() != EOpCallInternalRawFunction);
144 if (visit == PreVisit && node->isArray() && node->getOp() == EOpCallFunctionInAST)
145 {
146 // Handle call sites where the returned array is not assigned.
147 // Examples where f() is a function returning an array:
148 // 1. f();
149 // 2. another_array == f();
150 // 3. another_function(f());
151 // 4. return f();
152 // Cases 2 to 4 are already converted to simpler cases by
153 // SeparateExpressionsReturningArrays, so we only need to worry about the case where a
154 // function call returning an array forms an expression by itself.
155 TIntermBlock *parentBlock = getParentNode()->getAsBlock();
156 if (parentBlock)
157 {
158 // replace
159 // f();
160 // with
161 // type s0[size]; f(s0);
162 TIntermSequence replacements;
163
164 // type s0[size];
165 TIntermDeclaration *returnValueDeclaration = nullptr;
166 TVariable *returnValue = DeclareTempVariable(mSymbolTable, new TType(node->getType()),
167 EvqTemporary, &returnValueDeclaration);
168 replacements.push_back(returnValueDeclaration);
169
170 // f(s0);
171 TIntermSymbol *returnValueSymbol = CreateTempSymbolNode(returnValue);
172 replacements.push_back(createReplacementCall(node, returnValueSymbol));
173 mMultiReplacements.push_back(
174 NodeReplaceWithMultipleEntry(parentBlock, node, replacements));
175 }
176 return false;
177 }
178 return true;
179 }
180
visitBranch(Visit visit,TIntermBranch * node)181 bool ArrayReturnValueToOutParameterTraverser::visitBranch(Visit visit, TIntermBranch *node)
182 {
183 if (mFunctionWithArrayReturnValue && node->getFlowOp() == EOpReturn)
184 {
185 // Instead of returning a value, assign to the out parameter and then return.
186 TIntermSequence replacements;
187
188 TIntermTyped *expression = node->getExpression();
189 ASSERT(expression != nullptr);
190 const TSymbolUniqueId &functionId =
191 mFunctionWithArrayReturnValue->getFunction()->uniqueId();
192 ASSERT(mChangedFunctions.find(functionId.get()) != mChangedFunctions.end());
193 TIntermSymbol *returnValueSymbol =
194 new TIntermSymbol(mChangedFunctions[functionId.get()].returnValueVariable);
195 TIntermBinary *replacementAssignment =
196 new TIntermBinary(EOpAssign, returnValueSymbol, expression);
197 replacementAssignment->setLine(expression->getLine());
198 replacements.push_back(replacementAssignment);
199
200 TIntermBranch *replacementBranch = new TIntermBranch(EOpReturn, nullptr);
201 replacementBranch->setLine(node->getLine());
202 replacements.push_back(replacementBranch);
203
204 mMultiReplacements.push_back(
205 NodeReplaceWithMultipleEntry(getParentNode()->getAsBlock(), node, replacements));
206 }
207 return false;
208 }
209
visitBinary(Visit visit,TIntermBinary * node)210 bool ArrayReturnValueToOutParameterTraverser::visitBinary(Visit visit, TIntermBinary *node)
211 {
212 if (node->getOp() == EOpAssign && node->getLeft()->isArray())
213 {
214 TIntermAggregate *rightAgg = node->getRight()->getAsAggregate();
215 ASSERT(rightAgg == nullptr || rightAgg->getOp() != EOpCallInternalRawFunction);
216 if (rightAgg != nullptr && rightAgg->getOp() == EOpCallFunctionInAST)
217 {
218 TIntermAggregate *replacementCall = createReplacementCall(rightAgg, node->getLeft());
219 queueReplacement(replacementCall, OriginalNode::IS_DROPPED);
220 }
221 }
222 return false;
223 }
224
225 } // namespace
226
ArrayReturnValueToOutParameter(TCompiler * compiler,TIntermNode * root,TSymbolTable * symbolTable)227 bool ArrayReturnValueToOutParameter(TCompiler *compiler,
228 TIntermNode *root,
229 TSymbolTable *symbolTable)
230 {
231 return ArrayReturnValueToOutParameterTraverser::apply(compiler, root, symbolTable);
232 }
233
234 } // namespace sh
235