1 //===- R600OpenCLImageTypeLoweringPass.cpp ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file
10 /// This pass resolves calls to OpenCL image attribute, image resource ID and
11 /// sampler resource ID getter functions.
12 ///
13 /// Image attributes (size and format) are expected to be passed to the kernel
14 /// as kernel arguments immediately following the image argument itself,
15 /// therefore this pass adds image size and format arguments to the kernel
16 /// functions in the module. The kernel functions with image arguments are
17 /// re-created using the new signature. The new arguments are added to the
18 /// kernel metadata with kernel_arg_type set to "image_size" or "image_format".
19 /// Note: this pass may invalidate pointers to functions.
20 ///
21 /// Resource IDs of read-only images, write-only images and samplers are
22 /// defined to be their index among the kernel arguments of the same
23 /// type and access qualifier.
24 //
25 //===----------------------------------------------------------------------===//
26
27 #include "AMDGPU.h"
28 #include "llvm/ADT/SmallVector.h"
29 #include "llvm/ADT/StringRef.h"
30 #include "llvm/ADT/Twine.h"
31 #include "llvm/IR/Argument.h"
32 #include "llvm/IR/DerivedTypes.h"
33 #include "llvm/IR/Constants.h"
34 #include "llvm/IR/Function.h"
35 #include "llvm/IR/Instruction.h"
36 #include "llvm/IR/Instructions.h"
37 #include "llvm/IR/Metadata.h"
38 #include "llvm/IR/Module.h"
39 #include "llvm/IR/Type.h"
40 #include "llvm/IR/Use.h"
41 #include "llvm/IR/User.h"
42 #include "llvm/Pass.h"
43 #include "llvm/Support/Casting.h"
44 #include "llvm/Support/ErrorHandling.h"
45 #include "llvm/Transforms/Utils/Cloning.h"
46 #include "llvm/Transforms/Utils/ValueMapper.h"
47 #include <cassert>
48 #include <cstddef>
49 #include <cstdint>
50 #include <tuple>
51
52 using namespace llvm;
53
54 static StringRef GetImageSizeFunc = "llvm.OpenCL.image.get.size";
55 static StringRef GetImageFormatFunc = "llvm.OpenCL.image.get.format";
56 static StringRef GetImageResourceIDFunc = "llvm.OpenCL.image.get.resource.id";
57 static StringRef GetSamplerResourceIDFunc =
58 "llvm.OpenCL.sampler.get.resource.id";
59
60 static StringRef ImageSizeArgMDType = "__llvm_image_size";
61 static StringRef ImageFormatArgMDType = "__llvm_image_format";
62
63 static StringRef KernelsMDNodeName = "opencl.kernels";
64 static StringRef KernelArgMDNodeNames[] = {
65 "kernel_arg_addr_space",
66 "kernel_arg_access_qual",
67 "kernel_arg_type",
68 "kernel_arg_base_type",
69 "kernel_arg_type_qual"};
70 static const unsigned NumKernelArgMDNodes = 5;
71
72 namespace {
73
74 using MDVector = SmallVector<Metadata *, 8>;
75 struct KernelArgMD {
76 MDVector ArgVector[NumKernelArgMDNodes];
77 };
78
79 } // end anonymous namespace
80
81 static inline bool
IsImageType(StringRef TypeString)82 IsImageType(StringRef TypeString) {
83 return TypeString == "image2d_t" || TypeString == "image3d_t";
84 }
85
86 static inline bool
IsSamplerType(StringRef TypeString)87 IsSamplerType(StringRef TypeString) {
88 return TypeString == "sampler_t";
89 }
90
91 static Function *
GetFunctionFromMDNode(MDNode * Node)92 GetFunctionFromMDNode(MDNode *Node) {
93 if (!Node)
94 return nullptr;
95
96 size_t NumOps = Node->getNumOperands();
97 if (NumOps != NumKernelArgMDNodes + 1)
98 return nullptr;
99
100 auto F = mdconst::dyn_extract<Function>(Node->getOperand(0));
101 if (!F)
102 return nullptr;
103
104 // Sanity checks.
105 size_t ExpectNumArgNodeOps = F->arg_size() + 1;
106 for (size_t i = 0; i < NumKernelArgMDNodes; ++i) {
107 MDNode *ArgNode = dyn_cast_or_null<MDNode>(Node->getOperand(i + 1));
108 if (ArgNode->getNumOperands() != ExpectNumArgNodeOps)
109 return nullptr;
110 if (!ArgNode->getOperand(0))
111 return nullptr;
112
113 // FIXME: It should be possible to do image lowering when some metadata
114 // args missing or not in the expected order.
115 MDString *StringNode = dyn_cast<MDString>(ArgNode->getOperand(0));
116 if (!StringNode || StringNode->getString() != KernelArgMDNodeNames[i])
117 return nullptr;
118 }
119
120 return F;
121 }
122
123 static StringRef
AccessQualFromMD(MDNode * KernelMDNode,unsigned ArgIdx)124 AccessQualFromMD(MDNode *KernelMDNode, unsigned ArgIdx) {
125 MDNode *ArgAQNode = cast<MDNode>(KernelMDNode->getOperand(2));
126 return cast<MDString>(ArgAQNode->getOperand(ArgIdx + 1))->getString();
127 }
128
129 static StringRef
ArgTypeFromMD(MDNode * KernelMDNode,unsigned ArgIdx)130 ArgTypeFromMD(MDNode *KernelMDNode, unsigned ArgIdx) {
131 MDNode *ArgTypeNode = cast<MDNode>(KernelMDNode->getOperand(3));
132 return cast<MDString>(ArgTypeNode->getOperand(ArgIdx + 1))->getString();
133 }
134
135 static MDVector
GetArgMD(MDNode * KernelMDNode,unsigned OpIdx)136 GetArgMD(MDNode *KernelMDNode, unsigned OpIdx) {
137 MDVector Res;
138 for (unsigned i = 0; i < NumKernelArgMDNodes; ++i) {
139 MDNode *Node = cast<MDNode>(KernelMDNode->getOperand(i + 1));
140 Res.push_back(Node->getOperand(OpIdx));
141 }
142 return Res;
143 }
144
145 static void
PushArgMD(KernelArgMD & MD,const MDVector & V)146 PushArgMD(KernelArgMD &MD, const MDVector &V) {
147 assert(V.size() == NumKernelArgMDNodes);
148 for (unsigned i = 0; i < NumKernelArgMDNodes; ++i) {
149 MD.ArgVector[i].push_back(V[i]);
150 }
151 }
152
153 namespace {
154
155 class R600OpenCLImageTypeLoweringPass : public ModulePass {
156 static char ID;
157
158 LLVMContext *Context;
159 Type *Int32Type;
160 Type *ImageSizeType;
161 Type *ImageFormatType;
162 SmallVector<Instruction *, 4> InstsToErase;
163
replaceImageUses(Argument & ImageArg,uint32_t ResourceID,Argument & ImageSizeArg,Argument & ImageFormatArg)164 bool replaceImageUses(Argument &ImageArg, uint32_t ResourceID,
165 Argument &ImageSizeArg,
166 Argument &ImageFormatArg) {
167 bool Modified = false;
168
169 for (auto &Use : ImageArg.uses()) {
170 auto Inst = dyn_cast<CallInst>(Use.getUser());
171 if (!Inst) {
172 continue;
173 }
174
175 Function *F = Inst->getCalledFunction();
176 if (!F)
177 continue;
178
179 Value *Replacement = nullptr;
180 StringRef Name = F->getName();
181 if (Name.startswith(GetImageResourceIDFunc)) {
182 Replacement = ConstantInt::get(Int32Type, ResourceID);
183 } else if (Name.startswith(GetImageSizeFunc)) {
184 Replacement = &ImageSizeArg;
185 } else if (Name.startswith(GetImageFormatFunc)) {
186 Replacement = &ImageFormatArg;
187 } else {
188 continue;
189 }
190
191 Inst->replaceAllUsesWith(Replacement);
192 InstsToErase.push_back(Inst);
193 Modified = true;
194 }
195
196 return Modified;
197 }
198
replaceSamplerUses(Argument & SamplerArg,uint32_t ResourceID)199 bool replaceSamplerUses(Argument &SamplerArg, uint32_t ResourceID) {
200 bool Modified = false;
201
202 for (const auto &Use : SamplerArg.uses()) {
203 auto Inst = dyn_cast<CallInst>(Use.getUser());
204 if (!Inst) {
205 continue;
206 }
207
208 Function *F = Inst->getCalledFunction();
209 if (!F)
210 continue;
211
212 Value *Replacement = nullptr;
213 StringRef Name = F->getName();
214 if (Name == GetSamplerResourceIDFunc) {
215 Replacement = ConstantInt::get(Int32Type, ResourceID);
216 } else {
217 continue;
218 }
219
220 Inst->replaceAllUsesWith(Replacement);
221 InstsToErase.push_back(Inst);
222 Modified = true;
223 }
224
225 return Modified;
226 }
227
replaceImageAndSamplerUses(Function * F,MDNode * KernelMDNode)228 bool replaceImageAndSamplerUses(Function *F, MDNode *KernelMDNode) {
229 uint32_t NumReadOnlyImageArgs = 0;
230 uint32_t NumWriteOnlyImageArgs = 0;
231 uint32_t NumSamplerArgs = 0;
232
233 bool Modified = false;
234 InstsToErase.clear();
235 for (auto ArgI = F->arg_begin(); ArgI != F->arg_end(); ++ArgI) {
236 Argument &Arg = *ArgI;
237 StringRef Type = ArgTypeFromMD(KernelMDNode, Arg.getArgNo());
238
239 // Handle image types.
240 if (IsImageType(Type)) {
241 StringRef AccessQual = AccessQualFromMD(KernelMDNode, Arg.getArgNo());
242 uint32_t ResourceID;
243 if (AccessQual == "read_only") {
244 ResourceID = NumReadOnlyImageArgs++;
245 } else if (AccessQual == "write_only") {
246 ResourceID = NumWriteOnlyImageArgs++;
247 } else {
248 llvm_unreachable("Wrong image access qualifier.");
249 }
250
251 Argument &SizeArg = *(++ArgI);
252 Argument &FormatArg = *(++ArgI);
253 Modified |= replaceImageUses(Arg, ResourceID, SizeArg, FormatArg);
254
255 // Handle sampler type.
256 } else if (IsSamplerType(Type)) {
257 uint32_t ResourceID = NumSamplerArgs++;
258 Modified |= replaceSamplerUses(Arg, ResourceID);
259 }
260 }
261 for (unsigned i = 0; i < InstsToErase.size(); ++i) {
262 InstsToErase[i]->eraseFromParent();
263 }
264
265 return Modified;
266 }
267
268 std::tuple<Function *, MDNode *>
addImplicitArgs(Function * F,MDNode * KernelMDNode)269 addImplicitArgs(Function *F, MDNode *KernelMDNode) {
270 bool Modified = false;
271
272 FunctionType *FT = F->getFunctionType();
273 SmallVector<Type *, 8> ArgTypes;
274
275 // Metadata operands for new MDNode.
276 KernelArgMD NewArgMDs;
277 PushArgMD(NewArgMDs, GetArgMD(KernelMDNode, 0));
278
279 // Add implicit arguments to the signature.
280 for (unsigned i = 0; i < FT->getNumParams(); ++i) {
281 ArgTypes.push_back(FT->getParamType(i));
282 MDVector ArgMD = GetArgMD(KernelMDNode, i + 1);
283 PushArgMD(NewArgMDs, ArgMD);
284
285 if (!IsImageType(ArgTypeFromMD(KernelMDNode, i)))
286 continue;
287
288 // Add size implicit argument.
289 ArgTypes.push_back(ImageSizeType);
290 ArgMD[2] = ArgMD[3] = MDString::get(*Context, ImageSizeArgMDType);
291 PushArgMD(NewArgMDs, ArgMD);
292
293 // Add format implicit argument.
294 ArgTypes.push_back(ImageFormatType);
295 ArgMD[2] = ArgMD[3] = MDString::get(*Context, ImageFormatArgMDType);
296 PushArgMD(NewArgMDs, ArgMD);
297
298 Modified = true;
299 }
300 if (!Modified) {
301 return std::make_tuple(nullptr, nullptr);
302 }
303
304 // Create function with new signature and clone the old body into it.
305 auto NewFT = FunctionType::get(FT->getReturnType(), ArgTypes, false);
306 auto NewF = Function::Create(NewFT, F->getLinkage(), F->getName());
307 ValueToValueMapTy VMap;
308 auto NewFArgIt = NewF->arg_begin();
309 for (auto &Arg: F->args()) {
310 auto ArgName = Arg.getName();
311 NewFArgIt->setName(ArgName);
312 VMap[&Arg] = &(*NewFArgIt++);
313 if (IsImageType(ArgTypeFromMD(KernelMDNode, Arg.getArgNo()))) {
314 (NewFArgIt++)->setName(Twine("__size_") + ArgName);
315 (NewFArgIt++)->setName(Twine("__format_") + ArgName);
316 }
317 }
318 SmallVector<ReturnInst*, 8> Returns;
319 CloneFunctionInto(NewF, F, VMap, /*ModuleLevelChanges=*/false, Returns);
320
321 // Build new MDNode.
322 SmallVector<Metadata *, 6> KernelMDArgs;
323 KernelMDArgs.push_back(ConstantAsMetadata::get(NewF));
324 for (unsigned i = 0; i < NumKernelArgMDNodes; ++i)
325 KernelMDArgs.push_back(MDNode::get(*Context, NewArgMDs.ArgVector[i]));
326 MDNode *NewMDNode = MDNode::get(*Context, KernelMDArgs);
327
328 return std::make_tuple(NewF, NewMDNode);
329 }
330
transformKernels(Module & M)331 bool transformKernels(Module &M) {
332 NamedMDNode *KernelsMDNode = M.getNamedMetadata(KernelsMDNodeName);
333 if (!KernelsMDNode)
334 return false;
335
336 bool Modified = false;
337 for (unsigned i = 0; i < KernelsMDNode->getNumOperands(); ++i) {
338 MDNode *KernelMDNode = KernelsMDNode->getOperand(i);
339 Function *F = GetFunctionFromMDNode(KernelMDNode);
340 if (!F)
341 continue;
342
343 Function *NewF;
344 MDNode *NewMDNode;
345 std::tie(NewF, NewMDNode) = addImplicitArgs(F, KernelMDNode);
346 if (NewF) {
347 // Replace old function and metadata with new ones.
348 F->eraseFromParent();
349 M.getFunctionList().push_back(NewF);
350 M.getOrInsertFunction(NewF->getName(), NewF->getFunctionType(),
351 NewF->getAttributes());
352 KernelsMDNode->setOperand(i, NewMDNode);
353
354 F = NewF;
355 KernelMDNode = NewMDNode;
356 Modified = true;
357 }
358
359 Modified |= replaceImageAndSamplerUses(F, KernelMDNode);
360 }
361
362 return Modified;
363 }
364
365 public:
R600OpenCLImageTypeLoweringPass()366 R600OpenCLImageTypeLoweringPass() : ModulePass(ID) {}
367
runOnModule(Module & M)368 bool runOnModule(Module &M) override {
369 Context = &M.getContext();
370 Int32Type = Type::getInt32Ty(M.getContext());
371 ImageSizeType = ArrayType::get(Int32Type, 3);
372 ImageFormatType = ArrayType::get(Int32Type, 2);
373
374 return transformKernels(M);
375 }
376
getPassName() const377 StringRef getPassName() const override {
378 return "R600 OpenCL Image Type Pass";
379 }
380 };
381
382 } // end anonymous namespace
383
384 char R600OpenCLImageTypeLoweringPass::ID = 0;
385
createR600OpenCLImageTypeLoweringPass()386 ModulePass *llvm::createR600OpenCLImageTypeLoweringPass() {
387 return new R600OpenCLImageTypeLoweringPass();
388 }
389