• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- R600OpenCLImageTypeLoweringPass.cpp ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file
10 /// This pass resolves calls to OpenCL image attribute, image resource ID and
11 /// sampler resource ID getter functions.
12 ///
13 /// Image attributes (size and format) are expected to be passed to the kernel
14 /// as kernel arguments immediately following the image argument itself,
15 /// therefore this pass adds image size and format arguments to the kernel
16 /// functions in the module. The kernel functions with image arguments are
17 /// re-created using the new signature. The new arguments are added to the
18 /// kernel metadata with kernel_arg_type set to "image_size" or "image_format".
19 /// Note: this pass may invalidate pointers to functions.
20 ///
21 /// Resource IDs of read-only images, write-only images and samplers are
22 /// defined to be their index among the kernel arguments of the same
23 /// type and access qualifier.
24 //
25 //===----------------------------------------------------------------------===//
26 
27 #include "AMDGPU.h"
28 #include "llvm/ADT/SmallVector.h"
29 #include "llvm/ADT/StringRef.h"
30 #include "llvm/ADT/Twine.h"
31 #include "llvm/IR/Argument.h"
32 #include "llvm/IR/DerivedTypes.h"
33 #include "llvm/IR/Constants.h"
34 #include "llvm/IR/Function.h"
35 #include "llvm/IR/Instruction.h"
36 #include "llvm/IR/Instructions.h"
37 #include "llvm/IR/Metadata.h"
38 #include "llvm/IR/Module.h"
39 #include "llvm/IR/Type.h"
40 #include "llvm/IR/Use.h"
41 #include "llvm/IR/User.h"
42 #include "llvm/Pass.h"
43 #include "llvm/Support/Casting.h"
44 #include "llvm/Support/ErrorHandling.h"
45 #include "llvm/Transforms/Utils/Cloning.h"
46 #include "llvm/Transforms/Utils/ValueMapper.h"
47 #include <cassert>
48 #include <cstddef>
49 #include <cstdint>
50 #include <tuple>
51 
52 using namespace llvm;
53 
54 static StringRef GetImageSizeFunc =         "llvm.OpenCL.image.get.size";
55 static StringRef GetImageFormatFunc =       "llvm.OpenCL.image.get.format";
56 static StringRef GetImageResourceIDFunc =   "llvm.OpenCL.image.get.resource.id";
57 static StringRef GetSamplerResourceIDFunc =
58     "llvm.OpenCL.sampler.get.resource.id";
59 
60 static StringRef ImageSizeArgMDType =   "__llvm_image_size";
61 static StringRef ImageFormatArgMDType = "__llvm_image_format";
62 
63 static StringRef KernelsMDNodeName = "opencl.kernels";
64 static StringRef KernelArgMDNodeNames[] = {
65   "kernel_arg_addr_space",
66   "kernel_arg_access_qual",
67   "kernel_arg_type",
68   "kernel_arg_base_type",
69   "kernel_arg_type_qual"};
70 static const unsigned NumKernelArgMDNodes = 5;
71 
72 namespace {
73 
74 using MDVector = SmallVector<Metadata *, 8>;
75 struct KernelArgMD {
76   MDVector ArgVector[NumKernelArgMDNodes];
77 };
78 
79 } // end anonymous namespace
80 
81 static inline bool
IsImageType(StringRef TypeString)82 IsImageType(StringRef TypeString) {
83   return TypeString == "image2d_t" || TypeString == "image3d_t";
84 }
85 
86 static inline bool
IsSamplerType(StringRef TypeString)87 IsSamplerType(StringRef TypeString) {
88   return TypeString == "sampler_t";
89 }
90 
91 static Function *
GetFunctionFromMDNode(MDNode * Node)92 GetFunctionFromMDNode(MDNode *Node) {
93   if (!Node)
94     return nullptr;
95 
96   size_t NumOps = Node->getNumOperands();
97   if (NumOps != NumKernelArgMDNodes + 1)
98     return nullptr;
99 
100   auto F = mdconst::dyn_extract<Function>(Node->getOperand(0));
101   if (!F)
102     return nullptr;
103 
104   // Sanity checks.
105   size_t ExpectNumArgNodeOps = F->arg_size() + 1;
106   for (size_t i = 0; i < NumKernelArgMDNodes; ++i) {
107     MDNode *ArgNode = dyn_cast_or_null<MDNode>(Node->getOperand(i + 1));
108     if (ArgNode->getNumOperands() != ExpectNumArgNodeOps)
109       return nullptr;
110     if (!ArgNode->getOperand(0))
111       return nullptr;
112 
113     // FIXME: It should be possible to do image lowering when some metadata
114     // args missing or not in the expected order.
115     MDString *StringNode = dyn_cast<MDString>(ArgNode->getOperand(0));
116     if (!StringNode || StringNode->getString() != KernelArgMDNodeNames[i])
117       return nullptr;
118   }
119 
120   return F;
121 }
122 
123 static StringRef
AccessQualFromMD(MDNode * KernelMDNode,unsigned ArgIdx)124 AccessQualFromMD(MDNode *KernelMDNode, unsigned ArgIdx) {
125   MDNode *ArgAQNode = cast<MDNode>(KernelMDNode->getOperand(2));
126   return cast<MDString>(ArgAQNode->getOperand(ArgIdx + 1))->getString();
127 }
128 
129 static StringRef
ArgTypeFromMD(MDNode * KernelMDNode,unsigned ArgIdx)130 ArgTypeFromMD(MDNode *KernelMDNode, unsigned ArgIdx) {
131   MDNode *ArgTypeNode = cast<MDNode>(KernelMDNode->getOperand(3));
132   return cast<MDString>(ArgTypeNode->getOperand(ArgIdx + 1))->getString();
133 }
134 
135 static MDVector
GetArgMD(MDNode * KernelMDNode,unsigned OpIdx)136 GetArgMD(MDNode *KernelMDNode, unsigned OpIdx) {
137   MDVector Res;
138   for (unsigned i = 0; i < NumKernelArgMDNodes; ++i) {
139     MDNode *Node = cast<MDNode>(KernelMDNode->getOperand(i + 1));
140     Res.push_back(Node->getOperand(OpIdx));
141   }
142   return Res;
143 }
144 
145 static void
PushArgMD(KernelArgMD & MD,const MDVector & V)146 PushArgMD(KernelArgMD &MD, const MDVector &V) {
147   assert(V.size() == NumKernelArgMDNodes);
148   for (unsigned i = 0; i < NumKernelArgMDNodes; ++i) {
149     MD.ArgVector[i].push_back(V[i]);
150   }
151 }
152 
153 namespace {
154 
155 class R600OpenCLImageTypeLoweringPass : public ModulePass {
156   static char ID;
157 
158   LLVMContext *Context;
159   Type *Int32Type;
160   Type *ImageSizeType;
161   Type *ImageFormatType;
162   SmallVector<Instruction *, 4> InstsToErase;
163 
replaceImageUses(Argument & ImageArg,uint32_t ResourceID,Argument & ImageSizeArg,Argument & ImageFormatArg)164   bool replaceImageUses(Argument &ImageArg, uint32_t ResourceID,
165                         Argument &ImageSizeArg,
166                         Argument &ImageFormatArg) {
167     bool Modified = false;
168 
169     for (auto &Use : ImageArg.uses()) {
170       auto Inst = dyn_cast<CallInst>(Use.getUser());
171       if (!Inst) {
172         continue;
173       }
174 
175       Function *F = Inst->getCalledFunction();
176       if (!F)
177         continue;
178 
179       Value *Replacement = nullptr;
180       StringRef Name = F->getName();
181       if (Name.startswith(GetImageResourceIDFunc)) {
182         Replacement = ConstantInt::get(Int32Type, ResourceID);
183       } else if (Name.startswith(GetImageSizeFunc)) {
184         Replacement = &ImageSizeArg;
185       } else if (Name.startswith(GetImageFormatFunc)) {
186         Replacement = &ImageFormatArg;
187       } else {
188         continue;
189       }
190 
191       Inst->replaceAllUsesWith(Replacement);
192       InstsToErase.push_back(Inst);
193       Modified = true;
194     }
195 
196     return Modified;
197   }
198 
replaceSamplerUses(Argument & SamplerArg,uint32_t ResourceID)199   bool replaceSamplerUses(Argument &SamplerArg, uint32_t ResourceID) {
200     bool Modified = false;
201 
202     for (const auto &Use : SamplerArg.uses()) {
203       auto Inst = dyn_cast<CallInst>(Use.getUser());
204       if (!Inst) {
205         continue;
206       }
207 
208       Function *F = Inst->getCalledFunction();
209       if (!F)
210         continue;
211 
212       Value *Replacement = nullptr;
213       StringRef Name = F->getName();
214       if (Name == GetSamplerResourceIDFunc) {
215         Replacement = ConstantInt::get(Int32Type, ResourceID);
216       } else {
217         continue;
218       }
219 
220       Inst->replaceAllUsesWith(Replacement);
221       InstsToErase.push_back(Inst);
222       Modified = true;
223     }
224 
225     return Modified;
226   }
227 
replaceImageAndSamplerUses(Function * F,MDNode * KernelMDNode)228   bool replaceImageAndSamplerUses(Function *F, MDNode *KernelMDNode) {
229     uint32_t NumReadOnlyImageArgs = 0;
230     uint32_t NumWriteOnlyImageArgs = 0;
231     uint32_t NumSamplerArgs = 0;
232 
233     bool Modified = false;
234     InstsToErase.clear();
235     for (auto ArgI = F->arg_begin(); ArgI != F->arg_end(); ++ArgI) {
236       Argument &Arg = *ArgI;
237       StringRef Type = ArgTypeFromMD(KernelMDNode, Arg.getArgNo());
238 
239       // Handle image types.
240       if (IsImageType(Type)) {
241         StringRef AccessQual = AccessQualFromMD(KernelMDNode, Arg.getArgNo());
242         uint32_t ResourceID;
243         if (AccessQual == "read_only") {
244           ResourceID = NumReadOnlyImageArgs++;
245         } else if (AccessQual == "write_only") {
246           ResourceID = NumWriteOnlyImageArgs++;
247         } else {
248           llvm_unreachable("Wrong image access qualifier.");
249         }
250 
251         Argument &SizeArg = *(++ArgI);
252         Argument &FormatArg = *(++ArgI);
253         Modified |= replaceImageUses(Arg, ResourceID, SizeArg, FormatArg);
254 
255       // Handle sampler type.
256       } else if (IsSamplerType(Type)) {
257         uint32_t ResourceID = NumSamplerArgs++;
258         Modified |= replaceSamplerUses(Arg, ResourceID);
259       }
260     }
261     for (unsigned i = 0; i < InstsToErase.size(); ++i) {
262       InstsToErase[i]->eraseFromParent();
263     }
264 
265     return Modified;
266   }
267 
268   std::tuple<Function *, MDNode *>
addImplicitArgs(Function * F,MDNode * KernelMDNode)269   addImplicitArgs(Function *F, MDNode *KernelMDNode) {
270     bool Modified = false;
271 
272     FunctionType *FT = F->getFunctionType();
273     SmallVector<Type *, 8> ArgTypes;
274 
275     // Metadata operands for new MDNode.
276     KernelArgMD NewArgMDs;
277     PushArgMD(NewArgMDs, GetArgMD(KernelMDNode, 0));
278 
279     // Add implicit arguments to the signature.
280     for (unsigned i = 0; i < FT->getNumParams(); ++i) {
281       ArgTypes.push_back(FT->getParamType(i));
282       MDVector ArgMD = GetArgMD(KernelMDNode, i + 1);
283       PushArgMD(NewArgMDs, ArgMD);
284 
285       if (!IsImageType(ArgTypeFromMD(KernelMDNode, i)))
286         continue;
287 
288       // Add size implicit argument.
289       ArgTypes.push_back(ImageSizeType);
290       ArgMD[2] = ArgMD[3] = MDString::get(*Context, ImageSizeArgMDType);
291       PushArgMD(NewArgMDs, ArgMD);
292 
293       // Add format implicit argument.
294       ArgTypes.push_back(ImageFormatType);
295       ArgMD[2] = ArgMD[3] = MDString::get(*Context, ImageFormatArgMDType);
296       PushArgMD(NewArgMDs, ArgMD);
297 
298       Modified = true;
299     }
300     if (!Modified) {
301       return std::make_tuple(nullptr, nullptr);
302     }
303 
304     // Create function with new signature and clone the old body into it.
305     auto NewFT = FunctionType::get(FT->getReturnType(), ArgTypes, false);
306     auto NewF = Function::Create(NewFT, F->getLinkage(), F->getName());
307     ValueToValueMapTy VMap;
308     auto NewFArgIt = NewF->arg_begin();
309     for (auto &Arg: F->args()) {
310       auto ArgName = Arg.getName();
311       NewFArgIt->setName(ArgName);
312       VMap[&Arg] = &(*NewFArgIt++);
313       if (IsImageType(ArgTypeFromMD(KernelMDNode, Arg.getArgNo()))) {
314         (NewFArgIt++)->setName(Twine("__size_") + ArgName);
315         (NewFArgIt++)->setName(Twine("__format_") + ArgName);
316       }
317     }
318     SmallVector<ReturnInst*, 8> Returns;
319     CloneFunctionInto(NewF, F, VMap, /*ModuleLevelChanges=*/false, Returns);
320 
321     // Build new MDNode.
322     SmallVector<Metadata *, 6> KernelMDArgs;
323     KernelMDArgs.push_back(ConstantAsMetadata::get(NewF));
324     for (unsigned i = 0; i < NumKernelArgMDNodes; ++i)
325       KernelMDArgs.push_back(MDNode::get(*Context, NewArgMDs.ArgVector[i]));
326     MDNode *NewMDNode = MDNode::get(*Context, KernelMDArgs);
327 
328     return std::make_tuple(NewF, NewMDNode);
329   }
330 
transformKernels(Module & M)331   bool transformKernels(Module &M) {
332     NamedMDNode *KernelsMDNode = M.getNamedMetadata(KernelsMDNodeName);
333     if (!KernelsMDNode)
334       return false;
335 
336     bool Modified = false;
337     for (unsigned i = 0; i < KernelsMDNode->getNumOperands(); ++i) {
338       MDNode *KernelMDNode = KernelsMDNode->getOperand(i);
339       Function *F = GetFunctionFromMDNode(KernelMDNode);
340       if (!F)
341         continue;
342 
343       Function *NewF;
344       MDNode *NewMDNode;
345       std::tie(NewF, NewMDNode) = addImplicitArgs(F, KernelMDNode);
346       if (NewF) {
347         // Replace old function and metadata with new ones.
348         F->eraseFromParent();
349         M.getFunctionList().push_back(NewF);
350         M.getOrInsertFunction(NewF->getName(), NewF->getFunctionType(),
351                               NewF->getAttributes());
352         KernelsMDNode->setOperand(i, NewMDNode);
353 
354         F = NewF;
355         KernelMDNode = NewMDNode;
356         Modified = true;
357       }
358 
359       Modified |= replaceImageAndSamplerUses(F, KernelMDNode);
360     }
361 
362     return Modified;
363   }
364 
365 public:
R600OpenCLImageTypeLoweringPass()366   R600OpenCLImageTypeLoweringPass() : ModulePass(ID) {}
367 
runOnModule(Module & M)368   bool runOnModule(Module &M) override {
369     Context = &M.getContext();
370     Int32Type = Type::getInt32Ty(M.getContext());
371     ImageSizeType = ArrayType::get(Int32Type, 3);
372     ImageFormatType = ArrayType::get(Int32Type, 2);
373 
374     return transformKernels(M);
375   }
376 
getPassName() const377   StringRef getPassName() const override {
378     return "R600 OpenCL Image Type Pass";
379   }
380 };
381 
382 } // end anonymous namespace
383 
384 char R600OpenCLImageTypeLoweringPass::ID = 0;
385 
createR600OpenCLImageTypeLoweringPass()386 ModulePass *llvm::createR600OpenCLImageTypeLoweringPass() {
387   return new R600OpenCLImageTypeLoweringPass();
388 }
389