• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2023-2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  * http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "cleanup_inline_module.h"
17 
18 #include "llvm_ark_interface.h"
19 #include "inline_ir_utils.h"
20 #include "llvm_compiler_options.h"
21 
22 #include <sstream>
23 
24 #include <llvm/IR/Function.h>
25 #include <llvm/IR/Module.h>
26 #include <llvm/IR/ValueSymbolTable.h>
27 #include <llvm/Pass.h>
28 #include <llvm/Support/Debug.h>
29 #include <llvm/Transforms/IPO/FunctionImport.h>
30 #include <llvm/Transforms/Utils/FunctionImportUtils.h>
31 
32 #include "transforms/transform_utils.h"
33 
34 #define DEBUG_TYPE "cleanup-inline-module"
35 
36 using llvm::Argument;
37 using llvm::BasicBlock;
38 using llvm::cast;
39 using llvm::Constant;
40 using llvm::convertToDeclaration;
41 using llvm::DenseMap;
42 using llvm::DenseSet;
43 using llvm::for_each;
44 using llvm::Function;
45 using llvm::GlobalVariable;
46 using llvm::InlineAsm;
47 using llvm::isa;
48 using llvm::ListSeparator;
49 using llvm::Module;
50 using llvm::SmallVector;
51 using llvm::SmallVectorImpl;
52 using llvm::StringRef;
53 using llvm::User;
54 using llvm::Value;
55 
56 namespace {
57 
58 /**
59  * 1. Inserts a given element to a vector
60  * 2. Removes the last element from the vector
61  * 3. Checks if the last element was the initially given element
62  *
63  * Used to track a dfs path
64  *
65  * @tparam T type of vector elements
66  */
67 template <typename T>
68 class ScopedVectorElement final {
69 public:
ScopedVectorElement(llvm::SmallVectorImpl<T> * vector,T value)70     explicit ScopedVectorElement(llvm::SmallVectorImpl<T> *vector, T value) : vector_(vector), value_(value)
71     {
72         ASSERT(vector != nullptr);
73         ASSERT(value != nullptr);
74         vector->push_back(value);
75     }
76 
~ScopedVectorElement()77     ~ScopedVectorElement()
78     {
79         [[maybe_unused]] auto value = vector_->pop_back_val();
80         ASSERT(value == value_);
81     }
82 
83     ScopedVectorElement(const ScopedVectorElement &) = delete;
84     ScopedVectorElement &operator=(const ScopedVectorElement &) = delete;
85     ScopedVectorElement(ScopedVectorElement &&) = delete;
86     ScopedVectorElement &operator=(ScopedVectorElement &&) = delete;
87 
88 private:
89     SmallVectorImpl<T> *vector_;
90     T value_;
91 };
92 
93 using InlineFailurePath = SmallVector<Value *, 4U>;
94 using ScopedInlineFailurePathElement = ScopedVectorElement<InlineFailurePath::value_type>;
95 
96 enum class InlineFailureReason {
97     HAS_ADDRESS_TAKEN,
98     USES_INTERNAL_VARIABLE,
99 };
100 
101 class InlineFailure {
102 public:
InlineFailure(InlineFailurePath inlineFailurePath,InlineFailureReason reason)103     InlineFailure(InlineFailurePath inlineFailurePath, InlineFailureReason reason)
104         : inlineFailurePath_(std::move(inlineFailurePath)), reason_(reason)
105     {
106     }
107 
GetInlineFailurePath() const108     const InlineFailurePath &GetInlineFailurePath() const
109     {
110         return inlineFailurePath_;
111     }
112 
GetReason() const113     InlineFailureReason GetReason() const
114     {
115         return reason_;
116     }
117 
Print(llvm::raw_ostream * output)118     void Print(llvm::raw_ostream *output)
119     {
120         ASSERT(output != nullptr);
121 
122         if (reason_ == InlineFailureReason::HAS_ADDRESS_TAKEN) {
123             PrintForHasAddressTaken(output);
124         } else if (reason_ == InlineFailureReason::USES_INTERNAL_VARIABLE) {
125             PrintForUsesInternalVariable(output);
126         } else {
127             LLVM_DEBUG(llvm::dbgs() << "static_cast<int>(reason_) = " << static_cast<int>(reason_) << "\n");
128             llvm_unreachable("Unsupported reason");
129         }
130     }
131 
132 private:
PrintForHasAddressTaken(llvm::raw_ostream * output)133     void PrintForHasAddressTaken(llvm::raw_ostream *output)
134     {
135         ASSERT(!inlineFailurePath_.empty());
136         ASSERT(reason_ == InlineFailureReason::HAS_ADDRESS_TAKEN);
137         auto function = inlineFailurePath_[inlineFailurePath_.size() - 1];
138         ASSERT(isa<Function>(function));
139         *output << "address of the function = '" << function->getName() << "' is taken. Won't inline chain: ";
140         PrintPath(output);
141     }
142 
PrintForUsesInternalVariable(llvm::raw_ostream * output)143     void PrintForUsesInternalVariable(llvm::raw_ostream *output)
144     {
145         ASSERT(reason_ == InlineFailureReason::USES_INTERNAL_VARIABLE);
146         ASSERT(inlineFailurePath_.size() >= 2U);
147         // inline_failure_path_ = [..., functionUsingInternalVariable, ...not function..., internalVariable]
148         auto variable = inlineFailurePath_[inlineFailurePath_.size() - 1];
149         auto function = FindLastFunctionInPath();
150         *output << "Internal variable = '" << variable->getName() << "' is used in function = '" << function->getName()
151                 << "'. Won't inline chain: ";
152         PrintPath(output);
153     }
154 
FindLastFunctionInPath()155     Function *FindLastFunctionInPath()
156     {
157         for (auto it = inlineFailurePath_.rbegin(); it != inlineFailurePath_.rend(); it++) {
158             if (auto function = llvm::dyn_cast<Function>(*it)) {
159                 return function;
160             }
161         }
162         LLVM_DEBUG(llvm::dbgs() << "inline_failure_path_ = ");
163         LLVM_DEBUG(PrintPath(&llvm::dbgs()));
164         LLVM_DEBUG(llvm::dbgs() << "\n");
165         llvm_unreachable("Could not find function in inline_failure_path_");
166     }
167 
PrintPath(llvm::raw_ostream * output)168     void PrintPath(llvm::raw_ostream *output)
169     {
170         ListSeparator separator(" -> ");
171         for (const auto &pathElement : inlineFailurePath_) {
172             *output << separator << "'" << pathElement->getName() << "'";
173         }
174     }
175 
176 private:
177     InlineFailurePath inlineFailurePath_;
178     InlineFailureReason reason_;
179 };
180 }  // namespace
181 
182 namespace ark::llvmbackend::passes {
183 
184 /**
185  * Remove functions and variables unsuitable for inlining from module.
186  *
187  * Function is unsuitable for inlining if:
188  *
189  * 1. It has local linkage and address taken
190  * 2. Function does not have external linkage, references variable with local linkage and the variable is not a
191  * constant. Examples:
192  *     1. Function has 'static int x' in its body
193  *     2. Function assigns a value to a 'static thread_local'
194  */
195 class CleanupInlineModule::InlineModuleCleaner {
196 public:
Run(Module & module)197     bool Run(Module &module)
198     {
199         RemoveNonInlinableFunctions(module);
200         RemoveObjectFileGlobals(module);
201         RemoveDanglingAliases(module);
202         return true;
203     }
204 
205 private:
206     enum class FunctionState {
207         UNKNOWN,        // white
208         IN_PROGRESS,    // gray
209         NOT_INLINABLE,  // black
210         INLINABLE,      // black
211     };
212 
213     using DfsState = DenseMap<Function *, FunctionState>;
214 
RemoveNonInlinableFunctions(Module & module)215     void RemoveNonInlinableFunctions(Module &module)
216     {
217         for (auto &function : module.functions()) {
218             if (function.isDeclaration()) {
219                 continue;
220             }
221 
222             ScopedInlineFailurePathElement inlineFailurePathElement {&inlineFailurePath_, &function};
223             VisitFunction(&function);
224         }
225         ASSERT(inlineFailurePath_.empty());
226 
227         for_each(state_, [](auto entry) -> void {
228             auto functionState = entry.getSecond();
229             ASSERT(functionState == FunctionState::INLINABLE || functionState == FunctionState::NOT_INLINABLE);
230             if (functionState == FunctionState::NOT_INLINABLE) {
231                 auto function = entry.getFirst();
232                 LLVM_DEBUG(llvm::dbgs() << "InlineModuleCleaner: removed '" << function->getName() << "'\n");
233                 convertToDeclaration(*function);
234             }
235         });
236         LLVM_DEBUG(PrintInlineReport(&llvm::errs()));
237     }
238 
239     /**
240      * Visit the function during depth first search.
241      *
242      * After returning the dfs_state_ will contain either:
243      * 1. FunctionState::NOT_INLINABLE
244      * 2. FunctionState::INLINABLE
245      * for the function.
246      */
VisitFunction(Function * function)247     void VisitFunction(Function *function)
248     {
249         static_assert(FunctionState() == FunctionState::UNKNOWN, "FunctionState::UNKNOWN must be the default value");
250         auto functionState = state_.lookup(function);
251         if (functionState != FunctionState::UNKNOWN) {
252             return;
253         }
254         if (function->hasLocalLinkage() && function->hasAddressTaken()) {
255             state_[function] = FunctionState::NOT_INLINABLE;
256             ReportInlineFailure(InlineFailureReason::HAS_ADDRESS_TAKEN, inlineFailurePath_);
257             return;
258         }
259 
260         state_.insert({function, FunctionState::IN_PROGRESS});
261 
262         DenseSet<Value *> visited;
263         for (auto &basicBlock : *function) {
264             if (!IsInlinable(&basicBlock, visited)) {
265                 state_[function] = FunctionState::NOT_INLINABLE;
266                 return;
267             }
268         }
269         state_[function] = FunctionState::INLINABLE;
270     }
271 
IsInlinable(Value * value,DenseSet<Value * > & visited)272     bool IsInlinable(Value *value, DenseSet<Value *> &visited)
273     {
274         ScopedInlineFailurePathElement inlineFailurePathElement {&inlineFailurePath_, value};
275         if (visited.contains(value)) {
276             return true;
277         }
278         visited.insert(value);
279         if (isa<Function>(value)) {
280             auto function = cast<Function>(value);
281             /**
282              * The value operand is a function with external linkage.
283              *
284              * The function's operand from VisitFunction is inlinable because we'll either:
285              * 1. Inline the operand itself
286              * 2. Leave a reference to external function, which linker will resolve
287              *
288              * Example:
289              *
290              * @code
291              * extern void foo();
292              *
293              * void bar() {
294              *     foo(); // We're here: the 'bar' function has external linkage
295              * }
296              * @endcode
297              *
298              * Call stack might be:
299              *
300              * @code
301              * 1. IsInlinable(foo);
302              * 2. VisitFunction(bar);
303              * @endcode
304              */
305             if (function->hasExternalLinkage()) {
306                 return true;
307             }
308 
309             /**
310              * Function does not have external linkage, then we inline it if the function itself is inlinable.
311              *
312              * Example:
313              *
314              * @code
315              * static void foo() {
316              * }
317              *
318              * void bar() {
319              *   foo(); // We're here, 'foo' function does not have external linkage
320              * }
321              * @endcode
322              *
323              * Call stack might be:
324              *
325              * @code
326              * 1. IsInlinable(foo);
327              * 2. VisitFunction(bar);
328              * @endcode
329              */
330             VisitFunction(function);
331             auto functionState = state_.lookup(function);
332             return functionState == FunctionState::IN_PROGRESS || functionState == FunctionState::INLINABLE;
333         }
334         if (isa<GlobalVariable>(value)) {
335             auto globalVariable = cast<GlobalVariable>(value);
336             /**
337              * Could be a constant from C++ declared in header.
338              * We don't check for taken address because in different translation units address of such constant
339              * could be different
340              *
341              * https://stackoverflow.com/a/50489130
342              * > constexpr implies const and const on global/namespace scope implies static (internal linkage),
343              * > which means that every translation unit including this header gets its own copy of PI. The memory
344              * > for that static is only going to be allocated if an address or reference to it is taken, and the
345              * > address is going to be different in each translation unit.
346              */
347             if (globalVariable->hasLocalLinkage() && !globalVariable->isConstant()) {
348                 /**
349                  * Mutable variable with local linkage, can't inline.
350                  *
351                  * Example:
352                  *
353                  * @code
354                  * void foo() {
355                  *   static int x = 0;
356                  *   std::cout << x++ << std::endl;
357                  * }
358                  * @endcode
359                  */
360                 ReportInlineFailure(InlineFailureReason::USES_INTERNAL_VARIABLE, inlineFailurePath_);
361                 return false;
362             }
363             return true;
364         }
365         if (isa<User>(value)) {
366             auto user = cast<User>(value);
367             for (auto operand : user->operand_values()) {
368                 if (!IsInlinable(operand, visited)) {
369                     return false;
370                 }
371             }
372             return true;
373         }
374         if (isa<BasicBlock>(value)) {
375             for (auto &instruction : *cast<BasicBlock>(value)) {
376                 if (!IsInlinable(&instruction, visited)) {
377                     return false;
378                 }
379             }
380             return true;
381         }
382         if (isa<Argument, InlineAsm, Constant, llvm::MetadataAsValue>(value)) {
383             return true;
384         }
385         LLVM_DEBUG(llvm::dbgs() << "Value = ");
386         LLVM_DEBUG(value->print(llvm::dbgs()));
387         LLVM_DEBUG(llvm::dbgs() << "\n");
388         llvm_unreachable("Unexpected value");
389     }
390 
ReportInlineFailure(InlineFailureReason inlineFailureReason,const InlineFailurePath & failurePath)391     void ReportInlineFailure([[maybe_unused]] InlineFailureReason inlineFailureReason,
392                              [[maybe_unused]] const InlineFailurePath &failurePath)
393     {
394         LLVM_DEBUG(InlineFailure(failurePath, inlineFailureReason).Print(&llvm::dbgs()));
395         LLVM_DEBUG(llvm::dbgs() << "\n");
396     }
397 
PrintInlineReport(llvm::raw_ostream * output) const398     void PrintInlineReport(llvm::raw_ostream *output) const
399     {
400         std::vector<DfsState::value_type> entries {state_.begin(), state_.end()};
401         std::sort(entries.begin(), entries.end(), [](auto a, auto b) {
402             auto aState = a.getSecond();
403             auto bState = b.getSecond();
404             ASSERT(aState == FunctionState::INLINABLE || aState == FunctionState::NOT_INLINABLE);
405             ASSERT(bState == FunctionState::INLINABLE || bState == FunctionState::NOT_INLINABLE);
406             if (aState != bState) {
407                 return aState == FunctionState::NOT_INLINABLE;
408             }
409             auto aFunction = a.getFirst();
410             auto bFunction = b.getFirst();
411             return aFunction->getName() < bFunction->getName();
412         });
413         for (size_t i = 0; i < entries.size(); i++) {
414             auto entry = entries[i];
415             auto function = entry.getFirst();
416             auto functionState = entry.getSecond();
417             ASSERT(functionState == FunctionState::INLINABLE || functionState == FunctionState::NOT_INLINABLE);
418             *output << (i + 1) << ". '" << function->getName() << "' is"
419                     << (functionState == FunctionState::INLINABLE ? " inlinable" : " not inlinable") << "\n";
420         }
421     }
422 
RemoveObjectFileGlobals(const Module & module) const423     void RemoveObjectFileGlobals(const Module &module) const
424     {
425         static constexpr std::array VALUES_TO_ERASE = {
426             StringRef("llvm.used"),                //
427             StringRef("llvm.compiler.used"),       //
428             StringRef("llvm.global_ctors"),        //
429             StringRef("llvm.global_dtors"),        //
430             StringRef("llvm.global.annotations"),  //
431         };
432         for (const auto &name : VALUES_TO_ERASE) {
433             auto *globalValue = module.getNamedValue(name);
434             if (globalValue != nullptr) {
435                 LLVM_DEBUG(llvm::dbgs() << "Erase " << globalValue->getName() << "\n");
436                 globalValue->eraseFromParent();
437             }
438         }
439     }
440 
441 private:
442     DfsState state_;
443     InlineFailurePath inlineFailurePath_;
444 };
445 
ShouldInsert(const ark::llvmbackend::LLVMCompilerOptions * options)446 bool CleanupInlineModule::ShouldInsert(const ark::llvmbackend::LLVMCompilerOptions *options)
447 {
448     return options->doIrtocInline;
449 }
450 
CleanupInlineModule()451 CleanupInlineModule::CleanupInlineModule() : cleaner_ {std::make_unique<InlineModuleCleaner>()} {}
452 
453 CleanupInlineModule::CleanupInlineModule(CleanupInlineModule &&) = default;
454 
455 CleanupInlineModule &CleanupInlineModule::operator=(CleanupInlineModule &&) = default;
456 
457 CleanupInlineModule::~CleanupInlineModule() = default;
458 
run(llvm::Module & module,llvm::ModuleAnalysisManager &)459 llvm::PreservedAnalyses CleanupInlineModule::run(llvm::Module &module, llvm::ModuleAnalysisManager & /*AM*/)
460 {
461     auto changed = cleaner_->Run(module);
462     return changed ? llvm::PreservedAnalyses::none() : llvm::PreservedAnalyses::all();
463 }
464 
465 }  // namespace ark::llvmbackend::passes
466