• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2012, The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "Config.h"
18 #include "bcc/bcc_assert.h"
19 
20 #include "DebugHelper.h"
21 
22 #include "llvm/DerivedTypes.h"
23 #include "llvm/Function.h"
24 #include "llvm/Instructions.h"
25 #include "llvm/Module.h"
26 #include "llvm/Pass.h"
27 #include "llvm/Type.h"
28 #include "llvm/Support/IRBuilder.h"
29 
30 namespace {
31   /* ForEachExpandPass - This pass operates on functions that are able to be
32    * called via rsForEach() or "foreach_<NAME>". We create an inner loop for
33    * the ForEach-able function to be invoked over the appropriate data cells
34    * of the input/output allocations (adjusting other relevant parameters as
35    * we go). We support doing this for any ForEach-able compute kernels.
36    * The new function name is the original function name followed by
37    * ".expand". Note that we still generate code for the original function.
38    */
39   class ForEachExpandPass : public llvm::ModulePass {
40   private:
41   static char ID;
42 
43   llvm::Module *M;
44   llvm::LLVMContext *C;
45 
46   std::vector<std::string>& mNames;
47   std::vector<uint32_t>& mSignatures;
48 
getRootSignature(llvm::Function * F)49   uint32_t getRootSignature(llvm::Function *F) {
50     const llvm::NamedMDNode *ExportForEachMetadata =
51         M->getNamedMetadata("#rs_export_foreach");
52 
53     if (!ExportForEachMetadata) {
54       llvm::SmallVector<llvm::Type*, 8> RootArgTys;
55       for (llvm::Function::arg_iterator B = F->arg_begin(),
56                                         E = F->arg_end();
57            B != E;
58            ++B) {
59         RootArgTys.push_back(B->getType());
60       }
61 
62       // For pre-ICS bitcode, we may not have signature information. In that
63       // case, we use the size of the RootArgTys to select the number of
64       // arguments.
65       return (1 << RootArgTys.size()) - 1;
66     }
67 
68     bccAssert(ExportForEachMetadata->getNumOperands() > 0);
69 
70     // We only handle the case for legacy root() functions here, so this is
71     // hard-coded to look at only the first such function.
72     llvm::MDNode *SigNode = ExportForEachMetadata->getOperand(0);
73     if (SigNode != NULL && SigNode->getNumOperands() == 1) {
74       llvm::Value *SigVal = SigNode->getOperand(0);
75       if (SigVal->getValueID() == llvm::Value::MDStringVal) {
76         llvm::StringRef SigString =
77             static_cast<llvm::MDString*>(SigVal)->getString();
78         uint32_t Signature = 0;
79         if (SigString.getAsInteger(10, Signature)) {
80           ALOGE("Non-integer signature value '%s'", SigString.str().c_str());
81           return 0;
82         }
83         return Signature;
84       }
85     }
86 
87     return 0;
88   }
89 
hasIn(uint32_t Signature)90   static bool hasIn(uint32_t Signature) {
91     return Signature & 1;
92   }
93 
hasOut(uint32_t Signature)94   static bool hasOut(uint32_t Signature) {
95     return Signature & 2;
96   }
97 
hasUsrData(uint32_t Signature)98   static bool hasUsrData(uint32_t Signature) {
99     return Signature & 4;
100   }
101 
hasX(uint32_t Signature)102   static bool hasX(uint32_t Signature) {
103     return Signature & 8;
104   }
105 
hasY(uint32_t Signature)106   static bool hasY(uint32_t Signature) {
107     return Signature & 16;
108   }
109 
110   public:
ForEachExpandPass(std::vector<std::string> & Names,std::vector<uint32_t> & Signatures)111   ForEachExpandPass(std::vector<std::string>& Names,
112                     std::vector<uint32_t>& Signatures)
113       : ModulePass(ID), M(NULL), C(NULL), mNames(Names),
114         mSignatures(Signatures) {
115   }
116 
117   /* Performs the actual optimization on a selected function. On success, the
118    * Module will contain a new function of the name "<NAME>.expand" that
119    * invokes <NAME>() in a loop with the appropriate parameters.
120    */
ExpandFunction(llvm::Function * F,uint32_t Signature)121   bool ExpandFunction(llvm::Function *F, uint32_t Signature) {
122     ALOGV("Expanding ForEach-able Function %s", F->getName().str().c_str());
123 
124     if (!Signature) {
125       Signature = getRootSignature(F);
126       if (!Signature) {
127         // We couldn't determine how to expand this function based on its
128         // function signature.
129         return false;
130       }
131     }
132 
133     llvm::Type *VoidPtrTy = llvm::Type::getInt8PtrTy(*C);
134     llvm::Type *Int32Ty = llvm::Type::getInt32Ty(*C);
135     llvm::Type *SizeTy = Int32Ty;
136 
137     /* Defined in frameworks/base/libs/rs/rs_hal.h:
138      *
139      * struct RsForEachStubParamStruct {
140      *   const void *in;
141      *   void *out;
142      *   const void *usr;
143      *   size_t usr_len;
144      *   uint32_t x;
145      *   uint32_t y;
146      *   uint32_t z;
147      *   uint32_t lod;
148      *   enum RsAllocationCubemapFace face;
149      *   uint32_t ar[16];
150      * };
151      */
152     llvm::SmallVector<llvm::Type*, 9> StructTys;
153     StructTys.push_back(VoidPtrTy);  // const void *in
154     StructTys.push_back(VoidPtrTy);  // void *out
155     StructTys.push_back(VoidPtrTy);  // const void *usr
156     StructTys.push_back(SizeTy);     // size_t usr_len
157     StructTys.push_back(Int32Ty);    // uint32_t x
158     StructTys.push_back(Int32Ty);    // uint32_t y
159     StructTys.push_back(Int32Ty);    // uint32_t z
160     StructTys.push_back(Int32Ty);    // uint32_t lod
161     StructTys.push_back(Int32Ty);    // enum RsAllocationCubemapFace
162     StructTys.push_back(llvm::ArrayType::get(Int32Ty, 16));  // uint32_t ar[16]
163 
164     llvm::Type *ForEachStubPtrTy = llvm::StructType::create(
165         StructTys, "RsForEachStubParamStruct")->getPointerTo();
166 
167     /* Create the function signature for our expanded function.
168      * void (const RsForEachStubParamStruct *p, uint32_t x1, uint32_t x2,
169      *       uint32_t instep, uint32_t outstep)
170      */
171     llvm::SmallVector<llvm::Type*, 8> ParamTys;
172     ParamTys.push_back(ForEachStubPtrTy);  // const RsForEachStubParamStruct *p
173     ParamTys.push_back(Int32Ty);           // uint32_t x1
174     ParamTys.push_back(Int32Ty);           // uint32_t x2
175     ParamTys.push_back(Int32Ty);           // uint32_t instep
176     ParamTys.push_back(Int32Ty);           // uint32_t outstep
177 
178     llvm::FunctionType *FT =
179         llvm::FunctionType::get(llvm::Type::getVoidTy(*C), ParamTys, false);
180     llvm::Function *ExpandedFunc =
181         llvm::Function::Create(FT,
182                                llvm::GlobalValue::ExternalLinkage,
183                                F->getName() + ".expand", M);
184 
185     // Create and name the actual arguments to this expanded function.
186     llvm::SmallVector<llvm::Argument*, 8> ArgVec;
187     for (llvm::Function::arg_iterator B = ExpandedFunc->arg_begin(),
188                                       E = ExpandedFunc->arg_end();
189          B != E;
190          ++B) {
191       ArgVec.push_back(B);
192     }
193 
194     if (ArgVec.size() != 5) {
195       ALOGE("Incorrect number of arguments to function: %zu",
196             ArgVec.size());
197       return false;
198     }
199     llvm::Value *Arg_p = ArgVec[0];
200     llvm::Value *Arg_x1 = ArgVec[1];
201     llvm::Value *Arg_x2 = ArgVec[2];
202     llvm::Value *Arg_instep = ArgVec[3];
203     llvm::Value *Arg_outstep = ArgVec[4];
204 
205     Arg_p->setName("p");
206     Arg_x1->setName("x1");
207     Arg_x2->setName("x2");
208     Arg_instep->setName("instep");
209     Arg_outstep->setName("outstep");
210 
211     // Construct the actual function body.
212     llvm::BasicBlock *Begin =
213         llvm::BasicBlock::Create(*C, "Begin", ExpandedFunc);
214     llvm::IRBuilder<> Builder(Begin);
215 
216     // uint32_t X = x1;
217     llvm::AllocaInst *AX = Builder.CreateAlloca(Int32Ty, 0, "AX");
218     Builder.CreateStore(Arg_x1, AX);
219 
220     // Collect and construct the arguments for the kernel().
221     // Note that we load any loop-invariant arguments before entering the Loop.
222     llvm::Function::arg_iterator Args = F->arg_begin();
223 
224     llvm::Type *InTy = NULL;
225     llvm::AllocaInst *AIn = NULL;
226     if (hasIn(Signature)) {
227       InTy = Args->getType();
228       AIn = Builder.CreateAlloca(InTy, 0, "AIn");
229       Builder.CreateStore(Builder.CreatePointerCast(Builder.CreateLoad(
230           Builder.CreateStructGEP(Arg_p, 0)), InTy), AIn);
231       Args++;
232     }
233 
234     llvm::Type *OutTy = NULL;
235     llvm::AllocaInst *AOut = NULL;
236     if (hasOut(Signature)) {
237       OutTy = Args->getType();
238       AOut = Builder.CreateAlloca(OutTy, 0, "AOut");
239       Builder.CreateStore(Builder.CreatePointerCast(Builder.CreateLoad(
240           Builder.CreateStructGEP(Arg_p, 1)), OutTy), AOut);
241       Args++;
242     }
243 
244     llvm::Value *UsrData = NULL;
245     if (hasUsrData(Signature)) {
246       llvm::Type *UsrDataTy = Args->getType();
247       UsrData = Builder.CreatePointerCast(Builder.CreateLoad(
248           Builder.CreateStructGEP(Arg_p, 2)), UsrDataTy);
249       UsrData->setName("UsrData");
250       Args++;
251     }
252 
253     if (hasX(Signature)) {
254       Args++;
255     }
256 
257     llvm::Value *Y = NULL;
258     if (hasY(Signature)) {
259       Y = Builder.CreateLoad(Builder.CreateStructGEP(Arg_p, 5), "Y");
260       Args++;
261     }
262 
263     bccAssert(Args == F->arg_end());
264 
265     llvm::BasicBlock *Loop = llvm::BasicBlock::Create(*C, "Loop", ExpandedFunc);
266     llvm::BasicBlock *Exit = llvm::BasicBlock::Create(*C, "Exit", ExpandedFunc);
267 
268     // if (x1 < x2) goto Loop; else goto Exit;
269     llvm::Value *Cond = Builder.CreateICmpSLT(Arg_x1, Arg_x2);
270     Builder.CreateCondBr(Cond, Loop, Exit);
271 
272     // Loop:
273     Builder.SetInsertPoint(Loop);
274 
275     // Populate the actual call to kernel().
276     llvm::SmallVector<llvm::Value*, 8> RootArgs;
277 
278     llvm::Value *In = NULL;
279     llvm::Value *Out = NULL;
280 
281     if (AIn) {
282       In = Builder.CreateLoad(AIn, "In");
283       RootArgs.push_back(In);
284     }
285 
286     if (AOut) {
287       Out = Builder.CreateLoad(AOut, "Out");
288       RootArgs.push_back(Out);
289     }
290 
291     if (UsrData) {
292       RootArgs.push_back(UsrData);
293     }
294 
295     // We always have to load X, since it is used to iterate through the loop.
296     llvm::Value *X = Builder.CreateLoad(AX, "X");
297     if (hasX(Signature)) {
298       RootArgs.push_back(X);
299     }
300 
301     if (Y) {
302       RootArgs.push_back(Y);
303     }
304 
305     Builder.CreateCall(F, RootArgs);
306 
307     if (In) {
308       // In += instep
309       llvm::Value *NewIn = Builder.CreateIntToPtr(Builder.CreateNUWAdd(
310           Builder.CreatePtrToInt(In, Int32Ty), Arg_instep), InTy);
311       Builder.CreateStore(NewIn, AIn);
312     }
313 
314     if (Out) {
315       // Out += outstep
316       llvm::Value *NewOut = Builder.CreateIntToPtr(Builder.CreateNUWAdd(
317           Builder.CreatePtrToInt(Out, Int32Ty), Arg_outstep), OutTy);
318       Builder.CreateStore(NewOut, AOut);
319     }
320 
321     // X++;
322     llvm::Value *XPlusOne =
323         Builder.CreateNUWAdd(X, llvm::ConstantInt::get(Int32Ty, 1));
324     Builder.CreateStore(XPlusOne, AX);
325 
326     // If (X < x2) goto Loop; else goto Exit;
327     Cond = Builder.CreateICmpSLT(XPlusOne, Arg_x2);
328     Builder.CreateCondBr(Cond, Loop, Exit);
329 
330     // Exit:
331     Builder.SetInsertPoint(Exit);
332     Builder.CreateRetVoid();
333 
334     return true;
335   }
336 
runOnModule(llvm::Module & M)337   virtual bool runOnModule(llvm::Module &M) {
338     bool Changed = false;
339     this->M = &M;
340     C = &M.getContext();
341 
342     bccAssert(mNames.size() == mSignatures.size());
343     for (int i = 0, e = mNames.size(); i != e; i++) {
344       llvm::Function *kernel = M.getFunction(mNames[i]);
345       if (kernel && kernel->getReturnType()->isVoidTy()) {
346         Changed |= ExpandFunction(kernel, mSignatures[i]);
347       }
348     }
349 
350     return Changed;
351   }
352 
getPassName() const353   virtual const char *getPassName() const {
354     return "ForEach-able Function Expansion";
355   }
356 
357   };
358 }  // end anonymous namespace
359 
360 char ForEachExpandPass::ID = 0;
361 
362 namespace bcc {
363 
createForEachExpandPass(std::vector<std::string> & Names,std::vector<uint32_t> & Signatures)364   llvm::ModulePass *createForEachExpandPass(std::vector<std::string>& Names,
365                                             std::vector<uint32_t>& Signatures) {
366     return new ForEachExpandPass(Names, Signatures);
367   }
368 
369 }  // namespace bcc
370