1 /*
2 * Copyright (c) 2023-2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "cleanup_inline_module.h"
17
18 #include "llvm_ark_interface.h"
19 #include "inline_ir_utils.h"
20 #include "llvm_compiler_options.h"
21
22 #include <sstream>
23
24 #include <llvm/IR/Function.h>
25 #include <llvm/IR/Module.h>
26 #include <llvm/IR/ValueSymbolTable.h>
27 #include <llvm/Pass.h>
28 #include <llvm/Support/Debug.h>
29 #include <llvm/Transforms/IPO/FunctionImport.h>
30 #include <llvm/Transforms/Utils/FunctionImportUtils.h>
31
32 #include "transforms/transform_utils.h"
33
34 #define DEBUG_TYPE "cleanup-inline-module"
35
36 using llvm::Argument;
37 using llvm::BasicBlock;
38 using llvm::cast;
39 using llvm::Constant;
40 using llvm::convertToDeclaration;
41 using llvm::DenseMap;
42 using llvm::DenseSet;
43 using llvm::for_each;
44 using llvm::Function;
45 using llvm::GlobalVariable;
46 using llvm::InlineAsm;
47 using llvm::isa;
48 using llvm::ListSeparator;
49 using llvm::Module;
50 using llvm::SmallVector;
51 using llvm::SmallVectorImpl;
52 using llvm::StringRef;
53 using llvm::User;
54 using llvm::Value;
55
56 namespace {
57
58 /**
59 * 1. Inserts a given element to a vector
60 * 2. Removes the last element from the vector
61 * 3. Checks if the last element was the initially given element
62 *
63 * Used to track a dfs path
64 *
65 * @tparam T type of vector elements
66 */
67 template <typename T>
68 class ScopedVectorElement final {
69 public:
ScopedVectorElement(llvm::SmallVectorImpl<T> * vector,T value)70 explicit ScopedVectorElement(llvm::SmallVectorImpl<T> *vector, T value) : vector_(vector), value_(value)
71 {
72 ASSERT(vector != nullptr);
73 ASSERT(value != nullptr);
74 vector->push_back(value);
75 }
76
~ScopedVectorElement()77 ~ScopedVectorElement()
78 {
79 [[maybe_unused]] auto value = vector_->pop_back_val();
80 ASSERT(value == value_);
81 }
82
83 ScopedVectorElement(const ScopedVectorElement &) = delete;
84 ScopedVectorElement &operator=(const ScopedVectorElement &) = delete;
85 ScopedVectorElement(ScopedVectorElement &&) = delete;
86 ScopedVectorElement &operator=(ScopedVectorElement &&) = delete;
87
88 private:
89 SmallVectorImpl<T> *vector_;
90 T value_;
91 };
92
93 using InlineFailurePath = SmallVector<Value *, 4U>;
94 using ScopedInlineFailurePathElement = ScopedVectorElement<InlineFailurePath::value_type>;
95
96 enum class InlineFailureReason {
97 HAS_ADDRESS_TAKEN,
98 USES_INTERNAL_VARIABLE,
99 };
100
101 class InlineFailure {
102 public:
InlineFailure(InlineFailurePath inlineFailurePath,InlineFailureReason reason)103 InlineFailure(InlineFailurePath inlineFailurePath, InlineFailureReason reason)
104 : inlineFailurePath_(std::move(inlineFailurePath)), reason_(reason)
105 {
106 }
107
GetInlineFailurePath() const108 const InlineFailurePath &GetInlineFailurePath() const
109 {
110 return inlineFailurePath_;
111 }
112
GetReason() const113 InlineFailureReason GetReason() const
114 {
115 return reason_;
116 }
117
Print(llvm::raw_ostream * output)118 void Print(llvm::raw_ostream *output)
119 {
120 ASSERT(output != nullptr);
121
122 if (reason_ == InlineFailureReason::HAS_ADDRESS_TAKEN) {
123 PrintForHasAddressTaken(output);
124 } else if (reason_ == InlineFailureReason::USES_INTERNAL_VARIABLE) {
125 PrintForUsesInternalVariable(output);
126 } else {
127 LLVM_DEBUG(llvm::dbgs() << "static_cast<int>(reason_) = " << static_cast<int>(reason_) << "\n");
128 llvm_unreachable("Unsupported reason");
129 }
130 }
131
132 private:
PrintForHasAddressTaken(llvm::raw_ostream * output)133 void PrintForHasAddressTaken(llvm::raw_ostream *output)
134 {
135 ASSERT(!inlineFailurePath_.empty());
136 ASSERT(reason_ == InlineFailureReason::HAS_ADDRESS_TAKEN);
137 auto function = inlineFailurePath_[inlineFailurePath_.size() - 1];
138 ASSERT(isa<Function>(function));
139 *output << "address of the function = '" << function->getName() << "' is taken. Won't inline chain: ";
140 PrintPath(output);
141 }
142
PrintForUsesInternalVariable(llvm::raw_ostream * output)143 void PrintForUsesInternalVariable(llvm::raw_ostream *output)
144 {
145 ASSERT(reason_ == InlineFailureReason::USES_INTERNAL_VARIABLE);
146 ASSERT(inlineFailurePath_.size() >= 2U);
147 // inline_failure_path_ = [..., functionUsingInternalVariable, ...not function..., internalVariable]
148 auto variable = inlineFailurePath_[inlineFailurePath_.size() - 1];
149 auto function = FindLastFunctionInPath();
150 *output << "Internal variable = '" << variable->getName() << "' is used in function = '" << function->getName()
151 << "'. Won't inline chain: ";
152 PrintPath(output);
153 }
154
FindLastFunctionInPath()155 Function *FindLastFunctionInPath()
156 {
157 for (auto it = inlineFailurePath_.rbegin(); it != inlineFailurePath_.rend(); it++) {
158 if (auto function = llvm::dyn_cast<Function>(*it)) {
159 return function;
160 }
161 }
162 LLVM_DEBUG(llvm::dbgs() << "inline_failure_path_ = ");
163 LLVM_DEBUG(PrintPath(&llvm::dbgs()));
164 LLVM_DEBUG(llvm::dbgs() << "\n");
165 llvm_unreachable("Could not find function in inline_failure_path_");
166 }
167
PrintPath(llvm::raw_ostream * output)168 void PrintPath(llvm::raw_ostream *output)
169 {
170 ListSeparator separator(" -> ");
171 for (const auto &pathElement : inlineFailurePath_) {
172 *output << separator << "'" << pathElement->getName() << "'";
173 }
174 }
175
176 private:
177 InlineFailurePath inlineFailurePath_;
178 InlineFailureReason reason_;
179 };
180 } // namespace
181
182 namespace ark::llvmbackend::passes {
183
184 /**
185 * Remove functions and variables unsuitable for inlining from module.
186 *
187 * Function is unsuitable for inlining if:
188 *
189 * 1. It has local linkage and address taken
190 * 2. Function does not have external linkage, references variable with local linkage and the variable is not a
191 * constant. Examples:
192 * 1. Function has 'static int x' in its body
193 * 2. Function assigns a value to a 'static thread_local'
194 */
195 class CleanupInlineModule::InlineModuleCleaner {
196 public:
Run(Module & module)197 bool Run(Module &module)
198 {
199 RemoveNonInlinableFunctions(module);
200 RemoveObjectFileGlobals(module);
201 RemoveDanglingAliases(module);
202 return true;
203 }
204
205 private:
206 enum class FunctionState {
207 UNKNOWN, // white
208 IN_PROGRESS, // gray
209 NOT_INLINABLE, // black
210 INLINABLE, // black
211 };
212
213 using DfsState = DenseMap<Function *, FunctionState>;
214
RemoveNonInlinableFunctions(Module & module)215 void RemoveNonInlinableFunctions(Module &module)
216 {
217 for (auto &function : module.functions()) {
218 if (function.isDeclaration()) {
219 continue;
220 }
221
222 ScopedInlineFailurePathElement inlineFailurePathElement {&inlineFailurePath_, &function};
223 VisitFunction(&function);
224 }
225 ASSERT(inlineFailurePath_.empty());
226
227 for_each(state_, [](auto entry) -> void {
228 auto functionState = entry.getSecond();
229 ASSERT(functionState == FunctionState::INLINABLE || functionState == FunctionState::NOT_INLINABLE);
230 if (functionState == FunctionState::NOT_INLINABLE) {
231 auto function = entry.getFirst();
232 LLVM_DEBUG(llvm::dbgs() << "InlineModuleCleaner: removed '" << function->getName() << "'\n");
233 convertToDeclaration(*function);
234 }
235 });
236 LLVM_DEBUG(PrintInlineReport(&llvm::errs()));
237 }
238
239 /**
240 * Visit the function during depth first search.
241 *
242 * After returning the dfs_state_ will contain either:
243 * 1. FunctionState::NOT_INLINABLE
244 * 2. FunctionState::INLINABLE
245 * for the function.
246 */
VisitFunction(Function * function)247 void VisitFunction(Function *function)
248 {
249 static_assert(FunctionState() == FunctionState::UNKNOWN, "FunctionState::UNKNOWN must be the default value");
250 auto functionState = state_.lookup(function);
251 if (functionState != FunctionState::UNKNOWN) {
252 return;
253 }
254 if (function->hasLocalLinkage() && function->hasAddressTaken()) {
255 state_[function] = FunctionState::NOT_INLINABLE;
256 ReportInlineFailure(InlineFailureReason::HAS_ADDRESS_TAKEN, inlineFailurePath_);
257 return;
258 }
259
260 state_.insert({function, FunctionState::IN_PROGRESS});
261
262 DenseSet<Value *> visited;
263 for (auto &basicBlock : *function) {
264 if (!IsInlinable(&basicBlock, visited)) {
265 state_[function] = FunctionState::NOT_INLINABLE;
266 return;
267 }
268 }
269 state_[function] = FunctionState::INLINABLE;
270 }
271
IsInlinable(Value * value,DenseSet<Value * > & visited)272 bool IsInlinable(Value *value, DenseSet<Value *> &visited)
273 {
274 ScopedInlineFailurePathElement inlineFailurePathElement {&inlineFailurePath_, value};
275 if (visited.contains(value)) {
276 return true;
277 }
278 visited.insert(value);
279 if (isa<Function>(value)) {
280 auto function = cast<Function>(value);
281 /**
282 * The value operand is a function with external linkage.
283 *
284 * The function's operand from VisitFunction is inlinable because we'll either:
285 * 1. Inline the operand itself
286 * 2. Leave a reference to external function, which linker will resolve
287 *
288 * Example:
289 *
290 * @code
291 * extern void foo();
292 *
293 * void bar() {
294 * foo(); // We're here: the 'bar' function has external linkage
295 * }
296 * @endcode
297 *
298 * Call stack might be:
299 *
300 * @code
301 * 1. IsInlinable(foo);
302 * 2. VisitFunction(bar);
303 * @endcode
304 */
305 if (function->hasExternalLinkage()) {
306 return true;
307 }
308
309 /**
310 * Function does not have external linkage, then we inline it if the function itself is inlinable.
311 *
312 * Example:
313 *
314 * @code
315 * static void foo() {
316 * }
317 *
318 * void bar() {
319 * foo(); // We're here, 'foo' function does not have external linkage
320 * }
321 * @endcode
322 *
323 * Call stack might be:
324 *
325 * @code
326 * 1. IsInlinable(foo);
327 * 2. VisitFunction(bar);
328 * @endcode
329 */
330 VisitFunction(function);
331 auto functionState = state_.lookup(function);
332 return functionState == FunctionState::IN_PROGRESS || functionState == FunctionState::INLINABLE;
333 }
334 if (isa<GlobalVariable>(value)) {
335 auto globalVariable = cast<GlobalVariable>(value);
336 /**
337 * Could be a constant from C++ declared in header.
338 * We don't check for taken address because in different translation units address of such constant
339 * could be different
340 *
341 * https://stackoverflow.com/a/50489130
342 * > constexpr implies const and const on global/namespace scope implies static (internal linkage),
343 * > which means that every translation unit including this header gets its own copy of PI. The memory
344 * > for that static is only going to be allocated if an address or reference to it is taken, and the
345 * > address is going to be different in each translation unit.
346 */
347 if (globalVariable->hasLocalLinkage() && !globalVariable->isConstant()) {
348 /**
349 * Mutable variable with local linkage, can't inline.
350 *
351 * Example:
352 *
353 * @code
354 * void foo() {
355 * static int x = 0;
356 * std::cout << x++ << std::endl;
357 * }
358 * @endcode
359 */
360 ReportInlineFailure(InlineFailureReason::USES_INTERNAL_VARIABLE, inlineFailurePath_);
361 return false;
362 }
363 return true;
364 }
365 if (isa<User>(value)) {
366 auto user = cast<User>(value);
367 for (auto operand : user->operand_values()) {
368 if (!IsInlinable(operand, visited)) {
369 return false;
370 }
371 }
372 return true;
373 }
374 if (isa<BasicBlock>(value)) {
375 for (auto &instruction : *cast<BasicBlock>(value)) {
376 if (!IsInlinable(&instruction, visited)) {
377 return false;
378 }
379 }
380 return true;
381 }
382 if (isa<Argument, InlineAsm, Constant, llvm::MetadataAsValue>(value)) {
383 return true;
384 }
385 LLVM_DEBUG(llvm::dbgs() << "Value = ");
386 LLVM_DEBUG(value->print(llvm::dbgs()));
387 LLVM_DEBUG(llvm::dbgs() << "\n");
388 llvm_unreachable("Unexpected value");
389 }
390
ReportInlineFailure(InlineFailureReason inlineFailureReason,const InlineFailurePath & failurePath)391 void ReportInlineFailure([[maybe_unused]] InlineFailureReason inlineFailureReason,
392 [[maybe_unused]] const InlineFailurePath &failurePath)
393 {
394 LLVM_DEBUG(InlineFailure(failurePath, inlineFailureReason).Print(&llvm::dbgs()));
395 LLVM_DEBUG(llvm::dbgs() << "\n");
396 }
397
PrintInlineReport(llvm::raw_ostream * output) const398 void PrintInlineReport(llvm::raw_ostream *output) const
399 {
400 std::vector<DfsState::value_type> entries {state_.begin(), state_.end()};
401 std::sort(entries.begin(), entries.end(), [](auto a, auto b) {
402 auto aState = a.getSecond();
403 auto bState = b.getSecond();
404 ASSERT(aState == FunctionState::INLINABLE || aState == FunctionState::NOT_INLINABLE);
405 ASSERT(bState == FunctionState::INLINABLE || bState == FunctionState::NOT_INLINABLE);
406 if (aState != bState) {
407 return aState == FunctionState::NOT_INLINABLE;
408 }
409 auto aFunction = a.getFirst();
410 auto bFunction = b.getFirst();
411 return aFunction->getName() < bFunction->getName();
412 });
413 for (size_t i = 0; i < entries.size(); i++) {
414 auto entry = entries[i];
415 auto function = entry.getFirst();
416 auto functionState = entry.getSecond();
417 ASSERT(functionState == FunctionState::INLINABLE || functionState == FunctionState::NOT_INLINABLE);
418 *output << (i + 1) << ". '" << function->getName() << "' is"
419 << (functionState == FunctionState::INLINABLE ? " inlinable" : " not inlinable") << "\n";
420 }
421 }
422
RemoveObjectFileGlobals(const Module & module) const423 void RemoveObjectFileGlobals(const Module &module) const
424 {
425 static constexpr std::array VALUES_TO_ERASE = {
426 StringRef("llvm.used"), //
427 StringRef("llvm.compiler.used"), //
428 StringRef("llvm.global_ctors"), //
429 StringRef("llvm.global_dtors"), //
430 StringRef("llvm.global.annotations"), //
431 };
432 for (const auto &name : VALUES_TO_ERASE) {
433 auto *globalValue = module.getNamedValue(name);
434 if (globalValue != nullptr) {
435 LLVM_DEBUG(llvm::dbgs() << "Erase " << globalValue->getName() << "\n");
436 globalValue->eraseFromParent();
437 }
438 }
439 }
440
441 private:
442 DfsState state_;
443 InlineFailurePath inlineFailurePath_;
444 };
445
ShouldInsert(const ark::llvmbackend::LLVMCompilerOptions * options)446 bool CleanupInlineModule::ShouldInsert(const ark::llvmbackend::LLVMCompilerOptions *options)
447 {
448 return options->doIrtocInline;
449 }
450
CleanupInlineModule()451 CleanupInlineModule::CleanupInlineModule() : cleaner_ {std::make_unique<InlineModuleCleaner>()} {}
452
453 CleanupInlineModule::CleanupInlineModule(CleanupInlineModule &&) = default;
454
455 CleanupInlineModule &CleanupInlineModule::operator=(CleanupInlineModule &&) = default;
456
457 CleanupInlineModule::~CleanupInlineModule() = default;
458
run(llvm::Module & module,llvm::ModuleAnalysisManager &)459 llvm::PreservedAnalyses CleanupInlineModule::run(llvm::Module &module, llvm::ModuleAnalysisManager & /*AM*/)
460 {
461 auto changed = cleaner_->Run(module);
462 return changed ? llvm::PreservedAnalyses::none() : llvm::PreservedAnalyses::all();
463 }
464
465 } // namespace ark::llvmbackend::passes
466