• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright (c) 2002-2011 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 #include "compiler/DetectCallDepth.h"
8 #include "compiler/InfoSink.h"
9 
FunctionNode(const TString & fname)10 DetectCallDepth::FunctionNode::FunctionNode(const TString& fname)
11     : name(fname),
12       visit(PreVisit)
13 {
14 }
15 
getName() const16 const TString& DetectCallDepth::FunctionNode::getName() const
17 {
18     return name;
19 }
20 
addCallee(DetectCallDepth::FunctionNode * callee)21 void DetectCallDepth::FunctionNode::addCallee(
22     DetectCallDepth::FunctionNode* callee)
23 {
24     for (size_t i = 0; i < callees.size(); ++i) {
25         if (callees[i] == callee)
26             return;
27     }
28     callees.push_back(callee);
29 }
30 
detectCallDepth(DetectCallDepth * detectCallDepth,int depth)31 int DetectCallDepth::FunctionNode::detectCallDepth(DetectCallDepth* detectCallDepth, int depth)
32 {
33     ASSERT(visit == PreVisit);
34     ASSERT(detectCallDepth);
35 
36     int maxDepth = depth;
37     visit = InVisit;
38     for (size_t i = 0; i < callees.size(); ++i) {
39         switch (callees[i]->visit) {
40             case InVisit:
41                 // cycle detected, i.e., recursion detected.
42                 return kInfiniteCallDepth;
43             case PostVisit:
44                 break;
45             case PreVisit: {
46                 // Check before we recurse so we don't go too depth
47                 if (detectCallDepth->checkExceedsMaxDepth(depth))
48                     return depth;
49                 int callDepth = callees[i]->detectCallDepth(detectCallDepth, depth + 1);
50                 // Check after we recurse so we can exit immediately and provide info.
51                 if (detectCallDepth->checkExceedsMaxDepth(callDepth)) {
52                     detectCallDepth->getInfoSink().info << "<-" << callees[i]->getName();
53                     return callDepth;
54                 }
55                 maxDepth = std::max(callDepth, maxDepth);
56                 break;
57             }
58             default:
59                 UNREACHABLE();
60                 break;
61         }
62     }
63     visit = PostVisit;
64     return maxDepth;
65 }
66 
reset()67 void DetectCallDepth::FunctionNode::reset()
68 {
69     visit = PreVisit;
70 }
71 
DetectCallDepth(TInfoSink & infoSink,bool limitCallStackDepth,int maxCallStackDepth)72 DetectCallDepth::DetectCallDepth(TInfoSink& infoSink, bool limitCallStackDepth, int maxCallStackDepth)
73     : TIntermTraverser(true, false, true, false),
74       currentFunction(NULL),
75       infoSink(infoSink),
76       maxDepth(limitCallStackDepth ? maxCallStackDepth : FunctionNode::kInfiniteCallDepth)
77 {
78 }
79 
~DetectCallDepth()80 DetectCallDepth::~DetectCallDepth()
81 {
82     for (size_t i = 0; i < functions.size(); ++i)
83         delete functions[i];
84 }
85 
visitAggregate(Visit visit,TIntermAggregate * node)86 bool DetectCallDepth::visitAggregate(Visit visit, TIntermAggregate* node)
87 {
88     switch (node->getOp())
89     {
90         case EOpPrototype:
91             // Function declaration.
92             // Don't add FunctionNode here because node->getName() is the
93             // unmangled function name.
94             break;
95         case EOpFunction: {
96             // Function definition.
97             if (visit == PreVisit) {
98                 currentFunction = findFunctionByName(node->getName());
99                 if (currentFunction == NULL) {
100                     currentFunction = new FunctionNode(node->getName());
101                     functions.push_back(currentFunction);
102                 }
103             } else if (visit == PostVisit) {
104                 currentFunction = NULL;
105             }
106             break;
107         }
108         case EOpFunctionCall: {
109             // Function call.
110             if (visit == PreVisit) {
111                 FunctionNode* func = findFunctionByName(node->getName());
112                 if (func == NULL) {
113                     func = new FunctionNode(node->getName());
114                     functions.push_back(func);
115                 }
116                 if (currentFunction)
117                     currentFunction->addCallee(func);
118             }
119             break;
120         }
121         default:
122             break;
123     }
124     return true;
125 }
126 
checkExceedsMaxDepth(int depth)127 bool DetectCallDepth::checkExceedsMaxDepth(int depth)
128 {
129     return depth >= maxDepth;
130 }
131 
resetFunctionNodes()132 void DetectCallDepth::resetFunctionNodes()
133 {
134     for (size_t i = 0; i < functions.size(); ++i) {
135         functions[i]->reset();
136     }
137 }
138 
detectCallDepthForFunction(FunctionNode * func)139 DetectCallDepth::ErrorCode DetectCallDepth::detectCallDepthForFunction(FunctionNode* func)
140 {
141     currentFunction = NULL;
142     resetFunctionNodes();
143 
144     int maxCallDepth = func->detectCallDepth(this, 1);
145 
146     if (maxCallDepth == FunctionNode::kInfiniteCallDepth)
147         return kErrorRecursion;
148 
149     if (maxCallDepth >= maxDepth)
150         return kErrorMaxDepthExceeded;
151 
152     return kErrorNone;
153 }
154 
detectCallDepth()155 DetectCallDepth::ErrorCode DetectCallDepth::detectCallDepth()
156 {
157     if (maxDepth != FunctionNode::kInfiniteCallDepth) {
158         // Check all functions because the driver may fail on them
159         // TODO: Before detectingRecursion, strip unused functions.
160         for (size_t i = 0; i < functions.size(); ++i) {
161             ErrorCode error = detectCallDepthForFunction(functions[i]);
162             if (error != kErrorNone)
163                 return error;
164         }
165     } else {
166         FunctionNode* main = findFunctionByName("main(");
167         if (main == NULL)
168             return kErrorMissingMain;
169 
170         return detectCallDepthForFunction(main);
171     }
172 
173     return kErrorNone;
174 }
175 
findFunctionByName(const TString & name)176 DetectCallDepth::FunctionNode* DetectCallDepth::findFunctionByName(
177     const TString& name)
178 {
179     for (size_t i = 0; i < functions.size(); ++i) {
180         if (functions[i]->getName() == name)
181             return functions[i];
182     }
183     return NULL;
184 }
185 
186