• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright 2017 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // RunAtTheEndOfShader.cpp: Add code to be run at the end of the shader. In case main() contains a
7 // return statement, this is done by replacing the main() function with another function that calls
8 // the old main, like this:
9 //
10 // void main() { body }
11 // =>
12 // void main0() { body }
13 // void main()
14 // {
15 //     main0();
16 //     codeToRun
17 // }
18 //
19 // This way the code will get run even if the return statement inside main is executed.
20 //
21 // This is done if main ends in an unconditional |discard| as well, to help with SPIR-V generation
22 // that expects no dead-code to be present after branches in a block.  To avoid bugs when |discard|
23 // is wrapped in unconditional blocks, any |discard| in main() is used as a signal to wrap it.
24 //
25 
26 #include "compiler/translator/tree_util/RunAtTheEndOfShader.h"
27 
28 #include "compiler/translator/Compiler.h"
29 #include "compiler/translator/IntermNode.h"
30 #include "compiler/translator/StaticType.h"
31 #include "compiler/translator/SymbolTable.h"
32 #include "compiler/translator/tree_util/FindMain.h"
33 #include "compiler/translator/tree_util/IntermNode_util.h"
34 #include "compiler/translator/tree_util/IntermTraverse.h"
35 
36 namespace sh
37 {
38 
39 namespace
40 {
41 
42 constexpr const ImmutableString kMainString("main");
43 
44 class ContainsReturnOrDiscardTraverser : public TIntermTraverser
45 {
46   public:
ContainsReturnOrDiscardTraverser()47     ContainsReturnOrDiscardTraverser()
48         : TIntermTraverser(true, false, false), mContainsReturnOrDiscard(false)
49     {}
50 
visitBranch(Visit visit,TIntermBranch * node)51     bool visitBranch(Visit visit, TIntermBranch *node) override
52     {
53         if (node->getFlowOp() == EOpReturn || node->getFlowOp() == EOpKill)
54         {
55             mContainsReturnOrDiscard = true;
56         }
57         return false;
58     }
59 
containsReturnOrDiscard()60     bool containsReturnOrDiscard() { return mContainsReturnOrDiscard; }
61 
62   private:
63     bool mContainsReturnOrDiscard;
64 };
65 
ContainsReturnOrDiscard(TIntermNode * node)66 bool ContainsReturnOrDiscard(TIntermNode *node)
67 {
68     ContainsReturnOrDiscardTraverser traverser;
69     node->traverse(&traverser);
70     return traverser.containsReturnOrDiscard();
71 }
72 
WrapMainAndAppend(TIntermBlock * root,TIntermFunctionDefinition * main,TIntermNode * codeToRun,TSymbolTable * symbolTable)73 void WrapMainAndAppend(TIntermBlock *root,
74                        TIntermFunctionDefinition *main,
75                        TIntermNode *codeToRun,
76                        TSymbolTable *symbolTable)
77 {
78     // Replace main() with main0() with the same body.
79     TFunction *oldMain =
80         new TFunction(symbolTable, kEmptyImmutableString, SymbolType::AngleInternal,
81                       StaticType::GetBasic<EbtVoid, EbpUndefined>(), false);
82     TIntermFunctionDefinition *oldMainDefinition =
83         CreateInternalFunctionDefinitionNode(*oldMain, main->getBody());
84 
85     bool replaced = root->replaceChildNode(main, oldMainDefinition);
86     ASSERT(replaced);
87 
88     // void main()
89     TFunction *newMain = new TFunction(symbolTable, kMainString, SymbolType::UserDefined,
90                                        StaticType::GetBasic<EbtVoid, EbpUndefined>(), false);
91     TIntermFunctionPrototype *newMainProto = new TIntermFunctionPrototype(newMain);
92 
93     // {
94     //     main0();
95     //     codeToRun
96     // }
97     TIntermBlock *newMainBody = new TIntermBlock();
98     TIntermSequence emptySequence;
99     TIntermAggregate *oldMainCall = TIntermAggregate::CreateFunctionCall(*oldMain, &emptySequence);
100     newMainBody->appendStatement(oldMainCall);
101     newMainBody->appendStatement(codeToRun);
102 
103     // Add the new main() to the root node.
104     TIntermFunctionDefinition *newMainDefinition =
105         new TIntermFunctionDefinition(newMainProto, newMainBody);
106     root->appendStatement(newMainDefinition);
107 }
108 
109 }  // anonymous namespace
110 
RunAtTheEndOfShader(TCompiler * compiler,TIntermBlock * root,TIntermNode * codeToRun,TSymbolTable * symbolTable)111 bool RunAtTheEndOfShader(TCompiler *compiler,
112                          TIntermBlock *root,
113                          TIntermNode *codeToRun,
114                          TSymbolTable *symbolTable)
115 {
116     TIntermFunctionDefinition *main = FindMain(root);
117     if (ContainsReturnOrDiscard(main))
118     {
119         WrapMainAndAppend(root, main, codeToRun, symbolTable);
120     }
121     else
122     {
123         main->getBody()->appendStatement(codeToRun);
124     }
125 
126     return compiler->validateAST(root);
127 }
128 
129 }  // namespace sh
130