1 /*
2 * Copyright 2017, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "Wrapper.h"
18
19 #include "llvm/IR/Module.h"
20
21 #include "Builtin.h"
22 #include "Context.h"
23 #include "GlobalAllocSPIRITPass.h"
24 #include "RSAllocationUtils.h"
25 #include "bcinfo/MetadataExtractor.h"
26 #include "builder.h"
27 #include "instructions.h"
28 #include "module.h"
29 #include "pass.h"
30
31 #include <sstream>
32 #include <vector>
33
34 using bcinfo::MetadataExtractor;
35
36 namespace android {
37 namespace spirit {
38
AddBuffer(Instruction * elementType,uint32_t binding,Builder & b,Module * m)39 VariableInst *AddBuffer(Instruction *elementType, uint32_t binding, Builder &b,
40 Module *m) {
41 auto ArrTy = m->getRuntimeArrayType(elementType);
42 const size_t stride = m->getSize(elementType);
43 ArrTy->decorate(Decoration::ArrayStride)->addExtraOperand(stride);
44 auto StructTy = m->getStructType(ArrTy);
45 StructTy->decorate(Decoration::BufferBlock);
46 StructTy->memberDecorate(0, Decoration::Offset)->addExtraOperand(0);
47
48 auto StructPtrTy = m->getPointerType(StorageClass::Uniform, StructTy);
49
50 VariableInst *bufferVar = b.MakeVariable(StructPtrTy, StorageClass::Uniform);
51 bufferVar->decorate(Decoration::DescriptorSet)->addExtraOperand(0);
52 bufferVar->decorate(Decoration::Binding)->addExtraOperand(binding);
53 m->addVariable(bufferVar);
54
55 return bufferVar;
56 }
57
AddWrapper(const char * name,const uint32_t signature,const uint32_t numInput,Builder & b,Module * m)58 bool AddWrapper(const char *name, const uint32_t signature,
59 const uint32_t numInput, Builder &b, Module *m) {
60 FunctionDefinition *kernel = m->lookupFunctionDefinitionByName(name);
61 if (kernel == nullptr) {
62 // In the metadata for RenderScript LLVM bitcode, the first foreach kernel
63 // is always reserved for the root kernel, even though in the most recent RS
64 // apps it does not exist. Simply bypass wrapper generation here, and return
65 // true for this case.
66 // Otherwise, if a non-root kernel function cannot be found, it is a
67 // fatal internal error which is really unexpected.
68 return (strncmp(name, "root", 4) == 0);
69 }
70
71 // The following three cases are not supported
72 if (!MetadataExtractor::hasForEachSignatureKernel(signature)) {
73 // Not handling old-style kernel
74 return false;
75 }
76
77 if (MetadataExtractor::hasForEachSignatureUsrData(signature)) {
78 // Not handling the user argument
79 return false;
80 }
81
82 if (MetadataExtractor::hasForEachSignatureCtxt(signature)) {
83 // Not handling the context argument
84 return false;
85 }
86
87 TypeVoidInst *VoidTy = m->getVoidType();
88 TypeFunctionInst *FuncTy = m->getFunctionType(VoidTy, nullptr, 0);
89 FunctionDefinition *Func =
90 b.MakeFunctionDefinition(VoidTy, FunctionControl::None, FuncTy);
91 m->addFunctionDefinition(Func);
92
93 Block *Blk = b.MakeBlock();
94 Func->addBlock(Blk);
95
96 Blk->addInstruction(b.MakeLabel());
97
98 TypeIntInst *UIntTy = m->getUnsignedIntType(32);
99
100 Instruction *XValue = nullptr;
101 Instruction *YValue = nullptr;
102 Instruction *ZValue = nullptr;
103 Instruction *Index = nullptr;
104 VariableInst *InvocationId = nullptr;
105 VariableInst *NumWorkgroups = nullptr;
106
107 if (MetadataExtractor::hasForEachSignatureIn(signature) ||
108 MetadataExtractor::hasForEachSignatureOut(signature) ||
109 MetadataExtractor::hasForEachSignatureX(signature) ||
110 MetadataExtractor::hasForEachSignatureY(signature) ||
111 MetadataExtractor::hasForEachSignatureZ(signature)) {
112 TypeVectorInst *V3UIntTy = m->getVectorType(UIntTy, 3);
113 InvocationId = m->getInvocationId();
114 auto IID = b.MakeLoad(V3UIntTy, InvocationId);
115 Blk->addInstruction(IID);
116
117 XValue = b.MakeCompositeExtract(UIntTy, IID, {0});
118 Blk->addInstruction(XValue);
119
120 YValue = b.MakeCompositeExtract(UIntTy, IID, {1});
121 Blk->addInstruction(YValue);
122
123 ZValue = b.MakeCompositeExtract(UIntTy, IID, {2});
124 Blk->addInstruction(ZValue);
125
126 // TODO: Use SpecConstant for workgroup size
127 auto ConstOne = m->getConstant(UIntTy, 1U);
128 auto GroupSize =
129 m->getConstantComposite(V3UIntTy, ConstOne, ConstOne, ConstOne);
130
131 auto GroupSizeX = b.MakeCompositeExtract(UIntTy, GroupSize, {0});
132 Blk->addInstruction(GroupSizeX);
133
134 auto GroupSizeY = b.MakeCompositeExtract(UIntTy, GroupSize, {1});
135 Blk->addInstruction(GroupSizeY);
136
137 NumWorkgroups = m->getNumWorkgroups();
138 auto NumGroup = b.MakeLoad(V3UIntTy, NumWorkgroups);
139 Blk->addInstruction(NumGroup);
140
141 auto NumGroupX = b.MakeCompositeExtract(UIntTy, NumGroup, {0});
142 Blk->addInstruction(NumGroupX);
143
144 auto NumGroupY = b.MakeCompositeExtract(UIntTy, NumGroup, {1});
145 Blk->addInstruction(NumGroupY);
146
147 auto GlobalSizeX = b.MakeIMul(UIntTy, GroupSizeX, NumGroupX);
148 Blk->addInstruction(GlobalSizeX);
149
150 auto GlobalSizeY = b.MakeIMul(UIntTy, GroupSizeY, NumGroupY);
151 Blk->addInstruction(GlobalSizeY);
152
153 auto RowsAlongZ = b.MakeIMul(UIntTy, GlobalSizeY, ZValue);
154 Blk->addInstruction(RowsAlongZ);
155
156 auto NumRows = b.MakeIAdd(UIntTy, YValue, RowsAlongZ);
157 Blk->addInstruction(NumRows);
158
159 auto NumCellsFromYZ = b.MakeIMul(UIntTy, GlobalSizeX, NumRows);
160 Blk->addInstruction(NumCellsFromYZ);
161
162 Index = b.MakeIAdd(UIntTy, NumCellsFromYZ, XValue);
163 Blk->addInstruction(Index);
164 }
165
166 std::vector<IdRef> inputs;
167
168 ConstantInst *ConstZero = m->getConstant(UIntTy, 0);
169
170 for (uint32_t i = 0; i < numInput; i++) {
171 FunctionParameterInst *param = kernel->getParameter(i);
172 Instruction *elementType = param->mResultType.mInstruction;
173 VariableInst *inputBuffer = AddBuffer(elementType, i + 3, b, m);
174
175 TypePointerInst *PtrTy =
176 m->getPointerType(StorageClass::Function, elementType);
177 AccessChainInst *Ptr =
178 b.MakeAccessChain(PtrTy, inputBuffer, {ConstZero, Index});
179 Blk->addInstruction(Ptr);
180
181 Instruction *input = b.MakeLoad(elementType, Ptr);
182 Blk->addInstruction(input);
183
184 inputs.push_back(IdRef(input));
185 }
186
187 // TODO: Convert from unsigned int to signed int if that is what the kernel
188 // function takes for the coordinate parameters
189 if (MetadataExtractor::hasForEachSignatureX(signature)) {
190 inputs.push_back(XValue);
191 if (MetadataExtractor::hasForEachSignatureY(signature)) {
192 inputs.push_back(YValue);
193 if (MetadataExtractor::hasForEachSignatureZ(signature)) {
194 inputs.push_back(ZValue);
195 }
196 }
197 }
198
199 auto resultType = kernel->getReturnType();
200 auto kernelCall =
201 b.MakeFunctionCall(resultType, kernel->getInstruction(), inputs);
202 Blk->addInstruction(kernelCall);
203
204 if (MetadataExtractor::hasForEachSignatureOut(signature)) {
205 VariableInst *OutputBuffer = AddBuffer(resultType, 2, b, m);
206 auto resultPtrType = m->getPointerType(StorageClass::Function, resultType);
207 AccessChainInst *OutPtr =
208 b.MakeAccessChain(resultPtrType, OutputBuffer, {ConstZero, Index});
209 Blk->addInstruction(OutPtr);
210 Blk->addInstruction(b.MakeStore(OutPtr, kernelCall));
211 }
212
213 Blk->addInstruction(b.MakeReturn());
214
215 std::string wrapperName("entry_");
216 wrapperName.append(name);
217
218 EntryPointDefinition *entry = b.MakeEntryPointDefinition(
219 ExecutionModel::GLCompute, Func, wrapperName.c_str());
220
221 entry->setLocalSize(1, 1, 1);
222
223 if (Index != nullptr) {
224 entry->addToInterface(InvocationId);
225 entry->addToInterface(NumWorkgroups);
226 }
227
228 m->addEntryPoint(entry);
229
230 return true;
231 }
232
DecorateGlobalBuffer(llvm::Module & LM,Builder & b,Module * m)233 bool DecorateGlobalBuffer(llvm::Module &LM, Builder &b, Module *m) {
234 Instruction *inst = m->lookupByName("__GPUBlock");
235 if (inst == nullptr) {
236 return true;
237 }
238
239 VariableInst *bufferVar = static_cast<VariableInst *>(inst);
240 bufferVar->decorate(Decoration::DescriptorSet)->addExtraOperand(0);
241 bufferVar->decorate(Decoration::Binding)->addExtraOperand(0);
242
243 TypePointerInst *StructPtrTy =
244 static_cast<TypePointerInst *>(bufferVar->mResultType.mInstruction);
245 TypeStructInst *StructTy =
246 static_cast<TypeStructInst *>(StructPtrTy->mOperand2.mInstruction);
247 StructTy->decorate(Decoration::BufferBlock);
248
249 // Decorate each member with proper offsets
250
251 const auto GlobalsB = LM.globals().begin();
252 const auto GlobalsE = LM.globals().end();
253 const auto Found =
254 std::find_if(GlobalsB, GlobalsE, [](const llvm::GlobalVariable &GV) {
255 return GV.getName() == "__GPUBlock";
256 });
257
258 if (Found == GlobalsE) {
259 return true; // GPUBlock not found - not an error by itself.
260 }
261
262 const llvm::GlobalVariable &G = *Found;
263
264 rs2spirv::Context &Ctxt = rs2spirv::Context::getInstance();
265 bool IsCorrectTy = false;
266 if (const auto *LPtrTy = llvm::dyn_cast<llvm::PointerType>(G.getType())) {
267 if (auto *LStructTy =
268 llvm::dyn_cast<llvm::StructType>(LPtrTy->getElementType())) {
269 IsCorrectTy = true;
270
271 const auto &DLayout = LM.getDataLayout();
272 const auto *SLayout = DLayout.getStructLayout(LStructTy);
273 assert(SLayout);
274 if (SLayout == nullptr) {
275 std::cerr << "struct layout is null" << std::endl;
276 return false;
277 }
278 std::vector<uint32_t> offsets;
279 for (uint32_t i = 0, e = LStructTy->getNumElements(); i != e; ++i) {
280 auto decor = StructTy->memberDecorate(i, Decoration::Offset);
281 if (!decor) {
282 std::cerr << "failed creating member decoration for field " << i
283 << std::endl;
284 return false;
285 }
286 const uint32_t offset = (uint32_t)SLayout->getElementOffset(i);
287 decor->addExtraOperand(offset);
288 offsets.push_back(offset);
289 }
290 std::stringstream ssOffsets;
291 // TODO: define this string in a central place
292 ssOffsets << ".rsov.ExportedVars:";
293 for(uint32_t slot = 0; slot < Ctxt.getNumExportVar(); slot++) {
294 const uint32_t index = Ctxt.getExportVarIndex(slot);
295 const uint32_t offset = offsets[index];
296 ssOffsets << offset << ';';
297 }
298 m->addString(ssOffsets.str().c_str());
299
300 std::stringstream ssGlobalSize;
301 ssGlobalSize << ".rsov.GlobalSize:" << Ctxt.getGlobalSize();
302 m->addString(ssGlobalSize.str().c_str());
303 }
304 }
305
306 if (!IsCorrectTy) {
307 return false;
308 }
309
310 llvm::SmallVector<rs2spirv::RSAllocationInfo, 2> RSAllocs;
311 if (!getRSAllocationInfo(LM, RSAllocs)) {
312 // llvm::errs() << "Extracting rs_allocation info failed\n";
313 return true;
314 }
315
316 // TODO: clean up the binding number assignment
317 size_t BindingNum = 3;
318 for (const auto &A : RSAllocs) {
319 Instruction *inst = m->lookupByName(A.VarName.c_str());
320 if (inst == nullptr) {
321 return false;
322 }
323 VariableInst *bufferVar = static_cast<VariableInst *>(inst);
324 bufferVar->decorate(Decoration::DescriptorSet)->addExtraOperand(0);
325 bufferVar->decorate(Decoration::Binding)->addExtraOperand(BindingNum++);
326 }
327
328 return true;
329 }
330
AddHeader(Module * m)331 void AddHeader(Module *m) {
332 m->addCapability(Capability::Shader);
333 m->setMemoryModel(AddressingModel::Logical, MemoryModel::GLSL450);
334
335 m->addSource(SourceLanguage::GLSL, 450);
336 m->addSourceExtension("GL_ARB_separate_shader_objects");
337 m->addSourceExtension("GL_ARB_shading_language_420pack");
338 m->addSourceExtension("GL_GOOGLE_cpp_style_line_directive");
339 m->addSourceExtension("GL_GOOGLE_include_directive");
340 }
341
342 namespace {
343
344 class StorageClassVisitor : public DoNothingVisitor {
345 public:
visit(TypePointerInst * inst)346 void visit(TypePointerInst *inst) override {
347 matchAndReplace(inst->mOperand1);
348 }
349
visit(TypeForwardPointerInst * inst)350 void visit(TypeForwardPointerInst *inst) override {
351 matchAndReplace(inst->mOperand2);
352 }
353
visit(VariableInst * inst)354 void visit(VariableInst *inst) override { matchAndReplace(inst->mOperand1); }
355
356 private:
matchAndReplace(StorageClass & storage)357 void matchAndReplace(StorageClass &storage) {
358 if (storage == StorageClass::Function) {
359 storage = StorageClass::Uniform;
360 }
361 }
362 };
363
FixGlobalStorageClass(Module * m)364 void FixGlobalStorageClass(Module *m) {
365 StorageClassVisitor v;
366 m->getGlobalSection()->accept(&v);
367 }
368
369 } // anonymous namespace
370
AddWrappers(llvm::Module & LM,android::spirit::Module * m)371 bool AddWrappers(llvm::Module &LM,
372 android::spirit::Module *m) {
373 rs2spirv::Context &Ctxt = rs2spirv::Context::getInstance();
374 const bcinfo::MetadataExtractor &metadata = Ctxt.getMetadata();
375 android::spirit::Builder b;
376
377 m->setBuilder(&b);
378
379 FixGlobalStorageClass(m);
380
381 AddHeader(m);
382
383 DecorateGlobalBuffer(LM, b, m);
384
385 const size_t numKernel = metadata.getExportForEachSignatureCount();
386 const char **kernelName = metadata.getExportForEachNameList();
387 const uint32_t *kernelSigature = metadata.getExportForEachSignatureList();
388 const uint32_t *inputCount = metadata.getExportForEachInputCountList();
389
390 for (size_t i = 0; i < numKernel; i++) {
391 bool success =
392 AddWrapper(kernelName[i], kernelSigature[i], inputCount[i], b, m);
393 if (!success) {
394 return false;
395 }
396 }
397
398 m->consolidateAnnotations();
399 return true;
400 }
401
402 class WrapperPass : public Pass {
403 public:
WrapperPass(const llvm::Module & LM)404 WrapperPass(const llvm::Module &LM) : mLLVMModule(const_cast<llvm::Module&>(LM)) {}
405
run(Module * m,int * error)406 Module *run(Module *m, int *error) override {
407 bool success = AddWrappers(mLLVMModule, m);
408 if (error) {
409 *error = success ? 0 : -1;
410 }
411 return m;
412 }
413
414 private:
415 llvm::Module &mLLVMModule;
416 };
417
418 } // namespace spirit
419 } // namespace android
420
421 namespace rs2spirv {
422
CreateWrapperPass(const llvm::Module & LLVMModule)423 android::spirit::Pass* CreateWrapperPass(const llvm::Module &LLVMModule) {
424 return new android::spirit::WrapperPass(LLVMModule);
425 }
426
427 } // namespace rs2spirv
428