• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.federatedcompute.services.data;
18 
19 import static com.android.federatedcompute.services.data.FederatedTraningTaskContract.FEDERATED_TRAINING_TASKS_TABLE;
20 
21 import android.annotation.NonNull;
22 import android.annotation.Nullable;
23 import android.content.Context;
24 import android.database.sqlite.SQLiteDatabase;
25 import android.database.sqlite.SQLiteException;
26 import android.database.sqlite.SQLiteOpenHelper;
27 import android.util.Log;
28 
29 import com.android.federatedcompute.services.data.FederatedTraningTaskContract.FederatedTrainingTaskColumns;
30 
31 import com.google.common.annotations.VisibleForTesting;
32 import com.google.common.collect.Iterables;
33 
34 import java.util.List;
35 
36 /** DAO for accessing training task table. */
37 public class FederatedTrainingTaskDao {
38 
39     private static final String TAG = "FederatedTrainingTaskDao";
40 
41     private final SQLiteOpenHelper mDbHelper;
42     private static FederatedTrainingTaskDao sSingletonInstance;
43 
FederatedTrainingTaskDao(SQLiteOpenHelper dbHelper)44     private FederatedTrainingTaskDao(SQLiteOpenHelper dbHelper) {
45         this.mDbHelper = dbHelper;
46     }
47 
48     /** Returns an instance of the FederatedTrainingTaskDao given a context. */
49     @NonNull
getInstance(Context context)50     public static FederatedTrainingTaskDao getInstance(Context context) {
51         synchronized (FederatedTrainingTaskDao.class) {
52             if (sSingletonInstance == null) {
53                 sSingletonInstance =
54                         new FederatedTrainingTaskDao(
55                                 FederatedTrainingTaskDbHelper.getInstance(context));
56             }
57             return sSingletonInstance;
58         }
59     }
60 
61     /** It's only public to unit test. */
62     @VisibleForTesting
getInstanceForTest(Context context)63     public static FederatedTrainingTaskDao getInstanceForTest(Context context) {
64         synchronized (FederatedTrainingTaskDao.class) {
65             if (sSingletonInstance == null) {
66                 FederatedTrainingTaskDbHelper dbHelper =
67                         FederatedTrainingTaskDbHelper.getInstanceForTest(context);
68                 sSingletonInstance = new FederatedTrainingTaskDao(dbHelper);
69             }
70             return sSingletonInstance;
71         }
72     }
73 
74     /** Deletes a training task in FederatedTrainingTask table. */
deleteFederatedTrainingTask(String selection, String[] selectionArgs)75     private void deleteFederatedTrainingTask(String selection, String[] selectionArgs) {
76         SQLiteDatabase db = getWritableDatabase();
77         if (db == null) {
78             return;
79         }
80         db.delete(FEDERATED_TRAINING_TASKS_TABLE, selection, selectionArgs);
81     }
82 
83     /** Insert a training task or update it if task already exists. */
updateOrInsertFederatedTrainingTask(FederatedTrainingTask trainingTask)84     public boolean updateOrInsertFederatedTrainingTask(FederatedTrainingTask trainingTask) {
85         SQLiteDatabase db = getWritableDatabase();
86         if (db == null) {
87             throw new SQLiteException("Failed to open database.");
88         }
89         return trainingTask.addToDatabase(db);
90     }
91 
92     /** Get the list of tasks that match select conditions. */
93     @Nullable
getFederatedTrainingTask( String selection, String[] selectionArgs)94     public List<FederatedTrainingTask> getFederatedTrainingTask(
95             String selection, String[] selectionArgs) {
96         SQLiteDatabase db = mDbHelper.getReadableDatabase();
97         if (db == null) {
98             return null;
99         }
100         return FederatedTrainingTask.readFederatedTrainingTasksFromDatabase(
101                 db, selection, selectionArgs);
102     }
103 
104     /** Delete a task from table based on job scheduler id. */
findAndRemoveTaskByJobId(int jobId)105     public FederatedTrainingTask findAndRemoveTaskByJobId(int jobId) {
106         String selection = FederatedTrainingTaskColumns.JOB_SCHEDULER_JOB_ID + " = ?";
107         String[] selectionArgs = selectionArgs(jobId);
108         FederatedTrainingTask task =
109                 Iterables.getOnlyElement(getFederatedTrainingTask(selection, selectionArgs), null);
110         if (task != null) {
111             deleteFederatedTrainingTask(selection, selectionArgs);
112         }
113         return task;
114     }
115 
116     /** Delete a task from table based on population name. */
findAndRemoveTaskByPopulationName(String populationName)117     public FederatedTrainingTask findAndRemoveTaskByPopulationName(String populationName) {
118         String selection = FederatedTrainingTaskColumns.POPULATION_NAME + " = ?";
119         String[] selectionArgs = {populationName};
120         FederatedTrainingTask task =
121                 Iterables.getOnlyElement(getFederatedTrainingTask(selection, selectionArgs), null);
122         if (task != null) {
123             deleteFederatedTrainingTask(selection, selectionArgs);
124         }
125         return task;
126     }
127 
128     /** Delete a task from table based on population name and job scheduler id. */
findAndRemoveTaskByPopulationAndJobId( String populationName, int jobId)129     public FederatedTrainingTask findAndRemoveTaskByPopulationAndJobId(
130             String populationName, int jobId) {
131         String selection =
132                 FederatedTrainingTaskColumns.POPULATION_NAME
133                         + " = ? AND "
134                         + FederatedTrainingTaskColumns.JOB_SCHEDULER_JOB_ID
135                         + " = ?";
136         String[] selectionArgs = {populationName, String.valueOf(jobId)};
137         FederatedTrainingTask task =
138                 Iterables.getOnlyElement(getFederatedTrainingTask(selection, selectionArgs), null);
139         if (task != null) {
140             deleteFederatedTrainingTask(selection, selectionArgs);
141         }
142         return task;
143     }
144 
selectionArgs(Number... args)145     private String[] selectionArgs(Number... args) {
146         String[] values = new String[args.length];
147         for (int i = 0; i < args.length; i++) {
148             values[i] = String.valueOf(args[i]);
149         }
150         return values;
151     }
152 
153     /** It's only public to unit test. Clears all records in task table. */
154     @VisibleForTesting
clearDatabase()155     public boolean clearDatabase() {
156         SQLiteDatabase db = getWritableDatabase();
157         if (db == null) {
158             return false;
159         }
160         db.delete(FEDERATED_TRAINING_TASKS_TABLE, null, null);
161         return true;
162     }
163 
164     /* Returns a writable database object or null if error occurs. */
165     @Nullable
getWritableDatabase()166     private SQLiteDatabase getWritableDatabase() {
167         try {
168             return mDbHelper.getWritableDatabase();
169         } catch (SQLiteException e) {
170             Log.e(TAG, "Failed to open the database.", e);
171         }
172         return null;
173     }
174 }
175