• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2010, The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "slang_rs_backend.h"
18 
19 #include <string>
20 #include <vector>
21 
22 #include "llvm/ADT/Twine.h"
23 #include "llvm/ADT/StringExtras.h"
24 
25 #include "llvm/Constant.h"
26 #include "llvm/Constants.h"
27 #include "llvm/DerivedTypes.h"
28 #include "llvm/Function.h"
29 #include "llvm/Metadata.h"
30 #include "llvm/Module.h"
31 
32 #include "llvm/Support/IRBuilder.h"
33 
34 #include "slang_assert.h"
35 #include "slang_rs.h"
36 #include "slang_rs_context.h"
37 #include "slang_rs_export_foreach.h"
38 #include "slang_rs_export_func.h"
39 #include "slang_rs_export_type.h"
40 #include "slang_rs_export_var.h"
41 #include "slang_rs_metadata.h"
42 
43 namespace slang {
44 
RSBackend(RSContext * Context,clang::Diagnostic * Diags,const clang::CodeGenOptions & CodeGenOpts,const clang::TargetOptions & TargetOpts,PragmaList * Pragmas,llvm::raw_ostream * OS,Slang::OutputType OT,clang::SourceManager & SourceMgr,bool AllowRSPrefix)45 RSBackend::RSBackend(RSContext *Context,
46                      clang::Diagnostic *Diags,
47                      const clang::CodeGenOptions &CodeGenOpts,
48                      const clang::TargetOptions &TargetOpts,
49                      PragmaList *Pragmas,
50                      llvm::raw_ostream *OS,
51                      Slang::OutputType OT,
52                      clang::SourceManager &SourceMgr,
53                      bool AllowRSPrefix)
54     : Backend(Diags,
55               CodeGenOpts,
56               TargetOpts,
57               Pragmas,
58               OS,
59               OT),
60       mContext(Context),
61       mSourceMgr(SourceMgr),
62       mAllowRSPrefix(AllowRSPrefix),
63       mExportVarMetadata(NULL),
64       mExportFuncMetadata(NULL),
65       mExportForEachMetadata(NULL),
66       mExportTypeMetadata(NULL),
67       mRSObjectSlotsMetadata(NULL),
68       mRefCount(mContext->getASTContext()) {
69   return;
70 }
71 
72 // 1) Add zero initialization of local RS object types
AnnotateFunction(clang::FunctionDecl * FD)73 void RSBackend::AnnotateFunction(clang::FunctionDecl *FD) {
74   if (FD &&
75       FD->hasBody() &&
76       !SlangRS::IsFunctionInRSHeaderFile(FD, mSourceMgr)) {
77     mRefCount.Init();
78     mRefCount.Visit(FD->getBody());
79   }
80   return;
81 }
82 
HandleTopLevelDecl(clang::DeclGroupRef D)83 void RSBackend::HandleTopLevelDecl(clang::DeclGroupRef D) {
84   // Disallow user-defined functions with prefix "rs"
85   if (!mAllowRSPrefix) {
86     // Iterate all function declarations in the program.
87     for (clang::DeclGroupRef::iterator I = D.begin(), E = D.end();
88          I != E; I++) {
89       clang::FunctionDecl *FD = llvm::dyn_cast<clang::FunctionDecl>(*I);
90       if (FD == NULL)
91         continue;
92       if (!FD->getName().startswith("rs"))  // Check prefix
93         continue;
94       if (!SlangRS::IsFunctionInRSHeaderFile(FD, mSourceMgr))
95         mDiags.Report(clang::FullSourceLoc(FD->getLocation(), mSourceMgr),
96                       mDiags.getCustomDiagID(clang::Diagnostic::Error,
97                                              "invalid function name prefix, "
98                                              "\"rs\" is reserved: '%0'"))
99             << FD->getName();
100     }
101   }
102 
103   // Process any non-static function declarations
104   for (clang::DeclGroupRef::iterator I = D.begin(), E = D.end(); I != E; I++) {
105     clang::FunctionDecl *FD = llvm::dyn_cast<clang::FunctionDecl>(*I);
106     if (FD && FD->isGlobal()) {
107       AnnotateFunction(FD);
108     }
109   }
110 
111   Backend::HandleTopLevelDecl(D);
112   return;
113 }
114 
115 namespace {
116 
ValidateVarDecl(clang::VarDecl * VD)117 static bool ValidateVarDecl(clang::VarDecl *VD) {
118   if (!VD) {
119     return true;
120   }
121 
122   clang::ASTContext &C = VD->getASTContext();
123   const clang::Type *T = VD->getType().getTypePtr();
124   bool valid = true;
125 
126   if (VD->getLinkage() == clang::ExternalLinkage) {
127     llvm::StringRef TypeName;
128     if (!RSExportType::NormalizeType(T, TypeName, &C.getDiagnostics(), VD)) {
129       valid = false;
130     }
131   }
132   valid &= RSExportType::ValidateVarDecl(VD);
133 
134   return valid;
135 }
136 
ValidateASTContext(clang::ASTContext & C)137 static bool ValidateASTContext(clang::ASTContext &C) {
138   bool valid = true;
139   clang::TranslationUnitDecl *TUDecl = C.getTranslationUnitDecl();
140   for (clang::DeclContext::decl_iterator DI = TUDecl->decls_begin(),
141           DE = TUDecl->decls_end();
142        DI != DE;
143        DI++) {
144     clang::VarDecl *VD = llvm::dyn_cast<clang::VarDecl>(*DI);
145     if (VD && !ValidateVarDecl(VD)) {
146       valid = false;
147     }
148   }
149 
150   return valid;
151 }
152 
153 }  // namespace
154 
HandleTranslationUnitPre(clang::ASTContext & C)155 void RSBackend::HandleTranslationUnitPre(clang::ASTContext &C) {
156   clang::TranslationUnitDecl *TUDecl = C.getTranslationUnitDecl();
157 
158   if (!ValidateASTContext(C)) {
159     return;
160   }
161 
162   int version = mContext->getVersion();
163   if (version == 0) {
164     // Not setting a version is an error
165     mDiags.Report(mDiags.getCustomDiagID(clang::Diagnostic::Error,
166                       "Missing pragma for version in source file"));
167   } else if (version > 1) {
168     mDiags.Report(mDiags.getCustomDiagID(clang::Diagnostic::Error,
169                       "Pragma for version in source file must be set to 1"));
170   }
171 
172   // Create a static global destructor if necessary (to handle RS object
173   // runtime cleanup).
174   clang::FunctionDecl *FD = mRefCount.CreateStaticGlobalDtor();
175   if (FD) {
176     HandleTopLevelDecl(clang::DeclGroupRef(FD));
177   }
178 
179   // Process any static function declarations
180   for (clang::DeclContext::decl_iterator I = TUDecl->decls_begin(),
181           E = TUDecl->decls_end(); I != E; I++) {
182     if ((I->getKind() >= clang::Decl::firstFunction) &&
183         (I->getKind() <= clang::Decl::lastFunction)) {
184       clang::FunctionDecl *FD = llvm::dyn_cast<clang::FunctionDecl>(*I);
185       if (FD && !FD->isGlobal()) {
186         AnnotateFunction(FD);
187       }
188     }
189   }
190 
191   return;
192 }
193 
194 ///////////////////////////////////////////////////////////////////////////////
HandleTranslationUnitPost(llvm::Module * M)195 void RSBackend::HandleTranslationUnitPost(llvm::Module *M) {
196   if (!mContext->processExport()) {
197     return;
198   }
199 
200   // Dump export variable info
201   if (mContext->hasExportVar()) {
202     int slotCount = 0;
203     if (mExportVarMetadata == NULL)
204       mExportVarMetadata = M->getOrInsertNamedMetadata(RS_EXPORT_VAR_MN);
205 
206     llvm::SmallVector<llvm::Value*, 2> ExportVarInfo;
207 
208     // We emit slot information (#rs_object_slots) for any reference counted
209     // RS type or pointer (which can also be bound).
210 
211     for (RSContext::const_export_var_iterator I = mContext->export_vars_begin(),
212             E = mContext->export_vars_end();
213          I != E;
214          I++) {
215       const RSExportVar *EV = *I;
216       const RSExportType *ET = EV->getType();
217       bool countsAsRSObject = false;
218 
219       // Variable name
220       ExportVarInfo.push_back(
221           llvm::MDString::get(mLLVMContext, EV->getName().c_str()));
222 
223       // Type name
224       switch (ET->getClass()) {
225         case RSExportType::ExportClassPrimitive: {
226           const RSExportPrimitiveType *PT =
227               static_cast<const RSExportPrimitiveType*>(ET);
228           ExportVarInfo.push_back(
229               llvm::MDString::get(
230                 mLLVMContext, llvm::utostr_32(PT->getType())));
231           if (PT->isRSObjectType()) {
232             countsAsRSObject = true;
233           }
234           break;
235         }
236         case RSExportType::ExportClassPointer: {
237           ExportVarInfo.push_back(
238               llvm::MDString::get(
239                 mLLVMContext, ("*" + static_cast<const RSExportPointerType*>(ET)
240                   ->getPointeeType()->getName()).c_str()));
241           break;
242         }
243         case RSExportType::ExportClassMatrix: {
244           ExportVarInfo.push_back(
245               llvm::MDString::get(
246                 mLLVMContext, llvm::utostr_32(
247                   RSExportPrimitiveType::DataTypeRSMatrix2x2 +
248                   static_cast<const RSExportMatrixType*>(ET)->getDim() - 2)));
249           break;
250         }
251         case RSExportType::ExportClassVector:
252         case RSExportType::ExportClassConstantArray:
253         case RSExportType::ExportClassRecord: {
254           ExportVarInfo.push_back(
255               llvm::MDString::get(mLLVMContext,
256                 EV->getType()->getName().c_str()));
257           break;
258         }
259       }
260 
261       mExportVarMetadata->addOperand(
262           llvm::MDNode::get(mLLVMContext, ExportVarInfo));
263       ExportVarInfo.clear();
264 
265       if (mRSObjectSlotsMetadata == NULL) {
266         mRSObjectSlotsMetadata =
267             M->getOrInsertNamedMetadata(RS_OBJECT_SLOTS_MN);
268       }
269 
270       if (countsAsRSObject) {
271         mRSObjectSlotsMetadata->addOperand(llvm::MDNode::get(mLLVMContext,
272             llvm::MDString::get(mLLVMContext, llvm::utostr_32(slotCount))));
273       }
274 
275       slotCount++;
276     }
277   }
278 
279   // Dump export function info
280   if (mContext->hasExportFunc()) {
281     if (mExportFuncMetadata == NULL)
282       mExportFuncMetadata =
283           M->getOrInsertNamedMetadata(RS_EXPORT_FUNC_MN);
284 
285     llvm::SmallVector<llvm::Value*, 1> ExportFuncInfo;
286 
287     for (RSContext::const_export_func_iterator
288             I = mContext->export_funcs_begin(),
289             E = mContext->export_funcs_end();
290          I != E;
291          I++) {
292       const RSExportFunc *EF = *I;
293 
294       // Function name
295       if (!EF->hasParam()) {
296         ExportFuncInfo.push_back(llvm::MDString::get(mLLVMContext,
297                                                      EF->getName().c_str()));
298       } else {
299         llvm::Function *F = M->getFunction(EF->getName());
300         llvm::Function *HelperFunction;
301         const std::string HelperFunctionName(".helper_" + EF->getName());
302 
303         slangAssert(F && "Function marked as exported disappeared in Bitcode");
304 
305         // Create helper function
306         {
307           llvm::StructType *HelperFunctionParameterTy = NULL;
308 
309           if (!F->getArgumentList().empty()) {
310             std::vector<llvm::Type*> HelperFunctionParameterTys;
311             for (llvm::Function::arg_iterator AI = F->arg_begin(),
312                  AE = F->arg_end(); AI != AE; AI++)
313               HelperFunctionParameterTys.push_back(AI->getType());
314 
315             HelperFunctionParameterTy =
316                 llvm::StructType::get(mLLVMContext, HelperFunctionParameterTys);
317           }
318 
319           if (!EF->checkParameterPacketType(HelperFunctionParameterTy)) {
320             fprintf(stderr, "Failed to export function %s: parameter type "
321                             "mismatch during creation of helper function.\n",
322                     EF->getName().c_str());
323 
324             const RSExportRecordType *Expected = EF->getParamPacketType();
325             if (Expected) {
326               fprintf(stderr, "Expected:\n");
327               Expected->getLLVMType()->dump();
328             }
329             if (HelperFunctionParameterTy) {
330               fprintf(stderr, "Got:\n");
331               HelperFunctionParameterTy->dump();
332             }
333           }
334 
335           std::vector<llvm::Type*> Params;
336           if (HelperFunctionParameterTy) {
337             llvm::PointerType *HelperFunctionParameterTyP =
338                 llvm::PointerType::getUnqual(HelperFunctionParameterTy);
339             Params.push_back(HelperFunctionParameterTyP);
340           }
341 
342           llvm::FunctionType * HelperFunctionType =
343               llvm::FunctionType::get(F->getReturnType(),
344                                       Params,
345                                       /* IsVarArgs = */false);
346 
347           HelperFunction =
348               llvm::Function::Create(HelperFunctionType,
349                                      llvm::GlobalValue::ExternalLinkage,
350                                      HelperFunctionName,
351                                      M);
352 
353           HelperFunction->addFnAttr(llvm::Attribute::NoInline);
354           HelperFunction->setCallingConv(F->getCallingConv());
355 
356           // Create helper function body
357           {
358             llvm::Argument *HelperFunctionParameter =
359                 &(*HelperFunction->arg_begin());
360             llvm::BasicBlock *BB =
361                 llvm::BasicBlock::Create(mLLVMContext, "entry", HelperFunction);
362             llvm::IRBuilder<> *IB = new llvm::IRBuilder<>(BB);
363             llvm::SmallVector<llvm::Value*, 6> Params;
364             llvm::Value *Idx[2];
365 
366             Idx[0] =
367                 llvm::ConstantInt::get(llvm::Type::getInt32Ty(mLLVMContext), 0);
368 
369             // getelementptr and load instruction for all elements in
370             // parameter .p
371             for (size_t i = 0; i < EF->getNumParameters(); i++) {
372               // getelementptr
373               Idx[1] =
374                   llvm::ConstantInt::get(
375                       llvm::Type::getInt32Ty(mLLVMContext), i);
376               llvm::Value *Ptr = IB->CreateInBoundsGEP(HelperFunctionParameter,
377                                                        Idx,
378                                                        Idx + 2);
379 
380               // load
381               llvm::Value *V = IB->CreateLoad(Ptr);
382               Params.push_back(V);
383             }
384 
385             // Call and pass the all elements as parameter to F
386             llvm::CallInst *CI = IB->CreateCall(F, Params);
387 
388             CI->setCallingConv(F->getCallingConv());
389 
390             if (F->getReturnType() == llvm::Type::getVoidTy(mLLVMContext))
391               IB->CreateRetVoid();
392             else
393               IB->CreateRet(CI);
394 
395             delete IB;
396           }
397         }
398 
399         ExportFuncInfo.push_back(
400             llvm::MDString::get(mLLVMContext, HelperFunctionName.c_str()));
401       }
402 
403       mExportFuncMetadata->addOperand(
404           llvm::MDNode::get(mLLVMContext, ExportFuncInfo));
405       ExportFuncInfo.clear();
406     }
407   }
408 
409   // Dump export function info
410   if (mContext->hasExportForEach()) {
411     if (mExportForEachMetadata == NULL)
412       mExportForEachMetadata =
413           M->getOrInsertNamedMetadata(RS_EXPORT_FOREACH_MN);
414 
415     llvm::SmallVector<llvm::Value*, 1> ExportForEachInfo;
416 
417     for (RSContext::const_export_foreach_iterator
418             I = mContext->export_foreach_begin(),
419             E = mContext->export_foreach_end();
420          I != E;
421          I++) {
422       const RSExportForEach *EFE = *I;
423 
424       ExportForEachInfo.push_back(
425           llvm::MDString::get(mLLVMContext,
426                               llvm::utostr_32(EFE->getMetadataEncoding())));
427 
428       mExportForEachMetadata->addOperand(
429           llvm::MDNode::get(mLLVMContext, ExportForEachInfo));
430       ExportForEachInfo.clear();
431     }
432   }
433 
434   // Dump export type info
435   if (mContext->hasExportType()) {
436     llvm::SmallVector<llvm::Value*, 1> ExportTypeInfo;
437 
438     for (RSContext::const_export_type_iterator
439             I = mContext->export_types_begin(),
440             E = mContext->export_types_end();
441          I != E;
442          I++) {
443       // First, dump type name list to export
444       const RSExportType *ET = I->getValue();
445 
446       ExportTypeInfo.clear();
447       // Type name
448       ExportTypeInfo.push_back(
449           llvm::MDString::get(mLLVMContext, ET->getName().c_str()));
450 
451       if (ET->getClass() == RSExportType::ExportClassRecord) {
452         const RSExportRecordType *ERT =
453             static_cast<const RSExportRecordType*>(ET);
454 
455         if (mExportTypeMetadata == NULL)
456           mExportTypeMetadata =
457               M->getOrInsertNamedMetadata(RS_EXPORT_TYPE_MN);
458 
459         mExportTypeMetadata->addOperand(
460             llvm::MDNode::get(mLLVMContext, ExportTypeInfo));
461 
462         // Now, export struct field information to %[struct name]
463         std::string StructInfoMetadataName("%");
464         StructInfoMetadataName.append(ET->getName());
465         llvm::NamedMDNode *StructInfoMetadata =
466             M->getOrInsertNamedMetadata(StructInfoMetadataName);
467         llvm::SmallVector<llvm::Value*, 3> FieldInfo;
468 
469         slangAssert(StructInfoMetadata->getNumOperands() == 0 &&
470                     "Metadata with same name was created before");
471         for (RSExportRecordType::const_field_iterator FI = ERT->fields_begin(),
472                 FE = ERT->fields_end();
473              FI != FE;
474              FI++) {
475           const RSExportRecordType::Field *F = *FI;
476 
477           // 1. field name
478           FieldInfo.push_back(llvm::MDString::get(mLLVMContext,
479                                                   F->getName().c_str()));
480 
481           // 2. field type name
482           FieldInfo.push_back(
483               llvm::MDString::get(mLLVMContext,
484                                   F->getType()->getName().c_str()));
485 
486           // 3. field kind
487           switch (F->getType()->getClass()) {
488             case RSExportType::ExportClassPrimitive:
489             case RSExportType::ExportClassVector: {
490               const RSExportPrimitiveType *EPT =
491                   static_cast<const RSExportPrimitiveType*>(F->getType());
492               FieldInfo.push_back(
493                   llvm::MDString::get(mLLVMContext,
494                                       llvm::itostr(EPT->getKind())));
495               break;
496             }
497 
498             default: {
499               FieldInfo.push_back(
500                   llvm::MDString::get(mLLVMContext,
501                                       llvm::itostr(
502                                         RSExportPrimitiveType::DataKindUser)));
503               break;
504             }
505           }
506 
507           StructInfoMetadata->addOperand(
508               llvm::MDNode::get(mLLVMContext, FieldInfo));
509           FieldInfo.clear();
510         }
511       }   // ET->getClass() == RSExportType::ExportClassRecord
512     }
513   }
514 
515   return;
516 }
517 
~RSBackend()518 RSBackend::~RSBackend() {
519   return;
520 }
521 
522 }  // namespace slang
523