• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===-- AMDGPUOpenCLImageTypeLoweringPass.cpp -----------------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 /// \file
11 /// This pass resolves calls to OpenCL image attribute, image resource ID and
12 /// sampler resource ID getter functions.
13 ///
14 /// Image attributes (size and format) are expected to be passed to the kernel
15 /// as kernel arguments immediately following the image argument itself,
16 /// therefore this pass adds image size and format arguments to the kernel
17 /// functions in the module. The kernel functions with image arguments are
18 /// re-created using the new signature. The new arguments are added to the
19 /// kernel metadata with kernel_arg_type set to "image_size" or "image_format".
20 /// Note: this pass may invalidate pointers to functions.
21 ///
22 /// Resource IDs of read-only images, write-only images and samplers are
23 /// defined to be their index among the kernel arguments of the same
24 /// type and access qualifier.
25 //===----------------------------------------------------------------------===//
26 
27 #include "AMDGPU.h"
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/ADT/SmallVector.h"
30 #include "llvm/Analysis/Passes.h"
31 #include "llvm/IR/Constants.h"
32 #include "llvm/IR/Function.h"
33 #include "llvm/IR/Instructions.h"
34 #include "llvm/IR/Module.h"
35 #include "llvm/Transforms/Utils/Cloning.h"
36 
37 using namespace llvm;
38 
39 namespace {
40 
41 StringRef GetImageSizeFunc =         "llvm.OpenCL.image.get.size";
42 StringRef GetImageFormatFunc =       "llvm.OpenCL.image.get.format";
43 StringRef GetImageResourceIDFunc =   "llvm.OpenCL.image.get.resource.id";
44 StringRef GetSamplerResourceIDFunc = "llvm.OpenCL.sampler.get.resource.id";
45 
46 StringRef ImageSizeArgMDType =   "__llvm_image_size";
47 StringRef ImageFormatArgMDType = "__llvm_image_format";
48 
49 StringRef KernelsMDNodeName = "opencl.kernels";
50 StringRef KernelArgMDNodeNames[] = {
51   "kernel_arg_addr_space",
52   "kernel_arg_access_qual",
53   "kernel_arg_type",
54   "kernel_arg_base_type",
55   "kernel_arg_type_qual"};
56 const unsigned NumKernelArgMDNodes = 5;
57 
58 typedef SmallVector<Metadata *, 8> MDVector;
59 struct KernelArgMD {
60   MDVector ArgVector[NumKernelArgMDNodes];
61 };
62 
63 } // end anonymous namespace
64 
65 static inline bool
IsImageType(StringRef TypeString)66 IsImageType(StringRef TypeString) {
67   return TypeString == "image2d_t" || TypeString == "image3d_t";
68 }
69 
70 static inline bool
IsSamplerType(StringRef TypeString)71 IsSamplerType(StringRef TypeString) {
72   return TypeString == "sampler_t";
73 }
74 
75 static Function *
GetFunctionFromMDNode(MDNode * Node)76 GetFunctionFromMDNode(MDNode *Node) {
77   if (!Node)
78     return nullptr;
79 
80   size_t NumOps = Node->getNumOperands();
81   if (NumOps != NumKernelArgMDNodes + 1)
82     return nullptr;
83 
84   auto F = mdconst::dyn_extract<Function>(Node->getOperand(0));
85   if (!F)
86     return nullptr;
87 
88   // Sanity checks.
89   size_t ExpectNumArgNodeOps = F->arg_size() + 1;
90   for (size_t i = 0; i < NumKernelArgMDNodes; ++i) {
91     MDNode *ArgNode = dyn_cast_or_null<MDNode>(Node->getOperand(i + 1));
92     if (ArgNode->getNumOperands() != ExpectNumArgNodeOps)
93       return nullptr;
94     if (!ArgNode->getOperand(0))
95       return nullptr;
96 
97     // FIXME: It should be possible to do image lowering when some metadata
98     // args missing or not in the expected order.
99     MDString *StringNode = dyn_cast<MDString>(ArgNode->getOperand(0));
100     if (!StringNode || StringNode->getString() != KernelArgMDNodeNames[i])
101       return nullptr;
102   }
103 
104   return F;
105 }
106 
107 static StringRef
AccessQualFromMD(MDNode * KernelMDNode,unsigned ArgIdx)108 AccessQualFromMD(MDNode *KernelMDNode, unsigned ArgIdx) {
109   MDNode *ArgAQNode = cast<MDNode>(KernelMDNode->getOperand(2));
110   return cast<MDString>(ArgAQNode->getOperand(ArgIdx + 1))->getString();
111 }
112 
113 static StringRef
ArgTypeFromMD(MDNode * KernelMDNode,unsigned ArgIdx)114 ArgTypeFromMD(MDNode *KernelMDNode, unsigned ArgIdx) {
115   MDNode *ArgTypeNode = cast<MDNode>(KernelMDNode->getOperand(3));
116   return cast<MDString>(ArgTypeNode->getOperand(ArgIdx + 1))->getString();
117 }
118 
119 static MDVector
GetArgMD(MDNode * KernelMDNode,unsigned OpIdx)120 GetArgMD(MDNode *KernelMDNode, unsigned OpIdx) {
121   MDVector Res;
122   for (unsigned i = 0; i < NumKernelArgMDNodes; ++i) {
123     MDNode *Node = cast<MDNode>(KernelMDNode->getOperand(i + 1));
124     Res.push_back(Node->getOperand(OpIdx));
125   }
126   return Res;
127 }
128 
129 static void
PushArgMD(KernelArgMD & MD,const MDVector & V)130 PushArgMD(KernelArgMD &MD, const MDVector &V) {
131   assert(V.size() == NumKernelArgMDNodes);
132   for (unsigned i = 0; i < NumKernelArgMDNodes; ++i) {
133     MD.ArgVector[i].push_back(V[i]);
134   }
135 }
136 
137 namespace {
138 
139 class AMDGPUOpenCLImageTypeLoweringPass : public ModulePass {
140   static char ID;
141 
142   LLVMContext *Context;
143   Type *Int32Type;
144   Type *ImageSizeType;
145   Type *ImageFormatType;
146   SmallVector<Instruction *, 4> InstsToErase;
147 
replaceImageUses(Argument & ImageArg,uint32_t ResourceID,Argument & ImageSizeArg,Argument & ImageFormatArg)148   bool replaceImageUses(Argument &ImageArg, uint32_t ResourceID,
149                         Argument &ImageSizeArg,
150                         Argument &ImageFormatArg) {
151     bool Modified = false;
152 
153     for (auto &Use : ImageArg.uses()) {
154       auto Inst = dyn_cast<CallInst>(Use.getUser());
155       if (!Inst) {
156         continue;
157       }
158 
159       Function *F = Inst->getCalledFunction();
160       if (!F)
161         continue;
162 
163       Value *Replacement = nullptr;
164       StringRef Name = F->getName();
165       if (Name.startswith(GetImageResourceIDFunc)) {
166         Replacement = ConstantInt::get(Int32Type, ResourceID);
167       } else if (Name.startswith(GetImageSizeFunc)) {
168         Replacement = &ImageSizeArg;
169       } else if (Name.startswith(GetImageFormatFunc)) {
170         Replacement = &ImageFormatArg;
171       } else {
172         continue;
173       }
174 
175       Inst->replaceAllUsesWith(Replacement);
176       InstsToErase.push_back(Inst);
177       Modified = true;
178     }
179 
180     return Modified;
181   }
182 
replaceSamplerUses(Argument & SamplerArg,uint32_t ResourceID)183   bool replaceSamplerUses(Argument &SamplerArg, uint32_t ResourceID) {
184     bool Modified = false;
185 
186     for (const auto &Use : SamplerArg.uses()) {
187       auto Inst = dyn_cast<CallInst>(Use.getUser());
188       if (!Inst) {
189         continue;
190       }
191 
192       Function *F = Inst->getCalledFunction();
193       if (!F)
194         continue;
195 
196       Value *Replacement = nullptr;
197       StringRef Name = F->getName();
198       if (Name == GetSamplerResourceIDFunc) {
199         Replacement = ConstantInt::get(Int32Type, ResourceID);
200       } else {
201         continue;
202       }
203 
204       Inst->replaceAllUsesWith(Replacement);
205       InstsToErase.push_back(Inst);
206       Modified = true;
207     }
208 
209     return Modified;
210   }
211 
replaceImageAndSamplerUses(Function * F,MDNode * KernelMDNode)212   bool replaceImageAndSamplerUses(Function *F, MDNode *KernelMDNode) {
213     uint32_t NumReadOnlyImageArgs = 0;
214     uint32_t NumWriteOnlyImageArgs = 0;
215     uint32_t NumSamplerArgs = 0;
216 
217     bool Modified = false;
218     InstsToErase.clear();
219     for (auto ArgI = F->arg_begin(); ArgI != F->arg_end(); ++ArgI) {
220       Argument &Arg = *ArgI;
221       StringRef Type = ArgTypeFromMD(KernelMDNode, Arg.getArgNo());
222 
223       // Handle image types.
224       if (IsImageType(Type)) {
225         StringRef AccessQual = AccessQualFromMD(KernelMDNode, Arg.getArgNo());
226         uint32_t ResourceID;
227         if (AccessQual == "read_only") {
228           ResourceID = NumReadOnlyImageArgs++;
229         } else if (AccessQual == "write_only") {
230           ResourceID = NumWriteOnlyImageArgs++;
231         } else {
232           llvm_unreachable("Wrong image access qualifier.");
233         }
234 
235         Argument &SizeArg = *(++ArgI);
236         Argument &FormatArg = *(++ArgI);
237         Modified |= replaceImageUses(Arg, ResourceID, SizeArg, FormatArg);
238 
239       // Handle sampler type.
240       } else if (IsSamplerType(Type)) {
241         uint32_t ResourceID = NumSamplerArgs++;
242         Modified |= replaceSamplerUses(Arg, ResourceID);
243       }
244     }
245     for (unsigned i = 0; i < InstsToErase.size(); ++i) {
246       InstsToErase[i]->eraseFromParent();
247     }
248 
249     return Modified;
250   }
251 
252   std::tuple<Function *, MDNode *>
addImplicitArgs(Function * F,MDNode * KernelMDNode)253   addImplicitArgs(Function *F, MDNode *KernelMDNode) {
254     bool Modified = false;
255 
256     FunctionType *FT = F->getFunctionType();
257     SmallVector<Type *, 8> ArgTypes;
258 
259     // Metadata operands for new MDNode.
260     KernelArgMD NewArgMDs;
261     PushArgMD(NewArgMDs, GetArgMD(KernelMDNode, 0));
262 
263     // Add implicit arguments to the signature.
264     for (unsigned i = 0; i < FT->getNumParams(); ++i) {
265       ArgTypes.push_back(FT->getParamType(i));
266       MDVector ArgMD = GetArgMD(KernelMDNode, i + 1);
267       PushArgMD(NewArgMDs, ArgMD);
268 
269       if (!IsImageType(ArgTypeFromMD(KernelMDNode, i)))
270         continue;
271 
272       // Add size implicit argument.
273       ArgTypes.push_back(ImageSizeType);
274       ArgMD[2] = ArgMD[3] = MDString::get(*Context, ImageSizeArgMDType);
275       PushArgMD(NewArgMDs, ArgMD);
276 
277       // Add format implicit argument.
278       ArgTypes.push_back(ImageFormatType);
279       ArgMD[2] = ArgMD[3] = MDString::get(*Context, ImageFormatArgMDType);
280       PushArgMD(NewArgMDs, ArgMD);
281 
282       Modified = true;
283     }
284     if (!Modified) {
285       return std::make_tuple(nullptr, nullptr);
286     }
287 
288     // Create function with new signature and clone the old body into it.
289     auto NewFT = FunctionType::get(FT->getReturnType(), ArgTypes, false);
290     auto NewF = Function::Create(NewFT, F->getLinkage(), F->getName());
291     ValueToValueMapTy VMap;
292     auto NewFArgIt = NewF->arg_begin();
293     for (auto &Arg: F->args()) {
294       auto ArgName = Arg.getName();
295       NewFArgIt->setName(ArgName);
296       VMap[&Arg] = &(*NewFArgIt++);
297       if (IsImageType(ArgTypeFromMD(KernelMDNode, Arg.getArgNo()))) {
298         (NewFArgIt++)->setName(Twine("__size_") + ArgName);
299         (NewFArgIt++)->setName(Twine("__format_") + ArgName);
300       }
301     }
302     SmallVector<ReturnInst*, 8> Returns;
303     CloneFunctionInto(NewF, F, VMap, /*ModuleLevelChanges=*/false, Returns);
304 
305     // Build new MDNode.
306     SmallVector<llvm::Metadata *, 6> KernelMDArgs;
307     KernelMDArgs.push_back(ConstantAsMetadata::get(NewF));
308     for (unsigned i = 0; i < NumKernelArgMDNodes; ++i)
309       KernelMDArgs.push_back(MDNode::get(*Context, NewArgMDs.ArgVector[i]));
310     MDNode *NewMDNode = MDNode::get(*Context, KernelMDArgs);
311 
312     return std::make_tuple(NewF, NewMDNode);
313   }
314 
transformKernels(Module & M)315   bool transformKernels(Module &M) {
316     NamedMDNode *KernelsMDNode = M.getNamedMetadata(KernelsMDNodeName);
317     if (!KernelsMDNode)
318       return false;
319 
320     bool Modified = false;
321     for (unsigned i = 0; i < KernelsMDNode->getNumOperands(); ++i) {
322       MDNode *KernelMDNode = KernelsMDNode->getOperand(i);
323       Function *F = GetFunctionFromMDNode(KernelMDNode);
324       if (!F)
325         continue;
326 
327       Function *NewF;
328       MDNode *NewMDNode;
329       std::tie(NewF, NewMDNode) = addImplicitArgs(F, KernelMDNode);
330       if (NewF) {
331         // Replace old function and metadata with new ones.
332         F->eraseFromParent();
333         M.getFunctionList().push_back(NewF);
334         M.getOrInsertFunction(NewF->getName(), NewF->getFunctionType(),
335                               NewF->getAttributes());
336         KernelsMDNode->setOperand(i, NewMDNode);
337 
338         F = NewF;
339         KernelMDNode = NewMDNode;
340         Modified = true;
341       }
342 
343       Modified |= replaceImageAndSamplerUses(F, KernelMDNode);
344     }
345 
346     return Modified;
347   }
348 
349  public:
AMDGPUOpenCLImageTypeLoweringPass()350   AMDGPUOpenCLImageTypeLoweringPass() : ModulePass(ID) {}
351 
runOnModule(Module & M)352   bool runOnModule(Module &M) override {
353     Context = &M.getContext();
354     Int32Type = Type::getInt32Ty(M.getContext());
355     ImageSizeType = ArrayType::get(Int32Type, 3);
356     ImageFormatType = ArrayType::get(Int32Type, 2);
357 
358     return transformKernels(M);
359   }
360 
getPassName() const361   const char *getPassName() const override {
362     return "AMDGPU OpenCL Image Type Pass";
363   }
364 };
365 
366 char AMDGPUOpenCLImageTypeLoweringPass::ID = 0;
367 
368 } // end anonymous namespace
369 
createAMDGPUOpenCLImageTypeLoweringPass()370 ModulePass *llvm::createAMDGPUOpenCLImageTypeLoweringPass() {
371   return new AMDGPUOpenCLImageTypeLoweringPass();
372 }
373