1 /* <lambda>null2 * Copyright 2019 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package androidx.navigation.safe.args.generator.kotlin 18 19 import androidx.annotation.CheckResult 20 import androidx.navigation.safe.args.generator.NavWriter 21 import androidx.navigation.safe.args.generator.ObjectArrayType 22 import androidx.navigation.safe.args.generator.ObjectType 23 import androidx.navigation.safe.args.generator.ext.toCamelCase 24 import androidx.navigation.safe.args.generator.ext.toCamelCaseAsVar 25 import androidx.navigation.safe.args.generator.models.Action 26 import androidx.navigation.safe.args.generator.models.Destination 27 import com.squareup.kotlinpoet.AnnotationSpec 28 import com.squareup.kotlinpoet.ClassName 29 import com.squareup.kotlinpoet.FileSpec 30 import com.squareup.kotlinpoet.FunSpec 31 import com.squareup.kotlinpoet.KModifier 32 import com.squareup.kotlinpoet.ParameterSpec 33 import com.squareup.kotlinpoet.PropertySpec 34 import com.squareup.kotlinpoet.TypeSpec 35 import com.squareup.kotlinpoet.asTypeName 36 37 class KotlinNavWriter(private val useAndroidX: Boolean = true) : NavWriter<KotlinCodeFile> { 38 39 override fun generateDirectionsCodeFile( 40 destination: Destination, 41 parentDirectionsFileList: List<KotlinCodeFile> 42 ): KotlinCodeFile { 43 val destName = 44 destination.name 45 ?: throw IllegalStateException("Destination with actions must have name") 46 val className = ClassName(destName.packageName(), "${destName.simpleName()}Directions") 47 48 val actionTypes = 49 destination.actions.map { action -> action to generateDirectionTypeSpec(action) } 50 51 val actionsFunSpec = 52 actionTypes.map { (action, actionTypeSpec) -> 53 val typeName = ClassName("", actionTypeSpec.name!!) 54 val parameters = 55 action.args 56 .map { arg -> 57 ParameterSpec.builder( 58 name = arg.sanitizedName, 59 type = arg.type.typeName().copy(nullable = arg.isNullable) 60 ) 61 .apply { arg.defaultValue?.let { defaultValue(it.write()) } } 62 .build() 63 } 64 .sortedBy { it.defaultValue != null } 65 FunSpec.builder(action.id.javaIdentifier.toCamelCaseAsVar()) 66 .apply { 67 returns(NAV_DIRECTION_CLASSNAME) 68 addAnnotation(CHECK_RESULT) 69 addParameters(parameters) 70 if (action.args.isEmpty()) { 71 addStatement( 72 "return %T(%L)", 73 ACTION_ONLY_NAV_DIRECTION_CLASSNAME, 74 action.id.accessor() 75 ) 76 } else { 77 addStatement( 78 "return %T(${parameters.joinToString(", ") { it.name }})", 79 typeName 80 ) 81 } 82 } 83 .build() 84 } 85 86 // The parent destination list is ordered from the closest to the farthest parent of the 87 // processing destination in the graph hierarchy. 88 val parentActionsFunSpec = mutableListOf<FunSpec>() 89 parentDirectionsFileList.forEach { 90 val parentPackageName = it.wrapped.packageName 91 val parentTypeSpec = it.wrapped.members.filterIsInstance(TypeSpec::class.java).first() 92 val parentCompanionTypeSpec = parentTypeSpec.typeSpecs.first { it.isCompanion } 93 parentCompanionTypeSpec.funSpecs 94 .filter { function -> 95 actionsFunSpec.none { it.name == function.name } && // de-dupe local actions 96 parentActionsFunSpec.none { 97 it.name == function.name 98 } // de-dupe parent actions 99 } 100 .forEach { functionSpec -> 101 val params = functionSpec.parameters.joinToString(", ") { param -> param.name } 102 val methodSpec = 103 FunSpec.builder(functionSpec.name) 104 .addAnnotation(CHECK_RESULT) 105 .addParameters(functionSpec.parameters) 106 .returns(NAV_DIRECTION_CLASSNAME) 107 .addStatement( 108 "return %T.%L($params)", 109 ClassName(parentPackageName, parentTypeSpec.name!!), 110 functionSpec.name 111 ) 112 .build() 113 parentActionsFunSpec.add(methodSpec) 114 } 115 } 116 117 val typeSpec = 118 TypeSpec.classBuilder(className) 119 .primaryConstructor( 120 FunSpec.constructorBuilder().addModifiers(KModifier.PRIVATE).build() 121 ) 122 .addTypes( 123 actionTypes 124 .filter { (action, _) -> action.args.isNotEmpty() } 125 .map { (_, type) -> type } 126 ) 127 .addType( 128 TypeSpec.companionObjectBuilder() 129 .addFunctions(actionsFunSpec + parentActionsFunSpec) 130 .build() 131 ) 132 .build() 133 134 return FileSpec.builder(className.packageName, className.simpleName) 135 .addType(typeSpec) 136 .build() 137 .toCodeFile() 138 } 139 140 internal fun generateDirectionTypeSpec(action: Action): TypeSpec { 141 val className = ClassName("", action.id.javaIdentifier.toCamelCase()) 142 143 val actionIdPropSpec = 144 PropertySpec.builder("actionId", Int::class, KModifier.PUBLIC, KModifier.OVERRIDE) 145 .initializer("%L", action.id.accessor()) 146 .build() 147 148 val argumentsPropSpec = 149 PropertySpec.builder( 150 "arguments", 151 BUNDLE_CLASSNAME, 152 KModifier.PUBLIC, 153 KModifier.OVERRIDE 154 ) 155 .getter( 156 FunSpec.getterBuilder() 157 .apply { 158 if (action.args.any { it.type is ObjectType }) { 159 addAnnotation(CAST_NEVER_SUCCEEDS) 160 } 161 val resultVal = "result" 162 addStatement("val %L = %T()", resultVal, BUNDLE_CLASSNAME) 163 action.args.forEach { arg -> 164 arg.type.addBundlePutStatement( 165 this, 166 arg, 167 resultVal, 168 "this.${arg.sanitizedName}" 169 ) 170 } 171 addStatement("return %L", resultVal) 172 } 173 .build() 174 ) 175 .build() 176 177 val constructorFunSpec = 178 FunSpec.constructorBuilder() 179 .addParameters( 180 action.args 181 .map { arg -> 182 ParameterSpec.builder( 183 name = arg.sanitizedName, 184 type = arg.type.typeName().copy(nullable = arg.isNullable) 185 ) 186 .apply { arg.defaultValue?.let { defaultValue(it.write()) } } 187 .build() 188 } 189 .sortedBy { it.defaultValue != null } 190 ) 191 .build() 192 193 return if (action.args.isEmpty()) { 194 TypeSpec.objectBuilder(className) 195 } else { 196 TypeSpec.classBuilder(className) 197 .addModifiers(KModifier.DATA) 198 .primaryConstructor(constructorFunSpec) 199 .addProperties( 200 action.args.map { arg -> 201 PropertySpec.builder( 202 arg.sanitizedName, 203 arg.type.typeName().copy(nullable = arg.isNullable) 204 ) 205 .initializer(arg.sanitizedName) 206 .build() 207 } 208 ) 209 } 210 .addSuperinterface(NAV_DIRECTION_CLASSNAME) 211 .addModifiers(KModifier.PRIVATE) 212 .addProperty(actionIdPropSpec) 213 .addProperty(argumentsPropSpec) 214 .build() 215 } 216 217 override fun generateArgsCodeFile(destination: Destination): KotlinCodeFile { 218 val destName = 219 destination.name 220 ?: throw IllegalStateException("Destination with actions must have name") 221 val className = ClassName(destName.packageName(), "${destName.simpleName()}Args") 222 223 val constructorFunSpec = 224 FunSpec.constructorBuilder() 225 .addParameters( 226 destination.args 227 .map { arg -> 228 ParameterSpec.builder( 229 name = arg.sanitizedName, 230 type = arg.type.typeName().copy(nullable = arg.isNullable) 231 ) 232 .apply { arg.defaultValue?.let { defaultValue(it.write()) } } 233 .build() 234 } 235 .sortedBy { it.defaultValue != null } 236 ) 237 .build() 238 239 val toBundleFunSpec = 240 FunSpec.builder("toBundle") 241 .apply { 242 if (destination.args.any { it.type is ObjectType }) { 243 addAnnotation(CAST_NEVER_SUCCEEDS) 244 } 245 returns(BUNDLE_CLASSNAME) 246 val resultVal = "result" 247 addStatement("val %L = %T()", resultVal, BUNDLE_CLASSNAME) 248 destination.args.forEach { arg -> 249 arg.type.addBundlePutStatement( 250 this, 251 arg, 252 resultVal, 253 "this.${arg.sanitizedName}" 254 ) 255 } 256 addStatement("return %L", resultVal) 257 } 258 .build() 259 260 val fromBundleFunSpec = 261 FunSpec.builder("fromBundle") 262 .apply { 263 addAnnotation(JvmStatic::class) 264 if (destination.args.any { it.type is ObjectArrayType }) { 265 addAnnotation( 266 AnnotationSpec.builder(Suppress::class) 267 .addMember("%S,%S", "UNCHECKED_CAST", "DEPRECATION") 268 .build() 269 ) 270 } else if (destination.args.any { it.type is ObjectType }) { 271 addAnnotation( 272 AnnotationSpec.builder(Suppress::class) 273 .addMember("%S", "DEPRECATION") 274 .build() 275 ) 276 } 277 returns(className) 278 val bundleParamName = "bundle" 279 addParameter(bundleParamName, BUNDLE_CLASSNAME) 280 addStatement( 281 "%L.setClassLoader(%T::class.java.classLoader)", 282 bundleParamName, 283 className 284 ) 285 val tempVariables = 286 destination.args 287 .map { arg -> 288 val tempVal = "__${arg.sanitizedName}" 289 addStatement( 290 "val %L : %T", 291 tempVal, 292 arg.type.typeName().copy(nullable = arg.type.allowsNullable()) 293 ) 294 beginControlFlow( 295 "if (%L.containsKey(%S))", 296 bundleParamName, 297 arg.name 298 ) 299 arg.type.addBundleGetStatement(this, arg, tempVal, bundleParamName) 300 if (arg.type.allowsNullable() && !arg.isNullable) { 301 beginControlFlow("if (%L == null)", tempVal).apply { 302 addStatement( 303 "throw·%T(%S)", 304 IllegalArgumentException::class.asTypeName(), 305 "Argument \"${arg.name}\" is marked as non-null but was passed a " + 306 "null value." 307 ) 308 } 309 endControlFlow() 310 } 311 nextControlFlow("else") 312 val defaultValue = arg.defaultValue 313 if (defaultValue != null) { 314 addStatement("%L = %L", tempVal, arg.defaultValue.write()) 315 } else { 316 addStatement( 317 "throw·%T(%S)", 318 IllegalArgumentException::class.asTypeName(), 319 "Required argument \"${arg.name}\" is missing and does not have an " + 320 "android:defaultValue" 321 ) 322 } 323 endControlFlow() 324 arg 325 } 326 .sortedBy { it.defaultValue != null } 327 addStatement( 328 "return·%T(${tempVariables.joinToString(", ") { "__${it.sanitizedName}" }})", 329 className 330 ) 331 } 332 .build() 333 334 val toSavedStateHandleFunSpec = 335 FunSpec.builder("toSavedStateHandle") 336 .apply { 337 if (destination.args.any { it.type is ObjectType }) { 338 addAnnotation(CAST_NEVER_SUCCEEDS) 339 } 340 returns(SAVED_STATE_HANDLE_CLASSNAME) 341 val resultVal = "result" 342 addStatement("val %L = %T()", resultVal, SAVED_STATE_HANDLE_CLASSNAME) 343 destination.args.forEach { arg -> 344 arg.type.addSavedStateSetStatement( 345 this, 346 arg, 347 resultVal, 348 "this.${arg.sanitizedName}" 349 ) 350 } 351 addStatement("return %L", resultVal) 352 } 353 .build() 354 355 val fromSavedStateHandleFunSpec = 356 FunSpec.builder("fromSavedStateHandle") 357 .apply { 358 addAnnotation(JvmStatic::class) 359 returns(className) 360 val savedStateParamName = "savedStateHandle" 361 addParameter(savedStateParamName, SAVED_STATE_HANDLE_CLASSNAME) 362 val tempVariables = 363 destination.args 364 .map { arg -> 365 val tempVal = "__${arg.sanitizedName}" 366 addStatement( 367 "val %L : %T", 368 tempVal, 369 arg.type.typeName().copy(nullable = true) 370 ) 371 beginControlFlow( 372 "if (%L.contains(%S))", 373 savedStateParamName, 374 arg.name 375 ) 376 arg.type.addSavedStateGetStatement( 377 this, 378 arg, 379 tempVal, 380 savedStateParamName 381 ) 382 if (!arg.isNullable) { 383 beginControlFlow("if (%L == null)", tempVal) 384 val errorMessage = 385 if (arg.type.allowsNullable()) { 386 "Argument \"${arg.name}\" is marked as non-null but was passed a null value" 387 } else { 388 "Argument \"${arg.name}\" of type ${arg.type} does not support null values" 389 } 390 addStatement( 391 "throw·%T(%S)", 392 IllegalArgumentException::class.asTypeName(), 393 errorMessage 394 ) 395 endControlFlow() 396 } 397 nextControlFlow("else") 398 val defaultValue = arg.defaultValue 399 if (defaultValue != null) { 400 addStatement("%L = %L", tempVal, arg.defaultValue.write()) 401 } else { 402 addStatement( 403 "throw·%T(%S)", 404 IllegalArgumentException::class.asTypeName(), 405 "Required argument \"${arg.name}\" is missing and does not have an " + 406 "android:defaultValue" 407 ) 408 } 409 endControlFlow() 410 arg 411 } 412 .sortedBy { it.defaultValue != null } 413 addStatement( 414 "return·%T(${tempVariables.joinToString(", ") { "__${it.sanitizedName}" }})", 415 className 416 ) 417 } 418 .build() 419 420 val typeSpec = 421 TypeSpec.classBuilder(className) 422 .addSuperinterface(NAV_ARGS_CLASSNAME) 423 .addModifiers(KModifier.DATA) 424 .primaryConstructor(constructorFunSpec) 425 .addProperties( 426 destination.args.map { arg -> 427 PropertySpec.builder( 428 arg.sanitizedName, 429 arg.type.typeName().copy(nullable = arg.isNullable) 430 ) 431 .initializer(arg.sanitizedName) 432 .build() 433 } 434 ) 435 .addFunction(toBundleFunSpec) 436 .addFunction(toSavedStateHandleFunSpec) 437 .addType( 438 TypeSpec.companionObjectBuilder() 439 .addFunction(fromBundleFunSpec) 440 .addFunction(fromSavedStateHandleFunSpec) 441 .build() 442 ) 443 .build() 444 445 return FileSpec.builder(className.packageName, className.simpleName) 446 .addType(typeSpec) 447 .build() 448 .toCodeFile() 449 } 450 451 companion object { 452 /** 453 * Annotation to suppress casts that never succeed. This is necessary since the generated 454 * code will contain branches that contain a cast that will never occur and succeed. The 455 * reason being that Safe Args is not an annotation processor and cannot inspect the class 456 * hierarchy to generate the correct cast branch only. 457 */ 458 val CAST_NEVER_SUCCEEDS = 459 AnnotationSpec.builder(Suppress::class).addMember("%S", "CAST_NEVER_SUCCEEDS").build() 460 461 val CHECK_RESULT = AnnotationSpec.builder(CheckResult::class).build() 462 } 463 } 464