1 //
2 // Copyright (c) 2002-2011 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6
7 #include "compiler/DetectCallDepth.h"
8 #include "compiler/InfoSink.h"
9
FunctionNode(const TString & fname)10 DetectCallDepth::FunctionNode::FunctionNode(const TString& fname)
11 : name(fname),
12 visit(PreVisit)
13 {
14 }
15
getName() const16 const TString& DetectCallDepth::FunctionNode::getName() const
17 {
18 return name;
19 }
20
addCallee(DetectCallDepth::FunctionNode * callee)21 void DetectCallDepth::FunctionNode::addCallee(
22 DetectCallDepth::FunctionNode* callee)
23 {
24 for (size_t i = 0; i < callees.size(); ++i) {
25 if (callees[i] == callee)
26 return;
27 }
28 callees.push_back(callee);
29 }
30
detectCallDepth(DetectCallDepth * detectCallDepth,int depth)31 int DetectCallDepth::FunctionNode::detectCallDepth(DetectCallDepth* detectCallDepth, int depth)
32 {
33 ASSERT(visit == PreVisit);
34 ASSERT(detectCallDepth);
35
36 int maxDepth = depth;
37 visit = InVisit;
38 for (size_t i = 0; i < callees.size(); ++i) {
39 switch (callees[i]->visit) {
40 case InVisit:
41 // cycle detected, i.e., recursion detected.
42 return kInfiniteCallDepth;
43 case PostVisit:
44 break;
45 case PreVisit: {
46 // Check before we recurse so we don't go too depth
47 if (detectCallDepth->checkExceedsMaxDepth(depth))
48 return depth;
49 int callDepth = callees[i]->detectCallDepth(detectCallDepth, depth + 1);
50 // Check after we recurse so we can exit immediately and provide info.
51 if (detectCallDepth->checkExceedsMaxDepth(callDepth)) {
52 detectCallDepth->getInfoSink().info << "<-" << callees[i]->getName();
53 return callDepth;
54 }
55 maxDepth = std::max(callDepth, maxDepth);
56 break;
57 }
58 default:
59 UNREACHABLE();
60 break;
61 }
62 }
63 visit = PostVisit;
64 return maxDepth;
65 }
66
reset()67 void DetectCallDepth::FunctionNode::reset()
68 {
69 visit = PreVisit;
70 }
71
DetectCallDepth(TInfoSink & infoSink,bool limitCallStackDepth,int maxCallStackDepth)72 DetectCallDepth::DetectCallDepth(TInfoSink& infoSink, bool limitCallStackDepth, int maxCallStackDepth)
73 : TIntermTraverser(true, false, true, false),
74 currentFunction(NULL),
75 infoSink(infoSink),
76 maxDepth(limitCallStackDepth ? maxCallStackDepth : FunctionNode::kInfiniteCallDepth)
77 {
78 }
79
~DetectCallDepth()80 DetectCallDepth::~DetectCallDepth()
81 {
82 for (size_t i = 0; i < functions.size(); ++i)
83 delete functions[i];
84 }
85
visitAggregate(Visit visit,TIntermAggregate * node)86 bool DetectCallDepth::visitAggregate(Visit visit, TIntermAggregate* node)
87 {
88 switch (node->getOp())
89 {
90 case EOpPrototype:
91 // Function declaration.
92 // Don't add FunctionNode here because node->getName() is the
93 // unmangled function name.
94 break;
95 case EOpFunction: {
96 // Function definition.
97 if (visit == PreVisit) {
98 currentFunction = findFunctionByName(node->getName());
99 if (currentFunction == NULL) {
100 currentFunction = new FunctionNode(node->getName());
101 functions.push_back(currentFunction);
102 }
103 } else if (visit == PostVisit) {
104 currentFunction = NULL;
105 }
106 break;
107 }
108 case EOpFunctionCall: {
109 // Function call.
110 if (visit == PreVisit) {
111 FunctionNode* func = findFunctionByName(node->getName());
112 if (func == NULL) {
113 func = new FunctionNode(node->getName());
114 functions.push_back(func);
115 }
116 if (currentFunction)
117 currentFunction->addCallee(func);
118 }
119 break;
120 }
121 default:
122 break;
123 }
124 return true;
125 }
126
checkExceedsMaxDepth(int depth)127 bool DetectCallDepth::checkExceedsMaxDepth(int depth)
128 {
129 return depth >= maxDepth;
130 }
131
resetFunctionNodes()132 void DetectCallDepth::resetFunctionNodes()
133 {
134 for (size_t i = 0; i < functions.size(); ++i) {
135 functions[i]->reset();
136 }
137 }
138
detectCallDepthForFunction(FunctionNode * func)139 DetectCallDepth::ErrorCode DetectCallDepth::detectCallDepthForFunction(FunctionNode* func)
140 {
141 currentFunction = NULL;
142 resetFunctionNodes();
143
144 int maxCallDepth = func->detectCallDepth(this, 1);
145
146 if (maxCallDepth == FunctionNode::kInfiniteCallDepth)
147 return kErrorRecursion;
148
149 if (maxCallDepth >= maxDepth)
150 return kErrorMaxDepthExceeded;
151
152 return kErrorNone;
153 }
154
detectCallDepth()155 DetectCallDepth::ErrorCode DetectCallDepth::detectCallDepth()
156 {
157 if (maxDepth != FunctionNode::kInfiniteCallDepth) {
158 // Check all functions because the driver may fail on them
159 // TODO: Before detectingRecursion, strip unused functions.
160 for (size_t i = 0; i < functions.size(); ++i) {
161 ErrorCode error = detectCallDepthForFunction(functions[i]);
162 if (error != kErrorNone)
163 return error;
164 }
165 } else {
166 FunctionNode* main = findFunctionByName("main(");
167 if (main == NULL)
168 return kErrorMissingMain;
169
170 return detectCallDepthForFunction(main);
171 }
172
173 return kErrorNone;
174 }
175
findFunctionByName(const TString & name)176 DetectCallDepth::FunctionNode* DetectCallDepth::findFunctionByName(
177 const TString& name)
178 {
179 for (size_t i = 0; i < functions.size(); ++i) {
180 if (functions[i]->getName() == name)
181 return functions[i];
182 }
183 return NULL;
184 }
185
186