• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 import com.github.javaparser.JavaParser
18 import com.github.javaparser.ParseProblemException
19 import com.github.javaparser.ParseResult
20 import com.github.javaparser.ParserConfiguration
21 import com.github.javaparser.ast.Node
22 import com.github.javaparser.ast.body.ClassOrInterfaceDeclaration
23 import com.github.javaparser.ast.body.FieldDeclaration
24 import com.github.javaparser.ast.body.TypeDeclaration
25 import com.github.javaparser.ast.expr.AnnotationExpr
26 import com.github.javaparser.ast.expr.Expression
27 import com.github.javaparser.ast.expr.NormalAnnotationExpr
28 import com.github.javaparser.ast.expr.SingleMemberAnnotationExpr
29 import com.github.javaparser.ast.expr.StringLiteralExpr
30 import com.github.javaparser.resolution.declarations.ResolvedReferenceTypeDeclaration
31 import com.github.javaparser.resolution.types.ResolvedPrimitiveType
32 import com.github.javaparser.resolution.types.ResolvedReferenceType
33 import com.github.javaparser.symbolsolver.JavaSymbolSolver
34 import com.github.javaparser.symbolsolver.javaparsermodel.declarations.JavaParserClassDeclaration
35 import com.github.javaparser.symbolsolver.resolution.typesolvers.CombinedTypeSolver
36 import com.github.javaparser.symbolsolver.resolution.typesolvers.MemoryTypeSolver
37 import com.github.javaparser.symbolsolver.resolution.typesolvers.ReflectionTypeSolver
38 import com.squareup.javapoet.ClassName
39 import com.squareup.javapoet.ParameterizedTypeName
40 import com.squareup.javapoet.TypeName
41 import java.nio.file.Path
42 import java.util.Optional
43 
44 class PersistenceInfo(
45     val name: String,
46     val root: ClassFieldInfo,
47     val path: Path
48 )
49 
50 sealed class FieldInfo {
51     abstract val name: String
52     abstract val xmlName: String?
53     abstract val type: TypeName
54     abstract val isRequired: Boolean
55 }
56 
57 class PrimitiveFieldInfo(
58     override val name: String,
59     override val xmlName: String?,
60     override val type: TypeName,
61     override val isRequired: Boolean
62 ) : FieldInfo()
63 
64 class StringFieldInfo(
65     override val name: String,
66     override val xmlName: String?,
67     override val isRequired: Boolean
68 ) : FieldInfo() {
69     override val type: TypeName = ClassName.get(String::class.java)
70 }
71 
72 class ClassFieldInfo(
73     override val name: String,
74     override val xmlName: String?,
75     override val type: ClassName,
76     override val isRequired: Boolean,
77     val fields: List<FieldInfo>
78 ) : FieldInfo()
79 
80 class ListFieldInfo(
81     override val name: String,
82     override val xmlName: String?,
83     override val type: ParameterizedTypeName,
84     val element: ClassFieldInfo
85 ) : FieldInfo() {
86     override val isRequired: Boolean = true
87 }
88 
parsenull89 fun parse(files: List<Path>): List<PersistenceInfo> {
90     val typeSolver = CombinedTypeSolver().apply { add(ReflectionTypeSolver()) }
91     val javaParser = JavaParser(ParserConfiguration()
92         .setSymbolResolver(JavaSymbolSolver(typeSolver)))
93     val compilationUnits = files.map { javaParser.parse(it).getOrThrow() }
94     val memoryTypeSolver = MemoryTypeSolver().apply {
95         for (compilationUnit in compilationUnits) {
96             for (typeDeclaration in compilationUnit.getNodesByClass<TypeDeclaration<*>>()) {
97                 val name = typeDeclaration.fullyQualifiedName.getOrNull() ?: continue
98                 addDeclaration(name, typeDeclaration.resolve())
99             }
100         }
101     }
102     typeSolver.add(memoryTypeSolver)
103     return mutableListOf<PersistenceInfo>().apply {
104         for (compilationUnit in compilationUnits) {
105             val classDeclarations = compilationUnit
106                 .getNodesByClass<ClassOrInterfaceDeclaration>()
107                 .filter { !it.isInterface && (!it.isNestedType || it.isStatic) }
108             this += classDeclarations.mapNotNull { parsePersistenceInfo(it) }
109         }
110     }
111 }
112 
parsePersistenceInfonull113 private fun parsePersistenceInfo(classDeclaration: ClassOrInterfaceDeclaration): PersistenceInfo? {
114     val annotation = classDeclaration.getAnnotationByName("XmlPersistence").getOrNull()
115         ?: return null
116     val rootClassName = classDeclaration.nameAsString
117     val name = annotation.getMemberValue("value")?.stringLiteralValue
118         ?: "${rootClassName}Persistence"
119     val rootXmlName = classDeclaration.getAnnotationByName("XmlName").getOrNull()
120         ?.getMemberValue("value")?.stringLiteralValue
121     val root = parseClassFieldInfo(
122         rootXmlName ?: rootClassName, rootXmlName, true, classDeclaration
123     )
124     val path = classDeclaration.findCompilationUnit().get().storage.get().path
125         .resolveSibling("$name.java")
126     return PersistenceInfo(name, root, path)
127 }
128 
parseClassFieldInfonull129 private fun parseClassFieldInfo(
130     name: String,
131     xmlName: String?,
132     isRequired: Boolean,
133     classDeclaration: ClassOrInterfaceDeclaration
134 ): ClassFieldInfo {
135     val fields = classDeclaration.fields.filterNot { it.isStatic }.map { parseFieldInfo(it) }
136     val type = classDeclaration.resolve().typeName
137     return ClassFieldInfo(name, xmlName, type, isRequired, fields)
138 }
139 
parseFieldInfonull140 private fun parseFieldInfo(field: FieldDeclaration): FieldInfo {
141     require(field.isPublic && field.isFinal)
142     val variable = field.variables.single()
143     val name = variable.nameAsString
144     val annotations = field.annotations + variable.type.annotations
145     val annotation = annotations.getByName("XmlName")
146     val xmlName = annotation?.getMemberValue("value")?.stringLiteralValue
147     val isRequired = annotations.getByName("NonNull") != null
148     return when (val type = variable.type.resolve()) {
149         is ResolvedPrimitiveType -> {
150             val primitiveType = type.typeName
151             PrimitiveFieldInfo(name, xmlName, primitiveType, true)
152         }
153         is ResolvedReferenceType -> {
154             when (type.qualifiedName) {
155                 Boolean::class.javaObjectType.name, Byte::class.javaObjectType.name,
156                 Short::class.javaObjectType.name, Char::class.javaObjectType.name,
157                 Integer::class.javaObjectType.name, Long::class.javaObjectType.name,
158                 Float::class.javaObjectType.name, Double::class.javaObjectType.name ->
159                     PrimitiveFieldInfo(name, xmlName, type.typeName, isRequired)
160                 String::class.java.name -> StringFieldInfo(name, xmlName, isRequired)
161                 List::class.java.name -> {
162                     requireNotNull(xmlName)
163                     val elementType = type.typeParametersValues().single()
164                     require(elementType is ResolvedReferenceType)
165                     val listType = ParameterizedTypeName.get(
166                         ClassName.get(List::class.java), elementType.typeName
167                     )
168                     val element = parseClassFieldInfo(
169                         "(element)", xmlName, true, elementType.classDeclaration
170                     )
171                     ListFieldInfo(name, xmlName, listType, element)
172                 }
173                 else -> parseClassFieldInfo(name, xmlName, isRequired, type.classDeclaration)
174             }
175         }
176         else -> error(type)
177     }
178 }
179 
getOrThrownull180 private fun <T> ParseResult<T>.getOrThrow(): T =
181     if (isSuccessful) {
182         result.get()
183     } else {
184         throw ParseProblemException(problems)
185     }
186 
getNodesByClassnull187 private inline fun <reified T : Node> Node.getNodesByClass(): List<T> =
188     getNodesByClass(T::class.java)
189 
190 private fun <T : Node> Node.getNodesByClass(klass: Class<T>): List<T> = mutableListOf<T>().apply {
191     if (klass.isInstance(this@getNodesByClass)) {
192         this += klass.cast(this@getNodesByClass)
193     }
194     for (childNode in childNodes) {
195         this += childNode.getNodesByClass(klass)
196     }
197 }
198 
getOrNullnull199 private fun <T> Optional<T>.getOrNull(): T? = orElse(null)
200 
201 private fun List<AnnotationExpr>.getByName(name: String): AnnotationExpr? =
202     find { it.name.identifier == name }
203 
AnnotationExprnull204 private fun AnnotationExpr.getMemberValue(name: String): Expression? =
205     when (this) {
206         is NormalAnnotationExpr -> pairs.find { it.nameAsString == name }?.value
207         is SingleMemberAnnotationExpr -> if (name == "value") memberValue else null
208         else -> null
209     }
210 
211 private val Expression.stringLiteralValue: String
212     get() {
213         require(this is StringLiteralExpr)
214         return value
215     }
216 
217 private val ResolvedReferenceType.classDeclaration: ClassOrInterfaceDeclaration
218     get() {
219         val resolvedClassDeclaration = typeDeclaration
220         require(resolvedClassDeclaration is JavaParserClassDeclaration)
221         return resolvedClassDeclaration.wrappedNode
222     }
223 
224 private val ResolvedPrimitiveType.typeName: TypeName
225     get() =
226         when (this) {
227             ResolvedPrimitiveType.BOOLEAN -> TypeName.BOOLEAN
228             ResolvedPrimitiveType.BYTE -> TypeName.BYTE
229             ResolvedPrimitiveType.SHORT -> TypeName.SHORT
230             ResolvedPrimitiveType.CHAR -> TypeName.CHAR
231             ResolvedPrimitiveType.INT -> TypeName.INT
232             ResolvedPrimitiveType.LONG -> TypeName.LONG
233             ResolvedPrimitiveType.FLOAT -> TypeName.FLOAT
234             ResolvedPrimitiveType.DOUBLE -> TypeName.DOUBLE
235         }
236 
237 // This doesn't support type parameters.
238 private val ResolvedReferenceType.typeName: TypeName
239     get() = typeDeclaration.typeName
240 
241 private val ResolvedReferenceTypeDeclaration.typeName: ClassName
242     get() {
243         val packageName = packageName
244         val classNames = className.split(".")
245         val topLevelClassName = classNames.first()
246         val nestedClassNames = classNames.drop(1)
247         return ClassName.get(packageName, topLevelClassName, *nestedClassNames.toTypedArray())
248     }
249