1 /*
2 * Copyright 2015, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "slang_rs_foreach_lowering.h"
18
19 #include "clang/AST/ASTContext.h"
20 #include "clang/AST/Attr.h"
21 #include "llvm/Support/raw_ostream.h"
22 #include "slang_rs_context.h"
23 #include "slang_rs_export_foreach.h"
24
25 namespace slang {
26
27 namespace {
28
29 const char KERNEL_LAUNCH_FUNCTION_NAME[] = "rsForEach";
30 const char KERNEL_LAUNCH_FUNCTION_NAME_WITH_OPTIONS[] = "rsForEachWithOptions";
31 const char INTERNAL_LAUNCH_FUNCTION_NAME[] =
32 "_Z17rsForEachInternaliP14rs_script_calliiP13rs_allocation";
33
34 } // anonymous namespace
35
RSForEachLowering(RSContext * ctxt)36 RSForEachLowering::RSForEachLowering(RSContext* ctxt)
37 : mCtxt(ctxt), mASTCtxt(ctxt->getASTContext()) {}
38
39 // Check if the passed-in expr references a kernel function in the following
40 // pattern in the AST.
41 //
42 // ImplicitCastExpr 'void *' <BitCast>
43 // `-ImplicitCastExpr 'int (*)(int)' <FunctionToPointerDecay>
44 // `-DeclRefExpr 'int (int)' Function 'foo' 'int (int)'
matchFunctionDesignator(clang::Expr * expr)45 const clang::FunctionDecl* RSForEachLowering::matchFunctionDesignator(
46 clang::Expr* expr) {
47 clang::ImplicitCastExpr* ToVoidPtr =
48 clang::dyn_cast<clang::ImplicitCastExpr>(expr);
49 if (ToVoidPtr == nullptr) {
50 return nullptr;
51 }
52
53 clang::ImplicitCastExpr* Decay =
54 clang::dyn_cast<clang::ImplicitCastExpr>(ToVoidPtr->getSubExpr());
55
56 if (Decay == nullptr) {
57 return nullptr;
58 }
59
60 clang::DeclRefExpr* DRE =
61 clang::dyn_cast<clang::DeclRefExpr>(Decay->getSubExpr());
62
63 if (DRE == nullptr) {
64 return nullptr;
65 }
66
67 const clang::FunctionDecl* FD =
68 clang::dyn_cast<clang::FunctionDecl>(DRE->getDecl());
69
70 if (FD == nullptr) {
71 return nullptr;
72 }
73
74 return FD;
75 }
76
77 // Checks if the call expression is a legal rsForEach call by looking for the
78 // following pattern in the AST. On success, returns the first argument that is
79 // a FunctionDecl of a kernel function.
80 //
81 // CallExpr 'void'
82 // |
83 // |-ImplicitCastExpr 'void (*)(void *, ...)' <FunctionToPointerDecay>
84 // | `-DeclRefExpr 'void (void *, ...)' 'rsForEach' 'void (void *, ...)'
85 // |
86 // |-ImplicitCastExpr 'void *' <BitCast>
87 // | `-ImplicitCastExpr 'int (*)(int)' <FunctionToPointerDecay>
88 // | `-DeclRefExpr 'int (int)' Function 'foo' 'int (int)'
89 // |
90 // |-ImplicitCastExpr 'rs_allocation':'rs_allocation' <LValueToRValue>
91 // | `-DeclRefExpr 'rs_allocation':'rs_allocation' lvalue ParmVar 'in' 'rs_allocation':'rs_allocation'
92 // |
93 // `-ImplicitCastExpr 'rs_allocation':'rs_allocation' <LValueToRValue>
94 // `-DeclRefExpr 'rs_allocation':'rs_allocation' lvalue ParmVar 'out' 'rs_allocation':'rs_allocation'
matchKernelLaunchCall(clang::CallExpr * CE,int * slot,bool * hasOptions)95 const clang::FunctionDecl* RSForEachLowering::matchKernelLaunchCall(
96 clang::CallExpr* CE, int* slot, bool* hasOptions) {
97 const clang::Decl* D = CE->getCalleeDecl();
98 const clang::FunctionDecl* FD = clang::dyn_cast<clang::FunctionDecl>(D);
99
100 if (FD == nullptr) {
101 return nullptr;
102 }
103
104 const clang::StringRef& funcName = FD->getName();
105
106 if (funcName.equals(KERNEL_LAUNCH_FUNCTION_NAME)) {
107 *hasOptions = false;
108 } else if (funcName.equals(KERNEL_LAUNCH_FUNCTION_NAME_WITH_OPTIONS)) {
109 *hasOptions = true;
110 } else {
111 return nullptr;
112 }
113
114 if (mInsideKernel) {
115 mCtxt->ReportError(CE->getExprLoc(),
116 "Invalid kernel launch call made from inside another kernel.");
117 return nullptr;
118 }
119
120 clang::Expr* arg0 = CE->getArg(0);
121 const clang::FunctionDecl* kernel = matchFunctionDesignator(arg0);
122
123 if (kernel == nullptr) {
124 mCtxt->ReportError(arg0->getExprLoc(),
125 "Invalid kernel launch call. "
126 "Expects a function designator for the first argument.");
127 return nullptr;
128 }
129
130 // Verifies that kernel is indeed a "kernel" function.
131 *slot = mCtxt->getForEachSlotNumber(kernel);
132 if (*slot == -1) {
133 mCtxt->ReportError(CE->getExprLoc(),
134 "%0 applied to function %1 defined without \"kernel\" attribute")
135 << funcName << kernel->getName();
136 return nullptr;
137 }
138
139 return kernel;
140 }
141
142 // Create an AST node for the declaration of rsForEachInternal
CreateForEachInternalFunctionDecl()143 clang::FunctionDecl* RSForEachLowering::CreateForEachInternalFunctionDecl() {
144 clang::DeclContext* DC = mASTCtxt.getTranslationUnitDecl();
145 clang::SourceLocation Loc;
146
147 llvm::StringRef SR(INTERNAL_LAUNCH_FUNCTION_NAME);
148 clang::IdentifierInfo& II = mASTCtxt.Idents.get(SR);
149 clang::DeclarationName N(&II);
150
151 clang::FunctionProtoType::ExtProtoInfo EPI;
152
153 const clang::QualType& AllocTy = mCtxt->getAllocationType();
154 clang::QualType AllocPtrTy = mASTCtxt.getPointerType(AllocTy);
155
156 clang::QualType ScriptCallTy = mCtxt->getScriptCallType();
157 const clang::QualType ScriptCallPtrTy = mASTCtxt.getPointerType(ScriptCallTy);
158
159 clang::QualType ParamTypes[] = {
160 mASTCtxt.IntTy, // int slot
161 ScriptCallPtrTy, // rs_script_call_t* launch_options
162 mASTCtxt.IntTy, // int numOutput
163 mASTCtxt.IntTy, // int numInputs
164 AllocPtrTy // rs_allocation* allocs
165 };
166
167 clang::QualType T = mASTCtxt.getFunctionType(
168 mASTCtxt.VoidTy, // Return type
169 ParamTypes, // Parameter types
170 EPI);
171
172 clang::FunctionDecl* FD = clang::FunctionDecl::Create(
173 mASTCtxt, DC, Loc, Loc, N, T, nullptr, clang::SC_Extern);
174
175 static constexpr unsigned kNumParams = sizeof(ParamTypes) / sizeof(ParamTypes[0]);
176 clang::ParmVarDecl *ParamDecls[kNumParams];
177 for (unsigned I = 0; I != kNumParams; ++I) {
178 ParamDecls[I] = clang::ParmVarDecl::Create(mASTCtxt, FD, Loc,
179 Loc, nullptr, ParamTypes[I], nullptr, clang::SC_None, nullptr);
180 // Implicit means that this declaration was created by the compiler, and
181 // not part of the actual source code.
182 ParamDecls[I]->setImplicit();
183 }
184 FD->setParams(llvm::makeArrayRef(ParamDecls, kNumParams));
185
186 // Implicit means that this declaration was created by the compiler, and
187 // not part of the actual source code.
188 FD->setImplicit();
189
190 return FD;
191 }
192
193 // Create an expression like the following that references the rsForEachInternal to
194 // replace the callee in the original call expression that references rsForEach.
195 //
196 // ImplicitCastExpr 'void (*)(int, rs_script_call_t*, int, int, rs_allocation*)' <FunctionToPointerDecay>
197 // `-DeclRefExpr 'void' Function '_Z17rsForEachInternaliP14rs_script_calliiP13rs_allocation' 'void (int, rs_script_call_t*, int, int, rs_allocation*)'
CreateCalleeExprForInternalForEach()198 clang::Expr* RSForEachLowering::CreateCalleeExprForInternalForEach() {
199 clang::FunctionDecl* FDNew = CreateForEachInternalFunctionDecl();
200
201 const clang::QualType FDNewType = FDNew->getType();
202
203 clang::DeclRefExpr* refExpr = clang::DeclRefExpr::Create(
204 mASTCtxt, clang::NestedNameSpecifierLoc(), clang::SourceLocation(), FDNew,
205 false, clang::SourceLocation(), FDNewType, clang::VK_RValue);
206
207 clang::Expr* calleeNew = clang::ImplicitCastExpr::Create(
208 mASTCtxt, mASTCtxt.getPointerType(FDNewType),
209 clang::CK_FunctionToPointerDecay, refExpr, nullptr, clang::VK_RValue);
210
211 return calleeNew;
212 }
213
214 // This visit method checks (via pattern matching) if the call expression is to
215 // rsForEach, and the arguments satisfy the restrictions on the
216 // rsForEach API. If so, replace the call with a rsForEachInternal call
217 // with the first argument replaced by the slot number of the kernel function
218 // referenced in the original first argument.
219 //
220 // See comments to the helper methods defined above for details.
VisitCallExpr(clang::CallExpr * CE)221 void RSForEachLowering::VisitCallExpr(clang::CallExpr* CE) {
222 int slot;
223 bool hasOptions;
224 const clang::FunctionDecl* kernel = matchKernelLaunchCall(CE, &slot, &hasOptions);
225 if (kernel == nullptr) {
226 return;
227 }
228
229 slangAssert(slot >= 0);
230
231 const unsigned numArgsOrig = CE->getNumArgs();
232
233 clang::QualType resultType = kernel->getReturnType().getCanonicalType();
234 const unsigned numOutputsExpected = resultType->isVoidType() ? 0 : 1;
235
236 const unsigned numInputsExpected = RSExportForEach::getNumInputs(mCtxt->getTargetAPI(), kernel);
237
238 // Verifies that rsForEach takes the right number of input and output allocations.
239 // TODO: Check input/output allocation types match kernel function expectation.
240 const unsigned numAllocations = numArgsOrig - (hasOptions ? 2 : 1);
241 if (numInputsExpected + numOutputsExpected != numAllocations) {
242 mCtxt->ReportError(
243 CE->getExprLoc(),
244 "Number of input and output allocations unexpected for kernel function %0")
245 << kernel->getName();
246 return;
247 }
248
249 clang::Expr* calleeNew = CreateCalleeExprForInternalForEach();
250 CE->setCallee(calleeNew);
251
252 const clang::CanQualType IntTy = mASTCtxt.IntTy;
253 const unsigned IntTySize = mASTCtxt.getTypeSize(IntTy);
254 const llvm::APInt APIntSlot(IntTySize, slot);
255 const clang::Expr* arg0 = CE->getArg(0);
256 const clang::SourceLocation Loc(arg0->getLocStart());
257 clang::Expr* IntSlotNum =
258 clang::IntegerLiteral::Create(mASTCtxt, APIntSlot, IntTy, Loc);
259 CE->setArg(0, IntSlotNum);
260
261 /*
262 The last few arguments to rsForEach or rsForEachWithOptions are allocations.
263 Creates a new compound literal of an array initialized with those values, and
264 passes it to rsForEachInternal as the last (the 5th) argument.
265
266 For example, rsForEach(foo, ain1, ain2, aout) would be translated into
267 rsForEachInternal(
268 1, // Slot number for kernel
269 NULL, // Launch options
270 2, // Number of input allocations
271 1, // Number of output allocations
272 (rs_allocation[]){ain1, ain2, aout) // Input and output allocations
273 );
274
275 The AST for the rs_allocation array looks like following:
276
277 ImplicitCastExpr 0x99575670 'struct rs_allocation *' <ArrayToPointerDecay>
278 `-CompoundLiteralExpr 0x99575648 'struct rs_allocation [3]' lvalue
279 `-InitListExpr 0x99575590 'struct rs_allocation [3]'
280 |-ImplicitCastExpr 0x99574b38 'rs_allocation':'struct rs_allocation' <LValueToRValue>
281 | `-DeclRefExpr 0x99574a08 'rs_allocation':'struct rs_allocation' lvalue ParmVar 0x9942c408 'ain1' 'rs_allocation':'struct rs_allocation'
282 |-ImplicitCastExpr 0x99574b50 'rs_allocation':'struct rs_allocation' <LValueToRValue>
283 | `-DeclRefExpr 0x99574a30 'rs_allocation':'struct rs_allocation' lvalue ParmVar 0x9942c478 'ain2' 'rs_allocation':'struct rs_allocation'
284 `-ImplicitCastExpr 0x99574b68 'rs_allocation':'struct rs_allocation' <LValueToRValue>
285 `-DeclRefExpr 0x99574a58 'rs_allocation':'struct rs_allocation' lvalue ParmVar 0x9942c478 'aout' 'rs_allocation':'struct rs_allocation'
286 */
287
288 const clang::QualType& AllocTy = mCtxt->getAllocationType();
289 const llvm::APInt APIntNumAllocs(IntTySize, numAllocations);
290 clang::QualType AllocArrayTy = mASTCtxt.getConstantArrayType(
291 AllocTy,
292 APIntNumAllocs,
293 clang::ArrayType::ArraySizeModifier::Normal,
294 0 // index type qualifiers
295 );
296
297 const int allocArgIndexEnd = numArgsOrig - 1;
298 int allocArgIndexStart = allocArgIndexEnd;
299
300 clang::Expr** args = CE->getArgs();
301
302 clang::SourceLocation lparenloc;
303 clang::SourceLocation rparenloc;
304
305 if (numAllocations > 0) {
306 allocArgIndexStart = hasOptions ? 2 : 1;
307 lparenloc = args[allocArgIndexStart]->getExprLoc();
308 rparenloc = args[allocArgIndexEnd]->getExprLoc();
309 }
310
311 clang::InitListExpr* init = new (mASTCtxt) clang::InitListExpr(
312 mASTCtxt,
313 lparenloc,
314 llvm::ArrayRef<clang::Expr*>(args + allocArgIndexStart, numAllocations),
315 rparenloc);
316 init->setType(AllocArrayTy);
317
318 clang::TypeSourceInfo* ti = mASTCtxt.getTrivialTypeSourceInfo(AllocArrayTy);
319 clang::CompoundLiteralExpr* CLE = new (mASTCtxt) clang::CompoundLiteralExpr(
320 lparenloc,
321 ti,
322 AllocArrayTy,
323 clang::VK_LValue, // A compound literal is an l-value in C.
324 init,
325 false // Not file scope
326 );
327
328 const clang::QualType AllocPtrTy = mASTCtxt.getPointerType(AllocTy);
329
330 clang::ImplicitCastExpr* Decay = clang::ImplicitCastExpr::Create(
331 mASTCtxt,
332 AllocPtrTy,
333 clang::CK_ArrayToPointerDecay,
334 CLE,
335 nullptr, // C++ cast path
336 clang::VK_RValue
337 );
338
339 CE->setNumArgs(mASTCtxt, 5);
340
341 CE->setArg(4, Decay);
342
343 // Sets the new arguments for NULL launch option (if the user does not set one),
344 // the number of outputs, and the number of inputs.
345
346 if (!hasOptions) {
347 const llvm::APInt APIntZero(IntTySize, 0);
348 clang::Expr* IntNull =
349 clang::IntegerLiteral::Create(mASTCtxt, APIntZero, IntTy, Loc);
350 clang::QualType ScriptCallTy = mCtxt->getScriptCallType();
351 const clang::QualType ScriptCallPtrTy = mASTCtxt.getPointerType(ScriptCallTy);
352 clang::CStyleCastExpr* Cast =
353 clang::CStyleCastExpr::Create(mASTCtxt,
354 ScriptCallPtrTy,
355 clang::VK_RValue,
356 clang::CK_NullToPointer,
357 IntNull,
358 nullptr,
359 mASTCtxt.getTrivialTypeSourceInfo(ScriptCallPtrTy),
360 clang::SourceLocation(),
361 clang::SourceLocation());
362 CE->setArg(1, Cast);
363 }
364
365 const llvm::APInt APIntNumOutput(IntTySize, numOutputsExpected);
366 clang::Expr* IntNumOutput =
367 clang::IntegerLiteral::Create(mASTCtxt, APIntNumOutput, IntTy, Loc);
368 CE->setArg(2, IntNumOutput);
369
370 const llvm::APInt APIntNumInputs(IntTySize, numInputsExpected);
371 clang::Expr* IntNumInputs =
372 clang::IntegerLiteral::Create(mASTCtxt, APIntNumInputs, IntTy, Loc);
373 CE->setArg(3, IntNumInputs);
374 }
375
VisitStmt(clang::Stmt * S)376 void RSForEachLowering::VisitStmt(clang::Stmt* S) {
377 for (clang::Stmt* Child : S->children()) {
378 if (Child) {
379 Visit(Child);
380 }
381 }
382 }
383
handleForEachCalls(clang::FunctionDecl * FD,unsigned int targetAPI)384 void RSForEachLowering::handleForEachCalls(clang::FunctionDecl* FD,
385 unsigned int targetAPI) {
386 slangAssert(FD && FD->hasBody());
387
388 mInsideKernel = FD->hasAttr<clang::KernelAttr>();
389 VisitStmt(FD->getBody());
390 }
391
392 } // namespace slang
393