• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2010, The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "slang_rs_object_ref_count.h"
18 
19 #include <list>
20 
21 #include "clang/AST/DeclGroup.h"
22 #include "clang/AST/Expr.h"
23 #include "clang/AST/NestedNameSpecifier.h"
24 #include "clang/AST/OperationKinds.h"
25 #include "clang/AST/Stmt.h"
26 #include "clang/AST/StmtVisitor.h"
27 
28 #include "slang_assert.h"
29 #include "slang_rs.h"
30 #include "slang_rs_ast_replace.h"
31 #include "slang_rs_export_type.h"
32 
33 namespace slang {
34 
35 clang::FunctionDecl *RSObjectRefCount::
36     RSSetObjectFD[RSExportPrimitiveType::LastRSObjectType -
37                   RSExportPrimitiveType::FirstRSObjectType + 1];
38 clang::FunctionDecl *RSObjectRefCount::
39     RSClearObjectFD[RSExportPrimitiveType::LastRSObjectType -
40                     RSExportPrimitiveType::FirstRSObjectType + 1];
41 
GetRSRefCountingFunctions(clang::ASTContext & C)42 void RSObjectRefCount::GetRSRefCountingFunctions(clang::ASTContext &C) {
43   for (unsigned i = 0;
44        i < (sizeof(RSClearObjectFD) / sizeof(clang::FunctionDecl*));
45        i++) {
46     RSSetObjectFD[i] = NULL;
47     RSClearObjectFD[i] = NULL;
48   }
49 
50   clang::TranslationUnitDecl *TUDecl = C.getTranslationUnitDecl();
51 
52   for (clang::DeclContext::decl_iterator I = TUDecl->decls_begin(),
53           E = TUDecl->decls_end(); I != E; I++) {
54     if ((I->getKind() >= clang::Decl::firstFunction) &&
55         (I->getKind() <= clang::Decl::lastFunction)) {
56       clang::FunctionDecl *FD = static_cast<clang::FunctionDecl*>(*I);
57 
58       // points to RSSetObjectFD or RSClearObjectFD
59       clang::FunctionDecl **RSObjectFD;
60 
61       if (FD->getName() == "rsSetObject") {
62         slangAssert((FD->getNumParams() == 2) &&
63                     "Invalid rsSetObject function prototype (# params)");
64         RSObjectFD = RSSetObjectFD;
65       } else if (FD->getName() == "rsClearObject") {
66         slangAssert((FD->getNumParams() == 1) &&
67                     "Invalid rsClearObject function prototype (# params)");
68         RSObjectFD = RSClearObjectFD;
69       } else {
70         continue;
71       }
72 
73       const clang::ParmVarDecl *PVD = FD->getParamDecl(0);
74       clang::QualType PVT = PVD->getOriginalType();
75       // The first parameter must be a pointer like rs_allocation*
76       slangAssert(PVT->isPointerType() &&
77           "Invalid rs{Set,Clear}Object function prototype (pointer param)");
78 
79       // The rs object type passed to the FD
80       clang::QualType RST = PVT->getPointeeType();
81       RSExportPrimitiveType::DataType DT =
82           RSExportPrimitiveType::GetRSSpecificType(RST.getTypePtr());
83       slangAssert(RSExportPrimitiveType::IsRSObjectType(DT)
84              && "must be RS object type");
85 
86       RSObjectFD[(DT - RSExportPrimitiveType::FirstRSObjectType)] = FD;
87     }
88   }
89 }
90 
91 namespace {
92 
93 // This function constructs a new CompoundStmt from the input StmtList.
BuildCompoundStmt(clang::ASTContext & C,std::list<clang::Stmt * > & StmtList,clang::SourceLocation Loc)94 static clang::CompoundStmt* BuildCompoundStmt(clang::ASTContext &C,
95       std::list<clang::Stmt*> &StmtList, clang::SourceLocation Loc) {
96   unsigned NewStmtCount = StmtList.size();
97   unsigned CompoundStmtCount = 0;
98 
99   clang::Stmt **CompoundStmtList;
100   CompoundStmtList = new clang::Stmt*[NewStmtCount];
101 
102   std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
103   std::list<clang::Stmt*>::const_iterator E = StmtList.end();
104   for ( ; I != E; I++) {
105     CompoundStmtList[CompoundStmtCount++] = *I;
106   }
107   slangAssert(CompoundStmtCount == NewStmtCount);
108 
109   clang::CompoundStmt *CS = new(C) clang::CompoundStmt(C,
110                                                        CompoundStmtList,
111                                                        CompoundStmtCount,
112                                                        Loc,
113                                                        Loc);
114 
115   delete [] CompoundStmtList;
116 
117   return CS;
118 }
119 
AppendAfterStmt(clang::ASTContext & C,clang::CompoundStmt * CS,clang::Stmt * S,std::list<clang::Stmt * > & StmtList)120 static void AppendAfterStmt(clang::ASTContext &C,
121                             clang::CompoundStmt *CS,
122                             clang::Stmt *S,
123                             std::list<clang::Stmt*> &StmtList) {
124   slangAssert(CS);
125   clang::CompoundStmt::body_iterator bI = CS->body_begin();
126   clang::CompoundStmt::body_iterator bE = CS->body_end();
127   clang::Stmt **UpdatedStmtList =
128       new clang::Stmt*[CS->size() + StmtList.size()];
129 
130   unsigned UpdatedStmtCount = 0;
131   unsigned Once = 0;
132   for ( ; bI != bE; bI++) {
133     if (!S && ((*bI)->getStmtClass() == clang::Stmt::ReturnStmtClass)) {
134       // If we come across a return here, we don't have anything we can
135       // reasonably replace. We should have already inserted our destructor
136       // code in the proper spot, so we just clean up and return.
137       delete [] UpdatedStmtList;
138 
139       return;
140     }
141 
142     UpdatedStmtList[UpdatedStmtCount++] = *bI;
143 
144     if ((*bI == S) && !Once) {
145       Once++;
146       std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
147       std::list<clang::Stmt*>::const_iterator E = StmtList.end();
148       for ( ; I != E; I++) {
149         UpdatedStmtList[UpdatedStmtCount++] = *I;
150       }
151     }
152   }
153   slangAssert(Once <= 1);
154 
155   // When S is NULL, we are appending to the end of the CompoundStmt.
156   if (!S) {
157     slangAssert(Once == 0);
158     std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
159     std::list<clang::Stmt*>::const_iterator E = StmtList.end();
160     for ( ; I != E; I++) {
161       UpdatedStmtList[UpdatedStmtCount++] = *I;
162     }
163   }
164 
165   CS->setStmts(C, UpdatedStmtList, UpdatedStmtCount);
166 
167   delete [] UpdatedStmtList;
168 
169   return;
170 }
171 
172 // This class visits a compound statement and inserts DtorStmt
173 // in proper locations. This includes inserting it before any
174 // return statement in any sub-block, at the end of the logical enclosing
175 // scope (compound statement), and/or before any break/continue statement that
176 // would resume outside the declared scope. We will not handle the case for
177 // goto statements that leave a local scope.
178 //
179 // To accomplish these goals, it collects a list of sub-Stmt's that
180 // correspond to scope exit points. It then uses an RSASTReplace visitor to
181 // transform the AST, inserting appropriate destructors before each of those
182 // sub-Stmt's (and also before the exit of the outermost containing Stmt for
183 // the scope).
184 class DestructorVisitor : public clang::StmtVisitor<DestructorVisitor> {
185  private:
186   clang::ASTContext &mCtx;
187 
188   // The loop depth of the currently visited node.
189   int mLoopDepth;
190 
191   // The switch statement depth of the currently visited node.
192   // Note that this is tracked separately from the loop depth because
193   // SwitchStmt-contained ContinueStmt's should have destructors for the
194   // corresponding loop scope.
195   int mSwitchDepth;
196 
197   // The outermost statement block that we are currently visiting.
198   // This should always be a CompoundStmt.
199   clang::Stmt *mOuterStmt;
200 
201   // The destructor to execute for this scope/variable.
202   clang::Stmt* mDtorStmt;
203 
204   // The stack of statements which should be replaced by a compound statement
205   // containing the new destructor call followed by the original Stmt.
206   std::stack<clang::Stmt*> mReplaceStmtStack;
207 
208   // The source location for the variable declaration that we are trying to
209   // insert destructors for. Note that InsertDestructors() will not generate
210   // destructor calls for source locations that occur lexically before this
211   // location.
212   clang::SourceLocation mVarLoc;
213 
214  public:
215   DestructorVisitor(clang::ASTContext &C,
216                     clang::Stmt* OuterStmt,
217                     clang::Stmt* DtorStmt,
218                     clang::SourceLocation VarLoc);
219 
220   // This code walks the collected list of Stmts to replace and actually does
221   // the replacement. It also finishes up by appending the destructor to the
222   // current outermost CompoundStmt.
InsertDestructors()223   void InsertDestructors() {
224     clang::Stmt *S = NULL;
225     clang::SourceManager &SM = mCtx.getSourceManager();
226     std::list<clang::Stmt *> StmtList;
227     StmtList.push_back(mDtorStmt);
228 
229     while (!mReplaceStmtStack.empty()) {
230       S = mReplaceStmtStack.top();
231       mReplaceStmtStack.pop();
232 
233       // Skip all source locations that occur before the variable's
234       // declaration, since it won't have been initialized yet.
235       if (SM.isBeforeInTranslationUnit(S->getLocStart(), mVarLoc)) {
236         continue;
237       }
238 
239       StmtList.push_back(S);
240       clang::CompoundStmt *CS =
241           BuildCompoundStmt(mCtx, StmtList, S->getLocEnd());
242       StmtList.pop_back();
243 
244       RSASTReplace R(mCtx);
245       R.ReplaceStmt(mOuterStmt, S, CS);
246     }
247     clang::CompoundStmt *CS =
248       llvm::dyn_cast<clang::CompoundStmt>(mOuterStmt);
249     slangAssert(CS);
250     AppendAfterStmt(mCtx, CS, NULL, StmtList);
251   }
252 
253   void VisitStmt(clang::Stmt *S);
254   void VisitCompoundStmt(clang::CompoundStmt *CS);
255 
256   void VisitBreakStmt(clang::BreakStmt *BS);
257   void VisitCaseStmt(clang::CaseStmt *CS);
258   void VisitContinueStmt(clang::ContinueStmt *CS);
259   void VisitDefaultStmt(clang::DefaultStmt *DS);
260   void VisitDoStmt(clang::DoStmt *DS);
261   void VisitForStmt(clang::ForStmt *FS);
262   void VisitIfStmt(clang::IfStmt *IS);
263   void VisitReturnStmt(clang::ReturnStmt *RS);
264   void VisitSwitchCase(clang::SwitchCase *SC);
265   void VisitSwitchStmt(clang::SwitchStmt *SS);
266   void VisitWhileStmt(clang::WhileStmt *WS);
267 };
268 
DestructorVisitor(clang::ASTContext & C,clang::Stmt * OuterStmt,clang::Stmt * DtorStmt,clang::SourceLocation VarLoc)269 DestructorVisitor::DestructorVisitor(clang::ASTContext &C,
270                          clang::Stmt *OuterStmt,
271                          clang::Stmt *DtorStmt,
272                          clang::SourceLocation VarLoc)
273   : mCtx(C),
274     mLoopDepth(0),
275     mSwitchDepth(0),
276     mOuterStmt(OuterStmt),
277     mDtorStmt(DtorStmt),
278     mVarLoc(VarLoc) {
279   return;
280 }
281 
VisitStmt(clang::Stmt * S)282 void DestructorVisitor::VisitStmt(clang::Stmt *S) {
283   for (clang::Stmt::child_iterator I = S->child_begin(), E = S->child_end();
284        I != E;
285        I++) {
286     if (clang::Stmt *Child = *I) {
287       Visit(Child);
288     }
289   }
290   return;
291 }
292 
VisitCompoundStmt(clang::CompoundStmt * CS)293 void DestructorVisitor::VisitCompoundStmt(clang::CompoundStmt *CS) {
294   VisitStmt(CS);
295   return;
296 }
297 
VisitBreakStmt(clang::BreakStmt * BS)298 void DestructorVisitor::VisitBreakStmt(clang::BreakStmt *BS) {
299   VisitStmt(BS);
300   if ((mLoopDepth == 0) && (mSwitchDepth == 0)) {
301     mReplaceStmtStack.push(BS);
302   }
303   return;
304 }
305 
VisitCaseStmt(clang::CaseStmt * CS)306 void DestructorVisitor::VisitCaseStmt(clang::CaseStmt *CS) {
307   VisitStmt(CS);
308   return;
309 }
310 
VisitContinueStmt(clang::ContinueStmt * CS)311 void DestructorVisitor::VisitContinueStmt(clang::ContinueStmt *CS) {
312   VisitStmt(CS);
313   if (mLoopDepth == 0) {
314     // Switch statements can have nested continues.
315     mReplaceStmtStack.push(CS);
316   }
317   return;
318 }
319 
VisitDefaultStmt(clang::DefaultStmt * DS)320 void DestructorVisitor::VisitDefaultStmt(clang::DefaultStmt *DS) {
321   VisitStmt(DS);
322   return;
323 }
324 
VisitDoStmt(clang::DoStmt * DS)325 void DestructorVisitor::VisitDoStmt(clang::DoStmt *DS) {
326   mLoopDepth++;
327   VisitStmt(DS);
328   mLoopDepth--;
329   return;
330 }
331 
VisitForStmt(clang::ForStmt * FS)332 void DestructorVisitor::VisitForStmt(clang::ForStmt *FS) {
333   mLoopDepth++;
334   VisitStmt(FS);
335   mLoopDepth--;
336   return;
337 }
338 
VisitIfStmt(clang::IfStmt * IS)339 void DestructorVisitor::VisitIfStmt(clang::IfStmt *IS) {
340   VisitStmt(IS);
341   return;
342 }
343 
VisitReturnStmt(clang::ReturnStmt * RS)344 void DestructorVisitor::VisitReturnStmt(clang::ReturnStmt *RS) {
345   mReplaceStmtStack.push(RS);
346   return;
347 }
348 
VisitSwitchCase(clang::SwitchCase * SC)349 void DestructorVisitor::VisitSwitchCase(clang::SwitchCase *SC) {
350   slangAssert(false && "Both case and default have specialized handlers");
351   VisitStmt(SC);
352   return;
353 }
354 
VisitSwitchStmt(clang::SwitchStmt * SS)355 void DestructorVisitor::VisitSwitchStmt(clang::SwitchStmt *SS) {
356   mSwitchDepth++;
357   VisitStmt(SS);
358   mSwitchDepth--;
359   return;
360 }
361 
VisitWhileStmt(clang::WhileStmt * WS)362 void DestructorVisitor::VisitWhileStmt(clang::WhileStmt *WS) {
363   mLoopDepth++;
364   VisitStmt(WS);
365   mLoopDepth--;
366   return;
367 }
368 
ClearSingleRSObject(clang::ASTContext & C,clang::Expr * RefRSVar,clang::SourceLocation Loc)369 clang::Expr *ClearSingleRSObject(clang::ASTContext &C,
370                                  clang::Expr *RefRSVar,
371                                  clang::SourceLocation Loc) {
372   slangAssert(RefRSVar);
373   const clang::Type *T = RefRSVar->getType().getTypePtr();
374   slangAssert(!T->isArrayType() &&
375               "Should not be destroying arrays with this function");
376 
377   clang::FunctionDecl *ClearObjectFD = RSObjectRefCount::GetRSClearObjectFD(T);
378   slangAssert((ClearObjectFD != NULL) &&
379               "rsClearObject doesn't cover all RS object types");
380 
381   clang::QualType ClearObjectFDType = ClearObjectFD->getType();
382   clang::QualType ClearObjectFDArgType =
383       ClearObjectFD->getParamDecl(0)->getOriginalType();
384 
385   // Example destructor for "rs_font localFont;"
386   //
387   // (CallExpr 'void'
388   //   (ImplicitCastExpr 'void (*)(rs_font *)' <FunctionToPointerDecay>
389   //     (DeclRefExpr 'void (rs_font *)' FunctionDecl='rsClearObject'))
390   //   (UnaryOperator 'rs_font *' prefix '&'
391   //     (DeclRefExpr 'rs_font':'rs_font' Var='localFont')))
392 
393   // Get address of targeted RS object
394   clang::Expr *AddrRefRSVar =
395       new(C) clang::UnaryOperator(RefRSVar,
396                                   clang::UO_AddrOf,
397                                   ClearObjectFDArgType,
398                                   clang::VK_RValue,
399                                   clang::OK_Ordinary,
400                                   Loc);
401 
402   clang::Expr *RefRSClearObjectFD =
403       clang::DeclRefExpr::Create(C,
404                                  clang::NestedNameSpecifierLoc(),
405                                  clang::SourceLocation(),
406                                  ClearObjectFD,
407                                  false,
408                                  ClearObjectFD->getLocation(),
409                                  ClearObjectFDType,
410                                  clang::VK_RValue,
411                                  NULL);
412 
413   clang::Expr *RSClearObjectFP =
414       clang::ImplicitCastExpr::Create(C,
415                                       C.getPointerType(ClearObjectFDType),
416                                       clang::CK_FunctionToPointerDecay,
417                                       RefRSClearObjectFD,
418                                       NULL,
419                                       clang::VK_RValue);
420 
421   clang::CallExpr *RSClearObjectCall =
422       new(C) clang::CallExpr(C,
423                              RSClearObjectFP,
424                              &AddrRefRSVar,
425                              1,
426                              ClearObjectFD->getCallResultType(),
427                              clang::VK_RValue,
428                              Loc);
429 
430   return RSClearObjectCall;
431 }
432 
ArrayDim(const clang::Type * T)433 static int ArrayDim(const clang::Type *T) {
434   if (!T || !T->isArrayType()) {
435     return 0;
436   }
437 
438   const clang::ConstantArrayType *CAT =
439     static_cast<const clang::ConstantArrayType *>(T);
440   return static_cast<int>(CAT->getSize().getSExtValue());
441 }
442 
443 static clang::Stmt *ClearStructRSObject(
444     clang::ASTContext &C,
445     clang::DeclContext *DC,
446     clang::Expr *RefRSStruct,
447     clang::SourceLocation StartLoc,
448     clang::SourceLocation Loc);
449 
ClearArrayRSObject(clang::ASTContext & C,clang::DeclContext * DC,clang::Expr * RefRSArr,clang::SourceLocation StartLoc,clang::SourceLocation Loc)450 static clang::Stmt *ClearArrayRSObject(
451     clang::ASTContext &C,
452     clang::DeclContext *DC,
453     clang::Expr *RefRSArr,
454     clang::SourceLocation StartLoc,
455     clang::SourceLocation Loc) {
456   const clang::Type *BaseType = RefRSArr->getType().getTypePtr();
457   slangAssert(BaseType->isArrayType());
458 
459   int NumArrayElements = ArrayDim(BaseType);
460   // Actually extract out the base RS object type for use later
461   BaseType = BaseType->getArrayElementTypeNoTypeQual();
462 
463   clang::Stmt *StmtArray[2] = {NULL};
464   int StmtCtr = 0;
465 
466   if (NumArrayElements <= 0) {
467     return NULL;
468   }
469 
470   // Example destructor loop for "rs_font fontArr[10];"
471   //
472   // (CompoundStmt
473   //   (DeclStmt "int rsIntIter")
474   //   (ForStmt
475   //     (BinaryOperator 'int' '='
476   //       (DeclRefExpr 'int' Var='rsIntIter')
477   //       (IntegerLiteral 'int' 0))
478   //     (BinaryOperator 'int' '<'
479   //       (DeclRefExpr 'int' Var='rsIntIter')
480   //       (IntegerLiteral 'int' 10)
481   //     NULL << CondVar >>
482   //     (UnaryOperator 'int' postfix '++'
483   //       (DeclRefExpr 'int' Var='rsIntIter'))
484   //     (CallExpr 'void'
485   //       (ImplicitCastExpr 'void (*)(rs_font *)' <FunctionToPointerDecay>
486   //         (DeclRefExpr 'void (rs_font *)' FunctionDecl='rsClearObject'))
487   //       (UnaryOperator 'rs_font *' prefix '&'
488   //         (ArraySubscriptExpr 'rs_font':'rs_font'
489   //           (ImplicitCastExpr 'rs_font *' <ArrayToPointerDecay>
490   //             (DeclRefExpr 'rs_font [10]' Var='fontArr'))
491   //           (DeclRefExpr 'int' Var='rsIntIter')))))))
492 
493   // Create helper variable for iterating through elements
494   clang::IdentifierInfo& II = C.Idents.get("rsIntIter");
495   clang::VarDecl *IIVD =
496       clang::VarDecl::Create(C,
497                              DC,
498                              StartLoc,
499                              Loc,
500                              &II,
501                              C.IntTy,
502                              C.getTrivialTypeSourceInfo(C.IntTy),
503                              clang::SC_None,
504                              clang::SC_None);
505   clang::Decl *IID = (clang::Decl *)IIVD;
506 
507   clang::DeclGroupRef DGR = clang::DeclGroupRef::Create(C, &IID, 1);
508   StmtArray[StmtCtr++] = new(C) clang::DeclStmt(DGR, Loc, Loc);
509 
510   // Form the actual destructor loop
511   // for (Init; Cond; Inc)
512   //   RSClearObjectCall;
513 
514   // Init -> "rsIntIter = 0"
515   clang::DeclRefExpr *RefrsIntIter =
516       clang::DeclRefExpr::Create(C,
517                                  clang::NestedNameSpecifierLoc(),
518                                  clang::SourceLocation(),
519                                  IIVD,
520                                  false,
521                                  Loc,
522                                  C.IntTy,
523                                  clang::VK_RValue,
524                                  NULL);
525 
526   clang::Expr *Int0 = clang::IntegerLiteral::Create(C,
527       llvm::APInt(C.getTypeSize(C.IntTy), 0), C.IntTy, Loc);
528 
529   clang::BinaryOperator *Init =
530       new(C) clang::BinaryOperator(RefrsIntIter,
531                                    Int0,
532                                    clang::BO_Assign,
533                                    C.IntTy,
534                                    clang::VK_RValue,
535                                    clang::OK_Ordinary,
536                                    Loc);
537 
538   // Cond -> "rsIntIter < NumArrayElements"
539   clang::Expr *NumArrayElementsExpr = clang::IntegerLiteral::Create(C,
540       llvm::APInt(C.getTypeSize(C.IntTy), NumArrayElements), C.IntTy, Loc);
541 
542   clang::BinaryOperator *Cond =
543       new(C) clang::BinaryOperator(RefrsIntIter,
544                                    NumArrayElementsExpr,
545                                    clang::BO_LT,
546                                    C.IntTy,
547                                    clang::VK_RValue,
548                                    clang::OK_Ordinary,
549                                    Loc);
550 
551   // Inc -> "rsIntIter++"
552   clang::UnaryOperator *Inc =
553       new(C) clang::UnaryOperator(RefrsIntIter,
554                                   clang::UO_PostInc,
555                                   C.IntTy,
556                                   clang::VK_RValue,
557                                   clang::OK_Ordinary,
558                                   Loc);
559 
560   // Body -> "rsClearObject(&VD[rsIntIter]);"
561   // Destructor loop operates on individual array elements
562 
563   clang::Expr *RefRSArrPtr =
564       clang::ImplicitCastExpr::Create(C,
565           C.getPointerType(BaseType->getCanonicalTypeInternal()),
566           clang::CK_ArrayToPointerDecay,
567           RefRSArr,
568           NULL,
569           clang::VK_RValue);
570 
571   clang::Expr *RefRSArrPtrSubscript =
572       new(C) clang::ArraySubscriptExpr(RefRSArrPtr,
573                                        RefrsIntIter,
574                                        BaseType->getCanonicalTypeInternal(),
575                                        clang::VK_RValue,
576                                        clang::OK_Ordinary,
577                                        Loc);
578 
579   RSExportPrimitiveType::DataType DT =
580       RSExportPrimitiveType::GetRSSpecificType(BaseType);
581 
582   clang::Stmt *RSClearObjectCall = NULL;
583   if (BaseType->isArrayType()) {
584     RSClearObjectCall =
585         ClearArrayRSObject(C, DC, RefRSArrPtrSubscript, StartLoc, Loc);
586   } else if (DT == RSExportPrimitiveType::DataTypeUnknown) {
587     RSClearObjectCall =
588         ClearStructRSObject(C, DC, RefRSArrPtrSubscript, StartLoc, Loc);
589   } else {
590     RSClearObjectCall = ClearSingleRSObject(C, RefRSArrPtrSubscript, Loc);
591   }
592 
593   clang::ForStmt *DestructorLoop =
594       new(C) clang::ForStmt(C,
595                             Init,
596                             Cond,
597                             NULL,  // no condVar
598                             Inc,
599                             RSClearObjectCall,
600                             Loc,
601                             Loc,
602                             Loc);
603 
604   StmtArray[StmtCtr++] = DestructorLoop;
605   slangAssert(StmtCtr == 2);
606 
607   clang::CompoundStmt *CS =
608       new(C) clang::CompoundStmt(C, StmtArray, StmtCtr, Loc, Loc);
609 
610   return CS;
611 }
612 
CountRSObjectTypes(clang::ASTContext & C,const clang::Type * T,clang::SourceLocation Loc)613 static unsigned CountRSObjectTypes(clang::ASTContext &C,
614                                    const clang::Type *T,
615                                    clang::SourceLocation Loc) {
616   slangAssert(T);
617   unsigned RSObjectCount = 0;
618 
619   if (T->isArrayType()) {
620     return CountRSObjectTypes(C, T->getArrayElementTypeNoTypeQual(), Loc);
621   }
622 
623   RSExportPrimitiveType::DataType DT =
624       RSExportPrimitiveType::GetRSSpecificType(T);
625   if (DT != RSExportPrimitiveType::DataTypeUnknown) {
626     return (RSExportPrimitiveType::IsRSObjectType(DT) ? 1 : 0);
627   }
628 
629   if (T->isUnionType()) {
630     clang::RecordDecl *RD = T->getAsUnionType()->getDecl();
631     RD = RD->getDefinition();
632     for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
633            FE = RD->field_end();
634          FI != FE;
635          FI++) {
636       const clang::FieldDecl *FD = *FI;
637       const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
638       if (CountRSObjectTypes(C, FT, Loc)) {
639         slangAssert(false && "can't have unions with RS object types!");
640         return 0;
641       }
642     }
643   }
644 
645   if (!T->isStructureType()) {
646     return 0;
647   }
648 
649   clang::RecordDecl *RD = T->getAsStructureType()->getDecl();
650   RD = RD->getDefinition();
651   for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
652          FE = RD->field_end();
653        FI != FE;
654        FI++) {
655     const clang::FieldDecl *FD = *FI;
656     const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
657     if (CountRSObjectTypes(C, FT, Loc)) {
658       // Sub-structs should only count once (as should arrays, etc.)
659       RSObjectCount++;
660     }
661   }
662 
663   return RSObjectCount;
664 }
665 
ClearStructRSObject(clang::ASTContext & C,clang::DeclContext * DC,clang::Expr * RefRSStruct,clang::SourceLocation StartLoc,clang::SourceLocation Loc)666 static clang::Stmt *ClearStructRSObject(
667     clang::ASTContext &C,
668     clang::DeclContext *DC,
669     clang::Expr *RefRSStruct,
670     clang::SourceLocation StartLoc,
671     clang::SourceLocation Loc) {
672   const clang::Type *BaseType = RefRSStruct->getType().getTypePtr();
673 
674   slangAssert(!BaseType->isArrayType());
675 
676   // Structs should show up as unknown primitive types
677   slangAssert(RSExportPrimitiveType::GetRSSpecificType(BaseType) ==
678               RSExportPrimitiveType::DataTypeUnknown);
679 
680   unsigned FieldsToDestroy = CountRSObjectTypes(C, BaseType, Loc);
681 
682   unsigned StmtCount = 0;
683   clang::Stmt **StmtArray = new clang::Stmt*[FieldsToDestroy];
684   for (unsigned i = 0; i < FieldsToDestroy; i++) {
685     StmtArray[i] = NULL;
686   }
687 
688   // Populate StmtArray by creating a destructor for each RS object field
689   clang::RecordDecl *RD = BaseType->getAsStructureType()->getDecl();
690   RD = RD->getDefinition();
691   for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
692          FE = RD->field_end();
693        FI != FE;
694        FI++) {
695     // We just look through all field declarations to see if we find a
696     // declaration for an RS object type (or an array of one).
697     bool IsArrayType = false;
698     clang::FieldDecl *FD = *FI;
699     const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
700     const clang::Type *OrigType = FT;
701     while (FT && FT->isArrayType()) {
702       FT = FT->getArrayElementTypeNoTypeQual();
703       IsArrayType = true;
704     }
705 
706     if (RSExportPrimitiveType::IsRSObjectType(FT)) {
707       clang::DeclAccessPair FoundDecl =
708           clang::DeclAccessPair::make(FD, clang::AS_none);
709       clang::MemberExpr *RSObjectMember =
710           clang::MemberExpr::Create(C,
711                                     RefRSStruct,
712                                     false,
713                                     clang::NestedNameSpecifierLoc(),
714                                     clang::SourceLocation(),
715                                     FD,
716                                     FoundDecl,
717                                     clang::DeclarationNameInfo(),
718                                     NULL,
719                                     OrigType->getCanonicalTypeInternal(),
720                                     clang::VK_RValue,
721                                     clang::OK_Ordinary);
722 
723       slangAssert(StmtCount < FieldsToDestroy);
724 
725       if (IsArrayType) {
726         StmtArray[StmtCount++] = ClearArrayRSObject(C,
727                                                     DC,
728                                                     RSObjectMember,
729                                                     StartLoc,
730                                                     Loc);
731       } else {
732         StmtArray[StmtCount++] = ClearSingleRSObject(C,
733                                                      RSObjectMember,
734                                                      Loc);
735       }
736     } else if (FT->isStructureType() && CountRSObjectTypes(C, FT, Loc)) {
737       // In this case, we have a nested struct. We may not end up filling all
738       // of the spaces in StmtArray (sub-structs should handle themselves
739       // with separate compound statements).
740       clang::DeclAccessPair FoundDecl =
741           clang::DeclAccessPair::make(FD, clang::AS_none);
742       clang::MemberExpr *RSObjectMember =
743           clang::MemberExpr::Create(C,
744                                     RefRSStruct,
745                                     false,
746                                     clang::NestedNameSpecifierLoc(),
747                                     clang::SourceLocation(),
748                                     FD,
749                                     FoundDecl,
750                                     clang::DeclarationNameInfo(),
751                                     NULL,
752                                     OrigType->getCanonicalTypeInternal(),
753                                     clang::VK_RValue,
754                                     clang::OK_Ordinary);
755 
756       if (IsArrayType) {
757         StmtArray[StmtCount++] = ClearArrayRSObject(C,
758                                                     DC,
759                                                     RSObjectMember,
760                                                     StartLoc,
761                                                     Loc);
762       } else {
763         StmtArray[StmtCount++] = ClearStructRSObject(C,
764                                                      DC,
765                                                      RSObjectMember,
766                                                      StartLoc,
767                                                      Loc);
768       }
769     }
770   }
771 
772   slangAssert(StmtCount > 0);
773   clang::CompoundStmt *CS =
774       new(C) clang::CompoundStmt(C, StmtArray, StmtCount, Loc, Loc);
775 
776   delete [] StmtArray;
777 
778   return CS;
779 }
780 
CreateSingleRSSetObject(clang::ASTContext & C,clang::Expr * DstExpr,clang::Expr * SrcExpr,clang::SourceLocation StartLoc,clang::SourceLocation Loc)781 static clang::Stmt *CreateSingleRSSetObject(clang::ASTContext &C,
782                                             clang::Expr *DstExpr,
783                                             clang::Expr *SrcExpr,
784                                             clang::SourceLocation StartLoc,
785                                             clang::SourceLocation Loc) {
786   const clang::Type *T = DstExpr->getType().getTypePtr();
787   clang::FunctionDecl *SetObjectFD = RSObjectRefCount::GetRSSetObjectFD(T);
788   slangAssert((SetObjectFD != NULL) &&
789               "rsSetObject doesn't cover all RS object types");
790 
791   clang::QualType SetObjectFDType = SetObjectFD->getType();
792   clang::QualType SetObjectFDArgType[2];
793   SetObjectFDArgType[0] = SetObjectFD->getParamDecl(0)->getOriginalType();
794   SetObjectFDArgType[1] = SetObjectFD->getParamDecl(1)->getOriginalType();
795 
796   clang::Expr *RefRSSetObjectFD =
797       clang::DeclRefExpr::Create(C,
798                                  clang::NestedNameSpecifierLoc(),
799                                  clang::SourceLocation(),
800                                  SetObjectFD,
801                                  false,
802                                  Loc,
803                                  SetObjectFDType,
804                                  clang::VK_RValue,
805                                  NULL);
806 
807   clang::Expr *RSSetObjectFP =
808       clang::ImplicitCastExpr::Create(C,
809                                       C.getPointerType(SetObjectFDType),
810                                       clang::CK_FunctionToPointerDecay,
811                                       RefRSSetObjectFD,
812                                       NULL,
813                                       clang::VK_RValue);
814 
815   clang::Expr *ArgList[2];
816   ArgList[0] = new(C) clang::UnaryOperator(DstExpr,
817                                            clang::UO_AddrOf,
818                                            SetObjectFDArgType[0],
819                                            clang::VK_RValue,
820                                            clang::OK_Ordinary,
821                                            Loc);
822   ArgList[1] = SrcExpr;
823 
824   clang::CallExpr *RSSetObjectCall =
825       new(C) clang::CallExpr(C,
826                              RSSetObjectFP,
827                              ArgList,
828                              2,
829                              SetObjectFD->getCallResultType(),
830                              clang::VK_RValue,
831                              Loc);
832 
833   return RSSetObjectCall;
834 }
835 
836 static clang::Stmt *CreateStructRSSetObject(clang::ASTContext &C,
837                                             clang::Expr *LHS,
838                                             clang::Expr *RHS,
839                                             clang::SourceLocation StartLoc,
840                                             clang::SourceLocation Loc);
841 
842 /*static clang::Stmt *CreateArrayRSSetObject(clang::ASTContext &C,
843                                            clang::Expr *DstArr,
844                                            clang::Expr *SrcArr,
845                                            clang::SourceLocation StartLoc,
846                                            clang::SourceLocation Loc) {
847   clang::DeclContext *DC = NULL;
848   const clang::Type *BaseType = DstArr->getType().getTypePtr();
849   slangAssert(BaseType->isArrayType());
850 
851   int NumArrayElements = ArrayDim(BaseType);
852   // Actually extract out the base RS object type for use later
853   BaseType = BaseType->getArrayElementTypeNoTypeQual();
854 
855   clang::Stmt *StmtArray[2] = {NULL};
856   int StmtCtr = 0;
857 
858   if (NumArrayElements <= 0) {
859     return NULL;
860   }
861 
862   // Create helper variable for iterating through elements
863   clang::IdentifierInfo& II = C.Idents.get("rsIntIter");
864   clang::VarDecl *IIVD =
865       clang::VarDecl::Create(C,
866                              DC,
867                              StartLoc,
868                              Loc,
869                              &II,
870                              C.IntTy,
871                              C.getTrivialTypeSourceInfo(C.IntTy),
872                              clang::SC_None,
873                              clang::SC_None);
874   clang::Decl *IID = (clang::Decl *)IIVD;
875 
876   clang::DeclGroupRef DGR = clang::DeclGroupRef::Create(C, &IID, 1);
877   StmtArray[StmtCtr++] = new(C) clang::DeclStmt(DGR, Loc, Loc);
878 
879   // Form the actual loop
880   // for (Init; Cond; Inc)
881   //   RSSetObjectCall;
882 
883   // Init -> "rsIntIter = 0"
884   clang::DeclRefExpr *RefrsIntIter =
885       clang::DeclRefExpr::Create(C,
886                                  clang::NestedNameSpecifierLoc(),
887                                  IIVD,
888                                  Loc,
889                                  C.IntTy,
890                                  clang::VK_RValue,
891                                  NULL);
892 
893   clang::Expr *Int0 = clang::IntegerLiteral::Create(C,
894       llvm::APInt(C.getTypeSize(C.IntTy), 0), C.IntTy, Loc);
895 
896   clang::BinaryOperator *Init =
897       new(C) clang::BinaryOperator(RefrsIntIter,
898                                    Int0,
899                                    clang::BO_Assign,
900                                    C.IntTy,
901                                    clang::VK_RValue,
902                                    clang::OK_Ordinary,
903                                    Loc);
904 
905   // Cond -> "rsIntIter < NumArrayElements"
906   clang::Expr *NumArrayElementsExpr = clang::IntegerLiteral::Create(C,
907       llvm::APInt(C.getTypeSize(C.IntTy), NumArrayElements), C.IntTy, Loc);
908 
909   clang::BinaryOperator *Cond =
910       new(C) clang::BinaryOperator(RefrsIntIter,
911                                    NumArrayElementsExpr,
912                                    clang::BO_LT,
913                                    C.IntTy,
914                                    clang::VK_RValue,
915                                    clang::OK_Ordinary,
916                                    Loc);
917 
918   // Inc -> "rsIntIter++"
919   clang::UnaryOperator *Inc =
920       new(C) clang::UnaryOperator(RefrsIntIter,
921                                   clang::UO_PostInc,
922                                   C.IntTy,
923                                   clang::VK_RValue,
924                                   clang::OK_Ordinary,
925                                   Loc);
926 
927   // Body -> "rsSetObject(&Dst[rsIntIter], Src[rsIntIter]);"
928   // Loop operates on individual array elements
929 
930   clang::Expr *DstArrPtr =
931       clang::ImplicitCastExpr::Create(C,
932           C.getPointerType(BaseType->getCanonicalTypeInternal()),
933           clang::CK_ArrayToPointerDecay,
934           DstArr,
935           NULL,
936           clang::VK_RValue);
937 
938   clang::Expr *DstArrPtrSubscript =
939       new(C) clang::ArraySubscriptExpr(DstArrPtr,
940                                        RefrsIntIter,
941                                        BaseType->getCanonicalTypeInternal(),
942                                        clang::VK_RValue,
943                                        clang::OK_Ordinary,
944                                        Loc);
945 
946   clang::Expr *SrcArrPtr =
947       clang::ImplicitCastExpr::Create(C,
948           C.getPointerType(BaseType->getCanonicalTypeInternal()),
949           clang::CK_ArrayToPointerDecay,
950           SrcArr,
951           NULL,
952           clang::VK_RValue);
953 
954   clang::Expr *SrcArrPtrSubscript =
955       new(C) clang::ArraySubscriptExpr(SrcArrPtr,
956                                        RefrsIntIter,
957                                        BaseType->getCanonicalTypeInternal(),
958                                        clang::VK_RValue,
959                                        clang::OK_Ordinary,
960                                        Loc);
961 
962   RSExportPrimitiveType::DataType DT =
963       RSExportPrimitiveType::GetRSSpecificType(BaseType);
964 
965   clang::Stmt *RSSetObjectCall = NULL;
966   if (BaseType->isArrayType()) {
967     RSSetObjectCall = CreateArrayRSSetObject(C, DstArrPtrSubscript,
968                                              SrcArrPtrSubscript,
969                                              StartLoc, Loc);
970   } else if (DT == RSExportPrimitiveType::DataTypeUnknown) {
971     RSSetObjectCall = CreateStructRSSetObject(C, DstArrPtrSubscript,
972                                               SrcArrPtrSubscript,
973                                               StartLoc, Loc);
974   } else {
975     RSSetObjectCall = CreateSingleRSSetObject(C, DstArrPtrSubscript,
976                                               SrcArrPtrSubscript,
977                                               StartLoc, Loc);
978   }
979 
980   clang::ForStmt *DestructorLoop =
981       new(C) clang::ForStmt(C,
982                             Init,
983                             Cond,
984                             NULL,  // no condVar
985                             Inc,
986                             RSSetObjectCall,
987                             Loc,
988                             Loc,
989                             Loc);
990 
991   StmtArray[StmtCtr++] = DestructorLoop;
992   slangAssert(StmtCtr == 2);
993 
994   clang::CompoundStmt *CS =
995       new(C) clang::CompoundStmt(C, StmtArray, StmtCtr, Loc, Loc);
996 
997   return CS;
998 } */
999 
CreateStructRSSetObject(clang::ASTContext & C,clang::Expr * LHS,clang::Expr * RHS,clang::SourceLocation StartLoc,clang::SourceLocation Loc)1000 static clang::Stmt *CreateStructRSSetObject(clang::ASTContext &C,
1001                                             clang::Expr *LHS,
1002                                             clang::Expr *RHS,
1003                                             clang::SourceLocation StartLoc,
1004                                             clang::SourceLocation Loc) {
1005   clang::QualType QT = LHS->getType();
1006   const clang::Type *T = QT.getTypePtr();
1007   slangAssert(T->isStructureType());
1008   slangAssert(!RSExportPrimitiveType::IsRSObjectType(T));
1009 
1010   // Keep an extra slot for the original copy (memcpy)
1011   unsigned FieldsToSet = CountRSObjectTypes(C, T, Loc) + 1;
1012 
1013   unsigned StmtCount = 0;
1014   clang::Stmt **StmtArray = new clang::Stmt*[FieldsToSet];
1015   for (unsigned i = 0; i < FieldsToSet; i++) {
1016     StmtArray[i] = NULL;
1017   }
1018 
1019   clang::RecordDecl *RD = T->getAsStructureType()->getDecl();
1020   RD = RD->getDefinition();
1021   for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
1022          FE = RD->field_end();
1023        FI != FE;
1024        FI++) {
1025     bool IsArrayType = false;
1026     clang::FieldDecl *FD = *FI;
1027     const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
1028     const clang::Type *OrigType = FT;
1029 
1030     if (!CountRSObjectTypes(C, FT, Loc)) {
1031       // Skip to next if we don't have any viable RS object types
1032       continue;
1033     }
1034 
1035     clang::DeclAccessPair FoundDecl =
1036         clang::DeclAccessPair::make(FD, clang::AS_none);
1037     clang::MemberExpr *DstMember =
1038         clang::MemberExpr::Create(C,
1039                                   LHS,
1040                                   false,
1041                                   clang::NestedNameSpecifierLoc(),
1042                                   clang::SourceLocation(),
1043                                   FD,
1044                                   FoundDecl,
1045                                   clang::DeclarationNameInfo(),
1046                                   NULL,
1047                                   OrigType->getCanonicalTypeInternal(),
1048                                   clang::VK_RValue,
1049                                   clang::OK_Ordinary);
1050 
1051     clang::MemberExpr *SrcMember =
1052         clang::MemberExpr::Create(C,
1053                                   RHS,
1054                                   false,
1055                                   clang::NestedNameSpecifierLoc(),
1056                                   clang::SourceLocation(),
1057                                   FD,
1058                                   FoundDecl,
1059                                   clang::DeclarationNameInfo(),
1060                                   NULL,
1061                                   OrigType->getCanonicalTypeInternal(),
1062                                   clang::VK_RValue,
1063                                   clang::OK_Ordinary);
1064 
1065     if (FT->isArrayType()) {
1066       FT = FT->getArrayElementTypeNoTypeQual();
1067       IsArrayType = true;
1068     }
1069 
1070     RSExportPrimitiveType::DataType DT =
1071         RSExportPrimitiveType::GetRSSpecificType(FT);
1072 
1073     if (IsArrayType) {
1074       clang::DiagnosticsEngine &DiagEngine = C.getDiagnostics();
1075       DiagEngine.Report(
1076         clang::FullSourceLoc(Loc, C.getSourceManager()),
1077         DiagEngine.getCustomDiagID(
1078           clang::DiagnosticsEngine::Error,
1079           "Arrays of RS object types within structures cannot be copied"));
1080       // TODO(srhines): Support setting arrays of RS objects
1081       // StmtArray[StmtCount++] =
1082       //    CreateArrayRSSetObject(C, DstMember, SrcMember, StartLoc, Loc);
1083     } else if (DT == RSExportPrimitiveType::DataTypeUnknown) {
1084       StmtArray[StmtCount++] =
1085           CreateStructRSSetObject(C, DstMember, SrcMember, StartLoc, Loc);
1086     } else if (RSExportPrimitiveType::IsRSObjectType(DT)) {
1087       StmtArray[StmtCount++] =
1088           CreateSingleRSSetObject(C, DstMember, SrcMember, StartLoc, Loc);
1089     } else {
1090       slangAssert(false);
1091     }
1092   }
1093 
1094   slangAssert(StmtCount > 0 && StmtCount < FieldsToSet);
1095 
1096   // We still need to actually do the overall struct copy. For simplicity,
1097   // we just do a straight-up assignment (which will still preserve all
1098   // the proper RS object reference counts).
1099   clang::BinaryOperator *CopyStruct =
1100       new(C) clang::BinaryOperator(LHS, RHS, clang::BO_Assign, QT,
1101                                    clang::VK_RValue, clang::OK_Ordinary, Loc);
1102   StmtArray[StmtCount++] = CopyStruct;
1103 
1104   clang::CompoundStmt *CS =
1105       new(C) clang::CompoundStmt(C, StmtArray, StmtCount, Loc, Loc);
1106 
1107   delete [] StmtArray;
1108 
1109   return CS;
1110 }
1111 
1112 }  // namespace
1113 
ReplaceRSObjectAssignment(clang::BinaryOperator * AS)1114 void RSObjectRefCount::Scope::ReplaceRSObjectAssignment(
1115     clang::BinaryOperator *AS) {
1116 
1117   clang::QualType QT = AS->getType();
1118 
1119   clang::ASTContext &C = RSObjectRefCount::GetRSSetObjectFD(
1120       RSExportPrimitiveType::DataTypeRSFont)->getASTContext();
1121 
1122   clang::SourceLocation Loc = AS->getExprLoc();
1123   clang::SourceLocation StartLoc = AS->getLHS()->getExprLoc();
1124   clang::Stmt *UpdatedStmt = NULL;
1125 
1126   if (!RSExportPrimitiveType::IsRSObjectType(QT.getTypePtr())) {
1127     // By definition, this is a struct assignment if we get here
1128     UpdatedStmt =
1129         CreateStructRSSetObject(C, AS->getLHS(), AS->getRHS(), StartLoc, Loc);
1130   } else {
1131     UpdatedStmt =
1132         CreateSingleRSSetObject(C, AS->getLHS(), AS->getRHS(), StartLoc, Loc);
1133   }
1134 
1135   RSASTReplace R(C);
1136   R.ReplaceStmt(mCS, AS, UpdatedStmt);
1137   return;
1138 }
1139 
AppendRSObjectInit(clang::VarDecl * VD,clang::DeclStmt * DS,RSExportPrimitiveType::DataType DT,clang::Expr * InitExpr)1140 void RSObjectRefCount::Scope::AppendRSObjectInit(
1141     clang::VarDecl *VD,
1142     clang::DeclStmt *DS,
1143     RSExportPrimitiveType::DataType DT,
1144     clang::Expr *InitExpr) {
1145   slangAssert(VD);
1146 
1147   if (!InitExpr) {
1148     return;
1149   }
1150 
1151   clang::ASTContext &C = RSObjectRefCount::GetRSSetObjectFD(
1152       RSExportPrimitiveType::DataTypeRSFont)->getASTContext();
1153   clang::SourceLocation Loc = RSObjectRefCount::GetRSSetObjectFD(
1154       RSExportPrimitiveType::DataTypeRSFont)->getLocation();
1155   clang::SourceLocation StartLoc = RSObjectRefCount::GetRSSetObjectFD(
1156       RSExportPrimitiveType::DataTypeRSFont)->getInnerLocStart();
1157 
1158   if (DT == RSExportPrimitiveType::DataTypeIsStruct) {
1159     const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1160     clang::DeclRefExpr *RefRSVar =
1161         clang::DeclRefExpr::Create(C,
1162                                    clang::NestedNameSpecifierLoc(),
1163                                    clang::SourceLocation(),
1164                                    VD,
1165                                    false,
1166                                    Loc,
1167                                    T->getCanonicalTypeInternal(),
1168                                    clang::VK_RValue,
1169                                    NULL);
1170 
1171     clang::Stmt *RSSetObjectOps =
1172         CreateStructRSSetObject(C, RefRSVar, InitExpr, StartLoc, Loc);
1173 
1174     std::list<clang::Stmt*> StmtList;
1175     StmtList.push_back(RSSetObjectOps);
1176     AppendAfterStmt(C, mCS, DS, StmtList);
1177     return;
1178   }
1179 
1180   clang::FunctionDecl *SetObjectFD = RSObjectRefCount::GetRSSetObjectFD(DT);
1181   slangAssert((SetObjectFD != NULL) &&
1182               "rsSetObject doesn't cover all RS object types");
1183 
1184   clang::QualType SetObjectFDType = SetObjectFD->getType();
1185   clang::QualType SetObjectFDArgType[2];
1186   SetObjectFDArgType[0] = SetObjectFD->getParamDecl(0)->getOriginalType();
1187   SetObjectFDArgType[1] = SetObjectFD->getParamDecl(1)->getOriginalType();
1188 
1189   clang::Expr *RefRSSetObjectFD =
1190       clang::DeclRefExpr::Create(C,
1191                                  clang::NestedNameSpecifierLoc(),
1192                                  clang::SourceLocation(),
1193                                  SetObjectFD,
1194                                  false,
1195                                  Loc,
1196                                  SetObjectFDType,
1197                                  clang::VK_RValue,
1198                                  NULL);
1199 
1200   clang::Expr *RSSetObjectFP =
1201       clang::ImplicitCastExpr::Create(C,
1202                                       C.getPointerType(SetObjectFDType),
1203                                       clang::CK_FunctionToPointerDecay,
1204                                       RefRSSetObjectFD,
1205                                       NULL,
1206                                       clang::VK_RValue);
1207 
1208   const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1209   clang::DeclRefExpr *RefRSVar =
1210       clang::DeclRefExpr::Create(C,
1211                                  clang::NestedNameSpecifierLoc(),
1212                                  clang::SourceLocation(),
1213                                  VD,
1214                                  false,
1215                                  Loc,
1216                                  T->getCanonicalTypeInternal(),
1217                                  clang::VK_RValue,
1218                                  NULL);
1219 
1220   clang::Expr *ArgList[2];
1221   ArgList[0] = new(C) clang::UnaryOperator(RefRSVar,
1222                                            clang::UO_AddrOf,
1223                                            SetObjectFDArgType[0],
1224                                            clang::VK_RValue,
1225                                            clang::OK_Ordinary,
1226                                            Loc);
1227   ArgList[1] = InitExpr;
1228 
1229   clang::CallExpr *RSSetObjectCall =
1230       new(C) clang::CallExpr(C,
1231                              RSSetObjectFP,
1232                              ArgList,
1233                              2,
1234                              SetObjectFD->getCallResultType(),
1235                              clang::VK_RValue,
1236                              Loc);
1237 
1238   std::list<clang::Stmt*> StmtList;
1239   StmtList.push_back(RSSetObjectCall);
1240   AppendAfterStmt(C, mCS, DS, StmtList);
1241 
1242   return;
1243 }
1244 
InsertLocalVarDestructors()1245 void RSObjectRefCount::Scope::InsertLocalVarDestructors() {
1246   for (std::list<clang::VarDecl*>::const_iterator I = mRSO.begin(),
1247           E = mRSO.end();
1248         I != E;
1249         I++) {
1250     clang::VarDecl *VD = *I;
1251     clang::Stmt *RSClearObjectCall = ClearRSObject(VD, VD->getDeclContext());
1252     if (RSClearObjectCall) {
1253       DestructorVisitor DV((*mRSO.begin())->getASTContext(),
1254                            mCS,
1255                            RSClearObjectCall,
1256                            VD->getSourceRange().getBegin());
1257       DV.Visit(mCS);
1258       DV.InsertDestructors();
1259     }
1260   }
1261   return;
1262 }
1263 
ClearRSObject(clang::VarDecl * VD,clang::DeclContext * DC)1264 clang::Stmt *RSObjectRefCount::Scope::ClearRSObject(
1265     clang::VarDecl *VD,
1266     clang::DeclContext *DC) {
1267   slangAssert(VD);
1268   clang::ASTContext &C = VD->getASTContext();
1269   clang::SourceLocation Loc = VD->getLocation();
1270   clang::SourceLocation StartLoc = VD->getInnerLocStart();
1271   const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1272 
1273   // Reference expr to target RS object variable
1274   clang::DeclRefExpr *RefRSVar =
1275       clang::DeclRefExpr::Create(C,
1276                                  clang::NestedNameSpecifierLoc(),
1277                                  clang::SourceLocation(),
1278                                  VD,
1279                                  false,
1280                                  Loc,
1281                                  T->getCanonicalTypeInternal(),
1282                                  clang::VK_RValue,
1283                                  NULL);
1284 
1285   if (T->isArrayType()) {
1286     return ClearArrayRSObject(C, DC, RefRSVar, StartLoc, Loc);
1287   }
1288 
1289   RSExportPrimitiveType::DataType DT =
1290       RSExportPrimitiveType::GetRSSpecificType(T);
1291 
1292   if (DT == RSExportPrimitiveType::DataTypeUnknown ||
1293       DT == RSExportPrimitiveType::DataTypeIsStruct) {
1294     return ClearStructRSObject(C, DC, RefRSVar, StartLoc, Loc);
1295   }
1296 
1297   slangAssert((RSExportPrimitiveType::IsRSObjectType(DT)) &&
1298               "Should be RS object");
1299 
1300   return ClearSingleRSObject(C, RefRSVar, Loc);
1301 }
1302 
InitializeRSObject(clang::VarDecl * VD,RSExportPrimitiveType::DataType * DT,clang::Expr ** InitExpr)1303 bool RSObjectRefCount::InitializeRSObject(clang::VarDecl *VD,
1304                                           RSExportPrimitiveType::DataType *DT,
1305                                           clang::Expr **InitExpr) {
1306   slangAssert(VD && DT && InitExpr);
1307   const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
1308 
1309   // Loop through array types to get to base type
1310   while (T && T->isArrayType()) {
1311     T = T->getArrayElementTypeNoTypeQual();
1312   }
1313 
1314   bool DataTypeIsStructWithRSObject = false;
1315   *DT = RSExportPrimitiveType::GetRSSpecificType(T);
1316 
1317   if (*DT == RSExportPrimitiveType::DataTypeUnknown) {
1318     if (RSExportPrimitiveType::IsStructureTypeWithRSObject(T)) {
1319       *DT = RSExportPrimitiveType::DataTypeIsStruct;
1320       DataTypeIsStructWithRSObject = true;
1321     } else {
1322       return false;
1323     }
1324   }
1325 
1326   bool DataTypeIsRSObject = false;
1327   if (DataTypeIsStructWithRSObject) {
1328     DataTypeIsRSObject = true;
1329   } else {
1330     DataTypeIsRSObject = RSExportPrimitiveType::IsRSObjectType(*DT);
1331   }
1332   *InitExpr = VD->getInit();
1333 
1334   if (!DataTypeIsRSObject && *InitExpr) {
1335     // If we already have an initializer for a matrix type, we are done.
1336     return DataTypeIsRSObject;
1337   }
1338 
1339   clang::Expr *ZeroInitializer =
1340       CreateZeroInitializerForRSSpecificType(*DT,
1341                                              VD->getASTContext(),
1342                                              VD->getLocation());
1343 
1344   if (ZeroInitializer) {
1345     ZeroInitializer->setType(T->getCanonicalTypeInternal());
1346     VD->setInit(ZeroInitializer);
1347   }
1348 
1349   return DataTypeIsRSObject;
1350 }
1351 
CreateZeroInitializerForRSSpecificType(RSExportPrimitiveType::DataType DT,clang::ASTContext & C,const clang::SourceLocation & Loc)1352 clang::Expr *RSObjectRefCount::CreateZeroInitializerForRSSpecificType(
1353     RSExportPrimitiveType::DataType DT,
1354     clang::ASTContext &C,
1355     const clang::SourceLocation &Loc) {
1356   clang::Expr *Res = NULL;
1357   switch (DT) {
1358     case RSExportPrimitiveType::DataTypeIsStruct:
1359     case RSExportPrimitiveType::DataTypeRSElement:
1360     case RSExportPrimitiveType::DataTypeRSType:
1361     case RSExportPrimitiveType::DataTypeRSAllocation:
1362     case RSExportPrimitiveType::DataTypeRSSampler:
1363     case RSExportPrimitiveType::DataTypeRSScript:
1364     case RSExportPrimitiveType::DataTypeRSMesh:
1365     case RSExportPrimitiveType::DataTypeRSPath:
1366     case RSExportPrimitiveType::DataTypeRSProgramFragment:
1367     case RSExportPrimitiveType::DataTypeRSProgramVertex:
1368     case RSExportPrimitiveType::DataTypeRSProgramRaster:
1369     case RSExportPrimitiveType::DataTypeRSProgramStore:
1370     case RSExportPrimitiveType::DataTypeRSFont: {
1371       //    (ImplicitCastExpr 'nullptr_t'
1372       //      (IntegerLiteral 0)))
1373       llvm::APInt Zero(C.getTypeSize(C.IntTy), 0);
1374       clang::Expr *Int0 = clang::IntegerLiteral::Create(C, Zero, C.IntTy, Loc);
1375       clang::Expr *CastToNull =
1376           clang::ImplicitCastExpr::Create(C,
1377                                           C.NullPtrTy,
1378                                           clang::CK_IntegralToPointer,
1379                                           Int0,
1380                                           NULL,
1381                                           clang::VK_RValue);
1382 
1383       Res = new(C) clang::InitListExpr(C, Loc, &CastToNull, 1, Loc);
1384       break;
1385     }
1386     case RSExportPrimitiveType::DataTypeRSMatrix2x2:
1387     case RSExportPrimitiveType::DataTypeRSMatrix3x3:
1388     case RSExportPrimitiveType::DataTypeRSMatrix4x4: {
1389       // RS matrix is not completely an RS object. They hold data by themselves.
1390       // (InitListExpr rs_matrix2x2
1391       //   (InitListExpr float[4]
1392       //     (FloatingLiteral 0)
1393       //     (FloatingLiteral 0)
1394       //     (FloatingLiteral 0)
1395       //     (FloatingLiteral 0)))
1396       clang::QualType FloatTy = C.FloatTy;
1397       // Constructor sets value to 0.0f by default
1398       llvm::APFloat Val(C.getFloatTypeSemantics(FloatTy));
1399       clang::FloatingLiteral *Float0Val =
1400           clang::FloatingLiteral::Create(C,
1401                                          Val,
1402                                          /* isExact = */true,
1403                                          FloatTy,
1404                                          Loc);
1405 
1406       unsigned N = 0;
1407       if (DT == RSExportPrimitiveType::DataTypeRSMatrix2x2)
1408         N = 2;
1409       else if (DT == RSExportPrimitiveType::DataTypeRSMatrix3x3)
1410         N = 3;
1411       else if (DT == RSExportPrimitiveType::DataTypeRSMatrix4x4)
1412         N = 4;
1413 
1414       // Directly allocate 16 elements instead of dynamically allocate N*N
1415       clang::Expr *InitVals[16];
1416       for (unsigned i = 0; i < sizeof(InitVals) / sizeof(InitVals[0]); i++)
1417         InitVals[i] = Float0Val;
1418       clang::Expr *InitExpr =
1419           new(C) clang::InitListExpr(C, Loc, InitVals, N * N, Loc);
1420       InitExpr->setType(C.getConstantArrayType(FloatTy,
1421                                                llvm::APInt(32, N * N),
1422                                                clang::ArrayType::Normal,
1423                                                /* EltTypeQuals = */0));
1424 
1425       Res = new(C) clang::InitListExpr(C, Loc, &InitExpr, 1, Loc);
1426       break;
1427     }
1428     case RSExportPrimitiveType::DataTypeUnknown:
1429     case RSExportPrimitiveType::DataTypeFloat16:
1430     case RSExportPrimitiveType::DataTypeFloat32:
1431     case RSExportPrimitiveType::DataTypeFloat64:
1432     case RSExportPrimitiveType::DataTypeSigned8:
1433     case RSExportPrimitiveType::DataTypeSigned16:
1434     case RSExportPrimitiveType::DataTypeSigned32:
1435     case RSExportPrimitiveType::DataTypeSigned64:
1436     case RSExportPrimitiveType::DataTypeUnsigned8:
1437     case RSExportPrimitiveType::DataTypeUnsigned16:
1438     case RSExportPrimitiveType::DataTypeUnsigned32:
1439     case RSExportPrimitiveType::DataTypeUnsigned64:
1440     case RSExportPrimitiveType::DataTypeBoolean:
1441     case RSExportPrimitiveType::DataTypeUnsigned565:
1442     case RSExportPrimitiveType::DataTypeUnsigned5551:
1443     case RSExportPrimitiveType::DataTypeUnsigned4444:
1444     case RSExportPrimitiveType::DataTypeMax: {
1445       slangAssert(false && "Not RS object type!");
1446     }
1447     // No default case will enable compiler detecting the missing cases
1448   }
1449 
1450   return Res;
1451 }
1452 
VisitDeclStmt(clang::DeclStmt * DS)1453 void RSObjectRefCount::VisitDeclStmt(clang::DeclStmt *DS) {
1454   for (clang::DeclStmt::decl_iterator I = DS->decl_begin(), E = DS->decl_end();
1455        I != E;
1456        I++) {
1457     clang::Decl *D = *I;
1458     if (D->getKind() == clang::Decl::Var) {
1459       clang::VarDecl *VD = static_cast<clang::VarDecl*>(D);
1460       RSExportPrimitiveType::DataType DT =
1461           RSExportPrimitiveType::DataTypeUnknown;
1462       clang::Expr *InitExpr = NULL;
1463       if (InitializeRSObject(VD, &DT, &InitExpr)) {
1464         getCurrentScope()->addRSObject(VD);
1465         getCurrentScope()->AppendRSObjectInit(VD, DS, DT, InitExpr);
1466       }
1467     }
1468   }
1469   return;
1470 }
1471 
VisitCompoundStmt(clang::CompoundStmt * CS)1472 void RSObjectRefCount::VisitCompoundStmt(clang::CompoundStmt *CS) {
1473   if (!CS->body_empty()) {
1474     // Push a new scope
1475     Scope *S = new Scope(CS);
1476     mScopeStack.push(S);
1477 
1478     VisitStmt(CS);
1479 
1480     // Destroy the scope
1481     slangAssert((getCurrentScope() == S) && "Corrupted scope stack!");
1482     S->InsertLocalVarDestructors();
1483     mScopeStack.pop();
1484     delete S;
1485   }
1486   return;
1487 }
1488 
VisitBinAssign(clang::BinaryOperator * AS)1489 void RSObjectRefCount::VisitBinAssign(clang::BinaryOperator *AS) {
1490   clang::QualType QT = AS->getType();
1491 
1492   if (CountRSObjectTypes(mCtx, QT.getTypePtr(), AS->getExprLoc())) {
1493     getCurrentScope()->ReplaceRSObjectAssignment(AS);
1494   }
1495 
1496   return;
1497 }
1498 
VisitStmt(clang::Stmt * S)1499 void RSObjectRefCount::VisitStmt(clang::Stmt *S) {
1500   for (clang::Stmt::child_iterator I = S->child_begin(), E = S->child_end();
1501        I != E;
1502        I++) {
1503     if (clang::Stmt *Child = *I) {
1504       Visit(Child);
1505     }
1506   }
1507   return;
1508 }
1509 
1510 // This function walks the list of global variables and (potentially) creates
1511 // a single global static destructor function that properly decrements
1512 // reference counts on the contained RS object types.
CreateStaticGlobalDtor()1513 clang::FunctionDecl *RSObjectRefCount::CreateStaticGlobalDtor() {
1514   Init();
1515 
1516   clang::DeclContext *DC = mCtx.getTranslationUnitDecl();
1517   clang::SourceLocation loc;
1518 
1519   llvm::StringRef SR(".rs.dtor");
1520   clang::IdentifierInfo &II = mCtx.Idents.get(SR);
1521   clang::DeclarationName N(&II);
1522   clang::FunctionProtoType::ExtProtoInfo EPI;
1523   clang::QualType T = mCtx.getFunctionType(mCtx.VoidTy, NULL, 0, EPI);
1524   clang::FunctionDecl *FD = NULL;
1525 
1526   // Generate rsClearObject() call chains for every global variable
1527   // (whether static or extern).
1528   std::list<clang::Stmt *> StmtList;
1529   for (clang::DeclContext::decl_iterator I = DC->decls_begin(),
1530           E = DC->decls_end(); I != E; I++) {
1531     clang::VarDecl *VD = llvm::dyn_cast<clang::VarDecl>(*I);
1532     if (VD) {
1533       if (CountRSObjectTypes(mCtx, VD->getType().getTypePtr(), loc)) {
1534         if (!FD) {
1535           // Only create FD if we are going to use it.
1536           FD = clang::FunctionDecl::Create(mCtx, DC, loc, loc, N, T, NULL);
1537         }
1538         // Make sure to create any helpers within the function's DeclContext,
1539         // not the one associated with the global translation unit.
1540         clang::Stmt *RSClearObjectCall = Scope::ClearRSObject(VD, FD);
1541         StmtList.push_back(RSClearObjectCall);
1542       }
1543     }
1544   }
1545 
1546   // Nothing needs to be destroyed, so don't emit a dtor.
1547   if (StmtList.empty()) {
1548     return NULL;
1549   }
1550 
1551   clang::CompoundStmt *CS = BuildCompoundStmt(mCtx, StmtList, loc);
1552 
1553   FD->setBody(CS);
1554 
1555   return FD;
1556 }
1557 
1558 }  // namespace slang
1559