• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- unittest/AST/MatchVerifier.h - AST unit test support ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 //  Provides MatchVerifier, a base class to implement gtest matchers that
10 //  verify things that can be matched on the AST.
11 //
12 //  Also implements matchers based on MatchVerifier:
13 //  LocationVerifier and RangeVerifier to verify whether a matched node has
14 //  the expected source location or source range.
15 //
16 //===----------------------------------------------------------------------===//
17 
18 #ifndef LLVM_CLANG_UNITTESTS_AST_MATCHVERIFIER_H
19 #define LLVM_CLANG_UNITTESTS_AST_MATCHVERIFIER_H
20 
21 #include "clang/AST/ASTContext.h"
22 #include "clang/ASTMatchers/ASTMatchFinder.h"
23 #include "clang/ASTMatchers/ASTMatchers.h"
24 #include "clang/Testing/CommandLineArgs.h"
25 #include "clang/Tooling/Tooling.h"
26 #include "gtest/gtest.h"
27 
28 namespace clang {
29 namespace ast_matchers {
30 
31 /// \brief Base class for verifying some property of nodes found by a matcher.
32 template <typename NodeType>
33 class MatchVerifier : public MatchFinder::MatchCallback {
34 public:
35   template <typename MatcherType>
match(const std::string & Code,const MatcherType & AMatcher)36   testing::AssertionResult match(const std::string &Code,
37                                  const MatcherType &AMatcher) {
38     std::vector<std::string> Args;
39     return match(Code, AMatcher, Args, Lang_CXX03);
40   }
41 
42   template <typename MatcherType>
match(const std::string & Code,const MatcherType & AMatcher,TestLanguage L)43   testing::AssertionResult match(const std::string &Code,
44                                  const MatcherType &AMatcher, TestLanguage L) {
45     std::vector<std::string> Args;
46     return match(Code, AMatcher, Args, L);
47   }
48 
49   template <typename MatcherType>
50   testing::AssertionResult
51   match(const std::string &Code, const MatcherType &AMatcher,
52         std::vector<std::string> &Args, TestLanguage L);
53 
54   template <typename MatcherType>
55   testing::AssertionResult match(const Decl *D, const MatcherType &AMatcher);
56 
57 protected:
58   void run(const MatchFinder::MatchResult &Result) override;
verify(const MatchFinder::MatchResult & Result,const NodeType & Node)59   virtual void verify(const MatchFinder::MatchResult &Result,
60                       const NodeType &Node) {}
61 
setFailure(const Twine & Result)62   void setFailure(const Twine &Result) {
63     Verified = false;
64     VerifyResult = Result.str();
65   }
66 
setSuccess()67   void setSuccess() {
68     Verified = true;
69   }
70 
71 private:
72   bool Verified;
73   std::string VerifyResult;
74 };
75 
76 /// \brief Runs a matcher over some code, and returns the result of the
77 /// verifier for the matched node.
78 template <typename NodeType>
79 template <typename MatcherType>
80 testing::AssertionResult
match(const std::string & Code,const MatcherType & AMatcher,std::vector<std::string> & Args,TestLanguage L)81 MatchVerifier<NodeType>::match(const std::string &Code,
82                                const MatcherType &AMatcher,
83                                std::vector<std::string> &Args, TestLanguage L) {
84   MatchFinder Finder;
85   Finder.addMatcher(AMatcher.bind(""), this);
86   std::unique_ptr<tooling::FrontendActionFactory> Factory(
87       tooling::newFrontendActionFactory(&Finder));
88 
89   StringRef FileName;
90   switch (L) {
91   case Lang_C89:
92     Args.push_back("-std=c89");
93     FileName = "input.c";
94     break;
95   case Lang_C99:
96     Args.push_back("-std=c99");
97     FileName = "input.c";
98     break;
99   case Lang_CXX03:
100     Args.push_back("-std=c++03");
101     FileName = "input.cc";
102     break;
103   case Lang_CXX11:
104     Args.push_back("-std=c++11");
105     FileName = "input.cc";
106     break;
107   case Lang_CXX14:
108     Args.push_back("-std=c++14");
109     FileName = "input.cc";
110     break;
111   case Lang_CXX17:
112     Args.push_back("-std=c++17");
113     FileName = "input.cc";
114     break;
115   case Lang_CXX20:
116     Args.push_back("-std=c++20");
117     FileName = "input.cc";
118     break;
119   case Lang_OpenCL:
120     FileName = "input.cl";
121     break;
122   case Lang_OBJCXX:
123     FileName = "input.mm";
124     break;
125   }
126 
127   // Default to failure in case callback is never called
128   setFailure("Could not find match");
129   if (!tooling::runToolOnCodeWithArgs(Factory->create(), Code, Args, FileName))
130     return testing::AssertionFailure() << "Parsing error";
131   if (!Verified)
132     return testing::AssertionFailure() << VerifyResult;
133   return testing::AssertionSuccess();
134 }
135 
136 /// \brief Runs a matcher over some AST, and returns the result of the
137 /// verifier for the matched node.
138 template <typename NodeType> template <typename MatcherType>
match(const Decl * D,const MatcherType & AMatcher)139 testing::AssertionResult MatchVerifier<NodeType>::match(
140     const Decl *D, const MatcherType &AMatcher) {
141   MatchFinder Finder;
142   Finder.addMatcher(AMatcher.bind(""), this);
143 
144   setFailure("Could not find match");
145   Finder.match(*D, D->getASTContext());
146 
147   if (!Verified)
148     return testing::AssertionFailure() << VerifyResult;
149   return testing::AssertionSuccess();
150 }
151 
152 template <typename NodeType>
run(const MatchFinder::MatchResult & Result)153 void MatchVerifier<NodeType>::run(const MatchFinder::MatchResult &Result) {
154   const NodeType *Node = Result.Nodes.getNodeAs<NodeType>("");
155   if (!Node) {
156     setFailure("Matched node has wrong type");
157   } else {
158     // Callback has been called, default to success.
159     setSuccess();
160     verify(Result, *Node);
161   }
162 }
163 
164 template <>
165 inline void
run(const MatchFinder::MatchResult & Result)166 MatchVerifier<DynTypedNode>::run(const MatchFinder::MatchResult &Result) {
167   BoundNodes::IDToNodeMap M = Result.Nodes.getMap();
168   BoundNodes::IDToNodeMap::const_iterator I = M.find("");
169   if (I == M.end()) {
170     setFailure("Node was not bound");
171   } else {
172     // Callback has been called, default to success.
173     setSuccess();
174     verify(Result, I->second);
175   }
176 }
177 
178 /// \brief Verify whether a node has the correct source location.
179 ///
180 /// By default, Node.getSourceLocation() is checked. This can be changed
181 /// by overriding getLocation().
182 template <typename NodeType>
183 class LocationVerifier : public MatchVerifier<NodeType> {
184 public:
expectLocation(unsigned Line,unsigned Column)185   void expectLocation(unsigned Line, unsigned Column) {
186     ExpectLine = Line;
187     ExpectColumn = Column;
188   }
189 
190 protected:
verify(const MatchFinder::MatchResult & Result,const NodeType & Node)191   void verify(const MatchFinder::MatchResult &Result,
192               const NodeType &Node) override {
193     SourceLocation Loc = getLocation(Node);
194     unsigned Line = Result.SourceManager->getSpellingLineNumber(Loc);
195     unsigned Column = Result.SourceManager->getSpellingColumnNumber(Loc);
196     if (Line != ExpectLine || Column != ExpectColumn) {
197       std::string MsgStr;
198       llvm::raw_string_ostream Msg(MsgStr);
199       Msg << "Expected location <" << ExpectLine << ":" << ExpectColumn
200           << ">, found <";
201       Loc.print(Msg, *Result.SourceManager);
202       Msg << '>';
203       this->setFailure(Msg.str());
204     }
205   }
206 
getLocation(const NodeType & Node)207   virtual SourceLocation getLocation(const NodeType &Node) {
208     return Node.getLocation();
209   }
210 
211 private:
212   unsigned ExpectLine, ExpectColumn;
213 };
214 
215 /// \brief Verify whether a node has the correct source range.
216 ///
217 /// By default, Node.getSourceRange() is checked. This can be changed
218 /// by overriding getRange().
219 template <typename NodeType>
220 class RangeVerifier : public MatchVerifier<NodeType> {
221 public:
expectRange(unsigned BeginLine,unsigned BeginColumn,unsigned EndLine,unsigned EndColumn)222   void expectRange(unsigned BeginLine, unsigned BeginColumn,
223                    unsigned EndLine, unsigned EndColumn) {
224     ExpectBeginLine = BeginLine;
225     ExpectBeginColumn = BeginColumn;
226     ExpectEndLine = EndLine;
227     ExpectEndColumn = EndColumn;
228   }
229 
230 protected:
verify(const MatchFinder::MatchResult & Result,const NodeType & Node)231   void verify(const MatchFinder::MatchResult &Result,
232               const NodeType &Node) override {
233     SourceRange R = getRange(Node);
234     SourceLocation Begin = R.getBegin();
235     SourceLocation End = R.getEnd();
236     unsigned BeginLine = Result.SourceManager->getSpellingLineNumber(Begin);
237     unsigned BeginColumn = Result.SourceManager->getSpellingColumnNumber(Begin);
238     unsigned EndLine = Result.SourceManager->getSpellingLineNumber(End);
239     unsigned EndColumn = Result.SourceManager->getSpellingColumnNumber(End);
240     if (BeginLine != ExpectBeginLine || BeginColumn != ExpectBeginColumn ||
241         EndLine != ExpectEndLine || EndColumn != ExpectEndColumn) {
242       std::string MsgStr;
243       llvm::raw_string_ostream Msg(MsgStr);
244       Msg << "Expected range <" << ExpectBeginLine << ":" << ExpectBeginColumn
245           << '-' << ExpectEndLine << ":" << ExpectEndColumn << ">, found <";
246       Begin.print(Msg, *Result.SourceManager);
247       Msg << '-';
248       End.print(Msg, *Result.SourceManager);
249       Msg << '>';
250       this->setFailure(Msg.str());
251     }
252   }
253 
getRange(const NodeType & Node)254   virtual SourceRange getRange(const NodeType &Node) {
255     return Node.getSourceRange();
256   }
257 
258 private:
259   unsigned ExpectBeginLine, ExpectBeginColumn, ExpectEndLine, ExpectEndColumn;
260 };
261 
262 /// \brief Verify whether a node's dump contains a given substring.
263 class DumpVerifier : public MatchVerifier<DynTypedNode> {
264 public:
expectSubstring(const std::string & Str)265   void expectSubstring(const std::string &Str) {
266     ExpectSubstring = Str;
267   }
268 
269 protected:
verify(const MatchFinder::MatchResult & Result,const DynTypedNode & Node)270   void verify(const MatchFinder::MatchResult &Result,
271               const DynTypedNode &Node) override {
272     std::string DumpStr;
273     llvm::raw_string_ostream Dump(DumpStr);
274     Node.dump(Dump, *Result.Context);
275 
276     if (Dump.str().find(ExpectSubstring) == std::string::npos) {
277       std::string MsgStr;
278       llvm::raw_string_ostream Msg(MsgStr);
279       Msg << "Expected dump substring <" << ExpectSubstring << ">, found <"
280           << Dump.str() << '>';
281       this->setFailure(Msg.str());
282     }
283   }
284 
285 private:
286   std::string ExpectSubstring;
287 };
288 
289 /// \brief Verify whether a node's pretty print matches a given string.
290 class PrintVerifier : public MatchVerifier<DynTypedNode> {
291 public:
expectString(const std::string & Str)292   void expectString(const std::string &Str) {
293     ExpectString = Str;
294   }
295 
296 protected:
verify(const MatchFinder::MatchResult & Result,const DynTypedNode & Node)297   void verify(const MatchFinder::MatchResult &Result,
298               const DynTypedNode &Node) override {
299     std::string PrintStr;
300     llvm::raw_string_ostream Print(PrintStr);
301     Node.print(Print, Result.Context->getPrintingPolicy());
302 
303     if (Print.str() != ExpectString) {
304       std::string MsgStr;
305       llvm::raw_string_ostream Msg(MsgStr);
306       Msg << "Expected pretty print <" << ExpectString << ">, found <"
307           << Print.str() << '>';
308       this->setFailure(Msg.str());
309     }
310   }
311 
312 private:
313   std::string ExpectString;
314 };
315 
316 } // end namespace ast_matchers
317 } // end namespace clang
318 
319 #endif
320