1 /*
<lambda>null2  * Copyright 2024 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package androidx.room.coroutines
18 
19 import androidx.collection.CircularArray
20 import androidx.room.TransactionScope
21 import androidx.room.Transactor
22 import androidx.room.Transactor.SQLiteTransactionType
23 import androidx.room.concurrent.AtomicBoolean
24 import androidx.room.concurrent.ReentrantLock
25 import androidx.room.concurrent.ThreadLocal
26 import androidx.room.concurrent.asContextElement
27 import androidx.room.concurrent.currentThreadId
28 import androidx.room.concurrent.withLock
29 import androidx.room.util.SQLiteResultCode.SQLITE_BUSY
30 import androidx.room.util.SQLiteResultCode.SQLITE_ERROR
31 import androidx.room.util.SQLiteResultCode.SQLITE_MISUSE
32 import androidx.sqlite.SQLiteConnection
33 import androidx.sqlite.SQLiteDriver
34 import androidx.sqlite.SQLiteException
35 import androidx.sqlite.SQLiteStatement
36 import androidx.sqlite.execSQL
37 import androidx.sqlite.throwSQLiteException
38 import kotlin.collections.removeLast as removeLastKt
39 import kotlin.coroutines.CoroutineContext
40 import kotlin.coroutines.coroutineContext
41 import kotlin.time.Duration.Companion.seconds
42 import kotlinx.coroutines.TimeoutCancellationException
43 import kotlinx.coroutines.sync.Mutex
44 import kotlinx.coroutines.sync.Semaphore
45 import kotlinx.coroutines.sync.withLock
46 import kotlinx.coroutines.withContext
47 import kotlinx.coroutines.withTimeout
48 
49 internal class ConnectionPoolImpl : ConnectionPool {
50     private val driver: SQLiteDriver
51     private val readers: Pool
52     private val writers: Pool
53 
54     private val threadLocal = ThreadLocal<PooledConnectionImpl>()
55 
56     private val _isClosed = AtomicBoolean(false)
57     private val isClosed: Boolean
58         get() = _isClosed.get()
59 
60     // Amount of time to wait to acquire a connection before throwing, Android uses 30 seconds in
61     // its pool, so we do too here, but IDK if that is a good number. This timeout is unrelated to
62     // the busy handler.
63     // TODO(b/404380974): Allow public configuration
64     internal var timeout = 30.seconds
65 
66     constructor(driver: SQLiteDriver, fileName: String) {
67         this.driver = driver
68         this.readers = Pool(capacity = 1, connectionFactory = { driver.open(fileName) })
69         this.writers = readers
70     }
71 
72     constructor(
73         driver: SQLiteDriver,
74         fileName: String,
75         maxNumOfReaders: Int,
76         maxNumOfWriters: Int,
77     ) {
78         require(maxNumOfReaders > 0) { "Maximum number of readers must be greater than 0" }
79         require(maxNumOfWriters > 0) { "Maximum number of writers must be greater than 0" }
80         this.driver = driver
81         this.readers =
82             Pool(
83                 capacity = maxNumOfReaders,
84                 connectionFactory = {
85                     driver.open(fileName).also { newConnection ->
86                         // Enforce to be read only (might be disabled by a YOLO developer)
87                         newConnection.execSQL("PRAGMA query_only = 1")
88                     }
89                 }
90             )
91         this.writers =
92             Pool(capacity = maxNumOfWriters, connectionFactory = { driver.open(fileName) })
93     }
94 
95     override suspend fun <R> useConnection(
96         isReadOnly: Boolean,
97         block: suspend (Transactor) -> R
98     ): R {
99         if (isClosed) {
100             throwSQLiteException(SQLITE_MISUSE, "Connection pool is closed")
101         }
102         val confinedConnection =
103             threadLocal.get() ?: coroutineContext[ConnectionElement]?.connectionWrapper
104         if (confinedConnection != null) {
105             if (!isReadOnly && confinedConnection.isReadOnly) {
106                 throwSQLiteException(
107                     SQLITE_ERROR,
108                     "Cannot upgrade connection from reader to writer"
109                 )
110             }
111             return if (coroutineContext[ConnectionElement] == null) {
112                 // Reinstall the connection context element if it is missing. We are likely in
113                 // a new coroutine but were able to transfer the connection via the thread local.
114                 withContext(createConnectionContext(confinedConnection)) {
115                     block.invoke(confinedConnection)
116                 }
117             } else {
118                 block.invoke(confinedConnection)
119             }
120         }
121         val pool =
122             if (isReadOnly) {
123                 readers
124             } else {
125                 writers
126             }
127         val result: R
128         var exception: Throwable? = null
129         var connection: PooledConnectionImpl? = null
130         try {
131             val currentContext = coroutineContext
132             val (acquiredConnection, acquireError) = pool.acquireWithTimeout()
133             // Always try to create a wrapper even if an error occurs, so it can be recycled.
134             connection =
135                 acquiredConnection?.let {
136                     PooledConnectionImpl(
137                         delegate = it.markAcquired(currentContext),
138                         isReadOnly = readers !== writers && isReadOnly
139                     )
140                 }
141             if (acquireError is TimeoutCancellationException) {
142                 throwTimeoutException(isReadOnly)
143             } else if (acquireError != null) {
144                 throw acquireError
145             }
146             requireNotNull(connection)
147             result = withContext(createConnectionContext(connection)) { block.invoke(connection) }
148         } catch (ex: Throwable) {
149             exception = ex
150             throw ex
151         } finally {
152             try {
153                 connection?.let { usedConnection ->
154                     usedConnection.markRecycled()
155                     usedConnection.delegate.markReleased()
156                     pool.recycle(usedConnection.delegate)
157                 }
158             } catch (error: Throwable) {
159                 exception?.addSuppressed(error)
160             }
161         }
162         return result
163     }
164 
165     private suspend inline fun Pool.acquireWithTimeout(): Pair<ConnectionWithLock?, Throwable?> {
166         // Following async timeout with resources recommendation:
167         // https://kotlinlang.org/docs/cancellation-and-timeouts.html#asynchronous-timeout-and-resources
168         var connection: ConnectionWithLock? = null
169         var exceptionThrown: Throwable? = null
170         try {
171             withTimeout(timeout) { connection = this@acquireWithTimeout.acquire() }
172         } catch (ex: Throwable) {
173             exceptionThrown = ex
174         }
175         return connection to exceptionThrown
176     }
177 
178     private fun createConnectionContext(connection: PooledConnectionImpl) =
179         ConnectionElement(connection) + threadLocal.asContextElement(connection)
180 
181     private fun throwTimeoutException(isReadOnly: Boolean): Nothing {
182         val readOrWrite = if (isReadOnly) "reader" else "writer"
183         val message = buildString {
184             appendLine("Timed out attempting to acquire a $readOrWrite connection.")
185             appendLine()
186             appendLine("Writer pool:")
187             writers.dump(this)
188             appendLine("Reader pool:")
189             readers.dump(this)
190         }
191         throwSQLiteException(SQLITE_BUSY, message)
192     }
193 
194     // TODO: (b/319657104): Make suspending so pool closes when all connections are recycled.
195     override fun close() {
196         if (_isClosed.compareAndSet(expect = false, update = true)) {
197             readers.close()
198             writers.close()
199         }
200     }
201 }
202 
203 private class Pool(val capacity: Int, val connectionFactory: () -> SQLiteConnection) {
204     private val lock = ReentrantLock()
205     private var size = 0
206     private var isClosed = false
207     private val connections = arrayOfNulls<ConnectionWithLock>(capacity)
208     private val connectionPermits = Semaphore(permits = capacity)
209     private val availableConnections = CircularArray<ConnectionWithLock>(capacity)
210 
acquirenull211     suspend fun acquire(): ConnectionWithLock {
212         connectionPermits.acquire()
213         try {
214             return lock.withLock {
215                 if (isClosed) {
216                     throwSQLiteException(SQLITE_MISUSE, "Connection pool is closed")
217                 }
218                 if (availableConnections.isEmpty()) {
219                     tryOpenNewConnectionLocked()
220                 }
221                 availableConnections.popFirst()
222             }
223         } catch (ex: Throwable) {
224             connectionPermits.release()
225             throw ex
226         }
227     }
228 
tryOpenNewConnectionLockednull229     private fun tryOpenNewConnectionLocked() {
230         if (size >= capacity) {
231             // Capacity reached
232             return
233         }
234         val newConnection = ConnectionWithLock(connectionFactory.invoke())
235         connections[size++] = newConnection
236         availableConnections.addLast(newConnection)
237     }
238 
recyclenull239     fun recycle(connection: ConnectionWithLock) {
240         lock.withLock { availableConnections.addLast(connection) }
241         connectionPermits.release()
242     }
243 
closenull244     fun close() {
245         lock.withLock {
246             isClosed = true
247             connections.forEach { it?.close() }
248         }
249     }
250 
251     /* Dumps debug information */
dumpnull252     fun dump(builder: StringBuilder) =
253         lock.withLock {
254             val availableQueue = buildList {
255                 for (i in 0 until availableConnections.size()) {
256                     add(availableConnections[i])
257                 }
258             }
259             builder.append("\t" + super.toString() + " (")
260             builder.append("capacity=$capacity, ")
261             builder.append("permits=${connectionPermits.availablePermits}, ")
262             builder.append(
263                 "queue=(size=${availableQueue.size})[${availableQueue.joinToString()}], "
264             )
265             builder.appendLine(")")
266             connections.forEachIndexed { index, connection ->
267                 builder.appendLine("\t\t[${index + 1}] - ${connection?.toString()}")
268                 connection?.dump(builder)
269             }
270         }
271 }
272 
273 private class ConnectionWithLock(
274     private val delegate: SQLiteConnection,
275     private val lock: Mutex = Mutex()
<lambda>null276 ) : SQLiteConnection by delegate, Mutex by lock {
277 
278     private var acquireCoroutineContext: CoroutineContext? = null
279     private var acquireThrowable: Throwable? = null
280 
281     fun markAcquired(context: CoroutineContext) = apply {
282         acquireCoroutineContext = context
283         acquireThrowable = Throwable()
284     }
285 
286     fun markReleased() = apply {
287         acquireCoroutineContext = null
288         acquireThrowable = null
289     }
290 
291     /* Dumps debug information */
292     fun dump(builder: StringBuilder) {
293         if (acquireCoroutineContext != null || acquireThrowable != null) {
294             builder.appendLine("\t\tStatus: Acquired connection")
295             acquireCoroutineContext?.let { builder.appendLine("\t\tCoroutine: $it") }
296             acquireThrowable?.let {
297                 builder.appendLine("\t\tAcquired:")
298                 it.stackTraceToString().lines().drop(1).forEach { line ->
299                     builder.appendLine("\t\t$line")
300                 }
301             }
302         } else {
303             builder.appendLine("\t\tStatus: Free connection")
304         }
305     }
306 
307     override fun toString(): String {
308         return delegate.toString()
309     }
310 }
311 
312 private class ConnectionElement(val connectionWrapper: PooledConnectionImpl) :
313     CoroutineContext.Element {
314     companion object Key : CoroutineContext.Key<ConnectionElement>
315 
316     override val key: CoroutineContext.Key<ConnectionElement>
317         get() = ConnectionElement
318 }
319 
320 /**
321  * A connection wrapper to enforce pool contract and implement transactions.
322  *
323  * Actual connection interactions are serialized via a limited dispatcher, specifically compiling a
324  * statement and using it is serialized as to prevent a coroutine from concurrently using the
325  * statement between multiple different threads.
326  */
327 private class PooledConnectionImpl(
328     val delegate: ConnectionWithLock,
329     val isReadOnly: Boolean,
330 ) : Transactor, RawConnectionAccessor {
331     private val transactionStack = ArrayDeque<TransactionItem>()
332 
333     private val _isRecycled = AtomicBoolean(false)
334     private val isRecycled: Boolean
335         get() = _isRecycled.get()
336 
337     override val rawConnection: SQLiteConnection
338         get() = delegate
339 
usePreparednull340     override suspend fun <R> usePrepared(sql: String, block: (SQLiteStatement) -> R): R =
341         withStateCheck {
342             return delegate.withLock {
343                 StatementWrapper(delegate.prepare(sql)).use { block.invoke(it) }
344             }
345         }
346 
withTransactionnull347     override suspend fun <R> withTransaction(
348         type: SQLiteTransactionType,
349         block: suspend TransactionScope<R>.() -> R
350     ): R = withStateCheck { transaction(type, block) }
351 
<lambda>null352     override suspend fun inTransaction(): Boolean = withStateCheck {
353         return transactionStack.isNotEmpty()
354     }
355 
markRecyclednull356     fun markRecycled() {
357         if (_isRecycled.compareAndSet(expect = false, update = true)) {
358             // Perform a rollback in case there is an active transaction so that the connection
359             // is in a clean state when it is recycled. We don't know for sure if there is an
360             // unfinished transaction, hence we always try the rollback.
361             // TODO(b/319627988): Try to *really* check if there is an active transaction with the
362             //     C APIs sqlite3_txn_state or sqlite3_get_autocommit and possibly throw an error
363             //     if there is an unfinished transaction.
364             try {
365                 delegate.execSQL("ROLLBACK TRANSACTION")
366             } catch (_: SQLiteException) {
367                 // ignored
368             }
369         }
370     }
371 
transactionnull372     private suspend fun <R> transaction(
373         type: SQLiteTransactionType?,
374         block: suspend TransactionScope<R>.() -> R
375     ): R {
376         beginTransaction(type ?: SQLiteTransactionType.DEFERRED)
377         var success = true
378         var exception: Throwable? = null
379         try {
380             return TransactionImpl<R>().block()
381         } catch (ex: Throwable) {
382             success = false
383             if (ex is ConnectionPool.RollbackException) {
384                 // Type arguments in exception subclasses is not allowed but the exception is always
385                 // created with the correct type.
386                 @Suppress("UNCHECKED_CAST") return (ex.result as R)
387             } else {
388                 exception = ex
389                 throw ex
390             }
391         } finally {
392             try {
393                 endTransaction(success)
394             } catch (ex: SQLiteException) {
395                 exception?.addSuppressed(ex) ?: throw ex
396             }
397         }
398     }
399 
beginTransactionnull400     private suspend fun beginTransaction(type: SQLiteTransactionType) =
401         delegate.withLock {
402             val newTransactionId = transactionStack.size
403             if (transactionStack.isEmpty()) {
404                 when (type) {
405                     SQLiteTransactionType.DEFERRED -> delegate.execSQL("BEGIN DEFERRED TRANSACTION")
406                     SQLiteTransactionType.IMMEDIATE ->
407                         delegate.execSQL("BEGIN IMMEDIATE TRANSACTION")
408                     SQLiteTransactionType.EXCLUSIVE ->
409                         delegate.execSQL("BEGIN EXCLUSIVE TRANSACTION")
410                 }
411             } else {
412                 delegate.execSQL("SAVEPOINT '$newTransactionId'")
413             }
414             transactionStack.addLast(TransactionItem(id = newTransactionId, shouldRollback = false))
415         }
416 
endTransactionnull417     private suspend fun endTransaction(success: Boolean) =
418         delegate.withLock {
419             if (transactionStack.isEmpty()) {
420                 error("Not in a transaction")
421             }
422             val transaction = transactionStack.removeLastKt()
423             if (success && !transaction.shouldRollback) {
424                 if (transactionStack.isEmpty()) {
425                     delegate.execSQL("END TRANSACTION")
426                 } else {
427                     delegate.execSQL("RELEASE SAVEPOINT '${transaction.id}'")
428                 }
429             } else {
430                 if (transactionStack.isEmpty()) {
431                     delegate.execSQL("ROLLBACK TRANSACTION")
432                 } else {
433                     delegate.execSQL("ROLLBACK TRANSACTION TO SAVEPOINT '${transaction.id}'")
434                 }
435             }
436         }
437 
438     private class TransactionItem(val id: Int, var shouldRollback: Boolean)
439 
440     private inner class TransactionImpl<T> : TransactionScope<T>, RawConnectionAccessor {
441 
442         override val rawConnection: SQLiteConnection
443             get() = this@PooledConnectionImpl.rawConnection
444 
usePreparednull445         override suspend fun <R> usePrepared(sql: String, block: (SQLiteStatement) -> R): R =
446             this@PooledConnectionImpl.usePrepared(sql, block)
447 
448         override suspend fun <R> withNestedTransaction(
449             block: suspend (TransactionScope<R>) -> R
450         ): R = withStateCheck { transaction(null, block) }
451 
rollbacknull452         override suspend fun rollback(result: T): Nothing = withStateCheck {
453             if (transactionStack.isEmpty()) {
454                 error("Not in a transaction")
455             }
456             delegate.withLock { transactionStack.last().shouldRollback = true }
457             throw ConnectionPool.RollbackException(result)
458         }
459     }
460 
withStateChecknull461     private suspend inline fun <R> withStateCheck(block: () -> R): R {
462         if (isRecycled) {
463             throwSQLiteException(SQLITE_MISUSE, "Connection is recycled")
464         }
465         val connectionElement = coroutineContext[ConnectionElement]
466         if (connectionElement == null || connectionElement.connectionWrapper !== this) {
467             throwSQLiteException(
468                 SQLITE_MISUSE,
469                 "Attempted to use connection on a different coroutine"
470             )
471         }
472         return block.invoke()
473     }
474 
475     private inner class StatementWrapper(
476         private val delegate: SQLiteStatement,
477     ) : SQLiteStatement {
478 
479         private val threadId = currentThreadId()
480 
<lambda>null481         override fun bindBlob(index: Int, value: ByteArray): Unit = withStateCheck {
482             delegate.bindBlob(index, value)
483         }
484 
<lambda>null485         override fun bindDouble(index: Int, value: Double): Unit = withStateCheck {
486             delegate.bindDouble(index, value)
487         }
488 
<lambda>null489         override fun bindLong(index: Int, value: Long): Unit = withStateCheck {
490             delegate.bindLong(index, value)
491         }
492 
<lambda>null493         override fun bindText(index: Int, value: String): Unit = withStateCheck {
494             delegate.bindText(index, value)
495         }
496 
<lambda>null497         override fun bindNull(index: Int): Unit = withStateCheck { delegate.bindNull(index) }
498 
<lambda>null499         override fun getBlob(index: Int): ByteArray = withStateCheck { delegate.getBlob(index) }
500 
<lambda>null501         override fun getDouble(index: Int): Double = withStateCheck { delegate.getDouble(index) }
502 
getLongnull503         override fun getLong(index: Int): Long = withStateCheck { delegate.getLong(index) }
504 
<lambda>null505         override fun getText(index: Int): String = withStateCheck { delegate.getText(index) }
506 
<lambda>null507         override fun isNull(index: Int): Boolean = withStateCheck { delegate.isNull(index) }
508 
<lambda>null509         override fun getColumnCount(): Int = withStateCheck { delegate.getColumnCount() }
510 
<lambda>null511         override fun getColumnName(index: Int) = withStateCheck { delegate.getColumnName(index) }
512 
<lambda>null513         override fun getColumnType(index: Int) = withStateCheck { delegate.getColumnType(index) }
514 
<lambda>null515         override fun step(): Boolean = withStateCheck { delegate.step() }
516 
<lambda>null517         override fun reset() = withStateCheck { delegate.reset() }
518 
<lambda>null519         override fun clearBindings() = withStateCheck { delegate.clearBindings() }
520 
<lambda>null521         override fun close() = withStateCheck { delegate.close() }
522 
withStateChecknull523         private inline fun <R> withStateCheck(block: () -> R): R {
524             if (isRecycled) {
525                 throwSQLiteException(SQLITE_MISUSE, "Statement is recycled")
526             }
527             if (threadId != currentThreadId()) {
528                 throwSQLiteException(
529                     SQLITE_MISUSE,
530                     "Attempted to use statement on a different thread"
531                 )
532             }
533             return block.invoke()
534         }
535     }
536 }
537