1 /*
<lambda>null2 * Copyright 2024 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 package androidx.room.coroutines
18
19 import androidx.collection.CircularArray
20 import androidx.room.TransactionScope
21 import androidx.room.Transactor
22 import androidx.room.Transactor.SQLiteTransactionType
23 import androidx.room.concurrent.AtomicBoolean
24 import androidx.room.concurrent.ReentrantLock
25 import androidx.room.concurrent.ThreadLocal
26 import androidx.room.concurrent.asContextElement
27 import androidx.room.concurrent.currentThreadId
28 import androidx.room.concurrent.withLock
29 import androidx.room.util.SQLiteResultCode.SQLITE_BUSY
30 import androidx.room.util.SQLiteResultCode.SQLITE_ERROR
31 import androidx.room.util.SQLiteResultCode.SQLITE_MISUSE
32 import androidx.sqlite.SQLiteConnection
33 import androidx.sqlite.SQLiteDriver
34 import androidx.sqlite.SQLiteException
35 import androidx.sqlite.SQLiteStatement
36 import androidx.sqlite.execSQL
37 import androidx.sqlite.throwSQLiteException
38 import kotlin.collections.removeLast as removeLastKt
39 import kotlin.coroutines.CoroutineContext
40 import kotlin.coroutines.coroutineContext
41 import kotlin.time.Duration.Companion.seconds
42 import kotlinx.coroutines.TimeoutCancellationException
43 import kotlinx.coroutines.sync.Mutex
44 import kotlinx.coroutines.sync.Semaphore
45 import kotlinx.coroutines.sync.withLock
46 import kotlinx.coroutines.withContext
47 import kotlinx.coroutines.withTimeout
48
49 internal class ConnectionPoolImpl : ConnectionPool {
50 private val driver: SQLiteDriver
51 private val readers: Pool
52 private val writers: Pool
53
54 private val threadLocal = ThreadLocal<PooledConnectionImpl>()
55
56 private val _isClosed = AtomicBoolean(false)
57 private val isClosed: Boolean
58 get() = _isClosed.get()
59
60 // Amount of time to wait to acquire a connection before throwing, Android uses 30 seconds in
61 // its pool, so we do too here, but IDK if that is a good number. This timeout is unrelated to
62 // the busy handler.
63 // TODO(b/404380974): Allow public configuration
64 internal var timeout = 30.seconds
65
66 constructor(driver: SQLiteDriver, fileName: String) {
67 this.driver = driver
68 this.readers = Pool(capacity = 1, connectionFactory = { driver.open(fileName) })
69 this.writers = readers
70 }
71
72 constructor(
73 driver: SQLiteDriver,
74 fileName: String,
75 maxNumOfReaders: Int,
76 maxNumOfWriters: Int,
77 ) {
78 require(maxNumOfReaders > 0) { "Maximum number of readers must be greater than 0" }
79 require(maxNumOfWriters > 0) { "Maximum number of writers must be greater than 0" }
80 this.driver = driver
81 this.readers =
82 Pool(
83 capacity = maxNumOfReaders,
84 connectionFactory = {
85 driver.open(fileName).also { newConnection ->
86 // Enforce to be read only (might be disabled by a YOLO developer)
87 newConnection.execSQL("PRAGMA query_only = 1")
88 }
89 }
90 )
91 this.writers =
92 Pool(capacity = maxNumOfWriters, connectionFactory = { driver.open(fileName) })
93 }
94
95 override suspend fun <R> useConnection(
96 isReadOnly: Boolean,
97 block: suspend (Transactor) -> R
98 ): R {
99 if (isClosed) {
100 throwSQLiteException(SQLITE_MISUSE, "Connection pool is closed")
101 }
102 val confinedConnection =
103 threadLocal.get() ?: coroutineContext[ConnectionElement]?.connectionWrapper
104 if (confinedConnection != null) {
105 if (!isReadOnly && confinedConnection.isReadOnly) {
106 throwSQLiteException(
107 SQLITE_ERROR,
108 "Cannot upgrade connection from reader to writer"
109 )
110 }
111 return if (coroutineContext[ConnectionElement] == null) {
112 // Reinstall the connection context element if it is missing. We are likely in
113 // a new coroutine but were able to transfer the connection via the thread local.
114 withContext(createConnectionContext(confinedConnection)) {
115 block.invoke(confinedConnection)
116 }
117 } else {
118 block.invoke(confinedConnection)
119 }
120 }
121 val pool =
122 if (isReadOnly) {
123 readers
124 } else {
125 writers
126 }
127 val result: R
128 var exception: Throwable? = null
129 var connection: PooledConnectionImpl? = null
130 try {
131 val currentContext = coroutineContext
132 val (acquiredConnection, acquireError) = pool.acquireWithTimeout()
133 // Always try to create a wrapper even if an error occurs, so it can be recycled.
134 connection =
135 acquiredConnection?.let {
136 PooledConnectionImpl(
137 delegate = it.markAcquired(currentContext),
138 isReadOnly = readers !== writers && isReadOnly
139 )
140 }
141 if (acquireError is TimeoutCancellationException) {
142 throwTimeoutException(isReadOnly)
143 } else if (acquireError != null) {
144 throw acquireError
145 }
146 requireNotNull(connection)
147 result = withContext(createConnectionContext(connection)) { block.invoke(connection) }
148 } catch (ex: Throwable) {
149 exception = ex
150 throw ex
151 } finally {
152 try {
153 connection?.let { usedConnection ->
154 usedConnection.markRecycled()
155 usedConnection.delegate.markReleased()
156 pool.recycle(usedConnection.delegate)
157 }
158 } catch (error: Throwable) {
159 exception?.addSuppressed(error)
160 }
161 }
162 return result
163 }
164
165 private suspend inline fun Pool.acquireWithTimeout(): Pair<ConnectionWithLock?, Throwable?> {
166 // Following async timeout with resources recommendation:
167 // https://kotlinlang.org/docs/cancellation-and-timeouts.html#asynchronous-timeout-and-resources
168 var connection: ConnectionWithLock? = null
169 var exceptionThrown: Throwable? = null
170 try {
171 withTimeout(timeout) { connection = this@acquireWithTimeout.acquire() }
172 } catch (ex: Throwable) {
173 exceptionThrown = ex
174 }
175 return connection to exceptionThrown
176 }
177
178 private fun createConnectionContext(connection: PooledConnectionImpl) =
179 ConnectionElement(connection) + threadLocal.asContextElement(connection)
180
181 private fun throwTimeoutException(isReadOnly: Boolean): Nothing {
182 val readOrWrite = if (isReadOnly) "reader" else "writer"
183 val message = buildString {
184 appendLine("Timed out attempting to acquire a $readOrWrite connection.")
185 appendLine()
186 appendLine("Writer pool:")
187 writers.dump(this)
188 appendLine("Reader pool:")
189 readers.dump(this)
190 }
191 throwSQLiteException(SQLITE_BUSY, message)
192 }
193
194 // TODO: (b/319657104): Make suspending so pool closes when all connections are recycled.
195 override fun close() {
196 if (_isClosed.compareAndSet(expect = false, update = true)) {
197 readers.close()
198 writers.close()
199 }
200 }
201 }
202
203 private class Pool(val capacity: Int, val connectionFactory: () -> SQLiteConnection) {
204 private val lock = ReentrantLock()
205 private var size = 0
206 private var isClosed = false
207 private val connections = arrayOfNulls<ConnectionWithLock>(capacity)
208 private val connectionPermits = Semaphore(permits = capacity)
209 private val availableConnections = CircularArray<ConnectionWithLock>(capacity)
210
acquirenull211 suspend fun acquire(): ConnectionWithLock {
212 connectionPermits.acquire()
213 try {
214 return lock.withLock {
215 if (isClosed) {
216 throwSQLiteException(SQLITE_MISUSE, "Connection pool is closed")
217 }
218 if (availableConnections.isEmpty()) {
219 tryOpenNewConnectionLocked()
220 }
221 availableConnections.popFirst()
222 }
223 } catch (ex: Throwable) {
224 connectionPermits.release()
225 throw ex
226 }
227 }
228
tryOpenNewConnectionLockednull229 private fun tryOpenNewConnectionLocked() {
230 if (size >= capacity) {
231 // Capacity reached
232 return
233 }
234 val newConnection = ConnectionWithLock(connectionFactory.invoke())
235 connections[size++] = newConnection
236 availableConnections.addLast(newConnection)
237 }
238
recyclenull239 fun recycle(connection: ConnectionWithLock) {
240 lock.withLock { availableConnections.addLast(connection) }
241 connectionPermits.release()
242 }
243
closenull244 fun close() {
245 lock.withLock {
246 isClosed = true
247 connections.forEach { it?.close() }
248 }
249 }
250
251 /* Dumps debug information */
dumpnull252 fun dump(builder: StringBuilder) =
253 lock.withLock {
254 val availableQueue = buildList {
255 for (i in 0 until availableConnections.size()) {
256 add(availableConnections[i])
257 }
258 }
259 builder.append("\t" + super.toString() + " (")
260 builder.append("capacity=$capacity, ")
261 builder.append("permits=${connectionPermits.availablePermits}, ")
262 builder.append(
263 "queue=(size=${availableQueue.size})[${availableQueue.joinToString()}], "
264 )
265 builder.appendLine(")")
266 connections.forEachIndexed { index, connection ->
267 builder.appendLine("\t\t[${index + 1}] - ${connection?.toString()}")
268 connection?.dump(builder)
269 }
270 }
271 }
272
273 private class ConnectionWithLock(
274 private val delegate: SQLiteConnection,
275 private val lock: Mutex = Mutex()
<lambda>null276 ) : SQLiteConnection by delegate, Mutex by lock {
277
278 private var acquireCoroutineContext: CoroutineContext? = null
279 private var acquireThrowable: Throwable? = null
280
281 fun markAcquired(context: CoroutineContext) = apply {
282 acquireCoroutineContext = context
283 acquireThrowable = Throwable()
284 }
285
286 fun markReleased() = apply {
287 acquireCoroutineContext = null
288 acquireThrowable = null
289 }
290
291 /* Dumps debug information */
292 fun dump(builder: StringBuilder) {
293 if (acquireCoroutineContext != null || acquireThrowable != null) {
294 builder.appendLine("\t\tStatus: Acquired connection")
295 acquireCoroutineContext?.let { builder.appendLine("\t\tCoroutine: $it") }
296 acquireThrowable?.let {
297 builder.appendLine("\t\tAcquired:")
298 it.stackTraceToString().lines().drop(1).forEach { line ->
299 builder.appendLine("\t\t$line")
300 }
301 }
302 } else {
303 builder.appendLine("\t\tStatus: Free connection")
304 }
305 }
306
307 override fun toString(): String {
308 return delegate.toString()
309 }
310 }
311
312 private class ConnectionElement(val connectionWrapper: PooledConnectionImpl) :
313 CoroutineContext.Element {
314 companion object Key : CoroutineContext.Key<ConnectionElement>
315
316 override val key: CoroutineContext.Key<ConnectionElement>
317 get() = ConnectionElement
318 }
319
320 /**
321 * A connection wrapper to enforce pool contract and implement transactions.
322 *
323 * Actual connection interactions are serialized via a limited dispatcher, specifically compiling a
324 * statement and using it is serialized as to prevent a coroutine from concurrently using the
325 * statement between multiple different threads.
326 */
327 private class PooledConnectionImpl(
328 val delegate: ConnectionWithLock,
329 val isReadOnly: Boolean,
330 ) : Transactor, RawConnectionAccessor {
331 private val transactionStack = ArrayDeque<TransactionItem>()
332
333 private val _isRecycled = AtomicBoolean(false)
334 private val isRecycled: Boolean
335 get() = _isRecycled.get()
336
337 override val rawConnection: SQLiteConnection
338 get() = delegate
339
usePreparednull340 override suspend fun <R> usePrepared(sql: String, block: (SQLiteStatement) -> R): R =
341 withStateCheck {
342 return delegate.withLock {
343 StatementWrapper(delegate.prepare(sql)).use { block.invoke(it) }
344 }
345 }
346
withTransactionnull347 override suspend fun <R> withTransaction(
348 type: SQLiteTransactionType,
349 block: suspend TransactionScope<R>.() -> R
350 ): R = withStateCheck { transaction(type, block) }
351
<lambda>null352 override suspend fun inTransaction(): Boolean = withStateCheck {
353 return transactionStack.isNotEmpty()
354 }
355
markRecyclednull356 fun markRecycled() {
357 if (_isRecycled.compareAndSet(expect = false, update = true)) {
358 // Perform a rollback in case there is an active transaction so that the connection
359 // is in a clean state when it is recycled. We don't know for sure if there is an
360 // unfinished transaction, hence we always try the rollback.
361 // TODO(b/319627988): Try to *really* check if there is an active transaction with the
362 // C APIs sqlite3_txn_state or sqlite3_get_autocommit and possibly throw an error
363 // if there is an unfinished transaction.
364 try {
365 delegate.execSQL("ROLLBACK TRANSACTION")
366 } catch (_: SQLiteException) {
367 // ignored
368 }
369 }
370 }
371
transactionnull372 private suspend fun <R> transaction(
373 type: SQLiteTransactionType?,
374 block: suspend TransactionScope<R>.() -> R
375 ): R {
376 beginTransaction(type ?: SQLiteTransactionType.DEFERRED)
377 var success = true
378 var exception: Throwable? = null
379 try {
380 return TransactionImpl<R>().block()
381 } catch (ex: Throwable) {
382 success = false
383 if (ex is ConnectionPool.RollbackException) {
384 // Type arguments in exception subclasses is not allowed but the exception is always
385 // created with the correct type.
386 @Suppress("UNCHECKED_CAST") return (ex.result as R)
387 } else {
388 exception = ex
389 throw ex
390 }
391 } finally {
392 try {
393 endTransaction(success)
394 } catch (ex: SQLiteException) {
395 exception?.addSuppressed(ex) ?: throw ex
396 }
397 }
398 }
399
beginTransactionnull400 private suspend fun beginTransaction(type: SQLiteTransactionType) =
401 delegate.withLock {
402 val newTransactionId = transactionStack.size
403 if (transactionStack.isEmpty()) {
404 when (type) {
405 SQLiteTransactionType.DEFERRED -> delegate.execSQL("BEGIN DEFERRED TRANSACTION")
406 SQLiteTransactionType.IMMEDIATE ->
407 delegate.execSQL("BEGIN IMMEDIATE TRANSACTION")
408 SQLiteTransactionType.EXCLUSIVE ->
409 delegate.execSQL("BEGIN EXCLUSIVE TRANSACTION")
410 }
411 } else {
412 delegate.execSQL("SAVEPOINT '$newTransactionId'")
413 }
414 transactionStack.addLast(TransactionItem(id = newTransactionId, shouldRollback = false))
415 }
416
endTransactionnull417 private suspend fun endTransaction(success: Boolean) =
418 delegate.withLock {
419 if (transactionStack.isEmpty()) {
420 error("Not in a transaction")
421 }
422 val transaction = transactionStack.removeLastKt()
423 if (success && !transaction.shouldRollback) {
424 if (transactionStack.isEmpty()) {
425 delegate.execSQL("END TRANSACTION")
426 } else {
427 delegate.execSQL("RELEASE SAVEPOINT '${transaction.id}'")
428 }
429 } else {
430 if (transactionStack.isEmpty()) {
431 delegate.execSQL("ROLLBACK TRANSACTION")
432 } else {
433 delegate.execSQL("ROLLBACK TRANSACTION TO SAVEPOINT '${transaction.id}'")
434 }
435 }
436 }
437
438 private class TransactionItem(val id: Int, var shouldRollback: Boolean)
439
440 private inner class TransactionImpl<T> : TransactionScope<T>, RawConnectionAccessor {
441
442 override val rawConnection: SQLiteConnection
443 get() = this@PooledConnectionImpl.rawConnection
444
usePreparednull445 override suspend fun <R> usePrepared(sql: String, block: (SQLiteStatement) -> R): R =
446 this@PooledConnectionImpl.usePrepared(sql, block)
447
448 override suspend fun <R> withNestedTransaction(
449 block: suspend (TransactionScope<R>) -> R
450 ): R = withStateCheck { transaction(null, block) }
451
rollbacknull452 override suspend fun rollback(result: T): Nothing = withStateCheck {
453 if (transactionStack.isEmpty()) {
454 error("Not in a transaction")
455 }
456 delegate.withLock { transactionStack.last().shouldRollback = true }
457 throw ConnectionPool.RollbackException(result)
458 }
459 }
460
withStateChecknull461 private suspend inline fun <R> withStateCheck(block: () -> R): R {
462 if (isRecycled) {
463 throwSQLiteException(SQLITE_MISUSE, "Connection is recycled")
464 }
465 val connectionElement = coroutineContext[ConnectionElement]
466 if (connectionElement == null || connectionElement.connectionWrapper !== this) {
467 throwSQLiteException(
468 SQLITE_MISUSE,
469 "Attempted to use connection on a different coroutine"
470 )
471 }
472 return block.invoke()
473 }
474
475 private inner class StatementWrapper(
476 private val delegate: SQLiteStatement,
477 ) : SQLiteStatement {
478
479 private val threadId = currentThreadId()
480
<lambda>null481 override fun bindBlob(index: Int, value: ByteArray): Unit = withStateCheck {
482 delegate.bindBlob(index, value)
483 }
484
<lambda>null485 override fun bindDouble(index: Int, value: Double): Unit = withStateCheck {
486 delegate.bindDouble(index, value)
487 }
488
<lambda>null489 override fun bindLong(index: Int, value: Long): Unit = withStateCheck {
490 delegate.bindLong(index, value)
491 }
492
<lambda>null493 override fun bindText(index: Int, value: String): Unit = withStateCheck {
494 delegate.bindText(index, value)
495 }
496
<lambda>null497 override fun bindNull(index: Int): Unit = withStateCheck { delegate.bindNull(index) }
498
<lambda>null499 override fun getBlob(index: Int): ByteArray = withStateCheck { delegate.getBlob(index) }
500
<lambda>null501 override fun getDouble(index: Int): Double = withStateCheck { delegate.getDouble(index) }
502
getLongnull503 override fun getLong(index: Int): Long = withStateCheck { delegate.getLong(index) }
504
<lambda>null505 override fun getText(index: Int): String = withStateCheck { delegate.getText(index) }
506
<lambda>null507 override fun isNull(index: Int): Boolean = withStateCheck { delegate.isNull(index) }
508
<lambda>null509 override fun getColumnCount(): Int = withStateCheck { delegate.getColumnCount() }
510
<lambda>null511 override fun getColumnName(index: Int) = withStateCheck { delegate.getColumnName(index) }
512
<lambda>null513 override fun getColumnType(index: Int) = withStateCheck { delegate.getColumnType(index) }
514
<lambda>null515 override fun step(): Boolean = withStateCheck { delegate.step() }
516
<lambda>null517 override fun reset() = withStateCheck { delegate.reset() }
518
<lambda>null519 override fun clearBindings() = withStateCheck { delegate.clearBindings() }
520
<lambda>null521 override fun close() = withStateCheck { delegate.close() }
522
withStateChecknull523 private inline fun <R> withStateCheck(block: () -> R): R {
524 if (isRecycled) {
525 throwSQLiteException(SQLITE_MISUSE, "Statement is recycled")
526 }
527 if (threadId != currentThreadId()) {
528 throwSQLiteException(
529 SQLITE_MISUSE,
530 "Attempted to use statement on a different thread"
531 )
532 }
533 return block.invoke()
534 }
535 }
536 }
537