1 /*
<lambda>null2  * Copyright 2020 Google LLC
3  * Copyright 2010-2020 JetBrains s.r.o. and Kotlin Programming Language contributors.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  * http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17 package com.google.devtools.ksp
18 
19 import com.google.devtools.ksp.processing.Resolver
20 import com.google.devtools.ksp.symbol.*
21 import com.google.devtools.ksp.visitor.KSValidateVisitor
22 import java.lang.reflect.InvocationHandler
23 import java.lang.reflect.Method
24 import java.lang.reflect.Proxy
25 import java.util.concurrent.ConcurrentHashMap
26 import kotlin.reflect.KClass
27 
28 /**
29  * Try to resolve the [KSClassDeclaration] for a class using its fully qualified name.
30  *
31  * @param T The class to resolve a [KSClassDeclaration] for.
32  * @return Resolved [KSClassDeclaration] if found, `null` otherwise.
33  *
34  * @see [Resolver.getClassDeclarationByName]
35  */
36 inline fun <reified T> Resolver.getClassDeclarationByName(): KSClassDeclaration? {
37     return T::class.qualifiedName?.let { fqcn ->
38         getClassDeclarationByName(getKSNameFromString(fqcn))
39     }
40 }
41 
42 /**
43  * Find a class in the compilation classpath for the given name.
44  *
45  * @param name fully qualified name of the class to be loaded; using '.' as separator.
46  * @return a KSClassDeclaration, or null if not found.
47  */
getClassDeclarationByNamenull48 fun Resolver.getClassDeclarationByName(name: String): KSClassDeclaration? =
49     getClassDeclarationByName(getKSNameFromString(name))
50 
51 /**
52  * Find functions in the compilation classpath for the given name.
53  *
54  * @param name fully qualified name of the function to be loaded; using '.' as separator.
55  * @param includeTopLevel a boolean value indicate if top level functions should be searched. Default false. Note if top level functions are included, this operation can be expensive.
56  * @return a Sequence of KSFunctionDeclaration.
57  */
58 fun Resolver.getFunctionDeclarationsByName(
59     name: String,
60     includeTopLevel: Boolean = false
61 ): Sequence<KSFunctionDeclaration> = getFunctionDeclarationsByName(getKSNameFromString(name), includeTopLevel)
62 
63 /**
64  * Find a property in the compilation classpath for the given name.
65  *
66  * @param name fully qualified name of the property to be loaded; using '.' as separator.
67  * @param includeTopLevel a boolean value indicate if top level properties should be searched. Default false. Note if top level properties are included, this operation can be expensive.
68  * @return a KSPropertyDeclaration, or null if not found.
69  */
70 fun Resolver.getPropertyDeclarationByName(name: String, includeTopLevel: Boolean = false): KSPropertyDeclaration? =
71     getPropertyDeclarationByName(getKSNameFromString(name), includeTopLevel)
72 
73 /**
74  * Find the containing file of a KSNode.
75  * @return KSFile if the given KSNode has a containing file.
76  * exmample of symbols without a containing file: symbols from class files, synthetic symbols craeted by user.
77  */
78 val KSNode.containingFile: KSFile?
79     get() {
80         var parent = this.parent
81         while (parent != null && parent !is KSFile) {
82             parent = parent.parent
83         }
84         return parent as? KSFile?
85     }
86 
87 /**
88  * Get functions directly declared inside the class declaration.
89  *
90  * What are included: member functions, constructors, extension functions declared inside it, etc.
91  * What are NOT included: inherited functions, extension functions declared outside it.
92  */
getDeclaredFunctionsnull93 fun KSClassDeclaration.getDeclaredFunctions(): Sequence<KSFunctionDeclaration> {
94     return this.declarations.filterIsInstance<KSFunctionDeclaration>()
95 }
96 
97 /**
98  * Get properties directly declared inside the class declaration.
99  *
100  * What are included: member properties, extension properties declared inside it, etc.
101  * What are NOT included: inherited properties, extension properties declared outside it.
102  */
getDeclaredPropertiesnull103 fun KSClassDeclaration.getDeclaredProperties(): Sequence<KSPropertyDeclaration> {
104     return this.declarations.filterIsInstance<KSPropertyDeclaration>()
105 }
106 
getConstructorsnull107 fun KSClassDeclaration.getConstructors(): Sequence<KSFunctionDeclaration> {
108     return getDeclaredFunctions().filter {
109         it.isConstructor()
110     }
111 }
112 
113 /**
114  * Check whether this is a local declaration, or namely, declared in a function.
115  */
KSDeclarationnull116 fun KSDeclaration.isLocal(): Boolean {
117     return this.parentDeclaration != null && this.parentDeclaration !is KSClassDeclaration
118 }
119 
120 /**
121  * Perform a validation on a given symbol to check if all interested types in symbols enclosed scope are valid, i.e. resolvable.
122  * @param predicate: A lambda for filtering interested symbols for performance purpose. Default checks all.
123  */
KSNodenull124 fun KSNode.validate(predicate: (KSNode?, KSNode) -> Boolean = { _, _ -> true }): Boolean {
125     return this.accept(KSValidateVisitor(predicate), null)
126 }
127 
128 /**
129  * Find the KSClassDeclaration that the alias points to recursively.
130  */
KSTypeAliasnull131 fun KSTypeAlias.findActualType(): KSClassDeclaration {
132     val resolvedType = this.type.resolve().declaration
133     return if (resolvedType is KSTypeAlias) {
134         resolvedType.findActualType()
135     } else {
136         resolvedType as KSClassDeclaration
137     }
138 }
139 
140 /**
141  * Determine [Visibility] of a [KSDeclaration].
142  */
KSDeclarationnull143 fun KSDeclaration.getVisibility(): Visibility {
144     return when {
145         this.modifiers.contains(Modifier.PUBLIC) -> Visibility.PUBLIC
146         this.modifiers.contains(Modifier.OVERRIDE) -> {
147             when (this) {
148                 is KSFunctionDeclaration -> this.findOverridee()?.getVisibility()
149                 is KSPropertyDeclaration -> this.findOverridee()?.getVisibility()
150                 else -> null
151             } ?: Visibility.PUBLIC
152         }
153         this.isLocal() -> Visibility.LOCAL
154         this.modifiers.contains(Modifier.PRIVATE) -> Visibility.PRIVATE
155         this.modifiers.contains(Modifier.PROTECTED) ||
156             this.modifiers.contains(Modifier.OVERRIDE) -> Visibility.PROTECTED
157         this.modifiers.contains(Modifier.INTERNAL) -> Visibility.INTERNAL
158         else -> if (this.origin != Origin.JAVA && this.origin != Origin.JAVA_LIB)
159             Visibility.PUBLIC else Visibility.JAVA_PACKAGE
160     }
161 }
162 
163 /**
164  * get all super types for a class declaration
165  * Calling [getAllSuperTypes] requires type resolution therefore is expensive and should be avoided if possible.
166  */
KSClassDeclarationnull167 fun KSClassDeclaration.getAllSuperTypes(): Sequence<KSType> {
168 
169     fun KSTypeParameter.getTypesUpperBound(): Sequence<KSClassDeclaration> =
170         this.bounds.asSequence().flatMap {
171             when (val resolvedDeclaration = it.resolve().declaration) {
172                 is KSClassDeclaration -> sequenceOf(resolvedDeclaration)
173                 is KSTypeAlias -> sequenceOf(resolvedDeclaration.findActualType())
174                 is KSTypeParameter -> resolvedDeclaration.getTypesUpperBound()
175                 else -> throw IllegalStateException("unhandled type parameter bound, $ExceptionMessage")
176             }
177         }
178 
179     return this.superTypes
180         .asSequence()
181         .map { it.resolve() }
182         .plus(
183             this.superTypes
184                 .asSequence()
185                 .mapNotNull { it.resolve().declaration }
186                 .flatMap {
187                     when (it) {
188                         is KSClassDeclaration -> it.getAllSuperTypes()
189                         is KSTypeAlias -> it.findActualType().getAllSuperTypes()
190                         is KSTypeParameter -> it.getTypesUpperBound().flatMap { it.getAllSuperTypes() }
191                         else -> throw IllegalStateException("unhandled super type kind, $ExceptionMessage")
192                     }
193                 }
194         )
195         .distinct()
196 }
197 
KSClassDeclarationnull198 fun KSClassDeclaration.isAbstract() =
199     this.classKind == ClassKind.INTERFACE || this.modifiers.contains(Modifier.ABSTRACT)
200 
201 fun KSPropertyDeclaration.isAbstract(): Boolean {
202     if (modifiers.contains(Modifier.ABSTRACT)) {
203         return true
204     }
205     val parentClass = parentDeclaration as? KSClassDeclaration ?: return false
206     if (parentClass.classKind != ClassKind.INTERFACE) return false
207     // this is abstract if it does not have setter/getter or setter/getter have abstract modifiers
208     return (getter?.modifiers?.contains(Modifier.ABSTRACT) ?: true) &&
209         (setter?.modifiers?.contains(Modifier.ABSTRACT) ?: true)
210 }
211 
isOpennull212 fun KSDeclaration.isOpen() = !this.isLocal() &&
213     (
214         (this as? KSClassDeclaration)?.classKind == ClassKind.INTERFACE ||
215             this.modifiers.contains(Modifier.OVERRIDE) ||
216             this.modifiers.contains(Modifier.ABSTRACT) ||
217             this.modifiers.contains(Modifier.OPEN) ||
218             this.modifiers.contains(Modifier.SEALED) ||
219             (
220                 this !is KSClassDeclaration &&
221                     (this.parentDeclaration as? KSClassDeclaration)?.classKind == ClassKind.INTERFACE
222                 ) ||
223             (!this.modifiers.contains(Modifier.FINAL) && this.origin == Origin.JAVA)
224         )
225 
226 fun KSDeclaration.isPublic() = this.getVisibility() == Visibility.PUBLIC
227 
228 fun KSDeclaration.isProtected() = this.getVisibility() == Visibility.PROTECTED
229 
230 fun KSDeclaration.isInternal() = this.modifiers.contains(Modifier.INTERNAL)
231 
232 fun KSDeclaration.isPrivate() = this.modifiers.contains(Modifier.PRIVATE)
233 
234 fun KSDeclaration.isJavaPackagePrivate() = this.getVisibility() == Visibility.JAVA_PACKAGE
235 
236 fun KSDeclaration.closestClassDeclaration(): KSClassDeclaration? {
237     return if (this is KSClassDeclaration) {
238         this
239     } else {
240         this.parentDeclaration?.closestClassDeclaration()
241     }
242 }
243 
244 // TODO: cross module visibility is not handled
isVisibleFromnull245 fun KSDeclaration.isVisibleFrom(other: KSDeclaration): Boolean {
246     fun KSDeclaration.isSamePackage(other: KSDeclaration): Boolean =
247         this.packageName == other.packageName
248 
249     // lexical scope for local declaration.
250     fun KSDeclaration.parentDeclarationsForLocal(): List<KSDeclaration> {
251         val parents = mutableListOf<KSDeclaration>()
252 
253         var parentDeclaration = this.parentDeclaration!!
254 
255         while (parentDeclaration.isLocal()) {
256             parents.add(parentDeclaration)
257             parentDeclaration = parentDeclaration.parentDeclaration!!
258         }
259 
260         parents.add(parentDeclaration)
261 
262         return parents
263     }
264 
265     fun KSDeclaration.isVisibleInPrivate(other: KSDeclaration) =
266         (other.isLocal() && other.parentDeclarationsForLocal().contains(this.parentDeclaration)) ||
267             this.parentDeclaration == other.parentDeclaration ||
268             this.parentDeclaration == other || (
269             this.parentDeclaration == null &&
270                 other.parentDeclaration == null &&
271                 this.containingFile == other.containingFile
272             )
273 
274     return when {
275         // locals are limited to lexical scope
276         this.isLocal() -> this.parentDeclarationsForLocal().contains(other)
277         // file visibility or member
278         // TODO: address nested class.
279         this.isPrivate() -> this.isVisibleInPrivate(other)
280         this.isPublic() -> true
281         this.isInternal() && other.containingFile != null && this.containingFile != null -> true
282         this.isJavaPackagePrivate() -> this.isSamePackage(other)
283         this.isProtected() -> this.isVisibleInPrivate(other) || this.isSamePackage(other) ||
284             other.closestClassDeclaration()?.let {
285             this.closestClassDeclaration()!!.asStarProjectedType().isAssignableFrom(it.asStarProjectedType())
286         } ?: false
287         else -> false
288     }
289 }
290 
291 /**
292  * Returns `true` if this is a constructor function.
293  */
KSFunctionDeclarationnull294 fun KSFunctionDeclaration.isConstructor() = this.simpleName.asString() == "<init>"
295 
296 const val ExceptionMessage = "please file a bug at https://github.com/google/ksp/issues/new"
297 
298 val KSType.outerType: KSType?
299     get() {
300         if (Modifier.INNER !in declaration.modifiers)
301             return null
302         val outerDecl = declaration.parentDeclaration as? KSClassDeclaration ?: return null
303         return outerDecl.asType(arguments.subList(declaration.typeParameters.size, arguments.size))
304     }
305 
306 val KSType.innerArguments: List<KSTypeArgument>
307     get() = arguments.subList(0, declaration.typeParameters.size)
308 
309 @KspExperimental
Resolvernull310 fun Resolver.getKotlinClassByName(name: KSName): KSClassDeclaration? {
311     val kotlinName = mapJavaNameToKotlin(name) ?: name
312     return getClassDeclarationByName(kotlinName)
313 }
314 
315 @KspExperimental
Resolvernull316 fun Resolver.getKotlinClassByName(name: String): KSClassDeclaration? =
317     getKotlinClassByName(getKSNameFromString(name))
318 
319 @KspExperimental
320 fun Resolver.getJavaClassByName(name: KSName): KSClassDeclaration? {
321     val javaName = mapKotlinNameToJava(name) ?: name
322     return getClassDeclarationByName(javaName)
323 }
324 
325 @KspExperimental
getJavaClassByNamenull326 fun Resolver.getJavaClassByName(name: String): KSClassDeclaration? =
327     getJavaClassByName(getKSNameFromString(name))
328 
329 @KspExperimental
330 fun <T : Annotation> KSAnnotated.getAnnotationsByType(annotationKClass: KClass<T>): Sequence<T> {
331     return this.annotations.filter {
332         it.shortName.getShortName() == annotationKClass.simpleName && it.annotationType.resolve().declaration
333             .qualifiedName?.asString() == annotationKClass.qualifiedName
334     }.map { it.toAnnotation(annotationKClass.java) }
335 }
336 
337 @KspExperimental
isAnnotationPresentnull338 fun <T : Annotation> KSAnnotated.isAnnotationPresent(annotationKClass: KClass<T>): Boolean =
339     getAnnotationsByType(annotationKClass).firstOrNull() != null
340 
341 @KspExperimental
342 @Suppress("UNCHECKED_CAST")
343 private fun <T : Annotation> KSAnnotation.toAnnotation(annotationClass: Class<T>): T {
344     return Proxy.newProxyInstance(
345         annotationClass.classLoader,
346         arrayOf(annotationClass),
347         createInvocationHandler(annotationClass)
348     ) as T
349 }
350 
351 @KspExperimental
352 @Suppress("TooGenericExceptionCaught")
KSAnnotationnull353 private fun KSAnnotation.createInvocationHandler(clazz: Class<*>): InvocationHandler {
354     val cache = ConcurrentHashMap<Pair<Class<*>, Any>, Any>(arguments.size)
355     return InvocationHandler { proxy, method, _ ->
356         if (method.name == "toString" && arguments.none { it.name?.asString() == "toString" }) {
357             clazz.canonicalName +
358                 arguments.map { argument: KSValueArgument ->
359                     // handles default values for enums otherwise returns null
360                     val methodName = argument.name?.asString()
361                     val value = proxy.javaClass.methods.find { m -> m.name == methodName }?.invoke(proxy)
362                     "$methodName=$value"
363                 }.toList()
364         } else {
365             val argument = arguments.first { it.name?.asString() == method.name }
366             when (val result = argument.value ?: method.defaultValue) {
367                 is Proxy -> result
368                 is List<*> -> {
369                     val value = { result.asArray(method, clazz) }
370                     cache.getOrPut(Pair(method.returnType, result), value)
371                 }
372                 else -> {
373                     when {
374                         method.returnType.isEnum -> {
375                             val value = { result.asEnum(method.returnType) }
376                             cache.getOrPut(Pair(method.returnType, result), value)
377                         }
378                         method.returnType.isAnnotation -> {
379                             val value = { (result as KSAnnotation).asAnnotation(method.returnType) }
380                             cache.getOrPut(Pair(method.returnType, result), value)
381                         }
382                         method.returnType.name == "java.lang.Class" -> {
383                             cache.getOrPut(Pair(method.returnType, result)) {
384                                 when (result) {
385                                     is KSType -> result.asClass(clazz)
386                                     // Handles com.intellij.psi.impl.source.PsiImmediateClassType using reflection
387                                     // since api doesn't contain a reference to this
388                                     else -> Class.forName(
389                                         result.javaClass.methods
390                                             .first { it.name == "getCanonicalText" }
391                                             .invoke(result, false) as String
392                                     )
393                                 }
394                             }
395                         }
396                         method.returnType.name == "byte" -> {
397                             val value = { result.asByte() }
398                             cache.getOrPut(Pair(method.returnType, result), value)
399                         }
400                         method.returnType.name == "short" -> {
401                             val value = { result.asShort() }
402                             cache.getOrPut(Pair(method.returnType, result), value)
403                         }
404                         method.returnType.name == "long" -> {
405                             val value = { result.asLong() }
406                             cache.getOrPut(Pair(method.returnType, result), value)
407                         }
408                         method.returnType.name == "float" -> {
409                             val value = { result.asFloat() }
410                             cache.getOrPut(Pair(method.returnType, result), value)
411                         }
412                         method.returnType.name == "double" -> {
413                             val value = { result.asDouble() }
414                             cache.getOrPut(Pair(method.returnType, result), value)
415                         }
416                         else -> result // original value
417                     }
418                 }
419             }
420         }
421     }
422 }
423 
424 @KspExperimental
425 @Suppress("UNCHECKED_CAST")
KSAnnotationnull426 private fun KSAnnotation.asAnnotation(
427     annotationInterface: Class<*>,
428 ): Any {
429     return Proxy.newProxyInstance(
430         annotationInterface.classLoader, arrayOf(annotationInterface),
431         this.createInvocationHandler(annotationInterface)
432     ) as Proxy
433 }
434 
435 @KspExperimental
436 @Suppress("UNCHECKED_CAST")
asArraynull437 private fun List<*>.asArray(method: Method, proxyClass: Class<*>) =
438     when (method.returnType.componentType.name) {
439         "boolean" -> (this as List<Boolean>).toBooleanArray()
440         "byte" -> (this as List<Byte>).toByteArray()
441         "short" -> (this as List<Short>).toShortArray()
442         "char" -> (this as List<Char>).toCharArray()
443         "double" -> (this as List<Double>).toDoubleArray()
444         "float" -> (this as List<Float>).toFloatArray()
445         "int" -> (this as List<Int>).toIntArray()
446         "long" -> (this as List<Long>).toLongArray()
447         "java.lang.Class" -> (this as List<KSType>).asClasses(proxyClass).toTypedArray()
448         "java.lang.String" -> (this as List<String>).toTypedArray()
449         else -> { // arrays of enums or annotations
450             when {
451                 method.returnType.componentType.isEnum -> {
452                     this.toArray(method) { result -> result.asEnum(method.returnType.componentType) }
453                 }
454                 method.returnType.componentType.isAnnotation -> {
455                     this.toArray(method) { result ->
456                         (result as KSAnnotation).asAnnotation(method.returnType.componentType)
457                     }
458                 }
459                 else -> throw IllegalStateException("Unable to process type ${method.returnType.componentType.name}")
460             }
461         }
462     }
463 
464 @Suppress("UNCHECKED_CAST")
toArraynull465 private fun List<*>.toArray(method: Method, valueProvider: (Any) -> Any): Array<Any?> {
466     val array: Array<Any?> = java.lang.reflect.Array.newInstance(
467         method.returnType.componentType,
468         this.size
469     ) as Array<Any?>
470     for (r in 0 until this.size) {
471         array[r] = this[r]?.let { valueProvider.invoke(it) }
472     }
473     return array
474 }
475 
476 @Suppress("UNCHECKED_CAST")
asEnumnull477 private fun <T> Any.asEnum(returnType: Class<T>): T =
478     returnType.getDeclaredMethod("valueOf", String::class.java)
479         .invoke(
480             null,
481             if (this is KSType) {
482                 this.declaration.simpleName.getShortName()
483             } else {
484                 this.toString()
485             }
486         ) as T
487 
asBytenull488 private fun Any.asByte(): Byte = if (this is Int) this.toByte() else this as Byte
489 
490 private fun Any.asShort(): Short = if (this is Int) this.toShort() else this as Short
491 
492 private fun Any.asLong(): Long = if (this is Int) this.toLong() else this as Long
493 
494 private fun Any.asFloat(): Float = if (this is Int) this.toFloat() else this as Float
495 
496 private fun Any.asDouble(): Double = if (this is Int) this.toDouble() else this as Double
497 
498 // for Class/KClass member
499 @KspExperimental
500 class KSTypeNotPresentException(val ksType: KSType, cause: Throwable) : RuntimeException(cause)
501 // for Class[]/Array<KClass<*>> member.
502 @KspExperimental
503 class KSTypesNotPresentException(val ksTypes: List<KSType>, cause: Throwable) : RuntimeException(cause)
504 
505 @KspExperimental
506 private fun KSType.asClass(proxyClass: Class<*>) = try {
507     Class.forName(this.declaration.qualifiedName!!.asString(), true, proxyClass.classLoader)
508 } catch (e: Exception) {
509     throw KSTypeNotPresentException(this, e)
510 }
511 
512 @KspExperimental
Listnull513 private fun List<KSType>.asClasses(proxyClass: Class<*>) = try {
514     this.map { type -> type.asClass(proxyClass) }
515 } catch (e: Exception) {
516     throw KSTypesNotPresentException(this, e)
517 }
518 
isDefaultnull519 fun KSValueArgument.isDefault() = origin == Origin.SYNTHETIC
520