1 /*
2 * Copyright 2016-2017, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "GlobalMergePass.h"
18
19 #include "llvm/IR/Constants.h"
20 #include "llvm/IR/DataLayout.h"
21 #include "llvm/IR/GlobalVariable.h"
22 #include "llvm/IR/IRBuilder.h"
23 #include "llvm/IR/Instructions.h"
24 #include "llvm/IR/Module.h"
25 #include "llvm/Pass.h"
26 #include "llvm/Support/Debug.h"
27 #include "llvm/Support/raw_ostream.h"
28
29 #include "Context.h"
30 #include "RSAllocationUtils.h"
31
32 #include <functional>
33
34 #define DEBUG_TYPE "rs2spirv-global-merge"
35
36 using namespace llvm;
37
38 namespace rs2spirv {
39
40 namespace {
41
42 class GlobalMergePass : public ModulePass {
43 public:
44 static char ID;
GlobalMergePass(bool CPU=false)45 GlobalMergePass(bool CPU = false) : ModulePass(ID), mForCPU(CPU) {}
getPassName() const46 const char *getPassName() const override { return "GlobalMergePass"; }
47
runOnModule(Module & M)48 bool runOnModule(Module &M) override {
49 DEBUG(dbgs() << "RS2SPIRVGlobalMergePass\n");
50
51 SmallVector<GlobalVariable *, 8> Globals;
52 if (!collectGlobals(M, Globals)) {
53 return false; // Module not modified.
54 }
55
56 SmallVector<Type *, 8> Tys;
57 Tys.reserve(Globals.size());
58
59 Context &RS2SPIRVCtxt = Context::getInstance();
60
61 uint32_t index = 0;
62 for (GlobalVariable *GV : Globals) {
63 Tys.push_back(GV->getValueType());
64 const char *name = GV->getName().data();
65 RS2SPIRVCtxt.addExportVarIndex(name, index);
66 index++;
67 }
68
69 LLVMContext &LLVMCtxt = M.getContext();
70
71 StructType *MergedTy = StructType::create(LLVMCtxt, "struct.__GPUBuffer");
72 MergedTy->setBody(Tys, false);
73
74 // Size calculation has to consider data layout
75 const DataLayout &DL = M.getDataLayout();
76 const uint64_t BufferSize = DL.getTypeAllocSize(MergedTy);
77 RS2SPIRVCtxt.setGlobalSize(BufferSize);
78
79 Type *BufferVarTy = mForCPU ? static_cast<Type *>(PointerType::getUnqual(
80 Type::getInt8Ty(M.getContext())))
81 : static_cast<Type *>(MergedTy);
82 GlobalVariable *MergedGV =
83 new GlobalVariable(M, BufferVarTy, false, GlobalValue::ExternalLinkage,
84 nullptr, "__GPUBlock");
85
86 // For CPU, create a constant struct for initial values, which has each of
87 // its fields initialized to the original value of the corresponding global
88 // variable.
89 // During the script initialization, the driver should copy these initial
90 // values to the global buffer.
91 if (mForCPU) {
92 CreateInitFunction(LLVMCtxt, M, MergedGV, MergedTy, BufferSize, Globals);
93 }
94
95 const bool forCPU = mForCPU;
96 IntegerType *const Int32Ty = Type::getInt32Ty(LLVMCtxt);
97 ConstantInt *const Zero = ConstantInt::get(Int32Ty, 0);
98 Value *Idx[] = {Zero, nullptr};
99
100 auto InstMaker = [forCPU, MergedGV, MergedTy,
101 &Idx](Instruction *InsertBefore) {
102 Value *Base = MergedGV;
103 if (forCPU) {
104 LoadInst *Load = new LoadInst(MergedGV, "", InsertBefore);
105 DEBUG(Load->dump());
106 Base = new BitCastInst(Load, PointerType::getUnqual(MergedTy), "",
107 InsertBefore);
108 DEBUG(Base->dump());
109 }
110 GetElementPtrInst *GEP = GetElementPtrInst::CreateInBounds(
111 MergedTy, Base, Idx, "", InsertBefore);
112 DEBUG(GEP->dump());
113 return GEP;
114 };
115
116 for (size_t i = 0, e = Globals.size(); i != e; ++i) {
117 GlobalVariable *G = Globals[i];
118 Idx[1] = ConstantInt::get(Int32Ty, i);
119 ReplaceAllUsesWithNewInstructions(G, std::cref(InstMaker));
120 G->eraseFromParent();
121 }
122
123 // Return true, as the pass modifies module.
124 return true;
125 }
126
127 private:
128 // In the User of Value Old, replaces all references of Old with Value New
ReplaceUse(User * U,Value * Old,Value * New)129 static inline void ReplaceUse(User *U, Value *Old, Value *New) {
130 for (unsigned i = 0, n = U->getNumOperands(); i < n; ++i) {
131 if (U->getOperand(i) == Old) {
132 U->getOperandUse(i) = New;
133 }
134 }
135 }
136
137 // Replaces each use of V with new instructions created by
138 // funcCreateAndInsert and inserted right before that use. In the cases where
139 // the use is not an instruction, but a constant expression, recursively
140 // replaces that constant expression with a newly constructed equivalent
141 // instruction, before replacing V in that new instruction.
ReplaceAllUsesWithNewInstructions(Value * V,std::function<Instruction * (Instruction *)> funcCreateAndInsert)142 static inline void ReplaceAllUsesWithNewInstructions(
143 Value *V,
144 std::function<Instruction *(Instruction *)> funcCreateAndInsert) {
145 SmallVector<User *, 8> Users(V->user_begin(), V->user_end());
146 for (User *U : Users) {
147 if (Instruction *Inst = dyn_cast<Instruction>(U)) {
148 DEBUG(dbgs() << "\nBefore replacement:\n");
149 DEBUG(Inst->dump());
150 DEBUG(dbgs() << "----\n");
151
152 ReplaceUse(U, V, funcCreateAndInsert(Inst));
153
154 DEBUG(Inst->dump());
155 } else if (ConstantExpr *CE = dyn_cast<ConstantExpr>(U)) {
156 auto InstMaker([CE, V, &funcCreateAndInsert](Instruction *UserOfU) {
157 Instruction *Inst = CE->getAsInstruction();
158 Inst->insertBefore(UserOfU);
159 ReplaceUse(Inst, V, funcCreateAndInsert(Inst));
160
161 DEBUG(Inst->dump());
162 return Inst;
163 });
164 ReplaceAllUsesWithNewInstructions(U, InstMaker);
165 } else {
166 DEBUG(U->dump());
167 llvm_unreachable("Expecting only Instruction or ConstantExpr");
168 }
169 }
170 }
171
172 static inline void
CreateInitFunction(LLVMContext & LLVMCtxt,Module & M,GlobalVariable * MergedGV,StructType * MergedTy,const uint64_t BufferSize,const SmallVectorImpl<GlobalVariable * > & Globals)173 CreateInitFunction(LLVMContext &LLVMCtxt, Module &M, GlobalVariable *MergedGV,
174 StructType *MergedTy, const uint64_t BufferSize,
175 const SmallVectorImpl<GlobalVariable *> &Globals) {
176 SmallVector<Constant *, 8> Initializers;
177 Initializers.reserve(Globals.size());
178 for (size_t i = 0, e = Globals.size(); i != e; ++i) {
179 GlobalVariable *G = Globals[i];
180 Initializers.push_back(G->getInitializer());
181 }
182 ArrayRef<Constant *> ArrInit(Initializers.begin(), Initializers.end());
183 Constant *MergedInitializer = ConstantStruct::get(MergedTy, ArrInit);
184 GlobalVariable *MergedInit =
185 new GlobalVariable(M, MergedTy, true, GlobalValue::InternalLinkage,
186 MergedInitializer, "__GPUBlock0");
187
188 Function *UserInit = M.getFunction("init");
189 // If there is no user-defined init() function, make the new global
190 // initialization function the init().
191 StringRef FName(UserInit ? ".rsov.global_init" : "init");
192 Function *Func;
193 FunctionType *FTy = FunctionType::get(Type::getVoidTy(LLVMCtxt), false);
194 Func = Function::Create(FTy, GlobalValue::ExternalLinkage, FName, &M);
195 BasicBlock *Blk = BasicBlock::Create(LLVMCtxt, "entry", Func);
196 IRBuilder<> LLVMIRBuilder(Blk);
197 LoadInst *Load = LLVMIRBuilder.CreateLoad(MergedGV);
198 LLVMIRBuilder.CreateMemCpy(Load, MergedInit, BufferSize, 0);
199 LLVMIRBuilder.CreateRetVoid();
200
201 // If there is a user-defined init() function, add a call to the global
202 // initialization function in the beginning of that function.
203 if (UserInit) {
204 BasicBlock &EntryBlk = UserInit->getEntryBlock();
205 CallInst::Create(Func, {}, "", &EntryBlk.front());
206 }
207 }
208
collectGlobals(Module & M,SmallVectorImpl<GlobalVariable * > & Globals)209 bool collectGlobals(Module &M, SmallVectorImpl<GlobalVariable *> &Globals) {
210 for (GlobalVariable &GV : M.globals()) {
211 assert(!GV.hasComdat() && "global variable has a comdat section");
212 assert(!GV.hasSection() && "global variable has a non-default section");
213 assert(!GV.isDeclaration() && "global variable is only a declaration");
214 assert(!GV.isThreadLocal() && "global variable is thread-local");
215 assert(GV.getType()->getAddressSpace() == 0 &&
216 "global variable has non-default address space");
217
218 // TODO: Constants accessed by kernels should be handled differently
219 if (GV.isConstant()) {
220 continue;
221 }
222
223 // Global Allocations are handled differently in separate passes
224 if (isRSAllocation(GV)) {
225 continue;
226 }
227
228 Globals.push_back(&GV);
229 }
230
231 return !Globals.empty();
232 }
233
234 bool mForCPU;
235 };
236
237 } // namespace
238
239 char GlobalMergePass::ID = 0;
240
createGlobalMergePass(bool CPU)241 ModulePass *createGlobalMergePass(bool CPU) { return new GlobalMergePass(CPU); }
242
243 } // namespace rs2spirv
244