• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  * http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "check_tail_calls.h"
17 #include <llvm/CodeGen/TargetInstrInfo.h>
18 #include <llvm/CodeGen/MachineFunctionPass.h>
19 #include <llvm/IR/Instructions.h>
20 
21 #include "transforms/transform_utils.h"
22 
23 #define DEBUG_TYPE "check-tail-calls"
24 
25 namespace {
26 
27 /**
28  * Used to detect cycles while walking on CFG
29  *
30  * @tparam T type of set elements
31  */
32 template <typename T>
33 class ScopedSetElement final {
34 public:
ScopedSetElement(llvm::SmallPtrSetImpl<T> * set,T value)35     explicit ScopedSetElement(llvm::SmallPtrSetImpl<T> *set, T value) : set_(set), value_(value)
36     {
37         ASSERT(set != nullptr);
38         ASSERT(value != nullptr);
39         set->insert(value);
40     }
41 
~ScopedSetElement()42     ~ScopedSetElement()
43     {
44         set_->erase(value_);
45     }
46 
47     ScopedSetElement(const ScopedSetElement &) = delete;
48     ScopedSetElement &operator=(const ScopedSetElement &) = delete;
49     ScopedSetElement(ScopedSetElement &&) = delete;
50     ScopedSetElement &operator=(ScopedSetElement &&) = delete;
51 
52 private:
53     llvm::SmallPtrSetImpl<T> *set_;
54     T value_;
55 };
56 
57 using VisitedBasicBlocks = llvm::SmallPtrSet<llvm::MachineBasicBlock *, 4U>;
58 using VisitedBasicBlockElement = ScopedSetElement<VisitedBasicBlocks::value_type>;
59 
60 // This pass checks: 1) tail calls in interpreter handlers 2) calls to SlowPathes made in Irtoc FastPathes (they
61 // should be tail calls to avoid miscompilation). `CheckTailCallsPass` checks that llvm was able to lower
62 // them into machine code correspondingly, not falling back to regular calls.
63 class CheckTailCallsPass : public llvm::MachineFunctionPass {
64 public:
65     static constexpr llvm::StringRef PASS_NAME = "Check ARK Tail Calls";
66     static constexpr llvm::StringRef ARG_NAME = "check-ark-tail-calls";
67 
CheckTailCallsPass()68     explicit CheckTailCallsPass() : MachineFunctionPass(ID) {}
69 
getPassName() const70     llvm::StringRef getPassName() const override
71     {
72         return PASS_NAME;
73     }
74 
75     // Almost exact copy of `getTerminatingMustTailCall` from llvm
GetTerminatingTailCall(const llvm::BasicBlock * bb) const76     const llvm::CallInst *GetTerminatingTailCall(const llvm::BasicBlock *bb) const
77     {
78         if (bb->empty()) {
79             return nullptr;
80         }
81 
82         auto returnInst = llvm::dyn_cast<llvm::ReturnInst>(&bb->back());
83         if (returnInst == nullptr || returnInst == &bb->front()) {
84             return nullptr;
85         }
86 
87         const llvm::Instruction *prev = returnInst->getPrevNode();
88         if (prev == nullptr) {
89             return nullptr;
90         }
91 
92         if (llvm::Value *returnValue = returnInst->getReturnValue()) {
93             if (returnValue != prev) {
94                 return nullptr;
95             }
96             // Look through the optional bitcast
97             if (auto *bitcastInst = llvm::dyn_cast<llvm::BitCastInst>(prev)) {
98                 returnValue = bitcastInst->getOperand(0);
99                 prev = bitcastInst->getPrevNode();
100                 if (prev == nullptr || returnValue != prev) {
101                     return nullptr;
102                 }
103             }
104         }
105         if (auto *callInst = llvm::dyn_cast<llvm::CallInst>(prev)) {
106             if (callInst->isTailCall()) {
107                 return callInst;
108             }
109         }
110         return nullptr;
111     }
112 
IsRealTailCall(llvm::MachineBasicBlock * basicBlock,VisitedBasicBlocks * visitedBasicBlocks)113     static bool IsRealTailCall(llvm::MachineBasicBlock *basicBlock, VisitedBasicBlocks *visitedBasicBlocks)
114     {
115         auto *instInfo = basicBlock->getParent()->getSubtarget().getInstrInfo();
116         if (llvm::all_of(basicBlock->terminators(),
117                          [&instInfo](llvm::MachineInstr &term) { return instInfo->isTailCall(term); })) {
118             return true;
119         }
120         if (visitedBasicBlocks->contains(basicBlock)) {
121             llvm::report_fatal_error("Cycle in CFG in '" + basicBlock->getParent()->getName() +
122                                      "' prevents tail call check");
123         }
124         VisitedBasicBlockElement visitedBasicBlockElement {visitedBasicBlocks, basicBlock};
125         return llvm::all_of(basicBlock->successors(), [&visitedBasicBlocks](llvm::MachineBasicBlock *succ) {
126             return IsRealTailCall(succ, visitedBasicBlocks);
127         });
128     }
129 
runOnMachineFunction(llvm::MachineFunction & machineFunction)130     bool runOnMachineFunction(llvm::MachineFunction &machineFunction) override
131     {
132         llvm::SmallSet<const llvm::Instruction *, 4U> confirmedTailCalls;
133         for (auto &basicBlock : machineFunction) {
134             auto irBasicBlock = basicBlock.getBasicBlock();
135             if (irBasicBlock == nullptr) {
136                 continue;
137             }
138             auto callInst = GetTerminatingTailCall(irBasicBlock);
139             if (callInst != nullptr && callInst->hasFnAttr("ark-tail-call")) {
140                 VisitedBasicBlocks visitedBasicBlocks;
141                 if (IsRealTailCall(&basicBlock, &visitedBasicBlocks)) {
142                     confirmedTailCalls.insert(callInst);
143                 } else {
144                     llvm::report_fatal_error("Cannot find tail call for  '" + machineFunction.getName() + "'");
145                 }
146                 ASSERT(visitedBasicBlocks.empty());
147             }
148         }
149         for (auto &irBasicBlock : machineFunction.getFunction()) {
150             for (auto &irInst : irBasicBlock) {
151                 auto *callInst = llvm::dyn_cast<llvm::CallInst>(&irInst);
152                 if (callInst != nullptr && callInst->hasFnAttr("ark-tail-call") &&
153                     !confirmedTailCalls.contains(callInst)) {
154                     llvm::report_fatal_error("Missing tail call in '" + machineFunction.getName() + "'");
155                 }
156             }
157         }
158         return false;
159     }
160 
161     static inline char ID = 0;  // NOLINT(readability-identifier-naming)
162 };
163 }  // namespace
164 
CreateCheckTailCallsPass()165 llvm::MachineFunctionPass *ark::llvmbackend::CreateCheckTailCallsPass()
166 {
167     return new CheckTailCallsPass();
168 }
169 
170 // NOLINTNEXTLINE(fuchsia-statically-constructed-objects)
171 static llvm::RegisterPass<CheckTailCallsPass> g_ctc(CheckTailCallsPass::ARG_NAME, CheckTailCallsPass::PASS_NAME, false,
172                                                     false);
173