1 //
2 // Copyright 2017 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // RunAtTheEndOfShader.cpp: Add code to be run at the end of the shader. In case main() contains a
7 // return statement, this is done by replacing the main() function with another function that calls
8 // the old main, like this:
9 //
10 // void main() { body }
11 // =>
12 // void main0() { body }
13 // void main()
14 // {
15 // main0();
16 // codeToRun
17 // }
18 //
19 // This way the code will get run even if the return statement inside main is executed.
20 //
21 // This is done if main ends in an unconditional |discard| as well, to help with SPIR-V generation
22 // that expects no dead-code to be present after branches in a block. To avoid bugs when |discard|
23 // is wrapped in unconditional blocks, any |discard| in main() is used as a signal to wrap it.
24 //
25
26 #include "compiler/translator/tree_util/RunAtTheEndOfShader.h"
27
28 #include "compiler/translator/Compiler.h"
29 #include "compiler/translator/IntermNode.h"
30 #include "compiler/translator/StaticType.h"
31 #include "compiler/translator/SymbolTable.h"
32 #include "compiler/translator/tree_util/FindMain.h"
33 #include "compiler/translator/tree_util/IntermNode_util.h"
34 #include "compiler/translator/tree_util/IntermTraverse.h"
35
36 namespace sh
37 {
38
39 namespace
40 {
41
42 constexpr const ImmutableString kMainString("main");
43
44 class ContainsReturnOrDiscardTraverser : public TIntermTraverser
45 {
46 public:
ContainsReturnOrDiscardTraverser()47 ContainsReturnOrDiscardTraverser()
48 : TIntermTraverser(true, false, false), mContainsReturnOrDiscard(false)
49 {}
50
visitBranch(Visit visit,TIntermBranch * node)51 bool visitBranch(Visit visit, TIntermBranch *node) override
52 {
53 if (node->getFlowOp() == EOpReturn || node->getFlowOp() == EOpKill)
54 {
55 mContainsReturnOrDiscard = true;
56 }
57 return false;
58 }
59
containsReturnOrDiscard()60 bool containsReturnOrDiscard() { return mContainsReturnOrDiscard; }
61
62 private:
63 bool mContainsReturnOrDiscard;
64 };
65
ContainsReturnOrDiscard(TIntermNode * node)66 bool ContainsReturnOrDiscard(TIntermNode *node)
67 {
68 ContainsReturnOrDiscardTraverser traverser;
69 node->traverse(&traverser);
70 return traverser.containsReturnOrDiscard();
71 }
72
WrapMainAndAppend(TIntermBlock * root,TIntermFunctionDefinition * main,TIntermNode * codeToRun,TSymbolTable * symbolTable)73 void WrapMainAndAppend(TIntermBlock *root,
74 TIntermFunctionDefinition *main,
75 TIntermNode *codeToRun,
76 TSymbolTable *symbolTable)
77 {
78 // Replace main() with main0() with the same body.
79 TFunction *oldMain =
80 new TFunction(symbolTable, kEmptyImmutableString, SymbolType::AngleInternal,
81 StaticType::GetBasic<EbtVoid, EbpUndefined>(), false);
82 TIntermFunctionDefinition *oldMainDefinition =
83 CreateInternalFunctionDefinitionNode(*oldMain, main->getBody());
84
85 bool replaced = root->replaceChildNode(main, oldMainDefinition);
86 ASSERT(replaced);
87
88 // void main()
89 TFunction *newMain = new TFunction(symbolTable, kMainString, SymbolType::UserDefined,
90 StaticType::GetBasic<EbtVoid, EbpUndefined>(), false);
91 TIntermFunctionPrototype *newMainProto = new TIntermFunctionPrototype(newMain);
92
93 // {
94 // main0();
95 // codeToRun
96 // }
97 TIntermBlock *newMainBody = new TIntermBlock();
98 TIntermSequence emptySequence;
99 TIntermAggregate *oldMainCall = TIntermAggregate::CreateFunctionCall(*oldMain, &emptySequence);
100 newMainBody->appendStatement(oldMainCall);
101 newMainBody->appendStatement(codeToRun);
102
103 // Add the new main() to the root node.
104 TIntermFunctionDefinition *newMainDefinition =
105 new TIntermFunctionDefinition(newMainProto, newMainBody);
106 root->appendStatement(newMainDefinition);
107 }
108
109 } // anonymous namespace
110
RunAtTheEndOfShader(TCompiler * compiler,TIntermBlock * root,TIntermNode * codeToRun,TSymbolTable * symbolTable)111 bool RunAtTheEndOfShader(TCompiler *compiler,
112 TIntermBlock *root,
113 TIntermNode *codeToRun,
114 TSymbolTable *symbolTable)
115 {
116 TIntermFunctionDefinition *main = FindMain(root);
117 if (ContainsReturnOrDiscard(main))
118 {
119 WrapMainAndAppend(root, main, codeToRun, symbolTable);
120 }
121 else
122 {
123 main->getBody()->appendStatement(codeToRun);
124 }
125
126 return compiler->validateAST(root);
127 }
128
129 } // namespace sh
130