1 /*
2 * Copyright (c) 2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "transforms/passes/mem_barriers.h"
17 #include "llvm_ark_interface.h"
18 #include "llvm_compiler_options.h"
19
20 #include <llvm/IR/IRBuilder.h>
21 #include <llvm/IR/IntrinsicsAArch64.h>
22 #include <llvm/Transforms/Utils/BasicBlockUtils.h>
23
24 using llvm::BasicBlock;
25 using llvm::CallInst;
26 using llvm::Function;
27 using llvm::FunctionAnalysisManager;
28 using llvm::Instruction;
29
30 namespace ark::llvmbackend::passes {
31
Create(LLVMArkInterface * arkInterface,const ark::llvmbackend::LLVMCompilerOptions * options)32 MemBarriers MemBarriers::Create([[maybe_unused]] LLVMArkInterface *arkInterface,
33 const ark::llvmbackend::LLVMCompilerOptions *options)
34 {
35 return MemBarriers(arkInterface, options->optimize);
36 }
37
MemBarriers(LLVMArkInterface * arkInterface,bool optimize)38 MemBarriers::MemBarriers(LLVMArkInterface *arkInterface, bool optimize)
39 : arkInterface_ {arkInterface}, optimize_ {optimize}
40 {
41 }
42
run(Function & function,FunctionAnalysisManager &)43 llvm::PreservedAnalyses MemBarriers::run(Function &function, FunctionAnalysisManager & /*analysisManager*/)
44 {
45 bool changed = false;
46 for (BasicBlock &block : function) {
47 llvm::SmallVector<llvm::Instruction *> needsBarrier;
48 llvm::SmallVector<llvm::Instruction *> mergeSet;
49 for (auto &inst : block) {
50 auto callInst = llvm::dyn_cast<llvm::CallInst>(&inst);
51 if (callInst != nullptr && callInst->hasFnAttr("needs-mem-barrier")) {
52 mergeSet.push_back(callInst);
53 continue;
54 }
55 if (optimize_ && inst.mayWriteToMemory() && GrabsGuarded(&inst, mergeSet)) {
56 MergeBarriers(mergeSet, needsBarrier);
57 }
58 }
59
60 MergeBarriers(mergeSet, needsBarrier);
61 for (auto inst : needsBarrier) {
62 auto builder = llvm::IRBuilder<>(inst->getNextNode());
63 builder.CreateFence(llvm::AtomicOrdering::Release);
64 changed = true;
65 }
66 }
67 changed |= RelaxBarriers(function);
68
69 return changed ? llvm::PreservedAnalyses::none() : llvm::PreservedAnalyses::all();
70 }
71
GrabsGuarded(llvm::Instruction * inst,llvm::SmallVector<llvm::Instruction * > & mergeSet)72 bool MemBarriers::GrabsGuarded(llvm::Instruction *inst, llvm::SmallVector<llvm::Instruction *> &mergeSet)
73 {
74 llvm::SmallVector<llvm::Value *> inputs;
75 if (auto storeInst = llvm::dyn_cast<llvm::StoreInst>(inst)) {
76 inputs.push_back(storeInst->getValueOperand());
77 } else if (auto callInst = llvm::dyn_cast<llvm::CallInst>(inst)) {
78 inputs.append(callInst->arg_begin(), callInst->arg_end());
79 } else {
80 inputs.append(inst->value_op_begin(), inst->value_op_end());
81 }
82 for (auto input : inputs) {
83 auto inputInst = llvm::dyn_cast<llvm::Instruction>(input);
84 if (inputInst != nullptr && std::find(mergeSet.begin(), mergeSet.end(), inputInst) != mergeSet.end()) {
85 return true;
86 }
87 }
88 return false;
89 }
90
MergeBarriers(llvm::SmallVector<llvm::Instruction * > & mergeSet,llvm::SmallVector<llvm::Instruction * > & needsBarrier)91 void MemBarriers::MergeBarriers(llvm::SmallVector<llvm::Instruction *> &mergeSet,
92 llvm::SmallVector<llvm::Instruction *> &needsBarrier)
93 {
94 if (mergeSet.empty()) {
95 return;
96 }
97 if (optimize_) {
98 needsBarrier.push_back(mergeSet.back());
99 } else {
100 needsBarrier.append(mergeSet);
101 }
102 mergeSet.clear();
103 }
104
RelaxBarriers(llvm::Function & function)105 bool MemBarriers::RelaxBarriers(llvm::Function &function)
106 {
107 if (!arkInterface_->IsArm64() || !optimize_) {
108 return false;
109 }
110
111 auto opcode = llvm::Intrinsic::AARCH64Intrinsics::aarch64_dmb;
112 auto dmb = llvm::Intrinsic::getDeclaration(function.getParent(), opcode, {});
113 static constexpr uint32_t ISHST = 10U;
114 auto ishst = llvm::ConstantInt::get(llvm::Type::getInt32Ty(function.getContext()), ISHST);
115
116 bool changed = false;
117 for (auto &basicBlock : function) {
118 llvm::SmallVector<llvm::FenceInst *> fences;
119 for (auto &instruction : basicBlock) {
120 auto fence = llvm::dyn_cast<llvm::FenceInst>(&instruction);
121 if (fence != nullptr && fence->getOrdering() == llvm::AtomicOrdering::Release) {
122 fences.push_back(fence);
123 }
124 }
125 for (auto fence : fences) {
126 auto upgraded = llvm::CallInst::Create(dmb, {ishst}, llvm::None);
127 llvm::ReplaceInstWithInst(fence, upgraded);
128 }
129 changed |= !fences.empty();
130 }
131 return changed;
132 }
133
134 } // namespace ark::llvmbackend::passes
135