1 /*
<lambda>null2  * Copyright 2024 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package androidx.room.testing
18 
19 import androidx.room.BaseRoomConnectionManager
20 import androidx.room.DatabaseConfiguration
21 import androidx.room.Room
22 import androidx.room.RoomDatabase
23 import androidx.room.RoomOpenDelegate
24 import androidx.room.Transactor
25 import androidx.room.migration.AutoMigrationSpec
26 import androidx.room.migration.Migration
27 import androidx.room.migration.bundle.DatabaseBundle
28 import androidx.room.migration.bundle.EntityBundle
29 import androidx.room.migration.bundle.FtsEntityBundle
30 import androidx.room.util.FtsTableInfo
31 import androidx.room.util.TableInfo
32 import androidx.room.util.ViewInfo
33 import androidx.sqlite.SQLiteConnection
34 import androidx.sqlite.execSQL
35 import kotlin.reflect.KClass
36 import kotlin.reflect.safeCast
37 
38 /**
39  * A class that can help test and verify database creation and migration at different versions with
40  * different schemas.
41  *
42  * Common usage of this helper is to create a database at an older version first and then attempt a
43  * migration and validation:
44  * ```
45  * @Test
46  * fun migrationTest() {
47  *   val migrationTestHelper = getMigrationTestHelper()
48  *   // Create the database at version 1
49  *   val newConnection = migrationTestHelper.createDatabase(1)
50  *   // Insert some data that should be preserved
51  *   newConnection.execSQL("INSERT INTO Pet (id, name) VALUES (1, 'Tom')")
52  *   newConnection.close()
53  *
54  *   // Migrate the database to version 2
55  *   val migratedConnection =
56  *       migrationTestHelper.runMigrationsAndValidate(2, listOf(MIGRATION_1_2)))
57  *   migratedConnection.prepare("SELECT * FROM Pet).use { stmt ->
58  *     // Validates data is preserved between migrations.
59  *     assertThat(stmt.step()).isTrue()
60  *     assertThat(stmt.getText(1)).isEqualTo("Tom")
61  *   }
62  *   migratedConnection.close()
63  * }
64  * ```
65  *
66  * The helper relies on exported schemas so [androidx.room.Database.exportSchema] should be enabled.
67  * Schema location should be configured via Room's Gradle Plugin (id 'androidx.room'):
68  * ```
69  * room {
70  *   schemaDirectory("$projectDir/schemas")
71  * }
72  * ```
73  *
74  * The helper is then instantiated to use the same schema location where they are exported to. See
75  * platform-specific documentation for further configuration.
76  */
77 expect class MigrationTestHelper {
78     /**
79      * Creates the database at the given version.
80      *
81      * Once a database is created it can further validate with [runMigrationsAndValidate].
82      *
83      * @param version The version of the schema at which the database should be created.
84      * @return A database connection of the newly created database.
85      * @throws IllegalStateException If a new database was not created.
86      */
87     fun createDatabase(version: Int): SQLiteConnection
88 
89     /**
90      * Runs the given set of migrations on the existing database once created via [createDatabase].
91      *
92      * This function uses the same algorithm that Room performs to choose migrations such that the
93      * [migrations] instances provided must be sufficient to bring the database from current version
94      * to the desired version. If the database contains [androidx.room.AutoMigration]s, then those
95      * are already included in the list of migrations to execute if necessary. Note that provided
96      * manual migrations take precedence over auto migrations if they overlap in migration paths.
97      *
98      * Once migrations are done, this functions validates the database schema to ensure the
99      * migration performed resulted in the expected schema.
100      *
101      * @param version The final version the database should migrate to.
102      * @param migrations The list of migrations used to attempt the database migration.
103      * @return A database connection of the migrated database.
104      * @throws IllegalStateException If the schema validation fails.
105      */
106     fun runMigrationsAndValidate(
107         version: Int,
108         migrations: List<Migration> = emptyList()
109     ): SQLiteConnection
110 }
111 
112 internal typealias ConnectionManagerFactory =
113     (DatabaseConfiguration, RoomOpenDelegate) -> TestConnectionManager
114 
115 internal typealias ConfigurationFactory = (RoomDatabase.MigrationContainer) -> DatabaseConfiguration
116 
117 /** Common logic for [MigrationTestHelper.createDatabase] */
createDatabaseCommonnull118 internal fun createDatabaseCommon(
119     schema: DatabaseBundle,
120     configurationFactory: ConfigurationFactory,
121     connectionManagerFactory: ConnectionManagerFactory = { config, openDelegate ->
122         DefaultTestConnectionManager(config, openDelegate)
123     }
124 ): SQLiteConnection {
125     val emptyContainer = RoomDatabase.MigrationContainer()
126     val configuration = configurationFactory.invoke(emptyContainer)
127     val testConnectionManager =
128         connectionManagerFactory.invoke(configuration, CreateOpenDelegate(schema))
129     return testConnectionManager.openConnection()
130 }
131 
132 /** Common logic for [MigrationTestHelper.runMigrationsAndValidate] */
runMigrationsAndValidateCommonnull133 internal fun runMigrationsAndValidateCommon(
134     databaseInstance: RoomDatabase,
135     schema: DatabaseBundle,
136     migrations: List<Migration>,
137     autoMigrationSpecs: List<AutoMigrationSpec>,
138     validateUnknownTables: Boolean,
139     configurationFactory: ConfigurationFactory,
140     connectionManagerFactory: ConnectionManagerFactory = { config, openDelegate ->
141         DefaultTestConnectionManager(config, openDelegate)
142     }
143 ): SQLiteConnection {
144     val container = RoomDatabase.MigrationContainer()
145     container.addMigrations(migrations)
146     val autoMigrations = getAutoMigrations(databaseInstance, autoMigrationSpecs)
autoMigrationnull147     autoMigrations.forEach { autoMigration ->
148         val migrationExists =
149             container.contains(autoMigration.startVersion, autoMigration.endVersion)
150         if (!migrationExists) {
151             container.addMigration(autoMigration)
152         }
153     }
154     val configuration = configurationFactory.invoke(container)
155     val testConnectionManager =
156         connectionManagerFactory.invoke(
157             configuration,
158             MigrateOpenDelegate(schema, validateUnknownTables)
159         )
160     return testConnectionManager.openConnection()
161 }
162 
getAutoMigrationsnull163 private fun getAutoMigrations(
164     databaseInstance: RoomDatabase,
165     providedSpecs: List<AutoMigrationSpec>
166 ): List<Migration> {
167     val autoMigrationSpecMap =
168         createAutoMigrationSpecMap(
169             databaseInstance.getRequiredAutoMigrationSpecClasses(),
170             providedSpecs
171         )
172     return databaseInstance.createAutoMigrations(autoMigrationSpecMap)
173 }
174 
createAutoMigrationSpecMapnull175 private fun createAutoMigrationSpecMap(
176     requiredAutoMigrationSpecs: Set<KClass<out AutoMigrationSpec>>,
177     providedSpecs: List<AutoMigrationSpec>
178 ): Map<KClass<out AutoMigrationSpec>, AutoMigrationSpec> {
179     if (requiredAutoMigrationSpecs.isEmpty()) {
180         return emptyMap()
181     }
182     return buildMap {
183         requiredAutoMigrationSpecs.forEach { spec ->
184             val match = providedSpecs.firstOrNull { provided -> spec.safeCast(provided) != null }
185             requireNotNull(match) {
186                 "A required auto migration spec (${spec.qualifiedName}) has not been provided."
187             }
188             put(spec, match)
189         }
190     }
191 }
192 
193 internal abstract class TestConnectionManager : BaseRoomConnectionManager() {
194     override val callbacks: List<RoomDatabase.Callback> = emptyList()
195 
useConnectionnull196     override suspend fun <R> useConnection(
197         isReadOnly: Boolean,
198         block: suspend (Transactor) -> R
199     ): R {
200         error("Function should never be invoked during tests.")
201     }
202 
openConnectionnull203     abstract fun openConnection(): SQLiteConnection
204 }
205 
206 private class DefaultTestConnectionManager(
207     override val configuration: DatabaseConfiguration,
208     override val openDelegate: RoomOpenDelegate
209 ) : TestConnectionManager() {
210 
211     private val driverWrapper = DriverWrapper(requireNotNull(configuration.sqliteDriver))
212 
213     override fun openConnection() = driverWrapper.open(configuration.name ?: ":memory:")
214 }
215 
216 private sealed class TestOpenDelegate(databaseBundle: DatabaseBundle) :
217     RoomOpenDelegate(
218         version = databaseBundle.version,
219         identityHash = databaseBundle.identityHash,
220         legacyIdentityHash = databaseBundle.identityHash
221     ) {
onCreatenull222     override fun onCreate(connection: SQLiteConnection) {}
223 
onPreMigratenull224     override fun onPreMigrate(connection: SQLiteConnection) {}
225 
onPostMigratenull226     override fun onPostMigrate(connection: SQLiteConnection) {}
227 
onOpennull228     override fun onOpen(connection: SQLiteConnection) {}
229 
dropAllTablesnull230     override fun dropAllTables(connection: SQLiteConnection) {
231         error("Can't drop all tables during a test.")
232     }
233 }
234 
235 private class CreateOpenDelegate(val databaseBundle: DatabaseBundle) :
236     TestOpenDelegate(databaseBundle) {
237     private var createAllTables = false
238 
onOpennull239     override fun onOpen(connection: SQLiteConnection) {
240         check(createAllTables) {
241             "Creation of tables didn't occur while creating a new database. A database at the " +
242                 "driver configured path likely already exists. Did you forget to delete it?"
243         }
244     }
245 
onValidateSchemanull246     override fun onValidateSchema(connection: SQLiteConnection): ValidationResult {
247         error("Validation of schemas should never occur while creating a new database.")
248     }
249 
createAllTablesnull250     override fun createAllTables(connection: SQLiteConnection) {
251         databaseBundle.buildCreateQueries().forEach { createSql -> connection.execSQL(createSql) }
252         createAllTables = true
253     }
254 }
255 
256 private class MigrateOpenDelegate(
257     val databaseBundle: DatabaseBundle,
258     val validateUnknownTables: Boolean
259 ) : TestOpenDelegate(databaseBundle) {
onValidateSchemanull260     override fun onValidateSchema(connection: SQLiteConnection): ValidationResult {
261         val tables = databaseBundle.entitiesByTableName
262         tables.values.forEach { entity ->
263             when (entity) {
264                 is EntityBundle -> {
265                     val expected = entity.toTableInfo()
266                     val found = TableInfo.read(connection, entity.tableName)
267                     if (expected != found) {
268                         return ValidationResult(
269                             isValid = false,
270                             expectedFoundMsg =
271                                 """ ${expected.name.trimEnd()}
272                                 |
273                                 |Expected:
274                                 |
275                                 |$expected
276                                 |
277                                 |Found:
278                                 |
279                                 |$found
280                                 """
281                                     .trimMargin()
282                         )
283                     }
284                 }
285                 is FtsEntityBundle -> {
286                     val expected = entity.toFtsTableInfo()
287                     val found = FtsTableInfo.read(connection, entity.tableName)
288                     if (expected != found) {
289                         return ValidationResult(
290                             isValid = false,
291                             expectedFoundMsg =
292                                 """ ${expected.name.trimEnd()}
293                                 |
294                                 |Expected:
295                                 |
296                                 |$expected
297                                 |
298                                 |Found:
299                                 |
300                                 |$found
301                                 """
302                                     .trimMargin()
303                         )
304                     }
305                 }
306             }
307         }
308         databaseBundle.views.forEach { view ->
309             val expected = view.toViewInfo()
310             val found = ViewInfo.read(connection, view.viewName)
311             if (expected != found) {
312                 return ValidationResult(
313                     isValid = false,
314                     expectedFoundMsg =
315                         """ ${expected.name.trimEnd()}
316                         |
317                         |Expected: $expected
318                         |
319                         |Found: $found
320                         """
321                             .trimMargin()
322                 )
323             }
324         }
325         if (validateUnknownTables) {
326             val expectedTables = buildSet {
327                 tables.values.forEach { entity ->
328                     add(entity.tableName)
329                     if (entity is FtsEntityBundle) {
330                         addAll(entity.shadowTableNames)
331                     }
332                 }
333             }
334             connection
335                 .prepare(
336                     """
337                 SELECT name FROM sqlite_master
338                 WHERE type = 'table' AND name NOT IN (?, ?, ?)
339                 """
340                         .trimIndent()
341                 )
342                 .use { statement ->
343                     statement.bindText(1, Room.MASTER_TABLE_NAME)
344                     statement.bindText(2, "sqlite_sequence")
345                     statement.bindText(3, "android_metadata")
346                     while (statement.step()) {
347                         val tableName = statement.getText(0)
348                         if (!expectedTables.contains(tableName)) {
349                             return ValidationResult(
350                                 isValid = false,
351                                 expectedFoundMsg = "Unexpected table $tableName"
352                             )
353                         }
354                     }
355                 }
356         }
357         return ValidationResult(true, null)
358     }
359 
createAllTablesnull360     override fun createAllTables(connection: SQLiteConnection) {
361         error(
362             "Creation of tables should never occur while validating migrations. Did you forget " +
363                 "to first create the database?"
364         )
365     }
366 }
367