1 // Copyright 2015 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "CheckTraceVisitor.h"
6
7 #include <vector>
8
9 #include "Config.h"
10
11 using namespace clang;
12
CheckTraceVisitor(CXXMethodDecl * trace,RecordInfo * info,RecordCache * cache)13 CheckTraceVisitor::CheckTraceVisitor(CXXMethodDecl* trace,
14 RecordInfo* info,
15 RecordCache* cache)
16 : trace_(trace), info_(info), cache_(cache) {}
17
VisitMemberExpr(MemberExpr * member)18 bool CheckTraceVisitor::VisitMemberExpr(MemberExpr* member) {
19 // In weak callbacks, consider any occurrence as a correct usage.
20 // TODO: We really want to require that isAlive is checked on manually
21 // processed weak fields.
22 if (IsWeakCallback()) {
23 if (FieldDecl* field = dyn_cast<FieldDecl>(member->getMemberDecl()))
24 FoundField(field);
25 }
26 return true;
27 }
28
VisitCallExpr(CallExpr * call)29 bool CheckTraceVisitor::VisitCallExpr(CallExpr* call) {
30 // In weak callbacks we don't check calls (see VisitMemberExpr).
31 if (IsWeakCallback())
32 return true;
33
34 Expr* callee = call->getCallee();
35
36 // Trace calls from a templated derived class result in a
37 // DependentScopeMemberExpr because the concrete trace call depends on the
38 // instantiation of any shared template parameters. In this case the call is
39 // "unresolved" and we resort to comparing the syntactic type names.
40 if (CXXDependentScopeMemberExpr* expr =
41 dyn_cast<CXXDependentScopeMemberExpr>(callee)) {
42 CheckCXXDependentScopeMemberExpr(call, expr);
43 return true;
44 }
45
46 // A tracing call will have either a |visitor| or a |m_field| argument.
47 // A registerWeakMembers call will have a |this| argument.
48 if (call->getNumArgs() != 1)
49 return true;
50 Expr* arg = call->getArg(0);
51
52 if (UnresolvedMemberExpr* expr = dyn_cast<UnresolvedMemberExpr>(callee)) {
53 // This could be a trace call of a base class, as explained in the
54 // comments of CheckTraceBaseCall().
55 if (CheckTraceBaseCall(call))
56 return true;
57
58 if (expr->getMemberName().getAsString() == kRegisterWeakMembersName)
59 MarkAllWeakMembersTraced();
60
61 QualType base = expr->getBaseType();
62 if (!base->isPointerType())
63 return true;
64 CXXRecordDecl* decl = base->getPointeeType()->getAsCXXRecordDecl();
65 if (decl)
66 CheckTraceFieldCall(expr->getMemberName().getAsString(), decl, arg);
67 return true;
68 }
69
70 if (CXXMemberCallExpr* expr = dyn_cast<CXXMemberCallExpr>(call)) {
71 if (CheckTraceFieldMemberCall(expr) || CheckRegisterWeakMembers(expr))
72 return true;
73
74 }
75
76 CheckTraceBaseCall(call);
77 return true;
78 }
79
IsTraceCallName(const std::string & name)80 bool CheckTraceVisitor::IsTraceCallName(const std::string& name) {
81 // Currently, a manually dispatched class cannot have mixin bases (having
82 // one would add a vtable which we explicitly check against). This means
83 // that we can only make calls to a trace method of the same name. Revisit
84 // this if our mixin/vtable assumption changes.
85 return name == trace_->getName();
86 }
87
GetDependentTemplatedDecl(CXXDependentScopeMemberExpr * expr)88 CXXRecordDecl* CheckTraceVisitor::GetDependentTemplatedDecl(
89 CXXDependentScopeMemberExpr* expr) {
90 NestedNameSpecifier* qual = expr->getQualifier();
91 if (!qual)
92 return 0;
93
94 const Type* type = qual->getAsType();
95 if (!type)
96 return 0;
97
98 return RecordInfo::GetDependentTemplatedDecl(*type);
99 }
100
101 namespace {
102
103 class FindFieldVisitor : public RecursiveASTVisitor<FindFieldVisitor> {
104 public:
105 FindFieldVisitor();
106 MemberExpr* member() const;
107 FieldDecl* field() const;
108 bool TraverseMemberExpr(MemberExpr* member);
109
110 private:
111 MemberExpr* member_;
112 FieldDecl* field_;
113 };
114
FindFieldVisitor()115 FindFieldVisitor::FindFieldVisitor()
116 : member_(0),
117 field_(0) {
118 }
119
member() const120 MemberExpr* FindFieldVisitor::member() const {
121 return member_;
122 }
123
field() const124 FieldDecl* FindFieldVisitor::field() const {
125 return field_;
126 }
127
TraverseMemberExpr(MemberExpr * member)128 bool FindFieldVisitor::TraverseMemberExpr(MemberExpr* member) {
129 if (FieldDecl* field = dyn_cast<FieldDecl>(member->getMemberDecl())) {
130 member_ = member;
131 field_ = field;
132 return false;
133 }
134 return true;
135 }
136
137 } // namespace
138
CheckCXXDependentScopeMemberExpr(CallExpr * call,CXXDependentScopeMemberExpr * expr)139 void CheckTraceVisitor::CheckCXXDependentScopeMemberExpr(
140 CallExpr* call,
141 CXXDependentScopeMemberExpr* expr) {
142 std::string fn_name = expr->getMember().getAsString();
143
144 // Check for VisitorDispatcher::trace(field) and
145 // VisitorDispatcher::registerWeakMembers.
146 if (!expr->isImplicitAccess()) {
147 if (DeclRefExpr* base_decl = dyn_cast<DeclRefExpr>(expr->getBase())) {
148 if (Config::IsVisitorDispatcherType(base_decl->getType())) {
149 if (call->getNumArgs() == 1 && fn_name == kTraceName) {
150 FindFieldVisitor finder;
151 finder.TraverseStmt(call->getArg(0));
152 if (finder.field())
153 FoundField(finder.field());
154
155 return;
156 } else if (call->getNumArgs() == 1 &&
157 fn_name == kRegisterWeakMembersName) {
158 MarkAllWeakMembersTraced();
159 }
160 }
161 }
162 }
163
164 CXXRecordDecl* tmpl = GetDependentTemplatedDecl(expr);
165 if (!tmpl)
166 return;
167
168 // Check for Super<T>::trace(visitor)
169 if (call->getNumArgs() == 1 && IsTraceCallName(fn_name)) {
170 RecordInfo::Bases::iterator it = info_->GetBases().begin();
171 for (; it != info_->GetBases().end(); ++it) {
172 if (it->first->getName() == tmpl->getName())
173 it->second.MarkTraced();
174 }
175 }
176
177 // Check for TraceIfNeeded<T>::trace(visitor, &field)
178 if (call->getNumArgs() == 2 && fn_name == kTraceName &&
179 tmpl->getName() == kTraceIfNeededName) {
180 FindFieldVisitor finder;
181 finder.TraverseStmt(call->getArg(1));
182 if (finder.field())
183 FoundField(finder.field());
184 }
185 }
186
CheckTraceBaseCall(CallExpr * call)187 bool CheckTraceVisitor::CheckTraceBaseCall(CallExpr* call) {
188 // Checks for "Base::trace(visitor)"-like calls.
189
190 // Checking code for these two variables is shared among MemberExpr* case
191 // and UnresolvedMemberCase* case below.
192 //
193 // For example, if we've got "Base::trace(visitor)" as |call|,
194 // callee_record will be "Base", and func_name will be "trace".
195 CXXRecordDecl* callee_record = nullptr;
196 std::string func_name;
197
198 if (MemberExpr* callee = dyn_cast<MemberExpr>(call->getCallee())) {
199 if (!callee->hasQualifier())
200 return false;
201
202 FunctionDecl* trace_decl =
203 dyn_cast<FunctionDecl>(callee->getMemberDecl());
204 if (!trace_decl || !Config::IsTraceMethod(trace_decl))
205 return false;
206
207 const Type* type = callee->getQualifier()->getAsType();
208 if (!type)
209 return false;
210
211 callee_record = type->getAsCXXRecordDecl();
212 func_name = trace_decl->getName();
213 } else if (UnresolvedMemberExpr* callee =
214 dyn_cast<UnresolvedMemberExpr>(call->getCallee())) {
215 // Callee part may become unresolved if the type of the argument
216 // ("visitor") is a template parameter and the called function is
217 // overloaded.
218 //
219 // Here, we try to find a function that looks like trace() from the
220 // candidate overloaded functions, and if we find one, we assume it is
221 // called here.
222
223 CXXMethodDecl* trace_decl = nullptr;
224 for (NamedDecl* named_decl : callee->decls()) {
225 if (CXXMethodDecl* method_decl = dyn_cast<CXXMethodDecl>(named_decl)) {
226 if (Config::IsTraceMethod(method_decl)) {
227 trace_decl = method_decl;
228 break;
229 }
230 }
231 }
232 if (!trace_decl)
233 return false;
234
235 // Check if the passed argument is named "visitor".
236 if (call->getNumArgs() != 1)
237 return false;
238 DeclRefExpr* arg = dyn_cast<DeclRefExpr>(call->getArg(0));
239 if (!arg || arg->getNameInfo().getAsString() != kVisitorVarName)
240 return false;
241
242 callee_record = trace_decl->getParent();
243 func_name = callee->getMemberName().getAsString();
244 }
245
246 if (!callee_record)
247 return false;
248
249 if (!IsTraceCallName(func_name))
250 return false;
251
252 for (auto& base : info_->GetBases()) {
253 // We want to deal with omitted trace() function in an intermediary
254 // class in the class hierarchy, e.g.:
255 // class A : public GarbageCollected<A> { trace() { ... } };
256 // class B : public A { /* No trace(); have nothing to trace. */ };
257 // class C : public B { trace() { B::trace(visitor); } }
258 // where, B::trace() is actually A::trace(), and in some cases we get
259 // A as |callee_record| instead of B. We somehow need to mark B as
260 // traced if we find A::trace() call.
261 //
262 // To solve this, here we keep going up the class hierarchy as long as
263 // they are not required to have a trace method. The implementation is
264 // a simple DFS, where |base_records| represents the set of base classes
265 // we need to visit.
266
267 std::vector<CXXRecordDecl*> base_records;
268 base_records.push_back(base.first);
269
270 while (!base_records.empty()) {
271 CXXRecordDecl* base_record = base_records.back();
272 base_records.pop_back();
273
274 if (base_record == callee_record) {
275 // If we find a matching trace method, pretend the user has written
276 // a correct trace() method of the base; in the example above, we
277 // find A::trace() here and mark B as correctly traced.
278 base.second.MarkTraced();
279 return true;
280 }
281
282 if (RecordInfo* base_info = cache_->Lookup(base_record)) {
283 if (!base_info->RequiresTraceMethod()) {
284 // If this base class is not required to have a trace method, then
285 // the actual trace method may be defined in an ancestor.
286 for (auto& inner_base : base_info->GetBases())
287 base_records.push_back(inner_base.first);
288 }
289 }
290 }
291 }
292
293 return false;
294 }
295
CheckTraceFieldMemberCall(CXXMemberCallExpr * call)296 bool CheckTraceVisitor::CheckTraceFieldMemberCall(CXXMemberCallExpr* call) {
297 return CheckTraceFieldCall(call->getMethodDecl()->getNameAsString(),
298 call->getRecordDecl(),
299 call->getArg(0));
300 }
301
CheckTraceFieldCall(const std::string & name,CXXRecordDecl * callee,Expr * arg)302 bool CheckTraceVisitor::CheckTraceFieldCall(
303 const std::string& name,
304 CXXRecordDecl* callee,
305 Expr* arg) {
306 if (name != kTraceName || !Config::IsVisitor(callee->getName()))
307 return false;
308
309 FindFieldVisitor finder;
310 finder.TraverseStmt(arg);
311 if (finder.field())
312 FoundField(finder.field());
313
314 return true;
315 }
316
CheckRegisterWeakMembers(CXXMemberCallExpr * call)317 bool CheckTraceVisitor::CheckRegisterWeakMembers(CXXMemberCallExpr* call) {
318 CXXMethodDecl* fn = call->getMethodDecl();
319 if (fn->getName() != kRegisterWeakMembersName)
320 return false;
321
322 if (fn->isTemplateInstantiation()) {
323 const TemplateArgumentList& args =
324 *fn->getTemplateSpecializationInfo()->TemplateArguments;
325 // The second template argument is the callback method.
326 if (args.size() > 1 &&
327 args[1].getKind() == TemplateArgument::Declaration) {
328 if (FunctionDecl* callback =
329 dyn_cast<FunctionDecl>(args[1].getAsDecl())) {
330 if (callback->hasBody()) {
331 CheckTraceVisitor nested_visitor(nullptr, info_, nullptr);
332 nested_visitor.TraverseStmt(callback->getBody());
333 }
334 }
335 // TODO: mark all WeakMember<>s as traced even if
336 // the body isn't available?
337 }
338 }
339 return true;
340 }
341
IsWeakCallback() const342 bool CheckTraceVisitor::IsWeakCallback() const {
343 return !trace_;
344 }
345
MarkTraced(RecordInfo::Fields::iterator it)346 void CheckTraceVisitor::MarkTraced(RecordInfo::Fields::iterator it) {
347 // In a weak callback we can't mark strong fields as traced.
348 if (IsWeakCallback() && !it->second.edge()->IsWeakMember())
349 return;
350 it->second.MarkTraced();
351 }
352
FoundField(FieldDecl * field)353 void CheckTraceVisitor::FoundField(FieldDecl* field) {
354 if (Config::IsTemplateInstantiation(info_->record())) {
355 // Pointer equality on fields does not work for template instantiations.
356 // The trace method refers to fields of the template definition which
357 // are different from the instantiated fields that need to be traced.
358 const std::string& name = field->getNameAsString();
359 for (RecordInfo::Fields::iterator it = info_->GetFields().begin();
360 it != info_->GetFields().end();
361 ++it) {
362 if (it->first->getNameAsString() == name) {
363 MarkTraced(it);
364 break;
365 }
366 }
367 } else {
368 RecordInfo::Fields::iterator it = info_->GetFields().find(field);
369 if (it != info_->GetFields().end())
370 MarkTraced(it);
371 }
372 }
373
MarkAllWeakMembersTraced()374 void CheckTraceVisitor::MarkAllWeakMembersTraced() {
375 // If we find a call to registerWeakMembers which is unresolved we
376 // unsoundly consider all weak members as traced.
377 // TODO: Find out how to validate weak member tracing for unresolved call.
378 for (auto& field : info_->GetFields()) {
379 if (field.second.edge()->IsWeakMember())
380 field.second.MarkTraced();
381 }
382 }
383