1 //
2 // Copyright 2002 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // The ArrayReturnValueToOutParameter function changes return values of an array type to out
7 // parameters in function definitions, prototypes, and call sites.
8
9 #include "compiler/translator/tree_ops/d3d/ArrayReturnValueToOutParameter.h"
10
11 #include <map>
12
13 #include "compiler/translator/StaticType.h"
14 #include "compiler/translator/SymbolTable.h"
15 #include "compiler/translator/tree_util/IntermNode_util.h"
16 #include "compiler/translator/tree_util/IntermTraverse.h"
17
18 namespace sh
19 {
20
21 namespace
22 {
23
24 constexpr const ImmutableString kReturnValueVariableName("angle_return");
25
26 class ArrayReturnValueToOutParameterTraverser : private TIntermTraverser
27 {
28 public:
29 ANGLE_NO_DISCARD static bool apply(TCompiler *compiler,
30 TIntermNode *root,
31 TSymbolTable *symbolTable);
32
33 private:
34 ArrayReturnValueToOutParameterTraverser(TSymbolTable *symbolTable);
35
36 void visitFunctionPrototype(TIntermFunctionPrototype *node) override;
37 bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override;
38 bool visitAggregate(Visit visit, TIntermAggregate *node) override;
39 bool visitBranch(Visit visit, TIntermBranch *node) override;
40 bool visitBinary(Visit visit, TIntermBinary *node) override;
41
42 TIntermAggregate *createReplacementCall(TIntermAggregate *originalCall,
43 TIntermTyped *returnValueTarget);
44
45 // Set when traversal is inside a function with array return value.
46 TIntermFunctionDefinition *mFunctionWithArrayReturnValue;
47
48 struct ChangedFunction
49 {
50 const TVariable *returnValueVariable;
51 const TFunction *func;
52 };
53
54 // Map from function symbol ids to the changed function.
55 std::map<int, ChangedFunction> mChangedFunctions;
56 };
57
createReplacementCall(TIntermAggregate * originalCall,TIntermTyped * returnValueTarget)58 TIntermAggregate *ArrayReturnValueToOutParameterTraverser::createReplacementCall(
59 TIntermAggregate *originalCall,
60 TIntermTyped *returnValueTarget)
61 {
62 TIntermSequence replacementArguments;
63 TIntermSequence *originalArguments = originalCall->getSequence();
64 for (auto &arg : *originalArguments)
65 {
66 replacementArguments.push_back(arg);
67 }
68 replacementArguments.push_back(returnValueTarget);
69 ASSERT(originalCall->getFunction());
70 const TSymbolUniqueId &originalId = originalCall->getFunction()->uniqueId();
71 TIntermAggregate *replacementCall = TIntermAggregate::CreateFunctionCall(
72 *mChangedFunctions[originalId.get()].func, &replacementArguments);
73 replacementCall->setLine(originalCall->getLine());
74 return replacementCall;
75 }
76
apply(TCompiler * compiler,TIntermNode * root,TSymbolTable * symbolTable)77 bool ArrayReturnValueToOutParameterTraverser::apply(TCompiler *compiler,
78 TIntermNode *root,
79 TSymbolTable *symbolTable)
80 {
81 ArrayReturnValueToOutParameterTraverser arrayReturnValueToOutParam(symbolTable);
82 root->traverse(&arrayReturnValueToOutParam);
83 return arrayReturnValueToOutParam.updateTree(compiler, root);
84 }
85
ArrayReturnValueToOutParameterTraverser(TSymbolTable * symbolTable)86 ArrayReturnValueToOutParameterTraverser::ArrayReturnValueToOutParameterTraverser(
87 TSymbolTable *symbolTable)
88 : TIntermTraverser(true, false, true, symbolTable), mFunctionWithArrayReturnValue(nullptr)
89 {}
90
visitFunctionDefinition(Visit visit,TIntermFunctionDefinition * node)91 bool ArrayReturnValueToOutParameterTraverser::visitFunctionDefinition(
92 Visit visit,
93 TIntermFunctionDefinition *node)
94 {
95 if (node->getFunctionPrototype()->isArray() && visit == PreVisit)
96 {
97 // Replacing the function header is done on visitFunctionPrototype().
98 mFunctionWithArrayReturnValue = node;
99 }
100 if (visit == PostVisit)
101 {
102 mFunctionWithArrayReturnValue = nullptr;
103 }
104 return true;
105 }
106
visitFunctionPrototype(TIntermFunctionPrototype * node)107 void ArrayReturnValueToOutParameterTraverser::visitFunctionPrototype(TIntermFunctionPrototype *node)
108 {
109 if (node->isArray())
110 {
111 // Replace the whole prototype node with another node that has the out parameter
112 // added. Also set the function to return void.
113 const TSymbolUniqueId &functionId = node->getFunction()->uniqueId();
114 if (mChangedFunctions.find(functionId.get()) == mChangedFunctions.end())
115 {
116 TType *returnValueVariableType = new TType(node->getType());
117 returnValueVariableType->setQualifier(EvqOut);
118 ChangedFunction changedFunction;
119 changedFunction.returnValueVariable =
120 new TVariable(mSymbolTable, kReturnValueVariableName, returnValueVariableType,
121 SymbolType::AngleInternal);
122 TFunction *func = new TFunction(mSymbolTable, node->getFunction()->name(),
123 node->getFunction()->symbolType(),
124 StaticType::GetBasic<EbtVoid>(), false);
125 for (size_t i = 0; i < node->getFunction()->getParamCount(); ++i)
126 {
127 func->addParameter(node->getFunction()->getParam(i));
128 }
129 func->addParameter(changedFunction.returnValueVariable);
130 changedFunction.func = func;
131 mChangedFunctions[functionId.get()] = changedFunction;
132 }
133 TIntermFunctionPrototype *replacement =
134 new TIntermFunctionPrototype(mChangedFunctions[functionId.get()].func);
135 replacement->setLine(node->getLine());
136
137 queueReplacement(replacement, OriginalNode::IS_DROPPED);
138 }
139 }
140
visitAggregate(Visit visit,TIntermAggregate * node)141 bool ArrayReturnValueToOutParameterTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
142 {
143 ASSERT(!node->isArray() || node->getOp() != EOpCallInternalRawFunction);
144 if (visit == PreVisit && node->isArray() && node->getOp() == EOpCallFunctionInAST)
145 {
146 // Handle call sites where the returned array is not assigned.
147 // Examples where f() is a function returning an array:
148 // 1. f();
149 // 2. another_array == f();
150 // 3. another_function(f());
151 // 4. return f();
152 // Cases 2 to 4 are already converted to simpler cases by
153 // SeparateExpressionsReturningArrays, so we only need to worry about the case where a
154 // function call returning an array forms an expression by itself.
155 TIntermBlock *parentBlock = getParentNode()->getAsBlock();
156 if (parentBlock)
157 {
158 // replace
159 // f();
160 // with
161 // type s0[size]; f(s0);
162 TIntermSequence replacements;
163
164 // type s0[size];
165 TIntermDeclaration *returnValueDeclaration = nullptr;
166 TVariable *returnValue = DeclareTempVariable(mSymbolTable, new TType(node->getType()),
167 EvqTemporary, &returnValueDeclaration);
168 replacements.push_back(returnValueDeclaration);
169
170 // f(s0);
171 TIntermSymbol *returnValueSymbol = CreateTempSymbolNode(returnValue);
172 replacements.push_back(createReplacementCall(node, returnValueSymbol));
173 mMultiReplacements.emplace_back(parentBlock, node, std::move(replacements));
174 }
175 return false;
176 }
177 return true;
178 }
179
visitBranch(Visit visit,TIntermBranch * node)180 bool ArrayReturnValueToOutParameterTraverser::visitBranch(Visit visit, TIntermBranch *node)
181 {
182 if (mFunctionWithArrayReturnValue && node->getFlowOp() == EOpReturn)
183 {
184 // Instead of returning a value, assign to the out parameter and then return.
185 TIntermSequence replacements;
186
187 TIntermTyped *expression = node->getExpression();
188 ASSERT(expression != nullptr);
189 const TSymbolUniqueId &functionId =
190 mFunctionWithArrayReturnValue->getFunction()->uniqueId();
191 ASSERT(mChangedFunctions.find(functionId.get()) != mChangedFunctions.end());
192 TIntermSymbol *returnValueSymbol =
193 new TIntermSymbol(mChangedFunctions[functionId.get()].returnValueVariable);
194 TIntermBinary *replacementAssignment =
195 new TIntermBinary(EOpAssign, returnValueSymbol, expression);
196 replacementAssignment->setLine(expression->getLine());
197 replacements.push_back(replacementAssignment);
198
199 TIntermBranch *replacementBranch = new TIntermBranch(EOpReturn, nullptr);
200 replacementBranch->setLine(node->getLine());
201 replacements.push_back(replacementBranch);
202
203 mMultiReplacements.emplace_back(getParentNode()->getAsBlock(), node,
204 std::move(replacements));
205 }
206 return false;
207 }
208
visitBinary(Visit visit,TIntermBinary * node)209 bool ArrayReturnValueToOutParameterTraverser::visitBinary(Visit visit, TIntermBinary *node)
210 {
211 if (node->getOp() == EOpAssign && node->getLeft()->isArray())
212 {
213 TIntermAggregate *rightAgg = node->getRight()->getAsAggregate();
214 ASSERT(rightAgg == nullptr || rightAgg->getOp() != EOpCallInternalRawFunction);
215 if (rightAgg != nullptr && rightAgg->getOp() == EOpCallFunctionInAST)
216 {
217 TIntermAggregate *replacementCall = createReplacementCall(rightAgg, node->getLeft());
218 queueReplacement(replacementCall, OriginalNode::IS_DROPPED);
219 }
220 }
221 return false;
222 }
223
224 } // namespace
225
ArrayReturnValueToOutParameter(TCompiler * compiler,TIntermNode * root,TSymbolTable * symbolTable)226 bool ArrayReturnValueToOutParameter(TCompiler *compiler,
227 TIntermNode *root,
228 TSymbolTable *symbolTable)
229 {
230 return ArrayReturnValueToOutParameterTraverser::apply(compiler, root, symbolTable);
231 }
232
233 } // namespace sh
234