• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2010-2012, The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "slang_rs_export_type.h"
18 
19 #include <list>
20 #include <vector>
21 
22 #include "clang/AST/ASTContext.h"
23 #include "clang/AST/Attr.h"
24 #include "clang/AST/RecordLayout.h"
25 
26 #include "llvm/ADT/StringExtras.h"
27 #include "llvm/IR/DataLayout.h"
28 #include "llvm/IR/DerivedTypes.h"
29 #include "llvm/IR/Type.h"
30 
31 #include "slang_assert.h"
32 #include "slang_rs_context.h"
33 #include "slang_rs_export_element.h"
34 #include "slang_version.h"
35 
36 #define CHECK_PARENT_EQUALITY(ParentClass, E) \
37   if (!ParentClass::equals(E))                \
38     return false;
39 
40 namespace slang {
41 
42 namespace {
43 
44 // For the data types we support:
45 //  Category      - data type category
46 //  SName         - "common name" in script (C99)
47 //  RsType        - element name in RenderScript
48 //  RsShortType   - short element name in RenderScript
49 //  SizeInBits    - size in bits
50 //  CName         - reflected C name
51 //  JavaName      - reflected Java name
52 //  JavaArrayElementName - reflected name in Java arrays
53 //  CVecName      - prefix for C vector types
54 //  JavaVecName   - prefix for Java vector type
55 //  JavaPromotion - unsigned type undergoing Java promotion
56 //
57 // IMPORTANT: The data types in this table should be at the same index as
58 // specified by the corresponding DataType enum.
59 //
60 // TODO: Pull this information out into a separate file.
61 static RSReflectionType gReflectionTypes[] = {
62 #define _ nullptr
63   //      Category     SName              RsType       RsST           CName         JN      JAEN       CVN       JVN     JP
64 {PrimitiveDataType,   "half",         "FLOAT_16",     "F16", 16,     "half",   "short",  "short",   "Half",  "Short", false},
65 {PrimitiveDataType,  "float",         "FLOAT_32",     "F32", 32,    "float",   "float",  "float",  "Float",  "Float", false},
66 {PrimitiveDataType, "double",         "FLOAT_64",     "F64", 64,   "double",  "double", "double", "Double", "Double", false},
67 {PrimitiveDataType,   "char",         "SIGNED_8",      "I8",  8,   "int8_t",    "byte",   "byte",   "Byte",   "Byte", false},
68 {PrimitiveDataType,  "short",        "SIGNED_16",     "I16", 16,  "int16_t",   "short",  "short",  "Short",  "Short", false},
69 {PrimitiveDataType,    "int",        "SIGNED_32",     "I32", 32,  "int32_t",     "int",    "int",    "Int",    "Int", false},
70 {PrimitiveDataType,   "long",        "SIGNED_64",     "I64", 64,  "int64_t",    "long",   "long",   "Long",   "Long", false},
71 {PrimitiveDataType,  "uchar",       "UNSIGNED_8",      "U8",  8,  "uint8_t",   "short",   "byte",  "UByte",  "Short",  true},
72 {PrimitiveDataType, "ushort",      "UNSIGNED_16",     "U16", 16, "uint16_t",     "int",  "short", "UShort",    "Int",  true},
73 {PrimitiveDataType,   "uint",      "UNSIGNED_32",     "U32", 32, "uint32_t",    "long",    "int",   "UInt",   "Long",  true},
74 {PrimitiveDataType,  "ulong",      "UNSIGNED_64",     "U64", 64, "uint64_t",    "long",   "long",  "ULong",   "Long", false},
75 {PrimitiveDataType,   "bool",          "BOOLEAN", "BOOLEAN",  8,     "bool", "boolean",   "byte",        _,        _, false},
76 {PrimitiveDataType,        _,   "UNSIGNED_5_6_5",         _, 16,          _,         _,        _,        _,        _, false},
77 {PrimitiveDataType,        _, "UNSIGNED_5_5_5_1",         _, 16,          _,         _,        _,        _,        _, false},
78 {PrimitiveDataType,        _, "UNSIGNED_4_4_4_4",         _, 16,          _,         _,        _,        _,        _, false},
79 
80 {MatrixDataType, "rs_matrix2x2", "MATRIX_2X2", _,  4*32, "rs_matrix2x2", "Matrix2f", _, _, _, false},
81 {MatrixDataType, "rs_matrix3x3", "MATRIX_3X3", _,  9*32, "rs_matrix3x3", "Matrix3f", _, _, _, false},
82 {MatrixDataType, "rs_matrix4x4", "MATRIX_4X4", _, 16*32, "rs_matrix4x4", "Matrix4f", _, _, _, false},
83 
84 // RS object types are 32 bits in 32-bit RS, but 256 bits in 64-bit RS.
85 // This is handled specially by the GetElementSizeInBits() method.
86 {ObjectDataType, _,          "RS_ELEMENT",          "ELEMENT", 32,         "Element",         "Element", _, _, _, false},
87 {ObjectDataType, _,             "RS_TYPE",             "TYPE", 32,            "Type",            "Type", _, _, _, false},
88 {ObjectDataType, _,       "RS_ALLOCATION",       "ALLOCATION", 32,      "Allocation",      "Allocation", _, _, _, false},
89 {ObjectDataType, _,          "RS_SAMPLER",          "SAMPLER", 32,         "Sampler",         "Sampler", _, _, _, false},
90 {ObjectDataType, _,           "RS_SCRIPT",           "SCRIPT", 32,          "Script",          "Script", _, _, _, false},
91 {ObjectDataType, _,             "RS_MESH",             "MESH", 32,            "Mesh",            "Mesh", _, _, _, false},
92 {ObjectDataType, _,             "RS_PATH",             "PATH", 32,            "Path",            "Path", _, _, _, false},
93 {ObjectDataType, _, "RS_PROGRAM_FRAGMENT", "PROGRAM_FRAGMENT", 32, "ProgramFragment", "ProgramFragment", _, _, _, false},
94 {ObjectDataType, _,   "RS_PROGRAM_VERTEX",   "PROGRAM_VERTEX", 32,   "ProgramVertex",   "ProgramVertex", _, _, _, false},
95 {ObjectDataType, _,   "RS_PROGRAM_RASTER",   "PROGRAM_RASTER", 32,   "ProgramRaster",   "ProgramRaster", _, _, _, false},
96 {ObjectDataType, _,    "RS_PROGRAM_STORE",    "PROGRAM_STORE", 32,    "ProgramStore",    "ProgramStore", _, _, _, false},
97 {ObjectDataType, _,             "RS_FONT",             "FONT", 32,            "Font",            "Font", _, _, _, false},
98 #undef _
99 };
100 
101 const int kMaxVectorSize = 4;
102 
103 struct BuiltinInfo {
104   clang::BuiltinType::Kind builtinTypeKind;
105   DataType type;
106   /* TODO If we return std::string instead of llvm::StringRef, we could build
107    * the name instead of duplicating the entries.
108    */
109   const char *cname[kMaxVectorSize];
110 };
111 
112 
113 BuiltinInfo BuiltinInfoTable[] = {
114     {clang::BuiltinType::Bool, DataTypeBoolean,
115      {"bool", "bool2", "bool3", "bool4"}},
116     {clang::BuiltinType::Char_U, DataTypeUnsigned8,
117      {"uchar", "uchar2", "uchar3", "uchar4"}},
118     {clang::BuiltinType::UChar, DataTypeUnsigned8,
119      {"uchar", "uchar2", "uchar3", "uchar4"}},
120     {clang::BuiltinType::Char16, DataTypeSigned16,
121      {"short", "short2", "short3", "short4"}},
122     {clang::BuiltinType::Char32, DataTypeSigned32,
123      {"int", "int2", "int3", "int4"}},
124     {clang::BuiltinType::UShort, DataTypeUnsigned16,
125      {"ushort", "ushort2", "ushort3", "ushort4"}},
126     {clang::BuiltinType::UInt, DataTypeUnsigned32,
127      {"uint", "uint2", "uint3", "uint4"}},
128     {clang::BuiltinType::ULong, DataTypeUnsigned64,
129      {"ulong", "ulong2", "ulong3", "ulong4"}},
130     {clang::BuiltinType::ULongLong, DataTypeUnsigned64,
131      {"ulong", "ulong2", "ulong3", "ulong4"}},
132 
133     {clang::BuiltinType::Char_S, DataTypeSigned8,
134      {"char", "char2", "char3", "char4"}},
135     {clang::BuiltinType::SChar, DataTypeSigned8,
136      {"char", "char2", "char3", "char4"}},
137     {clang::BuiltinType::Short, DataTypeSigned16,
138      {"short", "short2", "short3", "short4"}},
139     {clang::BuiltinType::Int, DataTypeSigned32,
140      {"int", "int2", "int3", "int4"}},
141     {clang::BuiltinType::Long, DataTypeSigned64,
142      {"long", "long2", "long3", "long4"}},
143     {clang::BuiltinType::LongLong, DataTypeSigned64,
144      {"long", "long2", "long3", "long4"}},
145     {clang::BuiltinType::Half, DataTypeFloat16,
146      {"half", "half2", "half3", "half4"}},
147     {clang::BuiltinType::Float, DataTypeFloat32,
148      {"float", "float2", "float3", "float4"}},
149     {clang::BuiltinType::Double, DataTypeFloat64,
150      {"double", "double2", "double3", "double4"}},
151 };
152 const int BuiltinInfoTableCount = sizeof(BuiltinInfoTable) / sizeof(BuiltinInfoTable[0]);
153 
154 struct NameAndPrimitiveType {
155   const char *name;
156   DataType dataType;
157 };
158 
159 static NameAndPrimitiveType MatrixAndObjectDataTypes[] = {
160     {"rs_matrix2x2", DataTypeRSMatrix2x2},
161     {"rs_matrix3x3", DataTypeRSMatrix3x3},
162     {"rs_matrix4x4", DataTypeRSMatrix4x4},
163     {"rs_element", DataTypeRSElement},
164     {"rs_type", DataTypeRSType},
165     {"rs_allocation", DataTypeRSAllocation},
166     {"rs_sampler", DataTypeRSSampler},
167     {"rs_script", DataTypeRSScript},
168     {"rs_mesh", DataTypeRSMesh},
169     {"rs_path", DataTypeRSPath},
170     {"rs_program_fragment", DataTypeRSProgramFragment},
171     {"rs_program_vertex", DataTypeRSProgramVertex},
172     {"rs_program_raster", DataTypeRSProgramRaster},
173     {"rs_program_store", DataTypeRSProgramStore},
174     {"rs_font", DataTypeRSFont},
175 };
176 
177 const int MatrixAndObjectDataTypesCount =
178     sizeof(MatrixAndObjectDataTypes) / sizeof(MatrixAndObjectDataTypes[0]);
179 
180 static const clang::Type *TypeExportableHelper(
181     const clang::Type *T,
182     llvm::SmallPtrSet<const clang::Type*, 8>& SPS,
183     slang::RSContext *Context,
184     const clang::VarDecl *VD,
185     const clang::RecordDecl *TopLevelRecord,
186     ExportKind EK);
187 
188 template <unsigned N>
ReportTypeError(slang::RSContext * Context,const clang::NamedDecl * ND,const clang::RecordDecl * TopLevelRecord,const char (& Message)[N],unsigned int TargetAPI=0)189 static void ReportTypeError(slang::RSContext *Context,
190                             const clang::NamedDecl *ND,
191                             const clang::RecordDecl *TopLevelRecord,
192                             const char (&Message)[N],
193                             unsigned int TargetAPI = 0) {
194   // Attempt to use the type declaration first (if we have one).
195   // Fall back to the variable definition, if we are looking at something
196   // like an array declaration that can't be exported.
197   if (TopLevelRecord) {
198     Context->ReportError(TopLevelRecord->getLocation(), Message)
199         << TopLevelRecord->getName() << TargetAPI;
200   } else if (ND) {
201     Context->ReportError(ND->getLocation(), Message) << ND->getName()
202                                                      << TargetAPI;
203   } else {
204     slangAssert(false && "Variables should be validated before exporting");
205   }
206 }
207 
ConstantArrayTypeExportableHelper(const clang::ConstantArrayType * CAT,llvm::SmallPtrSet<const clang::Type *,8> & SPS,slang::RSContext * Context,const clang::VarDecl * VD,const clang::RecordDecl * TopLevelRecord,ExportKind EK)208 static const clang::Type *ConstantArrayTypeExportableHelper(
209     const clang::ConstantArrayType *CAT,
210     llvm::SmallPtrSet<const clang::Type*, 8>& SPS,
211     slang::RSContext *Context,
212     const clang::VarDecl *VD,
213     const clang::RecordDecl *TopLevelRecord,
214     ExportKind EK) {
215   // Check element type
216   const clang::Type *ElementType = GetConstantArrayElementType(CAT);
217   if (ElementType->isArrayType()) {
218     ReportTypeError(Context, VD, TopLevelRecord,
219                     "multidimensional arrays cannot be exported: '%0'");
220     return nullptr;
221   } else if (ElementType->isExtVectorType()) {
222     const clang::ExtVectorType *EVT =
223         static_cast<const clang::ExtVectorType*>(ElementType);
224     unsigned numElements = EVT->getNumElements();
225 
226     const clang::Type *BaseElementType = GetExtVectorElementType(EVT);
227     if (!RSExportPrimitiveType::IsPrimitiveType(BaseElementType)) {
228       ReportTypeError(Context, VD, TopLevelRecord,
229         "vectors of non-primitive types cannot be exported: '%0'");
230       return nullptr;
231     }
232 
233     if (numElements == 3 && CAT->getSize() != 1) {
234       ReportTypeError(Context, VD, TopLevelRecord,
235         "arrays of width 3 vector types cannot be exported: '%0'");
236       return nullptr;
237     }
238   }
239 
240   if (TypeExportableHelper(ElementType, SPS, Context, VD,
241                            TopLevelRecord, EK) == nullptr) {
242     return nullptr;
243   } else {
244     return CAT;
245   }
246 }
247 
FindBuiltinType(clang::BuiltinType::Kind builtinTypeKind)248 BuiltinInfo *FindBuiltinType(clang::BuiltinType::Kind builtinTypeKind) {
249   for (int i = 0; i < BuiltinInfoTableCount; i++) {
250     if (builtinTypeKind == BuiltinInfoTable[i].builtinTypeKind) {
251       return &BuiltinInfoTable[i];
252     }
253   }
254   return nullptr;
255 }
256 
TypeExportableHelper(clang::Type const * T,llvm::SmallPtrSet<clang::Type const *,8> & SPS,slang::RSContext * Context,clang::VarDecl const * VD,clang::RecordDecl const * TopLevelRecord,ExportKind EK)257 static const clang::Type *TypeExportableHelper(
258     clang::Type const *T,
259     llvm::SmallPtrSet<clang::Type const *, 8> &SPS,
260     slang::RSContext *Context,
261     clang::VarDecl const *VD,
262     clang::RecordDecl const *TopLevelRecord,
263     ExportKind EK) {
264   // Normalize first
265   if ((T = GetCanonicalType(T)) == nullptr)
266     return nullptr;
267 
268   if (SPS.count(T))
269     return T;
270 
271   const clang::Type *CTI = T->getCanonicalTypeInternal().getTypePtr();
272 
273   switch (T->getTypeClass()) {
274     case clang::Type::Builtin: {
275       const clang::BuiltinType *BT = static_cast<const clang::BuiltinType*>(CTI);
276       return FindBuiltinType(BT->getKind()) == nullptr ? nullptr : T;
277     }
278     case clang::Type::Record: {
279       if (RSExportPrimitiveType::GetRSSpecificType(T) != DataTypeUnknown) {
280         return T;  // RS object type, no further checks are needed
281       }
282 
283       // Check internal struct
284       if (T->isUnionType()) {
285         ReportTypeError(Context, VD, T->getAsUnionType()->getDecl(),
286                         "unions cannot be exported: '%0'");
287         return nullptr;
288       } else if (!T->isStructureType()) {
289         slangAssert(false && "Unknown type cannot be exported");
290         return nullptr;
291       }
292 
293       clang::RecordDecl *RD = T->getAsStructureType()->getDecl();
294       if (RD != nullptr) {
295         RD = RD->getDefinition();
296         if (RD == nullptr) {
297           ReportTypeError(Context, nullptr, T->getAsStructureType()->getDecl(),
298                           "struct is not defined in this module");
299           return nullptr;
300         }
301       }
302 
303       if (!TopLevelRecord) {
304         TopLevelRecord = RD;
305       }
306       if (RD->getName().empty()) {
307         ReportTypeError(Context, nullptr, RD,
308                         "anonymous structures cannot be exported");
309         return nullptr;
310       }
311 
312       // Fast check
313       if (RD->hasFlexibleArrayMember() || RD->hasObjectMember())
314         return nullptr;
315 
316       // Insert myself into checking set
317       SPS.insert(T);
318 
319       // Check all element
320       for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
321                FE = RD->field_end();
322            FI != FE;
323            FI++) {
324         const clang::FieldDecl *FD = *FI;
325         const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
326         FT = GetCanonicalType(FT);
327 
328         if (!TypeExportableHelper(FT, SPS, Context, VD, TopLevelRecord,
329                                   EK)) {
330           return nullptr;
331         }
332 
333         // We don't support bit fields yet
334         //
335         // TODO(zonr/srhines): allow bit fields of size 8, 16, 32
336         if (FD->isBitField()) {
337           Context->ReportError(
338               FD->getLocation(),
339               "bit fields are not able to be exported: '%0.%1'")
340               << RD->getName() << FD->getName();
341           return nullptr;
342         }
343       }
344 
345       return T;
346     }
347     case clang::Type::FunctionProto:
348       ReportTypeError(Context, VD, TopLevelRecord,
349                       "function types cannot be exported: '%0'");
350       return nullptr;
351     case clang::Type::Pointer: {
352       if (TopLevelRecord) {
353         ReportTypeError(Context, VD, TopLevelRecord,
354             "structures containing pointers cannot be used as the type of "
355             "an exported global variable or the parameter to an exported "
356             "function: '%0'");
357         return nullptr;
358       }
359 
360       const clang::PointerType *PT = static_cast<const clang::PointerType*>(CTI);
361       const clang::Type *PointeeType = GetPointeeType(PT);
362 
363       if (PointeeType->getTypeClass() == clang::Type::Pointer) {
364         ReportTypeError(Context, VD, TopLevelRecord,
365             "multiple levels of pointers cannot be exported: '%0'");
366         return nullptr;
367       }
368 
369       // Void pointers are forbidden for export, although we must accept
370       // void pointers that come in as arguments to a legacy kernel.
371       if (PointeeType->isVoidType() && EK != LegacyKernelArgument) {
372         ReportTypeError(Context, VD, TopLevelRecord,
373             "void pointers cannot be exported: '%0'");
374         return nullptr;
375       }
376 
377       // We don't support pointer with array-type pointee or unsupported pointee
378       // type
379       if (PointeeType->isArrayType() ||
380           (TypeExportableHelper(PointeeType, SPS, Context, VD,
381                                 TopLevelRecord, EK) == nullptr))
382         return nullptr;
383       else
384         return T;
385     }
386     case clang::Type::ExtVector: {
387       const clang::ExtVectorType *EVT =
388               static_cast<const clang::ExtVectorType*>(CTI);
389       // Only vector with size 2, 3 and 4 are supported.
390       if (EVT->getNumElements() < 2 || EVT->getNumElements() > 4)
391         return nullptr;
392 
393       // Check base element type
394       const clang::Type *ElementType = GetExtVectorElementType(EVT);
395 
396       if ((ElementType->getTypeClass() != clang::Type::Builtin) ||
397           (TypeExportableHelper(ElementType, SPS, Context, VD,
398                                 TopLevelRecord, EK) == nullptr))
399         return nullptr;
400       else
401         return T;
402     }
403     case clang::Type::ConstantArray: {
404       const clang::ConstantArrayType *CAT =
405               static_cast<const clang::ConstantArrayType*>(CTI);
406 
407       return ConstantArrayTypeExportableHelper(CAT, SPS, Context, VD,
408                                                TopLevelRecord, EK);
409     }
410     case clang::Type::Enum: {
411       // FIXME: We currently convert enums to integers, rather than reflecting
412       // a more complete (and nicer type-safe Java version).
413       return Context->getASTContext().IntTy.getTypePtr();
414     }
415     default: {
416       slangAssert(false && "Unknown type cannot be validated");
417       return nullptr;
418     }
419   }
420 }
421 
422 // Return the type that can be used to create RSExportType, will always return
423 // the canonical type.
424 //
425 // If the Type T is not exportable, this function returns nullptr. DiagEngine is
426 // used to generate proper Clang diagnostic messages when a non-exportable type
427 // is detected. TopLevelRecord is used to capture the highest struct (in the
428 // case of a nested hierarchy) for detecting other types that cannot be exported
429 // (mostly pointers within a struct).
TypeExportable(const clang::Type * T,slang::RSContext * Context,const clang::VarDecl * VD,ExportKind EK)430 static const clang::Type *TypeExportable(const clang::Type *T,
431                                          slang::RSContext *Context,
432                                          const clang::VarDecl *VD,
433                                          ExportKind EK) {
434   llvm::SmallPtrSet<const clang::Type*, 8> SPS =
435       llvm::SmallPtrSet<const clang::Type*, 8>();
436 
437   return TypeExportableHelper(T, SPS, Context, VD, nullptr, EK);
438 }
439 
ValidateRSObjectInVarDecl(slang::RSContext * Context,const clang::VarDecl * VD,bool InCompositeType,unsigned int TargetAPI)440 static bool ValidateRSObjectInVarDecl(slang::RSContext *Context,
441                                       const clang::VarDecl *VD, bool InCompositeType,
442                                       unsigned int TargetAPI) {
443   if (TargetAPI < SLANG_JB_TARGET_API) {
444     // Only if we are already in a composite type (like an array or structure).
445     if (InCompositeType) {
446       // Only if we are actually exported (i.e. non-static).
447       if (VD->hasLinkage() &&
448           (VD->getFormalLinkage() == clang::ExternalLinkage)) {
449         // Only if we are not a pointer to an object.
450         const clang::Type *T = GetCanonicalType(VD->getType().getTypePtr());
451         if (T->getTypeClass() != clang::Type::Pointer) {
452           ReportTypeError(Context, VD, nullptr,
453                           "arrays/structures containing RS object types "
454                           "cannot be exported in target API < %1: '%0'",
455                           SLANG_JB_TARGET_API);
456           return false;
457         }
458       }
459     }
460   }
461 
462   return true;
463 }
464 
465 // Helper function for ValidateType(). We do a recursive descent on the
466 // type hierarchy to ensure that we can properly export/handle the
467 // declaration.
468 // \return true if the variable declaration is valid,
469 //         false if it is invalid (along with proper diagnostics).
470 //
471 // C - ASTContext (for diagnostics + builtin types).
472 // T - sub-type that we are validating.
473 // ND - (optional) top-level named declaration that we are validating.
474 // SPS - set of types we have already seen/validated.
475 // InCompositeType - true if we are within an outer composite type.
476 // UnionDecl - set if we are in a sub-type of a union.
477 // TargetAPI - target SDK API level.
478 // IsFilterscript - whether or not we are compiling for Filterscript
479 // IsExtern - is this type externally visible (i.e. extern global or parameter
480 //                                             to an extern function)
ValidateTypeHelper(slang::RSContext * Context,clang::ASTContext & C,const clang::Type * & T,const clang::NamedDecl * ND,clang::SourceLocation Loc,llvm::SmallPtrSet<const clang::Type *,8> & SPS,bool InCompositeType,clang::RecordDecl * UnionDecl,unsigned int TargetAPI,bool IsFilterscript,bool IsExtern)481 static bool ValidateTypeHelper(
482     slang::RSContext *Context,
483     clang::ASTContext &C,
484     const clang::Type *&T,
485     const clang::NamedDecl *ND,
486     clang::SourceLocation Loc,
487     llvm::SmallPtrSet<const clang::Type*, 8>& SPS,
488     bool InCompositeType,
489     clang::RecordDecl *UnionDecl,
490     unsigned int TargetAPI,
491     bool IsFilterscript,
492     bool IsExtern) {
493   if ((T = GetCanonicalType(T)) == nullptr)
494     return true;
495 
496   if (SPS.count(T))
497     return true;
498 
499   const clang::Type *CTI = T->getCanonicalTypeInternal().getTypePtr();
500 
501   switch (T->getTypeClass()) {
502     case clang::Type::Record: {
503       if (RSExportPrimitiveType::IsRSObjectType(T)) {
504         const clang::VarDecl *VD = (ND ? llvm::dyn_cast<clang::VarDecl>(ND) : nullptr);
505         if (VD && !ValidateRSObjectInVarDecl(Context, VD, InCompositeType,
506                                              TargetAPI)) {
507           return false;
508         }
509       }
510 
511       if (RSExportPrimitiveType::GetRSSpecificType(T) != DataTypeUnknown) {
512         if (!UnionDecl) {
513           return true;
514         } else if (RSExportPrimitiveType::IsRSObjectType(T)) {
515           ReportTypeError(Context, nullptr, UnionDecl,
516               "unions containing RS object types are not allowed");
517           return false;
518         }
519       }
520 
521       clang::RecordDecl *RD = nullptr;
522 
523       // Check internal struct
524       if (T->isUnionType()) {
525         RD = T->getAsUnionType()->getDecl();
526         UnionDecl = RD;
527       } else if (T->isStructureType()) {
528         RD = T->getAsStructureType()->getDecl();
529       } else {
530         slangAssert(false && "Unknown type cannot be exported");
531         return false;
532       }
533 
534       if (RD != nullptr) {
535         RD = RD->getDefinition();
536         if (RD == nullptr) {
537           // FIXME
538           return true;
539         }
540       }
541 
542       // Fast check
543       if (RD->hasFlexibleArrayMember() || RD->hasObjectMember())
544         return false;
545 
546       // Insert myself into checking set
547       SPS.insert(T);
548 
549       // Check all elements
550       for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
551                FE = RD->field_end();
552            FI != FE;
553            FI++) {
554         const clang::FieldDecl *FD = *FI;
555         const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
556         FT = GetCanonicalType(FT);
557 
558         if (!ValidateTypeHelper(Context, C, FT, ND, Loc, SPS, true, UnionDecl,
559                                 TargetAPI, IsFilterscript, IsExtern)) {
560           return false;
561         }
562       }
563 
564       return true;
565     }
566 
567     case clang::Type::Builtin: {
568       if (IsFilterscript) {
569         clang::QualType QT = T->getCanonicalTypeInternal();
570         if (QT == C.DoubleTy ||
571             QT == C.LongDoubleTy ||
572             QT == C.LongTy ||
573             QT == C.LongLongTy) {
574           if (ND) {
575             Context->ReportError(
576                 Loc,
577                 "Builtin types > 32 bits in size are forbidden in "
578                 "Filterscript: '%0'")
579                 << ND->getName();
580           } else {
581             Context->ReportError(
582                 Loc,
583                 "Builtin types > 32 bits in size are forbidden in "
584                 "Filterscript");
585           }
586           return false;
587         }
588       }
589       break;
590     }
591 
592     case clang::Type::Pointer: {
593       if (IsFilterscript) {
594         if (ND) {
595           Context->ReportError(Loc,
596                                "Pointers are forbidden in Filterscript: '%0'")
597               << ND->getName();
598           return false;
599         } else {
600           // TODO(srhines): Find a better way to handle expressions (i.e. no
601           // NamedDecl) involving pointers in FS that should be allowed.
602           // An example would be calls to library functions like
603           // rsMatrixMultiply() that take rs_matrixNxN * types.
604         }
605       }
606 
607       // Forbid pointers in structures that are externally visible.
608       if (InCompositeType && IsExtern) {
609         if (ND) {
610           Context->ReportError(Loc,
611               "structures containing pointers cannot be used as the type of "
612               "an exported global variable or the parameter to an exported "
613               "function: '%0'")
614             << ND->getName();
615         } else {
616           Context->ReportError(Loc,
617               "structures containing pointers cannot be used as the type of "
618               "an exported global variable or the parameter to an exported "
619               "function");
620         }
621         return false;
622       }
623 
624       const clang::PointerType *PT = static_cast<const clang::PointerType*>(CTI);
625       const clang::Type *PointeeType = GetPointeeType(PT);
626 
627       return ValidateTypeHelper(Context, C, PointeeType, ND, Loc, SPS,
628                                 InCompositeType, UnionDecl, TargetAPI,
629                                 IsFilterscript, IsExtern);
630     }
631 
632     case clang::Type::ExtVector: {
633       const clang::ExtVectorType *EVT =
634               static_cast<const clang::ExtVectorType*>(CTI);
635       const clang::Type *ElementType = GetExtVectorElementType(EVT);
636       if (TargetAPI < SLANG_ICS_TARGET_API &&
637           InCompositeType &&
638           EVT->getNumElements() == 3 &&
639           ND &&
640           ND->getFormalLinkage() == clang::ExternalLinkage) {
641         ReportTypeError(Context, ND, nullptr,
642                         "structs containing vectors of dimension 3 cannot "
643                         "be exported at this API level: '%0'");
644         return false;
645       }
646       return ValidateTypeHelper(Context, C, ElementType, ND, Loc, SPS, true,
647                                 UnionDecl, TargetAPI, IsFilterscript, IsExtern);
648     }
649 
650     case clang::Type::ConstantArray: {
651       const clang::ConstantArrayType *CAT = static_cast<const clang::ConstantArrayType*>(CTI);
652       const clang::Type *ElementType = GetConstantArrayElementType(CAT);
653       return ValidateTypeHelper(Context, C, ElementType, ND, Loc, SPS, true,
654                                 UnionDecl, TargetAPI, IsFilterscript, IsExtern);
655     }
656 
657     default: {
658       break;
659     }
660   }
661 
662   return true;
663 }
664 
665 }  // namespace
666 
CreateDummyName(const char * type,const std::string & name)667 std::string CreateDummyName(const char *type, const std::string &name) {
668   std::stringstream S;
669   S << "<" << type;
670   if (!name.empty()) {
671     S << ":" << name;
672   }
673   S << ">";
674   return S.str();
675 }
676 
677 /****************************** RSExportType ******************************/
NormalizeType(const clang::Type * & T,llvm::StringRef & TypeName,RSContext * Context,const clang::VarDecl * VD,ExportKind EK)678 bool RSExportType::NormalizeType(const clang::Type *&T,
679                                  llvm::StringRef &TypeName,
680                                  RSContext *Context,
681                                  const clang::VarDecl *VD,
682                                  ExportKind EK) {
683   if ((T = TypeExportable(T, Context, VD, EK)) == nullptr) {
684     return false;
685   }
686   // Get type name
687   TypeName = RSExportType::GetTypeName(T);
688   if (Context && TypeName.empty()) {
689     if (VD) {
690       Context->ReportError(VD->getLocation(),
691                            "anonymous types cannot be exported");
692     } else {
693       Context->ReportError("anonymous types cannot be exported");
694     }
695     return false;
696   }
697 
698   return true;
699 }
700 
ValidateType(slang::RSContext * Context,clang::ASTContext & C,clang::QualType QT,const clang::NamedDecl * ND,clang::SourceLocation Loc,unsigned int TargetAPI,bool IsFilterscript,bool IsExtern)701 bool RSExportType::ValidateType(slang::RSContext *Context, clang::ASTContext &C,
702                                 clang::QualType QT, const clang::NamedDecl *ND,
703                                 clang::SourceLocation Loc,
704                                 unsigned int TargetAPI, bool IsFilterscript,
705                                 bool IsExtern) {
706   const clang::Type *T = QT.getTypePtr();
707   llvm::SmallPtrSet<const clang::Type*, 8> SPS =
708       llvm::SmallPtrSet<const clang::Type*, 8>();
709 
710   // If this is an externally visible variable declaration, we check if the
711   // type is able to be exported first.
712   if (auto VD = llvm::dyn_cast_or_null<clang::VarDecl>(ND)) {
713     if (VD->getFormalLinkage() == clang::ExternalLinkage) {
714       if (!TypeExportable(T, Context, VD, NotLegacyKernelArgument)) {
715         return false;
716       }
717     }
718   }
719   return ValidateTypeHelper(Context, C, T, ND, Loc, SPS, false, nullptr, TargetAPI,
720                             IsFilterscript, IsExtern);
721 }
722 
ValidateVarDecl(slang::RSContext * Context,clang::VarDecl * VD,unsigned int TargetAPI,bool IsFilterscript)723 bool RSExportType::ValidateVarDecl(slang::RSContext *Context,
724                                    clang::VarDecl *VD, unsigned int TargetAPI,
725                                    bool IsFilterscript) {
726   return ValidateType(Context, VD->getASTContext(), VD->getType(), VD,
727                       VD->getLocation(), TargetAPI, IsFilterscript,
728                       (VD->getFormalLinkage() == clang::ExternalLinkage));
729 }
730 
731 const clang::Type
GetTypeOfDecl(const clang::DeclaratorDecl * DD)732 *RSExportType::GetTypeOfDecl(const clang::DeclaratorDecl *DD) {
733   if (DD) {
734     clang::QualType T = DD->getType();
735 
736     if (T.isNull())
737       return nullptr;
738     else
739       return T.getTypePtr();
740   }
741   return nullptr;
742 }
743 
GetTypeName(const clang::Type * T)744 llvm::StringRef RSExportType::GetTypeName(const clang::Type* T) {
745   T = GetCanonicalType(T);
746   if (T == nullptr)
747     return llvm::StringRef();
748 
749   const clang::Type *CTI = T->getCanonicalTypeInternal().getTypePtr();
750 
751   switch (T->getTypeClass()) {
752     case clang::Type::Builtin: {
753       const clang::BuiltinType *BT = static_cast<const clang::BuiltinType*>(CTI);
754       BuiltinInfo *info = FindBuiltinType(BT->getKind());
755       if (info != nullptr) {
756         return info->cname[0];
757       }
758       slangAssert(false && "Unknown data type of the builtin");
759       break;
760     }
761     case clang::Type::Record: {
762       clang::RecordDecl *RD;
763       if (T->isStructureType()) {
764         RD = T->getAsStructureType()->getDecl();
765       } else {
766         break;
767       }
768 
769       llvm::StringRef Name = RD->getName();
770       if (Name.empty()) {
771         if (RD->getTypedefNameForAnonDecl() != nullptr) {
772           Name = RD->getTypedefNameForAnonDecl()->getName();
773         }
774 
775         if (Name.empty()) {
776           // Try to find a name from redeclaration (i.e. typedef)
777           for (clang::TagDecl::redecl_iterator RI = RD->redecls_begin(),
778                    RE = RD->redecls_end();
779                RI != RE;
780                RI++) {
781             slangAssert(*RI != nullptr && "cannot be NULL object");
782 
783             Name = (*RI)->getName();
784             if (!Name.empty())
785               break;
786           }
787         }
788       }
789       return Name;
790     }
791     case clang::Type::Pointer: {
792       // "*" plus pointee name
793       const clang::PointerType *P = static_cast<const clang::PointerType*>(CTI);
794       const clang::Type *PT = GetPointeeType(P);
795       llvm::StringRef PointeeName;
796       if (NormalizeType(PT, PointeeName, nullptr, nullptr,
797                         NotLegacyKernelArgument)) {
798         char *Name = new char[ 1 /* * */ + PointeeName.size() + 1 ];
799         Name[0] = '*';
800         memcpy(Name + 1, PointeeName.data(), PointeeName.size());
801         Name[PointeeName.size() + 1] = '\0';
802         return Name;
803       }
804       break;
805     }
806     case clang::Type::ExtVector: {
807       const clang::ExtVectorType *EVT =
808               static_cast<const clang::ExtVectorType*>(CTI);
809       return RSExportVectorType::GetTypeName(EVT);
810       break;
811     }
812     case clang::Type::ConstantArray : {
813       // Construct name for a constant array is too complicated.
814       return "<ConstantArray>";
815     }
816     default: {
817       break;
818     }
819   }
820 
821   return llvm::StringRef();
822 }
823 
824 
Create(RSContext * Context,const clang::Type * T,const llvm::StringRef & TypeName,ExportKind EK)825 RSExportType *RSExportType::Create(RSContext *Context,
826                                    const clang::Type *T,
827                                    const llvm::StringRef &TypeName,
828                                    ExportKind EK) {
829   // Lookup the context to see whether the type was processed before.
830   // Newly created RSExportType will insert into context
831   // in RSExportType::RSExportType()
832   RSContext::export_type_iterator ETI = Context->findExportType(TypeName);
833 
834   if (ETI != Context->export_types_end())
835     return ETI->second;
836 
837   const clang::Type *CTI = T->getCanonicalTypeInternal().getTypePtr();
838 
839   RSExportType *ET = nullptr;
840   switch (T->getTypeClass()) {
841     case clang::Type::Record: {
842       DataType dt = RSExportPrimitiveType::GetRSSpecificType(TypeName);
843       switch (dt) {
844         case DataTypeUnknown: {
845           // User-defined types
846           ET = RSExportRecordType::Create(Context,
847                                           T->getAsStructureType(),
848                                           TypeName);
849           break;
850         }
851         case DataTypeRSMatrix2x2: {
852           // 2 x 2 Matrix type
853           ET = RSExportMatrixType::Create(Context,
854                                           T->getAsStructureType(),
855                                           TypeName,
856                                           2);
857           break;
858         }
859         case DataTypeRSMatrix3x3: {
860           // 3 x 3 Matrix type
861           ET = RSExportMatrixType::Create(Context,
862                                           T->getAsStructureType(),
863                                           TypeName,
864                                           3);
865           break;
866         }
867         case DataTypeRSMatrix4x4: {
868           // 4 x 4 Matrix type
869           ET = RSExportMatrixType::Create(Context,
870                                           T->getAsStructureType(),
871                                           TypeName,
872                                           4);
873           break;
874         }
875         default: {
876           // Others are primitive types
877           ET = RSExportPrimitiveType::Create(Context, T, TypeName);
878           break;
879         }
880       }
881       break;
882     }
883     case clang::Type::Builtin: {
884       ET = RSExportPrimitiveType::Create(Context, T, TypeName);
885       break;
886     }
887     case clang::Type::Pointer: {
888       ET = RSExportPointerType::Create(Context,
889                                        static_cast<const clang::PointerType*>(CTI),
890                                        TypeName);
891       // FIXME: free the name (allocated in RSExportType::GetTypeName)
892       delete [] TypeName.data();
893       break;
894     }
895     case clang::Type::ExtVector: {
896       ET = RSExportVectorType::Create(Context,
897                                       static_cast<const clang::ExtVectorType*>(CTI),
898                                       TypeName);
899       break;
900     }
901     case clang::Type::ConstantArray: {
902       ET = RSExportConstantArrayType::Create(
903               Context,
904               static_cast<const clang::ConstantArrayType*>(CTI));
905       break;
906     }
907     default: {
908       Context->ReportError("unknown type cannot be exported: '%0'")
909           << T->getTypeClassName();
910       break;
911     }
912   }
913 
914   return ET;
915 }
916 
Create(RSContext * Context,const clang::Type * T,ExportKind EK,const clang::VarDecl * VD)917 RSExportType *RSExportType::Create(RSContext *Context, const clang::Type *T,
918                                    ExportKind EK, const clang::VarDecl *VD) {
919   llvm::StringRef TypeName;
920   if (NormalizeType(T, TypeName, Context, VD, EK)) {
921     return Create(Context, T, TypeName, EK);
922   } else {
923     return nullptr;
924   }
925 }
926 
CreateFromDecl(RSContext * Context,const clang::VarDecl * VD)927 RSExportType *RSExportType::CreateFromDecl(RSContext *Context,
928                                            const clang::VarDecl *VD) {
929   return RSExportType::Create(Context, GetTypeOfDecl(VD),
930                               NotLegacyKernelArgument, VD);
931 }
932 
getStoreSize() const933 size_t RSExportType::getStoreSize() const {
934   return getRSContext()->getDataLayout()->getTypeStoreSize(getLLVMType());
935 }
936 
getAllocSize() const937 size_t RSExportType::getAllocSize() const {
938     return getRSContext()->getDataLayout()->getTypeAllocSize(getLLVMType());
939 }
940 
RSExportType(RSContext * Context,ExportClass Class,const llvm::StringRef & Name)941 RSExportType::RSExportType(RSContext *Context,
942                            ExportClass Class,
943                            const llvm::StringRef &Name)
944     : RSExportable(Context, RSExportable::EX_TYPE),
945       mClass(Class),
946       // Make a copy on Name since memory stored @Name is either allocated in
947       // ASTContext or allocated in GetTypeName which will be destroyed later.
948       mName(Name.data(), Name.size()),
949       mLLVMType(nullptr) {
950   // Don't cache the type whose name start with '<'. Those type failed to
951   // get their name since constructing their name in GetTypeName() requiring
952   // complicated work.
953   if (!IsDummyName(Name)) {
954     // TODO(zonr): Need to check whether the insertion is successful or not.
955     Context->insertExportType(llvm::StringRef(Name), this);
956   }
957 
958 }
959 
keep()960 bool RSExportType::keep() {
961   if (!RSExportable::keep())
962     return false;
963   // Invalidate converted LLVM type.
964   mLLVMType = nullptr;
965   return true;
966 }
967 
equals(const RSExportable * E) const968 bool RSExportType::equals(const RSExportable *E) const {
969   CHECK_PARENT_EQUALITY(RSExportable, E);
970   return (static_cast<const RSExportType*>(E)->getClass() == getClass());
971 }
972 
~RSExportType()973 RSExportType::~RSExportType() {
974 }
975 
976 /************************** RSExportPrimitiveType **************************/
977 llvm::ManagedStatic<RSExportPrimitiveType::RSSpecificTypeMapTy>
978 RSExportPrimitiveType::RSSpecificTypeMap;
979 
IsPrimitiveType(const clang::Type * T)980 bool RSExportPrimitiveType::IsPrimitiveType(const clang::Type *T) {
981   if ((T != nullptr) && (T->getTypeClass() == clang::Type::Builtin))
982     return true;
983   else
984     return false;
985 }
986 
987 DataType
GetRSSpecificType(const llvm::StringRef & TypeName)988 RSExportPrimitiveType::GetRSSpecificType(const llvm::StringRef &TypeName) {
989   if (TypeName.empty())
990     return DataTypeUnknown;
991 
992   if (RSSpecificTypeMap->empty()) {
993     for (int i = 0; i < MatrixAndObjectDataTypesCount; i++) {
994       (*RSSpecificTypeMap)[MatrixAndObjectDataTypes[i].name] =
995           MatrixAndObjectDataTypes[i].dataType;
996     }
997   }
998 
999   RSSpecificTypeMapTy::const_iterator I = RSSpecificTypeMap->find(TypeName);
1000   if (I == RSSpecificTypeMap->end())
1001     return DataTypeUnknown;
1002   else
1003     return I->getValue();
1004 }
1005 
GetRSSpecificType(const clang::Type * T)1006 DataType RSExportPrimitiveType::GetRSSpecificType(const clang::Type *T) {
1007   T = GetCanonicalType(T);
1008   if ((T == nullptr) || (T->getTypeClass() != clang::Type::Record))
1009     return DataTypeUnknown;
1010 
1011   return GetRSSpecificType( RSExportType::GetTypeName(T) );
1012 }
1013 
IsRSMatrixType(DataType DT)1014 bool RSExportPrimitiveType::IsRSMatrixType(DataType DT) {
1015     if (DT < 0 || DT >= DataTypeMax) {
1016         return false;
1017     }
1018     return gReflectionTypes[DT].category == MatrixDataType;
1019 }
1020 
IsRSObjectType(DataType DT)1021 bool RSExportPrimitiveType::IsRSObjectType(DataType DT) {
1022     if (DT < 0 || DT >= DataTypeMax) {
1023         return false;
1024     }
1025     return gReflectionTypes[DT].category == ObjectDataType;
1026 }
1027 
IsStructureTypeWithRSObject(const clang::Type * T)1028 bool RSExportPrimitiveType::IsStructureTypeWithRSObject(const clang::Type *T) {
1029   bool RSObjectTypeSeen = false;
1030   while (T && T->isArrayType()) {
1031     T = T->getArrayElementTypeNoTypeQual();
1032   }
1033 
1034   const clang::RecordType *RT = T->getAsStructureType();
1035   if (!RT) {
1036     return false;
1037   }
1038 
1039   const clang::RecordDecl *RD = RT->getDecl();
1040   if (RD) {
1041     RD = RD->getDefinition();
1042   }
1043   if (!RD) {
1044     return false;
1045   }
1046 
1047   for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
1048          FE = RD->field_end();
1049        FI != FE;
1050        FI++) {
1051     // We just look through all field declarations to see if we find a
1052     // declaration for an RS object type (or an array of one).
1053     const clang::FieldDecl *FD = *FI;
1054     const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
1055     while (FT && FT->isArrayType()) {
1056       FT = FT->getArrayElementTypeNoTypeQual();
1057     }
1058 
1059     DataType DT = GetRSSpecificType(FT);
1060     if (IsRSObjectType(DT)) {
1061       // RS object types definitely need to be zero-initialized
1062       RSObjectTypeSeen = true;
1063     } else {
1064       switch (DT) {
1065         case DataTypeRSMatrix2x2:
1066         case DataTypeRSMatrix3x3:
1067         case DataTypeRSMatrix4x4:
1068           // Matrix types should get zero-initialized as well
1069           RSObjectTypeSeen = true;
1070           break;
1071         default:
1072           // Ignore all other primitive types
1073           break;
1074       }
1075       while (FT && FT->isArrayType()) {
1076         FT = FT->getArrayElementTypeNoTypeQual();
1077       }
1078       if (FT->isStructureType()) {
1079         // Recursively handle structs of structs (even though these can't
1080         // be exported, it is possible for a user to have them internally).
1081         RSObjectTypeSeen |= IsStructureTypeWithRSObject(FT);
1082       }
1083     }
1084   }
1085 
1086   return RSObjectTypeSeen;
1087 }
1088 
GetElementSizeInBits(const RSExportPrimitiveType * EPT)1089 size_t RSExportPrimitiveType::GetElementSizeInBits(const RSExportPrimitiveType *EPT) {
1090   int type = EPT->getType();
1091   slangAssert((type > DataTypeUnknown && type < DataTypeMax) &&
1092               "RSExportPrimitiveType::GetElementSizeInBits : unknown data type");
1093   // All RS object types are 256 bits in 64-bit RS.
1094   if (EPT->isRSObjectType() && EPT->getRSContext()->is64Bit()) {
1095     return 256;
1096   }
1097   return gReflectionTypes[type].size_in_bits;
1098 }
1099 
1100 DataType
GetDataType(RSContext * Context,const clang::Type * T)1101 RSExportPrimitiveType::GetDataType(RSContext *Context, const clang::Type *T) {
1102   if (T == nullptr)
1103     return DataTypeUnknown;
1104 
1105   switch (T->getTypeClass()) {
1106     case clang::Type::Builtin: {
1107       const clang::BuiltinType *BT =
1108               static_cast<const clang::BuiltinType*>(T->getCanonicalTypeInternal().getTypePtr());
1109       BuiltinInfo *info = FindBuiltinType(BT->getKind());
1110       if (info != nullptr) {
1111         return info->type;
1112       }
1113       // The size of type WChar depend on platform so we abandon the support
1114       // to them.
1115       Context->ReportError("built-in type cannot be exported: '%0'")
1116           << T->getTypeClassName();
1117       break;
1118     }
1119     case clang::Type::Record: {
1120       // must be RS object type
1121       return RSExportPrimitiveType::GetRSSpecificType(T);
1122     }
1123     default: {
1124       Context->ReportError("primitive type cannot be exported: '%0'")
1125           << T->getTypeClassName();
1126       break;
1127     }
1128   }
1129 
1130   return DataTypeUnknown;
1131 }
1132 
1133 RSExportPrimitiveType
Create(RSContext * Context,const clang::Type * T,const llvm::StringRef & TypeName,bool Normalized)1134 *RSExportPrimitiveType::Create(RSContext *Context,
1135                                const clang::Type *T,
1136                                const llvm::StringRef &TypeName,
1137                                bool Normalized) {
1138   DataType DT = GetDataType(Context, T);
1139 
1140   if ((DT == DataTypeUnknown) || TypeName.empty())
1141     return nullptr;
1142   else
1143     return new RSExportPrimitiveType(Context, ExportClassPrimitive, TypeName,
1144                                      DT, Normalized);
1145 }
1146 
Create(RSContext * Context,const clang::Type * T)1147 RSExportPrimitiveType *RSExportPrimitiveType::Create(RSContext *Context,
1148                                                      const clang::Type *T) {
1149   llvm::StringRef TypeName;
1150   if (RSExportType::NormalizeType(T, TypeName, Context, nullptr,
1151                                   NotLegacyKernelArgument) &&
1152       IsPrimitiveType(T)) {
1153     return Create(Context, T, TypeName);
1154   } else {
1155     return nullptr;
1156   }
1157 }
1158 
convertToLLVMType() const1159 llvm::Type *RSExportPrimitiveType::convertToLLVMType() const {
1160   llvm::LLVMContext &C = getRSContext()->getLLVMContext();
1161 
1162   if (isRSObjectType()) {
1163     // struct {
1164     //   int *p;
1165     // } __attribute__((packed, aligned(pointer_size)))
1166     //
1167     // which is
1168     //
1169     // <{ [1 x i32] }> in LLVM
1170     //
1171     std::vector<llvm::Type *> Elements;
1172     if (getRSContext()->is64Bit()) {
1173       // 64-bit path
1174       Elements.push_back(llvm::ArrayType::get(llvm::Type::getInt64Ty(C), 4));
1175       return llvm::StructType::get(C, Elements, true);
1176     } else {
1177       // 32-bit legacy path
1178       Elements.push_back(llvm::ArrayType::get(llvm::Type::getInt32Ty(C), 1));
1179       return llvm::StructType::get(C, Elements, true);
1180     }
1181   }
1182 
1183   switch (mType) {
1184     case DataTypeFloat16: {
1185       return llvm::Type::getHalfTy(C);
1186       break;
1187     }
1188     case DataTypeFloat32: {
1189       return llvm::Type::getFloatTy(C);
1190       break;
1191     }
1192     case DataTypeFloat64: {
1193       return llvm::Type::getDoubleTy(C);
1194       break;
1195     }
1196     case DataTypeBoolean: {
1197       return llvm::Type::getInt1Ty(C);
1198       break;
1199     }
1200     case DataTypeSigned8:
1201     case DataTypeUnsigned8: {
1202       return llvm::Type::getInt8Ty(C);
1203       break;
1204     }
1205     case DataTypeSigned16:
1206     case DataTypeUnsigned16:
1207     case DataTypeUnsigned565:
1208     case DataTypeUnsigned5551:
1209     case DataTypeUnsigned4444: {
1210       return llvm::Type::getInt16Ty(C);
1211       break;
1212     }
1213     case DataTypeSigned32:
1214     case DataTypeUnsigned32: {
1215       return llvm::Type::getInt32Ty(C);
1216       break;
1217     }
1218     case DataTypeSigned64:
1219     case DataTypeUnsigned64: {
1220       return llvm::Type::getInt64Ty(C);
1221       break;
1222     }
1223     default: {
1224       slangAssert(false && "Unknown data type");
1225     }
1226   }
1227 
1228   return nullptr;
1229 }
1230 
equals(const RSExportable * E) const1231 bool RSExportPrimitiveType::equals(const RSExportable *E) const {
1232   CHECK_PARENT_EQUALITY(RSExportType, E);
1233   return (static_cast<const RSExportPrimitiveType*>(E)->getType() == getType());
1234 }
1235 
getRSReflectionType(DataType DT)1236 RSReflectionType *RSExportPrimitiveType::getRSReflectionType(DataType DT) {
1237   if (DT > DataTypeUnknown && DT < DataTypeMax) {
1238     return &gReflectionTypes[DT];
1239   } else {
1240     return nullptr;
1241   }
1242 }
1243 
1244 /**************************** RSExportPointerType ****************************/
1245 
1246 RSExportPointerType
Create(RSContext * Context,const clang::PointerType * PT,const llvm::StringRef & TypeName)1247 *RSExportPointerType::Create(RSContext *Context,
1248                              const clang::PointerType *PT,
1249                              const llvm::StringRef &TypeName) {
1250   const clang::Type *PointeeType = GetPointeeType(PT);
1251   const RSExportType *PointeeET;
1252 
1253   if (PointeeType->getTypeClass() != clang::Type::Pointer) {
1254     PointeeET = RSExportType::Create(Context, PointeeType,
1255                                      NotLegacyKernelArgument);
1256   } else {
1257     // Double or higher dimension of pointer, export as int*
1258     PointeeET = RSExportPrimitiveType::Create(Context,
1259                     Context->getASTContext().IntTy.getTypePtr());
1260   }
1261 
1262   if (PointeeET == nullptr) {
1263     // Error diagnostic is emitted for corresponding pointee type
1264     return nullptr;
1265   }
1266 
1267   return new RSExportPointerType(Context, TypeName, PointeeET);
1268 }
1269 
convertToLLVMType() const1270 llvm::Type *RSExportPointerType::convertToLLVMType() const {
1271   llvm::Type *PointeeType = mPointeeType->getLLVMType();
1272   return llvm::PointerType::getUnqual(PointeeType);
1273 }
1274 
keep()1275 bool RSExportPointerType::keep() {
1276   if (!RSExportType::keep())
1277     return false;
1278   const_cast<RSExportType*>(mPointeeType)->keep();
1279   return true;
1280 }
1281 
equals(const RSExportable * E) const1282 bool RSExportPointerType::equals(const RSExportable *E) const {
1283   CHECK_PARENT_EQUALITY(RSExportType, E);
1284   return (static_cast<const RSExportPointerType*>(E)
1285               ->getPointeeType()->equals(getPointeeType()));
1286 }
1287 
1288 /***************************** RSExportVectorType *****************************/
1289 llvm::StringRef
GetTypeName(const clang::ExtVectorType * EVT)1290 RSExportVectorType::GetTypeName(const clang::ExtVectorType *EVT) {
1291   const clang::Type *ElementType = GetExtVectorElementType(EVT);
1292   llvm::StringRef name;
1293 
1294   if ((ElementType->getTypeClass() != clang::Type::Builtin))
1295     return name;
1296 
1297   const clang::BuiltinType *BT =
1298           static_cast<const clang::BuiltinType*>(
1299               ElementType->getCanonicalTypeInternal().getTypePtr());
1300 
1301   if ((EVT->getNumElements() < 1) ||
1302       (EVT->getNumElements() > 4))
1303     return name;
1304 
1305   BuiltinInfo *info = FindBuiltinType(BT->getKind());
1306   if (info != nullptr) {
1307     int I = EVT->getNumElements() - 1;
1308     if (I < kMaxVectorSize) {
1309       name = info->cname[I];
1310     } else {
1311       slangAssert(false && "Max vector is 4");
1312     }
1313   }
1314   return name;
1315 }
1316 
Create(RSContext * Context,const clang::ExtVectorType * EVT,const llvm::StringRef & TypeName,bool Normalized)1317 RSExportVectorType *RSExportVectorType::Create(RSContext *Context,
1318                                                const clang::ExtVectorType *EVT,
1319                                                const llvm::StringRef &TypeName,
1320                                                bool Normalized) {
1321   slangAssert(EVT != nullptr && EVT->getTypeClass() == clang::Type::ExtVector);
1322 
1323   const clang::Type *ElementType = GetExtVectorElementType(EVT);
1324   DataType DT = RSExportPrimitiveType::GetDataType(Context, ElementType);
1325 
1326   if (DT != DataTypeUnknown)
1327     return new RSExportVectorType(Context,
1328                                   TypeName,
1329                                   DT,
1330                                   Normalized,
1331                                   EVT->getNumElements());
1332   else
1333     return nullptr;
1334 }
1335 
convertToLLVMType() const1336 llvm::Type *RSExportVectorType::convertToLLVMType() const {
1337   llvm::Type *ElementType = RSExportPrimitiveType::convertToLLVMType();
1338   return llvm::VectorType::get(ElementType, getNumElement());
1339 }
1340 
equals(const RSExportable * E) const1341 bool RSExportVectorType::equals(const RSExportable *E) const {
1342   CHECK_PARENT_EQUALITY(RSExportPrimitiveType, E);
1343   return (static_cast<const RSExportVectorType*>(E)->getNumElement()
1344               == getNumElement());
1345 }
1346 
1347 /***************************** RSExportMatrixType *****************************/
Create(RSContext * Context,const clang::RecordType * RT,const llvm::StringRef & TypeName,unsigned Dim)1348 RSExportMatrixType *RSExportMatrixType::Create(RSContext *Context,
1349                                                const clang::RecordType *RT,
1350                                                const llvm::StringRef &TypeName,
1351                                                unsigned Dim) {
1352   slangAssert((RT != nullptr) && (RT->getTypeClass() == clang::Type::Record));
1353   slangAssert((Dim > 1) && "Invalid dimension of matrix");
1354 
1355   // Check whether the struct rs_matrix is in our expected form (but assume it's
1356   // correct if we're not sure whether it's correct or not)
1357   const clang::RecordDecl* RD = RT->getDecl();
1358   RD = RD->getDefinition();
1359   if (RD != nullptr) {
1360     // Find definition, perform further examination
1361     if (RD->field_empty()) {
1362       Context->ReportError(
1363           RD->getLocation(),
1364           "invalid matrix struct: must have 1 field for saving values: '%0'")
1365           << RD->getName();
1366       return nullptr;
1367     }
1368 
1369     clang::RecordDecl::field_iterator FIT = RD->field_begin();
1370     const clang::FieldDecl *FD = *FIT;
1371     const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
1372     if ((FT == nullptr) || (FT->getTypeClass() != clang::Type::ConstantArray)) {
1373       Context->ReportError(RD->getLocation(),
1374                            "invalid matrix struct: first field should"
1375                            " be an array with constant size: '%0'")
1376           << RD->getName();
1377       return nullptr;
1378     }
1379     const clang::ConstantArrayType *CAT =
1380       static_cast<const clang::ConstantArrayType *>(FT);
1381     const clang::Type *ElementType = GetConstantArrayElementType(CAT);
1382     if ((ElementType == nullptr) ||
1383         (ElementType->getTypeClass() != clang::Type::Builtin) ||
1384         (static_cast<const clang::BuiltinType *>(ElementType)->getKind() !=
1385          clang::BuiltinType::Float)) {
1386       Context->ReportError(RD->getLocation(),
1387                            "invalid matrix struct: first field "
1388                            "should be a float array: '%0'")
1389           << RD->getName();
1390       return nullptr;
1391     }
1392 
1393     if (CAT->getSize() != Dim * Dim) {
1394       Context->ReportError(RD->getLocation(),
1395                            "invalid matrix struct: first field "
1396                            "should be an array with size %0: '%1'")
1397           << (Dim * Dim) << (RD->getName());
1398       return nullptr;
1399     }
1400 
1401     FIT++;
1402     if (FIT != RD->field_end()) {
1403       Context->ReportError(RD->getLocation(),
1404                            "invalid matrix struct: must have "
1405                            "exactly 1 field: '%0'")
1406           << RD->getName();
1407       return nullptr;
1408     }
1409   }
1410 
1411   return new RSExportMatrixType(Context, TypeName, Dim);
1412 }
1413 
convertToLLVMType() const1414 llvm::Type *RSExportMatrixType::convertToLLVMType() const {
1415   // Construct LLVM type:
1416   // struct {
1417   //  float X[mDim * mDim];
1418   // }
1419 
1420   llvm::LLVMContext &C = getRSContext()->getLLVMContext();
1421   llvm::ArrayType *X = llvm::ArrayType::get(llvm::Type::getFloatTy(C),
1422                                             mDim * mDim);
1423   return llvm::StructType::get(C, X, false);
1424 }
1425 
equals(const RSExportable * E) const1426 bool RSExportMatrixType::equals(const RSExportable *E) const {
1427   CHECK_PARENT_EQUALITY(RSExportType, E);
1428   return (static_cast<const RSExportMatrixType*>(E)->getDim() == getDim());
1429 }
1430 
1431 /************************* RSExportConstantArrayType *************************/
1432 RSExportConstantArrayType
Create(RSContext * Context,const clang::ConstantArrayType * CAT)1433 *RSExportConstantArrayType::Create(RSContext *Context,
1434                                    const clang::ConstantArrayType *CAT) {
1435   slangAssert(CAT != nullptr && CAT->getTypeClass() == clang::Type::ConstantArray);
1436 
1437   slangAssert((CAT->getSize().getActiveBits() < 32) && "array too large");
1438 
1439   unsigned Size = static_cast<unsigned>(CAT->getSize().getZExtValue());
1440   slangAssert((Size > 0) && "Constant array should have size greater than 0");
1441 
1442   const clang::Type *ElementType = GetConstantArrayElementType(CAT);
1443   RSExportType *ElementET = RSExportType::Create(Context, ElementType,
1444                                                  NotLegacyKernelArgument);
1445 
1446   if (ElementET == nullptr) {
1447     return nullptr;
1448   }
1449 
1450   return new RSExportConstantArrayType(Context,
1451                                        ElementET,
1452                                        Size);
1453 }
1454 
convertToLLVMType() const1455 llvm::Type *RSExportConstantArrayType::convertToLLVMType() const {
1456   return llvm::ArrayType::get(mElementType->getLLVMType(), getNumElement());
1457 }
1458 
keep()1459 bool RSExportConstantArrayType::keep() {
1460   if (!RSExportType::keep())
1461     return false;
1462   const_cast<RSExportType*>(mElementType)->keep();
1463   return true;
1464 }
1465 
equals(const RSExportable * E) const1466 bool RSExportConstantArrayType::equals(const RSExportable *E) const {
1467   CHECK_PARENT_EQUALITY(RSExportType, E);
1468   const RSExportConstantArrayType *RHS =
1469       static_cast<const RSExportConstantArrayType*>(E);
1470   return ((getNumElement() == RHS->getNumElement()) &&
1471           (getElementType()->equals(RHS->getElementType())));
1472 }
1473 
1474 /**************************** RSExportRecordType ****************************/
Create(RSContext * Context,const clang::RecordType * RT,const llvm::StringRef & TypeName,bool mIsArtificial)1475 RSExportRecordType *RSExportRecordType::Create(RSContext *Context,
1476                                                const clang::RecordType *RT,
1477                                                const llvm::StringRef &TypeName,
1478                                                bool mIsArtificial) {
1479   slangAssert(RT != nullptr && RT->getTypeClass() == clang::Type::Record);
1480 
1481   const clang::RecordDecl *RD = RT->getDecl();
1482   slangAssert(RD->isStruct());
1483 
1484   RD = RD->getDefinition();
1485   if (RD == nullptr) {
1486     slangAssert(false && "struct is not defined in this module");
1487     return nullptr;
1488   }
1489 
1490   // Struct layout construct by clang. We rely on this for obtaining the
1491   // alloc size of a struct and offset of every field in that struct.
1492   const clang::ASTRecordLayout *RL =
1493       &Context->getASTContext().getASTRecordLayout(RD);
1494   slangAssert((RL != nullptr) &&
1495       "Failed to retrieve the struct layout from Clang.");
1496 
1497   RSExportRecordType *ERT =
1498       new RSExportRecordType(Context,
1499                              TypeName,
1500                              RD->hasAttr<clang::PackedAttr>(),
1501                              mIsArtificial,
1502                              RL->getDataSize().getQuantity(),
1503                              RL->getSize().getQuantity());
1504   unsigned int Index = 0;
1505 
1506   for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
1507            FE = RD->field_end();
1508        FI != FE;
1509        FI++, Index++) {
1510 
1511     // FIXME: All fields should be primitive type
1512     slangAssert(FI->getKind() == clang::Decl::Field);
1513     clang::FieldDecl *FD = *FI;
1514 
1515     if (FD->isBitField()) {
1516       return nullptr;
1517     }
1518 
1519     // Type
1520     RSExportType *ET = RSExportElement::CreateFromDecl(Context, FD);
1521 
1522     if (ET != nullptr) {
1523       ERT->mFields.push_back(
1524           new Field(ET, FD->getName(), ERT,
1525                     static_cast<size_t>(RL->getFieldOffset(Index) >> 3)));
1526     } else {
1527       Context->ReportError(RD->getLocation(),
1528                            "field type cannot be exported: '%0.%1'")
1529           << RD->getName() << FD->getName();
1530       return nullptr;
1531     }
1532   }
1533 
1534   return ERT;
1535 }
1536 
convertToLLVMType() const1537 llvm::Type *RSExportRecordType::convertToLLVMType() const {
1538   // Create an opaque type since struct may reference itself recursively.
1539 
1540   // TODO(sliao): LLVM took out the OpaqueType. Any other to migrate to?
1541   std::vector<llvm::Type*> FieldTypes;
1542 
1543   for (const_field_iterator FI = fields_begin(), FE = fields_end();
1544        FI != FE;
1545        FI++) {
1546     const Field *F = *FI;
1547     const RSExportType *FET = F->getType();
1548 
1549     FieldTypes.push_back(FET->getLLVMType());
1550   }
1551 
1552   llvm::StructType *ST = llvm::StructType::get(getRSContext()->getLLVMContext(),
1553                                                FieldTypes,
1554                                                mIsPacked);
1555   if (ST != nullptr) {
1556     return ST;
1557   } else {
1558     return nullptr;
1559   }
1560 }
1561 
keep()1562 bool RSExportRecordType::keep() {
1563   if (!RSExportType::keep())
1564     return false;
1565   for (std::list<const Field*>::iterator I = mFields.begin(),
1566           E = mFields.end();
1567        I != E;
1568        I++) {
1569     const_cast<RSExportType*>((*I)->getType())->keep();
1570   }
1571   return true;
1572 }
1573 
equals(const RSExportable * E) const1574 bool RSExportRecordType::equals(const RSExportable *E) const {
1575   CHECK_PARENT_EQUALITY(RSExportType, E);
1576 
1577   const RSExportRecordType *ERT = static_cast<const RSExportRecordType*>(E);
1578 
1579   if (ERT->getFields().size() != getFields().size())
1580     return false;
1581 
1582   const_field_iterator AI = fields_begin(), BI = ERT->fields_begin();
1583 
1584   for (unsigned i = 0, e = getFields().size(); i != e; i++) {
1585     if (!(*AI)->getType()->equals((*BI)->getType()))
1586       return false;
1587     AI++;
1588     BI++;
1589   }
1590 
1591   return true;
1592 }
1593 
convertToRTD(RSReflectionTypeData * rtd) const1594 void RSExportType::convertToRTD(RSReflectionTypeData *rtd) const {
1595     memset(rtd, 0, sizeof(*rtd));
1596     rtd->vecSize = 1;
1597 
1598     switch(getClass()) {
1599     case RSExportType::ExportClassPrimitive: {
1600             const RSExportPrimitiveType *EPT = static_cast<const RSExportPrimitiveType*>(this);
1601             rtd->type = RSExportPrimitiveType::getRSReflectionType(EPT);
1602             return;
1603         }
1604     case RSExportType::ExportClassPointer: {
1605             const RSExportPointerType *EPT = static_cast<const RSExportPointerType*>(this);
1606             const RSExportType *PointeeType = EPT->getPointeeType();
1607             PointeeType->convertToRTD(rtd);
1608             rtd->isPointer = true;
1609             return;
1610         }
1611     case RSExportType::ExportClassVector: {
1612             const RSExportVectorType *EVT = static_cast<const RSExportVectorType*>(this);
1613             rtd->type = EVT->getRSReflectionType(EVT);
1614             rtd->vecSize = EVT->getNumElement();
1615             return;
1616         }
1617     case RSExportType::ExportClassMatrix: {
1618             const RSExportMatrixType *EMT = static_cast<const RSExportMatrixType*>(this);
1619             unsigned Dim = EMT->getDim();
1620             slangAssert((Dim >= 2) && (Dim <= 4));
1621             rtd->type = &gReflectionTypes[15 + Dim-2];
1622             return;
1623         }
1624     case RSExportType::ExportClassConstantArray: {
1625             const RSExportConstantArrayType* CAT =
1626               static_cast<const RSExportConstantArrayType*>(this);
1627             CAT->getElementType()->convertToRTD(rtd);
1628             rtd->arraySize = CAT->getNumElement();
1629             return;
1630         }
1631     case RSExportType::ExportClassRecord: {
1632             slangAssert(!"RSExportType::ExportClassRecord not implemented");
1633             return;// RS_TYPE_CLASS_NAME_PREFIX + ET->getName() + ".Item";
1634         }
1635     default: {
1636             slangAssert(false && "Unknown class of type");
1637         }
1638     }
1639 }
1640 
1641 
1642 }  // namespace slang
1643