1 /*
<lambda>null2  * Copyright (C) 2016 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package androidx.room.ext
18 
19 import androidx.room.compiler.codegen.CodeLanguage
20 import androidx.room.compiler.codegen.VisibilityModifier
21 import androidx.room.compiler.codegen.XClassName
22 import androidx.room.compiler.codegen.XCodeBlock
23 import androidx.room.compiler.codegen.XFunSpec
24 import androidx.room.compiler.codegen.XMemberName
25 import androidx.room.compiler.codegen.XMemberName.Companion.companionMember
26 import androidx.room.compiler.codegen.XMemberName.Companion.packageMember
27 import androidx.room.compiler.codegen.XTypeName
28 import androidx.room.compiler.codegen.XTypeSpec
29 import androidx.room.compiler.codegen.asClassName
30 import androidx.room.compiler.codegen.asMutableClassName
31 import androidx.room.compiler.codegen.buildCodeBlock
32 import androidx.room.compiler.codegen.compat.XConverters.applyToJavaPoet
33 import androidx.room.ext.RoomGuavaTypeNames.GUAVA_ROOM
34 import androidx.room.solver.CodeGenScope
35 import com.squareup.kotlinpoet.javapoet.JTypeName
36 import java.util.concurrent.Callable
37 
38 object SupportDbTypeNames {
39     val DB = XClassName.get("$SQLITE_PACKAGE.db", "SupportSQLiteDatabase")
40     val SQLITE_STMT = XClassName.get("$SQLITE_PACKAGE.db", "SupportSQLiteStatement")
41     val SQLITE_OPEN_HELPER = XClassName.get("$SQLITE_PACKAGE.db", "SupportSQLiteOpenHelper")
42     val SQLITE_OPEN_HELPER_CALLBACK =
43         XClassName.get("$SQLITE_PACKAGE.db", "SupportSQLiteOpenHelper", "Callback")
44     val SQLITE_OPEN_HELPER_CONFIG =
45         XClassName.get("$SQLITE_PACKAGE.db", "SupportSQLiteOpenHelper", "Configuration")
46     val QUERY = XClassName.get("$SQLITE_PACKAGE.db", "SupportSQLiteQuery")
47 }
48 
49 object SQLiteDriverTypeNames {
50     val SQLITE = XClassName.get(SQLITE_PACKAGE, "SQLite")
51     val DRIVER = XClassName.get(SQLITE_PACKAGE, "SQLiteDriver")
52     val CONNECTION = XClassName.get(SQLITE_PACKAGE, "SQLiteConnection")
53     val STATEMENT = XClassName.get(SQLITE_PACKAGE, "SQLiteStatement")
54 }
55 
56 object RoomTypeNames {
57     val STRING_UTIL = XClassName.get("$ROOM_PACKAGE.util", "StringUtil")
58     val ROOM_DB = XClassName.get(ROOM_PACKAGE, "RoomDatabase")
59     val ROOM_DB_KT = XClassName.get(ROOM_PACKAGE, "RoomDatabaseKt")
60     val ROOM_DB_CALLBACK = XClassName.get(ROOM_PACKAGE, "RoomDatabase", "Callback")
61     val ROOM_DB_CONFIG = XClassName.get(ROOM_PACKAGE, "DatabaseConfiguration")
62     val INSERT_ADAPTER = XClassName.get(ROOM_PACKAGE, "EntityInsertAdapter")
63     val UPSERT_ADAPTER = XClassName.get(ROOM_PACKAGE, "EntityUpsertAdapter")
64     val DELETE_OR_UPDATE_ADAPTER = XClassName.get(ROOM_PACKAGE, "EntityDeleteOrUpdateAdapter")
65     val INSERT_ADAPTER_COMPAT = XClassName.get(ROOM_PACKAGE, "EntityInsertionAdapter")
66     val UPSERT_ADAPTER_COMPAT = XClassName.get(ROOM_PACKAGE, "EntityUpsertionAdapter")
67     val DELETE_OR_UPDATE_ADAPTER_COMPAT =
68         XClassName.get(ROOM_PACKAGE, "EntityDeletionOrUpdateAdapter")
69     val SHARED_SQLITE_STMT = XClassName.get(ROOM_PACKAGE, "SharedSQLiteStatement")
70     val INVALIDATION_TRACKER = XClassName.get(ROOM_PACKAGE, "InvalidationTracker")
71     val ROOM_SQL_QUERY = XClassName.get(ROOM_PACKAGE, "RoomSQLiteQuery")
72     val TABLE_INFO = XClassName.get("$ROOM_PACKAGE.util", "TableInfo")
73     val TABLE_INFO_COLUMN = XClassName.get("$ROOM_PACKAGE.util", "TableInfo", "Column")
74     val TABLE_INFO_FOREIGN_KEY = XClassName.get("$ROOM_PACKAGE.util", "TableInfo", "ForeignKey")
75     val TABLE_INFO_INDEX = XClassName.get("$ROOM_PACKAGE.util", "TableInfo", "Index")
76     val FTS_TABLE_INFO = XClassName.get("$ROOM_PACKAGE.util", "FtsTableInfo")
77     val VIEW_INFO = XClassName.get("$ROOM_PACKAGE.util", "ViewInfo")
78     val LIMIT_OFFSET_DATA_SOURCE = XClassName.get("$ROOM_PACKAGE.paging", "LimitOffsetDataSource")
79     val DB_UTIL = XClassName.get("$ROOM_PACKAGE.util", "DBUtil")
80     val CURSOR_UTIL = XClassName.get("$ROOM_PACKAGE.util", "CursorUtil")
81     val MIGRATION = XClassName.get("$ROOM_PACKAGE.migration", "Migration")
82     val AUTO_MIGRATION_SPEC = XClassName.get("$ROOM_PACKAGE.migration", "AutoMigrationSpec")
83     val UUID_UTIL = XClassName.get("$ROOM_PACKAGE.util", "UUIDUtil")
84     val AMBIGUOUS_COLUMN_RESOLVER = XClassName.get(ROOM_PACKAGE, "AmbiguousColumnResolver")
85     val RELATION_UTIL = XClassName.get("androidx.room.util", "RelationUtil")
86     val ROOM_OPEN_DELEGATE = XClassName.get(ROOM_PACKAGE, "RoomOpenDelegate")
87     val ROOM_OPEN_DELEGATE_VALIDATION_RESULT =
88         XClassName.get(ROOM_PACKAGE, "RoomOpenDelegate", "ValidationResult")
89     val STATEMENT_UTIL = XClassName.get("$ROOM_PACKAGE.util", "SQLiteStatementUtil")
90     val CONNECTION_UTIL = XClassName.get("$ROOM_PACKAGE.util", "SQLiteConnectionUtil")
91     val FLOW_UTIL = XClassName.get("$ROOM_PACKAGE.coroutines", "FlowUtil")
92     val RAW_QUERY = XClassName.get(ROOM_PACKAGE, "RoomRawQuery")
93     val ROOM_DB_CONSTRUCTOR = XClassName.get(ROOM_PACKAGE, "RoomDatabaseConstructor")
94     val BYTE_ARRAY_WRAPPER = XClassName.get("$ROOM_PACKAGE.util", "ByteArrayWrapper")
95 }
96 
97 object RoomAnnotationTypeNames {
98     val QUERY = XClassName.get(ROOM_PACKAGE, "Query")
99     val DAO = XClassName.get(ROOM_PACKAGE, "Dao")
100     val DATABASE = XClassName.get(ROOM_PACKAGE, "Database")
101     val PRIMARY_KEY = XClassName.get(ROOM_PACKAGE, "PrimaryKey")
102     val TYPE_CONVERTERS = XClassName.get(ROOM_PACKAGE, "TypeConverters")
103     val TYPE_CONVERTER = XClassName.get(ROOM_PACKAGE, "TypeConverter")
104     val ENTITY = XClassName.get(ROOM_PACKAGE, "Entity")
105 }
106 
107 object PagingTypeNames {
108     val DATA_SOURCE = XClassName.get(PAGING_PACKAGE, "DataSource")
109     val POSITIONAL_DATA_SOURCE = XClassName.get(PAGING_PACKAGE, "PositionalDataSource")
110     val DATA_SOURCE_FACTORY = XClassName.get(PAGING_PACKAGE, "DataSource", "Factory")
111     val PAGING_SOURCE = XClassName.get(PAGING_PACKAGE, "PagingSource")
112     val LISTENABLE_FUTURE_PAGING_SOURCE =
113         XClassName.get(PAGING_PACKAGE, "ListenableFuturePagingSource")
114     val RX2_PAGING_SOURCE = XClassName.get("$PAGING_PACKAGE.rxjava2", "RxPagingSource")
115     val RX3_PAGING_SOURCE = XClassName.get("$PAGING_PACKAGE.rxjava3", "RxPagingSource")
116 }
117 
118 object LifecyclesTypeNames {
119     val LIVE_DATA = XClassName.get(LIFECYCLE_PACKAGE, "LiveData")
120     val COMPUTABLE_LIVE_DATA = XClassName.get(LIFECYCLE_PACKAGE, "ComputableLiveData")
121 }
122 
123 object AndroidTypeNames {
124     val CURSOR = XClassName.get("android.database", "Cursor")
125     val BUILD = XClassName.get("android.os", "Build")
126     val CANCELLATION_SIGNAL = XClassName.get("android.os", "CancellationSignal")
127 }
128 
129 object CollectionTypeNames {
130     val ARRAY_MAP = XClassName.get(COLLECTION_PACKAGE, "ArrayMap")
131     val LONG_SPARSE_ARRAY = XClassName.get(COLLECTION_PACKAGE, "LongSparseArray")
132     val INT_SPARSE_ARRAY = XClassName.get(COLLECTION_PACKAGE, "SparseArrayCompat")
133 }
134 
135 object KotlinCollectionMemberNames {
136     val ARRAY_OF_NULLS = XClassName.get("kotlin", "LibraryKt").packageMember("arrayOfNulls")
137     val MUTABLE_LIST_OF = KotlinTypeNames.COLLECTIONS_KT.packageMember("mutableListOf")
138     val MUTABLE_SET_OF = KotlinTypeNames.SETS_KT.packageMember("mutableSetOf")
139     val MUTABLE_MAP_OF = KotlinTypeNames.MAPS_KT.packageMember("mutableMapOf")
140 }
141 
142 object CommonTypeNames {
143     val VOID = Void::class.asClassName()
144     val COLLECTION = Collection::class.asClassName()
145     val COLLECTIONS = XClassName.get("java.util", "Collections")
146     val ARRAYS = XClassName.get("java.util", "Arrays")
147     val LIST = List::class.asClassName()
148     val MUTABLE_LIST = List::class.asMutableClassName()
149     val ARRAY_LIST = XClassName.get("java.util", "ArrayList")
150     val MAP = Map::class.asClassName()
151     val MUTABLE_MAP = Map::class.asMutableClassName()
152     val HASH_MAP = XClassName.get("java.util", "HashMap")
153     val QUEUE = XClassName.get("java.util", "Queue")
154     val LINKED_HASH_MAP = LinkedHashMap::class.asClassName()
155     val SET = Set::class.asClassName()
156     val MUTABLE_SET = Set::class.asMutableClassName()
157     val HASH_SET = XClassName.get("java.util", "HashSet")
158     val STRING = String::class.asClassName()
159     val STRING_BUILDER = XClassName.get("java.lang", "StringBuilder")
160     val OPTIONAL = XClassName.get("java.util", "Optional")
161     val UUID = XClassName.get("java.util", "UUID")
162     val BYTE_BUFFER = XClassName.get("java.nio", "ByteBuffer")
163     val JAVA_CLASS = XClassName.get("java.lang", "Class")
164     val KOTLIN_CLASS = XClassName.get("kotlin.reflect", "KClass")
165     val CALLABLE = Callable::class.asClassName()
166     val DATE = XClassName.get("java.util", "Date")
167 }
168 
169 object ExceptionTypeNames {
170     val JAVA_ILLEGAL_STATE_EXCEPTION = XClassName.get("java.lang", "IllegalStateException")
171     val JAVA_ILLEGAL_ARG_EXCEPTION = XClassName.get("java.lang", "IllegalArgumentException")
172     val KOTLIN_ILLEGAL_STATE_EXCEPTION = XClassName.get("kotlin", "IllegalStateException")
173     val KOTLIN_ILLEGAL_ARG_EXCEPTION = XClassName.get("kotlin", "IllegalArgumentException")
174 }
175 
176 object GuavaTypeNames {
177     val OPTIONAL = XClassName.get("com.google.common.base", "Optional")
178     val IMMUTABLE_MULTIMAP_BUILDER =
179         XClassName.get("com.google.common.collect", "ImmutableMultimap", "Builder")
180     val IMMUTABLE_SET_MULTIMAP = XClassName.get("com.google.common.collect", "ImmutableSetMultimap")
181     val IMMUTABLE_SET_MULTIMAP_BUILDER =
182         XClassName.get("com.google.common.collect", "ImmutableSetMultimap", "Builder")
183     val IMMUTABLE_LIST_MULTIMAP =
184         XClassName.get("com.google.common.collect", "ImmutableListMultimap")
185     val IMMUTABLE_LIST_MULTIMAP_BUILDER =
186         XClassName.get("com.google.common.collect", "ImmutableListMultimap", "Builder")
187     val IMMUTABLE_MAP = XClassName.get("com.google.common.collect", "ImmutableMap")
188     val IMMUTABLE_LIST = XClassName.get("com.google.common.collect", "ImmutableList")
189     val IMMUTABLE_LIST_BUILDER =
190         XClassName.get("com.google.common.collect", "ImmutableList", "Builder")
191 }
192 
193 object GuavaUtilConcurrentTypeNames {
194     val LISTENABLE_FUTURE = XClassName.get("com.google.common.util.concurrent", "ListenableFuture")
195 }
196 
197 object RxJava2TypeNames {
198     val FLOWABLE = XClassName.get("io.reactivex", "Flowable")
199     val OBSERVABLE = XClassName.get("io.reactivex", "Observable")
200     val MAYBE = XClassName.get("io.reactivex", "Maybe")
201     val SINGLE = XClassName.get("io.reactivex", "Single")
202     val COMPLETABLE = XClassName.get("io.reactivex", "Completable")
203 }
204 
205 object RxJava3TypeNames {
206     val FLOWABLE = XClassName.get("io.reactivex.rxjava3.core", "Flowable")
207     val OBSERVABLE = XClassName.get("io.reactivex.rxjava3.core", "Observable")
208     val MAYBE = XClassName.get("io.reactivex.rxjava3.core", "Maybe")
209     val SINGLE = XClassName.get("io.reactivex.rxjava3.core", "Single")
210     val COMPLETABLE = XClassName.get("io.reactivex.rxjava3.core", "Completable")
211 }
212 
213 object ReactiveStreamsTypeNames {
214     val PUBLISHER = XClassName.get("org.reactivestreams", "Publisher")
215 }
216 
217 object RoomGuavaTypeNames {
218     val GUAVA_ROOM = XClassName.get("$ROOM_PACKAGE.guava", "GuavaRoom")
219     val GUAVA_ROOM_MARKER = XClassName.get("$ROOM_PACKAGE.guava", "GuavaRoomArtifactMarker")
220 }
221 
222 object RoomGuavaMemberNames {
223     val GUAVA_ROOM_CREATE_LISTENABLE_FUTURE = GUAVA_ROOM.packageMember("createListenableFuture")
224 }
225 
226 object RoomRxJava2TypeNames {
227     val RX2_ROOM = XClassName.get(ROOM_PACKAGE, "RxRoom")
228     val RX2_EMPTY_RESULT_SET_EXCEPTION = XClassName.get(ROOM_PACKAGE, "EmptyResultSetException")
229 }
230 
231 object RoomRxJava2MemberNames {
232     val RX_ROOM_CREATE_FLOWABLE =
233         RoomRxJava2TypeNames.RX2_ROOM.companionMember("createFlowable", isJvmStatic = true)
234     val RX_ROOM_CREATE_OBSERVABLE =
235         RoomRxJava2TypeNames.RX2_ROOM.companionMember("createObservable", isJvmStatic = true)
236     val RX_ROOM_CREATE_SINGLE =
237         RoomRxJava2TypeNames.RX2_ROOM.companionMember("createSingle", isJvmStatic = true)
238     val RX_ROOM_CREATE_MAYBE =
239         RoomRxJava2TypeNames.RX2_ROOM.companionMember("createMaybe", isJvmStatic = true)
240     val RX_ROOM_CREATE_COMPLETABLE =
241         RoomRxJava2TypeNames.RX2_ROOM.companionMember("createCompletable", isJvmStatic = true)
242 }
243 
244 object RoomRxJava3TypeNames {
245     val RX3_ROOM = XClassName.get("$ROOM_PACKAGE.rxjava3", "RxRoom")
246     val RX3_ROOM_MARKER = XClassName.get("$ROOM_PACKAGE.rxjava3", "Rx3RoomArtifactMarker")
247     val RX3_EMPTY_RESULT_SET_EXCEPTION =
248         XClassName.get("$ROOM_PACKAGE.rxjava3", "EmptyResultSetException")
249 }
250 
251 object RoomRxJava3MemberNames {
252     val RX_ROOM_CREATE_FLOWABLE = RoomRxJava3TypeNames.RX3_ROOM.packageMember("createFlowable")
253     val RX_ROOM_CREATE_OBSERVABLE = RoomRxJava3TypeNames.RX3_ROOM.packageMember("createObservable")
254     val RX_ROOM_CREATE_SINGLE = RoomRxJava3TypeNames.RX3_ROOM.packageMember("createSingle")
255     val RX_ROOM_CREATE_MAYBE = RoomRxJava3TypeNames.RX3_ROOM.packageMember("createMaybe")
256     val RX_ROOM_CREATE_COMPLETABLE =
257         RoomRxJava3TypeNames.RX3_ROOM.packageMember("createCompletable")
258 }
259 
260 object RoomPagingTypeNames {
261     val LIMIT_OFFSET_PAGING_SOURCE =
262         XClassName.get("$ROOM_PACKAGE.paging", "LimitOffsetPagingSource")
263 }
264 
265 object RoomPagingGuavaTypeNames {
266     val LIMIT_OFFSET_LISTENABLE_FUTURE_PAGING_SOURCE =
267         XClassName.get("$ROOM_PACKAGE.paging.guava", "LimitOffsetListenableFuturePagingSource")
268 }
269 
270 object RoomPagingRx2TypeNames {
271     val LIMIT_OFFSET_RX_PAGING_SOURCE =
272         XClassName.get("$ROOM_PACKAGE.paging.rxjava2", "LimitOffsetRxPagingSource")
273 }
274 
275 object RoomPagingRx3TypeNames {
276     val LIMIT_OFFSET_RX_PAGING_SOURCE =
277         XClassName.get("$ROOM_PACKAGE.paging.rxjava3", "LimitOffsetRxPagingSource")
278 }
279 
280 object RoomCoroutinesTypeNames {
281     val COROUTINES_ROOM = XClassName.get(ROOM_PACKAGE, "CoroutinesRoom")
282 }
283 
284 object KotlinTypeNames {
285     val ANY = Any::class.asClassName()
286     val UNIT = XClassName.get("kotlin", "Unit")
287     val CONTINUATION = XClassName.get("kotlin.coroutines", "Continuation")
288     val CHANNEL = XClassName.get("kotlinx.coroutines.channels", "Channel")
289     val RECEIVE_CHANNEL = XClassName.get("kotlinx.coroutines.channels", "ReceiveChannel")
290     val SEND_CHANNEL = XClassName.get("kotlinx.coroutines.channels", "SendChannel")
291     val FLOW = XClassName.get("kotlinx.coroutines.flow", "Flow")
292     val LAZY = XClassName.get("kotlin", "Lazy")
293     val COLLECTIONS_KT = XClassName.get("kotlin.collections", "CollectionsKt")
294     val SETS_KT = XClassName.get("kotlin.collections", "SetsKt")
295     val MAPS_KT = XClassName.get("kotlin.collections", "MapsKt")
296     val STRING_BUILDER = XClassName.get("kotlin.text", "StringBuilder")
297     val LINKED_HASH_MAP = XClassName.get("kotlin.collections", "LinkedHashMap")
298 }
299 
300 object RoomMemberNames {
301     val DB_UTIL_QUERY = RoomTypeNames.DB_UTIL.packageMember("query")
302     val DB_UTIL_FOREIGN_KEY_CHECK = RoomTypeNames.DB_UTIL.packageMember("foreignKeyCheck")
303     val DB_UTIL_DROP_FTS_SYNC_TRIGGERS = RoomTypeNames.DB_UTIL.packageMember("dropFtsSyncTriggers")
304     val DB_UTIL_PERFORM_SUSPENDING = RoomTypeNames.DB_UTIL.packageMember("performSuspending")
305     val DB_UTIL_PERFORM_BLOCKING = RoomTypeNames.DB_UTIL.packageMember("performBlocking")
306     val DB_UTIL_SUPPORT_DB_TO_CONNECTION = RoomTypeNames.DB_UTIL.packageMember("toSQLiteConnection")
307     val DB_UTIL_PERFORM_IN_TRANSACTION_SUSPENDING =
308         RoomTypeNames.DB_UTIL.packageMember("performInTransactionSuspending")
309     val CURSOR_UTIL_GET_COLUMN_INDEX = RoomTypeNames.CURSOR_UTIL.packageMember("getColumnIndex")
310     val CURSOR_UTIL_GET_COLUMN_INDEX_OR_THROW =
311         RoomTypeNames.CURSOR_UTIL.packageMember("getColumnIndexOrThrow")
312     val CURSOR_UTIL_WRAP_MAPPED_COLUMNS =
313         RoomTypeNames.CURSOR_UTIL.packageMember("wrapMappedColumns")
314     val ROOM_SQL_QUERY_ACQUIRE =
315         RoomTypeNames.ROOM_SQL_QUERY.companionMember("acquire", isJvmStatic = true)
316     val ROOM_DATABASE_WITH_TRANSACTION = RoomTypeNames.ROOM_DB_KT.packageMember("withTransaction")
317     val TABLE_INFO_READ = RoomTypeNames.TABLE_INFO.companionMember("read", isJvmStatic = true)
318     val FTS_TABLE_INFO_READ =
319         RoomTypeNames.FTS_TABLE_INFO.companionMember("read", isJvmStatic = true)
320     val VIEW_INFO_READ = RoomTypeNames.VIEW_INFO.companionMember("read", isJvmStatic = true)
321 }
322 
323 object SQLiteDriverMemberNames {
324     val CONNECTION_EXEC_SQL = SQLiteDriverTypeNames.SQLITE.packageMember("execSQL")
325 }
326 
327 val DEFERRED_TYPES =
328     listOf(
329         LifecyclesTypeNames.LIVE_DATA,
330         LifecyclesTypeNames.COMPUTABLE_LIVE_DATA,
331         RxJava2TypeNames.FLOWABLE,
332         RxJava2TypeNames.OBSERVABLE,
333         RxJava2TypeNames.MAYBE,
334         RxJava2TypeNames.SINGLE,
335         RxJava2TypeNames.COMPLETABLE,
336         RxJava3TypeNames.FLOWABLE,
337         RxJava3TypeNames.OBSERVABLE,
338         RxJava3TypeNames.MAYBE,
339         RxJava3TypeNames.SINGLE,
340         RxJava3TypeNames.COMPLETABLE,
341         GuavaUtilConcurrentTypeNames.LISTENABLE_FUTURE,
342         KotlinTypeNames.FLOW,
343         ReactiveStreamsTypeNames.PUBLISHER,
344         PagingTypeNames.PAGING_SOURCE
345     )
346 
XTypeNamenull347 fun XTypeName.defaultValue(): String {
348     return if (!isPrimitive) {
349         "null"
350     } else if (this == XTypeName.PRIMITIVE_BOOLEAN) {
351         "false"
352     } else if (this == XTypeName.PRIMITIVE_DOUBLE) {
353         "0.0"
354     } else if (this == XTypeName.PRIMITIVE_FLOAT) {
355         "0f"
356     } else {
357         "0"
358     }
359 }
360 
CallableTypeSpecBuildernull361 fun CallableTypeSpecBuilder(parameterTypeName: XTypeName, callBody: XFunSpec.Builder.() -> Unit) =
362     XTypeSpec.anonymousClassBuilder("").apply {
363         addSuperinterface(CommonTypeNames.CALLABLE.parametrizedBy(parameterTypeName))
364         addFunction(
365             XFunSpec.builder(
366                     name = "call",
367                     visibility = VisibilityModifier.PUBLIC,
368                     isOverride = true
369                 )
370                 .apply {
371                     returns(parameterTypeName)
372                     callBody()
373                 }
374                 .applyToJavaPoet { addException(JTypeName.get(Exception::class.java)) }
375                 .build()
376         )
377     }
378 
Function1TypeSpecnull379 fun Function1TypeSpec(
380     parameterTypeName: XTypeName,
381     parameterName: String,
382     returnTypeName: XTypeName,
383     callBody: XFunSpec.Builder.() -> Unit
384 ) =
385     XTypeSpec.anonymousClassBuilder("")
386         .apply {
387             superclass(
388                 Function1::class.asClassName().parametrizedBy(parameterTypeName, returnTypeName)
389             )
390             addFunction(
391                 XFunSpec.builder(
392                         name = "invoke",
393                         visibility = VisibilityModifier.PUBLIC,
394                         isOverride = true
395                     )
396                     .apply {
397                         addParameter(parameterName, parameterTypeName)
398                         returns(returnTypeName)
399                         callBody()
400                     }
401                     .build()
402             )
403         }
404         .build()
405 
406 /**
407  * Short-hand of [InvokeWithLambdaParameter] whose function call is a member function, i.e. a
408  * top-level function or a companion object function.
409  */
InvokeWithLambdaParameternull410 fun InvokeWithLambdaParameter(
411     scope: CodeGenScope,
412     functionName: XMemberName,
413     argFormat: List<String>,
414     args: List<Any>,
415     continuationParamName: String? = null,
416     lambdaSpec: LambdaSpec
417 ): XCodeBlock {
418     val functionCall = XCodeBlock.of("%M", functionName)
419     return InvokeWithLambdaParameter(
420         scope,
421         functionCall,
422         argFormat,
423         args,
424         continuationParamName,
425         lambdaSpec
426     )
427 }
428 
429 /**
430  * Generates a code block that invokes a function with a functional type as last parameter.
431  *
432  * For Java (jvmTarget >= 8) it will generate:
433  * ```
434  * <functionCall>(<args>, (<lambdaSpec.paramName>) -> <lambdaSpec.body>);
435  * ```
436  *
437  * For Java (jvmTarget < 8) it will generate:
438  * ```
439  * <functionCall>(<args>, new Function1<>() { <lambdaSpec.body> });
440  * ```
441  *
442  * For Kotlin it will generate:
443  * ```
444  * <functionCall>(<args>) { <lambdaSpec.body> }
445  * ```
446  *
447  * The [functionCall] must only be an expression up to a function name without the parenthesis. Its
448  * last parameter must also be a functional type. The [argFormat] and [args] are for the arguments
449  * of the function excluding the functional parameter.
450  *
451  * The ideal usage of this utility function is to generate code that invokes the various
452  * `DBUtil.perform*()` APIs for interacting with the database connection in DAOs.
453  */
InvokeWithLambdaParameternull454 fun InvokeWithLambdaParameter(
455     scope: CodeGenScope,
456     functionCall: XCodeBlock,
457     argFormat: List<String>,
458     args: List<Any>,
459     continuationParamName: String? = null,
460     lambdaSpec: LambdaSpec
461 ) = buildCodeBlock { language ->
462     check(argFormat.size == args.size)
463     when (language) {
464         CodeLanguage.JAVA -> {
465             if (lambdaSpec.javaLambdaSyntaxAvailable) {
466                 val argsFormatString = argFormat.joinToString(separator = ", ")
467                 add(
468                     "%L($argsFormatString, (%L) -> {\n",
469                     functionCall,
470                     *args.toTypedArray(),
471                     lambdaSpec.parameterName
472                 )
473                 indent()
474                 val bodyScope = scope.fork()
475                 with(lambdaSpec) { bodyScope.builder.body(bodyScope) }
476                 add(bodyScope.generate())
477                 unindent()
478                 add("}")
479                 if (continuationParamName != null) {
480                     add(", %L", continuationParamName)
481                 }
482                 add(");\n")
483             } else {
484                 val adjustedArgsFormatString =
485                     buildList {
486                             addAll(argFormat)
487                             add("%L") // the anonymous function
488                             if (continuationParamName != null) {
489                                 add("%L")
490                             }
491                         }
492                         .joinToString(separator = ", ")
493                 val adjustedArgs = buildList {
494                     addAll(args)
495                     add(
496                         Function1TypeSpec(
497                             parameterTypeName = lambdaSpec.parameterTypeName,
498                             parameterName = lambdaSpec.parameterName,
499                             returnTypeName = lambdaSpec.returnTypeName,
500                             callBody = {
501                                 val bodyScope = scope.fork()
502                                 with(lambdaSpec) { bodyScope.builder.body(bodyScope) }
503                                 addCode(bodyScope.generate())
504                             }
505                         )
506                     )
507                     if (continuationParamName != null) {
508                         add(continuationParamName)
509                     }
510                 }
511                 add(
512                     "%L($adjustedArgsFormatString);\n",
513                     functionCall,
514                     *adjustedArgs.toTypedArray(),
515                 )
516             }
517         }
518         CodeLanguage.KOTLIN -> {
519             val argsFormatString = argFormat.joinToString(separator = ", ")
520             if (lambdaSpec.parameterTypeName.rawTypeName != KotlinTypeNames.CONTINUATION) {
521                 add(
522                     "%L($argsFormatString) { %L ->\n",
523                     functionCall,
524                     *args.toTypedArray(),
525                     lambdaSpec.parameterName
526                 )
527             } else {
528                 add(
529                     "%L($argsFormatString) {\n",
530                     functionCall,
531                     *args.toTypedArray(),
532                 )
533             }
534             indent()
535             val bodyScope = scope.fork()
536             with(lambdaSpec) { bodyScope.builder.body(bodyScope) }
537             add(bodyScope.generate())
538             unindent()
539             add("}\n")
540         }
541     }
542 }
543 
544 /** Describes the lambda to be generated with [InvokeWithLambdaParameter]. */
545 abstract class LambdaSpec(
546     val parameterTypeName: XTypeName,
547     val parameterName: String,
548     val returnTypeName: XTypeName,
549     val javaLambdaSyntaxAvailable: Boolean
550 ) {
bodynull551     abstract fun XCodeBlock.Builder.body(scope: CodeGenScope)
552 }
553 
554 /**
555  * Generates an array literal with the given [values]
556  *
557  * Example: `ArrayLiteral(XTypeName.PRIMITIVE_INT, 1, 2, 3)`
558  *
559  * For Java will produce: `new int[] {1, 2, 3}`
560  *
561  * For Kotlin will produce: `intArrayOf(1, 2, 3)`,
562  */
563 fun ArrayLiteral(type: XTypeName, vararg values: Any) = buildCodeBlock { language ->
564     val space =
565         when (language) {
566             CodeLanguage.JAVA -> "%W"
567             CodeLanguage.KOTLIN -> " "
568         }
569     val initExpr =
570         when (language) {
571             CodeLanguage.JAVA -> XCodeBlock.of("new %T[] ", type)
572             CodeLanguage.KOTLIN -> XCodeBlock.of(getArrayOfFunction(type))
573         }
574     val openingChar =
575         when (language) {
576             CodeLanguage.JAVA -> "{"
577             CodeLanguage.KOTLIN -> "("
578         }
579     val closingChar =
580         when (language) {
581             CodeLanguage.JAVA -> "}"
582             CodeLanguage.KOTLIN -> ")"
583         }
584     add(
585         "%L$openingChar%L$closingChar",
586         initExpr,
587         XCodeBlock.builder()
588             .apply {
589                 val joining =
590                     Array(values.size) { i ->
591                         XCodeBlock.of(if (type == CommonTypeNames.STRING) "%S" else "%L", values[i])
592                     }
593                 val placeholders = joining.joinToString(separator = ",$space") { "%L" }
594                 add(placeholders, *joining)
595             }
596             .build()
597     )
598 }
599 
600 /**
601  * Generates a 2D array literal where the value at `i`,`j` will be produced by `valueProducer. For
602  * example:
603  * ```
604  * DoubleArrayLiteral(XTypeName.PRIMITIVE_INT, 2, { _ -> 3 }, { i, j -> i + j })
605  * ```
606  *
607  * For Java will produce:
608  * ```
609  * new int[][] {
610  *   {0, 1, 2},
611  *   {1, 2, 3}
612  * }
613  * ```
614  *
615  * For Kotlin will produce:
616  * ```
617  * arrayOf(
618  *   intArrayOf(0, 1, 2),
619  *   intArrayOf(1, 2, 3)
620  * )
621  * ```
622  */
DoubleArrayLiteralnull623 fun DoubleArrayLiteral(
624     type: XTypeName,
625     rowSize: Int,
626     columnSizeProducer: (Int) -> Int,
627     valueProducer: (Int, Int) -> Any
628 ) = buildCodeBlock { language ->
629     val space =
630         when (language) {
631             CodeLanguage.JAVA -> "%W"
632             CodeLanguage.KOTLIN -> " "
633         }
634     val outerInit =
635         when (language) {
636             CodeLanguage.JAVA -> XCodeBlock.of("new %T[][] ", type)
637             CodeLanguage.KOTLIN -> XCodeBlock.of("arrayOf")
638         }
639     val innerInit =
640         when (language) {
641             CodeLanguage.JAVA -> XCodeBlock.of("", type)
642             CodeLanguage.KOTLIN -> XCodeBlock.of(getArrayOfFunction(type))
643         }
644     val openingChar =
645         when (language) {
646             CodeLanguage.JAVA -> "{"
647             CodeLanguage.KOTLIN -> "("
648         }
649     val closingChar =
650         when (language) {
651             CodeLanguage.JAVA -> "}"
652             CodeLanguage.KOTLIN -> ")"
653         }
654     add(
655         "%L$openingChar%L$closingChar",
656         outerInit,
657         XCodeBlock.builder()
658             .apply {
659                 val joining =
660                     Array(rowSize) { i ->
661                         XCodeBlock.of(
662                             "%L$openingChar%L$closingChar",
663                             innerInit,
664                             XCodeBlock.builder()
665                                 .apply {
666                                     val joining =
667                                         Array(columnSizeProducer(i)) { j ->
668                                             XCodeBlock.of(
669                                                 if (type == CommonTypeNames.STRING) "%S" else "%L",
670                                                 valueProducer(i, j)
671                                             )
672                                         }
673                                     val placeholders =
674                                         joining.joinToString(separator = ",$space") { "%L" }
675                                     add(placeholders, *joining)
676                                 }
677                                 .build()
678                         )
679                     }
680                 val placeholders = joining.joinToString(separator = ",$space") { "%L" }
681                 add(placeholders, *joining)
682             }
683             .build()
684     )
685 }
686 
getArrayOfFunctionnull687 private fun getArrayOfFunction(type: XTypeName) =
688     when (type) {
689         XTypeName.PRIMITIVE_BOOLEAN -> "booleanArrayOf"
690         XTypeName.PRIMITIVE_BYTE -> "byteArrayOf"
691         XTypeName.PRIMITIVE_SHORT -> "shortArrayOf"
692         XTypeName.PRIMITIVE_INT -> "intArrayOf"
693         XTypeName.PRIMITIVE_LONG -> "longArrayOf"
694         XTypeName.PRIMITIVE_CHAR -> "charArrayOf"
695         XTypeName.PRIMITIVE_FLOAT -> "floatArrayOf"
696         XTypeName.PRIMITIVE_DOUBLE -> "doubleArrayOf"
697         else -> "arrayOf"
698     }
699 
getToArrayFunctionnull700 fun getToArrayFunction(type: XTypeName) =
701     when (type) {
702         XTypeName.PRIMITIVE_BOOLEAN -> "toBooleanArray()"
703         XTypeName.PRIMITIVE_BYTE -> "toByteArray()"
704         XTypeName.PRIMITIVE_SHORT -> "toShortArray()"
705         XTypeName.PRIMITIVE_INT -> "toIntArray()"
706         XTypeName.PRIMITIVE_LONG -> "toLongArray()"
707         XTypeName.PRIMITIVE_CHAR -> "toCharArray()"
708         XTypeName.PRIMITIVE_FLOAT -> "toFloatArray()"
709         XTypeName.PRIMITIVE_DOUBLE -> "toDoubleArray()"
710         else -> error("Provided type expected to be primitive. Found: $type")
711     }
712 
713 /** Code of expression for [Collection.size] in Kotlin, and [java.util.Collection.size] for Java. */
CollectionsSizeExprCodenull714 fun CollectionsSizeExprCode(varName: String) = buildCodeBlock { language ->
715     add(
716         when (language) {
717             CodeLanguage.JAVA -> "%L.size()" // java.util.Collections.size()
718             CodeLanguage.KOTLIN -> "%L.size" // kotlin.collections.Collection.size
719         },
720         varName
721     )
722 }
723 
724 /** Code of expression for [Array.size] in Kotlin, and `arr.length` for Java. */
ArraySizeExprCodenull725 fun ArraySizeExprCode(varName: String) = buildCodeBlock { language ->
726     add(
727         when (language) {
728             CodeLanguage.JAVA -> "%L.length" // Just `arr.length`
729             CodeLanguage.KOTLIN -> "%L.size" // kotlin.Array.size and primitives (e.g. IntArray)
730         },
731         varName
732     )
733 }
734 
735 /** Code of expression for [Map.keys] in Kotlin, and [java.util.Map.keySet] for Java. */
MapKeySetExprCodenull736 fun MapKeySetExprCode(varName: String) = buildCodeBlock { language ->
737     add(
738         when (language) {
739             CodeLanguage.JAVA -> "%L.keySet()" // java.util.Map.keySet()
740             CodeLanguage.KOTLIN -> "%L.keys" // kotlin.collections.Map.keys
741         },
742         varName
743     )
744 }
745