1 /*
2 * Copyright (c) 2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "check_tail_calls.h"
17 #include <llvm/CodeGen/TargetInstrInfo.h>
18 #include <llvm/CodeGen/MachineFunctionPass.h>
19 #include <llvm/IR/Instructions.h>
20
21 #include "transforms/transform_utils.h"
22
23 #define DEBUG_TYPE "check-tail-calls"
24
25 namespace {
26
27 /**
28 * Used to detect cycles while walking on CFG
29 *
30 * @tparam T type of set elements
31 */
32 template <typename T>
33 class ScopedSetElement final {
34 public:
ScopedSetElement(llvm::SmallPtrSetImpl<T> * set,T value)35 explicit ScopedSetElement(llvm::SmallPtrSetImpl<T> *set, T value) : set_(set), value_(value)
36 {
37 ASSERT(set != nullptr);
38 ASSERT(value != nullptr);
39 set->insert(value);
40 }
41
~ScopedSetElement()42 ~ScopedSetElement()
43 {
44 set_->erase(value_);
45 }
46
47 ScopedSetElement(const ScopedSetElement &) = delete;
48 ScopedSetElement &operator=(const ScopedSetElement &) = delete;
49 ScopedSetElement(ScopedSetElement &&) = delete;
50 ScopedSetElement &operator=(ScopedSetElement &&) = delete;
51
52 private:
53 llvm::SmallPtrSetImpl<T> *set_;
54 T value_;
55 };
56
57 using VisitedBasicBlocks = llvm::SmallPtrSet<llvm::MachineBasicBlock *, 4U>;
58 using VisitedBasicBlockElement = ScopedSetElement<VisitedBasicBlocks::value_type>;
59
60 // This pass checks: 1) tail calls in interpreter handlers 2) calls to SlowPathes made in Irtoc FastPathes (they
61 // should be tail calls to avoid miscompilation). `CheckTailCallsPass` checks that llvm was able to lower
62 // them into machine code correspondingly, not falling back to regular calls.
63 class CheckTailCallsPass : public llvm::MachineFunctionPass {
64 public:
65 static constexpr llvm::StringRef PASS_NAME = "Check ARK Tail Calls";
66 static constexpr llvm::StringRef ARG_NAME = "check-ark-tail-calls";
67
CheckTailCallsPass()68 explicit CheckTailCallsPass() : MachineFunctionPass(ID) {}
69
getPassName() const70 llvm::StringRef getPassName() const override
71 {
72 return PASS_NAME;
73 }
74
75 // Almost exact copy of `getTerminatingMustTailCall` from llvm
GetTerminatingTailCall(const llvm::BasicBlock * bb) const76 const llvm::CallInst *GetTerminatingTailCall(const llvm::BasicBlock *bb) const
77 {
78 if (bb->empty()) {
79 return nullptr;
80 }
81
82 auto returnInst = llvm::dyn_cast<llvm::ReturnInst>(&bb->back());
83 if (returnInst == nullptr || returnInst == &bb->front()) {
84 return nullptr;
85 }
86
87 const llvm::Instruction *prev = returnInst->getPrevNode();
88 if (prev == nullptr) {
89 return nullptr;
90 }
91
92 if (llvm::Value *returnValue = returnInst->getReturnValue()) {
93 if (returnValue != prev) {
94 return nullptr;
95 }
96 // Look through the optional bitcast
97 if (auto *bitcastInst = llvm::dyn_cast<llvm::BitCastInst>(prev)) {
98 returnValue = bitcastInst->getOperand(0);
99 prev = bitcastInst->getPrevNode();
100 if (prev == nullptr || returnValue != prev) {
101 return nullptr;
102 }
103 }
104 }
105 if (auto *callInst = llvm::dyn_cast<llvm::CallInst>(prev)) {
106 if (callInst->isTailCall()) {
107 return callInst;
108 }
109 }
110 return nullptr;
111 }
112
IsRealTailCall(llvm::MachineBasicBlock * basicBlock,VisitedBasicBlocks * visitedBasicBlocks)113 static bool IsRealTailCall(llvm::MachineBasicBlock *basicBlock, VisitedBasicBlocks *visitedBasicBlocks)
114 {
115 auto *instInfo = basicBlock->getParent()->getSubtarget().getInstrInfo();
116 if (llvm::all_of(basicBlock->terminators(),
117 [&instInfo](llvm::MachineInstr &term) { return instInfo->isTailCall(term); })) {
118 return true;
119 }
120 if (visitedBasicBlocks->contains(basicBlock)) {
121 llvm::report_fatal_error("Cycle in CFG in '" + basicBlock->getParent()->getName() +
122 "' prevents tail call check");
123 }
124 VisitedBasicBlockElement visitedBasicBlockElement {visitedBasicBlocks, basicBlock};
125 return llvm::all_of(basicBlock->successors(), [&visitedBasicBlocks](llvm::MachineBasicBlock *succ) {
126 return IsRealTailCall(succ, visitedBasicBlocks);
127 });
128 }
129
runOnMachineFunction(llvm::MachineFunction & machineFunction)130 bool runOnMachineFunction(llvm::MachineFunction &machineFunction) override
131 {
132 llvm::SmallSet<const llvm::Instruction *, 4U> confirmedTailCalls;
133 for (auto &basicBlock : machineFunction) {
134 auto irBasicBlock = basicBlock.getBasicBlock();
135 if (irBasicBlock == nullptr) {
136 continue;
137 }
138 auto callInst = GetTerminatingTailCall(irBasicBlock);
139 if (callInst != nullptr && callInst->hasFnAttr("ark-tail-call")) {
140 VisitedBasicBlocks visitedBasicBlocks;
141 if (IsRealTailCall(&basicBlock, &visitedBasicBlocks)) {
142 confirmedTailCalls.insert(callInst);
143 } else {
144 llvm::report_fatal_error("Cannot find tail call for '" + machineFunction.getName() + "'");
145 }
146 ASSERT(visitedBasicBlocks.empty());
147 }
148 }
149 for (auto &irBasicBlock : machineFunction.getFunction()) {
150 for (auto &irInst : irBasicBlock) {
151 auto *callInst = llvm::dyn_cast<llvm::CallInst>(&irInst);
152 if (callInst != nullptr && callInst->hasFnAttr("ark-tail-call") &&
153 !confirmedTailCalls.contains(callInst)) {
154 llvm::report_fatal_error("Missing tail call in '" + machineFunction.getName() + "'");
155 }
156 }
157 }
158 return false;
159 }
160
161 static inline char ID = 0; // NOLINT(readability-identifier-naming)
162 };
163 } // namespace
164
CreateCheckTailCallsPass()165 llvm::MachineFunctionPass *ark::llvmbackend::CreateCheckTailCallsPass()
166 {
167 return new CheckTailCallsPass();
168 }
169
170 // NOLINTNEXTLINE(fuchsia-statically-constructed-objects)
171 static llvm::RegisterPass<CheckTailCallsPass> g_ctc(CheckTailCallsPass::ARG_NAME, CheckTailCallsPass::PASS_NAME, false,
172 false);
173