1 //===-- AMDGPUOpenCLImageTypeLoweringPass.cpp -----------------------------===//
2 //
3 // The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 /// \file
11 /// This pass resolves calls to OpenCL image attribute, image resource ID and
12 /// sampler resource ID getter functions.
13 ///
14 /// Image attributes (size and format) are expected to be passed to the kernel
15 /// as kernel arguments immediately following the image argument itself,
16 /// therefore this pass adds image size and format arguments to the kernel
17 /// functions in the module. The kernel functions with image arguments are
18 /// re-created using the new signature. The new arguments are added to the
19 /// kernel metadata with kernel_arg_type set to "image_size" or "image_format".
20 /// Note: this pass may invalidate pointers to functions.
21 ///
22 /// Resource IDs of read-only images, write-only images and samplers are
23 /// defined to be their index among the kernel arguments of the same
24 /// type and access qualifier.
25 //===----------------------------------------------------------------------===//
26
27 #include "AMDGPU.h"
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/ADT/SmallVector.h"
30 #include "llvm/Analysis/Passes.h"
31 #include "llvm/IR/Constants.h"
32 #include "llvm/IR/Function.h"
33 #include "llvm/IR/Instructions.h"
34 #include "llvm/IR/Module.h"
35 #include "llvm/Transforms/Utils/Cloning.h"
36
37 using namespace llvm;
38
39 namespace {
40
41 StringRef GetImageSizeFunc = "llvm.OpenCL.image.get.size";
42 StringRef GetImageFormatFunc = "llvm.OpenCL.image.get.format";
43 StringRef GetImageResourceIDFunc = "llvm.OpenCL.image.get.resource.id";
44 StringRef GetSamplerResourceIDFunc = "llvm.OpenCL.sampler.get.resource.id";
45
46 StringRef ImageSizeArgMDType = "__llvm_image_size";
47 StringRef ImageFormatArgMDType = "__llvm_image_format";
48
49 StringRef KernelsMDNodeName = "opencl.kernels";
50 StringRef KernelArgMDNodeNames[] = {
51 "kernel_arg_addr_space",
52 "kernel_arg_access_qual",
53 "kernel_arg_type",
54 "kernel_arg_base_type",
55 "kernel_arg_type_qual"};
56 const unsigned NumKernelArgMDNodes = 5;
57
58 typedef SmallVector<Metadata *, 8> MDVector;
59 struct KernelArgMD {
60 MDVector ArgVector[NumKernelArgMDNodes];
61 };
62
63 } // end anonymous namespace
64
65 static inline bool
IsImageType(StringRef TypeString)66 IsImageType(StringRef TypeString) {
67 return TypeString == "image2d_t" || TypeString == "image3d_t";
68 }
69
70 static inline bool
IsSamplerType(StringRef TypeString)71 IsSamplerType(StringRef TypeString) {
72 return TypeString == "sampler_t";
73 }
74
75 static Function *
GetFunctionFromMDNode(MDNode * Node)76 GetFunctionFromMDNode(MDNode *Node) {
77 if (!Node)
78 return nullptr;
79
80 size_t NumOps = Node->getNumOperands();
81 if (NumOps != NumKernelArgMDNodes + 1)
82 return nullptr;
83
84 auto F = mdconst::dyn_extract<Function>(Node->getOperand(0));
85 if (!F)
86 return nullptr;
87
88 // Sanity checks.
89 size_t ExpectNumArgNodeOps = F->arg_size() + 1;
90 for (size_t i = 0; i < NumKernelArgMDNodes; ++i) {
91 MDNode *ArgNode = dyn_cast_or_null<MDNode>(Node->getOperand(i + 1));
92 if (ArgNode->getNumOperands() != ExpectNumArgNodeOps)
93 return nullptr;
94 if (!ArgNode->getOperand(0))
95 return nullptr;
96
97 // FIXME: It should be possible to do image lowering when some metadata
98 // args missing or not in the expected order.
99 MDString *StringNode = dyn_cast<MDString>(ArgNode->getOperand(0));
100 if (!StringNode || StringNode->getString() != KernelArgMDNodeNames[i])
101 return nullptr;
102 }
103
104 return F;
105 }
106
107 static StringRef
AccessQualFromMD(MDNode * KernelMDNode,unsigned ArgIdx)108 AccessQualFromMD(MDNode *KernelMDNode, unsigned ArgIdx) {
109 MDNode *ArgAQNode = cast<MDNode>(KernelMDNode->getOperand(2));
110 return cast<MDString>(ArgAQNode->getOperand(ArgIdx + 1))->getString();
111 }
112
113 static StringRef
ArgTypeFromMD(MDNode * KernelMDNode,unsigned ArgIdx)114 ArgTypeFromMD(MDNode *KernelMDNode, unsigned ArgIdx) {
115 MDNode *ArgTypeNode = cast<MDNode>(KernelMDNode->getOperand(3));
116 return cast<MDString>(ArgTypeNode->getOperand(ArgIdx + 1))->getString();
117 }
118
119 static MDVector
GetArgMD(MDNode * KernelMDNode,unsigned OpIdx)120 GetArgMD(MDNode *KernelMDNode, unsigned OpIdx) {
121 MDVector Res;
122 for (unsigned i = 0; i < NumKernelArgMDNodes; ++i) {
123 MDNode *Node = cast<MDNode>(KernelMDNode->getOperand(i + 1));
124 Res.push_back(Node->getOperand(OpIdx));
125 }
126 return Res;
127 }
128
129 static void
PushArgMD(KernelArgMD & MD,const MDVector & V)130 PushArgMD(KernelArgMD &MD, const MDVector &V) {
131 assert(V.size() == NumKernelArgMDNodes);
132 for (unsigned i = 0; i < NumKernelArgMDNodes; ++i) {
133 MD.ArgVector[i].push_back(V[i]);
134 }
135 }
136
137 namespace {
138
139 class AMDGPUOpenCLImageTypeLoweringPass : public ModulePass {
140 static char ID;
141
142 LLVMContext *Context;
143 Type *Int32Type;
144 Type *ImageSizeType;
145 Type *ImageFormatType;
146 SmallVector<Instruction *, 4> InstsToErase;
147
replaceImageUses(Argument & ImageArg,uint32_t ResourceID,Argument & ImageSizeArg,Argument & ImageFormatArg)148 bool replaceImageUses(Argument &ImageArg, uint32_t ResourceID,
149 Argument &ImageSizeArg,
150 Argument &ImageFormatArg) {
151 bool Modified = false;
152
153 for (auto &Use : ImageArg.uses()) {
154 auto Inst = dyn_cast<CallInst>(Use.getUser());
155 if (!Inst) {
156 continue;
157 }
158
159 Function *F = Inst->getCalledFunction();
160 if (!F)
161 continue;
162
163 Value *Replacement = nullptr;
164 StringRef Name = F->getName();
165 if (Name.startswith(GetImageResourceIDFunc)) {
166 Replacement = ConstantInt::get(Int32Type, ResourceID);
167 } else if (Name.startswith(GetImageSizeFunc)) {
168 Replacement = &ImageSizeArg;
169 } else if (Name.startswith(GetImageFormatFunc)) {
170 Replacement = &ImageFormatArg;
171 } else {
172 continue;
173 }
174
175 Inst->replaceAllUsesWith(Replacement);
176 InstsToErase.push_back(Inst);
177 Modified = true;
178 }
179
180 return Modified;
181 }
182
replaceSamplerUses(Argument & SamplerArg,uint32_t ResourceID)183 bool replaceSamplerUses(Argument &SamplerArg, uint32_t ResourceID) {
184 bool Modified = false;
185
186 for (const auto &Use : SamplerArg.uses()) {
187 auto Inst = dyn_cast<CallInst>(Use.getUser());
188 if (!Inst) {
189 continue;
190 }
191
192 Function *F = Inst->getCalledFunction();
193 if (!F)
194 continue;
195
196 Value *Replacement = nullptr;
197 StringRef Name = F->getName();
198 if (Name == GetSamplerResourceIDFunc) {
199 Replacement = ConstantInt::get(Int32Type, ResourceID);
200 } else {
201 continue;
202 }
203
204 Inst->replaceAllUsesWith(Replacement);
205 InstsToErase.push_back(Inst);
206 Modified = true;
207 }
208
209 return Modified;
210 }
211
replaceImageAndSamplerUses(Function * F,MDNode * KernelMDNode)212 bool replaceImageAndSamplerUses(Function *F, MDNode *KernelMDNode) {
213 uint32_t NumReadOnlyImageArgs = 0;
214 uint32_t NumWriteOnlyImageArgs = 0;
215 uint32_t NumSamplerArgs = 0;
216
217 bool Modified = false;
218 InstsToErase.clear();
219 for (auto ArgI = F->arg_begin(); ArgI != F->arg_end(); ++ArgI) {
220 Argument &Arg = *ArgI;
221 StringRef Type = ArgTypeFromMD(KernelMDNode, Arg.getArgNo());
222
223 // Handle image types.
224 if (IsImageType(Type)) {
225 StringRef AccessQual = AccessQualFromMD(KernelMDNode, Arg.getArgNo());
226 uint32_t ResourceID;
227 if (AccessQual == "read_only") {
228 ResourceID = NumReadOnlyImageArgs++;
229 } else if (AccessQual == "write_only") {
230 ResourceID = NumWriteOnlyImageArgs++;
231 } else {
232 llvm_unreachable("Wrong image access qualifier.");
233 }
234
235 Argument &SizeArg = *(++ArgI);
236 Argument &FormatArg = *(++ArgI);
237 Modified |= replaceImageUses(Arg, ResourceID, SizeArg, FormatArg);
238
239 // Handle sampler type.
240 } else if (IsSamplerType(Type)) {
241 uint32_t ResourceID = NumSamplerArgs++;
242 Modified |= replaceSamplerUses(Arg, ResourceID);
243 }
244 }
245 for (unsigned i = 0; i < InstsToErase.size(); ++i) {
246 InstsToErase[i]->eraseFromParent();
247 }
248
249 return Modified;
250 }
251
252 std::tuple<Function *, MDNode *>
addImplicitArgs(Function * F,MDNode * KernelMDNode)253 addImplicitArgs(Function *F, MDNode *KernelMDNode) {
254 bool Modified = false;
255
256 FunctionType *FT = F->getFunctionType();
257 SmallVector<Type *, 8> ArgTypes;
258
259 // Metadata operands for new MDNode.
260 KernelArgMD NewArgMDs;
261 PushArgMD(NewArgMDs, GetArgMD(KernelMDNode, 0));
262
263 // Add implicit arguments to the signature.
264 for (unsigned i = 0; i < FT->getNumParams(); ++i) {
265 ArgTypes.push_back(FT->getParamType(i));
266 MDVector ArgMD = GetArgMD(KernelMDNode, i + 1);
267 PushArgMD(NewArgMDs, ArgMD);
268
269 if (!IsImageType(ArgTypeFromMD(KernelMDNode, i)))
270 continue;
271
272 // Add size implicit argument.
273 ArgTypes.push_back(ImageSizeType);
274 ArgMD[2] = ArgMD[3] = MDString::get(*Context, ImageSizeArgMDType);
275 PushArgMD(NewArgMDs, ArgMD);
276
277 // Add format implicit argument.
278 ArgTypes.push_back(ImageFormatType);
279 ArgMD[2] = ArgMD[3] = MDString::get(*Context, ImageFormatArgMDType);
280 PushArgMD(NewArgMDs, ArgMD);
281
282 Modified = true;
283 }
284 if (!Modified) {
285 return std::make_tuple(nullptr, nullptr);
286 }
287
288 // Create function with new signature and clone the old body into it.
289 auto NewFT = FunctionType::get(FT->getReturnType(), ArgTypes, false);
290 auto NewF = Function::Create(NewFT, F->getLinkage(), F->getName());
291 ValueToValueMapTy VMap;
292 auto NewFArgIt = NewF->arg_begin();
293 for (auto &Arg: F->args()) {
294 auto ArgName = Arg.getName();
295 NewFArgIt->setName(ArgName);
296 VMap[&Arg] = &(*NewFArgIt++);
297 if (IsImageType(ArgTypeFromMD(KernelMDNode, Arg.getArgNo()))) {
298 (NewFArgIt++)->setName(Twine("__size_") + ArgName);
299 (NewFArgIt++)->setName(Twine("__format_") + ArgName);
300 }
301 }
302 SmallVector<ReturnInst*, 8> Returns;
303 CloneFunctionInto(NewF, F, VMap, /*ModuleLevelChanges=*/false, Returns);
304
305 // Build new MDNode.
306 SmallVector<llvm::Metadata *, 6> KernelMDArgs;
307 KernelMDArgs.push_back(ConstantAsMetadata::get(NewF));
308 for (unsigned i = 0; i < NumKernelArgMDNodes; ++i)
309 KernelMDArgs.push_back(MDNode::get(*Context, NewArgMDs.ArgVector[i]));
310 MDNode *NewMDNode = MDNode::get(*Context, KernelMDArgs);
311
312 return std::make_tuple(NewF, NewMDNode);
313 }
314
transformKernels(Module & M)315 bool transformKernels(Module &M) {
316 NamedMDNode *KernelsMDNode = M.getNamedMetadata(KernelsMDNodeName);
317 if (!KernelsMDNode)
318 return false;
319
320 bool Modified = false;
321 for (unsigned i = 0; i < KernelsMDNode->getNumOperands(); ++i) {
322 MDNode *KernelMDNode = KernelsMDNode->getOperand(i);
323 Function *F = GetFunctionFromMDNode(KernelMDNode);
324 if (!F)
325 continue;
326
327 Function *NewF;
328 MDNode *NewMDNode;
329 std::tie(NewF, NewMDNode) = addImplicitArgs(F, KernelMDNode);
330 if (NewF) {
331 // Replace old function and metadata with new ones.
332 F->eraseFromParent();
333 M.getFunctionList().push_back(NewF);
334 M.getOrInsertFunction(NewF->getName(), NewF->getFunctionType(),
335 NewF->getAttributes());
336 KernelsMDNode->setOperand(i, NewMDNode);
337
338 F = NewF;
339 KernelMDNode = NewMDNode;
340 Modified = true;
341 }
342
343 Modified |= replaceImageAndSamplerUses(F, KernelMDNode);
344 }
345
346 return Modified;
347 }
348
349 public:
AMDGPUOpenCLImageTypeLoweringPass()350 AMDGPUOpenCLImageTypeLoweringPass() : ModulePass(ID) {}
351
runOnModule(Module & M)352 bool runOnModule(Module &M) override {
353 Context = &M.getContext();
354 Int32Type = Type::getInt32Ty(M.getContext());
355 ImageSizeType = ArrayType::get(Int32Type, 3);
356 ImageFormatType = ArrayType::get(Int32Type, 2);
357
358 return transformKernels(M);
359 }
360
getPassName() const361 const char *getPassName() const override {
362 return "AMDGPU OpenCL Image Type Pass";
363 }
364 };
365
366 char AMDGPUOpenCLImageTypeLoweringPass::ID = 0;
367
368 } // end anonymous namespace
369
createAMDGPUOpenCLImageTypeLoweringPass()370 ModulePass *llvm::createAMDGPUOpenCLImageTypeLoweringPass() {
371 return new AMDGPUOpenCLImageTypeLoweringPass();
372 }
373