1 /* 2 * Copyright (C) 2023 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.federatedcompute.services.data; 18 19 import static com.android.federatedcompute.services.data.FederatedTraningTaskContract.FEDERATED_TRAINING_TASKS_TABLE; 20 21 import android.annotation.NonNull; 22 import android.annotation.Nullable; 23 import android.content.Context; 24 import android.database.sqlite.SQLiteDatabase; 25 import android.database.sqlite.SQLiteException; 26 import android.database.sqlite.SQLiteOpenHelper; 27 import android.util.Log; 28 29 import com.android.federatedcompute.services.data.FederatedTraningTaskContract.FederatedTrainingTaskColumns; 30 31 import com.google.common.annotations.VisibleForTesting; 32 import com.google.common.collect.Iterables; 33 34 import java.util.List; 35 36 /** DAO for accessing training task table. */ 37 public class FederatedTrainingTaskDao { 38 39 private static final String TAG = "FederatedTrainingTaskDao"; 40 41 private final SQLiteOpenHelper mDbHelper; 42 private static FederatedTrainingTaskDao sSingletonInstance; 43 FederatedTrainingTaskDao(SQLiteOpenHelper dbHelper)44 private FederatedTrainingTaskDao(SQLiteOpenHelper dbHelper) { 45 this.mDbHelper = dbHelper; 46 } 47 48 /** Returns an instance of the FederatedTrainingTaskDao given a context. */ 49 @NonNull getInstance(Context context)50 public static FederatedTrainingTaskDao getInstance(Context context) { 51 synchronized (FederatedTrainingTaskDao.class) { 52 if (sSingletonInstance == null) { 53 sSingletonInstance = 54 new FederatedTrainingTaskDao( 55 FederatedTrainingTaskDbHelper.getInstance(context)); 56 } 57 return sSingletonInstance; 58 } 59 } 60 61 /** It's only public to unit test. */ 62 @VisibleForTesting getInstanceForTest(Context context)63 public static FederatedTrainingTaskDao getInstanceForTest(Context context) { 64 synchronized (FederatedTrainingTaskDao.class) { 65 if (sSingletonInstance == null) { 66 FederatedTrainingTaskDbHelper dbHelper = 67 FederatedTrainingTaskDbHelper.getInstanceForTest(context); 68 sSingletonInstance = new FederatedTrainingTaskDao(dbHelper); 69 } 70 return sSingletonInstance; 71 } 72 } 73 74 /** Deletes a training task in FederatedTrainingTask table. */ deleteFederatedTrainingTask(String selection, String[] selectionArgs)75 private void deleteFederatedTrainingTask(String selection, String[] selectionArgs) { 76 SQLiteDatabase db = getWritableDatabase(); 77 if (db == null) { 78 return; 79 } 80 db.delete(FEDERATED_TRAINING_TASKS_TABLE, selection, selectionArgs); 81 } 82 83 /** Insert a training task or update it if task already exists. */ updateOrInsertFederatedTrainingTask(FederatedTrainingTask trainingTask)84 public boolean updateOrInsertFederatedTrainingTask(FederatedTrainingTask trainingTask) { 85 SQLiteDatabase db = getWritableDatabase(); 86 if (db == null) { 87 throw new SQLiteException("Failed to open database."); 88 } 89 return trainingTask.addToDatabase(db); 90 } 91 92 /** Get the list of tasks that match select conditions. */ 93 @Nullable getFederatedTrainingTask( String selection, String[] selectionArgs)94 public List<FederatedTrainingTask> getFederatedTrainingTask( 95 String selection, String[] selectionArgs) { 96 SQLiteDatabase db = mDbHelper.getReadableDatabase(); 97 if (db == null) { 98 return null; 99 } 100 return FederatedTrainingTask.readFederatedTrainingTasksFromDatabase( 101 db, selection, selectionArgs); 102 } 103 104 /** Delete a task from table based on job scheduler id. */ findAndRemoveTaskByJobId(int jobId)105 public FederatedTrainingTask findAndRemoveTaskByJobId(int jobId) { 106 String selection = FederatedTrainingTaskColumns.JOB_SCHEDULER_JOB_ID + " = ?"; 107 String[] selectionArgs = selectionArgs(jobId); 108 FederatedTrainingTask task = 109 Iterables.getOnlyElement(getFederatedTrainingTask(selection, selectionArgs), null); 110 if (task != null) { 111 deleteFederatedTrainingTask(selection, selectionArgs); 112 } 113 return task; 114 } 115 116 /** Delete a task from table based on population name. */ findAndRemoveTaskByPopulationName(String populationName)117 public FederatedTrainingTask findAndRemoveTaskByPopulationName(String populationName) { 118 String selection = FederatedTrainingTaskColumns.POPULATION_NAME + " = ?"; 119 String[] selectionArgs = {populationName}; 120 FederatedTrainingTask task = 121 Iterables.getOnlyElement(getFederatedTrainingTask(selection, selectionArgs), null); 122 if (task != null) { 123 deleteFederatedTrainingTask(selection, selectionArgs); 124 } 125 return task; 126 } 127 128 /** Delete a task from table based on population name and job scheduler id. */ findAndRemoveTaskByPopulationAndJobId( String populationName, int jobId)129 public FederatedTrainingTask findAndRemoveTaskByPopulationAndJobId( 130 String populationName, int jobId) { 131 String selection = 132 FederatedTrainingTaskColumns.POPULATION_NAME 133 + " = ? AND " 134 + FederatedTrainingTaskColumns.JOB_SCHEDULER_JOB_ID 135 + " = ?"; 136 String[] selectionArgs = {populationName, String.valueOf(jobId)}; 137 FederatedTrainingTask task = 138 Iterables.getOnlyElement(getFederatedTrainingTask(selection, selectionArgs), null); 139 if (task != null) { 140 deleteFederatedTrainingTask(selection, selectionArgs); 141 } 142 return task; 143 } 144 selectionArgs(Number... args)145 private String[] selectionArgs(Number... args) { 146 String[] values = new String[args.length]; 147 for (int i = 0; i < args.length; i++) { 148 values[i] = String.valueOf(args[i]); 149 } 150 return values; 151 } 152 153 /** It's only public to unit test. Clears all records in task table. */ 154 @VisibleForTesting clearDatabase()155 public boolean clearDatabase() { 156 SQLiteDatabase db = getWritableDatabase(); 157 if (db == null) { 158 return false; 159 } 160 db.delete(FEDERATED_TRAINING_TASKS_TABLE, null, null); 161 return true; 162 } 163 164 /* Returns a writable database object or null if error occurs. */ 165 @Nullable getWritableDatabase()166 private SQLiteDatabase getWritableDatabase() { 167 try { 168 return mDbHelper.getWritableDatabase(); 169 } catch (SQLiteException e) { 170 Log.e(TAG, "Failed to open the database.", e); 171 } 172 return null; 173 } 174 } 175