1 //===- unittest/AST/MatchVerifier.h - AST unit test support ---------------===//
2 //
3 // The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // Provides MatchVerifier, a base class to implement gtest matchers that
11 // verify things that can be matched on the AST.
12 //
13 // Also implements matchers based on MatchVerifier:
14 // LocationVerifier and RangeVerifier to verify whether a matched node has
15 // the expected source location or source range.
16 //
17 //===----------------------------------------------------------------------===//
18
19 #include "clang/AST/ASTContext.h"
20 #include "clang/ASTMatchers/ASTMatchFinder.h"
21 #include "clang/ASTMatchers/ASTMatchers.h"
22 #include "clang/Tooling/Tooling.h"
23 #include "gtest/gtest.h"
24
25 namespace clang {
26 namespace ast_matchers {
27
28 enum Language { Lang_C, Lang_C89, Lang_CXX, Lang_CXX11, Lang_OpenCL };
29
30 /// \brief Base class for verifying some property of nodes found by a matcher.
31 template <typename NodeType>
32 class MatchVerifier : public MatchFinder::MatchCallback {
33 public:
34 template <typename MatcherType>
match(const std::string & Code,const MatcherType & AMatcher)35 testing::AssertionResult match(const std::string &Code,
36 const MatcherType &AMatcher) {
37 std::vector<std::string> Args;
38 return match(Code, AMatcher, Args, Lang_CXX);
39 }
40
41 template <typename MatcherType>
match(const std::string & Code,const MatcherType & AMatcher,Language L)42 testing::AssertionResult match(const std::string &Code,
43 const MatcherType &AMatcher,
44 Language L) {
45 std::vector<std::string> Args;
46 return match(Code, AMatcher, Args, L);
47 }
48
49 template <typename MatcherType>
50 testing::AssertionResult match(const std::string &Code,
51 const MatcherType &AMatcher,
52 std::vector<std::string>& Args,
53 Language L);
54
55 protected:
56 virtual void run(const MatchFinder::MatchResult &Result);
verify(const MatchFinder::MatchResult & Result,const NodeType & Node)57 virtual void verify(const MatchFinder::MatchResult &Result,
58 const NodeType &Node) {}
59
setFailure(const Twine & Result)60 void setFailure(const Twine &Result) {
61 Verified = false;
62 VerifyResult = Result.str();
63 }
64
setSuccess()65 void setSuccess() {
66 Verified = true;
67 }
68
69 private:
70 bool Verified;
71 std::string VerifyResult;
72 };
73
74 /// \brief Runs a matcher over some code, and returns the result of the
75 /// verifier for the matched node.
76 template <typename NodeType> template <typename MatcherType>
match(const std::string & Code,const MatcherType & AMatcher,std::vector<std::string> & Args,Language L)77 testing::AssertionResult MatchVerifier<NodeType>::match(
78 const std::string &Code, const MatcherType &AMatcher,
79 std::vector<std::string>& Args, Language L) {
80 MatchFinder Finder;
81 Finder.addMatcher(AMatcher.bind(""), this);
82 std::unique_ptr<tooling::FrontendActionFactory> Factory(
83 tooling::newFrontendActionFactory(&Finder));
84
85 StringRef FileName;
86 switch (L) {
87 case Lang_C:
88 Args.push_back("-std=c99");
89 FileName = "input.c";
90 break;
91 case Lang_C89:
92 Args.push_back("-std=c89");
93 FileName = "input.c";
94 break;
95 case Lang_CXX:
96 Args.push_back("-std=c++98");
97 FileName = "input.cc";
98 break;
99 case Lang_CXX11:
100 Args.push_back("-std=c++11");
101 FileName = "input.cc";
102 break;
103 case Lang_OpenCL:
104 FileName = "input.cl";
105 }
106
107 // Default to failure in case callback is never called
108 setFailure("Could not find match");
109 if (!tooling::runToolOnCodeWithArgs(Factory->create(), Code, Args, FileName))
110 return testing::AssertionFailure() << "Parsing error";
111 if (!Verified)
112 return testing::AssertionFailure() << VerifyResult;
113 return testing::AssertionSuccess();
114 }
115
116 template <typename NodeType>
run(const MatchFinder::MatchResult & Result)117 void MatchVerifier<NodeType>::run(const MatchFinder::MatchResult &Result) {
118 const NodeType *Node = Result.Nodes.getNodeAs<NodeType>("");
119 if (!Node) {
120 setFailure("Matched node has wrong type");
121 } else {
122 // Callback has been called, default to success.
123 setSuccess();
124 verify(Result, *Node);
125 }
126 }
127
128 template <>
run(const MatchFinder::MatchResult & Result)129 inline void MatchVerifier<ast_type_traits::DynTypedNode>::run(
130 const MatchFinder::MatchResult &Result) {
131 BoundNodes::IDToNodeMap M = Result.Nodes.getMap();
132 BoundNodes::IDToNodeMap::const_iterator I = M.find("");
133 if (I == M.end()) {
134 setFailure("Node was not bound");
135 } else {
136 // Callback has been called, default to success.
137 setSuccess();
138 verify(Result, I->second);
139 }
140 }
141
142 /// \brief Verify whether a node has the correct source location.
143 ///
144 /// By default, Node.getSourceLocation() is checked. This can be changed
145 /// by overriding getLocation().
146 template <typename NodeType>
147 class LocationVerifier : public MatchVerifier<NodeType> {
148 public:
expectLocation(unsigned Line,unsigned Column)149 void expectLocation(unsigned Line, unsigned Column) {
150 ExpectLine = Line;
151 ExpectColumn = Column;
152 }
153
154 protected:
verify(const MatchFinder::MatchResult & Result,const NodeType & Node)155 void verify(const MatchFinder::MatchResult &Result, const NodeType &Node) {
156 SourceLocation Loc = getLocation(Node);
157 unsigned Line = Result.SourceManager->getSpellingLineNumber(Loc);
158 unsigned Column = Result.SourceManager->getSpellingColumnNumber(Loc);
159 if (Line != ExpectLine || Column != ExpectColumn) {
160 std::string MsgStr;
161 llvm::raw_string_ostream Msg(MsgStr);
162 Msg << "Expected location <" << ExpectLine << ":" << ExpectColumn
163 << ">, found <";
164 Loc.print(Msg, *Result.SourceManager);
165 Msg << '>';
166 this->setFailure(Msg.str());
167 }
168 }
169
getLocation(const NodeType & Node)170 virtual SourceLocation getLocation(const NodeType &Node) {
171 return Node.getLocation();
172 }
173
174 private:
175 unsigned ExpectLine, ExpectColumn;
176 };
177
178 /// \brief Verify whether a node has the correct source range.
179 ///
180 /// By default, Node.getSourceRange() is checked. This can be changed
181 /// by overriding getRange().
182 template <typename NodeType>
183 class RangeVerifier : public MatchVerifier<NodeType> {
184 public:
expectRange(unsigned BeginLine,unsigned BeginColumn,unsigned EndLine,unsigned EndColumn)185 void expectRange(unsigned BeginLine, unsigned BeginColumn,
186 unsigned EndLine, unsigned EndColumn) {
187 ExpectBeginLine = BeginLine;
188 ExpectBeginColumn = BeginColumn;
189 ExpectEndLine = EndLine;
190 ExpectEndColumn = EndColumn;
191 }
192
193 protected:
verify(const MatchFinder::MatchResult & Result,const NodeType & Node)194 void verify(const MatchFinder::MatchResult &Result, const NodeType &Node) {
195 SourceRange R = getRange(Node);
196 SourceLocation Begin = R.getBegin();
197 SourceLocation End = R.getEnd();
198 unsigned BeginLine = Result.SourceManager->getSpellingLineNumber(Begin);
199 unsigned BeginColumn = Result.SourceManager->getSpellingColumnNumber(Begin);
200 unsigned EndLine = Result.SourceManager->getSpellingLineNumber(End);
201 unsigned EndColumn = Result.SourceManager->getSpellingColumnNumber(End);
202 if (BeginLine != ExpectBeginLine || BeginColumn != ExpectBeginColumn ||
203 EndLine != ExpectEndLine || EndColumn != ExpectEndColumn) {
204 std::string MsgStr;
205 llvm::raw_string_ostream Msg(MsgStr);
206 Msg << "Expected range <" << ExpectBeginLine << ":" << ExpectBeginColumn
207 << '-' << ExpectEndLine << ":" << ExpectEndColumn << ">, found <";
208 Begin.print(Msg, *Result.SourceManager);
209 Msg << '-';
210 End.print(Msg, *Result.SourceManager);
211 Msg << '>';
212 this->setFailure(Msg.str());
213 }
214 }
215
getRange(const NodeType & Node)216 virtual SourceRange getRange(const NodeType &Node) {
217 return Node.getSourceRange();
218 }
219
220 private:
221 unsigned ExpectBeginLine, ExpectBeginColumn, ExpectEndLine, ExpectEndColumn;
222 };
223
224 /// \brief Verify whether a node's dump contains a given substring.
225 class DumpVerifier : public MatchVerifier<ast_type_traits::DynTypedNode> {
226 public:
expectSubstring(const std::string & Str)227 void expectSubstring(const std::string &Str) {
228 ExpectSubstring = Str;
229 }
230
231 protected:
verify(const MatchFinder::MatchResult & Result,const ast_type_traits::DynTypedNode & Node)232 void verify(const MatchFinder::MatchResult &Result,
233 const ast_type_traits::DynTypedNode &Node) {
234 std::string DumpStr;
235 llvm::raw_string_ostream Dump(DumpStr);
236 Node.dump(Dump, *Result.SourceManager);
237
238 if (Dump.str().find(ExpectSubstring) == std::string::npos) {
239 std::string MsgStr;
240 llvm::raw_string_ostream Msg(MsgStr);
241 Msg << "Expected dump substring <" << ExpectSubstring << ">, found <"
242 << Dump.str() << '>';
243 this->setFailure(Msg.str());
244 }
245 }
246
247 private:
248 std::string ExpectSubstring;
249 };
250
251 /// \brief Verify whether a node's pretty print matches a given string.
252 class PrintVerifier : public MatchVerifier<ast_type_traits::DynTypedNode> {
253 public:
expectString(const std::string & Str)254 void expectString(const std::string &Str) {
255 ExpectString = Str;
256 }
257
258 protected:
verify(const MatchFinder::MatchResult & Result,const ast_type_traits::DynTypedNode & Node)259 void verify(const MatchFinder::MatchResult &Result,
260 const ast_type_traits::DynTypedNode &Node) {
261 std::string PrintStr;
262 llvm::raw_string_ostream Print(PrintStr);
263 Node.print(Print, Result.Context->getPrintingPolicy());
264
265 if (Print.str() != ExpectString) {
266 std::string MsgStr;
267 llvm::raw_string_ostream Msg(MsgStr);
268 Msg << "Expected pretty print <" << ExpectString << ">, found <"
269 << Print.str() << '>';
270 this->setFailure(Msg.str());
271 }
272 }
273
274 private:
275 std::string ExpectString;
276 };
277
278 } // end namespace ast_matchers
279 } // end namespace clang
280