1 /*
<lambda>null2  * Copyright 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package androidx.navigation.safe.args.generator.kotlin
18 
19 import androidx.annotation.CheckResult
20 import androidx.navigation.safe.args.generator.NavWriter
21 import androidx.navigation.safe.args.generator.ObjectArrayType
22 import androidx.navigation.safe.args.generator.ObjectType
23 import androidx.navigation.safe.args.generator.ext.toCamelCase
24 import androidx.navigation.safe.args.generator.ext.toCamelCaseAsVar
25 import androidx.navigation.safe.args.generator.models.Action
26 import androidx.navigation.safe.args.generator.models.Destination
27 import com.squareup.kotlinpoet.AnnotationSpec
28 import com.squareup.kotlinpoet.ClassName
29 import com.squareup.kotlinpoet.FileSpec
30 import com.squareup.kotlinpoet.FunSpec
31 import com.squareup.kotlinpoet.KModifier
32 import com.squareup.kotlinpoet.ParameterSpec
33 import com.squareup.kotlinpoet.PropertySpec
34 import com.squareup.kotlinpoet.TypeSpec
35 import com.squareup.kotlinpoet.asTypeName
36 
37 class KotlinNavWriter(private val useAndroidX: Boolean = true) : NavWriter<KotlinCodeFile> {
38 
39     override fun generateDirectionsCodeFile(
40         destination: Destination,
41         parentDirectionsFileList: List<KotlinCodeFile>
42     ): KotlinCodeFile {
43         val destName =
44             destination.name
45                 ?: throw IllegalStateException("Destination with actions must have name")
46         val className = ClassName(destName.packageName(), "${destName.simpleName()}Directions")
47 
48         val actionTypes =
49             destination.actions.map { action -> action to generateDirectionTypeSpec(action) }
50 
51         val actionsFunSpec =
52             actionTypes.map { (action, actionTypeSpec) ->
53                 val typeName = ClassName("", actionTypeSpec.name!!)
54                 val parameters =
55                     action.args
56                         .map { arg ->
57                             ParameterSpec.builder(
58                                     name = arg.sanitizedName,
59                                     type = arg.type.typeName().copy(nullable = arg.isNullable)
60                                 )
61                                 .apply { arg.defaultValue?.let { defaultValue(it.write()) } }
62                                 .build()
63                         }
64                         .sortedBy { it.defaultValue != null }
65                 FunSpec.builder(action.id.javaIdentifier.toCamelCaseAsVar())
66                     .apply {
67                         returns(NAV_DIRECTION_CLASSNAME)
68                         addAnnotation(CHECK_RESULT)
69                         addParameters(parameters)
70                         if (action.args.isEmpty()) {
71                             addStatement(
72                                 "return %T(%L)",
73                                 ACTION_ONLY_NAV_DIRECTION_CLASSNAME,
74                                 action.id.accessor()
75                             )
76                         } else {
77                             addStatement(
78                                 "return %T(${parameters.joinToString(", ") { it.name }})",
79                                 typeName
80                             )
81                         }
82                     }
83                     .build()
84             }
85 
86         // The parent destination list is ordered from the closest to the farthest parent of the
87         // processing destination in the graph hierarchy.
88         val parentActionsFunSpec = mutableListOf<FunSpec>()
89         parentDirectionsFileList.forEach {
90             val parentPackageName = it.wrapped.packageName
91             val parentTypeSpec = it.wrapped.members.filterIsInstance(TypeSpec::class.java).first()
92             val parentCompanionTypeSpec = parentTypeSpec.typeSpecs.first { it.isCompanion }
93             parentCompanionTypeSpec.funSpecs
94                 .filter { function ->
95                     actionsFunSpec.none { it.name == function.name } && // de-dupe local actions
96                         parentActionsFunSpec.none {
97                             it.name == function.name
98                         } // de-dupe parent actions
99                 }
100                 .forEach { functionSpec ->
101                     val params = functionSpec.parameters.joinToString(", ") { param -> param.name }
102                     val methodSpec =
103                         FunSpec.builder(functionSpec.name)
104                             .addAnnotation(CHECK_RESULT)
105                             .addParameters(functionSpec.parameters)
106                             .returns(NAV_DIRECTION_CLASSNAME)
107                             .addStatement(
108                                 "return %T.%L($params)",
109                                 ClassName(parentPackageName, parentTypeSpec.name!!),
110                                 functionSpec.name
111                             )
112                             .build()
113                     parentActionsFunSpec.add(methodSpec)
114                 }
115         }
116 
117         val typeSpec =
118             TypeSpec.classBuilder(className)
119                 .primaryConstructor(
120                     FunSpec.constructorBuilder().addModifiers(KModifier.PRIVATE).build()
121                 )
122                 .addTypes(
123                     actionTypes
124                         .filter { (action, _) -> action.args.isNotEmpty() }
125                         .map { (_, type) -> type }
126                 )
127                 .addType(
128                     TypeSpec.companionObjectBuilder()
129                         .addFunctions(actionsFunSpec + parentActionsFunSpec)
130                         .build()
131                 )
132                 .build()
133 
134         return FileSpec.builder(className.packageName, className.simpleName)
135             .addType(typeSpec)
136             .build()
137             .toCodeFile()
138     }
139 
140     internal fun generateDirectionTypeSpec(action: Action): TypeSpec {
141         val className = ClassName("", action.id.javaIdentifier.toCamelCase())
142 
143         val actionIdPropSpec =
144             PropertySpec.builder("actionId", Int::class, KModifier.PUBLIC, KModifier.OVERRIDE)
145                 .initializer("%L", action.id.accessor())
146                 .build()
147 
148         val argumentsPropSpec =
149             PropertySpec.builder(
150                     "arguments",
151                     BUNDLE_CLASSNAME,
152                     KModifier.PUBLIC,
153                     KModifier.OVERRIDE
154                 )
155                 .getter(
156                     FunSpec.getterBuilder()
157                         .apply {
158                             if (action.args.any { it.type is ObjectType }) {
159                                 addAnnotation(CAST_NEVER_SUCCEEDS)
160                             }
161                             val resultVal = "result"
162                             addStatement("val %L = %T()", resultVal, BUNDLE_CLASSNAME)
163                             action.args.forEach { arg ->
164                                 arg.type.addBundlePutStatement(
165                                     this,
166                                     arg,
167                                     resultVal,
168                                     "this.${arg.sanitizedName}"
169                                 )
170                             }
171                             addStatement("return %L", resultVal)
172                         }
173                         .build()
174                 )
175                 .build()
176 
177         val constructorFunSpec =
178             FunSpec.constructorBuilder()
179                 .addParameters(
180                     action.args
181                         .map { arg ->
182                             ParameterSpec.builder(
183                                     name = arg.sanitizedName,
184                                     type = arg.type.typeName().copy(nullable = arg.isNullable)
185                                 )
186                                 .apply { arg.defaultValue?.let { defaultValue(it.write()) } }
187                                 .build()
188                         }
189                         .sortedBy { it.defaultValue != null }
190                 )
191                 .build()
192 
193         return if (action.args.isEmpty()) {
194                 TypeSpec.objectBuilder(className)
195             } else {
196                 TypeSpec.classBuilder(className)
197                     .addModifiers(KModifier.DATA)
198                     .primaryConstructor(constructorFunSpec)
199                     .addProperties(
200                         action.args.map { arg ->
201                             PropertySpec.builder(
202                                     arg.sanitizedName,
203                                     arg.type.typeName().copy(nullable = arg.isNullable)
204                                 )
205                                 .initializer(arg.sanitizedName)
206                                 .build()
207                         }
208                     )
209             }
210             .addSuperinterface(NAV_DIRECTION_CLASSNAME)
211             .addModifiers(KModifier.PRIVATE)
212             .addProperty(actionIdPropSpec)
213             .addProperty(argumentsPropSpec)
214             .build()
215     }
216 
217     override fun generateArgsCodeFile(destination: Destination): KotlinCodeFile {
218         val destName =
219             destination.name
220                 ?: throw IllegalStateException("Destination with actions must have name")
221         val className = ClassName(destName.packageName(), "${destName.simpleName()}Args")
222 
223         val constructorFunSpec =
224             FunSpec.constructorBuilder()
225                 .addParameters(
226                     destination.args
227                         .map { arg ->
228                             ParameterSpec.builder(
229                                     name = arg.sanitizedName,
230                                     type = arg.type.typeName().copy(nullable = arg.isNullable)
231                                 )
232                                 .apply { arg.defaultValue?.let { defaultValue(it.write()) } }
233                                 .build()
234                         }
235                         .sortedBy { it.defaultValue != null }
236                 )
237                 .build()
238 
239         val toBundleFunSpec =
240             FunSpec.builder("toBundle")
241                 .apply {
242                     if (destination.args.any { it.type is ObjectType }) {
243                         addAnnotation(CAST_NEVER_SUCCEEDS)
244                     }
245                     returns(BUNDLE_CLASSNAME)
246                     val resultVal = "result"
247                     addStatement("val %L = %T()", resultVal, BUNDLE_CLASSNAME)
248                     destination.args.forEach { arg ->
249                         arg.type.addBundlePutStatement(
250                             this,
251                             arg,
252                             resultVal,
253                             "this.${arg.sanitizedName}"
254                         )
255                     }
256                     addStatement("return %L", resultVal)
257                 }
258                 .build()
259 
260         val fromBundleFunSpec =
261             FunSpec.builder("fromBundle")
262                 .apply {
263                     addAnnotation(JvmStatic::class)
264                     if (destination.args.any { it.type is ObjectArrayType }) {
265                         addAnnotation(
266                             AnnotationSpec.builder(Suppress::class)
267                                 .addMember("%S,%S", "UNCHECKED_CAST", "DEPRECATION")
268                                 .build()
269                         )
270                     } else if (destination.args.any { it.type is ObjectType }) {
271                         addAnnotation(
272                             AnnotationSpec.builder(Suppress::class)
273                                 .addMember("%S", "DEPRECATION")
274                                 .build()
275                         )
276                     }
277                     returns(className)
278                     val bundleParamName = "bundle"
279                     addParameter(bundleParamName, BUNDLE_CLASSNAME)
280                     addStatement(
281                         "%L.setClassLoader(%T::class.java.classLoader)",
282                         bundleParamName,
283                         className
284                     )
285                     val tempVariables =
286                         destination.args
287                             .map { arg ->
288                                 val tempVal = "__${arg.sanitizedName}"
289                                 addStatement(
290                                     "val %L : %T",
291                                     tempVal,
292                                     arg.type.typeName().copy(nullable = arg.type.allowsNullable())
293                                 )
294                                 beginControlFlow(
295                                     "if (%L.containsKey(%S))",
296                                     bundleParamName,
297                                     arg.name
298                                 )
299                                 arg.type.addBundleGetStatement(this, arg, tempVal, bundleParamName)
300                                 if (arg.type.allowsNullable() && !arg.isNullable) {
301                                     beginControlFlow("if (%L == null)", tempVal).apply {
302                                         addStatement(
303                                             "throw·%T(%S)",
304                                             IllegalArgumentException::class.asTypeName(),
305                                             "Argument \"${arg.name}\" is marked as non-null but was passed a " +
306                                                 "null value."
307                                         )
308                                     }
309                                     endControlFlow()
310                                 }
311                                 nextControlFlow("else")
312                                 val defaultValue = arg.defaultValue
313                                 if (defaultValue != null) {
314                                     addStatement("%L = %L", tempVal, arg.defaultValue.write())
315                                 } else {
316                                     addStatement(
317                                         "throw·%T(%S)",
318                                         IllegalArgumentException::class.asTypeName(),
319                                         "Required argument \"${arg.name}\" is missing and does not have an " +
320                                             "android:defaultValue"
321                                     )
322                                 }
323                                 endControlFlow()
324                                 arg
325                             }
326                             .sortedBy { it.defaultValue != null }
327                     addStatement(
328                         "return·%T(${tempVariables.joinToString(", ") { "__${it.sanitizedName}" }})",
329                         className
330                     )
331                 }
332                 .build()
333 
334         val toSavedStateHandleFunSpec =
335             FunSpec.builder("toSavedStateHandle")
336                 .apply {
337                     if (destination.args.any { it.type is ObjectType }) {
338                         addAnnotation(CAST_NEVER_SUCCEEDS)
339                     }
340                     returns(SAVED_STATE_HANDLE_CLASSNAME)
341                     val resultVal = "result"
342                     addStatement("val %L = %T()", resultVal, SAVED_STATE_HANDLE_CLASSNAME)
343                     destination.args.forEach { arg ->
344                         arg.type.addSavedStateSetStatement(
345                             this,
346                             arg,
347                             resultVal,
348                             "this.${arg.sanitizedName}"
349                         )
350                     }
351                     addStatement("return %L", resultVal)
352                 }
353                 .build()
354 
355         val fromSavedStateHandleFunSpec =
356             FunSpec.builder("fromSavedStateHandle")
357                 .apply {
358                     addAnnotation(JvmStatic::class)
359                     returns(className)
360                     val savedStateParamName = "savedStateHandle"
361                     addParameter(savedStateParamName, SAVED_STATE_HANDLE_CLASSNAME)
362                     val tempVariables =
363                         destination.args
364                             .map { arg ->
365                                 val tempVal = "__${arg.sanitizedName}"
366                                 addStatement(
367                                     "val %L : %T",
368                                     tempVal,
369                                     arg.type.typeName().copy(nullable = true)
370                                 )
371                                 beginControlFlow(
372                                     "if (%L.contains(%S))",
373                                     savedStateParamName,
374                                     arg.name
375                                 )
376                                 arg.type.addSavedStateGetStatement(
377                                     this,
378                                     arg,
379                                     tempVal,
380                                     savedStateParamName
381                                 )
382                                 if (!arg.isNullable) {
383                                     beginControlFlow("if (%L == null)", tempVal)
384                                     val errorMessage =
385                                         if (arg.type.allowsNullable()) {
386                                             "Argument \"${arg.name}\" is marked as non-null but was passed a null value"
387                                         } else {
388                                             "Argument \"${arg.name}\" of type ${arg.type} does not support null values"
389                                         }
390                                     addStatement(
391                                         "throw·%T(%S)",
392                                         IllegalArgumentException::class.asTypeName(),
393                                         errorMessage
394                                     )
395                                     endControlFlow()
396                                 }
397                                 nextControlFlow("else")
398                                 val defaultValue = arg.defaultValue
399                                 if (defaultValue != null) {
400                                     addStatement("%L = %L", tempVal, arg.defaultValue.write())
401                                 } else {
402                                     addStatement(
403                                         "throw·%T(%S)",
404                                         IllegalArgumentException::class.asTypeName(),
405                                         "Required argument \"${arg.name}\" is missing and does not have an " +
406                                             "android:defaultValue"
407                                     )
408                                 }
409                                 endControlFlow()
410                                 arg
411                             }
412                             .sortedBy { it.defaultValue != null }
413                     addStatement(
414                         "return·%T(${tempVariables.joinToString(", ") { "__${it.sanitizedName}" }})",
415                         className
416                     )
417                 }
418                 .build()
419 
420         val typeSpec =
421             TypeSpec.classBuilder(className)
422                 .addSuperinterface(NAV_ARGS_CLASSNAME)
423                 .addModifiers(KModifier.DATA)
424                 .primaryConstructor(constructorFunSpec)
425                 .addProperties(
426                     destination.args.map { arg ->
427                         PropertySpec.builder(
428                                 arg.sanitizedName,
429                                 arg.type.typeName().copy(nullable = arg.isNullable)
430                             )
431                             .initializer(arg.sanitizedName)
432                             .build()
433                     }
434                 )
435                 .addFunction(toBundleFunSpec)
436                 .addFunction(toSavedStateHandleFunSpec)
437                 .addType(
438                     TypeSpec.companionObjectBuilder()
439                         .addFunction(fromBundleFunSpec)
440                         .addFunction(fromSavedStateHandleFunSpec)
441                         .build()
442                 )
443                 .build()
444 
445         return FileSpec.builder(className.packageName, className.simpleName)
446             .addType(typeSpec)
447             .build()
448             .toCodeFile()
449     }
450 
451     companion object {
452         /**
453          * Annotation to suppress casts that never succeed. This is necessary since the generated
454          * code will contain branches that contain a cast that will never occur and succeed. The
455          * reason being that Safe Args is not an annotation processor and cannot inspect the class
456          * hierarchy to generate the correct cast branch only.
457          */
458         val CAST_NEVER_SUCCEEDS =
459             AnnotationSpec.builder(Suppress::class).addMember("%S", "CAST_NEVER_SUCCEEDS").build()
460 
461         val CHECK_RESULT = AnnotationSpec.builder(CheckResult::class).build()
462     }
463 }
464