1 //
2 // Copyright 2017 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // RunAtTheEndOfShader.cpp: Add code to be run at the end of the shader. In case main() contains a
7 // return statement, this is done by replacing the main() function with another function that calls
8 // the old main, like this:
9 //
10 // void main() { body }
11 // =>
12 // void main0() { body }
13 // void main()
14 // {
15 // main0();
16 // codeToRun
17 // }
18 //
19 // This way the code will get run even if the return statement inside main is executed.
20 //
21
22 #include "compiler/translator/tree_util/RunAtTheEndOfShader.h"
23
24 #include "compiler/translator/Compiler.h"
25 #include "compiler/translator/IntermNode.h"
26 #include "compiler/translator/StaticType.h"
27 #include "compiler/translator/SymbolTable.h"
28 #include "compiler/translator/tree_util/FindMain.h"
29 #include "compiler/translator/tree_util/IntermNode_util.h"
30 #include "compiler/translator/tree_util/IntermTraverse.h"
31
32 namespace sh
33 {
34
35 namespace
36 {
37
38 constexpr const ImmutableString kMainString("main");
39
40 class ContainsReturnTraverser : public TIntermTraverser
41 {
42 public:
ContainsReturnTraverser()43 ContainsReturnTraverser() : TIntermTraverser(true, false, false), mContainsReturn(false) {}
44
visitBranch(Visit visit,TIntermBranch * node)45 bool visitBranch(Visit visit, TIntermBranch *node) override
46 {
47 if (node->getFlowOp() == EOpReturn)
48 {
49 mContainsReturn = true;
50 }
51 return false;
52 }
53
containsReturn()54 bool containsReturn() { return mContainsReturn; }
55
56 private:
57 bool mContainsReturn;
58 };
59
ContainsReturn(TIntermNode * node)60 bool ContainsReturn(TIntermNode *node)
61 {
62 ContainsReturnTraverser traverser;
63 node->traverse(&traverser);
64 return traverser.containsReturn();
65 }
66
WrapMainAndAppend(TIntermBlock * root,TIntermFunctionDefinition * main,TIntermNode * codeToRun,TSymbolTable * symbolTable)67 void WrapMainAndAppend(TIntermBlock *root,
68 TIntermFunctionDefinition *main,
69 TIntermNode *codeToRun,
70 TSymbolTable *symbolTable)
71 {
72 // Replace main() with main0() with the same body.
73 TFunction *oldMain =
74 new TFunction(symbolTable, kEmptyImmutableString, SymbolType::AngleInternal,
75 StaticType::GetBasic<EbtVoid>(), false);
76 TIntermFunctionDefinition *oldMainDefinition =
77 CreateInternalFunctionDefinitionNode(*oldMain, main->getBody());
78
79 bool replaced = root->replaceChildNode(main, oldMainDefinition);
80 ASSERT(replaced);
81
82 // void main()
83 TFunction *newMain = new TFunction(symbolTable, kMainString, SymbolType::UserDefined,
84 StaticType::GetBasic<EbtVoid>(), false);
85 TIntermFunctionPrototype *newMainProto = new TIntermFunctionPrototype(newMain);
86
87 // {
88 // main0();
89 // codeToRun
90 // }
91 TIntermBlock *newMainBody = new TIntermBlock();
92 TIntermSequence emptySequence;
93 TIntermAggregate *oldMainCall = TIntermAggregate::CreateFunctionCall(*oldMain, &emptySequence);
94 newMainBody->appendStatement(oldMainCall);
95 newMainBody->appendStatement(codeToRun);
96
97 // Add the new main() to the root node.
98 TIntermFunctionDefinition *newMainDefinition =
99 new TIntermFunctionDefinition(newMainProto, newMainBody);
100 root->appendStatement(newMainDefinition);
101 }
102
103 } // anonymous namespace
104
RunAtTheEndOfShader(TCompiler * compiler,TIntermBlock * root,TIntermNode * codeToRun,TSymbolTable * symbolTable)105 bool RunAtTheEndOfShader(TCompiler *compiler,
106 TIntermBlock *root,
107 TIntermNode *codeToRun,
108 TSymbolTable *symbolTable)
109 {
110 TIntermFunctionDefinition *main = FindMain(root);
111 if (!ContainsReturn(main))
112 {
113 main->getBody()->appendStatement(codeToRun);
114 }
115 else
116 {
117 WrapMainAndAppend(root, main, codeToRun, symbolTable);
118 }
119
120 return compiler->validateAST(root);
121 }
122
123 } // namespace sh
124