1 //===-- IncludeFixer.cpp - Include inserter based on sema callbacks -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "IncludeFixer.h"
10 #include "clang/Format/Format.h"
11 #include "clang/Frontend/CompilerInstance.h"
12 #include "clang/Lex/HeaderSearch.h"
13 #include "clang/Lex/Preprocessor.h"
14 #include "clang/Parse/ParseAST.h"
15 #include "clang/Sema/Sema.h"
16 #include "llvm/Support/Debug.h"
17 #include "llvm/Support/raw_ostream.h"
18
19 #define DEBUG_TYPE "clang-include-fixer"
20
21 using namespace clang;
22
23 namespace clang {
24 namespace include_fixer {
25 namespace {
26 /// Manages the parse, gathers include suggestions.
27 class Action : public clang::ASTFrontendAction {
28 public:
Action(SymbolIndexManager & SymbolIndexMgr,bool MinimizeIncludePaths)29 explicit Action(SymbolIndexManager &SymbolIndexMgr, bool MinimizeIncludePaths)
30 : SemaSource(SymbolIndexMgr, MinimizeIncludePaths,
31 /*GenerateDiagnostics=*/false) {}
32
33 std::unique_ptr<clang::ASTConsumer>
CreateASTConsumer(clang::CompilerInstance & Compiler,StringRef InFile)34 CreateASTConsumer(clang::CompilerInstance &Compiler,
35 StringRef InFile) override {
36 SemaSource.setFilePath(InFile);
37 return std::make_unique<clang::ASTConsumer>();
38 }
39
ExecuteAction()40 void ExecuteAction() override {
41 clang::CompilerInstance *Compiler = &getCompilerInstance();
42 assert(!Compiler->hasSema() && "CI already has Sema");
43
44 // Set up our hooks into sema and parse the AST.
45 if (hasCodeCompletionSupport() &&
46 !Compiler->getFrontendOpts().CodeCompletionAt.FileName.empty())
47 Compiler->createCodeCompletionConsumer();
48
49 clang::CodeCompleteConsumer *CompletionConsumer = nullptr;
50 if (Compiler->hasCodeCompletionConsumer())
51 CompletionConsumer = &Compiler->getCodeCompletionConsumer();
52
53 Compiler->createSema(getTranslationUnitKind(), CompletionConsumer);
54 SemaSource.setCompilerInstance(Compiler);
55 Compiler->getSema().addExternalSource(&SemaSource);
56
57 clang::ParseAST(Compiler->getSema(), Compiler->getFrontendOpts().ShowStats,
58 Compiler->getFrontendOpts().SkipFunctionBodies);
59 }
60
61 IncludeFixerContext
getIncludeFixerContext(const clang::SourceManager & SourceManager,clang::HeaderSearch & HeaderSearch) const62 getIncludeFixerContext(const clang::SourceManager &SourceManager,
63 clang::HeaderSearch &HeaderSearch) const {
64 return SemaSource.getIncludeFixerContext(SourceManager, HeaderSearch,
65 SemaSource.getMatchedSymbols());
66 }
67
68 private:
69 IncludeFixerSemaSource SemaSource;
70 };
71
72 } // namespace
73
IncludeFixerActionFactory(SymbolIndexManager & SymbolIndexMgr,std::vector<IncludeFixerContext> & Contexts,StringRef StyleName,bool MinimizeIncludePaths)74 IncludeFixerActionFactory::IncludeFixerActionFactory(
75 SymbolIndexManager &SymbolIndexMgr,
76 std::vector<IncludeFixerContext> &Contexts, StringRef StyleName,
77 bool MinimizeIncludePaths)
78 : SymbolIndexMgr(SymbolIndexMgr), Contexts(Contexts),
79 MinimizeIncludePaths(MinimizeIncludePaths) {}
80
81 IncludeFixerActionFactory::~IncludeFixerActionFactory() = default;
82
runInvocation(std::shared_ptr<clang::CompilerInvocation> Invocation,clang::FileManager * Files,std::shared_ptr<clang::PCHContainerOperations> PCHContainerOps,clang::DiagnosticConsumer * Diagnostics)83 bool IncludeFixerActionFactory::runInvocation(
84 std::shared_ptr<clang::CompilerInvocation> Invocation,
85 clang::FileManager *Files,
86 std::shared_ptr<clang::PCHContainerOperations> PCHContainerOps,
87 clang::DiagnosticConsumer *Diagnostics) {
88 assert(Invocation->getFrontendOpts().Inputs.size() == 1);
89
90 // Set up Clang.
91 clang::CompilerInstance Compiler(PCHContainerOps);
92 Compiler.setInvocation(std::move(Invocation));
93 Compiler.setFileManager(Files);
94
95 // Create the compiler's actual diagnostics engine. We want to drop all
96 // diagnostics here.
97 Compiler.createDiagnostics(new clang::IgnoringDiagConsumer,
98 /*ShouldOwnClient=*/true);
99 Compiler.createSourceManager(*Files);
100
101 // We abort on fatal errors so don't let a large number of errors become
102 // fatal. A missing #include can cause thousands of errors.
103 Compiler.getDiagnostics().setErrorLimit(0);
104
105 // Run the parser, gather missing includes.
106 auto ScopedToolAction =
107 std::make_unique<Action>(SymbolIndexMgr, MinimizeIncludePaths);
108 Compiler.ExecuteAction(*ScopedToolAction);
109
110 Contexts.push_back(ScopedToolAction->getIncludeFixerContext(
111 Compiler.getSourceManager(),
112 Compiler.getPreprocessor().getHeaderSearchInfo()));
113
114 // Technically this should only return true if we're sure that we have a
115 // parseable file. We don't know that though. Only inform users of fatal
116 // errors.
117 return !Compiler.getDiagnostics().hasFatalErrorOccurred();
118 }
119
addDiagnosticsForContext(TypoCorrection & Correction,const IncludeFixerContext & Context,StringRef Code,SourceLocation StartOfFile,ASTContext & Ctx)120 static bool addDiagnosticsForContext(TypoCorrection &Correction,
121 const IncludeFixerContext &Context,
122 StringRef Code, SourceLocation StartOfFile,
123 ASTContext &Ctx) {
124 auto Reps = createIncludeFixerReplacements(
125 Code, Context, format::getLLVMStyle(), /*AddQualifiers=*/false);
126 if (!Reps || Reps->size() != 1)
127 return false;
128
129 unsigned DiagID = Ctx.getDiagnostics().getCustomDiagID(
130 DiagnosticsEngine::Note, "Add '#include %0' to provide the missing "
131 "declaration [clang-include-fixer]");
132
133 // FIXME: Currently we only generate a diagnostic for the first header. Give
134 // the user choices.
135 const tooling::Replacement &Placed = *Reps->begin();
136
137 auto Begin = StartOfFile.getLocWithOffset(Placed.getOffset());
138 auto End = Begin.getLocWithOffset(std::max(0, (int)Placed.getLength() - 1));
139 PartialDiagnostic PD(DiagID, Ctx.getDiagAllocator());
140 PD << Context.getHeaderInfos().front().Header
141 << FixItHint::CreateReplacement(CharSourceRange::getCharRange(Begin, End),
142 Placed.getReplacementText());
143 Correction.addExtraDiagnostic(std::move(PD));
144 return true;
145 }
146
147 /// Callback for incomplete types. If we encounter a forward declaration we
148 /// have the fully qualified name ready. Just query that.
MaybeDiagnoseMissingCompleteType(clang::SourceLocation Loc,clang::QualType T)149 bool IncludeFixerSemaSource::MaybeDiagnoseMissingCompleteType(
150 clang::SourceLocation Loc, clang::QualType T) {
151 // Ignore spurious callbacks from SFINAE contexts.
152 if (CI->getSema().isSFINAEContext())
153 return false;
154
155 clang::ASTContext &context = CI->getASTContext();
156 std::string QueryString = QualType(T->getUnqualifiedDesugaredType(), 0)
157 .getAsString(context.getPrintingPolicy());
158 LLVM_DEBUG(llvm::dbgs() << "Query missing complete type '" << QueryString
159 << "'");
160 // Pass an empty range here since we don't add qualifier in this case.
161 std::vector<find_all_symbols::SymbolInfo> MatchedSymbols =
162 query(QueryString, "", tooling::Range());
163
164 if (!MatchedSymbols.empty() && GenerateDiagnostics) {
165 TypoCorrection Correction;
166 FileID FID = CI->getSourceManager().getFileID(Loc);
167 StringRef Code = CI->getSourceManager().getBufferData(FID);
168 SourceLocation StartOfFile =
169 CI->getSourceManager().getLocForStartOfFile(FID);
170 addDiagnosticsForContext(
171 Correction,
172 getIncludeFixerContext(CI->getSourceManager(),
173 CI->getPreprocessor().getHeaderSearchInfo(),
174 MatchedSymbols),
175 Code, StartOfFile, CI->getASTContext());
176 for (const PartialDiagnostic &PD : Correction.getExtraDiagnostics())
177 CI->getSema().Diag(Loc, PD);
178 }
179 return true;
180 }
181
182 /// Callback for unknown identifiers. Try to piece together as much
183 /// qualification as we can get and do a query.
CorrectTypo(const DeclarationNameInfo & Typo,int LookupKind,Scope * S,CXXScopeSpec * SS,CorrectionCandidateCallback & CCC,DeclContext * MemberContext,bool EnteringContext,const ObjCObjectPointerType * OPT)184 clang::TypoCorrection IncludeFixerSemaSource::CorrectTypo(
185 const DeclarationNameInfo &Typo, int LookupKind, Scope *S, CXXScopeSpec *SS,
186 CorrectionCandidateCallback &CCC, DeclContext *MemberContext,
187 bool EnteringContext, const ObjCObjectPointerType *OPT) {
188 // Ignore spurious callbacks from SFINAE contexts.
189 if (CI->getSema().isSFINAEContext())
190 return clang::TypoCorrection();
191
192 // We currently ignore the unidentified symbol which is not from the
193 // main file.
194 //
195 // However, this is not always true due to templates in a non-self contained
196 // header, consider the case:
197 //
198 // // header.h
199 // template <typename T>
200 // class Foo {
201 // T t;
202 // };
203 //
204 // // test.cc
205 // // We need to add <bar.h> in test.cc instead of header.h.
206 // class Bar;
207 // Foo<Bar> foo;
208 //
209 // FIXME: Add the missing header to the header file where the symbol comes
210 // from.
211 if (!CI->getSourceManager().isWrittenInMainFile(Typo.getLoc()))
212 return clang::TypoCorrection();
213
214 std::string TypoScopeString;
215 if (S) {
216 // FIXME: Currently we only use namespace contexts. Use other context
217 // types for query.
218 for (const auto *Context = S->getEntity(); Context;
219 Context = Context->getParent()) {
220 if (const auto *ND = dyn_cast<NamespaceDecl>(Context)) {
221 if (!ND->getName().empty())
222 TypoScopeString = ND->getNameAsString() + "::" + TypoScopeString;
223 }
224 }
225 }
226
227 auto ExtendNestedNameSpecifier = [this](CharSourceRange Range) {
228 StringRef Source =
229 Lexer::getSourceText(Range, CI->getSourceManager(), CI->getLangOpts());
230
231 // Skip forward until we find a character that's neither identifier nor
232 // colon. This is a bit of a hack around the fact that we will only get a
233 // single callback for a long nested name if a part of the beginning is
234 // unknown. For example:
235 //
236 // llvm::sys::path::parent_path(...)
237 // ^~~~ ^~~
238 // known
239 // ^~~~
240 // unknown, last callback
241 // ^~~~~~~~~~~
242 // no callback
243 //
244 // With the extension we get the full nested name specifier including
245 // parent_path.
246 // FIXME: Don't rely on source text.
247 const char *End = Source.end();
248 while (isIdentifierBody(*End) || *End == ':')
249 ++End;
250
251 return std::string(Source.begin(), End);
252 };
253
254 /// If we have a scope specification, use that to get more precise results.
255 std::string QueryString;
256 tooling::Range SymbolRange;
257 const auto &SM = CI->getSourceManager();
258 auto CreateToolingRange = [&QueryString, &SM](SourceLocation BeginLoc) {
259 return tooling::Range(SM.getDecomposedLoc(BeginLoc).second,
260 QueryString.size());
261 };
262 if (SS && SS->getRange().isValid()) {
263 auto Range = CharSourceRange::getTokenRange(SS->getRange().getBegin(),
264 Typo.getLoc());
265
266 QueryString = ExtendNestedNameSpecifier(Range);
267 SymbolRange = CreateToolingRange(Range.getBegin());
268 } else if (Typo.getName().isIdentifier() && !Typo.getLoc().isMacroID()) {
269 auto Range =
270 CharSourceRange::getTokenRange(Typo.getBeginLoc(), Typo.getEndLoc());
271
272 QueryString = ExtendNestedNameSpecifier(Range);
273 SymbolRange = CreateToolingRange(Range.getBegin());
274 } else {
275 QueryString = Typo.getAsString();
276 SymbolRange = CreateToolingRange(Typo.getLoc());
277 }
278
279 LLVM_DEBUG(llvm::dbgs() << "TypoScopeQualifiers: " << TypoScopeString
280 << "\n");
281 std::vector<find_all_symbols::SymbolInfo> MatchedSymbols =
282 query(QueryString, TypoScopeString, SymbolRange);
283
284 if (!MatchedSymbols.empty() && GenerateDiagnostics) {
285 TypoCorrection Correction(Typo.getName());
286 Correction.setCorrectionRange(SS, Typo);
287 FileID FID = SM.getFileID(Typo.getLoc());
288 StringRef Code = SM.getBufferData(FID);
289 SourceLocation StartOfFile = SM.getLocForStartOfFile(FID);
290 if (addDiagnosticsForContext(
291 Correction, getIncludeFixerContext(
292 SM, CI->getPreprocessor().getHeaderSearchInfo(),
293 MatchedSymbols),
294 Code, StartOfFile, CI->getASTContext()))
295 return Correction;
296 }
297 return TypoCorrection();
298 }
299
300 /// Get the minimal include for a given path.
minimizeInclude(StringRef Include,const clang::SourceManager & SourceManager,clang::HeaderSearch & HeaderSearch) const301 std::string IncludeFixerSemaSource::minimizeInclude(
302 StringRef Include, const clang::SourceManager &SourceManager,
303 clang::HeaderSearch &HeaderSearch) const {
304 if (!MinimizeIncludePaths)
305 return std::string(Include);
306
307 // Get the FileEntry for the include.
308 StringRef StrippedInclude = Include.trim("\"<>");
309 auto Entry = SourceManager.getFileManager().getFile(StrippedInclude);
310
311 // If the file doesn't exist return the path from the database.
312 // FIXME: This should never happen.
313 if (!Entry)
314 return std::string(Include);
315
316 bool IsSystem = false;
317 std::string Suggestion =
318 HeaderSearch.suggestPathToFileForDiagnostics(*Entry, "", &IsSystem);
319
320 return IsSystem ? '<' + Suggestion + '>' : '"' + Suggestion + '"';
321 }
322
323 /// Get the include fixer context for the queried symbol.
getIncludeFixerContext(const clang::SourceManager & SourceManager,clang::HeaderSearch & HeaderSearch,ArrayRef<find_all_symbols::SymbolInfo> MatchedSymbols) const324 IncludeFixerContext IncludeFixerSemaSource::getIncludeFixerContext(
325 const clang::SourceManager &SourceManager,
326 clang::HeaderSearch &HeaderSearch,
327 ArrayRef<find_all_symbols::SymbolInfo> MatchedSymbols) const {
328 std::vector<find_all_symbols::SymbolInfo> SymbolCandidates;
329 for (const auto &Symbol : MatchedSymbols) {
330 std::string FilePath = Symbol.getFilePath().str();
331 std::string MinimizedFilePath = minimizeInclude(
332 ((FilePath[0] == '"' || FilePath[0] == '<') ? FilePath
333 : "\"" + FilePath + "\""),
334 SourceManager, HeaderSearch);
335 SymbolCandidates.emplace_back(Symbol.getName(), Symbol.getSymbolKind(),
336 MinimizedFilePath, Symbol.getContexts());
337 }
338 return IncludeFixerContext(FilePath, QuerySymbolInfos, SymbolCandidates);
339 }
340
341 std::vector<find_all_symbols::SymbolInfo>
query(StringRef Query,StringRef ScopedQualifiers,tooling::Range Range)342 IncludeFixerSemaSource::query(StringRef Query, StringRef ScopedQualifiers,
343 tooling::Range Range) {
344 assert(!Query.empty() && "Empty query!");
345
346 // Save all instances of an unidentified symbol.
347 //
348 // We use conservative behavior for detecting the same unidentified symbol
349 // here. The symbols which have the same ScopedQualifier and RawIdentifier
350 // are considered equal. So that clang-include-fixer avoids false positives,
351 // and always adds missing qualifiers to correct symbols.
352 if (!GenerateDiagnostics && !QuerySymbolInfos.empty()) {
353 if (ScopedQualifiers == QuerySymbolInfos.front().ScopedQualifiers &&
354 Query == QuerySymbolInfos.front().RawIdentifier) {
355 QuerySymbolInfos.push_back(
356 {Query.str(), std::string(ScopedQualifiers), Range});
357 }
358 return {};
359 }
360
361 LLVM_DEBUG(llvm::dbgs() << "Looking up '" << Query << "' at ");
362 LLVM_DEBUG(CI->getSourceManager()
363 .getLocForStartOfFile(CI->getSourceManager().getMainFileID())
364 .getLocWithOffset(Range.getOffset())
365 .print(llvm::dbgs(), CI->getSourceManager()));
366 LLVM_DEBUG(llvm::dbgs() << " ...");
367 llvm::StringRef FileName = CI->getSourceManager().getFilename(
368 CI->getSourceManager().getLocForStartOfFile(
369 CI->getSourceManager().getMainFileID()));
370
371 QuerySymbolInfos.push_back(
372 {Query.str(), std::string(ScopedQualifiers), Range});
373
374 // Query the symbol based on C++ name Lookup rules.
375 // Firstly, lookup the identifier with scoped namespace contexts;
376 // If that fails, falls back to look up the identifier directly.
377 //
378 // For example:
379 //
380 // namespace a {
381 // b::foo f;
382 // }
383 //
384 // 1. lookup a::b::foo.
385 // 2. lookup b::foo.
386 std::string QueryString = ScopedQualifiers.str() + Query.str();
387 // It's unsafe to do nested search for the identifier with scoped namespace
388 // context, it might treat the identifier as a nested class of the scoped
389 // namespace.
390 std::vector<find_all_symbols::SymbolInfo> MatchedSymbols =
391 SymbolIndexMgr.search(QueryString, /*IsNestedSearch=*/false, FileName);
392 if (MatchedSymbols.empty())
393 MatchedSymbols =
394 SymbolIndexMgr.search(Query, /*IsNestedSearch=*/true, FileName);
395 LLVM_DEBUG(llvm::dbgs() << "Having found " << MatchedSymbols.size()
396 << " symbols\n");
397 // We store a copy of MatchedSymbols in a place where it's globally reachable.
398 // This is used by the standalone version of the tool.
399 this->MatchedSymbols = MatchedSymbols;
400 return MatchedSymbols;
401 }
402
createIncludeFixerReplacements(StringRef Code,const IncludeFixerContext & Context,const clang::format::FormatStyle & Style,bool AddQualifiers)403 llvm::Expected<tooling::Replacements> createIncludeFixerReplacements(
404 StringRef Code, const IncludeFixerContext &Context,
405 const clang::format::FormatStyle &Style, bool AddQualifiers) {
406 if (Context.getHeaderInfos().empty())
407 return tooling::Replacements();
408 StringRef FilePath = Context.getFilePath();
409 std::string IncludeName =
410 "#include " + Context.getHeaderInfos().front().Header + "\n";
411 // Create replacements for the new header.
412 clang::tooling::Replacements Insertions;
413 auto Err =
414 Insertions.add(tooling::Replacement(FilePath, UINT_MAX, 0, IncludeName));
415 if (Err)
416 return std::move(Err);
417
418 auto CleanReplaces = cleanupAroundReplacements(Code, Insertions, Style);
419 if (!CleanReplaces)
420 return CleanReplaces;
421
422 auto Replaces = std::move(*CleanReplaces);
423 if (AddQualifiers) {
424 for (const auto &Info : Context.getQuerySymbolInfos()) {
425 // Ignore the empty range.
426 if (Info.Range.getLength() > 0) {
427 auto R = tooling::Replacement(
428 {FilePath, Info.Range.getOffset(), Info.Range.getLength(),
429 Context.getHeaderInfos().front().QualifiedName});
430 auto Err = Replaces.add(R);
431 if (Err) {
432 llvm::consumeError(std::move(Err));
433 R = tooling::Replacement(
434 R.getFilePath(), Replaces.getShiftedCodePosition(R.getOffset()),
435 R.getLength(), R.getReplacementText());
436 Replaces = Replaces.merge(tooling::Replacements(R));
437 }
438 }
439 }
440 }
441 return formatReplacements(Code, Replaces, Style);
442 }
443
444 } // namespace include_fixer
445 } // namespace clang
446