1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.textclassifier.downloader; 18 19 import android.content.Context; 20 import android.util.ArrayMap; 21 import androidx.annotation.GuardedBy; 22 import androidx.room.Room; 23 import com.android.textclassifier.common.ModelType; 24 import com.android.textclassifier.common.ModelType.ModelTypeDef; 25 import com.android.textclassifier.common.TextClassifierServiceExecutors; 26 import com.android.textclassifier.common.TextClassifierSettings; 27 import com.android.textclassifier.common.base.TcLog; 28 import com.android.textclassifier.downloader.DownloadedModelDatabase.Manifest; 29 import com.android.textclassifier.downloader.DownloadedModelDatabase.ManifestEnrollment; 30 import com.android.textclassifier.downloader.DownloadedModelDatabase.Model; 31 import com.android.textclassifier.downloader.DownloadedModelDatabase.ModelView; 32 import com.android.textclassifier.utils.IndentingPrintWriter; 33 import com.google.common.annotations.VisibleForTesting; 34 import com.google.common.collect.ImmutableList; 35 import com.google.common.collect.ImmutableMap; 36 import com.google.common.collect.Iterables; 37 import java.io.File; 38 import java.util.ArrayList; 39 import java.util.List; 40 import java.util.Map; 41 import java.util.Optional; 42 import java.util.Set; 43 import java.util.stream.Collectors; 44 import javax.annotation.Nullable; 45 46 /** A singleton implementation of DownloadedModelManager. */ 47 public final class DownloadedModelManagerImpl implements DownloadedModelManager { 48 private static final String TAG = "DownloadedModelManagerImpl"; 49 private static final String DOWNLOAD_SUB_DIR_NAME = "textclassifier/downloads/models"; 50 private static final String DOWNLOADED_MODEL_DATABASE_NAME = "tcs-downloaded-model-db"; 51 52 private static final Object staticLock = new Object(); 53 54 @GuardedBy("staticLock") 55 private static DownloadedModelManagerImpl instance; 56 57 private final File modelDownloaderDir; 58 private final DownloadedModelDatabase db; 59 private final TextClassifierSettings settings; 60 61 private final Object cacheLock = new Object(); 62 63 // modeltype -> downloaded model files 64 @GuardedBy("cacheLock") 65 private final ArrayMap<String, List<Model>> modelLookupCache; 66 67 @GuardedBy("cacheLock") 68 private boolean cacheInitialized; 69 70 @Nullable getInstance(Context context)71 public static DownloadedModelManager getInstance(Context context) { 72 synchronized (staticLock) { 73 if (instance == null) { 74 DownloadedModelDatabase db = 75 Room.databaseBuilder( 76 context, DownloadedModelDatabase.class, DOWNLOADED_MODEL_DATABASE_NAME) 77 .build(); 78 File modelDownloaderDir = new File(context.getFilesDir(), DOWNLOAD_SUB_DIR_NAME); 79 instance = 80 new DownloadedModelManagerImpl( 81 db, modelDownloaderDir, new TextClassifierSettings(context)); 82 } 83 return instance; 84 } 85 } 86 87 @VisibleForTesting getInstanceForTesting( DownloadedModelDatabase db, File modelDownloaderDir, TextClassifierSettings settings)88 static DownloadedModelManagerImpl getInstanceForTesting( 89 DownloadedModelDatabase db, File modelDownloaderDir, TextClassifierSettings settings) { 90 return new DownloadedModelManagerImpl(db, modelDownloaderDir, settings); 91 } 92 DownloadedModelManagerImpl( DownloadedModelDatabase db, File modelDownloaderDir, TextClassifierSettings settings)93 private DownloadedModelManagerImpl( 94 DownloadedModelDatabase db, File modelDownloaderDir, TextClassifierSettings settings) { 95 this.db = db; 96 this.modelDownloaderDir = modelDownloaderDir; 97 this.modelLookupCache = new ArrayMap<>(); 98 for (String modelType : ModelType.values()) { 99 this.modelLookupCache.put(modelType, new ArrayList<>()); 100 } 101 this.settings = settings; 102 this.cacheInitialized = false; 103 } 104 105 @Override getModelDownloaderDir()106 public File getModelDownloaderDir() { 107 if (!modelDownloaderDir.exists()) { 108 modelDownloaderDir.mkdirs(); 109 } 110 return modelDownloaderDir; 111 } 112 113 @Override 114 @Nullable listModels(@odelTypeDef String modelType)115 public ImmutableList<File> listModels(@ModelTypeDef String modelType) { 116 synchronized (cacheLock) { 117 if (!cacheInitialized) { 118 updateCache(); 119 } 120 ImmutableList.Builder<File> builder = ImmutableList.builder(); 121 ImmutableList<String> blockedModels = settings.getModelUrlBlocklist(); 122 for (Model model : modelLookupCache.get(modelType)) { 123 if (blockedModels.contains(model.getModelUrl())) { 124 TcLog.d(TAG, "Model is blocklisted: " + model); 125 continue; 126 } 127 builder.add(new File(model.getModelPath())); 128 } 129 return builder.build(); 130 } 131 } 132 133 @Override 134 @Nullable getModel(String modelUrl)135 public Model getModel(String modelUrl) { 136 List<Model> models = db.dao().queryModelWithModelUrl(modelUrl); 137 return Iterables.getFirst(models, null); 138 } 139 140 @Override 141 @Nullable getManifest(String manifestUrl)142 public Manifest getManifest(String manifestUrl) { 143 List<Manifest> manifests = db.dao().queryManifestWithManifestUrl(manifestUrl); 144 return Iterables.getFirst(manifests, null); 145 } 146 147 @Override 148 @Nullable getManifestEnrollment( @odelTypeDef String modelType, String localeTag)149 public ManifestEnrollment getManifestEnrollment( 150 @ModelTypeDef String modelType, String localeTag) { 151 List<ManifestEnrollment> manifestEnrollments = 152 db.dao().queryManifestEnrollmentWithModelTypeAndLocaleTag(modelType, localeTag); 153 return Iterables.getFirst(manifestEnrollments, null); 154 } 155 156 @Override registerModel(String modelUrl, String modelPath)157 public void registerModel(String modelUrl, String modelPath) { 158 db.dao().insert(Model.create(modelUrl, modelPath)); 159 } 160 161 @Override registerManifest(String manifestUrl, String modelUrl)162 public void registerManifest(String manifestUrl, String modelUrl) { 163 db.dao().insertManifestAndModelCrossRef(manifestUrl, modelUrl); 164 } 165 166 @Override registerManifestDownloadFailure(String manifestUrl)167 public void registerManifestDownloadFailure(String manifestUrl) { 168 db.dao().increaseManifestFailureCounts(manifestUrl); 169 } 170 171 @Override registerManifestEnrollment( @odelTypeDef String modelType, String localeTag, String manifestUrl)172 public void registerManifestEnrollment( 173 @ModelTypeDef String modelType, String localeTag, String manifestUrl) { 174 db.dao().insert(ManifestEnrollment.create(modelType, localeTag, manifestUrl)); 175 } 176 177 @Override dump(IndentingPrintWriter printWriter)178 public void dump(IndentingPrintWriter printWriter) { 179 printWriter.println("DownloadedModelManagerImpl:"); 180 printWriter.increaseIndent(); 181 db.dump(printWriter, TextClassifierServiceExecutors.getDownloaderExecutor()); 182 printWriter.println("ModelLookupCache:"); 183 synchronized (cacheLock) { 184 for (Map.Entry<String, List<Model>> entry : modelLookupCache.entrySet()) { 185 printWriter.println(entry.getKey()); 186 printWriter.increaseIndent(); 187 for (Model model : entry.getValue()) { 188 printWriter.println(model.toString()); 189 } 190 printWriter.decreaseIndent(); 191 } 192 } 193 printWriter.decreaseIndent(); 194 } 195 196 @Override onDownloadCompleted( ImmutableMap<String, ManifestsToDownloadByType> manifestsToDownload)197 public void onDownloadCompleted( 198 ImmutableMap<String, ManifestsToDownloadByType> manifestsToDownload) { 199 TcLog.d(TAG, "Start to clean up models and update model lookup cache..."); 200 // Step 1: Clean up ManifestEnrollment table 201 List<ManifestEnrollment> allManifestEnrollments = db.dao().queryAllManifestEnrollments(); 202 List<ManifestEnrollment> manifestEnrollmentsToDelete = new ArrayList<>(); 203 for (String modelType : ModelType.values()) { 204 List<ManifestEnrollment> manifestEnrollmentsByType = 205 allManifestEnrollments.stream() 206 .filter(modelEnrollment -> modelEnrollment.getModelType().equals(modelType)) 207 .collect(Collectors.toList()); 208 ManifestsToDownloadByType manifestsToDownloadByType = manifestsToDownload.get(modelType); 209 210 if (manifestsToDownloadByType == null) { 211 // No suitable manifests configured for this model type. Delete everything. 212 manifestEnrollmentsToDelete.addAll(manifestEnrollmentsByType); 213 continue; 214 } 215 ImmutableMap<String, String> localeTagToManifestUrl = 216 manifestsToDownloadByType.localeTagToManifestUrl(); 217 218 boolean allModelsDownloaded = true; 219 for (Map.Entry<String, String> entry : localeTagToManifestUrl.entrySet()) { 220 String localeTag = entry.getKey(); 221 String manifestUrl = entry.getValue(); 222 Optional<ManifestEnrollment> manifestEnrollmentForLocaleTagAndManifestUrl = 223 manifestEnrollmentsByType.stream() 224 .filter( 225 manifestEnrollment -> 226 manifestEnrollment.getLocaleTag().equals(localeTag) 227 && manifestEnrollment.getManifestUrl().equals(manifestUrl)) 228 .findAny(); 229 if (!manifestEnrollmentForLocaleTagAndManifestUrl.isPresent()) { 230 // The desired manifest failed to be downloaded. 231 TcLog.w( 232 TAG, 233 String.format( 234 "Desired manifest is missing on download completed: %s, %s, %s", 235 modelType, localeTag, manifestUrl)); 236 allModelsDownloaded = false; 237 } 238 } 239 if (allModelsDownloaded) { 240 // Delete unused manifest enrollments. 241 manifestEnrollmentsToDelete.addAll( 242 manifestEnrollmentsByType.stream() 243 .filter( 244 manifestEnrollment -> 245 !manifestEnrollment 246 .getManifestUrl() 247 .equals(localeTagToManifestUrl.get(manifestEnrollment.getLocaleTag()))) 248 .collect(Collectors.toList())); 249 } else { 250 // TODO(licha): We may still need to delete models here. E.g. we are switching from en to 251 // zh. Although we fail to download zh model, we still want to delete en models. 252 TcLog.w( 253 TAG, "Unused models were not deleted because downloading of at least one model failed"); 254 } 255 } 256 db.dao().deleteManifestEnrollments(manifestEnrollmentsToDelete); 257 // Step 2: Clean up Manifests and Models that are not linked to any ManifestEnrollment 258 db.dao().deleteUnusedManifestsAndModels(); 259 // Step 3: Clean up Manifest failure records 260 // We only keep a failure record if the worker stills trys to download it 261 // We restrict the deletion to failure records only because although some manifest urls are not 262 // in allAttemptedManifestUrls, they can still be useful (e.g. current manifest is v901, and we 263 // failed to download v902. v901 will not be in the map, but it should be kept.) 264 List<String> allAttemptedManifestUrls = 265 manifestsToDownload.entrySet().stream() 266 .flatMap( 267 entry -> 268 entry.getValue().localeTagToManifestUrl().entrySet().stream() 269 .map(Map.Entry::getValue)) 270 .collect(Collectors.toList()); 271 db.dao().deleteUnusedManifestFailureRecords(allAttemptedManifestUrls); 272 // Step 4: Update lookup cache 273 updateCache(); 274 // Step 5: Clean up unused model files. 275 Set<String> modelPathsToKeep = 276 db.dao().queryAllModels().stream().map(Model::getModelPath).collect(Collectors.toSet()); 277 for (File modelFile : getModelDownloaderDir().listFiles()) { 278 if (!modelPathsToKeep.contains(modelFile.getAbsolutePath())) { 279 TcLog.d(TAG, "Delete model file: " + modelFile.getAbsolutePath()); 280 if (!modelFile.delete()) { 281 TcLog.e(TAG, "Failed to delete model file: " + modelFile.getAbsolutePath()); 282 } 283 } 284 } 285 } 286 287 // Clear the cache table and rebuild the cache based on ModelView table updateCache()288 private void updateCache() { 289 synchronized (cacheLock) { 290 TcLog.d(TAG, "Updating model lookup cache..."); 291 for (String modelType : ModelType.values()) { 292 modelLookupCache.get(modelType).clear(); 293 } 294 for (ModelView modelView : db.dao().queryAllModelViews()) { 295 modelLookupCache 296 .get(modelView.getManifestEnrollment().getModelType()) 297 .add(modelView.getModel()); 298 } 299 cacheInitialized = true; 300 } 301 } 302 } 303