/* * Copyright 2020 Google LLC * Copyright 2010-2020 JetBrains s.r.o. and Kotlin Programming Language contributors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.google.devtools.ksp import com.google.devtools.ksp.processing.Resolver import com.google.devtools.ksp.symbol.* import com.google.devtools.ksp.visitor.KSValidateVisitor import java.lang.reflect.InvocationHandler import java.lang.reflect.Method import java.lang.reflect.Proxy import java.util.concurrent.ConcurrentHashMap import kotlin.reflect.KClass /** * Try to resolve the [KSClassDeclaration] for a class using its fully qualified name. * * @param T The class to resolve a [KSClassDeclaration] for. * @return Resolved [KSClassDeclaration] if found, `null` otherwise. * * @see [Resolver.getClassDeclarationByName] */ inline fun Resolver.getClassDeclarationByName(): KSClassDeclaration? { return T::class.qualifiedName?.let { fqcn -> getClassDeclarationByName(getKSNameFromString(fqcn)) } } /** * Find a class in the compilation classpath for the given name. * * @param name fully qualified name of the class to be loaded; using '.' as separator. * @return a KSClassDeclaration, or null if not found. */ fun Resolver.getClassDeclarationByName(name: String): KSClassDeclaration? = getClassDeclarationByName(getKSNameFromString(name)) /** * Find functions in the compilation classpath for the given name. * * @param name fully qualified name of the function to be loaded; using '.' as separator. * @param includeTopLevel a boolean value indicate if top level functions should be searched. Default false. Note if top level functions are included, this operation can be expensive. * @return a Sequence of KSFunctionDeclaration. */ fun Resolver.getFunctionDeclarationsByName( name: String, includeTopLevel: Boolean = false ): Sequence = getFunctionDeclarationsByName(getKSNameFromString(name), includeTopLevel) /** * Find a property in the compilation classpath for the given name. * * @param name fully qualified name of the property to be loaded; using '.' as separator. * @param includeTopLevel a boolean value indicate if top level properties should be searched. Default false. Note if top level properties are included, this operation can be expensive. * @return a KSPropertyDeclaration, or null if not found. */ fun Resolver.getPropertyDeclarationByName(name: String, includeTopLevel: Boolean = false): KSPropertyDeclaration? = getPropertyDeclarationByName(getKSNameFromString(name), includeTopLevel) /** * Find the containing file of a KSNode. * @return KSFile if the given KSNode has a containing file. * exmample of symbols without a containing file: symbols from class files, synthetic symbols craeted by user. */ val KSNode.containingFile: KSFile? get() { var parent = this.parent while (parent != null && parent !is KSFile) { parent = parent.parent } return parent as? KSFile? } /** * Get functions directly declared inside the class declaration. * * What are included: member functions, constructors, extension functions declared inside it, etc. * What are NOT included: inherited functions, extension functions declared outside it. */ fun KSClassDeclaration.getDeclaredFunctions(): Sequence { return this.declarations.filterIsInstance() } /** * Get properties directly declared inside the class declaration. * * What are included: member properties, extension properties declared inside it, etc. * What are NOT included: inherited properties, extension properties declared outside it. */ fun KSClassDeclaration.getDeclaredProperties(): Sequence { return this.declarations.filterIsInstance() } fun KSClassDeclaration.getConstructors(): Sequence { return getDeclaredFunctions().filter { it.isConstructor() } } /** * Check whether this is a local declaration, or namely, declared in a function. */ fun KSDeclaration.isLocal(): Boolean { return this.parentDeclaration != null && this.parentDeclaration !is KSClassDeclaration } /** * Perform a validation on a given symbol to check if all interested types in symbols enclosed scope are valid, i.e. resolvable. * @param predicate: A lambda for filtering interested symbols for performance purpose. Default checks all. */ fun KSNode.validate(predicate: (KSNode?, KSNode) -> Boolean = { _, _ -> true }): Boolean { return this.accept(KSValidateVisitor(predicate), null) } /** * Find the KSClassDeclaration that the alias points to recursively. */ fun KSTypeAlias.findActualType(): KSClassDeclaration { val resolvedType = this.type.resolve().declaration return if (resolvedType is KSTypeAlias) { resolvedType.findActualType() } else { resolvedType as KSClassDeclaration } } /** * Determine [Visibility] of a [KSDeclaration]. */ fun KSDeclaration.getVisibility(): Visibility { return when { this.modifiers.contains(Modifier.PUBLIC) -> Visibility.PUBLIC this.modifiers.contains(Modifier.OVERRIDE) -> { when (this) { is KSFunctionDeclaration -> this.findOverridee()?.getVisibility() is KSPropertyDeclaration -> this.findOverridee()?.getVisibility() else -> null } ?: Visibility.PUBLIC } this.isLocal() -> Visibility.LOCAL this.modifiers.contains(Modifier.PRIVATE) -> Visibility.PRIVATE this.modifiers.contains(Modifier.PROTECTED) || this.modifiers.contains(Modifier.OVERRIDE) -> Visibility.PROTECTED this.modifiers.contains(Modifier.INTERNAL) -> Visibility.INTERNAL else -> if (this.origin != Origin.JAVA && this.origin != Origin.JAVA_LIB) Visibility.PUBLIC else Visibility.JAVA_PACKAGE } } /** * get all super types for a class declaration * Calling [getAllSuperTypes] requires type resolution therefore is expensive and should be avoided if possible. */ fun KSClassDeclaration.getAllSuperTypes(): Sequence { fun KSTypeParameter.getTypesUpperBound(): Sequence = this.bounds.asSequence().flatMap { when (val resolvedDeclaration = it.resolve().declaration) { is KSClassDeclaration -> sequenceOf(resolvedDeclaration) is KSTypeAlias -> sequenceOf(resolvedDeclaration.findActualType()) is KSTypeParameter -> resolvedDeclaration.getTypesUpperBound() else -> throw IllegalStateException("unhandled type parameter bound, $ExceptionMessage") } } return this.superTypes .asSequence() .map { it.resolve() } .plus( this.superTypes .asSequence() .mapNotNull { it.resolve().declaration } .flatMap { when (it) { is KSClassDeclaration -> it.getAllSuperTypes() is KSTypeAlias -> it.findActualType().getAllSuperTypes() is KSTypeParameter -> it.getTypesUpperBound().flatMap { it.getAllSuperTypes() } else -> throw IllegalStateException("unhandled super type kind, $ExceptionMessage") } } ) .distinct() } fun KSClassDeclaration.isAbstract() = this.classKind == ClassKind.INTERFACE || this.modifiers.contains(Modifier.ABSTRACT) fun KSPropertyDeclaration.isAbstract(): Boolean { if (modifiers.contains(Modifier.ABSTRACT)) { return true } val parentClass = parentDeclaration as? KSClassDeclaration ?: return false if (parentClass.classKind != ClassKind.INTERFACE) return false // this is abstract if it does not have setter/getter or setter/getter have abstract modifiers return (getter?.modifiers?.contains(Modifier.ABSTRACT) ?: true) && (setter?.modifiers?.contains(Modifier.ABSTRACT) ?: true) } fun KSDeclaration.isOpen() = !this.isLocal() && ( (this as? KSClassDeclaration)?.classKind == ClassKind.INTERFACE || this.modifiers.contains(Modifier.OVERRIDE) || this.modifiers.contains(Modifier.ABSTRACT) || this.modifiers.contains(Modifier.OPEN) || this.modifiers.contains(Modifier.SEALED) || ( this !is KSClassDeclaration && (this.parentDeclaration as? KSClassDeclaration)?.classKind == ClassKind.INTERFACE ) || (!this.modifiers.contains(Modifier.FINAL) && this.origin == Origin.JAVA) ) fun KSDeclaration.isPublic() = this.getVisibility() == Visibility.PUBLIC fun KSDeclaration.isProtected() = this.getVisibility() == Visibility.PROTECTED fun KSDeclaration.isInternal() = this.modifiers.contains(Modifier.INTERNAL) fun KSDeclaration.isPrivate() = this.modifiers.contains(Modifier.PRIVATE) fun KSDeclaration.isJavaPackagePrivate() = this.getVisibility() == Visibility.JAVA_PACKAGE fun KSDeclaration.closestClassDeclaration(): KSClassDeclaration? { return if (this is KSClassDeclaration) { this } else { this.parentDeclaration?.closestClassDeclaration() } } // TODO: cross module visibility is not handled fun KSDeclaration.isVisibleFrom(other: KSDeclaration): Boolean { fun KSDeclaration.isSamePackage(other: KSDeclaration): Boolean = this.packageName == other.packageName // lexical scope for local declaration. fun KSDeclaration.parentDeclarationsForLocal(): List { val parents = mutableListOf() var parentDeclaration = this.parentDeclaration!! while (parentDeclaration.isLocal()) { parents.add(parentDeclaration) parentDeclaration = parentDeclaration.parentDeclaration!! } parents.add(parentDeclaration) return parents } fun KSDeclaration.isVisibleInPrivate(other: KSDeclaration) = (other.isLocal() && other.parentDeclarationsForLocal().contains(this.parentDeclaration)) || this.parentDeclaration == other.parentDeclaration || this.parentDeclaration == other || ( this.parentDeclaration == null && other.parentDeclaration == null && this.containingFile == other.containingFile ) return when { // locals are limited to lexical scope this.isLocal() -> this.parentDeclarationsForLocal().contains(other) // file visibility or member // TODO: address nested class. this.isPrivate() -> this.isVisibleInPrivate(other) this.isPublic() -> true this.isInternal() && other.containingFile != null && this.containingFile != null -> true this.isJavaPackagePrivate() -> this.isSamePackage(other) this.isProtected() -> this.isVisibleInPrivate(other) || this.isSamePackage(other) || other.closestClassDeclaration()?.let { this.closestClassDeclaration()!!.asStarProjectedType().isAssignableFrom(it.asStarProjectedType()) } ?: false else -> false } } /** * Returns `true` if this is a constructor function. */ fun KSFunctionDeclaration.isConstructor() = this.simpleName.asString() == "" const val ExceptionMessage = "please file a bug at https://github.com/google/ksp/issues/new" val KSType.outerType: KSType? get() { if (Modifier.INNER !in declaration.modifiers) return null val outerDecl = declaration.parentDeclaration as? KSClassDeclaration ?: return null return outerDecl.asType(arguments.subList(declaration.typeParameters.size, arguments.size)) } val KSType.innerArguments: List get() = arguments.subList(0, declaration.typeParameters.size) @KspExperimental fun Resolver.getKotlinClassByName(name: KSName): KSClassDeclaration? { val kotlinName = mapJavaNameToKotlin(name) ?: name return getClassDeclarationByName(kotlinName) } @KspExperimental fun Resolver.getKotlinClassByName(name: String): KSClassDeclaration? = getKotlinClassByName(getKSNameFromString(name)) @KspExperimental fun Resolver.getJavaClassByName(name: KSName): KSClassDeclaration? { val javaName = mapKotlinNameToJava(name) ?: name return getClassDeclarationByName(javaName) } @KspExperimental fun Resolver.getJavaClassByName(name: String): KSClassDeclaration? = getJavaClassByName(getKSNameFromString(name)) @KspExperimental fun KSAnnotated.getAnnotationsByType(annotationKClass: KClass): Sequence { return this.annotations.filter { it.shortName.getShortName() == annotationKClass.simpleName && it.annotationType.resolve().declaration .qualifiedName?.asString() == annotationKClass.qualifiedName }.map { it.toAnnotation(annotationKClass.java) } } @KspExperimental fun KSAnnotated.isAnnotationPresent(annotationKClass: KClass): Boolean = getAnnotationsByType(annotationKClass).firstOrNull() != null @KspExperimental @Suppress("UNCHECKED_CAST") private fun KSAnnotation.toAnnotation(annotationClass: Class): T { return Proxy.newProxyInstance( annotationClass.classLoader, arrayOf(annotationClass), createInvocationHandler(annotationClass) ) as T } @KspExperimental @Suppress("TooGenericExceptionCaught") private fun KSAnnotation.createInvocationHandler(clazz: Class<*>): InvocationHandler { val cache = ConcurrentHashMap, Any>, Any>(arguments.size) return InvocationHandler { proxy, method, _ -> if (method.name == "toString" && arguments.none { it.name?.asString() == "toString" }) { clazz.canonicalName + arguments.map { argument: KSValueArgument -> // handles default values for enums otherwise returns null val methodName = argument.name?.asString() val value = proxy.javaClass.methods.find { m -> m.name == methodName }?.invoke(proxy) "$methodName=$value" }.toList() } else { val argument = arguments.first { it.name?.asString() == method.name } when (val result = argument.value ?: method.defaultValue) { is Proxy -> result is List<*> -> { val value = { result.asArray(method, clazz) } cache.getOrPut(Pair(method.returnType, result), value) } else -> { when { method.returnType.isEnum -> { val value = { result.asEnum(method.returnType) } cache.getOrPut(Pair(method.returnType, result), value) } method.returnType.isAnnotation -> { val value = { (result as KSAnnotation).asAnnotation(method.returnType) } cache.getOrPut(Pair(method.returnType, result), value) } method.returnType.name == "java.lang.Class" -> { cache.getOrPut(Pair(method.returnType, result)) { when (result) { is KSType -> result.asClass(clazz) // Handles com.intellij.psi.impl.source.PsiImmediateClassType using reflection // since api doesn't contain a reference to this else -> Class.forName( result.javaClass.methods .first { it.name == "getCanonicalText" } .invoke(result, false) as String ) } } } method.returnType.name == "byte" -> { val value = { result.asByte() } cache.getOrPut(Pair(method.returnType, result), value) } method.returnType.name == "short" -> { val value = { result.asShort() } cache.getOrPut(Pair(method.returnType, result), value) } method.returnType.name == "long" -> { val value = { result.asLong() } cache.getOrPut(Pair(method.returnType, result), value) } method.returnType.name == "float" -> { val value = { result.asFloat() } cache.getOrPut(Pair(method.returnType, result), value) } method.returnType.name == "double" -> { val value = { result.asDouble() } cache.getOrPut(Pair(method.returnType, result), value) } else -> result // original value } } } } } } @KspExperimental @Suppress("UNCHECKED_CAST") private fun KSAnnotation.asAnnotation( annotationInterface: Class<*>, ): Any { return Proxy.newProxyInstance( annotationInterface.classLoader, arrayOf(annotationInterface), this.createInvocationHandler(annotationInterface) ) as Proxy } @KspExperimental @Suppress("UNCHECKED_CAST") private fun List<*>.asArray(method: Method, proxyClass: Class<*>) = when (method.returnType.componentType.name) { "boolean" -> (this as List).toBooleanArray() "byte" -> (this as List).toByteArray() "short" -> (this as List).toShortArray() "char" -> (this as List).toCharArray() "double" -> (this as List).toDoubleArray() "float" -> (this as List).toFloatArray() "int" -> (this as List).toIntArray() "long" -> (this as List).toLongArray() "java.lang.Class" -> (this as List).asClasses(proxyClass).toTypedArray() "java.lang.String" -> (this as List).toTypedArray() else -> { // arrays of enums or annotations when { method.returnType.componentType.isEnum -> { this.toArray(method) { result -> result.asEnum(method.returnType.componentType) } } method.returnType.componentType.isAnnotation -> { this.toArray(method) { result -> (result as KSAnnotation).asAnnotation(method.returnType.componentType) } } else -> throw IllegalStateException("Unable to process type ${method.returnType.componentType.name}") } } } @Suppress("UNCHECKED_CAST") private fun List<*>.toArray(method: Method, valueProvider: (Any) -> Any): Array { val array: Array = java.lang.reflect.Array.newInstance( method.returnType.componentType, this.size ) as Array for (r in 0 until this.size) { array[r] = this[r]?.let { valueProvider.invoke(it) } } return array } @Suppress("UNCHECKED_CAST") private fun Any.asEnum(returnType: Class): T = returnType.getDeclaredMethod("valueOf", String::class.java) .invoke( null, if (this is KSType) { this.declaration.simpleName.getShortName() } else { this.toString() } ) as T private fun Any.asByte(): Byte = if (this is Int) this.toByte() else this as Byte private fun Any.asShort(): Short = if (this is Int) this.toShort() else this as Short private fun Any.asLong(): Long = if (this is Int) this.toLong() else this as Long private fun Any.asFloat(): Float = if (this is Int) this.toFloat() else this as Float private fun Any.asDouble(): Double = if (this is Int) this.toDouble() else this as Double // for Class/KClass member @KspExperimental class KSTypeNotPresentException(val ksType: KSType, cause: Throwable) : RuntimeException(cause) // for Class[]/Array> member. @KspExperimental class KSTypesNotPresentException(val ksTypes: List, cause: Throwable) : RuntimeException(cause) @KspExperimental private fun KSType.asClass(proxyClass: Class<*>) = try { Class.forName(this.declaration.qualifiedName!!.asString(), true, proxyClass.classLoader) } catch (e: Exception) { throw KSTypeNotPresentException(this, e) } @KspExperimental private fun List.asClasses(proxyClass: Class<*>) = try { this.map { type -> type.asClass(proxyClass) } } catch (e: Exception) { throw KSTypesNotPresentException(this, e) } fun KSValueArgument.isDefault() = origin == Origin.SYNTHETIC