• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2022 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.adservices.data;
18 
19 import static org.junit.Assert.assertArrayEquals;
20 import static org.junit.Assert.assertEquals;
21 
22 import android.content.Context;
23 import android.database.Cursor;
24 import android.database.sqlite.SQLiteDatabase;
25 import android.util.Log;
26 
27 import androidx.test.core.app.ApplicationProvider;
28 
29 import com.android.adservices.data.measurement.MeasurementDbHelper;
30 import com.android.adservices.data.shared.SharedDbHelper;
31 
32 import com.google.common.collect.ImmutableSet;
33 
34 import java.util.ArrayList;
35 import java.util.Collections;
36 import java.util.List;
37 import java.util.Set;
38 
39 public final class DbTestUtil {
40     private static final Context sContext = ApplicationProvider.getApplicationContext();
41     private static final String DATABASE_NAME_FOR_TEST = "adservices_test.db";
42     private static final String MSMT_DATABASE_NAME_FOR_TEST = "adservices_msmt_test.db";
43     private static final String SHARED_DATABASE_NAME_FOR_TEST = "adservices_shared_test.db";
44 
45     private static DbHelper sSingleton;
46     private static MeasurementDbHelper sMsmtSingleton;
47     private static SharedDbHelper sSharedSingleton;
48 
49     /** Erases all data from the table rows */
deleteTable(String tableName)50     public static void deleteTable(String tableName) {
51         SQLiteDatabase db = getDbHelperForTest().safeGetWritableDatabase();
52         if (db == null) {
53             return;
54         }
55 
56         db.delete(tableName, /* whereClause= */ null, /* whereArgs= */ null);
57     }
58 
59     /**
60      * Create an instance of database instance for testing.
61      *
62      * @return a test database
63      */
getDbHelperForTest()64     public static DbHelper getDbHelperForTest() {
65         synchronized (DbHelper.class) {
66             if (sSingleton == null) {
67                 sSingleton =
68                         new DbHelper(sContext, DATABASE_NAME_FOR_TEST, DbHelper.DATABASE_VERSION);
69             }
70             return sSingleton;
71         }
72     }
73 
getMeasurementDbHelperForTest()74     public static MeasurementDbHelper getMeasurementDbHelperForTest() {
75         synchronized (MeasurementDbHelper.class) {
76             if (sMsmtSingleton == null) {
77                 sMsmtSingleton =
78                         new MeasurementDbHelper(
79                                 sContext,
80                                 MSMT_DATABASE_NAME_FOR_TEST,
81                                 MeasurementDbHelper.CURRENT_DATABASE_VERSION,
82                                 getDbHelperForTest());
83             }
84             return sMsmtSingleton;
85         }
86     }
87 
getSharedDbHelperForTest()88     public static SharedDbHelper getSharedDbHelperForTest() {
89         synchronized (SharedDbHelper.class) {
90             if (sSharedSingleton == null) {
91                 sSharedSingleton =
92                         new SharedDbHelper(
93                                 sContext,
94                                 SHARED_DATABASE_NAME_FOR_TEST,
95                                 SharedDbHelper.CURRENT_DATABASE_VERSION,
96                                 getDbHelperForTest());
97             }
98             return sSharedSingleton;
99         }
100     }
101 
102     /** Return true if table exists in the DB and column count matches. */
doesTableExistAndColumnCountMatch( SQLiteDatabase db, String tableName, int columnCount)103     public static boolean doesTableExistAndColumnCountMatch(
104             SQLiteDatabase db, String tableName, int columnCount) {
105         final Set<String> tableColumns = getTableColumns(db, tableName);
106         int actualCol = tableColumns.size();
107         Log.d("DbTestUtil_log_test,", " table name: " + tableName + " column count: " + actualCol);
108         return tableColumns.size() == columnCount;
109     }
110 
111     /** Returns column names of the table. */
getTableColumns(SQLiteDatabase db, String tableName)112     public static Set<String> getTableColumns(SQLiteDatabase db, String tableName) {
113         String query =
114                 "select p.name from sqlite_master s "
115                         + "join pragma_table_info(s.name) p "
116                         + "where s.tbl_name = '"
117                         + tableName
118                         + "'";
119         Cursor cursor = db.rawQuery(query, null);
120         if (cursor == null) {
121             throw new IllegalArgumentException("Cursor is null.");
122         }
123 
124         ImmutableSet.Builder<String> tableColumns = ImmutableSet.builder();
125         while (cursor.moveToNext()) {
126             tableColumns.add(cursor.getString(0));
127         }
128 
129         return tableColumns.build();
130     }
131 
132     /** Return true if the given index exists in the DB. */
doesIndexExist(SQLiteDatabase db, String index)133     public static boolean doesIndexExist(SQLiteDatabase db, String index) {
134         String query = "SELECT * FROM sqlite_master WHERE type='index' and name='" + index + "'";
135         Cursor cursor = db.rawQuery(query, null);
136         return cursor != null && cursor.getCount() > 0;
137     }
138 
doesTableExist(SQLiteDatabase db, String table)139     public static boolean doesTableExist(SQLiteDatabase db, String table) {
140         String query = "SELECT * FROM sqlite_master WHERE type='table' and name='" + table + "'";
141         Cursor cursor = db.rawQuery(query, null);
142         return cursor != null && cursor.getCount() > 0;
143     }
144 
145     /** Return test database name */
getDatabaseNameForTest()146     public static String getDatabaseNameForTest() {
147         return DATABASE_NAME_FOR_TEST;
148     }
149 
assertDatabasesEqual(SQLiteDatabase expectedDb, SQLiteDatabase actualDb)150     public static void assertDatabasesEqual(SQLiteDatabase expectedDb, SQLiteDatabase actualDb) {
151         List<String> expectedTables = getTables(expectedDb);
152         List<String> actualTables = getTables(actualDb);
153         assertArrayEquals(expectedTables.toArray(), actualTables.toArray());
154         assertTableSchemaEqual(expectedDb, actualDb, expectedTables);
155         assertIndexesEqual(expectedDb, actualDb, expectedTables);
156     }
157 
assertTableSchemaEqual( SQLiteDatabase expectedDb, SQLiteDatabase actualDb, List<String> tableNames)158     private static void assertTableSchemaEqual(
159             SQLiteDatabase expectedDb, SQLiteDatabase actualDb, List<String> tableNames) {
160         for (String tableName : tableNames) {
161             Cursor columnsCursorExpected =
162                     expectedDb.rawQuery("PRAGMA TABLE_INFO(" + tableName + ")", null);
163             Cursor columnsCursorActual =
164                     actualDb.rawQuery("PRAGMA TABLE_INFO(" + tableName + ")", null);
165             assertEquals(
166                     "Table columns mismatch for " + tableName,
167                     columnsCursorExpected.getCount(),
168                     columnsCursorActual.getCount());
169 
170             // Checks the columns in order. Newly created columns should be inserted as the end.
171             while (columnsCursorExpected.moveToNext() && columnsCursorActual.moveToNext()) {
172                 assertEquals(
173                         "Column mismatch for " + tableName,
174                         columnsCursorExpected.getString(
175                                 columnsCursorExpected.getColumnIndex("name")),
176                         columnsCursorActual.getString(columnsCursorActual.getColumnIndex("name")));
177                 assertEquals(
178                         "Column mismatch for " + tableName,
179                         columnsCursorExpected.getString(
180                                 columnsCursorExpected.getColumnIndex("type")),
181                         columnsCursorActual.getString(columnsCursorActual.getColumnIndex("type")));
182                 assertEquals(
183                         "Column mismatch for " + tableName,
184                         columnsCursorExpected.getInt(
185                                 columnsCursorExpected.getColumnIndex("notnull")),
186                         columnsCursorActual.getInt(columnsCursorActual.getColumnIndex("notnull")));
187                 assertEquals(
188                         "Column mismatch for " + tableName,
189                         columnsCursorExpected.getString(
190                                 columnsCursorExpected.getColumnIndex("dflt_value")),
191                         columnsCursorActual.getString(
192                                 columnsCursorActual.getColumnIndex("dflt_value")));
193                 assertEquals(
194                         "Column mismatch for " + tableName,
195                         columnsCursorExpected.getInt(columnsCursorExpected.getColumnIndex("pk")),
196                         columnsCursorActual.getInt(columnsCursorActual.getColumnIndex("pk")));
197             }
198 
199             columnsCursorExpected.close();
200             columnsCursorActual.close();
201         }
202     }
203 
getTables(SQLiteDatabase db)204     private static List<String> getTables(SQLiteDatabase db) {
205         String listTableQuery = "SELECT name FROM sqlite_master where type = 'table'";
206         List<String> tables = new ArrayList<>();
207         try (Cursor cursor = db.rawQuery(listTableQuery, null)) {
208             while (cursor.moveToNext()) {
209                 tables.add(cursor.getString(cursor.getColumnIndex("name")));
210             }
211         }
212         Collections.sort(tables);
213         return tables;
214     }
215 
assertIndexesEqual( SQLiteDatabase expectedDb, SQLiteDatabase actualDb, List<String> tables)216     private static void assertIndexesEqual(
217             SQLiteDatabase expectedDb, SQLiteDatabase actualDb, List<String> tables) {
218         for (String tableName : tables) {
219             String indexListQuery =
220                     "SELECT name FROM sqlite_master where type = 'index' AND tbl_name = '"
221                             + tableName
222                             + "' ORDER BY name ASC";
223             Cursor indexListCursorExpected = expectedDb.rawQuery(indexListQuery, null);
224             Cursor indexListCursorActual = actualDb.rawQuery(indexListQuery, null);
225             assertEquals(
226                     "Table indexes mismatch for " + tableName,
227                     indexListCursorExpected.getCount(),
228                     indexListCursorActual.getCount());
229 
230             while (indexListCursorExpected.moveToNext() && indexListCursorActual.moveToNext()) {
231                 String expectedIndexName =
232                         indexListCursorExpected.getString(
233                                 indexListCursorExpected.getColumnIndex("name"));
234                 assertEquals(
235                         "Index mismatch for " + tableName,
236                         expectedIndexName,
237                         indexListCursorActual.getString(
238                                 indexListCursorActual.getColumnIndex("name")));
239 
240                 assertIndexInfoEqual(expectedDb, actualDb, expectedIndexName);
241             }
242 
243             indexListCursorExpected.close();
244             indexListCursorActual.close();
245         }
246     }
247 
assertIndexInfoEqual( SQLiteDatabase expectedDb, SQLiteDatabase actualDb, String indexName)248     private static void assertIndexInfoEqual(
249             SQLiteDatabase expectedDb, SQLiteDatabase actualDb, String indexName) {
250         Cursor indexInfoCursorExpected =
251                 expectedDb.rawQuery("PRAGMA main.INDEX_INFO (" + indexName + ")", null);
252         Cursor indexInfoCursorActual =
253                 actualDb.rawQuery("PRAGMA main.INDEX_INFO (" + indexName + ")", null);
254         assertEquals(
255                 "Index columns count mismatch for " + indexName,
256                 indexInfoCursorExpected.getCount(),
257                 indexInfoCursorActual.getCount());
258 
259         while (indexInfoCursorExpected.moveToNext() && indexInfoCursorActual.moveToNext()) {
260             assertEquals(
261                     "Index info mismatch for " + indexName,
262                     indexInfoCursorExpected.getInt(indexInfoCursorExpected.getColumnIndex("seqno")),
263                     indexInfoCursorActual.getInt(indexInfoCursorActual.getColumnIndex("seqno")));
264             assertEquals(
265                     "Index info mismatch for " + indexName,
266                     indexInfoCursorExpected.getInt(indexInfoCursorExpected.getColumnIndex("cid")),
267                     indexInfoCursorActual.getInt(indexInfoCursorActual.getColumnIndex("cid")));
268             assertEquals(
269                     "Index info mismatch for " + indexName,
270                     indexInfoCursorExpected.getString(
271                             indexInfoCursorExpected.getColumnIndex("name")),
272                     indexInfoCursorActual.getString(indexInfoCursorActual.getColumnIndex("name")));
273         }
274 
275         indexInfoCursorExpected.close();
276         indexInfoCursorActual.close();
277     }
278 }
279