1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.textclassifier.downloader; 18 19 import androidx.annotation.IntDef; 20 import androidx.annotation.NonNull; 21 import androidx.room.ColumnInfo; 22 import androidx.room.Dao; 23 import androidx.room.Database; 24 import androidx.room.DatabaseView; 25 import androidx.room.Delete; 26 import androidx.room.Embedded; 27 import androidx.room.Entity; 28 import androidx.room.ForeignKey; 29 import androidx.room.Index; 30 import androidx.room.Insert; 31 import androidx.room.OnConflictStrategy; 32 import androidx.room.Query; 33 import androidx.room.RoomDatabase; 34 import androidx.room.Transaction; 35 import com.android.textclassifier.common.ModelType.ModelTypeDef; 36 import com.android.textclassifier.utils.IndentingPrintWriter; 37 import com.google.auto.value.AutoValue; 38 import com.google.common.collect.Iterables; 39 import java.lang.annotation.Retention; 40 import java.lang.annotation.RetentionPolicy; 41 import java.util.List; 42 import java.util.concurrent.ExecutorService; 43 44 /** Database storing info about models downloaded by model downloader */ 45 @Database( 46 entities = { 47 DownloadedModelDatabase.Model.class, 48 DownloadedModelDatabase.Manifest.class, 49 DownloadedModelDatabase.ManifestModelCrossRef.class, 50 DownloadedModelDatabase.ManifestEnrollment.class 51 }, 52 views = {DownloadedModelDatabase.ModelView.class}, 53 version = 1, 54 exportSchema = true) 55 abstract class DownloadedModelDatabase extends RoomDatabase { 56 public static final String TAG = "DownloadedModelDatabase"; 57 58 /** Rpresents a downloaded model file. */ 59 @AutoValue 60 @Entity( 61 tableName = "model", 62 primaryKeys = {"model_url"}) 63 abstract static class Model { 64 @AutoValue.CopyAnnotations 65 @ColumnInfo(name = "model_url") 66 @NonNull getModelUrl()67 public abstract String getModelUrl(); 68 69 @AutoValue.CopyAnnotations 70 @ColumnInfo(name = "model_path") 71 @NonNull getModelPath()72 public abstract String getModelPath(); 73 create(String modelUrl, String modelPath)74 public static Model create(String modelUrl, String modelPath) { 75 return new AutoValue_DownloadedModelDatabase_Model(modelUrl, modelPath); 76 } 77 } 78 79 /** Rpresents a manifest we processed. */ 80 @AutoValue 81 @Entity( 82 tableName = "manifest", 83 primaryKeys = {"manifest_url"}) 84 abstract static class Manifest { 85 // TODO(licha): Consider using Enum here 86 @Retention(RetentionPolicy.SOURCE) 87 @IntDef({STATUS_UNKNOWN, STATUS_FAILED, STATUS_SUCCEEDED}) 88 @interface StatusDef {} 89 90 public static final int STATUS_UNKNOWN = 0; 91 /** Failed to download this manifest. Could be retried in the future. */ 92 public static final int STATUS_FAILED = 1; 93 /** Downloaded this manifest successfully and it's currently in storage. */ 94 public static final int STATUS_SUCCEEDED = 2; 95 96 @AutoValue.CopyAnnotations 97 @ColumnInfo(name = "manifest_url") 98 @NonNull getManifestUrl()99 public abstract String getManifestUrl(); 100 101 @AutoValue.CopyAnnotations 102 @ColumnInfo(name = "status") 103 @StatusDef getStatus()104 public abstract int getStatus(); 105 106 @AutoValue.CopyAnnotations 107 @ColumnInfo(name = "failure_counts") getFailureCounts()108 public abstract int getFailureCounts(); 109 create(String manifestUrl, @StatusDef int status, int failureCounts)110 public static Manifest create(String manifestUrl, @StatusDef int status, int failureCounts) { 111 return new AutoValue_DownloadedModelDatabase_Manifest(manifestUrl, status, failureCounts); 112 } 113 } 114 115 /** 116 * Represents the relationship between manfiests and downloaded models. 117 * 118 * <p>A manifest can include multiple models, a model can also be included in multiple manifests. 119 * In different manifests, a model may have different configurations (e.g. primary model in 120 * manfiest A but dark model in B). 121 */ 122 @AutoValue 123 @Entity( 124 tableName = "manifest_model_cross_ref", 125 primaryKeys = {"manifest_url", "model_url"}, 126 foreignKeys = { 127 @ForeignKey( 128 entity = Manifest.class, 129 parentColumns = "manifest_url", 130 childColumns = "manifest_url", 131 onDelete = ForeignKey.CASCADE), 132 @ForeignKey( 133 entity = Model.class, 134 parentColumns = "model_url", 135 childColumns = "model_url", 136 onDelete = ForeignKey.CASCADE), 137 }, 138 indices = { 139 @Index(value = {"manifest_url"}), 140 @Index(value = {"model_url"}), 141 }) 142 abstract static class ManifestModelCrossRef { 143 @AutoValue.CopyAnnotations 144 @ColumnInfo(name = "manifest_url") 145 @NonNull getManifestUrl()146 public abstract String getManifestUrl(); 147 148 @AutoValue.CopyAnnotations 149 @ColumnInfo(name = "model_url") 150 @NonNull getModelUrl()151 public abstract String getModelUrl(); 152 create(String manifestUrl, String modelUrl)153 public static ManifestModelCrossRef create(String manifestUrl, String modelUrl) { 154 return new AutoValue_DownloadedModelDatabase_ManifestModelCrossRef(manifestUrl, modelUrl); 155 } 156 } 157 158 /** 159 * Represents the relationship between user scenarios and manifests. 160 * 161 * <p>For each unique user scenario (i.e. modelType + localTag), we store the manifest we should 162 * use. The same manifest can be used for different scenarios. 163 */ 164 @AutoValue 165 @Entity( 166 tableName = "manifest_enrollment", 167 primaryKeys = {"model_type", "locale_tag"}, 168 foreignKeys = { 169 @ForeignKey( 170 entity = Manifest.class, 171 parentColumns = "manifest_url", 172 childColumns = "manifest_url", 173 onDelete = ForeignKey.CASCADE) 174 }, 175 indices = {@Index(value = {"manifest_url"})}) 176 abstract static class ManifestEnrollment { 177 @AutoValue.CopyAnnotations 178 @ColumnInfo(name = "model_type") 179 @NonNull 180 @ModelTypeDef getModelType()181 public abstract String getModelType(); 182 183 @AutoValue.CopyAnnotations 184 @ColumnInfo(name = "locale_tag") 185 @NonNull getLocaleTag()186 public abstract String getLocaleTag(); 187 188 @AutoValue.CopyAnnotations 189 @ColumnInfo(name = "manifest_url") 190 @NonNull getManifestUrl()191 public abstract String getManifestUrl(); 192 create( @odelTypeDef String modelType, String localeTag, String manifestUrl)193 public static ManifestEnrollment create( 194 @ModelTypeDef String modelType, String localeTag, String manifestUrl) { 195 return new AutoValue_DownloadedModelDatabase_ManifestEnrollment( 196 modelType, localeTag, manifestUrl); 197 } 198 } 199 200 /** Represents the mapping from manfiest enrollments to models. */ 201 @AutoValue 202 @DatabaseView( 203 value = 204 "SELECT manifest_enrollment.*, model.* " 205 + "FROM manifest_enrollment " 206 + "INNER JOIN manifest_model_cross_ref " 207 + "ON manifest_enrollment.manifest_url = manifest_model_cross_ref.manifest_url " 208 + "INNER JOIN model " 209 + "ON manifest_model_cross_ref.model_url = model.model_url", 210 viewName = "model_view") 211 abstract static class ModelView { 212 @AutoValue.CopyAnnotations 213 @Embedded 214 @NonNull getManifestEnrollment()215 public abstract ManifestEnrollment getManifestEnrollment(); 216 217 @AutoValue.CopyAnnotations 218 @Embedded 219 @NonNull getModel()220 public abstract Model getModel(); 221 create(ManifestEnrollment manifestEnrollment, Model model)222 public static ModelView create(ManifestEnrollment manifestEnrollment, Model model) { 223 return new AutoValue_DownloadedModelDatabase_ModelView(manifestEnrollment, model); 224 } 225 } 226 227 @Dao 228 abstract static class DownloadedModelDatabaseDao { 229 // Full table scan 230 @Query("SELECT * FROM model") queryAllModels()231 abstract List<Model> queryAllModels(); 232 233 @Query("SELECT * FROM manifest") queryAllManifests()234 abstract List<Manifest> queryAllManifests(); 235 236 @Query("SELECT * FROM manifest_model_cross_ref") queryAllManifestModelCrossRefs()237 abstract List<ManifestModelCrossRef> queryAllManifestModelCrossRefs(); 238 239 @Query("SELECT * FROM manifest_enrollment") queryAllManifestEnrollments()240 abstract List<ManifestEnrollment> queryAllManifestEnrollments(); 241 242 @Query("SELECT * FROM model_view") queryAllModelViews()243 abstract List<ModelView> queryAllModelViews(); 244 245 // Single table query 246 @Query("SELECT * FROM model WHERE model_url = :modelUrl") queryModelWithModelUrl(String modelUrl)247 abstract List<Model> queryModelWithModelUrl(String modelUrl); 248 249 @Query("SELECT * FROM manifest WHERE manifest_url = :manifestUrl") queryManifestWithManifestUrl(String manifestUrl)250 abstract List<Manifest> queryManifestWithManifestUrl(String manifestUrl); 251 252 @Query( 253 "SELECT * FROM manifest_enrollment WHERE model_type = :modelType " 254 + "AND locale_tag = :localeTag") queryManifestEnrollmentWithModelTypeAndLocaleTag( @odelTypeDef String modelType, String localeTag)255 abstract List<ManifestEnrollment> queryManifestEnrollmentWithModelTypeAndLocaleTag( 256 @ModelTypeDef String modelType, String localeTag); 257 258 // Helpers for clean up 259 @Query( 260 "SELECT manifest.* FROM manifest " 261 + "LEFT JOIN model_view " 262 + "ON manifest.manifest_url = model_view.manifest_url " 263 + "WHERE model_view.manifest_url IS NULL " 264 + "AND manifest.status = 2") queryUnusedManifests()265 abstract List<Manifest> queryUnusedManifests(); 266 267 @Query( 268 "SELECT * FROM manifest WHERE manifest.status = 1 " 269 + "AND manifest.manifest_url NOT IN (:manifestUrlsToKeep)") queryUnusedManifestFailureRecords(List<String> manifestUrlsToKeep)270 abstract List<Manifest> queryUnusedManifestFailureRecords(List<String> manifestUrlsToKeep); 271 272 @Query( 273 "SELECT model.* FROM model LEFT JOIN model_view " 274 + "ON model.model_url = model_view.model_url " 275 + "WHERE model_view.model_url IS NULL") queryUnusedModels()276 abstract List<Model> queryUnusedModels(); 277 278 // Insertion 279 @Insert(onConflict = OnConflictStrategy.REPLACE) insert(Model model)280 abstract void insert(Model model); 281 282 @Insert(onConflict = OnConflictStrategy.REPLACE) insert(Manifest manifest)283 abstract void insert(Manifest manifest); 284 285 @Insert(onConflict = OnConflictStrategy.REPLACE) insert(ManifestModelCrossRef manifestModelCrossRef)286 abstract void insert(ManifestModelCrossRef manifestModelCrossRef); 287 288 @Insert(onConflict = OnConflictStrategy.REPLACE) insert(ManifestEnrollment manifestEnrollment)289 abstract void insert(ManifestEnrollment manifestEnrollment); 290 291 @Transaction insertManifestAndModelCrossRef(String manifestUrl, String modelUrl)292 void insertManifestAndModelCrossRef(String manifestUrl, String modelUrl) { 293 insert(Manifest.create(manifestUrl, Manifest.STATUS_SUCCEEDED, /* failureCounts= */ 0)); 294 insert(ManifestModelCrossRef.create(manifestUrl, modelUrl)); 295 } 296 297 @Transaction increaseManifestFailureCounts(String manifestUrl)298 void increaseManifestFailureCounts(String manifestUrl) { 299 List<Manifest> manifests = queryManifestWithManifestUrl(manifestUrl); 300 if (manifests.isEmpty()) { 301 insert(Manifest.create(manifestUrl, Manifest.STATUS_FAILED, /* failureCounts= */ 1)); 302 } else { 303 Manifest prevManifest = Iterables.getOnlyElement(manifests); 304 insert( 305 Manifest.create( 306 manifestUrl, Manifest.STATUS_FAILED, prevManifest.getFailureCounts() + 1)); 307 } 308 } 309 310 // Deletion 311 @Delete deleteModels(List<Model> models)312 abstract void deleteModels(List<Model> models); 313 314 @Delete deleteManifests(List<Manifest> manifests)315 abstract void deleteManifests(List<Manifest> manifests); 316 317 @Delete deleteManifestModelCrossRefs(List<ManifestModelCrossRef> manifestModelCrossRefs)318 abstract void deleteManifestModelCrossRefs(List<ManifestModelCrossRef> manifestModelCrossRefs); 319 320 @Delete deleteManifestEnrollments(List<ManifestEnrollment> manifestEnrollments)321 abstract void deleteManifestEnrollments(List<ManifestEnrollment> manifestEnrollments); 322 323 @Transaction deleteUnusedManifestsAndModels()324 void deleteUnusedManifestsAndModels() { 325 // Because Manifest table is the parent table of ManifestModelCrossRef table, related cross 326 // ref row in that table will be deleted automatically 327 deleteManifests(queryUnusedManifests()); 328 deleteModels(queryUnusedModels()); 329 } 330 331 @Transaction deleteUnusedManifestFailureRecords(List<String> manifestUrlsToKeep)332 void deleteUnusedManifestFailureRecords(List<String> manifestUrlsToKeep) { 333 deleteManifests(queryUnusedManifestFailureRecords(manifestUrlsToKeep)); 334 } 335 } 336 dao()337 abstract DownloadedModelDatabaseDao dao(); 338 339 /** Dump the database for debugging. */ dump(IndentingPrintWriter printWriter, ExecutorService executorService)340 void dump(IndentingPrintWriter printWriter, ExecutorService executorService) { 341 printWriter.println("DownloadedModelDatabase"); 342 printWriter.increaseIndent(); 343 printWriter.println("Model Table:"); 344 printWriter.increaseIndent(); 345 List<Model> models = dao().queryAllModels(); 346 for (Model model : models) { 347 printWriter.println(model.toString()); 348 } 349 printWriter.decreaseIndent(); 350 printWriter.println("Manifest Table:"); 351 printWriter.increaseIndent(); 352 List<Manifest> manifests = dao().queryAllManifests(); 353 for (Manifest manifest : manifests) { 354 printWriter.println(manifest.toString()); 355 } 356 printWriter.decreaseIndent(); 357 printWriter.println("ManifestModelCrossRef Table:"); 358 printWriter.increaseIndent(); 359 List<ManifestModelCrossRef> manifestModelCrossRefs = dao().queryAllManifestModelCrossRefs(); 360 for (ManifestModelCrossRef manifestModelCrossRef : manifestModelCrossRefs) { 361 printWriter.println(manifestModelCrossRef.toString()); 362 } 363 printWriter.decreaseIndent(); 364 printWriter.println("ManifestEnrollment Table:"); 365 printWriter.increaseIndent(); 366 List<ManifestEnrollment> manifestEnrollments = dao().queryAllManifestEnrollments(); 367 for (ManifestEnrollment manifestEnrollment : manifestEnrollments) { 368 printWriter.println(manifestEnrollment.toString()); 369 } 370 printWriter.decreaseIndent(); 371 printWriter.decreaseIndent(); 372 } 373 } 374