• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.textclassifier.downloader;
18 
19 import android.content.Context;
20 import android.util.ArrayMap;
21 import androidx.annotation.GuardedBy;
22 import androidx.room.Room;
23 import com.android.textclassifier.common.ModelType;
24 import com.android.textclassifier.common.ModelType.ModelTypeDef;
25 import com.android.textclassifier.common.TextClassifierServiceExecutors;
26 import com.android.textclassifier.common.TextClassifierSettings;
27 import com.android.textclassifier.common.base.TcLog;
28 import com.android.textclassifier.downloader.DownloadedModelDatabase.Manifest;
29 import com.android.textclassifier.downloader.DownloadedModelDatabase.ManifestEnrollment;
30 import com.android.textclassifier.downloader.DownloadedModelDatabase.Model;
31 import com.android.textclassifier.downloader.DownloadedModelDatabase.ModelView;
32 import com.android.textclassifier.utils.IndentingPrintWriter;
33 import com.google.common.annotations.VisibleForTesting;
34 import com.google.common.collect.ImmutableList;
35 import com.google.common.collect.ImmutableMap;
36 import com.google.common.collect.Iterables;
37 import java.io.File;
38 import java.util.ArrayList;
39 import java.util.List;
40 import java.util.Map;
41 import java.util.Optional;
42 import java.util.Set;
43 import java.util.stream.Collectors;
44 import javax.annotation.Nullable;
45 
46 /** A singleton implementation of DownloadedModelManager. */
47 public final class DownloadedModelManagerImpl implements DownloadedModelManager {
48   private static final String TAG = "DownloadedModelManagerImpl";
49   private static final String DOWNLOAD_SUB_DIR_NAME = "textclassifier/downloads/models";
50   private static final String DOWNLOADED_MODEL_DATABASE_NAME = "tcs-downloaded-model-db";
51 
52   private static final Object staticLock = new Object();
53 
54   @GuardedBy("staticLock")
55   private static DownloadedModelManagerImpl instance;
56 
57   private final File modelDownloaderDir;
58   private final DownloadedModelDatabase db;
59   private final TextClassifierSettings settings;
60 
61   private final Object cacheLock = new Object();
62 
63   // modeltype -> downloaded model files
64   @GuardedBy("cacheLock")
65   private final ArrayMap<String, List<Model>> modelLookupCache;
66 
67   @GuardedBy("cacheLock")
68   private boolean cacheInitialized;
69 
70   @Nullable
getInstance(Context context)71   public static DownloadedModelManager getInstance(Context context) {
72     synchronized (staticLock) {
73       if (instance == null) {
74         DownloadedModelDatabase db =
75             Room.databaseBuilder(
76                     context, DownloadedModelDatabase.class, DOWNLOADED_MODEL_DATABASE_NAME)
77                 .build();
78         File modelDownloaderDir = new File(context.getFilesDir(), DOWNLOAD_SUB_DIR_NAME);
79         instance =
80             new DownloadedModelManagerImpl(
81                 db, modelDownloaderDir, new TextClassifierSettings(context));
82       }
83       return instance;
84     }
85   }
86 
87   @VisibleForTesting
getInstanceForTesting( DownloadedModelDatabase db, File modelDownloaderDir, TextClassifierSettings settings)88   static DownloadedModelManagerImpl getInstanceForTesting(
89       DownloadedModelDatabase db, File modelDownloaderDir, TextClassifierSettings settings) {
90     return new DownloadedModelManagerImpl(db, modelDownloaderDir, settings);
91   }
92 
DownloadedModelManagerImpl( DownloadedModelDatabase db, File modelDownloaderDir, TextClassifierSettings settings)93   private DownloadedModelManagerImpl(
94       DownloadedModelDatabase db, File modelDownloaderDir, TextClassifierSettings settings) {
95     this.db = db;
96     this.modelDownloaderDir = modelDownloaderDir;
97     this.modelLookupCache = new ArrayMap<>();
98     for (String modelType : ModelType.values()) {
99       this.modelLookupCache.put(modelType, new ArrayList<>());
100     }
101     this.settings = settings;
102     this.cacheInitialized = false;
103   }
104 
105   @Override
getModelDownloaderDir()106   public File getModelDownloaderDir() {
107     if (!modelDownloaderDir.exists()) {
108       modelDownloaderDir.mkdirs();
109     }
110     return modelDownloaderDir;
111   }
112 
113   @Override
114   @Nullable
listModels(@odelTypeDef String modelType)115   public ImmutableList<File> listModels(@ModelTypeDef String modelType) {
116     synchronized (cacheLock) {
117       if (!cacheInitialized) {
118         updateCache();
119       }
120       ImmutableList.Builder<File> builder = ImmutableList.builder();
121       ImmutableList<String> blockedModels = settings.getModelUrlBlocklist();
122       for (Model model : modelLookupCache.get(modelType)) {
123         if (blockedModels.contains(model.getModelUrl())) {
124           TcLog.d(TAG, "Model is blocklisted: " + model);
125           continue;
126         }
127         builder.add(new File(model.getModelPath()));
128       }
129       return builder.build();
130     }
131   }
132 
133   @Override
134   @Nullable
getModel(String modelUrl)135   public Model getModel(String modelUrl) {
136     List<Model> models = db.dao().queryModelWithModelUrl(modelUrl);
137     return Iterables.getFirst(models, null);
138   }
139 
140   @Override
141   @Nullable
getManifest(String manifestUrl)142   public Manifest getManifest(String manifestUrl) {
143     List<Manifest> manifests = db.dao().queryManifestWithManifestUrl(manifestUrl);
144     return Iterables.getFirst(manifests, null);
145   }
146 
147   @Override
148   @Nullable
getManifestEnrollment( @odelTypeDef String modelType, String localeTag)149   public ManifestEnrollment getManifestEnrollment(
150       @ModelTypeDef String modelType, String localeTag) {
151     List<ManifestEnrollment> manifestEnrollments =
152         db.dao().queryManifestEnrollmentWithModelTypeAndLocaleTag(modelType, localeTag);
153     return Iterables.getFirst(manifestEnrollments, null);
154   }
155 
156   @Override
registerModel(String modelUrl, String modelPath)157   public void registerModel(String modelUrl, String modelPath) {
158     db.dao().insert(Model.create(modelUrl, modelPath));
159   }
160 
161   @Override
registerManifest(String manifestUrl, String modelUrl)162   public void registerManifest(String manifestUrl, String modelUrl) {
163     db.dao().insertManifestAndModelCrossRef(manifestUrl, modelUrl);
164   }
165 
166   @Override
registerManifestDownloadFailure(String manifestUrl)167   public void registerManifestDownloadFailure(String manifestUrl) {
168     db.dao().increaseManifestFailureCounts(manifestUrl);
169   }
170 
171   @Override
registerManifestEnrollment( @odelTypeDef String modelType, String localeTag, String manifestUrl)172   public void registerManifestEnrollment(
173       @ModelTypeDef String modelType, String localeTag, String manifestUrl) {
174     db.dao().insert(ManifestEnrollment.create(modelType, localeTag, manifestUrl));
175   }
176 
177   @Override
dump(IndentingPrintWriter printWriter)178   public void dump(IndentingPrintWriter printWriter) {
179     printWriter.println("DownloadedModelManagerImpl:");
180     printWriter.increaseIndent();
181     db.dump(printWriter, TextClassifierServiceExecutors.getDownloaderExecutor());
182     printWriter.println("ModelLookupCache:");
183     synchronized (cacheLock) {
184       for (Map.Entry<String, List<Model>> entry : modelLookupCache.entrySet()) {
185         printWriter.println(entry.getKey());
186         printWriter.increaseIndent();
187         for (Model model : entry.getValue()) {
188           printWriter.println(model.toString());
189         }
190         printWriter.decreaseIndent();
191       }
192     }
193     printWriter.decreaseIndent();
194   }
195 
196   @Override
onDownloadCompleted( ImmutableMap<String, ManifestsToDownloadByType> manifestsToDownload)197   public void onDownloadCompleted(
198       ImmutableMap<String, ManifestsToDownloadByType> manifestsToDownload) {
199     TcLog.d(TAG, "Start to clean up models and update model lookup cache...");
200     // Step 1: Clean up ManifestEnrollment table
201     List<ManifestEnrollment> allManifestEnrollments = db.dao().queryAllManifestEnrollments();
202     List<ManifestEnrollment> manifestEnrollmentsToDelete = new ArrayList<>();
203     for (String modelType : ModelType.values()) {
204       List<ManifestEnrollment> manifestEnrollmentsByType =
205           allManifestEnrollments.stream()
206               .filter(modelEnrollment -> modelEnrollment.getModelType().equals(modelType))
207               .collect(Collectors.toList());
208       ManifestsToDownloadByType manifestsToDownloadByType = manifestsToDownload.get(modelType);
209 
210       if (manifestsToDownloadByType == null) {
211         // No suitable manifests configured for this model type. Delete everything.
212         manifestEnrollmentsToDelete.addAll(manifestEnrollmentsByType);
213         continue;
214       }
215       ImmutableMap<String, String> localeTagToManifestUrl =
216           manifestsToDownloadByType.localeTagToManifestUrl();
217 
218       boolean allModelsDownloaded = true;
219       for (Map.Entry<String, String> entry : localeTagToManifestUrl.entrySet()) {
220         String localeTag = entry.getKey();
221         String manifestUrl = entry.getValue();
222         Optional<ManifestEnrollment> manifestEnrollmentForLocaleTagAndManifestUrl =
223             manifestEnrollmentsByType.stream()
224                 .filter(
225                     manifestEnrollment ->
226                         manifestEnrollment.getLocaleTag().equals(localeTag)
227                             && manifestEnrollment.getManifestUrl().equals(manifestUrl))
228                 .findAny();
229         if (!manifestEnrollmentForLocaleTagAndManifestUrl.isPresent()) {
230           // The desired manifest failed to be downloaded.
231           TcLog.w(
232               TAG,
233               String.format(
234                   "Desired manifest is missing on download completed: %s, %s, %s",
235                   modelType, localeTag, manifestUrl));
236           allModelsDownloaded = false;
237         }
238       }
239       if (allModelsDownloaded) {
240         // Delete unused manifest enrollments.
241         manifestEnrollmentsToDelete.addAll(
242             manifestEnrollmentsByType.stream()
243                 .filter(
244                     manifestEnrollment ->
245                         !manifestEnrollment
246                             .getManifestUrl()
247                             .equals(localeTagToManifestUrl.get(manifestEnrollment.getLocaleTag())))
248                 .collect(Collectors.toList()));
249       } else {
250         // TODO(licha): We may still need to delete models here. E.g. we are switching from en to
251         // zh. Although we fail to download zh model, we still want to delete en models.
252         TcLog.w(
253             TAG, "Unused models were not deleted because downloading of at least one model failed");
254       }
255     }
256     db.dao().deleteManifestEnrollments(manifestEnrollmentsToDelete);
257     // Step 2: Clean up Manifests and Models that are not linked to any ManifestEnrollment
258     db.dao().deleteUnusedManifestsAndModels();
259     // Step 3: Clean up Manifest failure records
260     // We only keep a failure record if the worker stills trys to download it
261     // We restrict the deletion to failure records only because although some manifest urls are not
262     // in allAttemptedManifestUrls, they can still be useful (e.g. current manifest is v901, and we
263     // failed to download v902. v901 will not be in the map, but it should be kept.)
264     List<String> allAttemptedManifestUrls =
265         manifestsToDownload.entrySet().stream()
266             .flatMap(
267                 entry ->
268                     entry.getValue().localeTagToManifestUrl().entrySet().stream()
269                         .map(Map.Entry::getValue))
270             .collect(Collectors.toList());
271     db.dao().deleteUnusedManifestFailureRecords(allAttemptedManifestUrls);
272     // Step 4: Update lookup cache
273     updateCache();
274     // Step 5: Clean up unused model files.
275     Set<String> modelPathsToKeep =
276         db.dao().queryAllModels().stream().map(Model::getModelPath).collect(Collectors.toSet());
277     for (File modelFile : getModelDownloaderDir().listFiles()) {
278       if (!modelPathsToKeep.contains(modelFile.getAbsolutePath())) {
279         TcLog.d(TAG, "Delete model file: " + modelFile.getAbsolutePath());
280         if (!modelFile.delete()) {
281           TcLog.e(TAG, "Failed to delete model file: " + modelFile.getAbsolutePath());
282         }
283       }
284     }
285   }
286 
287   // Clear the cache table and rebuild the cache based on ModelView table
updateCache()288   private void updateCache() {
289     synchronized (cacheLock) {
290       TcLog.d(TAG, "Updating model lookup cache...");
291       for (String modelType : ModelType.values()) {
292         modelLookupCache.get(modelType).clear();
293       }
294       for (ModelView modelView : db.dao().queryAllModelViews()) {
295         modelLookupCache
296             .get(modelView.getManifestEnrollment().getModelType())
297             .add(modelView.getModel());
298       }
299       cacheInitialized = true;
300     }
301   }
302 }
303