• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright 2017 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // RunAtTheEndOfShader.cpp: Add code to be run at the end of the shader. In case main() contains a
7 // return statement, this is done by replacing the main() function with another function that calls
8 // the old main, like this:
9 //
10 // void main() { body }
11 // =>
12 // void main0() { body }
13 // void main()
14 // {
15 //     main0();
16 //     codeToRun
17 // }
18 //
19 // This way the code will get run even if the return statement inside main is executed.
20 //
21 
22 #include "compiler/translator/tree_util/RunAtTheEndOfShader.h"
23 
24 #include "compiler/translator/Compiler.h"
25 #include "compiler/translator/IntermNode.h"
26 #include "compiler/translator/StaticType.h"
27 #include "compiler/translator/SymbolTable.h"
28 #include "compiler/translator/tree_util/FindMain.h"
29 #include "compiler/translator/tree_util/IntermNode_util.h"
30 #include "compiler/translator/tree_util/IntermTraverse.h"
31 
32 namespace sh
33 {
34 
35 namespace
36 {
37 
38 constexpr const ImmutableString kMainString("main");
39 
40 class ContainsReturnTraverser : public TIntermTraverser
41 {
42   public:
ContainsReturnTraverser()43     ContainsReturnTraverser() : TIntermTraverser(true, false, false), mContainsReturn(false) {}
44 
visitBranch(Visit visit,TIntermBranch * node)45     bool visitBranch(Visit visit, TIntermBranch *node) override
46     {
47         if (node->getFlowOp() == EOpReturn)
48         {
49             mContainsReturn = true;
50         }
51         return false;
52     }
53 
containsReturn()54     bool containsReturn() { return mContainsReturn; }
55 
56   private:
57     bool mContainsReturn;
58 };
59 
ContainsReturn(TIntermNode * node)60 bool ContainsReturn(TIntermNode *node)
61 {
62     ContainsReturnTraverser traverser;
63     node->traverse(&traverser);
64     return traverser.containsReturn();
65 }
66 
WrapMainAndAppend(TIntermBlock * root,TIntermFunctionDefinition * main,TIntermNode * codeToRun,TSymbolTable * symbolTable)67 void WrapMainAndAppend(TIntermBlock *root,
68                        TIntermFunctionDefinition *main,
69                        TIntermNode *codeToRun,
70                        TSymbolTable *symbolTable)
71 {
72     // Replace main() with main0() with the same body.
73     TFunction *oldMain =
74         new TFunction(symbolTable, kEmptyImmutableString, SymbolType::AngleInternal,
75                       StaticType::GetBasic<EbtVoid>(), false);
76     TIntermFunctionDefinition *oldMainDefinition =
77         CreateInternalFunctionDefinitionNode(*oldMain, main->getBody());
78 
79     bool replaced = root->replaceChildNode(main, oldMainDefinition);
80     ASSERT(replaced);
81 
82     // void main()
83     TFunction *newMain = new TFunction(symbolTable, kMainString, SymbolType::UserDefined,
84                                        StaticType::GetBasic<EbtVoid>(), false);
85     TIntermFunctionPrototype *newMainProto = new TIntermFunctionPrototype(newMain);
86 
87     // {
88     //     main0();
89     //     codeToRun
90     // }
91     TIntermBlock *newMainBody = new TIntermBlock();
92     TIntermSequence emptySequence;
93     TIntermAggregate *oldMainCall = TIntermAggregate::CreateFunctionCall(*oldMain, &emptySequence);
94     newMainBody->appendStatement(oldMainCall);
95     newMainBody->appendStatement(codeToRun);
96 
97     // Add the new main() to the root node.
98     TIntermFunctionDefinition *newMainDefinition =
99         new TIntermFunctionDefinition(newMainProto, newMainBody);
100     root->appendStatement(newMainDefinition);
101 }
102 
103 }  // anonymous namespace
104 
RunAtTheEndOfShader(TCompiler * compiler,TIntermBlock * root,TIntermNode * codeToRun,TSymbolTable * symbolTable)105 bool RunAtTheEndOfShader(TCompiler *compiler,
106                          TIntermBlock *root,
107                          TIntermNode *codeToRun,
108                          TSymbolTable *symbolTable)
109 {
110     TIntermFunctionDefinition *main = FindMain(root);
111     if (!ContainsReturn(main))
112     {
113         main->getBody()->appendStatement(codeToRun);
114     }
115     else
116     {
117         WrapMainAndAppend(root, main, codeToRun, symbolTable);
118     }
119 
120     return compiler->validateAST(root);
121 }
122 
123 }  // namespace sh
124