• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.textclassifier.downloader;
18 
19 import androidx.annotation.IntDef;
20 import androidx.annotation.NonNull;
21 import androidx.room.ColumnInfo;
22 import androidx.room.Dao;
23 import androidx.room.Database;
24 import androidx.room.DatabaseView;
25 import androidx.room.Delete;
26 import androidx.room.Embedded;
27 import androidx.room.Entity;
28 import androidx.room.ForeignKey;
29 import androidx.room.Index;
30 import androidx.room.Insert;
31 import androidx.room.OnConflictStrategy;
32 import androidx.room.Query;
33 import androidx.room.RoomDatabase;
34 import androidx.room.Transaction;
35 import com.android.textclassifier.common.ModelType.ModelTypeDef;
36 import com.android.textclassifier.utils.IndentingPrintWriter;
37 import com.google.auto.value.AutoValue;
38 import com.google.common.collect.Iterables;
39 import java.lang.annotation.Retention;
40 import java.lang.annotation.RetentionPolicy;
41 import java.util.List;
42 import java.util.concurrent.ExecutorService;
43 
44 /** Database storing info about models downloaded by model downloader */
45 @Database(
46     entities = {
47       DownloadedModelDatabase.Model.class,
48       DownloadedModelDatabase.Manifest.class,
49       DownloadedModelDatabase.ManifestModelCrossRef.class,
50       DownloadedModelDatabase.ManifestEnrollment.class
51     },
52     views = {DownloadedModelDatabase.ModelView.class},
53     version = 1,
54     exportSchema = true)
55 abstract class DownloadedModelDatabase extends RoomDatabase {
56   public static final String TAG = "DownloadedModelDatabase";
57 
58   /** Rpresents a downloaded model file. */
59   @AutoValue
60   @Entity(
61       tableName = "model",
62       primaryKeys = {"model_url"})
63   abstract static class Model {
64     @AutoValue.CopyAnnotations
65     @ColumnInfo(name = "model_url")
66     @NonNull
getModelUrl()67     public abstract String getModelUrl();
68 
69     @AutoValue.CopyAnnotations
70     @ColumnInfo(name = "model_path")
71     @NonNull
getModelPath()72     public abstract String getModelPath();
73 
create(String modelUrl, String modelPath)74     public static Model create(String modelUrl, String modelPath) {
75       return new AutoValue_DownloadedModelDatabase_Model(modelUrl, modelPath);
76     }
77   }
78 
79   /** Rpresents a manifest we processed. */
80   @AutoValue
81   @Entity(
82       tableName = "manifest",
83       primaryKeys = {"manifest_url"})
84   abstract static class Manifest {
85     // TODO(licha): Consider using Enum here
86     @Retention(RetentionPolicy.SOURCE)
87     @IntDef({STATUS_UNKNOWN, STATUS_FAILED, STATUS_SUCCEEDED})
88     @interface StatusDef {}
89 
90     public static final int STATUS_UNKNOWN = 0;
91     /** Failed to download this manifest. Could be retried in the future. */
92     public static final int STATUS_FAILED = 1;
93     /** Downloaded this manifest successfully and it's currently in storage. */
94     public static final int STATUS_SUCCEEDED = 2;
95 
96     @AutoValue.CopyAnnotations
97     @ColumnInfo(name = "manifest_url")
98     @NonNull
getManifestUrl()99     public abstract String getManifestUrl();
100 
101     @AutoValue.CopyAnnotations
102     @ColumnInfo(name = "status")
103     @StatusDef
getStatus()104     public abstract int getStatus();
105 
106     @AutoValue.CopyAnnotations
107     @ColumnInfo(name = "failure_counts")
getFailureCounts()108     public abstract int getFailureCounts();
109 
create(String manifestUrl, @StatusDef int status, int failureCounts)110     public static Manifest create(String manifestUrl, @StatusDef int status, int failureCounts) {
111       return new AutoValue_DownloadedModelDatabase_Manifest(manifestUrl, status, failureCounts);
112     }
113   }
114 
115   /**
116    * Represents the relationship between manfiests and downloaded models.
117    *
118    * <p>A manifest can include multiple models, a model can also be included in multiple manifests.
119    * In different manifests, a model may have different configurations (e.g. primary model in
120    * manfiest A but dark model in B).
121    */
122   @AutoValue
123   @Entity(
124       tableName = "manifest_model_cross_ref",
125       primaryKeys = {"manifest_url", "model_url"},
126       foreignKeys = {
127         @ForeignKey(
128             entity = Manifest.class,
129             parentColumns = "manifest_url",
130             childColumns = "manifest_url",
131             onDelete = ForeignKey.CASCADE),
132         @ForeignKey(
133             entity = Model.class,
134             parentColumns = "model_url",
135             childColumns = "model_url",
136             onDelete = ForeignKey.CASCADE),
137       },
138       indices = {
139         @Index(value = {"manifest_url"}),
140         @Index(value = {"model_url"}),
141       })
142   abstract static class ManifestModelCrossRef {
143     @AutoValue.CopyAnnotations
144     @ColumnInfo(name = "manifest_url")
145     @NonNull
getManifestUrl()146     public abstract String getManifestUrl();
147 
148     @AutoValue.CopyAnnotations
149     @ColumnInfo(name = "model_url")
150     @NonNull
getModelUrl()151     public abstract String getModelUrl();
152 
create(String manifestUrl, String modelUrl)153     public static ManifestModelCrossRef create(String manifestUrl, String modelUrl) {
154       return new AutoValue_DownloadedModelDatabase_ManifestModelCrossRef(manifestUrl, modelUrl);
155     }
156   }
157 
158   /**
159    * Represents the relationship between user scenarios and manifests.
160    *
161    * <p>For each unique user scenario (i.e. modelType + localTag), we store the manifest we should
162    * use. The same manifest can be used for different scenarios.
163    */
164   @AutoValue
165   @Entity(
166       tableName = "manifest_enrollment",
167       primaryKeys = {"model_type", "locale_tag"},
168       foreignKeys = {
169         @ForeignKey(
170             entity = Manifest.class,
171             parentColumns = "manifest_url",
172             childColumns = "manifest_url",
173             onDelete = ForeignKey.CASCADE)
174       },
175       indices = {@Index(value = {"manifest_url"})})
176   abstract static class ManifestEnrollment {
177     @AutoValue.CopyAnnotations
178     @ColumnInfo(name = "model_type")
179     @NonNull
180     @ModelTypeDef
getModelType()181     public abstract String getModelType();
182 
183     @AutoValue.CopyAnnotations
184     @ColumnInfo(name = "locale_tag")
185     @NonNull
getLocaleTag()186     public abstract String getLocaleTag();
187 
188     @AutoValue.CopyAnnotations
189     @ColumnInfo(name = "manifest_url")
190     @NonNull
getManifestUrl()191     public abstract String getManifestUrl();
192 
create( @odelTypeDef String modelType, String localeTag, String manifestUrl)193     public static ManifestEnrollment create(
194         @ModelTypeDef String modelType, String localeTag, String manifestUrl) {
195       return new AutoValue_DownloadedModelDatabase_ManifestEnrollment(
196           modelType, localeTag, manifestUrl);
197     }
198   }
199 
200   /** Represents the mapping from manfiest enrollments to models. */
201   @AutoValue
202   @DatabaseView(
203       value =
204           "SELECT manifest_enrollment.*, model.* "
205               + "FROM manifest_enrollment "
206               + "INNER JOIN manifest_model_cross_ref "
207               + "ON manifest_enrollment.manifest_url = manifest_model_cross_ref.manifest_url "
208               + "INNER JOIN model "
209               + "ON manifest_model_cross_ref.model_url = model.model_url",
210       viewName = "model_view")
211   abstract static class ModelView {
212     @AutoValue.CopyAnnotations
213     @Embedded
214     @NonNull
getManifestEnrollment()215     public abstract ManifestEnrollment getManifestEnrollment();
216 
217     @AutoValue.CopyAnnotations
218     @Embedded
219     @NonNull
getModel()220     public abstract Model getModel();
221 
create(ManifestEnrollment manifestEnrollment, Model model)222     public static ModelView create(ManifestEnrollment manifestEnrollment, Model model) {
223       return new AutoValue_DownloadedModelDatabase_ModelView(manifestEnrollment, model);
224     }
225   }
226 
227   @Dao
228   abstract static class DownloadedModelDatabaseDao {
229     // Full table scan
230     @Query("SELECT * FROM model")
queryAllModels()231     abstract List<Model> queryAllModels();
232 
233     @Query("SELECT * FROM manifest")
queryAllManifests()234     abstract List<Manifest> queryAllManifests();
235 
236     @Query("SELECT * FROM manifest_model_cross_ref")
queryAllManifestModelCrossRefs()237     abstract List<ManifestModelCrossRef> queryAllManifestModelCrossRefs();
238 
239     @Query("SELECT * FROM manifest_enrollment")
queryAllManifestEnrollments()240     abstract List<ManifestEnrollment> queryAllManifestEnrollments();
241 
242     @Query("SELECT * FROM model_view")
queryAllModelViews()243     abstract List<ModelView> queryAllModelViews();
244 
245     // Single table query
246     @Query("SELECT * FROM model WHERE model_url = :modelUrl")
queryModelWithModelUrl(String modelUrl)247     abstract List<Model> queryModelWithModelUrl(String modelUrl);
248 
249     @Query("SELECT * FROM manifest WHERE manifest_url = :manifestUrl")
queryManifestWithManifestUrl(String manifestUrl)250     abstract List<Manifest> queryManifestWithManifestUrl(String manifestUrl);
251 
252     @Query(
253         "SELECT * FROM manifest_enrollment WHERE model_type = :modelType "
254             + "AND locale_tag = :localeTag")
queryManifestEnrollmentWithModelTypeAndLocaleTag( @odelTypeDef String modelType, String localeTag)255     abstract List<ManifestEnrollment> queryManifestEnrollmentWithModelTypeAndLocaleTag(
256         @ModelTypeDef String modelType, String localeTag);
257 
258     // Helpers for clean up
259     @Query(
260         "SELECT manifest.* FROM manifest "
261             + "LEFT JOIN model_view "
262             + "ON manifest.manifest_url = model_view.manifest_url "
263             + "WHERE model_view.manifest_url IS NULL "
264             + "AND manifest.status = 2")
queryUnusedManifests()265     abstract List<Manifest> queryUnusedManifests();
266 
267     @Query(
268         "SELECT * FROM manifest WHERE manifest.status = 1 "
269             + "AND manifest.manifest_url NOT IN (:manifestUrlsToKeep)")
queryUnusedManifestFailureRecords(List<String> manifestUrlsToKeep)270     abstract List<Manifest> queryUnusedManifestFailureRecords(List<String> manifestUrlsToKeep);
271 
272     @Query(
273         "SELECT model.* FROM model LEFT JOIN model_view "
274             + "ON model.model_url = model_view.model_url "
275             + "WHERE model_view.model_url IS NULL")
queryUnusedModels()276     abstract List<Model> queryUnusedModels();
277 
278     // Insertion
279     @Insert(onConflict = OnConflictStrategy.REPLACE)
insert(Model model)280     abstract void insert(Model model);
281 
282     @Insert(onConflict = OnConflictStrategy.REPLACE)
insert(Manifest manifest)283     abstract void insert(Manifest manifest);
284 
285     @Insert(onConflict = OnConflictStrategy.REPLACE)
insert(ManifestModelCrossRef manifestModelCrossRef)286     abstract void insert(ManifestModelCrossRef manifestModelCrossRef);
287 
288     @Insert(onConflict = OnConflictStrategy.REPLACE)
insert(ManifestEnrollment manifestEnrollment)289     abstract void insert(ManifestEnrollment manifestEnrollment);
290 
291     @Transaction
insertManifestAndModelCrossRef(String manifestUrl, String modelUrl)292     void insertManifestAndModelCrossRef(String manifestUrl, String modelUrl) {
293       insert(Manifest.create(manifestUrl, Manifest.STATUS_SUCCEEDED, /* failureCounts= */ 0));
294       insert(ManifestModelCrossRef.create(manifestUrl, modelUrl));
295     }
296 
297     @Transaction
increaseManifestFailureCounts(String manifestUrl)298     void increaseManifestFailureCounts(String manifestUrl) {
299       List<Manifest> manifests = queryManifestWithManifestUrl(manifestUrl);
300       if (manifests.isEmpty()) {
301         insert(Manifest.create(manifestUrl, Manifest.STATUS_FAILED, /* failureCounts= */ 1));
302       } else {
303         Manifest prevManifest = Iterables.getOnlyElement(manifests);
304         insert(
305             Manifest.create(
306                 manifestUrl, Manifest.STATUS_FAILED, prevManifest.getFailureCounts() + 1));
307       }
308     }
309 
310     // Deletion
311     @Delete
deleteModels(List<Model> models)312     abstract void deleteModels(List<Model> models);
313 
314     @Delete
deleteManifests(List<Manifest> manifests)315     abstract void deleteManifests(List<Manifest> manifests);
316 
317     @Delete
deleteManifestModelCrossRefs(List<ManifestModelCrossRef> manifestModelCrossRefs)318     abstract void deleteManifestModelCrossRefs(List<ManifestModelCrossRef> manifestModelCrossRefs);
319 
320     @Delete
deleteManifestEnrollments(List<ManifestEnrollment> manifestEnrollments)321     abstract void deleteManifestEnrollments(List<ManifestEnrollment> manifestEnrollments);
322 
323     @Transaction
deleteUnusedManifestsAndModels()324     void deleteUnusedManifestsAndModels() {
325       // Because Manifest table is the parent table of ManifestModelCrossRef table, related cross
326       // ref row in that table will be deleted automatically
327       deleteManifests(queryUnusedManifests());
328       deleteModels(queryUnusedModels());
329     }
330 
331     @Transaction
deleteUnusedManifestFailureRecords(List<String> manifestUrlsToKeep)332     void deleteUnusedManifestFailureRecords(List<String> manifestUrlsToKeep) {
333       deleteManifests(queryUnusedManifestFailureRecords(manifestUrlsToKeep));
334     }
335   }
336 
dao()337   abstract DownloadedModelDatabaseDao dao();
338 
339   /** Dump the database for debugging. */
dump(IndentingPrintWriter printWriter, ExecutorService executorService)340   void dump(IndentingPrintWriter printWriter, ExecutorService executorService) {
341     printWriter.println("DownloadedModelDatabase");
342     printWriter.increaseIndent();
343     printWriter.println("Model Table:");
344     printWriter.increaseIndent();
345     List<Model> models = dao().queryAllModels();
346     for (Model model : models) {
347       printWriter.println(model.toString());
348     }
349     printWriter.decreaseIndent();
350     printWriter.println("Manifest Table:");
351     printWriter.increaseIndent();
352     List<Manifest> manifests = dao().queryAllManifests();
353     for (Manifest manifest : manifests) {
354       printWriter.println(manifest.toString());
355     }
356     printWriter.decreaseIndent();
357     printWriter.println("ManifestModelCrossRef Table:");
358     printWriter.increaseIndent();
359     List<ManifestModelCrossRef> manifestModelCrossRefs = dao().queryAllManifestModelCrossRefs();
360     for (ManifestModelCrossRef manifestModelCrossRef : manifestModelCrossRefs) {
361       printWriter.println(manifestModelCrossRef.toString());
362     }
363     printWriter.decreaseIndent();
364     printWriter.println("ManifestEnrollment Table:");
365     printWriter.increaseIndent();
366     List<ManifestEnrollment> manifestEnrollments = dao().queryAllManifestEnrollments();
367     for (ManifestEnrollment manifestEnrollment : manifestEnrollments) {
368       printWriter.println(manifestEnrollment.toString());
369     }
370     printWriter.decreaseIndent();
371     printWriter.decreaseIndent();
372   }
373 }
374