• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2015, The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "slang_rs_foreach_lowering.h"
18 
19 #include "clang/AST/ASTContext.h"
20 #include "clang/AST/Attr.h"
21 #include "llvm/Support/raw_ostream.h"
22 #include "slang_rs_context.h"
23 #include "slang_rs_export_foreach.h"
24 
25 namespace slang {
26 
27 namespace {
28 
29 const char KERNEL_LAUNCH_FUNCTION_NAME[] = "rsForEach";
30 const char KERNEL_LAUNCH_FUNCTION_NAME_WITH_OPTIONS[] = "rsForEachWithOptions";
31 const char INTERNAL_LAUNCH_FUNCTION_NAME[] =
32     "_Z17rsForEachInternaliP14rs_script_calliiP13rs_allocation";
33 
34 }  // anonymous namespace
35 
RSForEachLowering(RSContext * ctxt)36 RSForEachLowering::RSForEachLowering(RSContext* ctxt)
37     : mCtxt(ctxt), mASTCtxt(ctxt->getASTContext()) {}
38 
39 // Check if the passed-in expr references a kernel function in the following
40 // pattern in the AST.
41 //
42 // ImplicitCastExpr 'void *' <BitCast>
43 //  `-ImplicitCastExpr 'int (*)(int)' <FunctionToPointerDecay>
44 //    `-DeclRefExpr 'int (int)' Function 'foo' 'int (int)'
matchFunctionDesignator(clang::Expr * expr)45 const clang::FunctionDecl* RSForEachLowering::matchFunctionDesignator(
46     clang::Expr* expr) {
47   clang::ImplicitCastExpr* ToVoidPtr =
48       clang::dyn_cast<clang::ImplicitCastExpr>(expr);
49   if (ToVoidPtr == nullptr) {
50     return nullptr;
51   }
52 
53   clang::ImplicitCastExpr* Decay =
54       clang::dyn_cast<clang::ImplicitCastExpr>(ToVoidPtr->getSubExpr());
55 
56   if (Decay == nullptr) {
57     return nullptr;
58   }
59 
60   clang::DeclRefExpr* DRE =
61       clang::dyn_cast<clang::DeclRefExpr>(Decay->getSubExpr());
62 
63   if (DRE == nullptr) {
64     return nullptr;
65   }
66 
67   const clang::FunctionDecl* FD =
68       clang::dyn_cast<clang::FunctionDecl>(DRE->getDecl());
69 
70   if (FD == nullptr) {
71     return nullptr;
72   }
73 
74   return FD;
75 }
76 
77 // Checks if the call expression is a legal rsForEach call by looking for the
78 // following pattern in the AST. On success, returns the first argument that is
79 // a FunctionDecl of a kernel function.
80 //
81 // CallExpr 'void'
82 // |
83 // |-ImplicitCastExpr 'void (*)(void *, ...)' <FunctionToPointerDecay>
84 // | `-DeclRefExpr  'void (void *, ...)'  'rsForEach' 'void (void *, ...)'
85 // |
86 // |-ImplicitCastExpr 'void *' <BitCast>
87 // | `-ImplicitCastExpr 'int (*)(int)' <FunctionToPointerDecay>
88 // |   `-DeclRefExpr 'int (int)' Function 'foo' 'int (int)'
89 // |
90 // |-ImplicitCastExpr 'rs_allocation':'rs_allocation' <LValueToRValue>
91 // | `-DeclRefExpr 'rs_allocation':'rs_allocation' lvalue ParmVar 'in' 'rs_allocation':'rs_allocation'
92 // |
93 // `-ImplicitCastExpr 'rs_allocation':'rs_allocation' <LValueToRValue>
94 //   `-DeclRefExpr  'rs_allocation':'rs_allocation' lvalue ParmVar 'out' 'rs_allocation':'rs_allocation'
matchKernelLaunchCall(clang::CallExpr * CE,int * slot,bool * hasOptions)95 const clang::FunctionDecl* RSForEachLowering::matchKernelLaunchCall(
96     clang::CallExpr* CE, int* slot, bool* hasOptions) {
97   const clang::Decl* D = CE->getCalleeDecl();
98   const clang::FunctionDecl* FD = clang::dyn_cast<clang::FunctionDecl>(D);
99 
100   if (FD == nullptr) {
101     return nullptr;
102   }
103 
104   const clang::StringRef& funcName = FD->getName();
105 
106   if (funcName.equals(KERNEL_LAUNCH_FUNCTION_NAME)) {
107     *hasOptions = false;
108   } else if (funcName.equals(KERNEL_LAUNCH_FUNCTION_NAME_WITH_OPTIONS)) {
109     *hasOptions = true;
110   } else {
111     return nullptr;
112   }
113 
114   if (mInsideKernel) {
115     mCtxt->ReportError(CE->getExprLoc(),
116         "Invalid kernel launch call made from inside another kernel.");
117     return nullptr;
118   }
119 
120   clang::Expr* arg0 = CE->getArg(0);
121   const clang::FunctionDecl* kernel = matchFunctionDesignator(arg0);
122 
123   if (kernel == nullptr) {
124     mCtxt->ReportError(arg0->getExprLoc(),
125                        "Invalid kernel launch call. "
126                        "Expects a function designator for the first argument.");
127     return nullptr;
128   }
129 
130   // Verifies that kernel is indeed a "kernel" function.
131   *slot = mCtxt->getForEachSlotNumber(kernel);
132   if (*slot == -1) {
133     mCtxt->ReportError(CE->getExprLoc(),
134          "%0 applied to function %1 defined without \"kernel\" attribute")
135          << funcName << kernel->getName();
136     return nullptr;
137   }
138 
139   return kernel;
140 }
141 
142 // Create an AST node for the declaration of rsForEachInternal
CreateForEachInternalFunctionDecl()143 clang::FunctionDecl* RSForEachLowering::CreateForEachInternalFunctionDecl() {
144   clang::DeclContext* DC = mASTCtxt.getTranslationUnitDecl();
145   clang::SourceLocation Loc;
146 
147   llvm::StringRef SR(INTERNAL_LAUNCH_FUNCTION_NAME);
148   clang::IdentifierInfo& II = mASTCtxt.Idents.get(SR);
149   clang::DeclarationName N(&II);
150 
151   clang::FunctionProtoType::ExtProtoInfo EPI;
152 
153   const clang::QualType& AllocTy = mCtxt->getAllocationType();
154   clang::QualType AllocPtrTy = mASTCtxt.getPointerType(AllocTy);
155 
156   clang::QualType ScriptCallTy = mCtxt->getScriptCallType();
157   const clang::QualType ScriptCallPtrTy = mASTCtxt.getPointerType(ScriptCallTy);
158 
159   clang::QualType ParamTypes[] = {
160     mASTCtxt.IntTy,   // int slot
161     ScriptCallPtrTy,  // rs_script_call_t* launch_options
162     mASTCtxt.IntTy,   // int numOutput
163     mASTCtxt.IntTy,   // int numInputs
164     AllocPtrTy        // rs_allocation* allocs
165   };
166 
167   clang::QualType T = mASTCtxt.getFunctionType(
168       mASTCtxt.VoidTy,  // Return type
169       ParamTypes,       // Parameter types
170       EPI);
171 
172   clang::FunctionDecl* FD = clang::FunctionDecl::Create(
173       mASTCtxt, DC, Loc, Loc, N, T, nullptr, clang::SC_Extern);
174 
175   static constexpr unsigned kNumParams = sizeof(ParamTypes) / sizeof(ParamTypes[0]);
176   clang::ParmVarDecl *ParamDecls[kNumParams];
177   for (unsigned I = 0; I != kNumParams; ++I) {
178     ParamDecls[I] = clang::ParmVarDecl::Create(mASTCtxt, FD, Loc,
179         Loc, nullptr, ParamTypes[I], nullptr, clang::SC_None, nullptr);
180     // Implicit means that this declaration was created by the compiler, and
181     // not part of the actual source code.
182     ParamDecls[I]->setImplicit();
183   }
184   FD->setParams(llvm::makeArrayRef(ParamDecls, kNumParams));
185 
186   // Implicit means that this declaration was created by the compiler, and
187   // not part of the actual source code.
188   FD->setImplicit();
189 
190   return FD;
191 }
192 
193 // Create an expression like the following that references the rsForEachInternal to
194 // replace the callee in the original call expression that references rsForEach.
195 //
196 // ImplicitCastExpr 'void (*)(int, rs_script_call_t*, int, int, rs_allocation*)' <FunctionToPointerDecay>
197 // `-DeclRefExpr 'void' Function '_Z17rsForEachInternaliP14rs_script_calliiP13rs_allocation' 'void (int, rs_script_call_t*, int, int, rs_allocation*)'
CreateCalleeExprForInternalForEach()198 clang::Expr* RSForEachLowering::CreateCalleeExprForInternalForEach() {
199   clang::FunctionDecl* FDNew = CreateForEachInternalFunctionDecl();
200 
201   const clang::QualType FDNewType = FDNew->getType();
202 
203   clang::DeclRefExpr* refExpr = clang::DeclRefExpr::Create(
204       mASTCtxt, clang::NestedNameSpecifierLoc(), clang::SourceLocation(), FDNew,
205       false, clang::SourceLocation(), FDNewType, clang::VK_RValue);
206 
207   clang::Expr* calleeNew = clang::ImplicitCastExpr::Create(
208       mASTCtxt, mASTCtxt.getPointerType(FDNewType),
209       clang::CK_FunctionToPointerDecay, refExpr, nullptr, clang::VK_RValue);
210 
211   return calleeNew;
212 }
213 
214 // This visit method checks (via pattern matching) if the call expression is to
215 // rsForEach, and the arguments satisfy the restrictions on the
216 // rsForEach API. If so, replace the call with a rsForEachInternal call
217 // with the first argument replaced by the slot number of the kernel function
218 // referenced in the original first argument.
219 //
220 // See comments to the helper methods defined above for details.
VisitCallExpr(clang::CallExpr * CE)221 void RSForEachLowering::VisitCallExpr(clang::CallExpr* CE) {
222   int slot;
223   bool hasOptions;
224   const clang::FunctionDecl* kernel = matchKernelLaunchCall(CE, &slot, &hasOptions);
225   if (kernel == nullptr) {
226     return;
227   }
228 
229   slangAssert(slot >= 0);
230 
231   const unsigned numArgsOrig = CE->getNumArgs();
232 
233   clang::QualType resultType = kernel->getReturnType().getCanonicalType();
234   const unsigned numOutputsExpected = resultType->isVoidType() ? 0 : 1;
235 
236   const unsigned numInputsExpected = RSExportForEach::getNumInputs(mCtxt->getTargetAPI(), kernel);
237 
238   // Verifies that rsForEach takes the right number of input and output allocations.
239   // TODO: Check input/output allocation types match kernel function expectation.
240   const unsigned numAllocations = numArgsOrig - (hasOptions ? 2 : 1);
241   if (numInputsExpected + numOutputsExpected != numAllocations) {
242     mCtxt->ReportError(
243       CE->getExprLoc(),
244       "Number of input and output allocations unexpected for kernel function %0")
245     << kernel->getName();
246     return;
247   }
248 
249   clang::Expr* calleeNew = CreateCalleeExprForInternalForEach();
250   CE->setCallee(calleeNew);
251 
252   const clang::CanQualType IntTy = mASTCtxt.IntTy;
253   const unsigned IntTySize = mASTCtxt.getTypeSize(IntTy);
254   const llvm::APInt APIntSlot(IntTySize, slot);
255   const clang::Expr* arg0 = CE->getArg(0);
256   const clang::SourceLocation Loc(arg0->getLocStart());
257   clang::Expr* IntSlotNum =
258       clang::IntegerLiteral::Create(mASTCtxt, APIntSlot, IntTy, Loc);
259   CE->setArg(0, IntSlotNum);
260 
261   /*
262     The last few arguments to rsForEach or rsForEachWithOptions are allocations.
263     Creates a new compound literal of an array initialized with those values, and
264     passes it to rsForEachInternal as the last (the 5th) argument.
265 
266     For example, rsForEach(foo, ain1, ain2, aout) would be translated into
267     rsForEachInternal(
268         1,                                   // Slot number for kernel
269         NULL,                                // Launch options
270         2,                                   // Number of input allocations
271         1,                                   // Number of output allocations
272         (rs_allocation[]){ain1, ain2, aout)  // Input and output allocations
273     );
274 
275     The AST for the rs_allocation array looks like following:
276 
277     ImplicitCastExpr 0x99575670 'struct rs_allocation *' <ArrayToPointerDecay>
278     `-CompoundLiteralExpr 0x99575648 'struct rs_allocation [3]' lvalue
279       `-InitListExpr 0x99575590 'struct rs_allocation [3]'
280       |-ImplicitCastExpr 0x99574b38 'rs_allocation':'struct rs_allocation' <LValueToRValue>
281       | `-DeclRefExpr 0x99574a08 'rs_allocation':'struct rs_allocation' lvalue ParmVar 0x9942c408 'ain1' 'rs_allocation':'struct rs_allocation'
282       |-ImplicitCastExpr 0x99574b50 'rs_allocation':'struct rs_allocation' <LValueToRValue>
283       | `-DeclRefExpr 0x99574a30 'rs_allocation':'struct rs_allocation' lvalue ParmVar 0x9942c478 'ain2' 'rs_allocation':'struct rs_allocation'
284       `-ImplicitCastExpr 0x99574b68 'rs_allocation':'struct rs_allocation' <LValueToRValue>
285         `-DeclRefExpr 0x99574a58 'rs_allocation':'struct rs_allocation' lvalue ParmVar 0x9942c478 'aout' 'rs_allocation':'struct rs_allocation'
286   */
287 
288   const clang::QualType& AllocTy = mCtxt->getAllocationType();
289   const llvm::APInt APIntNumAllocs(IntTySize, numAllocations);
290   clang::QualType AllocArrayTy = mASTCtxt.getConstantArrayType(
291       AllocTy,
292       APIntNumAllocs,
293       clang::ArrayType::ArraySizeModifier::Normal,
294       0  // index type qualifiers
295   );
296 
297   const int allocArgIndexEnd = numArgsOrig - 1;
298   int allocArgIndexStart = allocArgIndexEnd;
299 
300   clang::Expr** args = CE->getArgs();
301 
302   clang::SourceLocation lparenloc;
303   clang::SourceLocation rparenloc;
304 
305   if (numAllocations > 0) {
306     allocArgIndexStart = hasOptions ? 2 : 1;
307     lparenloc = args[allocArgIndexStart]->getExprLoc();
308     rparenloc = args[allocArgIndexEnd]->getExprLoc();
309   }
310 
311   clang::InitListExpr* init = new (mASTCtxt) clang::InitListExpr(
312       mASTCtxt,
313       lparenloc,
314       llvm::ArrayRef<clang::Expr*>(args + allocArgIndexStart, numAllocations),
315       rparenloc);
316   init->setType(AllocArrayTy);
317 
318   clang::TypeSourceInfo* ti = mASTCtxt.getTrivialTypeSourceInfo(AllocArrayTy);
319   clang::CompoundLiteralExpr* CLE = new (mASTCtxt) clang::CompoundLiteralExpr(
320       lparenloc,
321       ti,
322       AllocArrayTy,
323       clang::VK_LValue,  // A compound literal is an l-value in C.
324       init,
325       false  // Not file scope
326   );
327 
328   const clang::QualType AllocPtrTy = mASTCtxt.getPointerType(AllocTy);
329 
330   clang::ImplicitCastExpr* Decay = clang::ImplicitCastExpr::Create(
331       mASTCtxt,
332       AllocPtrTy,
333       clang::CK_ArrayToPointerDecay,
334       CLE,
335       nullptr,  // C++ cast path
336       clang::VK_RValue
337   );
338 
339   CE->setNumArgs(mASTCtxt, 5);
340 
341   CE->setArg(4, Decay);
342 
343   // Sets the new arguments for NULL launch option (if the user does not set one),
344   // the number of outputs, and the number of inputs.
345 
346   if (!hasOptions) {
347     const llvm::APInt APIntZero(IntTySize, 0);
348     clang::Expr* IntNull =
349         clang::IntegerLiteral::Create(mASTCtxt, APIntZero, IntTy, Loc);
350     clang::QualType ScriptCallTy = mCtxt->getScriptCallType();
351     const clang::QualType ScriptCallPtrTy = mASTCtxt.getPointerType(ScriptCallTy);
352     clang::CStyleCastExpr* Cast =
353         clang::CStyleCastExpr::Create(mASTCtxt,
354                                       ScriptCallPtrTy,
355                                       clang::VK_RValue,
356                                       clang::CK_NullToPointer,
357                                       IntNull,
358                                       nullptr,
359                                       mASTCtxt.getTrivialTypeSourceInfo(ScriptCallPtrTy),
360                                       clang::SourceLocation(),
361                                       clang::SourceLocation());
362     CE->setArg(1, Cast);
363   }
364 
365   const llvm::APInt APIntNumOutput(IntTySize, numOutputsExpected);
366   clang::Expr* IntNumOutput =
367       clang::IntegerLiteral::Create(mASTCtxt, APIntNumOutput, IntTy, Loc);
368   CE->setArg(2, IntNumOutput);
369 
370   const llvm::APInt APIntNumInputs(IntTySize, numInputsExpected);
371   clang::Expr* IntNumInputs =
372       clang::IntegerLiteral::Create(mASTCtxt, APIntNumInputs, IntTy, Loc);
373   CE->setArg(3, IntNumInputs);
374 }
375 
VisitStmt(clang::Stmt * S)376 void RSForEachLowering::VisitStmt(clang::Stmt* S) {
377   for (clang::Stmt* Child : S->children()) {
378     if (Child) {
379       Visit(Child);
380     }
381   }
382 }
383 
handleForEachCalls(clang::FunctionDecl * FD,unsigned int targetAPI)384 void RSForEachLowering::handleForEachCalls(clang::FunctionDecl* FD,
385                                            unsigned int targetAPI) {
386   slangAssert(FD && FD->hasBody());
387 
388   mInsideKernel = FD->hasAttr<clang::KernelAttr>();
389   VisitStmt(FD->getBody());
390 }
391 
392 }  // namespace slang
393