1 /*
2  * Copyright 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package androidx.sqlite.inspection
18 
19 import android.database.sqlite.SQLiteDatabase
20 import androidx.inspection.ArtTooling
21 import androidx.inspection.testing.DefaultTestInspectorEnvironment
22 import androidx.inspection.testing.InspectorTester
23 import androidx.inspection.testing.TestInspectorExecutors
24 import androidx.sqlite.db.SupportSQLiteDatabase
25 import androidx.sqlite.inspection.SqliteInspectorProtocol.Command
26 import androidx.sqlite.inspection.SqliteInspectorProtocol.QueryCommand
27 import androidx.sqlite.inspection.SqliteInspectorProtocol.TrackDatabasesCommand
28 import androidx.sqlite.inspection.sqldeligttestapp.Database
29 import androidx.sqlite.inspection.test.TestEntity
30 import androidx.test.core.app.ApplicationProvider
31 import androidx.test.ext.junit.runners.AndroidJUnit4
32 import com.google.common.truth.Truth.assertThat
33 import com.google.common.truth.Truth.assertWithMessage
34 import com.squareup.sqldelight.Query
35 import com.squareup.sqldelight.android.AndroidSqliteDriver
36 import com.squareup.sqldelight.db.SqlDriver
37 import com.squareup.sqldelight.runtime.coroutines.asFlow
38 import com.squareup.sqldelight.runtime.coroutines.mapToList
39 import kotlinx.coroutines.ExperimentalCoroutinesApi
40 import kotlinx.coroutines.FlowPreview
41 import kotlinx.coroutines.Job
42 import kotlinx.coroutines.flow.produceIn
43 import kotlinx.coroutines.flow.take
44 import kotlinx.coroutines.runBlocking
45 import org.junit.After
46 import org.junit.Before
47 import org.junit.Test
48 import org.junit.runner.RunWith
49 
50 @ExperimentalCoroutinesApi
51 @RunWith(AndroidJUnit4::class)
52 @FlowPreview
53 class SqlDelightInvalidationTest {
54 
55     lateinit var driver: SqlDriver
56     lateinit var openedDb: SupportSQLiteDatabase
57 
58     @Before
setupnull59     fun setup() {
60         driver =
61             AndroidSqliteDriver(
62                 schema = Database.Schema,
63                 context = ApplicationProvider.getApplicationContext(),
64                 callback =
65                     object : AndroidSqliteDriver.Callback(Database.Schema) {
66                         override fun onCreate(db: SupportSQLiteDatabase) {
67                             openedDb = db
68                             super.onCreate(db)
69                         }
70                     }
71             )
72     }
73 
74     @Test
testnull75     fun test() {
76         runBlocking {
77             val dao = Database(driver).testEntityQueries
78             dao.insertOrReplace("one")
79             val sqliteDb = openedDb.getSqliteDb()
80             val query = dao.selectAll()
81             val job = this.coroutineContext[Job]!!
82             val tester =
83                 InspectorTester(
84                     inspectorId = "androidx.sqlite.inspection",
85                     environment =
86                         DefaultTestInspectorEnvironment(
87                             TestInspectorExecutors(job),
88                             TestArtTooling(sqliteDb, listOf(query))
89                         )
90                 )
91             val updates = query.asFlow().mapToList().take(2).produceIn(this)
92 
93             val firstExpected = TestEntity.Impl(1, "one")
94             val secondExpected = TestEntity.Impl(2, "foo")
95             assertThat(updates.receive()).isEqualTo(listOf(firstExpected))
96 
97             val startTrackingCommand =
98                 Command.newBuilder()
99                     .setTrackDatabases(TrackDatabasesCommand.getDefaultInstance())
100                     .build()
101 
102             tester.sendCommand(startTrackingCommand.toByteArray())
103 
104             val insertQuery = """INSERT INTO TestEntity VALUES(2, "foo")"""
105             val insertCommand =
106                 Command.newBuilder()
107                     .setQuery(
108                         QueryCommand.newBuilder().setDatabaseId(1).setQuery(insertQuery).build()
109                     )
110                     .build()
111             val responseBytes = tester.sendCommand(insertCommand.toByteArray())
112             val response = SqliteInspectorProtocol.Response.parseFrom(responseBytes)
113             assertWithMessage("test sanity, insert query should succeed")
114                 .that(response.hasErrorOccurred())
115                 .isFalse()
116 
117             assertThat(updates.receive()).isEqualTo(listOf(firstExpected, secondExpected))
118         }
119     }
120 
121     @After
tearDownnull122     fun tearDown() {
123         driver.close()
124     }
125 }
126 
127 /** extract the framework sqlite database instance from a room database via reflection. */
SupportSQLiteDatabasenull128 private fun SupportSQLiteDatabase.getSqliteDb(): SQLiteDatabase {
129     // this runs with defaults so we can extract db from it until inspection supports support
130     // instances directly
131     return this::class.java.getDeclaredField("mDelegate").let {
132         it.isAccessible = true
133         it.get(this)
134     } as SQLiteDatabase
135 }
136 
137 @Suppress("UNCHECKED_CAST")
138 class TestArtTooling(private val sqliteDb: SQLiteDatabase, private val queries: List<Query<*>>) :
139     ArtTooling {
registerEntryHooknull140     override fun registerEntryHook(
141         originClass: Class<*>,
142         originMethod: String,
143         entryHook: ArtTooling.EntryHook
144     ) {
145         // no-op
146     }
147 
findInstancesnull148     override fun <T : Any?> findInstances(clazz: Class<T>): List<T> {
149         if (clazz.isAssignableFrom(Query::class.java)) {
150             return queries as List<T>
151         } else if (clazz.isAssignableFrom(SQLiteDatabase::class.java)) {
152             return listOf(sqliteDb as T)
153         }
154         return emptyList()
155     }
156 
registerExitHooknull157     override fun <T : Any?> registerExitHook(
158         originClass: Class<*>,
159         originMethod: String,
160         exitHook: ArtTooling.ExitHook<T>
161     ) {
162         // no-op
163     }
164 }
165