1 //===- SPIRVLowerOCLBlocks.cpp - Lower OpenCL blocks ------------*- C++ -*-===//
2 //
3 // The LLVM/SPIR-V Translator
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 // Copyright (c) 2014 Advanced Micro Devices, Inc. All rights reserved.
9 //
10 // Permission is hereby granted, free of charge, to any person obtaining a
11 // copy of this software and associated documentation files (the "Software"),
12 // to deal with the Software without restriction, including without limitation
13 // the rights to use, copy, modify, merge, publish, distribute, sublicense,
14 // and/or sell copies of the Software, and to permit persons to whom the
15 // Software is furnished to do so, subject to the following conditions:
16 //
17 // Redistributions of source code must retain the above copyright notice,
18 // this list of conditions and the following disclaimers.
19 // Redistributions in binary form must reproduce the above copyright notice,
20 // this list of conditions and the following disclaimers in the documentation
21 // and/or other materials provided with the distribution.
22 // Neither the names of Advanced Micro Devices, Inc., nor the names of its
23 // contributors may be used to endorse or promote products derived from this
24 // Software without specific prior written permission.
25 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
26 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
27 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
28 // CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
29 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
30 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS WITH
31 // THE SOFTWARE.
32 //
33 //===----------------------------------------------------------------------===//
34 /// \file
35 ///
36 /// This file implements lowering of OpenCL blocks to functions.
37 ///
38 //===----------------------------------------------------------------------===//
39
40 #ifndef OCLLOWERBLOCKS_H_
41 #define OCLLOWERBLOCKS_H_
42
43 #include "SPIRVInternal.h"
44 #include "OCLUtil.h"
45
46 #include "llvm/ADT/DenseMap.h"
47 #include "llvm/ADT/SetVector.h"
48 #include "llvm/ADT/StringSwitch.h"
49 #include "llvm/ADT/Triple.h"
50 #include "llvm/Analysis/AliasAnalysis.h"
51 #include "llvm/Analysis/AssumptionCache.h"
52 #include "llvm/Analysis/CallGraph.h"
53 #include "llvm/IR/Verifier.h"
54 #include "llvm/Bitcode/ReaderWriter.h"
55 #include "llvm/IR/Constants.h"
56 #include "llvm/IR/DerivedTypes.h"
57 #include "llvm/IR/Function.h"
58 #include "llvm/IR/InstrTypes.h"
59 #include "llvm/IR/Instructions.h"
60 #include "llvm/IR/Module.h"
61 #include "llvm/IR/Operator.h"
62 #include "llvm/Pass.h"
63 #include "llvm/PassSupport.h"
64 #include "llvm/Support/Casting.h"
65 #include "llvm/Support/Debug.h"
66 #include "llvm/Support/raw_ostream.h"
67 #include "llvm/Support/ToolOutputFile.h"
68 #include "llvm/Transforms/Utils/Cloning.h"
69
70 #include <iostream>
71 #include <list>
72 #include <memory>
73 #include <set>
74 #include <sstream>
75 #include <vector>
76
77 #define DEBUG_TYPE "spvblocks"
78
79 using namespace llvm;
80 using namespace SPIRV;
81 using namespace OCLUtil;
82
83 namespace SPIRV{
84
85 /// Lower SPIR2 blocks to function calls.
86 ///
87 /// SPIR2 representation of blocks:
88 ///
89 /// block = spir_block_bind(bitcast(block_func), context_len, context_align,
90 /// context)
91 /// block_func_ptr = bitcast(spir_get_block_invoke(block))
92 /// context_ptr = spir_get_block_context(block)
93 /// ret = block_func_ptr(context_ptr, args)
94 ///
95 /// Propagates block_func to each spir_get_block_invoke through def-use chain of
96 /// spir_block_bind, so that
97 /// ret = block_func(context, args)
98 class SPIRVLowerOCLBlocks: public ModulePass {
99 public:
SPIRVLowerOCLBlocks()100 SPIRVLowerOCLBlocks():ModulePass(ID), M(nullptr){
101 initializeSPIRVLowerOCLBlocksPass(*PassRegistry::getPassRegistry());
102 }
103
getAnalysisUsage(AnalysisUsage & AU) const104 virtual void getAnalysisUsage(AnalysisUsage &AU) const {
105 AU.addRequired<CallGraphWrapperPass>();
106 //AU.addRequired<AliasAnalysis>();
107 AU.addRequired<AssumptionCacheTracker>();
108 }
109
runOnModule(Module & Module)110 virtual bool runOnModule(Module &Module) {
111 M = &Module;
112 lowerBlockBind();
113 lowerGetBlockInvoke();
114 lowerGetBlockContext();
115 erase(M->getFunction(SPIR_INTRINSIC_GET_BLOCK_INVOKE));
116 erase(M->getFunction(SPIR_INTRINSIC_GET_BLOCK_CONTEXT));
117 erase(M->getFunction(SPIR_INTRINSIC_BLOCK_BIND));
118 DEBUG(dbgs() << "------- After OCLLowerBlocks ------------\n" <<
119 *M << '\n');
120 return true;
121 }
122
123 static char ID;
124 private:
125 const static int MaxIter = 1000;
126 Module *M;
127
128 bool
lowerBlockBind()129 lowerBlockBind() {
130 auto F = M->getFunction(SPIR_INTRINSIC_BLOCK_BIND);
131 if (!F)
132 return false;
133 int Iter = MaxIter;
134 while(lowerBlockBind(F) && Iter > 0){
135 Iter--;
136 DEBUG(dbgs() << "-------------- after iteration " << MaxIter - Iter <<
137 " --------------\n" << *M << '\n');
138 }
139 assert(Iter > 0 && "Too many iterations");
140 return true;
141 }
142
143 bool
eraseUselessFunctions()144 eraseUselessFunctions() {
145 bool changed = false;
146 for (auto I = M->begin(), E = M->end(); I != E;) {
147 Function *F = static_cast<Function*>(I++);
148 if (!GlobalValue::isInternalLinkage(F->getLinkage()) &&
149 !F->isDeclaration())
150 continue;
151
152 dumpUsers(F, "[eraseUselessFunctions] ");
153 for (auto UI = F->user_begin(), UE = F->user_end(); UI != UE;) {
154 auto U = *UI++;
155 if (auto CE = dyn_cast<ConstantExpr>(U)){
156 if (CE->use_empty()) {
157 CE->dropAllReferences();
158 changed = true;
159 }
160 }
161 }
162 if (F->use_empty()) {
163 erase(F);
164 changed = true;
165 }
166 }
167 return changed;
168 }
169
170 void
lowerGetBlockInvoke()171 lowerGetBlockInvoke() {
172 if (auto F = M->getFunction(SPIR_INTRINSIC_GET_BLOCK_INVOKE)) {
173 for (auto UI = F->user_begin(), UE = F->user_end(); UI != UE;) {
174 auto CI = dyn_cast<CallInst>(*UI++);
175 assert(CI && "Invalid usage of spir_get_block_invoke");
176 lowerGetBlockInvoke(CI);
177 }
178 }
179 }
180
181 void
lowerGetBlockContext()182 lowerGetBlockContext() {
183 if (auto F = M->getFunction(SPIR_INTRINSIC_GET_BLOCK_CONTEXT)) {
184 for (auto UI = F->user_begin(), UE = F->user_end(); UI != UE;) {
185 auto CI = dyn_cast<CallInst>(*UI++);
186 assert(CI && "Invalid usage of spir_get_block_context");
187 lowerGetBlockContext(CI);
188 }
189 }
190 }
191 /// Lower calls of spir_block_bind.
192 /// Return true if the Module is changed.
193 bool
lowerBlockBind(Function * BlockBindFunc)194 lowerBlockBind(Function *BlockBindFunc) {
195 bool changed = false;
196 for (auto I = BlockBindFunc->user_begin(), E = BlockBindFunc->user_end();
197 I != E;) {
198 DEBUG(dbgs() << "[lowerBlockBind] " << **I << '\n');
199 // Handle spir_block_bind(bitcast(block_func), context_len,
200 // context_align, context)
201 auto CallBlkBind = cast<CallInst>(*I++);
202 Function *InvF = nullptr;
203 Value *Ctx = nullptr;
204 Value *CtxLen = nullptr;
205 Value *CtxAlign = nullptr;
206 getBlockInvokeFuncAndContext(CallBlkBind, &InvF, &Ctx, &CtxLen,
207 &CtxAlign);
208 for (auto II = CallBlkBind->user_begin(), EE = CallBlkBind->user_end();
209 II != EE;) {
210 auto BlkUser = *II++;
211 SPIRVDBG(dbgs() << " Block user: " << *BlkUser << '\n');
212 if (auto Ret = dyn_cast<ReturnInst>(BlkUser)) {
213 bool Inlined = false;
214 changed |= lowerReturnBlock(Ret, CallBlkBind, Inlined);
215 if (Inlined)
216 return true;
217 } else if (auto CI = dyn_cast<CallInst>(BlkUser)){
218 auto CallBindF = CI->getCalledFunction();
219 auto Name = CallBindF->getName();
220 std::string DemangledName;
221 if (Name == SPIR_INTRINSIC_GET_BLOCK_INVOKE) {
222 assert(CI->getArgOperand(0) == CallBlkBind);
223 changed |= lowerGetBlockInvoke(CI, cast<Function>(InvF));
224 } else if (Name == SPIR_INTRINSIC_GET_BLOCK_CONTEXT) {
225 assert(CI->getArgOperand(0) == CallBlkBind);
226 // Handle context_ptr = spir_get_block_context(block)
227 lowerGetBlockContext(CI, Ctx);
228 changed = true;
229 } else if (oclIsBuiltin(Name, &DemangledName)) {
230 lowerBlockBuiltin(CI, InvF, Ctx, CtxLen, CtxAlign, DemangledName);
231 changed = true;
232 } else
233 llvm_unreachable("Invalid block user");
234 }
235 }
236 erase(CallBlkBind);
237 }
238 changed |= eraseUselessFunctions();
239 return changed;
240 }
241
242 void
lowerGetBlockContext(CallInst * CallGetBlkCtx,Value * Ctx=nullptr)243 lowerGetBlockContext(CallInst *CallGetBlkCtx, Value *Ctx = nullptr) {
244 if (!Ctx)
245 getBlockInvokeFuncAndContext(CallGetBlkCtx->getArgOperand(0), nullptr,
246 &Ctx);
247 CallGetBlkCtx->replaceAllUsesWith(Ctx);
248 DEBUG(dbgs() << " [lowerGetBlockContext] " << *CallGetBlkCtx << " => " <<
249 *Ctx << "\n\n");
250 erase(CallGetBlkCtx);
251 }
252
253 bool
lowerGetBlockInvoke(CallInst * CallGetBlkInvoke,Function * InvokeF=nullptr)254 lowerGetBlockInvoke(CallInst *CallGetBlkInvoke,
255 Function *InvokeF = nullptr) {
256 bool changed = false;
257 for (auto UI = CallGetBlkInvoke->user_begin(),
258 UE = CallGetBlkInvoke->user_end();
259 UI != UE;) {
260 // Handle block_func_ptr = bitcast(spir_get_block_invoke(block))
261 auto CallInv = cast<Instruction>(*UI++);
262 auto Cast = dyn_cast<BitCastInst>(CallInv);
263 if (Cast)
264 CallInv = dyn_cast<Instruction>(*CallInv->user_begin());
265 DEBUG(dbgs() << "[lowerGetBlockInvoke] " << *CallInv);
266 // Handle ret = block_func_ptr(context_ptr, args)
267 auto CI = cast<CallInst>(CallInv);
268 auto F = CI->getCalledValue();
269 if (InvokeF == nullptr) {
270 getBlockInvokeFuncAndContext(CallGetBlkInvoke->getArgOperand(0),
271 &InvokeF, nullptr);
272 assert(InvokeF);
273 }
274 assert(F->getType() == InvokeF->getType());
275 CI->replaceUsesOfWith(F, InvokeF);
276 DEBUG(dbgs() << " => " << *CI << "\n\n");
277 erase(Cast);
278 changed = true;
279 }
280 erase(CallGetBlkInvoke);
281 return changed;
282 }
283
284 void
lowerBlockBuiltin(CallInst * CI,Function * InvF,Value * Ctx,Value * CtxLen,Value * CtxAlign,const std::string & DemangledName)285 lowerBlockBuiltin(CallInst *CI, Function *InvF, Value *Ctx, Value *CtxLen,
286 Value *CtxAlign, const std::string& DemangledName) {
287 mutateCallInstSPIRV (M, CI, [=](CallInst *CI, std::vector<Value *> &Args) {
288 size_t I = 0;
289 size_t E = Args.size();
290 for (; I != E; ++I) {
291 if (isPointerToOpaqueStructType(Args[I]->getType(),
292 SPIR_TYPE_NAME_BLOCK_T)) {
293 break;
294 }
295 }
296 assert (I < E);
297 Args[I] = castToVoidFuncPtr(InvF);
298 if (I + 1 == E) {
299 Args.push_back(Ctx);
300 Args.push_back(CtxLen);
301 Args.push_back(CtxAlign);
302 } else {
303 Args.insert(Args.begin() + I + 1, CtxAlign);
304 Args.insert(Args.begin() + I + 1, CtxLen);
305 Args.insert(Args.begin() + I + 1, Ctx);
306 }
307 if (DemangledName == kOCLBuiltinName::EnqueueKernel) {
308 // Insert event arguments if there are not.
309 if (!isa<IntegerType>(Args[3]->getType())) {
310 Args.insert(Args.begin() + 3, getInt32(M, 0));
311 Args.insert(Args.begin() + 4, getOCLNullClkEventPtr());
312 }
313 if (!isOCLClkEventPtrType(Args[5]->getType()))
314 Args.insert(Args.begin() + 5, getOCLNullClkEventPtr());
315 }
316 return getSPIRVFuncName(OCLSPIRVBuiltinMap::map(DemangledName));
317 });
318 }
319 /// Transform return of a block.
320 /// The function returning a block is inlined since the context cannot be
321 /// passed to another function.
322 /// Returns true of module is changed.
323 bool
lowerReturnBlock(ReturnInst * Ret,Value * CallBlkBind,bool & Inlined)324 lowerReturnBlock(ReturnInst *Ret, Value *CallBlkBind, bool &Inlined) {
325 auto F = Ret->getParent()->getParent();
326 auto changed = false;
327 for (auto UI = F->user_begin(), UE = F->user_end(); UI != UE;) {
328 auto U = *UI++;
329 dumpUsers(U);
330 auto Inst = dyn_cast<Instruction>(U);
331 if (Inst && Inst->use_empty()) {
332 erase(Inst);
333 changed = true;
334 continue;
335 }
336 auto CI = dyn_cast<CallInst>(U);
337 if(!CI || CI->getCalledFunction() != F)
338 continue;
339
340 DEBUG(dbgs() << "[lowerReturnBlock] inline " << F->getName() << '\n');
341 auto CG = &getAnalysis<CallGraphWrapperPass>().getCallGraph();
342 auto ACT = &getAnalysis<AssumptionCacheTracker>();
343 //auto AA = &getAnalysis<AliasAnalysis>();
344 //InlineFunctionInfo IFI(CG, M->getDataLayout(), AA, ACT);
345 InlineFunctionInfo IFI(CG, ACT);
346 InlineFunction(CI, IFI);
347 Inlined = true;
348 }
349 return changed || Inlined;
350 }
351
352 void
getBlockInvokeFuncAndContext(Value * Blk,Function ** PInvF,Value ** PCtx,Value ** PCtxLen=nullptr,Value ** PCtxAlign=nullptr)353 getBlockInvokeFuncAndContext(Value *Blk, Function **PInvF, Value **PCtx,
354 Value **PCtxLen = nullptr, Value **PCtxAlign = nullptr){
355 Function *InvF = nullptr;
356 Value *Ctx = nullptr;
357 Value *CtxLen = nullptr;
358 Value *CtxAlign = nullptr;
359 if (auto CallBlkBind = dyn_cast<CallInst>(Blk)) {
360 assert(CallBlkBind->getCalledFunction()->getName() ==
361 SPIR_INTRINSIC_BLOCK_BIND && "Invalid block");
362 InvF = dyn_cast<Function>(
363 CallBlkBind->getArgOperand(0)->stripPointerCasts());
364 CtxLen = CallBlkBind->getArgOperand(1);
365 CtxAlign = CallBlkBind->getArgOperand(2);
366 Ctx = CallBlkBind->getArgOperand(3);
367 } else if (auto F = dyn_cast<Function>(Blk->stripPointerCasts())) {
368 InvF = F;
369 Ctx = Constant::getNullValue(IntegerType::getInt8PtrTy(M->getContext()));
370 } else if (auto Load = dyn_cast<LoadInst>(Blk)) {
371 auto Op = Load->getPointerOperand();
372 if (auto GV = dyn_cast<GlobalVariable>(Op)) {
373 if (GV->isConstant()) {
374 InvF = cast<Function>(GV->getInitializer()->stripPointerCasts());
375 Ctx = Constant::getNullValue(IntegerType::getInt8PtrTy(M->getContext()));
376 } else {
377 llvm_unreachable("load non-constant block?");
378 }
379 } else {
380 llvm_unreachable("Loading block from non global?");
381 }
382 } else {
383 llvm_unreachable("Invalid block");
384 }
385 DEBUG(dbgs() << " Block invocation func: " << InvF->getName() << '\n' <<
386 " Block context: " << *Ctx << '\n');
387 assert(InvF && Ctx && "Invalid block");
388 if (PInvF)
389 *PInvF = InvF;
390 if (PCtx)
391 *PCtx = Ctx;
392 if (PCtxLen)
393 *PCtxLen = CtxLen;
394 if (PCtxAlign)
395 *PCtxAlign = CtxAlign;
396 }
397 void
erase(Instruction * I)398 erase(Instruction *I) {
399 if (!I)
400 return;
401 if (I->use_empty()) {
402 I->dropAllReferences();
403 I->eraseFromParent();
404 }
405 else
406 dumpUsers(I);
407 }
408 void
erase(ConstantExpr * I)409 erase(ConstantExpr *I) {
410 if (!I)
411 return;
412 if (I->use_empty()) {
413 I->dropAllReferences();
414 I->destroyConstant();
415 } else
416 dumpUsers(I);
417 }
418 void
erase(Function * F)419 erase(Function *F) {
420 if (!F)
421 return;
422 if (!F->use_empty()) {
423 dumpUsers(F);
424 return;
425 }
426 F->dropAllReferences();
427 auto &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph();
428 CG.removeFunctionFromModule(new CallGraphNode(F));
429 }
430
getOCLClkEventType()431 llvm::PointerType* getOCLClkEventType() {
432 return getOrCreateOpaquePtrType(M, SPIR_TYPE_NAME_CLK_EVENT_T,
433 SPIRAS_Global);
434 }
435
getOCLClkEventPtrType()436 llvm::PointerType* getOCLClkEventPtrType() {
437 return PointerType::get(getOCLClkEventType(), SPIRAS_Generic);
438 }
439
isOCLClkEventPtrType(Type * T)440 bool isOCLClkEventPtrType(Type *T) {
441 if (auto PT = dyn_cast<PointerType>(T))
442 return isPointerToOpaqueStructType(
443 PT->getElementType(), SPIR_TYPE_NAME_CLK_EVENT_T);
444 return false;
445 }
446
getOCLNullClkEventPtr()447 llvm::Constant* getOCLNullClkEventPtr() {
448 return Constant::getNullValue(getOCLClkEventPtrType());
449 }
450
dumpGetBlockInvokeUsers(StringRef Prompt)451 void dumpGetBlockInvokeUsers(StringRef Prompt) {
452 DEBUG(dbgs() << Prompt);
453 dumpUsers(M->getFunction(SPIR_INTRINSIC_GET_BLOCK_INVOKE));
454 }
455 };
456
457 char SPIRVLowerOCLBlocks::ID = 0;
458 }
459
460 INITIALIZE_PASS_BEGIN(SPIRVLowerOCLBlocks, "spvblocks",
461 "SPIR-V lower OCL blocks", false, false)
INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass)462 INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass)
463 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
464 //INITIALIZE_AG_DEPENDENCY(AliasAnalysis)
465 INITIALIZE_PASS_END(SPIRVLowerOCLBlocks, "spvblocks",
466 "SPIR-V lower OCL blocks", false, false)
467
468 ModulePass *llvm::createSPIRVLowerOCLBlocks() {
469 return new SPIRVLowerOCLBlocks();
470 }
471
472 #endif /* OCLLOWERBLOCKS_H_ */
473