1 //
2 // Copyright 2024 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // SeparateStructFromFunctionDeclarations: Separate struct declarations from function declaration
7 // return type.
8 //
9
10 #include "compiler/translator/tree_ops/SeparateStructFromFunctionDeclarations.h"
11 #include "compiler/translator/Compiler.h"
12 #include "compiler/translator/IntermRebuild.h"
13 #include "compiler/translator/SymbolTable.h"
14
15 namespace sh
16 {
17 namespace
18 {
19 class SeparateStructFromFunctionDeclarationsTraverser : public TIntermRebuild
20 {
21 public:
SeparateStructFromFunctionDeclarationsTraverser(TCompiler & compiler)22 explicit SeparateStructFromFunctionDeclarationsTraverser(TCompiler &compiler)
23 : TIntermRebuild(compiler, true, true)
24 {}
25
visitFunctionPrototypePre(TIntermFunctionPrototype & node)26 PreResult visitFunctionPrototypePre(TIntermFunctionPrototype &node) override
27 {
28 const TFunction *function = node.getFunction();
29 if (mFunctionsToReplace.count(function) > 0)
30 {
31 TIntermFunctionPrototype *newFuncProto =
32 new TIntermFunctionPrototype(mFunctionsToReplace[function]);
33 return newFuncProto;
34 }
35 else if (node.getType().isStructSpecifier())
36 {
37 const TType &oldType = node.getType();
38 const TStructure *structure = oldType.getStruct();
39 // Name unnamed inline structs
40 if (structure->symbolType() == SymbolType::Empty)
41 {
42 structure = new TStructure(&mSymbolTable, kEmptyImmutableString,
43 &structure->fields(), SymbolType::AngleInternal);
44 }
45
46 TVariable *structVar = new TVariable(&mSymbolTable, ImmutableString(""),
47 new TType(structure, true), SymbolType::Empty);
48 ASSERT(!mStructDeclarations.empty());
49 mStructDeclarations.back().push_back(new TIntermDeclaration({structVar}));
50
51 TType *returnType = new TType(structure, false);
52 if (oldType.isArray())
53 {
54 returnType->makeArrays(oldType.getArraySizes());
55 }
56 returnType->setQualifier(oldType.getQualifier());
57
58 const TFunction *oldFunc = function;
59 ASSERT(oldFunc->symbolType() == SymbolType::UserDefined);
60
61 const TFunction *newFunc = cloneFunctionAndChangeReturnType(oldFunc, returnType);
62 mFunctionsToReplace[oldFunc] = newFunc;
63
64 return new TIntermFunctionPrototype(newFunc);
65 }
66
67 return node;
68 }
69
visitAggregatePre(TIntermAggregate & node)70 PreResult visitAggregatePre(TIntermAggregate &node) override
71 {
72 const TFunction *function = node.getFunction();
73 if (mFunctionsToReplace.count(function) > 0)
74 {
75 TIntermAggregate *replacementNode = TIntermAggregate::CreateFunctionCall(
76 *mFunctionsToReplace[function], node.getSequence());
77
78 return PreResult(replacementNode, VisitBits::Children);
79 }
80
81 return node;
82 }
83
visitBlockPre(TIntermBlock & node)84 PreResult visitBlockPre(TIntermBlock &node) override
85 {
86 mStructDeclarations.push_back({});
87 return node;
88 }
89
visitBlockPost(TIntermBlock & node)90 PostResult visitBlockPost(TIntermBlock &node) override
91 {
92 ASSERT(!mStructDeclarations.empty());
93
94 std::vector<TIntermDeclaration *> declarations = mStructDeclarations.back();
95 mStructDeclarations.pop_back();
96
97 if (!declarations.empty())
98 {
99 TIntermBlock *blockWithStructDeclarations = new TIntermBlock();
100 if (node.isTreeRoot())
101 {
102 blockWithStructDeclarations->setIsTreeRoot();
103 }
104
105 for (TIntermDeclaration *structDecl : declarations)
106 {
107 blockWithStructDeclarations->appendStatement(structDecl);
108 }
109
110 for (TIntermNode *statement : *node.getSequence())
111 {
112 blockWithStructDeclarations->appendStatement(statement);
113 }
114
115 return blockWithStructDeclarations;
116 }
117
118 return node;
119 }
120
121 private:
cloneFunctionAndChangeReturnType(const TFunction * oldFunc,const TType * newReturnType)122 const TFunction *cloneFunctionAndChangeReturnType(const TFunction *oldFunc,
123 const TType *newReturnType)
124
125 {
126 ASSERT(oldFunc->symbolType() == SymbolType::UserDefined);
127
128 TFunction *newFunc = new TFunction(&mSymbolTable, oldFunc->name(), oldFunc->symbolType(),
129 newReturnType, oldFunc->isKnownToNotHaveSideEffects());
130
131 if (oldFunc->isDefined())
132 {
133 newFunc->setDefined();
134 }
135
136 if (oldFunc->hasPrototypeDeclaration())
137 {
138 newFunc->setHasPrototypeDeclaration();
139 }
140
141 const size_t paramCount = oldFunc->getParamCount();
142 for (size_t i = 0; i < paramCount; ++i)
143 {
144 const TVariable *var = oldFunc->getParam(i);
145 newFunc->addParameter(var);
146 }
147
148 return newFunc;
149 }
150
151 using FunctionReplacement = angle::HashMap<const TFunction *, const TFunction *>;
152 FunctionReplacement mFunctionsToReplace;
153
154 // Stack of struct declarations to insert per block
155 std::vector<std::vector<TIntermDeclaration *>> mStructDeclarations;
156 };
157
158 } // anonymous namespace
159
SeparateStructFromFunctionDeclarations(TCompiler & compiler,TIntermBlock & root)160 bool SeparateStructFromFunctionDeclarations(TCompiler &compiler, TIntermBlock &root)
161 {
162 SeparateStructFromFunctionDeclarationsTraverser separateStructDecls(compiler);
163 return separateStructDecls.rebuildRoot(root);
164 }
165 } // namespace sh
166