• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 import com.squareup.javapoet.ClassName
18 import com.squareup.javapoet.FieldSpec
19 import com.squareup.javapoet.JavaFile
20 import com.squareup.javapoet.MethodSpec
21 import com.squareup.javapoet.NameAllocator
22 import com.squareup.javapoet.ParameterSpec
23 import com.squareup.javapoet.TypeSpec
24 import java.io.File
25 import java.io.FileInputStream
26 import java.io.FileNotFoundException
27 import java.io.FileOutputStream
28 import java.io.IOException
29 import java.nio.charset.StandardCharsets
30 import java.time.Year
31 import java.util.Objects
32 import javax.lang.model.element.Modifier
33 
34 // JavaPoet only supports line comments, and can't add a newline after file level comments.
35 val FILE_HEADER = """
36     /*
37      * Copyright (C) ${Year.now().value} The Android Open Source Project
38      *
39      * Licensed under the Apache License, Version 2.0 (the "License");
40      * you may not use this file except in compliance with the License.
41      * You may obtain a copy of the License at
42      *
43      *      http://www.apache.org/licenses/LICENSE-2.0
44      *
45      * Unless required by applicable law or agreed to in writing, software
46      * distributed under the License is distributed on an "AS IS" BASIS,
47      * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
48      * See the License for the specific language governing permissions and
49      * limitations under the License.
50      */
51 
52     // Generated by xmlpersistence. DO NOT MODIFY!
53     // CHECKSTYLE:OFF Generated code
54     // @formatter:off
55 """.trimIndent() + "\n\n"
56 
57 private val atomicFileType = ClassName.get("android.util", "AtomicFile")
58 
generatenull59 fun generate(persistence: PersistenceInfo): JavaFile {
60     val distinctClassFields = persistence.root.allClassFields.distinctBy { it.type }
61     val type = TypeSpec.classBuilder(persistence.name)
62         .addJavadoc(
63             """
64                 Generated class implementing XML persistence for${'$'}W{@link $1T}.
65                 <p>
66                 This class provides atomicity for persistence via {@link $2T}, however it does not provide
67                 thread safety, so please bring your own synchronization mechanism.
68             """.trimIndent(), persistence.root.type, atomicFileType
69         )
70         .addModifiers(Modifier.PUBLIC, Modifier.FINAL)
71         .addField(generateFileField())
72         .addMethod(generateConstructor())
73         .addMethod(generateReadMethod(persistence.root))
74         .addMethod(generateParseMethod(persistence.root))
75         .addMethods(distinctClassFields.map { generateParseClassMethod(it) })
76         .addMethod(generateWriteMethod(persistence.root))
77         .addMethod(generateSerializeMethod(persistence.root))
78         .addMethods(distinctClassFields.map { generateSerializeClassMethod(it) })
79         .addMethod(generateDeleteMethod())
80         .build()
81     return JavaFile.builder(persistence.root.type.packageName(), type)
82         .skipJavaLangImports(true)
83         .indent("    ")
84         .build()
85 }
86 
87 private val nonNullType = ClassName.get("android.annotation", "NonNull")
88 
generateFileFieldnull89 private fun generateFileField(): FieldSpec =
90     FieldSpec.builder(atomicFileType, "mFile", Modifier.PRIVATE, Modifier.FINAL)
91         .addAnnotation(nonNullType)
92         .build()
93 
94 private fun generateConstructor(): MethodSpec =
95     MethodSpec.constructorBuilder()
96         .addJavadoc(
97             """
98                 Create an instance of this class.
99 
100                 @param file the XML file for persistence
101             """.trimIndent()
102         )
103         .addModifiers(Modifier.PUBLIC)
104         .addParameter(
105             ParameterSpec.builder(File::class.java, "file").addAnnotation(nonNullType).build()
106         )
107         .addStatement("mFile = new \$1T(file)", atomicFileType)
108         .build()
109 
110 private val nullableType = ClassName.get("android.annotation", "Nullable")
111 
112 private val xmlPullParserType = ClassName.get("org.xmlpull.v1", "XmlPullParser")
113 
114 private val xmlType = ClassName.get("android.util", "Xml")
115 
116 private val xmlPullParserExceptionType = ClassName.get("org.xmlpull.v1", "XmlPullParserException")
117 
118 private fun generateReadMethod(rootField: ClassFieldInfo): MethodSpec =
119     MethodSpec.methodBuilder("read")
120         .addJavadoc(
121             """
122                 Read${'$'}W{@link $1T}${'$'}Wfrom${'$'}Wthe${'$'}WXML${'$'}Wfile.
123 
124                 @return the persisted${'$'}W{@link $1T},${'$'}Wor${'$'}W{@code null}${'$'}Wif${'$'}Wthe${'$'}WXML${'$'}Wfile${'$'}Wdoesn't${'$'}Wexist
125                 @throws IllegalArgumentException if an error occurred while reading
126             """.trimIndent(), rootField.type
127         )
128         .addAnnotation(nullableType)
129         .addModifiers(Modifier.PUBLIC)
130         .returns(rootField.type)
131         .addControlFlow("try (\$1T inputStream = mFile.openRead())", FileInputStream::class.java) {
132             addStatement("final \$1T parser = \$2T.newPullParser()", xmlPullParserType, xmlType)
133             addStatement("parser.setInput(inputStream, null)")
134             addStatement("return parse(parser)")
135             nextControlFlow("catch (\$1T e)", FileNotFoundException::class.java)
136             addStatement("return null")
137             nextControlFlow(
138                 "catch (\$1T | \$2T e)", IOException::class.java, xmlPullParserExceptionType
139             )
140             addStatement("throw new IllegalArgumentException(e)")
141         }
142         .build()
143 
144 private val ClassFieldInfo.allClassFields: List<ClassFieldInfo>
145     get() =
<lambda>null146         mutableListOf<ClassFieldInfo>().apply {
147             this += this@allClassFields
148             for (field in fields) {
149                 when (field) {
150                     is ClassFieldInfo -> this += field.allClassFields
151                     is ListFieldInfo -> this += field.element.allClassFields
152                     else -> {}
153                 }
154             }
155         }
156 
generateParseMethodnull157 private fun generateParseMethod(rootField: ClassFieldInfo): MethodSpec =
158     MethodSpec.methodBuilder("parse")
159         .addAnnotation(nonNullType)
160         .addModifiers(Modifier.PRIVATE, Modifier.STATIC)
161         .returns(rootField.type)
162         .addParameter(
163             ParameterSpec.builder(xmlPullParserType, "parser").addAnnotation(nonNullType).build()
164         )
165         .addExceptions(listOf(ClassName.get(IOException::class.java), xmlPullParserExceptionType))
166         .apply {
167             addStatement("int type")
168             addStatement("int depth")
169             addStatement("int innerDepth = parser.getDepth() + 1")
170             addControlFlow(
171                 "while ((type = parser.next()) != \$1T.END_DOCUMENT\$W"
172                     + "&& ((depth = parser.getDepth()) >= innerDepth || type != \$1T.END_TAG))",
173                 xmlPullParserType
174             ) {
175                 addControlFlow(
176                     "if (depth > innerDepth || type != \$1T.START_TAG)", xmlPullParserType
177                 ) {
178                     addStatement("continue")
179                 }
180                 addControlFlow(
181                     "if (\$1T.equals(parser.getName(),\$W\$2S))", Objects::class.java,
182                     rootField.tagName
183                 ) {
184                     addStatement("return \$1L(parser)", rootField.parseMethodName)
185                 }
186             }
187             addStatement(
188                 "throw new IllegalArgumentException(\$1S)",
189                 "Missing root tag <${rootField.tagName}>"
190             )
191         }
192         .build()
193 
generateParseClassMethodnull194 private fun generateParseClassMethod(classField: ClassFieldInfo): MethodSpec =
195     MethodSpec.methodBuilder(classField.parseMethodName)
196         .addAnnotation(nonNullType)
197         .addModifiers(Modifier.PRIVATE, Modifier.STATIC)
198         .returns(classField.type)
199         .addParameter(
200             ParameterSpec.builder(xmlPullParserType, "parser").addAnnotation(nonNullType).build()
201         )
202         .apply {
203             val (attributeFields, tagFields) = classField.fields
204                 .partition { it is PrimitiveFieldInfo || it is StringFieldInfo }
205             if (tagFields.isNotEmpty()) {
206                 addExceptions(
207                     listOf(ClassName.get(IOException::class.java), xmlPullParserExceptionType)
208                 )
209             }
210             val nameAllocator = NameAllocator().apply {
211                 newName("parser")
212                 newName("type")
213                 newName("depth")
214                 newName("innerDepth")
215             }
216             for (field in attributeFields) {
217                 val variableName = nameAllocator.newName(field.variableName, field)
218                 when (field) {
219                     is PrimitiveFieldInfo -> {
220                         val stringVariableName =
221                             nameAllocator.newName("${field.variableName}String")
222                         addStatement(
223                             "final String \$1L =\$Wparser.getAttributeValue(null,\$W\$2S)",
224                             stringVariableName, field.attributeName
225                         )
226                         if (field.isRequired) {
227                             addControlFlow("if (\$1L == null)", stringVariableName) {
228                                 addStatement(
229                                     "throw new IllegalArgumentException(\$1S)",
230                                     "Missing attribute \"${field.attributeName}\""
231                                 )
232                             }
233                         }
234                         val boxedType = field.type.box()
235                         val parseTypeMethodName = if (field.type.isPrimitive) {
236                             "parse${field.type.toString().capitalize()}"
237                         } else {
238                             "valueOf"
239                         }
240                         if (field.isRequired) {
241                             addStatement(
242                                 "final \$1T \$2L =\$W\$3T.\$4L($5L)", field.type, variableName,
243                                 boxedType, parseTypeMethodName, stringVariableName
244                             )
245                         } else {
246                             addStatement(
247                                 "final \$1T \$2L =\$W$3L != null ?\$W\$4T.\$5L($3L)\$W: null",
248                                 field.type, variableName, stringVariableName, boxedType,
249                                 parseTypeMethodName
250                             )
251                         }
252                     }
253                     is StringFieldInfo ->
254                         addStatement(
255                             "final String \$1L =\$Wparser.getAttributeValue(null,\$W\$2S)",
256                             variableName, field.attributeName
257                         )
258                     else -> error(field)
259                 }
260             }
261             if (tagFields.isNotEmpty()) {
262                 for (field in tagFields) {
263                     val variableName = nameAllocator.newName(field.variableName, field)
264                     when (field) {
265                         is ClassFieldInfo ->
266                             addStatement("\$1T \$2L =\$Wnull", field.type, variableName)
267                         is ListFieldInfo ->
268                             addStatement(
269                                 "final \$1T \$2L =\$Wnew \$3T<>()", field.type, variableName,
270                                 ArrayList::class.java
271                             )
272                         else -> error(field)
273                     }
274                 }
275                 addStatement("int type")
276                 addStatement("int depth")
277                 addStatement("int innerDepth = parser.getDepth() + 1")
278                 addControlFlow(
279                     "while ((type = parser.next()) != \$1T.END_DOCUMENT\$W"
280                         + "&& ((depth = parser.getDepth()) >= innerDepth || type != \$1T.END_TAG))",
281                     xmlPullParserType
282                 ) {
283                     addControlFlow(
284                         "if (depth > innerDepth || type != \$1T.START_TAG)", xmlPullParserType
285                     ) {
286                         addStatement("continue")
287                     }
288                     addControlFlow("switch (parser.getName())") {
289                         for (field in tagFields) {
290                             addControlFlow("case \$1S:", field.tagName) {
291                                 val variableName = nameAllocator.get(field)
292                                 when (field) {
293                                     is ClassFieldInfo -> {
294                                         addControlFlow("if (\$1L != null)", variableName) {
295                                             addStatement(
296                                                 "throw new IllegalArgumentException(\$1S)",
297                                                 "Duplicate tag \"${field.tagName}\""
298                                             )
299                                         }
300                                         addStatement(
301                                             "\$1L =\$W\$2L(parser)", variableName,
302                                             field.parseMethodName
303                                         )
304                                         addStatement("break")
305                                     }
306                                     is ListFieldInfo -> {
307                                         val elementNameAllocator = nameAllocator.clone()
308                                         val elementVariableName = elementNameAllocator.newName(
309                                             field.element.xmlName!!.toLowerCamelCase()
310                                         )
311                                         addStatement(
312                                             "final \$1T \$2L =\$W\$3L(parser)", field.element.type,
313                                             elementVariableName, field.element.parseMethodName
314                                         )
315                                         addStatement(
316                                             "\$1L.add(\$2L)", variableName, elementVariableName
317                                         )
318                                         addStatement("break")
319                                     }
320                                     else -> error(field)
321                                 }
322                             }
323                         }
324                     }
325                 }
326             }
327             for (field in tagFields.filter { it is ClassFieldInfo && it.isRequired }) {
328                 addControlFlow("if ($1L == null)", nameAllocator.get(field)) {
329                     addStatement(
330                         "throw new IllegalArgumentException(\$1S)", "Missing tag <${field.tagName}>"
331                     )
332                 }
333             }
334             addStatement(
335                 classField.fields.joinToString(",\$W", "return new \$1T(", ")") {
336                     nameAllocator.get(it)
337                 }, classField.type
338             )
339         }
340         .build()
341 
342 private val ClassFieldInfo.parseMethodName: String
343     get() = "parse${type.simpleName().toUpperCamelCase()}"
344 
345 private val xmlSerializerType = ClassName.get("org.xmlpull.v1", "XmlSerializer")
346 
generateWriteMethodnull347 private fun generateWriteMethod(rootField: ClassFieldInfo): MethodSpec =
348     MethodSpec.methodBuilder("write")
349         .apply {
350             val nameAllocator = NameAllocator().apply {
351                 newName("outputStream")
352                 newName("serializer")
353             }
354             val parameterName = nameAllocator.newName(rootField.variableName)
355             addJavadoc(
356                 """
357                     Write${'$'}W{@link $1T}${'$'}Wto${'$'}Wthe${'$'}WXML${'$'}Wfile.
358 
359                     @param $2L the${'$'}W{@link ${'$'}1T}${'$'}Wto${'$'}Wpersist
360                 """.trimIndent(), rootField.type, parameterName
361             )
362             addAnnotation(nullableType)
363             addModifiers(Modifier.PUBLIC)
364             addParameter(
365                 ParameterSpec.builder(rootField.type, parameterName)
366                     .addAnnotation(nonNullType)
367                     .build()
368             )
369             addStatement("\$1T outputStream = null", FileOutputStream::class.java)
370             addControlFlow("try") {
371                 addStatement("outputStream = mFile.startWrite()")
372                 addStatement(
373                     "final \$1T serializer =\$W\$2T.newSerializer()", xmlSerializerType, xmlType
374                 )
375                 addStatement(
376                     "serializer.setOutput(outputStream, \$1T.UTF_8.name())",
377                     StandardCharsets::class.java
378                 )
379                 addStatement(
380                     "serializer.setFeature(\$1S, true)",
381                     "http://xmlpull.org/v1/doc/features.html#indent-output"
382                 )
383                 addStatement("serializer.startDocument(null, true)")
384                 addStatement("serialize(serializer,\$W\$1L)", parameterName)
385                 addStatement("serializer.endDocument()")
386                 addStatement("mFile.finishWrite(outputStream)")
387                 nextControlFlow("catch (Exception e)")
388                 addStatement("e.printStackTrace()")
389                 addStatement("mFile.failWrite(outputStream)")
390             }
391         }
392         .build()
393 
generateSerializeMethodnull394 private fun generateSerializeMethod(rootField: ClassFieldInfo): MethodSpec =
395     MethodSpec.methodBuilder("serialize")
396         .addModifiers(Modifier.PRIVATE, Modifier.STATIC)
397         .addParameter(
398             ParameterSpec.builder(xmlSerializerType, "serializer")
399                 .addAnnotation(nonNullType)
400                 .build()
401         )
402         .apply {
403             val nameAllocator = NameAllocator().apply { newName("serializer") }
404             val parameterName = nameAllocator.newName(rootField.variableName)
405             addParameter(
406                 ParameterSpec.builder(rootField.type, parameterName)
407                     .addAnnotation(nonNullType)
408                     .build()
409             )
410             addException(IOException::class.java)
411             addStatement("serializer.startTag(null, \$1S)", rootField.tagName)
412             addStatement("\$1L(serializer, \$2L)", rootField.serializeMethodName, parameterName)
413             addStatement("serializer.endTag(null, \$1S)", rootField.tagName)
414         }
415         .build()
416 
generateSerializeClassMethodnull417 private fun generateSerializeClassMethod(classField: ClassFieldInfo): MethodSpec =
418     MethodSpec.methodBuilder(classField.serializeMethodName)
419         .addModifiers(Modifier.PRIVATE, Modifier.STATIC)
420         .addParameter(
421             ParameterSpec.builder(xmlSerializerType, "serializer")
422                 .addAnnotation(nonNullType)
423                 .build()
424         )
425         .apply {
426             val nameAllocator = NameAllocator().apply {
427                 newName("serializer")
428                 newName("i")
429             }
430             val parameterName = nameAllocator.newName(classField.serializeParameterName)
431             addParameter(
432                 ParameterSpec.builder(classField.type, parameterName)
433                     .addAnnotation(nonNullType)
434                     .build()
435             )
436             addException(IOException::class.java)
437             val (attributeFields, tagFields) = classField.fields
438                 .partition { it is PrimitiveFieldInfo || it is StringFieldInfo }
439             for (field in attributeFields) {
440                 val variableName = "$parameterName.${field.name}"
441                 if (!field.isRequired) {
442                     beginControlFlow("if (\$1L != null)", variableName)
443                 }
444                 when (field) {
445                     is PrimitiveFieldInfo -> {
446                         if (field.isRequired && !field.type.isPrimitive) {
447                             addControlFlow("if (\$1L == null)", variableName) {
448                                 addStatement(
449                                     "throw new IllegalArgumentException(\$1S)",
450                                     "Field \"${field.name}\" is null"
451                                 )
452                             }
453                         }
454                         val stringVariableName =
455                             nameAllocator.newName("${field.variableName}String")
456                         addStatement(
457                             "final String \$1L =\$WString.valueOf(\$2L)", stringVariableName,
458                             variableName
459                         )
460                         addStatement(
461                             "serializer.attribute(null, \$1S, \$2L)", field.attributeName,
462                             stringVariableName
463                         )
464                     }
465                     is StringFieldInfo -> {
466                         if (field.isRequired) {
467                             addControlFlow("if (\$1L == null)", variableName) {
468                                 addStatement(
469                                     "throw new IllegalArgumentException(\$1S)",
470                                     "Field \"${field.name}\" is null"
471                                 )
472                             }
473                         }
474                         addStatement(
475                             "serializer.attribute(null, \$1S, \$2L)", field.attributeName,
476                             variableName
477                         )
478                     }
479                     else -> error(field)
480                 }
481                 if (!field.isRequired) {
482                     endControlFlow()
483                 }
484             }
485             for (field in tagFields) {
486                 val variableName = "$parameterName.${field.name}"
487                 if (field.isRequired) {
488                     addControlFlow("if (\$1L == null)", variableName) {
489                         addStatement(
490                             "throw new IllegalArgumentException(\$1S)",
491                             "Field \"${field.name}\" is null"
492                         )
493                     }
494                 }
495                 when (field) {
496                     is ClassFieldInfo -> {
497                         addStatement("serializer.startTag(null, \$1S)", field.tagName)
498                         addStatement(
499                             "\$1L(serializer, \$2L)", field.serializeMethodName, variableName
500                         )
501                         addStatement("serializer.endTag(null, \$1S)", field.tagName)
502                     }
503                     is ListFieldInfo -> {
504                         val sizeVariableName = nameAllocator.newName("${field.variableName}Size")
505                         addStatement(
506                             "final int \$1L =\$W\$2L.size()", sizeVariableName, variableName
507                         )
508                         addControlFlow("for (int i = 0;\$Wi < \$1L;\$Wi++)", sizeVariableName) {
509                             val elementNameAllocator = nameAllocator.clone()
510                             val elementVariableName = elementNameAllocator.newName(
511                                 field.element.xmlName!!.toLowerCamelCase()
512                             )
513                             addStatement(
514                                 "final \$1T \$2L =\$W\$3L.get(i)", field.element.type,
515                                 elementVariableName, variableName
516                             )
517                             addControlFlow("if (\$1L == null)", elementVariableName) {
518                                 addStatement(
519                                     "throw new IllegalArgumentException(\$1S\$W+ i\$W+ \$2S)",
520                                     "Field element \"${field.name}[", "]\" is null"
521                                 )
522                             }
523                             addStatement("serializer.startTag(null, \$1S)", field.element.tagName)
524                             addStatement(
525                                 "\$1L(serializer,\$W\$2L)", field.element.serializeMethodName,
526                                 elementVariableName
527                             )
528                             addStatement("serializer.endTag(null, \$1S)", field.element.tagName)
529                         }
530                     }
531                     else -> error(field)
532                 }
533             }
534         }
535         .build()
536 
537 private val ClassFieldInfo.serializeMethodName: String
538     get() = "serialize${type.simpleName().toUpperCamelCase()}"
539 
540 private val ClassFieldInfo.serializeParameterName: String
541     get() = type.simpleName().toLowerCamelCase()
542 
543 private val FieldInfo.variableName: String
544     get() = name.toLowerCamelCase()
545 
546 private val FieldInfo.attributeName: String
547     get() {
548         check(this is PrimitiveFieldInfo || this is StringFieldInfo)
549         return xmlNameOrName.toLowerCamelCase()
550     }
551 
552 private val FieldInfo.tagName: String
553     get() {
554         check(this is ClassFieldInfo || this is ListFieldInfo)
555         return xmlNameOrName.toLowerKebabCase()
556     }
557 
558 private val FieldInfo.xmlNameOrName: String
559     get() = xmlName ?: name
560 
generateDeleteMethodnull561 private fun generateDeleteMethod(): MethodSpec =
562     MethodSpec.methodBuilder("delete")
563         .addJavadoc("Delete the XML file, if any.")
564         .addModifiers(Modifier.PUBLIC)
565         .addStatement("mFile.delete()")
566         .build()
567 
568 private inline fun MethodSpec.Builder.addControlFlow(
569     controlFlow: String,
570     vararg args: Any,
571     block: MethodSpec.Builder.() -> Unit
572 ): MethodSpec.Builder {
573     beginControlFlow(controlFlow, *args)
574     block()
575     endControlFlow()
576     return this
577 }
578