1 /*
2 * Copyright (C) 2020 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 import com.github.javaparser.JavaParser
18 import com.github.javaparser.ParseProblemException
19 import com.github.javaparser.ParseResult
20 import com.github.javaparser.ParserConfiguration
21 import com.github.javaparser.ast.Node
22 import com.github.javaparser.ast.body.ClassOrInterfaceDeclaration
23 import com.github.javaparser.ast.body.FieldDeclaration
24 import com.github.javaparser.ast.body.TypeDeclaration
25 import com.github.javaparser.ast.expr.AnnotationExpr
26 import com.github.javaparser.ast.expr.Expression
27 import com.github.javaparser.ast.expr.NormalAnnotationExpr
28 import com.github.javaparser.ast.expr.SingleMemberAnnotationExpr
29 import com.github.javaparser.ast.expr.StringLiteralExpr
30 import com.github.javaparser.resolution.declarations.ResolvedReferenceTypeDeclaration
31 import com.github.javaparser.resolution.types.ResolvedPrimitiveType
32 import com.github.javaparser.resolution.types.ResolvedReferenceType
33 import com.github.javaparser.symbolsolver.JavaSymbolSolver
34 import com.github.javaparser.symbolsolver.javaparsermodel.declarations.JavaParserClassDeclaration
35 import com.github.javaparser.symbolsolver.resolution.typesolvers.CombinedTypeSolver
36 import com.github.javaparser.symbolsolver.resolution.typesolvers.MemoryTypeSolver
37 import com.github.javaparser.symbolsolver.resolution.typesolvers.ReflectionTypeSolver
38 import com.squareup.javapoet.ClassName
39 import com.squareup.javapoet.ParameterizedTypeName
40 import com.squareup.javapoet.TypeName
41 import java.nio.file.Path
42 import java.util.Optional
43
44 class PersistenceInfo(
45 val name: String,
46 val root: ClassFieldInfo,
47 val path: Path
48 )
49
50 sealed class FieldInfo {
51 abstract val name: String
52 abstract val xmlName: String?
53 abstract val type: TypeName
54 abstract val isRequired: Boolean
55 }
56
57 class PrimitiveFieldInfo(
58 override val name: String,
59 override val xmlName: String?,
60 override val type: TypeName,
61 override val isRequired: Boolean
62 ) : FieldInfo()
63
64 class StringFieldInfo(
65 override val name: String,
66 override val xmlName: String?,
67 override val isRequired: Boolean
68 ) : FieldInfo() {
69 override val type: TypeName = ClassName.get(String::class.java)
70 }
71
72 class ClassFieldInfo(
73 override val name: String,
74 override val xmlName: String?,
75 override val type: ClassName,
76 override val isRequired: Boolean,
77 val fields: List<FieldInfo>
78 ) : FieldInfo()
79
80 class ListFieldInfo(
81 override val name: String,
82 override val xmlName: String?,
83 override val type: ParameterizedTypeName,
84 val element: ClassFieldInfo
85 ) : FieldInfo() {
86 override val isRequired: Boolean = true
87 }
88
parsenull89 fun parse(files: List<Path>): List<PersistenceInfo> {
90 val typeSolver = CombinedTypeSolver().apply { add(ReflectionTypeSolver()) }
91 val javaParser = JavaParser(ParserConfiguration()
92 .setSymbolResolver(JavaSymbolSolver(typeSolver)))
93 val compilationUnits = files.map { javaParser.parse(it).getOrThrow() }
94 val memoryTypeSolver = MemoryTypeSolver().apply {
95 for (compilationUnit in compilationUnits) {
96 for (typeDeclaration in compilationUnit.getNodesByClass<TypeDeclaration<*>>()) {
97 val name = typeDeclaration.fullyQualifiedName.getOrNull() ?: continue
98 addDeclaration(name, typeDeclaration.resolve())
99 }
100 }
101 }
102 typeSolver.add(memoryTypeSolver)
103 return mutableListOf<PersistenceInfo>().apply {
104 for (compilationUnit in compilationUnits) {
105 val classDeclarations = compilationUnit
106 .getNodesByClass<ClassOrInterfaceDeclaration>()
107 .filter { !it.isInterface && (!it.isNestedType || it.isStatic) }
108 this += classDeclarations.mapNotNull { parsePersistenceInfo(it) }
109 }
110 }
111 }
112
parsePersistenceInfonull113 private fun parsePersistenceInfo(classDeclaration: ClassOrInterfaceDeclaration): PersistenceInfo? {
114 val annotation = classDeclaration.getAnnotationByName("XmlPersistence").getOrNull()
115 ?: return null
116 val rootClassName = classDeclaration.nameAsString
117 val name = annotation.getMemberValue("value")?.stringLiteralValue
118 ?: "${rootClassName}Persistence"
119 val rootXmlName = classDeclaration.getAnnotationByName("XmlName").getOrNull()
120 ?.getMemberValue("value")?.stringLiteralValue
121 val root = parseClassFieldInfo(
122 rootXmlName ?: rootClassName, rootXmlName, true, classDeclaration
123 )
124 val path = classDeclaration.findCompilationUnit().get().storage.get().path
125 .resolveSibling("$name.java")
126 return PersistenceInfo(name, root, path)
127 }
128
parseClassFieldInfonull129 private fun parseClassFieldInfo(
130 name: String,
131 xmlName: String?,
132 isRequired: Boolean,
133 classDeclaration: ClassOrInterfaceDeclaration
134 ): ClassFieldInfo {
135 val fields = classDeclaration.fields.filterNot { it.isStatic }.map { parseFieldInfo(it) }
136 val type = classDeclaration.resolve().typeName
137 return ClassFieldInfo(name, xmlName, type, isRequired, fields)
138 }
139
parseFieldInfonull140 private fun parseFieldInfo(field: FieldDeclaration): FieldInfo {
141 require(field.isPublic && field.isFinal)
142 val variable = field.variables.single()
143 val name = variable.nameAsString
144 val annotations = field.annotations + variable.type.annotations
145 val annotation = annotations.getByName("XmlName")
146 val xmlName = annotation?.getMemberValue("value")?.stringLiteralValue
147 val isRequired = annotations.getByName("NonNull") != null
148 return when (val type = variable.type.resolve()) {
149 is ResolvedPrimitiveType -> {
150 val primitiveType = type.typeName
151 PrimitiveFieldInfo(name, xmlName, primitiveType, true)
152 }
153 is ResolvedReferenceType -> {
154 when (type.qualifiedName) {
155 Boolean::class.javaObjectType.name, Byte::class.javaObjectType.name,
156 Short::class.javaObjectType.name, Char::class.javaObjectType.name,
157 Integer::class.javaObjectType.name, Long::class.javaObjectType.name,
158 Float::class.javaObjectType.name, Double::class.javaObjectType.name ->
159 PrimitiveFieldInfo(name, xmlName, type.typeName, isRequired)
160 String::class.java.name -> StringFieldInfo(name, xmlName, isRequired)
161 List::class.java.name -> {
162 requireNotNull(xmlName)
163 val elementType = type.typeParametersValues().single()
164 require(elementType is ResolvedReferenceType)
165 val listType = ParameterizedTypeName.get(
166 ClassName.get(List::class.java), elementType.typeName
167 )
168 val element = parseClassFieldInfo(
169 "(element)", xmlName, true, elementType.classDeclaration
170 )
171 ListFieldInfo(name, xmlName, listType, element)
172 }
173 else -> parseClassFieldInfo(name, xmlName, isRequired, type.classDeclaration)
174 }
175 }
176 else -> error(type)
177 }
178 }
179
getOrThrownull180 private fun <T> ParseResult<T>.getOrThrow(): T =
181 if (isSuccessful) {
182 result.get()
183 } else {
184 throw ParseProblemException(problems)
185 }
186
getNodesByClassnull187 private inline fun <reified T : Node> Node.getNodesByClass(): List<T> =
188 getNodesByClass(T::class.java)
189
190 private fun <T : Node> Node.getNodesByClass(klass: Class<T>): List<T> = mutableListOf<T>().apply {
191 if (klass.isInstance(this@getNodesByClass)) {
192 this += klass.cast(this@getNodesByClass)
193 }
194 for (childNode in childNodes) {
195 this += childNode.getNodesByClass(klass)
196 }
197 }
198
getOrNullnull199 private fun <T> Optional<T>.getOrNull(): T? = orElse(null)
200
201 private fun List<AnnotationExpr>.getByName(name: String): AnnotationExpr? =
202 find { it.name.identifier == name }
203
AnnotationExprnull204 private fun AnnotationExpr.getMemberValue(name: String): Expression? =
205 when (this) {
206 is NormalAnnotationExpr -> pairs.find { it.nameAsString == name }?.value
207 is SingleMemberAnnotationExpr -> if (name == "value") memberValue else null
208 else -> null
209 }
210
211 private val Expression.stringLiteralValue: String
212 get() {
213 require(this is StringLiteralExpr)
214 return value
215 }
216
217 private val ResolvedReferenceType.classDeclaration: ClassOrInterfaceDeclaration
218 get() {
219 val resolvedClassDeclaration = typeDeclaration
220 require(resolvedClassDeclaration is JavaParserClassDeclaration)
221 return resolvedClassDeclaration.wrappedNode
222 }
223
224 private val ResolvedPrimitiveType.typeName: TypeName
225 get() =
226 when (this) {
227 ResolvedPrimitiveType.BOOLEAN -> TypeName.BOOLEAN
228 ResolvedPrimitiveType.BYTE -> TypeName.BYTE
229 ResolvedPrimitiveType.SHORT -> TypeName.SHORT
230 ResolvedPrimitiveType.CHAR -> TypeName.CHAR
231 ResolvedPrimitiveType.INT -> TypeName.INT
232 ResolvedPrimitiveType.LONG -> TypeName.LONG
233 ResolvedPrimitiveType.FLOAT -> TypeName.FLOAT
234 ResolvedPrimitiveType.DOUBLE -> TypeName.DOUBLE
235 }
236
237 // This doesn't support type parameters.
238 private val ResolvedReferenceType.typeName: TypeName
239 get() = typeDeclaration.typeName
240
241 private val ResolvedReferenceTypeDeclaration.typeName: ClassName
242 get() {
243 val packageName = packageName
244 val classNames = className.split(".")
245 val topLevelClassName = classNames.first()
246 val nestedClassNames = classNames.drop(1)
247 return ClassName.get(packageName, topLevelClassName, *nestedClassNames.toTypedArray())
248 }
249