• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2010, The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "slang_rs_object_ref_count.h"
18 
19 #include <list>
20 
21 #include "clang/AST/DeclGroup.h"
22 #include "clang/AST/Expr.h"
23 #include "clang/AST/NestedNameSpecifier.h"
24 #include "clang/AST/OperationKinds.h"
25 #include "clang/AST/Stmt.h"
26 #include "clang/AST/StmtVisitor.h"
27 
28 #include "slang_assert.h"
29 #include "slang_rs.h"
30 #include "slang_rs_ast_replace.h"
31 #include "slang_rs_export_type.h"
32 
33 namespace slang {
34 
35 clang::FunctionDecl *RSObjectRefCount::
36     RSSetObjectFD[RSExportPrimitiveType::LastRSObjectType -
37                   RSExportPrimitiveType::FirstRSObjectType + 1];
38 clang::FunctionDecl *RSObjectRefCount::
39     RSClearObjectFD[RSExportPrimitiveType::LastRSObjectType -
40                     RSExportPrimitiveType::FirstRSObjectType + 1];
41 
GetRSRefCountingFunctions(clang::ASTContext & C)42 void RSObjectRefCount::GetRSRefCountingFunctions(clang::ASTContext &C) {
43   for (unsigned i = 0;
44        i < (sizeof(RSClearObjectFD) / sizeof(clang::FunctionDecl*));
45        i++) {
46     RSSetObjectFD[i] = NULL;
47     RSClearObjectFD[i] = NULL;
48   }
49 
50   clang::TranslationUnitDecl *TUDecl = C.getTranslationUnitDecl();
51 
52   for (clang::DeclContext::decl_iterator I = TUDecl->decls_begin(),
53           E = TUDecl->decls_end(); I != E; I++) {
54     if ((I->getKind() >= clang::Decl::firstFunction) &&
55         (I->getKind() <= clang::Decl::lastFunction)) {
56       clang::FunctionDecl *FD = static_cast<clang::FunctionDecl*>(*I);
57 
58       // points to RSSetObjectFD or RSClearObjectFD
59       clang::FunctionDecl **RSObjectFD;
60 
61       if (FD->getName() == "rsSetObject") {
62         slangAssert((FD->getNumParams() == 2) &&
63                     "Invalid rsSetObject function prototype (# params)");
64         RSObjectFD = RSSetObjectFD;
65       } else if (FD->getName() == "rsClearObject") {
66         slangAssert((FD->getNumParams() == 1) &&
67                     "Invalid rsClearObject function prototype (# params)");
68         RSObjectFD = RSClearObjectFD;
69       } else {
70         continue;
71       }
72 
73       const clang::ParmVarDecl *PVD = FD->getParamDecl(0);
74       clang::QualType PVT = PVD->getOriginalType();
75       // The first parameter must be a pointer like rs_allocation*
76       slangAssert(PVT->isPointerType() &&
77           "Invalid rs{Set,Clear}Object function prototype (pointer param)");
78 
79       // The rs object type passed to the FD
80       clang::QualType RST = PVT->getPointeeType();
81       RSExportPrimitiveType::DataType DT =
82           RSExportPrimitiveType::GetRSSpecificType(RST.getTypePtr());
83       slangAssert(RSExportPrimitiveType::IsRSObjectType(DT)
84              && "must be RS object type");
85 
86       RSObjectFD[(DT - RSExportPrimitiveType::FirstRSObjectType)] = FD;
87     }
88   }
89 }
90 
91 namespace {
92 
93 // This function constructs a new CompoundStmt from the input StmtList.
BuildCompoundStmt(clang::ASTContext & C,std::list<clang::Stmt * > & StmtList,clang::SourceLocation Loc)94 static clang::CompoundStmt* BuildCompoundStmt(clang::ASTContext &C,
95       std::list<clang::Stmt*> &StmtList, clang::SourceLocation Loc) {
96   unsigned NewStmtCount = StmtList.size();
97   unsigned CompoundStmtCount = 0;
98 
99   clang::Stmt **CompoundStmtList;
100   CompoundStmtList = new clang::Stmt*[NewStmtCount];
101 
102   std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
103   std::list<clang::Stmt*>::const_iterator E = StmtList.end();
104   for ( ; I != E; I++) {
105     CompoundStmtList[CompoundStmtCount++] = *I;
106   }
107   slangAssert(CompoundStmtCount == NewStmtCount);
108 
109   clang::CompoundStmt *CS = new(C) clang::CompoundStmt(
110       C, llvm::makeArrayRef(CompoundStmtList, CompoundStmtCount), Loc, Loc);
111 
112   delete [] CompoundStmtList;
113 
114   return CS;
115 }
116 
AppendAfterStmt(clang::ASTContext & C,clang::CompoundStmt * CS,clang::Stmt * S,std::list<clang::Stmt * > & StmtList)117 static void AppendAfterStmt(clang::ASTContext &C,
118                             clang::CompoundStmt *CS,
119                             clang::Stmt *S,
120                             std::list<clang::Stmt*> &StmtList) {
121   slangAssert(CS);
122   clang::CompoundStmt::body_iterator bI = CS->body_begin();
123   clang::CompoundStmt::body_iterator bE = CS->body_end();
124   clang::Stmt **UpdatedStmtList =
125       new clang::Stmt*[CS->size() + StmtList.size()];
126 
127   unsigned UpdatedStmtCount = 0;
128   unsigned Once = 0;
129   for ( ; bI != bE; bI++) {
130     if (!S && ((*bI)->getStmtClass() == clang::Stmt::ReturnStmtClass)) {
131       // If we come across a return here, we don't have anything we can
132       // reasonably replace. We should have already inserted our destructor
133       // code in the proper spot, so we just clean up and return.
134       delete [] UpdatedStmtList;
135 
136       return;
137     }
138 
139     UpdatedStmtList[UpdatedStmtCount++] = *bI;
140 
141     if ((*bI == S) && !Once) {
142       Once++;
143       std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
144       std::list<clang::Stmt*>::const_iterator E = StmtList.end();
145       for ( ; I != E; I++) {
146         UpdatedStmtList[UpdatedStmtCount++] = *I;
147       }
148     }
149   }
150   slangAssert(Once <= 1);
151 
152   // When S is NULL, we are appending to the end of the CompoundStmt.
153   if (!S) {
154     slangAssert(Once == 0);
155     std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
156     std::list<clang::Stmt*>::const_iterator E = StmtList.end();
157     for ( ; I != E; I++) {
158       UpdatedStmtList[UpdatedStmtCount++] = *I;
159     }
160   }
161 
162   CS->setStmts(C, UpdatedStmtList, UpdatedStmtCount);
163 
164   delete [] UpdatedStmtList;
165 
166   return;
167 }
168 
169 // This class visits a compound statement and inserts DtorStmt
170 // in proper locations. This includes inserting it before any
171 // return statement in any sub-block, at the end of the logical enclosing
172 // scope (compound statement), and/or before any break/continue statement that
173 // would resume outside the declared scope. We will not handle the case for
174 // goto statements that leave a local scope.
175 //
176 // To accomplish these goals, it collects a list of sub-Stmt's that
177 // correspond to scope exit points. It then uses an RSASTReplace visitor to
178 // transform the AST, inserting appropriate destructors before each of those
179 // sub-Stmt's (and also before the exit of the outermost containing Stmt for
180 // the scope).
181 class DestructorVisitor : public clang::StmtVisitor<DestructorVisitor> {
182  private:
183   clang::ASTContext &mCtx;
184 
185   // The loop depth of the currently visited node.
186   int mLoopDepth;
187 
188   // The switch statement depth of the currently visited node.
189   // Note that this is tracked separately from the loop depth because
190   // SwitchStmt-contained ContinueStmt's should have destructors for the
191   // corresponding loop scope.
192   int mSwitchDepth;
193 
194   // The outermost statement block that we are currently visiting.
195   // This should always be a CompoundStmt.
196   clang::Stmt *mOuterStmt;
197 
198   // The destructor to execute for this scope/variable.
199   clang::Stmt* mDtorStmt;
200 
201   // The stack of statements which should be replaced by a compound statement
202   // containing the new destructor call followed by the original Stmt.
203   std::stack<clang::Stmt*> mReplaceStmtStack;
204 
205   // The source location for the variable declaration that we are trying to
206   // insert destructors for. Note that InsertDestructors() will not generate
207   // destructor calls for source locations that occur lexically before this
208   // location.
209   clang::SourceLocation mVarLoc;
210 
211  public:
212   DestructorVisitor(clang::ASTContext &C,
213                     clang::Stmt* OuterStmt,
214                     clang::Stmt* DtorStmt,
215                     clang::SourceLocation VarLoc);
216 
217   // This code walks the collected list of Stmts to replace and actually does
218   // the replacement. It also finishes up by appending the destructor to the
219   // current outermost CompoundStmt.
InsertDestructors()220   void InsertDestructors() {
221     clang::Stmt *S = NULL;
222     clang::SourceManager &SM = mCtx.getSourceManager();
223     std::list<clang::Stmt *> StmtList;
224     StmtList.push_back(mDtorStmt);
225 
226     while (!mReplaceStmtStack.empty()) {
227       S = mReplaceStmtStack.top();
228       mReplaceStmtStack.pop();
229 
230       // Skip all source locations that occur before the variable's
231       // declaration, since it won't have been initialized yet.
232       if (SM.isBeforeInTranslationUnit(S->getLocStart(), mVarLoc)) {
233         continue;
234       }
235 
236       StmtList.push_back(S);
237       clang::CompoundStmt *CS =
238           BuildCompoundStmt(mCtx, StmtList, S->getLocEnd());
239       StmtList.pop_back();
240 
241       RSASTReplace R(mCtx);
242       R.ReplaceStmt(mOuterStmt, S, CS);
243     }
244     clang::CompoundStmt *CS =
245       llvm::dyn_cast<clang::CompoundStmt>(mOuterStmt);
246     slangAssert(CS);
247     AppendAfterStmt(mCtx, CS, NULL, StmtList);
248   }
249 
250   void VisitStmt(clang::Stmt *S);
251   void VisitCompoundStmt(clang::CompoundStmt *CS);
252 
253   void VisitBreakStmt(clang::BreakStmt *BS);
254   void VisitCaseStmt(clang::CaseStmt *CS);
255   void VisitContinueStmt(clang::ContinueStmt *CS);
256   void VisitDefaultStmt(clang::DefaultStmt *DS);
257   void VisitDoStmt(clang::DoStmt *DS);
258   void VisitForStmt(clang::ForStmt *FS);
259   void VisitIfStmt(clang::IfStmt *IS);
260   void VisitReturnStmt(clang::ReturnStmt *RS);
261   void VisitSwitchCase(clang::SwitchCase *SC);
262   void VisitSwitchStmt(clang::SwitchStmt *SS);
263   void VisitWhileStmt(clang::WhileStmt *WS);
264 };
265 
DestructorVisitor(clang::ASTContext & C,clang::Stmt * OuterStmt,clang::Stmt * DtorStmt,clang::SourceLocation VarLoc)266 DestructorVisitor::DestructorVisitor(clang::ASTContext &C,
267                          clang::Stmt *OuterStmt,
268                          clang::Stmt *DtorStmt,
269                          clang::SourceLocation VarLoc)
270   : mCtx(C),
271     mLoopDepth(0),
272     mSwitchDepth(0),
273     mOuterStmt(OuterStmt),
274     mDtorStmt(DtorStmt),
275     mVarLoc(VarLoc) {
276   return;
277 }
278 
VisitStmt(clang::Stmt * S)279 void DestructorVisitor::VisitStmt(clang::Stmt *S) {
280   for (clang::Stmt::child_iterator I = S->child_begin(), E = S->child_end();
281        I != E;
282        I++) {
283     if (clang::Stmt *Child = *I) {
284       Visit(Child);
285     }
286   }
287   return;
288 }
289 
VisitCompoundStmt(clang::CompoundStmt * CS)290 void DestructorVisitor::VisitCompoundStmt(clang::CompoundStmt *CS) {
291   VisitStmt(CS);
292   return;
293 }
294 
VisitBreakStmt(clang::BreakStmt * BS)295 void DestructorVisitor::VisitBreakStmt(clang::BreakStmt *BS) {
296   VisitStmt(BS);
297   if ((mLoopDepth == 0) && (mSwitchDepth == 0)) {
298     mReplaceStmtStack.push(BS);
299   }
300   return;
301 }
302 
VisitCaseStmt(clang::CaseStmt * CS)303 void DestructorVisitor::VisitCaseStmt(clang::CaseStmt *CS) {
304   VisitStmt(CS);
305   return;
306 }
307 
VisitContinueStmt(clang::ContinueStmt * CS)308 void DestructorVisitor::VisitContinueStmt(clang::ContinueStmt *CS) {
309   VisitStmt(CS);
310   if (mLoopDepth == 0) {
311     // Switch statements can have nested continues.
312     mReplaceStmtStack.push(CS);
313   }
314   return;
315 }
316 
VisitDefaultStmt(clang::DefaultStmt * DS)317 void DestructorVisitor::VisitDefaultStmt(clang::DefaultStmt *DS) {
318   VisitStmt(DS);
319   return;
320 }
321 
VisitDoStmt(clang::DoStmt * DS)322 void DestructorVisitor::VisitDoStmt(clang::DoStmt *DS) {
323   mLoopDepth++;
324   VisitStmt(DS);
325   mLoopDepth--;
326   return;
327 }
328 
VisitForStmt(clang::ForStmt * FS)329 void DestructorVisitor::VisitForStmt(clang::ForStmt *FS) {
330   mLoopDepth++;
331   VisitStmt(FS);
332   mLoopDepth--;
333   return;
334 }
335 
VisitIfStmt(clang::IfStmt * IS)336 void DestructorVisitor::VisitIfStmt(clang::IfStmt *IS) {
337   VisitStmt(IS);
338   return;
339 }
340 
VisitReturnStmt(clang::ReturnStmt * RS)341 void DestructorVisitor::VisitReturnStmt(clang::ReturnStmt *RS) {
342   mReplaceStmtStack.push(RS);
343   return;
344 }
345 
VisitSwitchCase(clang::SwitchCase * SC)346 void DestructorVisitor::VisitSwitchCase(clang::SwitchCase *SC) {
347   slangAssert(false && "Both case and default have specialized handlers");
348   VisitStmt(SC);
349   return;
350 }
351 
VisitSwitchStmt(clang::SwitchStmt * SS)352 void DestructorVisitor::VisitSwitchStmt(clang::SwitchStmt *SS) {
353   mSwitchDepth++;
354   VisitStmt(SS);
355   mSwitchDepth--;
356   return;
357 }
358 
VisitWhileStmt(clang::WhileStmt * WS)359 void DestructorVisitor::VisitWhileStmt(clang::WhileStmt *WS) {
360   mLoopDepth++;
361   VisitStmt(WS);
362   mLoopDepth--;
363   return;
364 }
365 
ClearSingleRSObject(clang::ASTContext & C,clang::Expr * RefRSVar,clang::SourceLocation Loc)366 clang::Expr *ClearSingleRSObject(clang::ASTContext &C,
367                                  clang::Expr *RefRSVar,
368                                  clang::SourceLocation Loc) {
369   slangAssert(RefRSVar);
370   const clang::Type *T = RefRSVar->getType().getTypePtr();
371   slangAssert(!T->isArrayType() &&
372               "Should not be destroying arrays with this function");
373 
374   clang::FunctionDecl *ClearObjectFD = RSObjectRefCount::GetRSClearObjectFD(T);
375   slangAssert((ClearObjectFD != NULL) &&
376               "rsClearObject doesn't cover all RS object types");
377 
378   clang::QualType ClearObjectFDType = ClearObjectFD->getType();
379   clang::QualType ClearObjectFDArgType =
380       ClearObjectFD->getParamDecl(0)->getOriginalType();
381 
382   // Example destructor for "rs_font localFont;"
383   //
384   // (CallExpr 'void'
385   //   (ImplicitCastExpr 'void (*)(rs_font *)' <FunctionToPointerDecay>
386   //     (DeclRefExpr 'void (rs_font *)' FunctionDecl='rsClearObject'))
387   //   (UnaryOperator 'rs_font *' prefix '&'
388   //     (DeclRefExpr 'rs_font':'rs_font' Var='localFont')))
389 
390   // Get address of targeted RS object
391   clang::Expr *AddrRefRSVar =
392       new(C) clang::UnaryOperator(RefRSVar,
393                                   clang::UO_AddrOf,
394                                   ClearObjectFDArgType,
395                                   clang::VK_RValue,
396                                   clang::OK_Ordinary,
397                                   Loc);
398 
399   clang::Expr *RefRSClearObjectFD =
400       clang::DeclRefExpr::Create(C,
401                                  clang::NestedNameSpecifierLoc(),
402                                  clang::SourceLocation(),
403                                  ClearObjectFD,
404                                  false,
405                                  ClearObjectFD->getLocation(),
406                                  ClearObjectFDType,
407                                  clang::VK_RValue,
408                                  NULL);
409 
410   clang::Expr *RSClearObjectFP =
411       clang::ImplicitCastExpr::Create(C,
412                                       C.getPointerType(ClearObjectFDType),
413                                       clang::CK_FunctionToPointerDecay,
414                                       RefRSClearObjectFD,
415                                       NULL,
416                                       clang::VK_RValue);
417 
418   llvm::SmallVector<clang::Expr*, 1> ArgList;
419   ArgList.push_back(AddrRefRSVar);
420 
421   clang::CallExpr *RSClearObjectCall =
422       new(C) clang::CallExpr(C,
423                              RSClearObjectFP,
424                              ArgList,
425                              ClearObjectFD->getCallResultType(),
426                              clang::VK_RValue,
427                              Loc);
428 
429   return RSClearObjectCall;
430 }
431 
ArrayDim(const clang::Type * T)432 static int ArrayDim(const clang::Type *T) {
433   if (!T || !T->isArrayType()) {
434     return 0;
435   }
436 
437   const clang::ConstantArrayType *CAT =
438     static_cast<const clang::ConstantArrayType *>(T);
439   return static_cast<int>(CAT->getSize().getSExtValue());
440 }
441 
442 static clang::Stmt *ClearStructRSObject(
443     clang::ASTContext &C,
444     clang::DeclContext *DC,
445     clang::Expr *RefRSStruct,
446     clang::SourceLocation StartLoc,
447     clang::SourceLocation Loc);
448 
ClearArrayRSObject(clang::ASTContext & C,clang::DeclContext * DC,clang::Expr * RefRSArr,clang::SourceLocation StartLoc,clang::SourceLocation Loc)449 static clang::Stmt *ClearArrayRSObject(
450     clang::ASTContext &C,
451     clang::DeclContext *DC,
452     clang::Expr *RefRSArr,
453     clang::SourceLocation StartLoc,
454     clang::SourceLocation Loc) {
455   const clang::Type *BaseType = RefRSArr->getType().getTypePtr();
456   slangAssert(BaseType->isArrayType());
457 
458   int NumArrayElements = ArrayDim(BaseType);
459   // Actually extract out the base RS object type for use later
460   BaseType = BaseType->getArrayElementTypeNoTypeQual();
461 
462   clang::Stmt *StmtArray[2] = {NULL};
463   int StmtCtr = 0;
464 
465   if (NumArrayElements <= 0) {
466     return NULL;
467   }
468 
469   // Example destructor loop for "rs_font fontArr[10];"
470   //
471   // (CompoundStmt
472   //   (DeclStmt "int rsIntIter")
473   //   (ForStmt
474   //     (BinaryOperator 'int' '='
475   //       (DeclRefExpr 'int' Var='rsIntIter')
476   //       (IntegerLiteral 'int' 0))
477   //     (BinaryOperator 'int' '<'
478   //       (DeclRefExpr 'int' Var='rsIntIter')
479   //       (IntegerLiteral 'int' 10)
480   //     NULL << CondVar >>
481   //     (UnaryOperator 'int' postfix '++'
482   //       (DeclRefExpr 'int' Var='rsIntIter'))
483   //     (CallExpr 'void'
484   //       (ImplicitCastExpr 'void (*)(rs_font *)' <FunctionToPointerDecay>
485   //         (DeclRefExpr 'void (rs_font *)' FunctionDecl='rsClearObject'))
486   //       (UnaryOperator 'rs_font *' prefix '&'
487   //         (ArraySubscriptExpr 'rs_font':'rs_font'
488   //           (ImplicitCastExpr 'rs_font *' <ArrayToPointerDecay>
489   //             (DeclRefExpr 'rs_font [10]' Var='fontArr'))
490   //           (DeclRefExpr 'int' Var='rsIntIter')))))))
491 
492   // Create helper variable for iterating through elements
493   clang::IdentifierInfo& II = C.Idents.get("rsIntIter");
494   clang::VarDecl *IIVD =
495       clang::VarDecl::Create(C,
496                              DC,
497                              StartLoc,
498                              Loc,
499                              &II,
500                              C.IntTy,
501                              C.getTrivialTypeSourceInfo(C.IntTy),
502                              clang::SC_None,
503                              clang::SC_None);
504   clang::Decl *IID = (clang::Decl *)IIVD;
505 
506   clang::DeclGroupRef DGR = clang::DeclGroupRef::Create(C, &IID, 1);
507   StmtArray[StmtCtr++] = new(C) clang::DeclStmt(DGR, Loc, Loc);
508 
509   // Form the actual destructor loop
510   // for (Init; Cond; Inc)
511   //   RSClearObjectCall;
512 
513   // Init -> "rsIntIter = 0"
514   clang::DeclRefExpr *RefrsIntIter =
515       clang::DeclRefExpr::Create(C,
516                                  clang::NestedNameSpecifierLoc(),
517                                  clang::SourceLocation(),
518                                  IIVD,
519                                  false,
520                                  Loc,
521                                  C.IntTy,
522                                  clang::VK_RValue,
523                                  NULL);
524 
525   clang::Expr *Int0 = clang::IntegerLiteral::Create(C,
526       llvm::APInt(C.getTypeSize(C.IntTy), 0), C.IntTy, Loc);
527 
528   clang::BinaryOperator *Init =
529       new(C) clang::BinaryOperator(RefrsIntIter,
530                                    Int0,
531                                    clang::BO_Assign,
532                                    C.IntTy,
533                                    clang::VK_RValue,
534                                    clang::OK_Ordinary,
535                                    Loc,
536                                    false);
537 
538   // Cond -> "rsIntIter < NumArrayElements"
539   clang::Expr *NumArrayElementsExpr = clang::IntegerLiteral::Create(C,
540       llvm::APInt(C.getTypeSize(C.IntTy), NumArrayElements), C.IntTy, Loc);
541 
542   clang::BinaryOperator *Cond =
543       new(C) clang::BinaryOperator(RefrsIntIter,
544                                    NumArrayElementsExpr,
545                                    clang::BO_LT,
546                                    C.IntTy,
547                                    clang::VK_RValue,
548                                    clang::OK_Ordinary,
549                                    Loc,
550                                    false);
551 
552   // Inc -> "rsIntIter++"
553   clang::UnaryOperator *Inc =
554       new(C) clang::UnaryOperator(RefrsIntIter,
555                                   clang::UO_PostInc,
556                                   C.IntTy,
557                                   clang::VK_RValue,
558                                   clang::OK_Ordinary,
559                                   Loc);
560 
561   // Body -> "rsClearObject(&VD[rsIntIter]);"
562   // Destructor loop operates on individual array elements
563 
564   clang::Expr *RefRSArrPtr =
565       clang::ImplicitCastExpr::Create(C,
566           C.getPointerType(BaseType->getCanonicalTypeInternal()),
567           clang::CK_ArrayToPointerDecay,
568           RefRSArr,
569           NULL,
570           clang::VK_RValue);
571 
572   clang::Expr *RefRSArrPtrSubscript =
573       new(C) clang::ArraySubscriptExpr(RefRSArrPtr,
574                                        RefrsIntIter,
575                                        BaseType->getCanonicalTypeInternal(),
576                                        clang::VK_RValue,
577                                        clang::OK_Ordinary,
578                                        Loc);
579 
580   RSExportPrimitiveType::DataType DT =
581       RSExportPrimitiveType::GetRSSpecificType(BaseType);
582 
583   clang::Stmt *RSClearObjectCall = NULL;
584   if (BaseType->isArrayType()) {
585     RSClearObjectCall =
586         ClearArrayRSObject(C, DC, RefRSArrPtrSubscript, StartLoc, Loc);
587   } else if (DT == RSExportPrimitiveType::DataTypeUnknown) {
588     RSClearObjectCall =
589         ClearStructRSObject(C, DC, RefRSArrPtrSubscript, StartLoc, Loc);
590   } else {
591     RSClearObjectCall = ClearSingleRSObject(C, RefRSArrPtrSubscript, Loc);
592   }
593 
594   clang::ForStmt *DestructorLoop =
595       new(C) clang::ForStmt(C,
596                             Init,
597                             Cond,
598                             NULL,  // no condVar
599                             Inc,
600                             RSClearObjectCall,
601                             Loc,
602                             Loc,
603                             Loc);
604 
605   StmtArray[StmtCtr++] = DestructorLoop;
606   slangAssert(StmtCtr == 2);
607 
608   clang::CompoundStmt *CS = new(C) clang::CompoundStmt(
609       C, llvm::makeArrayRef(StmtArray, StmtCtr), Loc, Loc);
610 
611   return CS;
612 }
613 
CountRSObjectTypes(clang::ASTContext & C,const clang::Type * T,clang::SourceLocation Loc)614 static unsigned CountRSObjectTypes(clang::ASTContext &C,
615                                    const clang::Type *T,
616                                    clang::SourceLocation Loc) {
617   slangAssert(T);
618   unsigned RSObjectCount = 0;
619 
620   if (T->isArrayType()) {
621     return CountRSObjectTypes(C, T->getArrayElementTypeNoTypeQual(), Loc);
622   }
623 
624   RSExportPrimitiveType::DataType DT =
625       RSExportPrimitiveType::GetRSSpecificType(T);
626   if (DT != RSExportPrimitiveType::DataTypeUnknown) {
627     return (RSExportPrimitiveType::IsRSObjectType(DT) ? 1 : 0);
628   }
629 
630   if (T->isUnionType()) {
631     clang::RecordDecl *RD = T->getAsUnionType()->getDecl();
632     RD = RD->getDefinition();
633     for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
634            FE = RD->field_end();
635          FI != FE;
636          FI++) {
637       const clang::FieldDecl *FD = *FI;
638       const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
639       if (CountRSObjectTypes(C, FT, Loc)) {
640         slangAssert(false && "can't have unions with RS object types!");
641         return 0;
642       }
643     }
644   }
645 
646   if (!T->isStructureType()) {
647     return 0;
648   }
649 
650   clang::RecordDecl *RD = T->getAsStructureType()->getDecl();
651   RD = RD->getDefinition();
652   for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
653          FE = RD->field_end();
654        FI != FE;
655        FI++) {
656     const clang::FieldDecl *FD = *FI;
657     const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
658     if (CountRSObjectTypes(C, FT, Loc)) {
659       // Sub-structs should only count once (as should arrays, etc.)
660       RSObjectCount++;
661     }
662   }
663 
664   return RSObjectCount;
665 }
666 
ClearStructRSObject(clang::ASTContext & C,clang::DeclContext * DC,clang::Expr * RefRSStruct,clang::SourceLocation StartLoc,clang::SourceLocation Loc)667 static clang::Stmt *ClearStructRSObject(
668     clang::ASTContext &C,
669     clang::DeclContext *DC,
670     clang::Expr *RefRSStruct,
671     clang::SourceLocation StartLoc,
672     clang::SourceLocation Loc) {
673   const clang::Type *BaseType = RefRSStruct->getType().getTypePtr();
674 
675   slangAssert(!BaseType->isArrayType());
676 
677   // Structs should show up as unknown primitive types
678   slangAssert(RSExportPrimitiveType::GetRSSpecificType(BaseType) ==
679               RSExportPrimitiveType::DataTypeUnknown);
680 
681   unsigned FieldsToDestroy = CountRSObjectTypes(C, BaseType, Loc);
682   slangAssert(FieldsToDestroy != 0);
683 
684   unsigned StmtCount = 0;
685   clang::Stmt **StmtArray = new clang::Stmt*[FieldsToDestroy];
686   for (unsigned i = 0; i < FieldsToDestroy; i++) {
687     StmtArray[i] = NULL;
688   }
689 
690   // Populate StmtArray by creating a destructor for each RS object field
691   clang::RecordDecl *RD = BaseType->getAsStructureType()->getDecl();
692   RD = RD->getDefinition();
693   for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
694          FE = RD->field_end();
695        FI != FE;
696        FI++) {
697     // We just look through all field declarations to see if we find a
698     // declaration for an RS object type (or an array of one).
699     bool IsArrayType = false;
700     clang::FieldDecl *FD = *FI;
701     const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
702     const clang::Type *OrigType = FT;
703     while (FT && FT->isArrayType()) {
704       FT = FT->getArrayElementTypeNoTypeQual();
705       IsArrayType = true;
706     }
707 
708     if (RSExportPrimitiveType::IsRSObjectType(FT)) {
709       clang::DeclAccessPair FoundDecl =
710           clang::DeclAccessPair::make(FD, clang::AS_none);
711       clang::MemberExpr *RSObjectMember =
712           clang::MemberExpr::Create(C,
713                                     RefRSStruct,
714                                     false,
715                                     clang::NestedNameSpecifierLoc(),
716                                     clang::SourceLocation(),
717                                     FD,
718                                     FoundDecl,
719                                     clang::DeclarationNameInfo(),
720                                     NULL,
721                                     OrigType->getCanonicalTypeInternal(),
722                                     clang::VK_RValue,
723                                     clang::OK_Ordinary);
724 
725       slangAssert(StmtCount < FieldsToDestroy);
726 
727       if (IsArrayType) {
728         StmtArray[StmtCount++] = ClearArrayRSObject(C,
729                                                     DC,
730                                                     RSObjectMember,
731                                                     StartLoc,
732                                                     Loc);
733       } else {
734         StmtArray[StmtCount++] = ClearSingleRSObject(C,
735                                                      RSObjectMember,
736                                                      Loc);
737       }
738     } else if (FT->isStructureType() && CountRSObjectTypes(C, FT, Loc)) {
739       // In this case, we have a nested struct. We may not end up filling all
740       // of the spaces in StmtArray (sub-structs should handle themselves
741       // with separate compound statements).
742       clang::DeclAccessPair FoundDecl =
743           clang::DeclAccessPair::make(FD, clang::AS_none);
744       clang::MemberExpr *RSObjectMember =
745           clang::MemberExpr::Create(C,
746                                     RefRSStruct,
747                                     false,
748                                     clang::NestedNameSpecifierLoc(),
749                                     clang::SourceLocation(),
750                                     FD,
751                                     FoundDecl,
752                                     clang::DeclarationNameInfo(),
753                                     NULL,
754                                     OrigType->getCanonicalTypeInternal(),
755                                     clang::VK_RValue,
756                                     clang::OK_Ordinary);
757 
758       if (IsArrayType) {
759         StmtArray[StmtCount++] = ClearArrayRSObject(C,
760                                                     DC,
761                                                     RSObjectMember,
762                                                     StartLoc,
763                                                     Loc);
764       } else {
765         StmtArray[StmtCount++] = ClearStructRSObject(C,
766                                                      DC,
767                                                      RSObjectMember,
768                                                      StartLoc,
769                                                      Loc);
770       }
771     }
772   }
773 
774   slangAssert(StmtCount > 0);
775   clang::CompoundStmt *CS = new(C) clang::CompoundStmt(
776       C, llvm::makeArrayRef(StmtArray, StmtCount), Loc, Loc);
777 
778   delete [] StmtArray;
779 
780   return CS;
781 }
782 
CreateSingleRSSetObject(clang::ASTContext & C,clang::Expr * DstExpr,clang::Expr * SrcExpr,clang::SourceLocation StartLoc,clang::SourceLocation Loc)783 static clang::Stmt *CreateSingleRSSetObject(clang::ASTContext &C,
784                                             clang::Expr *DstExpr,
785                                             clang::Expr *SrcExpr,
786                                             clang::SourceLocation StartLoc,
787                                             clang::SourceLocation Loc) {
788   const clang::Type *T = DstExpr->getType().getTypePtr();
789   clang::FunctionDecl *SetObjectFD = RSObjectRefCount::GetRSSetObjectFD(T);
790   slangAssert((SetObjectFD != NULL) &&
791               "rsSetObject doesn't cover all RS object types");
792 
793   clang::QualType SetObjectFDType = SetObjectFD->getType();
794   clang::QualType SetObjectFDArgType[2];
795   SetObjectFDArgType[0] = SetObjectFD->getParamDecl(0)->getOriginalType();
796   SetObjectFDArgType[1] = SetObjectFD->getParamDecl(1)->getOriginalType();
797 
798   clang::Expr *RefRSSetObjectFD =
799       clang::DeclRefExpr::Create(C,
800                                  clang::NestedNameSpecifierLoc(),
801                                  clang::SourceLocation(),
802                                  SetObjectFD,
803                                  false,
804                                  Loc,
805                                  SetObjectFDType,
806                                  clang::VK_RValue,
807                                  NULL);
808 
809   clang::Expr *RSSetObjectFP =
810       clang::ImplicitCastExpr::Create(C,
811                                       C.getPointerType(SetObjectFDType),
812                                       clang::CK_FunctionToPointerDecay,
813                                       RefRSSetObjectFD,
814                                       NULL,
815                                       clang::VK_RValue);
816 
817   llvm::SmallVector<clang::Expr*, 2> ArgList;
818   ArgList.push_back(new(C) clang::UnaryOperator(DstExpr,
819                                                 clang::UO_AddrOf,
820                                                 SetObjectFDArgType[0],
821                                                 clang::VK_RValue,
822                                                 clang::OK_Ordinary,
823                                                 Loc));
824   ArgList.push_back(SrcExpr);
825 
826   clang::CallExpr *RSSetObjectCall =
827       new(C) clang::CallExpr(C,
828                              RSSetObjectFP,
829                              ArgList,
830                              SetObjectFD->getCallResultType(),
831                              clang::VK_RValue,
832                              Loc);
833 
834   return RSSetObjectCall;
835 }
836 
837 static clang::Stmt *CreateStructRSSetObject(clang::ASTContext &C,
838                                             clang::Expr *LHS,
839                                             clang::Expr *RHS,
840                                             clang::SourceLocation StartLoc,
841                                             clang::SourceLocation Loc);
842 
843 /*static clang::Stmt *CreateArrayRSSetObject(clang::ASTContext &C,
844                                            clang::Expr *DstArr,
845                                            clang::Expr *SrcArr,
846                                            clang::SourceLocation StartLoc,
847                                            clang::SourceLocation Loc) {
848   clang::DeclContext *DC = NULL;
849   const clang::Type *BaseType = DstArr->getType().getTypePtr();
850   slangAssert(BaseType->isArrayType());
851 
852   int NumArrayElements = ArrayDim(BaseType);
853   // Actually extract out the base RS object type for use later
854   BaseType = BaseType->getArrayElementTypeNoTypeQual();
855 
856   clang::Stmt *StmtArray[2] = {NULL};
857   int StmtCtr = 0;
858 
859   if (NumArrayElements <= 0) {
860     return NULL;
861   }
862 
863   // Create helper variable for iterating through elements
864   clang::IdentifierInfo& II = C.Idents.get("rsIntIter");
865   clang::VarDecl *IIVD =
866       clang::VarDecl::Create(C,
867                              DC,
868                              StartLoc,
869                              Loc,
870                              &II,
871                              C.IntTy,
872                              C.getTrivialTypeSourceInfo(C.IntTy),
873                              clang::SC_None,
874                              clang::SC_None);
875   clang::Decl *IID = (clang::Decl *)IIVD;
876 
877   clang::DeclGroupRef DGR = clang::DeclGroupRef::Create(C, &IID, 1);
878   StmtArray[StmtCtr++] = new(C) clang::DeclStmt(DGR, Loc, Loc);
879 
880   // Form the actual loop
881   // for (Init; Cond; Inc)
882   //   RSSetObjectCall;
883 
884   // Init -> "rsIntIter = 0"
885   clang::DeclRefExpr *RefrsIntIter =
886       clang::DeclRefExpr::Create(C,
887                                  clang::NestedNameSpecifierLoc(),
888                                  IIVD,
889                                  Loc,
890                                  C.IntTy,
891                                  clang::VK_RValue,
892                                  NULL);
893 
894   clang::Expr *Int0 = clang::IntegerLiteral::Create(C,
895       llvm::APInt(C.getTypeSize(C.IntTy), 0), C.IntTy, Loc);
896 
897   clang::BinaryOperator *Init =
898       new(C) clang::BinaryOperator(RefrsIntIter,
899                                    Int0,
900                                    clang::BO_Assign,
901                                    C.IntTy,
902                                    clang::VK_RValue,
903                                    clang::OK_Ordinary,
904                                    Loc);
905 
906   // Cond -> "rsIntIter < NumArrayElements"
907   clang::Expr *NumArrayElementsExpr = clang::IntegerLiteral::Create(C,
908       llvm::APInt(C.getTypeSize(C.IntTy), NumArrayElements), C.IntTy, Loc);
909 
910   clang::BinaryOperator *Cond =
911       new(C) clang::BinaryOperator(RefrsIntIter,
912                                    NumArrayElementsExpr,
913                                    clang::BO_LT,
914                                    C.IntTy,
915                                    clang::VK_RValue,
916                                    clang::OK_Ordinary,
917                                    Loc);
918 
919   // Inc -> "rsIntIter++"
920   clang::UnaryOperator *Inc =
921       new(C) clang::UnaryOperator(RefrsIntIter,
922                                   clang::UO_PostInc,
923                                   C.IntTy,
924                                   clang::VK_RValue,
925                                   clang::OK_Ordinary,
926                                   Loc);
927 
928   // Body -> "rsSetObject(&Dst[rsIntIter], Src[rsIntIter]);"
929   // Loop operates on individual array elements
930 
931   clang::Expr *DstArrPtr =
932       clang::ImplicitCastExpr::Create(C,
933           C.getPointerType(BaseType->getCanonicalTypeInternal()),
934           clang::CK_ArrayToPointerDecay,
935           DstArr,
936           NULL,
937           clang::VK_RValue);
938 
939   clang::Expr *DstArrPtrSubscript =
940       new(C) clang::ArraySubscriptExpr(DstArrPtr,
941                                        RefrsIntIter,
942                                        BaseType->getCanonicalTypeInternal(),
943                                        clang::VK_RValue,
944                                        clang::OK_Ordinary,
945                                        Loc);
946 
947   clang::Expr *SrcArrPtr =
948       clang::ImplicitCastExpr::Create(C,
949           C.getPointerType(BaseType->getCanonicalTypeInternal()),
950           clang::CK_ArrayToPointerDecay,
951           SrcArr,
952           NULL,
953           clang::VK_RValue);
954 
955   clang::Expr *SrcArrPtrSubscript =
956       new(C) clang::ArraySubscriptExpr(SrcArrPtr,
957                                        RefrsIntIter,
958                                        BaseType->getCanonicalTypeInternal(),
959                                        clang::VK_RValue,
960                                        clang::OK_Ordinary,
961                                        Loc);
962 
963   RSExportPrimitiveType::DataType DT =
964       RSExportPrimitiveType::GetRSSpecificType(BaseType);
965 
966   clang::Stmt *RSSetObjectCall = NULL;
967   if (BaseType->isArrayType()) {
968     RSSetObjectCall = CreateArrayRSSetObject(C, DstArrPtrSubscript,
969                                              SrcArrPtrSubscript,
970                                              StartLoc, Loc);
971   } else if (DT == RSExportPrimitiveType::DataTypeUnknown) {
972     RSSetObjectCall = CreateStructRSSetObject(C, DstArrPtrSubscript,
973                                               SrcArrPtrSubscript,
974                                               StartLoc, Loc);
975   } else {
976     RSSetObjectCall = CreateSingleRSSetObject(C, DstArrPtrSubscript,
977                                               SrcArrPtrSubscript,
978                                               StartLoc, Loc);
979   }
980 
981   clang::ForStmt *DestructorLoop =
982       new(C) clang::ForStmt(C,
983                             Init,
984                             Cond,
985                             NULL,  // no condVar
986                             Inc,
987                             RSSetObjectCall,
988                             Loc,
989                             Loc,
990                             Loc);
991 
992   StmtArray[StmtCtr++] = DestructorLoop;
993   slangAssert(StmtCtr == 2);
994 
995   clang::CompoundStmt *CS =
996       new(C) clang::CompoundStmt(C, StmtArray, StmtCtr, Loc, Loc);
997 
998   return CS;
999 } */
1000 
CreateStructRSSetObject(clang::ASTContext & C,clang::Expr * LHS,clang::Expr * RHS,clang::SourceLocation StartLoc,clang::SourceLocation Loc)1001 static clang::Stmt *CreateStructRSSetObject(clang::ASTContext &C,
1002                                             clang::Expr *LHS,
1003                                             clang::Expr *RHS,
1004                                             clang::SourceLocation StartLoc,
1005                                             clang::SourceLocation Loc) {
1006   clang::QualType QT = LHS->getType();
1007   const clang::Type *T = QT.getTypePtr();
1008   slangAssert(T->isStructureType());
1009   slangAssert(!RSExportPrimitiveType::IsRSObjectType(T));
1010 
1011   // Keep an extra slot for the original copy (memcpy)
1012   unsigned FieldsToSet = CountRSObjectTypes(C, T, Loc) + 1;
1013 
1014   unsigned StmtCount = 0;
1015   clang::Stmt **StmtArray = new clang::Stmt*[FieldsToSet];
1016   for (unsigned i = 0; i < FieldsToSet; i++) {
1017     StmtArray[i] = NULL;
1018   }
1019 
1020   clang::RecordDecl *RD = T->getAsStructureType()->getDecl();
1021   RD = RD->getDefinition();
1022   for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
1023          FE = RD->field_end();
1024        FI != FE;
1025        FI++) {
1026     bool IsArrayType = false;
1027     clang::FieldDecl *FD = *FI;
1028     const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
1029     const clang::Type *OrigType = FT;
1030 
1031     if (!CountRSObjectTypes(C, FT, Loc)) {
1032       // Skip to next if we don't have any viable RS object types
1033       continue;
1034     }
1035 
1036     clang::DeclAccessPair FoundDecl =
1037         clang::DeclAccessPair::make(FD, clang::AS_none);
1038     clang::MemberExpr *DstMember =
1039         clang::MemberExpr::Create(C,
1040                                   LHS,
1041                                   false,
1042                                   clang::NestedNameSpecifierLoc(),
1043                                   clang::SourceLocation(),
1044                                   FD,
1045                                   FoundDecl,
1046                                   clang::DeclarationNameInfo(),
1047                                   NULL,
1048                                   OrigType->getCanonicalTypeInternal(),
1049                                   clang::VK_RValue,
1050                                   clang::OK_Ordinary);
1051 
1052     clang::MemberExpr *SrcMember =
1053         clang::MemberExpr::Create(C,
1054                                   RHS,
1055                                   false,
1056                                   clang::NestedNameSpecifierLoc(),
1057                                   clang::SourceLocation(),
1058                                   FD,
1059                                   FoundDecl,
1060                                   clang::DeclarationNameInfo(),
1061                                   NULL,
1062                                   OrigType->getCanonicalTypeInternal(),
1063                                   clang::VK_RValue,
1064                                   clang::OK_Ordinary);
1065 
1066     if (FT->isArrayType()) {
1067       FT = FT->getArrayElementTypeNoTypeQual();
1068       IsArrayType = true;
1069     }
1070 
1071     RSExportPrimitiveType::DataType DT =
1072         RSExportPrimitiveType::GetRSSpecificType(FT);
1073 
1074     if (IsArrayType) {
1075       clang::DiagnosticsEngine &DiagEngine = C.getDiagnostics();
1076       DiagEngine.Report(
1077         clang::FullSourceLoc(Loc, C.getSourceManager()),
1078         DiagEngine.getCustomDiagID(
1079           clang::DiagnosticsEngine::Error,
1080           "Arrays of RS object types within structures cannot be copied"));
1081       // TODO(srhines): Support setting arrays of RS objects
1082       // StmtArray[StmtCount++] =
1083       //    CreateArrayRSSetObject(C, DstMember, SrcMember, StartLoc, Loc);
1084     } else if (DT == RSExportPrimitiveType::DataTypeUnknown) {
1085       StmtArray[StmtCount++] =
1086           CreateStructRSSetObject(C, DstMember, SrcMember, StartLoc, Loc);
1087     } else if (RSExportPrimitiveType::IsRSObjectType(DT)) {
1088       StmtArray[StmtCount++] =
1089           CreateSingleRSSetObject(C, DstMember, SrcMember, StartLoc, Loc);
1090     } else {
1091       slangAssert(false);
1092     }
1093   }
1094 
1095   slangAssert(StmtCount < FieldsToSet);
1096 
1097   // We still need to actually do the overall struct copy. For simplicity,
1098   // we just do a straight-up assignment (which will still preserve all
1099   // the proper RS object reference counts).
1100   clang::BinaryOperator *CopyStruct =
1101       new(C) clang::BinaryOperator(LHS, RHS, clang::BO_Assign, QT,
1102                                    clang::VK_RValue, clang::OK_Ordinary, Loc,
1103                                    false);
1104   StmtArray[StmtCount++] = CopyStruct;
1105 
1106   clang::CompoundStmt *CS = new(C) clang::CompoundStmt(
1107       C, llvm::makeArrayRef(StmtArray, StmtCount), Loc, Loc);
1108 
1109   delete [] StmtArray;
1110 
1111   return CS;
1112 }
1113 
1114 }  // namespace
1115 
ReplaceRSObjectAssignment(clang::BinaryOperator * AS)1116 void RSObjectRefCount::Scope::ReplaceRSObjectAssignment(
1117     clang::BinaryOperator *AS) {
1118 
1119   clang::QualType QT = AS->getType();
1120 
1121   clang::ASTContext &C = RSObjectRefCount::GetRSSetObjectFD(
1122       RSExportPrimitiveType::DataTypeRSFont)->getASTContext();
1123 
1124   clang::SourceLocation Loc = AS->getExprLoc();
1125   clang::SourceLocation StartLoc = AS->getLHS()->getExprLoc();
1126   clang::Stmt *UpdatedStmt = NULL;
1127 
1128   if (!RSExportPrimitiveType::IsRSObjectType(QT.getTypePtr())) {
1129     // By definition, this is a struct assignment if we get here
1130     UpdatedStmt =
1131         CreateStructRSSetObject(C, AS->getLHS(), AS->getRHS(), StartLoc, Loc);
1132   } else {
1133     UpdatedStmt =
1134         CreateSingleRSSetObject(C, AS->getLHS(), AS->getRHS(), StartLoc, Loc);
1135   }
1136 
1137   RSASTReplace R(C);
1138   R.ReplaceStmt(mCS, AS, UpdatedStmt);
1139   return;
1140 }
1141 
AppendRSObjectInit(clang::VarDecl * VD,clang::DeclStmt * DS,RSExportPrimitiveType::DataType DT,clang::Expr * InitExpr)1142 void RSObjectRefCount::Scope::AppendRSObjectInit(
1143     clang::VarDecl *VD,
1144     clang::DeclStmt *DS,
1145     RSExportPrimitiveType::DataType DT,
1146     clang::Expr *InitExpr) {
1147   slangAssert(VD);
1148 
1149   if (!InitExpr) {
1150     return;
1151   }
1152 
1153   clang::ASTContext &C = RSObjectRefCount::GetRSSetObjectFD(
1154       RSExportPrimitiveType::DataTypeRSFont)->getASTContext();
1155   clang::SourceLocation Loc = RSObjectRefCount::GetRSSetObjectFD(
1156       RSExportPrimitiveType::DataTypeRSFont)->getLocation();
1157   clang::SourceLocation StartLoc = RSObjectRefCount::GetRSSetObjectFD(
1158       RSExportPrimitiveType::DataTypeRSFont)->getInnerLocStart();
1159 
1160   if (DT == RSExportPrimitiveType::DataTypeIsStruct) {
1161     const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1162     clang::DeclRefExpr *RefRSVar =
1163         clang::DeclRefExpr::Create(C,
1164                                    clang::NestedNameSpecifierLoc(),
1165                                    clang::SourceLocation(),
1166                                    VD,
1167                                    false,
1168                                    Loc,
1169                                    T->getCanonicalTypeInternal(),
1170                                    clang::VK_RValue,
1171                                    NULL);
1172 
1173     clang::Stmt *RSSetObjectOps =
1174         CreateStructRSSetObject(C, RefRSVar, InitExpr, StartLoc, Loc);
1175 
1176     std::list<clang::Stmt*> StmtList;
1177     StmtList.push_back(RSSetObjectOps);
1178     AppendAfterStmt(C, mCS, DS, StmtList);
1179     return;
1180   }
1181 
1182   clang::FunctionDecl *SetObjectFD = RSObjectRefCount::GetRSSetObjectFD(DT);
1183   slangAssert((SetObjectFD != NULL) &&
1184               "rsSetObject doesn't cover all RS object types");
1185 
1186   clang::QualType SetObjectFDType = SetObjectFD->getType();
1187   clang::QualType SetObjectFDArgType[2];
1188   SetObjectFDArgType[0] = SetObjectFD->getParamDecl(0)->getOriginalType();
1189   SetObjectFDArgType[1] = SetObjectFD->getParamDecl(1)->getOriginalType();
1190 
1191   clang::Expr *RefRSSetObjectFD =
1192       clang::DeclRefExpr::Create(C,
1193                                  clang::NestedNameSpecifierLoc(),
1194                                  clang::SourceLocation(),
1195                                  SetObjectFD,
1196                                  false,
1197                                  Loc,
1198                                  SetObjectFDType,
1199                                  clang::VK_RValue,
1200                                  NULL);
1201 
1202   clang::Expr *RSSetObjectFP =
1203       clang::ImplicitCastExpr::Create(C,
1204                                       C.getPointerType(SetObjectFDType),
1205                                       clang::CK_FunctionToPointerDecay,
1206                                       RefRSSetObjectFD,
1207                                       NULL,
1208                                       clang::VK_RValue);
1209 
1210   const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1211   clang::DeclRefExpr *RefRSVar =
1212       clang::DeclRefExpr::Create(C,
1213                                  clang::NestedNameSpecifierLoc(),
1214                                  clang::SourceLocation(),
1215                                  VD,
1216                                  false,
1217                                  Loc,
1218                                  T->getCanonicalTypeInternal(),
1219                                  clang::VK_RValue,
1220                                  NULL);
1221 
1222   llvm::SmallVector<clang::Expr*, 2> ArgList;
1223   ArgList.push_back(new(C) clang::UnaryOperator(RefRSVar,
1224                                                 clang::UO_AddrOf,
1225                                                 SetObjectFDArgType[0],
1226                                                 clang::VK_RValue,
1227                                                 clang::OK_Ordinary,
1228                                                 Loc));
1229   ArgList.push_back(InitExpr);
1230 
1231   clang::CallExpr *RSSetObjectCall =
1232       new(C) clang::CallExpr(C,
1233                              RSSetObjectFP,
1234                              ArgList,
1235                              SetObjectFD->getCallResultType(),
1236                              clang::VK_RValue,
1237                              Loc);
1238 
1239   std::list<clang::Stmt*> StmtList;
1240   StmtList.push_back(RSSetObjectCall);
1241   AppendAfterStmt(C, mCS, DS, StmtList);
1242 
1243   return;
1244 }
1245 
InsertLocalVarDestructors()1246 void RSObjectRefCount::Scope::InsertLocalVarDestructors() {
1247   for (std::list<clang::VarDecl*>::const_iterator I = mRSO.begin(),
1248           E = mRSO.end();
1249         I != E;
1250         I++) {
1251     clang::VarDecl *VD = *I;
1252     clang::Stmt *RSClearObjectCall = ClearRSObject(VD, VD->getDeclContext());
1253     if (RSClearObjectCall) {
1254       DestructorVisitor DV((*mRSO.begin())->getASTContext(),
1255                            mCS,
1256                            RSClearObjectCall,
1257                            VD->getSourceRange().getBegin());
1258       DV.Visit(mCS);
1259       DV.InsertDestructors();
1260     }
1261   }
1262   return;
1263 }
1264 
ClearRSObject(clang::VarDecl * VD,clang::DeclContext * DC)1265 clang::Stmt *RSObjectRefCount::Scope::ClearRSObject(
1266     clang::VarDecl *VD,
1267     clang::DeclContext *DC) {
1268   slangAssert(VD);
1269   clang::ASTContext &C = VD->getASTContext();
1270   clang::SourceLocation Loc = VD->getLocation();
1271   clang::SourceLocation StartLoc = VD->getInnerLocStart();
1272   const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1273 
1274   // Reference expr to target RS object variable
1275   clang::DeclRefExpr *RefRSVar =
1276       clang::DeclRefExpr::Create(C,
1277                                  clang::NestedNameSpecifierLoc(),
1278                                  clang::SourceLocation(),
1279                                  VD,
1280                                  false,
1281                                  Loc,
1282                                  T->getCanonicalTypeInternal(),
1283                                  clang::VK_RValue,
1284                                  NULL);
1285 
1286   if (T->isArrayType()) {
1287     return ClearArrayRSObject(C, DC, RefRSVar, StartLoc, Loc);
1288   }
1289 
1290   RSExportPrimitiveType::DataType DT =
1291       RSExportPrimitiveType::GetRSSpecificType(T);
1292 
1293   if (DT == RSExportPrimitiveType::DataTypeUnknown ||
1294       DT == RSExportPrimitiveType::DataTypeIsStruct) {
1295     return ClearStructRSObject(C, DC, RefRSVar, StartLoc, Loc);
1296   }
1297 
1298   slangAssert((RSExportPrimitiveType::IsRSObjectType(DT)) &&
1299               "Should be RS object");
1300 
1301   return ClearSingleRSObject(C, RefRSVar, Loc);
1302 }
1303 
InitializeRSObject(clang::VarDecl * VD,RSExportPrimitiveType::DataType * DT,clang::Expr ** InitExpr)1304 bool RSObjectRefCount::InitializeRSObject(clang::VarDecl *VD,
1305                                           RSExportPrimitiveType::DataType *DT,
1306                                           clang::Expr **InitExpr) {
1307   slangAssert(VD && DT && InitExpr);
1308   const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1309 
1310   // Loop through array types to get to base type
1311   while (T && T->isArrayType()) {
1312     T = T->getArrayElementTypeNoTypeQual();
1313   }
1314 
1315   bool DataTypeIsStructWithRSObject = false;
1316   *DT = RSExportPrimitiveType::GetRSSpecificType(T);
1317 
1318   if (*DT == RSExportPrimitiveType::DataTypeUnknown) {
1319     if (RSExportPrimitiveType::IsStructureTypeWithRSObject(T)) {
1320       *DT = RSExportPrimitiveType::DataTypeIsStruct;
1321       DataTypeIsStructWithRSObject = true;
1322     } else {
1323       return false;
1324     }
1325   }
1326 
1327   bool DataTypeIsRSObject = false;
1328   if (DataTypeIsStructWithRSObject) {
1329     DataTypeIsRSObject = true;
1330   } else {
1331     DataTypeIsRSObject = RSExportPrimitiveType::IsRSObjectType(*DT);
1332   }
1333   *InitExpr = VD->getInit();
1334 
1335   if (!DataTypeIsRSObject && *InitExpr) {
1336     // If we already have an initializer for a matrix type, we are done.
1337     return DataTypeIsRSObject;
1338   }
1339 
1340   clang::Expr *ZeroInitializer =
1341       CreateZeroInitializerForRSSpecificType(*DT,
1342                                              VD->getASTContext(),
1343                                              VD->getLocation());
1344 
1345   if (ZeroInitializer) {
1346     ZeroInitializer->setType(T->getCanonicalTypeInternal());
1347     VD->setInit(ZeroInitializer);
1348   }
1349 
1350   return DataTypeIsRSObject;
1351 }
1352 
CreateZeroInitializerForRSSpecificType(RSExportPrimitiveType::DataType DT,clang::ASTContext & C,const clang::SourceLocation & Loc)1353 clang::Expr *RSObjectRefCount::CreateZeroInitializerForRSSpecificType(
1354     RSExportPrimitiveType::DataType DT,
1355     clang::ASTContext &C,
1356     const clang::SourceLocation &Loc) {
1357   clang::Expr *Res = NULL;
1358   switch (DT) {
1359     case RSExportPrimitiveType::DataTypeIsStruct:
1360     case RSExportPrimitiveType::DataTypeRSElement:
1361     case RSExportPrimitiveType::DataTypeRSType:
1362     case RSExportPrimitiveType::DataTypeRSAllocation:
1363     case RSExportPrimitiveType::DataTypeRSSampler:
1364     case RSExportPrimitiveType::DataTypeRSScript:
1365     case RSExportPrimitiveType::DataTypeRSMesh:
1366     case RSExportPrimitiveType::DataTypeRSPath:
1367     case RSExportPrimitiveType::DataTypeRSProgramFragment:
1368     case RSExportPrimitiveType::DataTypeRSProgramVertex:
1369     case RSExportPrimitiveType::DataTypeRSProgramRaster:
1370     case RSExportPrimitiveType::DataTypeRSProgramStore:
1371     case RSExportPrimitiveType::DataTypeRSFont: {
1372       //    (ImplicitCastExpr 'nullptr_t'
1373       //      (IntegerLiteral 0)))
1374       llvm::APInt Zero(C.getTypeSize(C.IntTy), 0);
1375       clang::Expr *Int0 = clang::IntegerLiteral::Create(C, Zero, C.IntTy, Loc);
1376       clang::Expr *CastToNull =
1377           clang::ImplicitCastExpr::Create(C,
1378                                           C.NullPtrTy,
1379                                           clang::CK_IntegralToPointer,
1380                                           Int0,
1381                                           NULL,
1382                                           clang::VK_RValue);
1383 
1384       llvm::SmallVector<clang::Expr*, 1>InitList;
1385       InitList.push_back(CastToNull);
1386 
1387       Res = new(C) clang::InitListExpr(C, Loc, InitList, Loc);
1388       break;
1389     }
1390     case RSExportPrimitiveType::DataTypeRSMatrix2x2:
1391     case RSExportPrimitiveType::DataTypeRSMatrix3x3:
1392     case RSExportPrimitiveType::DataTypeRSMatrix4x4: {
1393       // RS matrix is not completely an RS object. They hold data by themselves.
1394       // (InitListExpr rs_matrix2x2
1395       //   (InitListExpr float[4]
1396       //     (FloatingLiteral 0)
1397       //     (FloatingLiteral 0)
1398       //     (FloatingLiteral 0)
1399       //     (FloatingLiteral 0)))
1400       clang::QualType FloatTy = C.FloatTy;
1401       // Constructor sets value to 0.0f by default
1402       llvm::APFloat Val(C.getFloatTypeSemantics(FloatTy));
1403       clang::FloatingLiteral *Float0Val =
1404           clang::FloatingLiteral::Create(C,
1405                                          Val,
1406                                          /* isExact = */true,
1407                                          FloatTy,
1408                                          Loc);
1409 
1410       unsigned N = 0;
1411       if (DT == RSExportPrimitiveType::DataTypeRSMatrix2x2)
1412         N = 2;
1413       else if (DT == RSExportPrimitiveType::DataTypeRSMatrix3x3)
1414         N = 3;
1415       else if (DT == RSExportPrimitiveType::DataTypeRSMatrix4x4)
1416         N = 4;
1417       unsigned N_2 = N * N;
1418 
1419       // Assume we are going to be allocating 16 elements, since 4x4 is max.
1420       llvm::SmallVector<clang::Expr*, 16> InitVals;
1421       for (unsigned i = 0; i < N_2; i++)
1422         InitVals.push_back(Float0Val);
1423       clang::Expr *InitExpr =
1424           new(C) clang::InitListExpr(C, Loc, InitVals, Loc);
1425       InitExpr->setType(C.getConstantArrayType(FloatTy,
1426                                                llvm::APInt(32, N_2),
1427                                                clang::ArrayType::Normal,
1428                                                /* EltTypeQuals = */0));
1429       llvm::SmallVector<clang::Expr*, 1> InitExprVec;
1430       InitExprVec.push_back(InitExpr);
1431 
1432       Res = new(C) clang::InitListExpr(C, Loc, InitExprVec, Loc);
1433       break;
1434     }
1435     case RSExportPrimitiveType::DataTypeUnknown:
1436     case RSExportPrimitiveType::DataTypeFloat16:
1437     case RSExportPrimitiveType::DataTypeFloat32:
1438     case RSExportPrimitiveType::DataTypeFloat64:
1439     case RSExportPrimitiveType::DataTypeSigned8:
1440     case RSExportPrimitiveType::DataTypeSigned16:
1441     case RSExportPrimitiveType::DataTypeSigned32:
1442     case RSExportPrimitiveType::DataTypeSigned64:
1443     case RSExportPrimitiveType::DataTypeUnsigned8:
1444     case RSExportPrimitiveType::DataTypeUnsigned16:
1445     case RSExportPrimitiveType::DataTypeUnsigned32:
1446     case RSExportPrimitiveType::DataTypeUnsigned64:
1447     case RSExportPrimitiveType::DataTypeBoolean:
1448     case RSExportPrimitiveType::DataTypeUnsigned565:
1449     case RSExportPrimitiveType::DataTypeUnsigned5551:
1450     case RSExportPrimitiveType::DataTypeUnsigned4444:
1451     case RSExportPrimitiveType::DataTypeMax: {
1452       slangAssert(false && "Not RS object type!");
1453     }
1454     // No default case will enable compiler detecting the missing cases
1455   }
1456 
1457   return Res;
1458 }
1459 
VisitDeclStmt(clang::DeclStmt * DS)1460 void RSObjectRefCount::VisitDeclStmt(clang::DeclStmt *DS) {
1461   for (clang::DeclStmt::decl_iterator I = DS->decl_begin(), E = DS->decl_end();
1462        I != E;
1463        I++) {
1464     clang::Decl *D = *I;
1465     if (D->getKind() == clang::Decl::Var) {
1466       clang::VarDecl *VD = static_cast<clang::VarDecl*>(D);
1467       RSExportPrimitiveType::DataType DT =
1468           RSExportPrimitiveType::DataTypeUnknown;
1469       clang::Expr *InitExpr = NULL;
1470       if (InitializeRSObject(VD, &DT, &InitExpr)) {
1471         // We need to zero-init all RS object types (including matrices), ...
1472         getCurrentScope()->AppendRSObjectInit(VD, DS, DT, InitExpr);
1473         // ... but, only add to the list of RS objects if we have some
1474         // non-matrix RS object fields.
1475         if (CountRSObjectTypes(mCtx, VD->getType().getTypePtr(),
1476                                VD->getLocation())) {
1477           getCurrentScope()->addRSObject(VD);
1478         }
1479       }
1480     }
1481   }
1482   return;
1483 }
1484 
VisitCompoundStmt(clang::CompoundStmt * CS)1485 void RSObjectRefCount::VisitCompoundStmt(clang::CompoundStmt *CS) {
1486   if (!CS->body_empty()) {
1487     // Push a new scope
1488     Scope *S = new Scope(CS);
1489     mScopeStack.push(S);
1490 
1491     VisitStmt(CS);
1492 
1493     // Destroy the scope
1494     slangAssert((getCurrentScope() == S) && "Corrupted scope stack!");
1495     S->InsertLocalVarDestructors();
1496     mScopeStack.pop();
1497     delete S;
1498   }
1499   return;
1500 }
1501 
VisitBinAssign(clang::BinaryOperator * AS)1502 void RSObjectRefCount::VisitBinAssign(clang::BinaryOperator *AS) {
1503   clang::QualType QT = AS->getType();
1504 
1505   if (CountRSObjectTypes(mCtx, QT.getTypePtr(), AS->getExprLoc())) {
1506     getCurrentScope()->ReplaceRSObjectAssignment(AS);
1507   }
1508 
1509   return;
1510 }
1511 
VisitStmt(clang::Stmt * S)1512 void RSObjectRefCount::VisitStmt(clang::Stmt *S) {
1513   for (clang::Stmt::child_iterator I = S->child_begin(), E = S->child_end();
1514        I != E;
1515        I++) {
1516     if (clang::Stmt *Child = *I) {
1517       Visit(Child);
1518     }
1519   }
1520   return;
1521 }
1522 
1523 // This function walks the list of global variables and (potentially) creates
1524 // a single global static destructor function that properly decrements
1525 // reference counts on the contained RS object types.
CreateStaticGlobalDtor()1526 clang::FunctionDecl *RSObjectRefCount::CreateStaticGlobalDtor() {
1527   Init();
1528 
1529   clang::DeclContext *DC = mCtx.getTranslationUnitDecl();
1530   clang::SourceLocation loc;
1531 
1532   llvm::StringRef SR(".rs.dtor");
1533   clang::IdentifierInfo &II = mCtx.Idents.get(SR);
1534   clang::DeclarationName N(&II);
1535   clang::FunctionProtoType::ExtProtoInfo EPI;
1536   clang::QualType T = mCtx.getFunctionType(mCtx.VoidTy,
1537       llvm::ArrayRef<clang::QualType>(), EPI);
1538   clang::FunctionDecl *FD = NULL;
1539 
1540   // Generate rsClearObject() call chains for every global variable
1541   // (whether static or extern).
1542   std::list<clang::Stmt *> StmtList;
1543   for (clang::DeclContext::decl_iterator I = DC->decls_begin(),
1544           E = DC->decls_end(); I != E; I++) {
1545     clang::VarDecl *VD = llvm::dyn_cast<clang::VarDecl>(*I);
1546     if (VD) {
1547       if (CountRSObjectTypes(mCtx, VD->getType().getTypePtr(), loc)) {
1548         if (!FD) {
1549           // Only create FD if we are going to use it.
1550           FD = clang::FunctionDecl::Create(mCtx, DC, loc, loc, N, T, NULL);
1551         }
1552         // Make sure to create any helpers within the function's DeclContext,
1553         // not the one associated with the global translation unit.
1554         clang::Stmt *RSClearObjectCall = Scope::ClearRSObject(VD, FD);
1555         StmtList.push_back(RSClearObjectCall);
1556       }
1557     }
1558   }
1559 
1560   // Nothing needs to be destroyed, so don't emit a dtor.
1561   if (StmtList.empty()) {
1562     return NULL;
1563   }
1564 
1565   clang::CompoundStmt *CS = BuildCompoundStmt(mCtx, StmtList, loc);
1566 
1567   FD->setBody(CS);
1568 
1569   return FD;
1570 }
1571 
1572 }  // namespace slang
1573