• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===--- unittests/Tooling/RecursiveASTVisitorTests/CallbacksCommon.h -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestVisitor.h"
10 
11 using namespace clang;
12 
13 namespace {
14 
15 enum class ShouldTraversePostOrder : bool {
16   No = false,
17   Yes = true,
18 };
19 
20 /// Base class for tests for RecursiveASTVisitor tests that validate the
21 /// sequence of calls to user-defined callbacks like Traverse*(), WalkUp*(),
22 /// Visit*().
23 template <typename Derived>
24 class RecordingVisitorBase : public TestVisitor<Derived> {
25   ShouldTraversePostOrder ShouldTraversePostOrderValue;
26 
27 public:
RecordingVisitorBase(ShouldTraversePostOrder ShouldTraversePostOrderValue)28   RecordingVisitorBase(ShouldTraversePostOrder ShouldTraversePostOrderValue)
29       : ShouldTraversePostOrderValue(ShouldTraversePostOrderValue) {}
30 
shouldTraversePostOrder()31   bool shouldTraversePostOrder() const {
32     return static_cast<bool>(ShouldTraversePostOrderValue);
33   }
34 
35   // Callbacks received during traversal.
36   std::string CallbackLog;
37   unsigned CallbackLogIndent = 0;
38 
stmtToString(Stmt * S)39   std::string stmtToString(Stmt *S) {
40     StringRef ClassName = S->getStmtClassName();
41     if (IntegerLiteral *IL = dyn_cast<IntegerLiteral>(S)) {
42       return (ClassName + "(" + IL->getValue().toString(10, false) + ")").str();
43     }
44     if (UnaryOperator *UO = dyn_cast<UnaryOperator>(S)) {
45       return (ClassName + "(" + UnaryOperator::getOpcodeStr(UO->getOpcode()) +
46               ")")
47           .str();
48     }
49     if (BinaryOperator *BO = dyn_cast<BinaryOperator>(S)) {
50       return (ClassName + "(" + BinaryOperator::getOpcodeStr(BO->getOpcode()) +
51               ")")
52           .str();
53     }
54     if (CallExpr *CE = dyn_cast<CallExpr>(S)) {
55       if (FunctionDecl *Callee = CE->getDirectCallee()) {
56         if (Callee->getIdentifier()) {
57           return (ClassName + "(" + Callee->getName() + ")").str();
58         }
59       }
60     }
61     if (DeclRefExpr *DRE = dyn_cast<DeclRefExpr>(S)) {
62       if (NamedDecl *ND = DRE->getFoundDecl()) {
63         if (ND->getIdentifier()) {
64           return (ClassName + "(" + ND->getName() + ")").str();
65         }
66       }
67     }
68     return ClassName.str();
69   }
70 
71   /// Record the fact that the user-defined callback member function
72   /// \p CallbackName was called with the argument \p S. Then, record the
73   /// effects of calling the default implementation \p CallDefaultFn.
74   template <typename CallDefault>
recordCallback(StringRef CallbackName,Stmt * S,CallDefault CallDefaultFn)75   void recordCallback(StringRef CallbackName, Stmt *S,
76                       CallDefault CallDefaultFn) {
77     for (unsigned i = 0; i != CallbackLogIndent; ++i) {
78       CallbackLog += "  ";
79     }
80     CallbackLog += (CallbackName + " " + stmtToString(S) + "\n").str();
81     ++CallbackLogIndent;
82     CallDefaultFn();
83     --CallbackLogIndent;
84   }
85 };
86 
87 template <typename VisitorTy>
visitorCallbackLogEqual(VisitorTy Visitor,StringRef Code,StringRef ExpectedLog)88 ::testing::AssertionResult visitorCallbackLogEqual(VisitorTy Visitor,
89                                                    StringRef Code,
90                                                    StringRef ExpectedLog) {
91   Visitor.runOver(Code);
92   // EXPECT_EQ shows the diff between the two strings if they are different.
93   EXPECT_EQ(ExpectedLog.trim().str(),
94             StringRef(Visitor.CallbackLog).trim().str());
95   if (ExpectedLog.trim() != StringRef(Visitor.CallbackLog).trim()) {
96     return ::testing::AssertionFailure();
97   }
98   return ::testing::AssertionSuccess();
99 }
100 
101 } // namespace
102