1 /* 2 * Copyright (C) 2022 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.adservices.data; 18 19 import static org.junit.Assert.assertArrayEquals; 20 import static org.junit.Assert.assertEquals; 21 22 import android.content.Context; 23 import android.database.Cursor; 24 import android.database.sqlite.SQLiteDatabase; 25 import android.util.Log; 26 27 import androidx.test.core.app.ApplicationProvider; 28 29 import com.android.adservices.data.measurement.MeasurementDbHelper; 30 import com.android.adservices.data.shared.SharedDbHelper; 31 32 import com.google.common.collect.ImmutableSet; 33 34 import java.util.ArrayList; 35 import java.util.Collections; 36 import java.util.List; 37 import java.util.Set; 38 39 public final class DbTestUtil { 40 private static final Context sContext = ApplicationProvider.getApplicationContext(); 41 private static final String DATABASE_NAME_FOR_TEST = "adservices_test.db"; 42 private static final String MSMT_DATABASE_NAME_FOR_TEST = "adservices_msmt_test.db"; 43 private static final String SHARED_DATABASE_NAME_FOR_TEST = "adservices_shared_test.db"; 44 45 private static DbHelper sSingleton; 46 private static MeasurementDbHelper sMsmtSingleton; 47 private static SharedDbHelper sSharedSingleton; 48 49 /** Erases all data from the table rows */ deleteTable(String tableName)50 public static void deleteTable(String tableName) { 51 SQLiteDatabase db = getDbHelperForTest().safeGetWritableDatabase(); 52 if (db == null) { 53 return; 54 } 55 56 db.delete(tableName, /* whereClause= */ null, /* whereArgs= */ null); 57 } 58 59 /** 60 * Create an instance of database instance for testing. 61 * 62 * @return a test database 63 */ getDbHelperForTest()64 public static DbHelper getDbHelperForTest() { 65 synchronized (DbHelper.class) { 66 if (sSingleton == null) { 67 sSingleton = 68 new DbHelper(sContext, DATABASE_NAME_FOR_TEST, DbHelper.DATABASE_VERSION); 69 } 70 return sSingleton; 71 } 72 } 73 getMeasurementDbHelperForTest()74 public static MeasurementDbHelper getMeasurementDbHelperForTest() { 75 synchronized (MeasurementDbHelper.class) { 76 if (sMsmtSingleton == null) { 77 sMsmtSingleton = 78 new MeasurementDbHelper( 79 sContext, 80 MSMT_DATABASE_NAME_FOR_TEST, 81 MeasurementDbHelper.CURRENT_DATABASE_VERSION, 82 getDbHelperForTest()); 83 } 84 return sMsmtSingleton; 85 } 86 } 87 getSharedDbHelperForTest()88 public static SharedDbHelper getSharedDbHelperForTest() { 89 synchronized (SharedDbHelper.class) { 90 if (sSharedSingleton == null) { 91 sSharedSingleton = 92 new SharedDbHelper( 93 sContext, 94 SHARED_DATABASE_NAME_FOR_TEST, 95 SharedDbHelper.CURRENT_DATABASE_VERSION, 96 getDbHelperForTest()); 97 } 98 return sSharedSingleton; 99 } 100 } 101 102 /** Return true if table exists in the DB and column count matches. */ doesTableExistAndColumnCountMatch( SQLiteDatabase db, String tableName, int columnCount)103 public static boolean doesTableExistAndColumnCountMatch( 104 SQLiteDatabase db, String tableName, int columnCount) { 105 final Set<String> tableColumns = getTableColumns(db, tableName); 106 int actualCol = tableColumns.size(); 107 Log.d("DbTestUtil_log_test,", " table name: " + tableName + " column count: " + actualCol); 108 return tableColumns.size() == columnCount; 109 } 110 111 /** Returns column names of the table. */ getTableColumns(SQLiteDatabase db, String tableName)112 public static Set<String> getTableColumns(SQLiteDatabase db, String tableName) { 113 String query = 114 "select p.name from sqlite_master s " 115 + "join pragma_table_info(s.name) p " 116 + "where s.tbl_name = '" 117 + tableName 118 + "'"; 119 Cursor cursor = db.rawQuery(query, null); 120 if (cursor == null) { 121 throw new IllegalArgumentException("Cursor is null."); 122 } 123 124 ImmutableSet.Builder<String> tableColumns = ImmutableSet.builder(); 125 while (cursor.moveToNext()) { 126 tableColumns.add(cursor.getString(0)); 127 } 128 129 return tableColumns.build(); 130 } 131 132 /** Return true if the given index exists in the DB. */ doesIndexExist(SQLiteDatabase db, String index)133 public static boolean doesIndexExist(SQLiteDatabase db, String index) { 134 String query = "SELECT * FROM sqlite_master WHERE type='index' and name='" + index + "'"; 135 Cursor cursor = db.rawQuery(query, null); 136 return cursor != null && cursor.getCount() > 0; 137 } 138 doesTableExist(SQLiteDatabase db, String table)139 public static boolean doesTableExist(SQLiteDatabase db, String table) { 140 String query = "SELECT * FROM sqlite_master WHERE type='table' and name='" + table + "'"; 141 Cursor cursor = db.rawQuery(query, null); 142 return cursor != null && cursor.getCount() > 0; 143 } 144 145 /** Return test database name */ getDatabaseNameForTest()146 public static String getDatabaseNameForTest() { 147 return DATABASE_NAME_FOR_TEST; 148 } 149 assertDatabasesEqual(SQLiteDatabase expectedDb, SQLiteDatabase actualDb)150 public static void assertDatabasesEqual(SQLiteDatabase expectedDb, SQLiteDatabase actualDb) { 151 List<String> expectedTables = getTables(expectedDb); 152 List<String> actualTables = getTables(actualDb); 153 assertArrayEquals(expectedTables.toArray(), actualTables.toArray()); 154 assertTableSchemaEqual(expectedDb, actualDb, expectedTables); 155 assertIndexesEqual(expectedDb, actualDb, expectedTables); 156 } 157 assertTableSchemaEqual( SQLiteDatabase expectedDb, SQLiteDatabase actualDb, List<String> tableNames)158 private static void assertTableSchemaEqual( 159 SQLiteDatabase expectedDb, SQLiteDatabase actualDb, List<String> tableNames) { 160 for (String tableName : tableNames) { 161 Cursor columnsCursorExpected = 162 expectedDb.rawQuery("PRAGMA TABLE_INFO(" + tableName + ")", null); 163 Cursor columnsCursorActual = 164 actualDb.rawQuery("PRAGMA TABLE_INFO(" + tableName + ")", null); 165 assertEquals( 166 "Table columns mismatch for " + tableName, 167 columnsCursorExpected.getCount(), 168 columnsCursorActual.getCount()); 169 170 // Checks the columns in order. Newly created columns should be inserted as the end. 171 while (columnsCursorExpected.moveToNext() && columnsCursorActual.moveToNext()) { 172 assertEquals( 173 "Column mismatch for " + tableName, 174 columnsCursorExpected.getString( 175 columnsCursorExpected.getColumnIndex("name")), 176 columnsCursorActual.getString(columnsCursorActual.getColumnIndex("name"))); 177 assertEquals( 178 "Column mismatch for " + tableName, 179 columnsCursorExpected.getString( 180 columnsCursorExpected.getColumnIndex("type")), 181 columnsCursorActual.getString(columnsCursorActual.getColumnIndex("type"))); 182 assertEquals( 183 "Column mismatch for " + tableName, 184 columnsCursorExpected.getInt( 185 columnsCursorExpected.getColumnIndex("notnull")), 186 columnsCursorActual.getInt(columnsCursorActual.getColumnIndex("notnull"))); 187 assertEquals( 188 "Column mismatch for " + tableName, 189 columnsCursorExpected.getString( 190 columnsCursorExpected.getColumnIndex("dflt_value")), 191 columnsCursorActual.getString( 192 columnsCursorActual.getColumnIndex("dflt_value"))); 193 assertEquals( 194 "Column mismatch for " + tableName, 195 columnsCursorExpected.getInt(columnsCursorExpected.getColumnIndex("pk")), 196 columnsCursorActual.getInt(columnsCursorActual.getColumnIndex("pk"))); 197 } 198 199 columnsCursorExpected.close(); 200 columnsCursorActual.close(); 201 } 202 } 203 getTables(SQLiteDatabase db)204 private static List<String> getTables(SQLiteDatabase db) { 205 String listTableQuery = "SELECT name FROM sqlite_master where type = 'table'"; 206 List<String> tables = new ArrayList<>(); 207 try (Cursor cursor = db.rawQuery(listTableQuery, null)) { 208 while (cursor.moveToNext()) { 209 tables.add(cursor.getString(cursor.getColumnIndex("name"))); 210 } 211 } 212 Collections.sort(tables); 213 return tables; 214 } 215 assertIndexesEqual( SQLiteDatabase expectedDb, SQLiteDatabase actualDb, List<String> tables)216 private static void assertIndexesEqual( 217 SQLiteDatabase expectedDb, SQLiteDatabase actualDb, List<String> tables) { 218 for (String tableName : tables) { 219 String indexListQuery = 220 "SELECT name FROM sqlite_master where type = 'index' AND tbl_name = '" 221 + tableName 222 + "' ORDER BY name ASC"; 223 Cursor indexListCursorExpected = expectedDb.rawQuery(indexListQuery, null); 224 Cursor indexListCursorActual = actualDb.rawQuery(indexListQuery, null); 225 assertEquals( 226 "Table indexes mismatch for " + tableName, 227 indexListCursorExpected.getCount(), 228 indexListCursorActual.getCount()); 229 230 while (indexListCursorExpected.moveToNext() && indexListCursorActual.moveToNext()) { 231 String expectedIndexName = 232 indexListCursorExpected.getString( 233 indexListCursorExpected.getColumnIndex("name")); 234 assertEquals( 235 "Index mismatch for " + tableName, 236 expectedIndexName, 237 indexListCursorActual.getString( 238 indexListCursorActual.getColumnIndex("name"))); 239 240 assertIndexInfoEqual(expectedDb, actualDb, expectedIndexName); 241 } 242 243 indexListCursorExpected.close(); 244 indexListCursorActual.close(); 245 } 246 } 247 assertIndexInfoEqual( SQLiteDatabase expectedDb, SQLiteDatabase actualDb, String indexName)248 private static void assertIndexInfoEqual( 249 SQLiteDatabase expectedDb, SQLiteDatabase actualDb, String indexName) { 250 Cursor indexInfoCursorExpected = 251 expectedDb.rawQuery("PRAGMA main.INDEX_INFO (" + indexName + ")", null); 252 Cursor indexInfoCursorActual = 253 actualDb.rawQuery("PRAGMA main.INDEX_INFO (" + indexName + ")", null); 254 assertEquals( 255 "Index columns count mismatch for " + indexName, 256 indexInfoCursorExpected.getCount(), 257 indexInfoCursorActual.getCount()); 258 259 while (indexInfoCursorExpected.moveToNext() && indexInfoCursorActual.moveToNext()) { 260 assertEquals( 261 "Index info mismatch for " + indexName, 262 indexInfoCursorExpected.getInt(indexInfoCursorExpected.getColumnIndex("seqno")), 263 indexInfoCursorActual.getInt(indexInfoCursorActual.getColumnIndex("seqno"))); 264 assertEquals( 265 "Index info mismatch for " + indexName, 266 indexInfoCursorExpected.getInt(indexInfoCursorExpected.getColumnIndex("cid")), 267 indexInfoCursorActual.getInt(indexInfoCursorActual.getColumnIndex("cid"))); 268 assertEquals( 269 "Index info mismatch for " + indexName, 270 indexInfoCursorExpected.getString( 271 indexInfoCursorExpected.getColumnIndex("name")), 272 indexInfoCursorActual.getString(indexInfoCursorActual.getColumnIndex("name"))); 273 } 274 275 indexInfoCursorExpected.close(); 276 indexInfoCursorActual.close(); 277 } 278 } 279