• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright 2002 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // The ArrayReturnValueToOutParameter function changes return values of an array type to out
7 // parameters in function definitions, prototypes, and call sites.
8 
9 #include "compiler/translator/tree_ops/d3d/ArrayReturnValueToOutParameter.h"
10 
11 #include <map>
12 
13 #include "compiler/translator/StaticType.h"
14 #include "compiler/translator/SymbolTable.h"
15 #include "compiler/translator/tree_util/IntermNode_util.h"
16 #include "compiler/translator/tree_util/IntermTraverse.h"
17 
18 namespace sh
19 {
20 
21 namespace
22 {
23 
24 constexpr const ImmutableString kReturnValueVariableName("angle_return");
25 
26 class ArrayReturnValueToOutParameterTraverser : private TIntermTraverser
27 {
28   public:
29     ANGLE_NO_DISCARD static bool apply(TCompiler *compiler,
30                                        TIntermNode *root,
31                                        TSymbolTable *symbolTable);
32 
33   private:
34     ArrayReturnValueToOutParameterTraverser(TSymbolTable *symbolTable);
35 
36     void visitFunctionPrototype(TIntermFunctionPrototype *node) override;
37     bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override;
38     bool visitAggregate(Visit visit, TIntermAggregate *node) override;
39     bool visitBranch(Visit visit, TIntermBranch *node) override;
40     bool visitBinary(Visit visit, TIntermBinary *node) override;
41 
42     TIntermAggregate *createReplacementCall(TIntermAggregate *originalCall,
43                                             TIntermTyped *returnValueTarget);
44 
45     // Set when traversal is inside a function with array return value.
46     TIntermFunctionDefinition *mFunctionWithArrayReturnValue;
47 
48     struct ChangedFunction
49     {
50         const TVariable *returnValueVariable;
51         const TFunction *func;
52     };
53 
54     // Map from function symbol ids to the changed function.
55     std::map<int, ChangedFunction> mChangedFunctions;
56 };
57 
createReplacementCall(TIntermAggregate * originalCall,TIntermTyped * returnValueTarget)58 TIntermAggregate *ArrayReturnValueToOutParameterTraverser::createReplacementCall(
59     TIntermAggregate *originalCall,
60     TIntermTyped *returnValueTarget)
61 {
62     TIntermSequence replacementArguments;
63     TIntermSequence *originalArguments = originalCall->getSequence();
64     for (auto &arg : *originalArguments)
65     {
66         replacementArguments.push_back(arg);
67     }
68     replacementArguments.push_back(returnValueTarget);
69     ASSERT(originalCall->getFunction());
70     const TSymbolUniqueId &originalId = originalCall->getFunction()->uniqueId();
71     TIntermAggregate *replacementCall = TIntermAggregate::CreateFunctionCall(
72         *mChangedFunctions[originalId.get()].func, &replacementArguments);
73     replacementCall->setLine(originalCall->getLine());
74     return replacementCall;
75 }
76 
apply(TCompiler * compiler,TIntermNode * root,TSymbolTable * symbolTable)77 bool ArrayReturnValueToOutParameterTraverser::apply(TCompiler *compiler,
78                                                     TIntermNode *root,
79                                                     TSymbolTable *symbolTable)
80 {
81     ArrayReturnValueToOutParameterTraverser arrayReturnValueToOutParam(symbolTable);
82     root->traverse(&arrayReturnValueToOutParam);
83     return arrayReturnValueToOutParam.updateTree(compiler, root);
84 }
85 
ArrayReturnValueToOutParameterTraverser(TSymbolTable * symbolTable)86 ArrayReturnValueToOutParameterTraverser::ArrayReturnValueToOutParameterTraverser(
87     TSymbolTable *symbolTable)
88     : TIntermTraverser(true, false, true, symbolTable), mFunctionWithArrayReturnValue(nullptr)
89 {}
90 
visitFunctionDefinition(Visit visit,TIntermFunctionDefinition * node)91 bool ArrayReturnValueToOutParameterTraverser::visitFunctionDefinition(
92     Visit visit,
93     TIntermFunctionDefinition *node)
94 {
95     if (node->getFunctionPrototype()->isArray() && visit == PreVisit)
96     {
97         // Replacing the function header is done on visitFunctionPrototype().
98         mFunctionWithArrayReturnValue = node;
99     }
100     if (visit == PostVisit)
101     {
102         mFunctionWithArrayReturnValue = nullptr;
103     }
104     return true;
105 }
106 
visitFunctionPrototype(TIntermFunctionPrototype * node)107 void ArrayReturnValueToOutParameterTraverser::visitFunctionPrototype(TIntermFunctionPrototype *node)
108 {
109     if (node->isArray())
110     {
111         // Replace the whole prototype node with another node that has the out parameter
112         // added. Also set the function to return void.
113         const TSymbolUniqueId &functionId = node->getFunction()->uniqueId();
114         if (mChangedFunctions.find(functionId.get()) == mChangedFunctions.end())
115         {
116             TType *returnValueVariableType = new TType(node->getType());
117             returnValueVariableType->setQualifier(EvqOut);
118             ChangedFunction changedFunction;
119             changedFunction.returnValueVariable =
120                 new TVariable(mSymbolTable, kReturnValueVariableName, returnValueVariableType,
121                               SymbolType::AngleInternal);
122             TFunction *func = new TFunction(mSymbolTable, node->getFunction()->name(),
123                                             node->getFunction()->symbolType(),
124                                             StaticType::GetBasic<EbtVoid>(), false);
125             for (size_t i = 0; i < node->getFunction()->getParamCount(); ++i)
126             {
127                 func->addParameter(node->getFunction()->getParam(i));
128             }
129             func->addParameter(changedFunction.returnValueVariable);
130             changedFunction.func                = func;
131             mChangedFunctions[functionId.get()] = changedFunction;
132         }
133         TIntermFunctionPrototype *replacement =
134             new TIntermFunctionPrototype(mChangedFunctions[functionId.get()].func);
135         replacement->setLine(node->getLine());
136 
137         queueReplacement(replacement, OriginalNode::IS_DROPPED);
138     }
139 }
140 
visitAggregate(Visit visit,TIntermAggregate * node)141 bool ArrayReturnValueToOutParameterTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
142 {
143     ASSERT(!node->isArray() || node->getOp() != EOpCallInternalRawFunction);
144     if (visit == PreVisit && node->isArray() && node->getOp() == EOpCallFunctionInAST)
145     {
146         // Handle call sites where the returned array is not assigned.
147         // Examples where f() is a function returning an array:
148         // 1. f();
149         // 2. another_array == f();
150         // 3. another_function(f());
151         // 4. return f();
152         // Cases 2 to 4 are already converted to simpler cases by
153         // SeparateExpressionsReturningArrays, so we only need to worry about the case where a
154         // function call returning an array forms an expression by itself.
155         TIntermBlock *parentBlock = getParentNode()->getAsBlock();
156         if (parentBlock)
157         {
158             // replace
159             //   f();
160             // with
161             //   type s0[size]; f(s0);
162             TIntermSequence replacements;
163 
164             // type s0[size];
165             TIntermDeclaration *returnValueDeclaration = nullptr;
166             TVariable *returnValue = DeclareTempVariable(mSymbolTable, new TType(node->getType()),
167                                                          EvqTemporary, &returnValueDeclaration);
168             replacements.push_back(returnValueDeclaration);
169 
170             // f(s0);
171             TIntermSymbol *returnValueSymbol = CreateTempSymbolNode(returnValue);
172             replacements.push_back(createReplacementCall(node, returnValueSymbol));
173             mMultiReplacements.emplace_back(parentBlock, node, std::move(replacements));
174         }
175         return false;
176     }
177     return true;
178 }
179 
visitBranch(Visit visit,TIntermBranch * node)180 bool ArrayReturnValueToOutParameterTraverser::visitBranch(Visit visit, TIntermBranch *node)
181 {
182     if (mFunctionWithArrayReturnValue && node->getFlowOp() == EOpReturn)
183     {
184         // Instead of returning a value, assign to the out parameter and then return.
185         TIntermSequence replacements;
186 
187         TIntermTyped *expression = node->getExpression();
188         ASSERT(expression != nullptr);
189         const TSymbolUniqueId &functionId =
190             mFunctionWithArrayReturnValue->getFunction()->uniqueId();
191         ASSERT(mChangedFunctions.find(functionId.get()) != mChangedFunctions.end());
192         TIntermSymbol *returnValueSymbol =
193             new TIntermSymbol(mChangedFunctions[functionId.get()].returnValueVariable);
194         TIntermBinary *replacementAssignment =
195             new TIntermBinary(EOpAssign, returnValueSymbol, expression);
196         replacementAssignment->setLine(expression->getLine());
197         replacements.push_back(replacementAssignment);
198 
199         TIntermBranch *replacementBranch = new TIntermBranch(EOpReturn, nullptr);
200         replacementBranch->setLine(node->getLine());
201         replacements.push_back(replacementBranch);
202 
203         mMultiReplacements.emplace_back(getParentNode()->getAsBlock(), node,
204                                         std::move(replacements));
205     }
206     return false;
207 }
208 
visitBinary(Visit visit,TIntermBinary * node)209 bool ArrayReturnValueToOutParameterTraverser::visitBinary(Visit visit, TIntermBinary *node)
210 {
211     if (node->getOp() == EOpAssign && node->getLeft()->isArray())
212     {
213         TIntermAggregate *rightAgg = node->getRight()->getAsAggregate();
214         ASSERT(rightAgg == nullptr || rightAgg->getOp() != EOpCallInternalRawFunction);
215         if (rightAgg != nullptr && rightAgg->getOp() == EOpCallFunctionInAST)
216         {
217             TIntermAggregate *replacementCall = createReplacementCall(rightAgg, node->getLeft());
218             queueReplacement(replacementCall, OriginalNode::IS_DROPPED);
219         }
220     }
221     return false;
222 }
223 
224 }  // namespace
225 
ArrayReturnValueToOutParameter(TCompiler * compiler,TIntermNode * root,TSymbolTable * symbolTable)226 bool ArrayReturnValueToOutParameter(TCompiler *compiler,
227                                     TIntermNode *root,
228                                     TSymbolTable *symbolTable)
229 {
230     return ArrayReturnValueToOutParameterTraverser::apply(compiler, root, symbolTable);
231 }
232 
233 }  // namespace sh
234