• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.textclassifier.downloader;
18 
19 import android.content.Context;
20 import android.util.ArrayMap;
21 import androidx.annotation.GuardedBy;
22 import androidx.room.Room;
23 import com.android.textclassifier.common.ModelType;
24 import com.android.textclassifier.common.ModelType.ModelTypeDef;
25 import com.android.textclassifier.common.TextClassifierServiceExecutors;
26 import com.android.textclassifier.common.TextClassifierSettings;
27 import com.android.textclassifier.common.base.TcLog;
28 import com.android.textclassifier.downloader.DownloadedModelDatabase.Manifest;
29 import com.android.textclassifier.downloader.DownloadedModelDatabase.ManifestEnrollment;
30 import com.android.textclassifier.downloader.DownloadedModelDatabase.Model;
31 import com.android.textclassifier.downloader.DownloadedModelDatabase.ModelView;
32 import com.android.textclassifier.utils.IndentingPrintWriter;
33 import com.google.common.annotations.VisibleForTesting;
34 import com.google.common.collect.ImmutableList;
35 import com.google.common.collect.ImmutableMap;
36 import com.google.common.collect.Iterables;
37 import java.io.File;
38 import java.util.ArrayList;
39 import java.util.List;
40 import java.util.Map;
41 import java.util.Optional;
42 import java.util.Set;
43 import java.util.stream.Collectors;
44 import javax.annotation.Nullable;
45 
46 /** A singleton implementation of DownloadedModelManager. */
47 public final class DownloadedModelManagerImpl implements DownloadedModelManager {
48   private static final String TAG = "DownloadedModelManagerImpl";
49   private static final String DOWNLOAD_SUB_DIR_NAME = "textclassifier/downloads/models";
50   private static final String DOWNLOADED_MODEL_DATABASE_NAME = "tcs-downloaded-model-db";
51 
52   private static final Object staticLock = new Object();
53 
54   @GuardedBy("staticLock")
55   private static DownloadedModelManagerImpl instance;
56 
57   private final File modelDownloaderDir;
58   private final DownloadedModelDatabase db;
59   private final TextClassifierSettings settings;
60 
61   private final Object cacheLock = new Object();
62 
63   // modeltype -> downloaded model files
64   @GuardedBy("cacheLock")
65   private final ArrayMap<String, List<Model>> modelLookupCache;
66 
67   @GuardedBy("cacheLock")
68   private boolean cacheInitialized;
69 
70   @Nullable
getInstance(Context context)71   public static DownloadedModelManager getInstance(Context context) {
72     synchronized (staticLock) {
73       if (instance == null) {
74         DownloadedModelDatabase db =
75             Room.databaseBuilder(
76                     context, DownloadedModelDatabase.class, DOWNLOADED_MODEL_DATABASE_NAME)
77                 .build();
78         File modelDownloaderDir = new File(context.getFilesDir(), DOWNLOAD_SUB_DIR_NAME);
79         instance =
80             new DownloadedModelManagerImpl(db, modelDownloaderDir, new TextClassifierSettings());
81       }
82       return instance;
83     }
84   }
85 
86   @VisibleForTesting
getInstanceForTesting( DownloadedModelDatabase db, File modelDownloaderDir, TextClassifierSettings settings)87   static DownloadedModelManagerImpl getInstanceForTesting(
88       DownloadedModelDatabase db, File modelDownloaderDir, TextClassifierSettings settings) {
89     return new DownloadedModelManagerImpl(db, modelDownloaderDir, settings);
90   }
91 
DownloadedModelManagerImpl( DownloadedModelDatabase db, File modelDownloaderDir, TextClassifierSettings settings)92   private DownloadedModelManagerImpl(
93       DownloadedModelDatabase db, File modelDownloaderDir, TextClassifierSettings settings) {
94     this.db = db;
95     this.modelDownloaderDir = modelDownloaderDir;
96     this.modelLookupCache = new ArrayMap<>();
97     for (String modelType : ModelType.values()) {
98       this.modelLookupCache.put(modelType, new ArrayList<>());
99     }
100     this.settings = settings;
101     this.cacheInitialized = false;
102   }
103 
104   @Override
getModelDownloaderDir()105   public File getModelDownloaderDir() {
106     if (!modelDownloaderDir.exists()) {
107       modelDownloaderDir.mkdirs();
108     }
109     return modelDownloaderDir;
110   }
111 
112   @Override
113   @Nullable
listModels(@odelTypeDef String modelType)114   public ImmutableList<File> listModels(@ModelTypeDef String modelType) {
115     synchronized (cacheLock) {
116       if (!cacheInitialized) {
117         updateCache();
118       }
119       ImmutableList.Builder<File> builder = ImmutableList.builder();
120       ImmutableList<String> blockedModels = settings.getModelUrlBlocklist();
121       for (Model model : modelLookupCache.get(modelType)) {
122         if (blockedModels.contains(model.getModelUrl())) {
123           TcLog.d(TAG, "Model is blocklisted: " + model);
124           continue;
125         }
126         builder.add(new File(model.getModelPath()));
127       }
128       return builder.build();
129     }
130   }
131 
132   @Override
133   @Nullable
getModel(String modelUrl)134   public Model getModel(String modelUrl) {
135     List<Model> models = db.dao().queryModelWithModelUrl(modelUrl);
136     return Iterables.getFirst(models, null);
137   }
138 
139   @Override
140   @Nullable
getManifest(String manifestUrl)141   public Manifest getManifest(String manifestUrl) {
142     List<Manifest> manifests = db.dao().queryManifestWithManifestUrl(manifestUrl);
143     return Iterables.getFirst(manifests, null);
144   }
145 
146   @Override
147   @Nullable
getManifestEnrollment( @odelTypeDef String modelType, String localeTag)148   public ManifestEnrollment getManifestEnrollment(
149       @ModelTypeDef String modelType, String localeTag) {
150     List<ManifestEnrollment> manifestEnrollments =
151         db.dao().queryManifestEnrollmentWithModelTypeAndLocaleTag(modelType, localeTag);
152     return Iterables.getFirst(manifestEnrollments, null);
153   }
154 
155   @Override
registerModel(String modelUrl, String modelPath)156   public void registerModel(String modelUrl, String modelPath) {
157     db.dao().insert(Model.create(modelUrl, modelPath));
158   }
159 
160   @Override
registerManifest(String manifestUrl, String modelUrl)161   public void registerManifest(String manifestUrl, String modelUrl) {
162     db.dao().insertManifestAndModelCrossRef(manifestUrl, modelUrl);
163   }
164 
165   @Override
registerManifestDownloadFailure(String manifestUrl)166   public void registerManifestDownloadFailure(String manifestUrl) {
167     db.dao().increaseManifestFailureCounts(manifestUrl);
168   }
169 
170   @Override
registerManifestEnrollment( @odelTypeDef String modelType, String localeTag, String manifestUrl)171   public void registerManifestEnrollment(
172       @ModelTypeDef String modelType, String localeTag, String manifestUrl) {
173     db.dao().insert(ManifestEnrollment.create(modelType, localeTag, manifestUrl));
174   }
175 
176   @Override
dump(IndentingPrintWriter printWriter)177   public void dump(IndentingPrintWriter printWriter) {
178     printWriter.println("DownloadedModelManagerImpl:");
179     printWriter.increaseIndent();
180     db.dump(printWriter, TextClassifierServiceExecutors.getDownloaderExecutor());
181     printWriter.println("ModelLookupCache:");
182     synchronized (cacheLock) {
183       for (Map.Entry<String, List<Model>> entry : modelLookupCache.entrySet()) {
184         printWriter.println(entry.getKey());
185         printWriter.increaseIndent();
186         for (Model model : entry.getValue()) {
187           printWriter.println(model.toString());
188         }
189         printWriter.decreaseIndent();
190       }
191     }
192     printWriter.decreaseIndent();
193   }
194 
195   @Override
onDownloadCompleted( ImmutableMap<String, ManifestsToDownloadByType> manifestsToDownload)196   public void onDownloadCompleted(
197       ImmutableMap<String, ManifestsToDownloadByType> manifestsToDownload) {
198     TcLog.d(TAG, "Start to clean up models and update model lookup cache...");
199     // Step 1: Clean up ManifestEnrollment table
200     List<ManifestEnrollment> allManifestEnrollments = db.dao().queryAllManifestEnrollments();
201     List<ManifestEnrollment> manifestEnrollmentsToDelete = new ArrayList<>();
202     for (String modelType : ModelType.values()) {
203       List<ManifestEnrollment> manifestEnrollmentsByType =
204           allManifestEnrollments.stream()
205               .filter(modelEnrollment -> modelEnrollment.getModelType().equals(modelType))
206               .collect(Collectors.toList());
207       ManifestsToDownloadByType manifestsToDownloadByType = manifestsToDownload.get(modelType);
208 
209       if (manifestsToDownloadByType == null) {
210         // No suitable manifests configured for this model type. Delete everything.
211         manifestEnrollmentsToDelete.addAll(manifestEnrollmentsByType);
212         continue;
213       }
214       ImmutableMap<String, String> localeTagToManifestUrl =
215           manifestsToDownloadByType.localeTagToManifestUrl();
216 
217       boolean allModelsDownloaded = true;
218       for (Map.Entry<String, String> entry : localeTagToManifestUrl.entrySet()) {
219         String localeTag = entry.getKey();
220         String manifestUrl = entry.getValue();
221         Optional<ManifestEnrollment> manifestEnrollmentForLocaleTagAndManifestUrl =
222             manifestEnrollmentsByType.stream()
223                 .filter(
224                     manifestEnrollment ->
225                         manifestEnrollment.getLocaleTag().equals(localeTag)
226                             && manifestEnrollment.getManifestUrl().equals(manifestUrl))
227                 .findAny();
228         if (!manifestEnrollmentForLocaleTagAndManifestUrl.isPresent()) {
229           // The desired manifest failed to be downloaded.
230           TcLog.w(
231               TAG,
232               String.format(
233                   "Desired manifest is missing on download completed: %s, %s, %s",
234                   modelType, localeTag, manifestUrl));
235           allModelsDownloaded = false;
236         }
237       }
238       if (allModelsDownloaded) {
239         // Delete unused manifest enrollments.
240         manifestEnrollmentsToDelete.addAll(
241             manifestEnrollmentsByType.stream()
242                 .filter(
243                     manifestEnrollment ->
244                         !manifestEnrollment
245                             .getManifestUrl()
246                             .equals(localeTagToManifestUrl.get(manifestEnrollment.getLocaleTag())))
247                 .collect(Collectors.toList()));
248       } else {
249         // TODO(licha): We may still need to delete models here. E.g. we are switching from en to
250         // zh. Although we fail to download zh model, we still want to delete en models.
251         TcLog.w(
252             TAG, "Unused models were not deleted because downloading of at least one model failed");
253       }
254     }
255     db.dao().deleteManifestEnrollments(manifestEnrollmentsToDelete);
256     // Step 2: Clean up Manifests and Models that are not linked to any ManifestEnrollment
257     db.dao().deleteUnusedManifestsAndModels();
258     // Step 3: Clean up Manifest failure records
259     // We only keep a failure record if the worker stills trys to download it
260     // We restrict the deletion to failure records only because although some manifest urls are not
261     // in allAttemptedManifestUrls, they can still be useful (e.g. current manifest is v901, and we
262     // failed to download v902. v901 will not be in the map, but it should be kept.)
263     List<String> allAttemptedManifestUrls =
264         manifestsToDownload.entrySet().stream()
265             .flatMap(
266                 entry ->
267                     entry.getValue().localeTagToManifestUrl().entrySet().stream()
268                         .map(Map.Entry::getValue))
269             .collect(Collectors.toList());
270     db.dao().deleteUnusedManifestFailureRecords(allAttemptedManifestUrls);
271     // Step 4: Update lookup cache
272     updateCache();
273     // Step 5: Clean up unused model files.
274     Set<String> modelPathsToKeep =
275         db.dao().queryAllModels().stream().map(Model::getModelPath).collect(Collectors.toSet());
276     for (File modelFile : getModelDownloaderDir().listFiles()) {
277       if (!modelPathsToKeep.contains(modelFile.getAbsolutePath())) {
278         TcLog.d(TAG, "Delete model file: " + modelFile.getAbsolutePath());
279         if (!modelFile.delete()) {
280           TcLog.e(TAG, "Failed to delete model file: " + modelFile.getAbsolutePath());
281         }
282       }
283     }
284   }
285 
286   // Clear the cache table and rebuild the cache based on ModelView table
updateCache()287   private void updateCache() {
288     synchronized (cacheLock) {
289       TcLog.d(TAG, "Updating model lookup cache...");
290       for (String modelType : ModelType.values()) {
291         modelLookupCache.get(modelType).clear();
292       }
293       for (ModelView modelView : db.dao().queryAllModelViews()) {
294         modelLookupCache
295             .get(modelView.getManifestEnrollment().getModelType())
296             .add(modelView.getModel());
297       }
298       cacheInitialized = true;
299     }
300   }
301 }
302