1 //===--- TestVisitor.h ------------------------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// 9 /// \file 10 /// \brief Defines utility templates for RecursiveASTVisitor related tests. 11 /// 12 //===----------------------------------------------------------------------===// 13 14 #ifndef LLVM_CLANG_UNITTESTS_TOOLING_TESTVISITOR_H 15 #define LLVM_CLANG_UNITTESTS_TOOLING_TESTVISITOR_H 16 17 #include "clang/AST/ASTConsumer.h" 18 #include "clang/AST/ASTContext.h" 19 #include "clang/AST/RecursiveASTVisitor.h" 20 #include "clang/Frontend/CompilerInstance.h" 21 #include "clang/Frontend/FrontendAction.h" 22 #include "clang/Tooling/Tooling.h" 23 #include "gtest/gtest.h" 24 #include <vector> 25 26 namespace clang { 27 28 /// \brief Base class for simple RecursiveASTVisitor based tests. 29 /// 30 /// This is a drop-in replacement for RecursiveASTVisitor itself, with the 31 /// additional capability of running it over a snippet of code. 32 /// 33 /// Visits template instantiations and implicit code by default. 34 template <typename T> 35 class TestVisitor : public RecursiveASTVisitor<T> { 36 public: TestVisitor()37 TestVisitor() { } 38 ~TestVisitor()39 virtual ~TestVisitor() { } 40 41 enum Language { 42 Lang_C, 43 Lang_CXX98, 44 Lang_CXX11, 45 Lang_CXX14, 46 Lang_CXX17, 47 Lang_CXX2a, 48 Lang_OBJC, 49 Lang_OBJCXX11, 50 Lang_CXX = Lang_CXX98 51 }; 52 53 /// \brief Runs the current AST visitor over the given code. 54 bool runOver(StringRef Code, Language L = Lang_CXX) { 55 std::vector<std::string> Args; 56 switch (L) { 57 case Lang_C: 58 Args.push_back("-x"); 59 Args.push_back("c"); 60 break; 61 case Lang_CXX98: Args.push_back("-std=c++98"); break; 62 case Lang_CXX11: Args.push_back("-std=c++11"); break; 63 case Lang_CXX14: Args.push_back("-std=c++14"); break; 64 case Lang_CXX17: Args.push_back("-std=c++17"); break; 65 case Lang_CXX2a: Args.push_back("-std=c++2a"); break; 66 case Lang_OBJC: 67 Args.push_back("-ObjC"); 68 Args.push_back("-fobjc-runtime=macosx-10.12.0"); 69 break; 70 case Lang_OBJCXX11: 71 Args.push_back("-ObjC++"); 72 Args.push_back("-std=c++11"); 73 Args.push_back("-fblocks"); 74 break; 75 } 76 return tooling::runToolOnCodeWithArgs(CreateTestAction(), Code, Args); 77 } 78 shouldVisitTemplateInstantiations()79 bool shouldVisitTemplateInstantiations() const { 80 return true; 81 } 82 shouldVisitImplicitCode()83 bool shouldVisitImplicitCode() const { 84 return true; 85 } 86 87 protected: CreateTestAction()88 virtual std::unique_ptr<ASTFrontendAction> CreateTestAction() { 89 return std::make_unique<TestAction>(this); 90 } 91 92 class FindConsumer : public ASTConsumer { 93 public: FindConsumer(TestVisitor * Visitor)94 FindConsumer(TestVisitor *Visitor) : Visitor(Visitor) {} 95 HandleTranslationUnit(clang::ASTContext & Context)96 void HandleTranslationUnit(clang::ASTContext &Context) override { 97 Visitor->Context = &Context; 98 Visitor->TraverseDecl(Context.getTranslationUnitDecl()); 99 } 100 101 private: 102 TestVisitor *Visitor; 103 }; 104 105 class TestAction : public ASTFrontendAction { 106 public: TestAction(TestVisitor * Visitor)107 TestAction(TestVisitor *Visitor) : Visitor(Visitor) {} 108 109 std::unique_ptr<clang::ASTConsumer> CreateASTConsumer(CompilerInstance &,llvm::StringRef dummy)110 CreateASTConsumer(CompilerInstance &, llvm::StringRef dummy) override { 111 /// TestConsumer will be deleted by the framework calling us. 112 return std::make_unique<FindConsumer>(Visitor); 113 } 114 115 protected: 116 TestVisitor *Visitor; 117 }; 118 119 ASTContext *Context; 120 }; 121 122 /// \brief A RecursiveASTVisitor to check that certain matches are (or are 123 /// not) observed during visitation. 124 /// 125 /// This is a RecursiveASTVisitor for testing the RecursiveASTVisitor itself, 126 /// and allows simple creation of test visitors running matches on only a small 127 /// subset of the Visit* methods. 128 template <typename T, template <typename> class Visitor = TestVisitor> 129 class ExpectedLocationVisitor : public Visitor<T> { 130 public: 131 /// \brief Expect 'Match' *not* to occur at the given 'Line' and 'Column'. 132 /// 133 /// Any number of matches can be disallowed. DisallowMatch(Twine Match,unsigned Line,unsigned Column)134 void DisallowMatch(Twine Match, unsigned Line, unsigned Column) { 135 DisallowedMatches.push_back(MatchCandidate(Match, Line, Column)); 136 } 137 138 /// \brief Expect 'Match' to occur at the given 'Line' and 'Column'. 139 /// 140 /// Any number of expected matches can be set by calling this repeatedly. 141 /// Each is expected to be matched 'Times' number of times. (This is useful in 142 /// cases in which different AST nodes can match at the same source code 143 /// location.) 144 void ExpectMatch(Twine Match, unsigned Line, unsigned Column, 145 unsigned Times = 1) { 146 ExpectedMatches.push_back(ExpectedMatch(Match, Line, Column, Times)); 147 } 148 149 /// \brief Checks that all expected matches have been found. ~ExpectedLocationVisitor()150 ~ExpectedLocationVisitor() override { 151 for (typename std::vector<ExpectedMatch>::const_iterator 152 It = ExpectedMatches.begin(), End = ExpectedMatches.end(); 153 It != End; ++It) { 154 It->ExpectFound(); 155 } 156 } 157 158 protected: 159 /// \brief Checks an actual match against expected and disallowed matches. 160 /// 161 /// Implementations are required to call this with appropriate values 162 /// for 'Name' during visitation. Match(StringRef Name,SourceLocation Location)163 void Match(StringRef Name, SourceLocation Location) { 164 const FullSourceLoc FullLocation = this->Context->getFullLoc(Location); 165 166 for (typename std::vector<MatchCandidate>::const_iterator 167 It = DisallowedMatches.begin(), End = DisallowedMatches.end(); 168 It != End; ++It) { 169 EXPECT_FALSE(It->Matches(Name, FullLocation)) 170 << "Matched disallowed " << *It; 171 } 172 173 for (typename std::vector<ExpectedMatch>::iterator 174 It = ExpectedMatches.begin(), End = ExpectedMatches.end(); 175 It != End; ++It) { 176 It->UpdateFor(Name, FullLocation, this->Context->getSourceManager()); 177 } 178 } 179 180 private: 181 struct MatchCandidate { 182 std::string ExpectedName; 183 unsigned LineNumber; 184 unsigned ColumnNumber; 185 MatchCandidateMatchCandidate186 MatchCandidate(Twine Name, unsigned LineNumber, unsigned ColumnNumber) 187 : ExpectedName(Name.str()), LineNumber(LineNumber), 188 ColumnNumber(ColumnNumber) { 189 } 190 MatchesMatchCandidate191 bool Matches(StringRef Name, FullSourceLoc const &Location) const { 192 return MatchesName(Name) && MatchesLocation(Location); 193 } 194 PartiallyMatchesMatchCandidate195 bool PartiallyMatches(StringRef Name, FullSourceLoc const &Location) const { 196 return MatchesName(Name) || MatchesLocation(Location); 197 } 198 MatchesNameMatchCandidate199 bool MatchesName(StringRef Name) const { 200 return Name == ExpectedName; 201 } 202 MatchesLocationMatchCandidate203 bool MatchesLocation(FullSourceLoc const &Location) const { 204 return Location.isValid() && 205 Location.getSpellingLineNumber() == LineNumber && 206 Location.getSpellingColumnNumber() == ColumnNumber; 207 } 208 209 friend std::ostream &operator<<(std::ostream &Stream, 210 MatchCandidate const &Match) { 211 return Stream << Match.ExpectedName 212 << " at " << Match.LineNumber << ":" << Match.ColumnNumber; 213 } 214 }; 215 216 struct ExpectedMatch { ExpectedMatchExpectedMatch217 ExpectedMatch(Twine Name, unsigned LineNumber, unsigned ColumnNumber, 218 unsigned Times) 219 : Candidate(Name, LineNumber, ColumnNumber), TimesExpected(Times), 220 TimesSeen(0) {} 221 UpdateForExpectedMatch222 void UpdateFor(StringRef Name, FullSourceLoc Location, SourceManager &SM) { 223 if (Candidate.Matches(Name, Location)) { 224 EXPECT_LT(TimesSeen, TimesExpected); 225 ++TimesSeen; 226 } else if (TimesSeen < TimesExpected && 227 Candidate.PartiallyMatches(Name, Location)) { 228 llvm::raw_string_ostream Stream(PartialMatches); 229 Stream << ", partial match: \"" << Name << "\" at "; 230 Location.print(Stream, SM); 231 } 232 } 233 ExpectFoundExpectedMatch234 void ExpectFound() const { 235 EXPECT_EQ(TimesExpected, TimesSeen) 236 << "Expected \"" << Candidate.ExpectedName 237 << "\" at " << Candidate.LineNumber 238 << ":" << Candidate.ColumnNumber << PartialMatches; 239 } 240 241 MatchCandidate Candidate; 242 std::string PartialMatches; 243 unsigned TimesExpected; 244 unsigned TimesSeen; 245 }; 246 247 std::vector<MatchCandidate> DisallowedMatches; 248 std::vector<ExpectedMatch> ExpectedMatches; 249 }; 250 } 251 252 #endif 253