1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.textclassifier.downloader; 18 19 import android.content.Context; 20 import android.util.ArrayMap; 21 import androidx.annotation.GuardedBy; 22 import androidx.room.Room; 23 import com.android.textclassifier.common.ModelType; 24 import com.android.textclassifier.common.ModelType.ModelTypeDef; 25 import com.android.textclassifier.common.TextClassifierServiceExecutors; 26 import com.android.textclassifier.common.TextClassifierSettings; 27 import com.android.textclassifier.common.base.TcLog; 28 import com.android.textclassifier.downloader.DownloadedModelDatabase.Manifest; 29 import com.android.textclassifier.downloader.DownloadedModelDatabase.ManifestEnrollment; 30 import com.android.textclassifier.downloader.DownloadedModelDatabase.Model; 31 import com.android.textclassifier.downloader.DownloadedModelDatabase.ModelView; 32 import com.android.textclassifier.utils.IndentingPrintWriter; 33 import com.google.common.annotations.VisibleForTesting; 34 import com.google.common.collect.ImmutableList; 35 import com.google.common.collect.ImmutableMap; 36 import com.google.common.collect.Iterables; 37 import java.io.File; 38 import java.util.ArrayList; 39 import java.util.List; 40 import java.util.Map; 41 import java.util.Optional; 42 import java.util.Set; 43 import java.util.stream.Collectors; 44 import javax.annotation.Nullable; 45 46 /** A singleton implementation of DownloadedModelManager. */ 47 public final class DownloadedModelManagerImpl implements DownloadedModelManager { 48 private static final String TAG = "DownloadedModelManagerImpl"; 49 private static final String DOWNLOAD_SUB_DIR_NAME = "textclassifier/downloads/models"; 50 private static final String DOWNLOADED_MODEL_DATABASE_NAME = "tcs-downloaded-model-db"; 51 52 private static final Object staticLock = new Object(); 53 54 @GuardedBy("staticLock") 55 private static DownloadedModelManagerImpl instance; 56 57 private final File modelDownloaderDir; 58 private final DownloadedModelDatabase db; 59 private final TextClassifierSettings settings; 60 61 private final Object cacheLock = new Object(); 62 63 // modeltype -> downloaded model files 64 @GuardedBy("cacheLock") 65 private final ArrayMap<String, List<Model>> modelLookupCache; 66 67 @GuardedBy("cacheLock") 68 private boolean cacheInitialized; 69 70 @Nullable getInstance(Context context)71 public static DownloadedModelManager getInstance(Context context) { 72 synchronized (staticLock) { 73 if (instance == null) { 74 DownloadedModelDatabase db = 75 Room.databaseBuilder( 76 context, DownloadedModelDatabase.class, DOWNLOADED_MODEL_DATABASE_NAME) 77 .build(); 78 File modelDownloaderDir = new File(context.getFilesDir(), DOWNLOAD_SUB_DIR_NAME); 79 instance = 80 new DownloadedModelManagerImpl(db, modelDownloaderDir, new TextClassifierSettings()); 81 } 82 return instance; 83 } 84 } 85 86 @VisibleForTesting getInstanceForTesting( DownloadedModelDatabase db, File modelDownloaderDir, TextClassifierSettings settings)87 static DownloadedModelManagerImpl getInstanceForTesting( 88 DownloadedModelDatabase db, File modelDownloaderDir, TextClassifierSettings settings) { 89 return new DownloadedModelManagerImpl(db, modelDownloaderDir, settings); 90 } 91 DownloadedModelManagerImpl( DownloadedModelDatabase db, File modelDownloaderDir, TextClassifierSettings settings)92 private DownloadedModelManagerImpl( 93 DownloadedModelDatabase db, File modelDownloaderDir, TextClassifierSettings settings) { 94 this.db = db; 95 this.modelDownloaderDir = modelDownloaderDir; 96 this.modelLookupCache = new ArrayMap<>(); 97 for (String modelType : ModelType.values()) { 98 this.modelLookupCache.put(modelType, new ArrayList<>()); 99 } 100 this.settings = settings; 101 this.cacheInitialized = false; 102 } 103 104 @Override getModelDownloaderDir()105 public File getModelDownloaderDir() { 106 if (!modelDownloaderDir.exists()) { 107 modelDownloaderDir.mkdirs(); 108 } 109 return modelDownloaderDir; 110 } 111 112 @Override 113 @Nullable listModels(@odelTypeDef String modelType)114 public ImmutableList<File> listModels(@ModelTypeDef String modelType) { 115 synchronized (cacheLock) { 116 if (!cacheInitialized) { 117 updateCache(); 118 } 119 ImmutableList.Builder<File> builder = ImmutableList.builder(); 120 ImmutableList<String> blockedModels = settings.getModelUrlBlocklist(); 121 for (Model model : modelLookupCache.get(modelType)) { 122 if (blockedModels.contains(model.getModelUrl())) { 123 TcLog.d(TAG, "Model is blocklisted: " + model); 124 continue; 125 } 126 builder.add(new File(model.getModelPath())); 127 } 128 return builder.build(); 129 } 130 } 131 132 @Override 133 @Nullable getModel(String modelUrl)134 public Model getModel(String modelUrl) { 135 List<Model> models = db.dao().queryModelWithModelUrl(modelUrl); 136 return Iterables.getFirst(models, null); 137 } 138 139 @Override 140 @Nullable getManifest(String manifestUrl)141 public Manifest getManifest(String manifestUrl) { 142 List<Manifest> manifests = db.dao().queryManifestWithManifestUrl(manifestUrl); 143 return Iterables.getFirst(manifests, null); 144 } 145 146 @Override 147 @Nullable getManifestEnrollment( @odelTypeDef String modelType, String localeTag)148 public ManifestEnrollment getManifestEnrollment( 149 @ModelTypeDef String modelType, String localeTag) { 150 List<ManifestEnrollment> manifestEnrollments = 151 db.dao().queryManifestEnrollmentWithModelTypeAndLocaleTag(modelType, localeTag); 152 return Iterables.getFirst(manifestEnrollments, null); 153 } 154 155 @Override registerModel(String modelUrl, String modelPath)156 public void registerModel(String modelUrl, String modelPath) { 157 db.dao().insert(Model.create(modelUrl, modelPath)); 158 } 159 160 @Override registerManifest(String manifestUrl, String modelUrl)161 public void registerManifest(String manifestUrl, String modelUrl) { 162 db.dao().insertManifestAndModelCrossRef(manifestUrl, modelUrl); 163 } 164 165 @Override registerManifestDownloadFailure(String manifestUrl)166 public void registerManifestDownloadFailure(String manifestUrl) { 167 db.dao().increaseManifestFailureCounts(manifestUrl); 168 } 169 170 @Override registerManifestEnrollment( @odelTypeDef String modelType, String localeTag, String manifestUrl)171 public void registerManifestEnrollment( 172 @ModelTypeDef String modelType, String localeTag, String manifestUrl) { 173 db.dao().insert(ManifestEnrollment.create(modelType, localeTag, manifestUrl)); 174 } 175 176 @Override dump(IndentingPrintWriter printWriter)177 public void dump(IndentingPrintWriter printWriter) { 178 printWriter.println("DownloadedModelManagerImpl:"); 179 printWriter.increaseIndent(); 180 db.dump(printWriter, TextClassifierServiceExecutors.getDownloaderExecutor()); 181 printWriter.println("ModelLookupCache:"); 182 synchronized (cacheLock) { 183 for (Map.Entry<String, List<Model>> entry : modelLookupCache.entrySet()) { 184 printWriter.println(entry.getKey()); 185 printWriter.increaseIndent(); 186 for (Model model : entry.getValue()) { 187 printWriter.println(model.toString()); 188 } 189 printWriter.decreaseIndent(); 190 } 191 } 192 printWriter.decreaseIndent(); 193 } 194 195 @Override onDownloadCompleted( ImmutableMap<String, ManifestsToDownloadByType> manifestsToDownload)196 public void onDownloadCompleted( 197 ImmutableMap<String, ManifestsToDownloadByType> manifestsToDownload) { 198 TcLog.d(TAG, "Start to clean up models and update model lookup cache..."); 199 // Step 1: Clean up ManifestEnrollment table 200 List<ManifestEnrollment> allManifestEnrollments = db.dao().queryAllManifestEnrollments(); 201 List<ManifestEnrollment> manifestEnrollmentsToDelete = new ArrayList<>(); 202 for (String modelType : ModelType.values()) { 203 List<ManifestEnrollment> manifestEnrollmentsByType = 204 allManifestEnrollments.stream() 205 .filter(modelEnrollment -> modelEnrollment.getModelType().equals(modelType)) 206 .collect(Collectors.toList()); 207 ManifestsToDownloadByType manifestsToDownloadByType = manifestsToDownload.get(modelType); 208 209 if (manifestsToDownloadByType == null) { 210 // No suitable manifests configured for this model type. Delete everything. 211 manifestEnrollmentsToDelete.addAll(manifestEnrollmentsByType); 212 continue; 213 } 214 ImmutableMap<String, String> localeTagToManifestUrl = 215 manifestsToDownloadByType.localeTagToManifestUrl(); 216 217 boolean allModelsDownloaded = true; 218 for (Map.Entry<String, String> entry : localeTagToManifestUrl.entrySet()) { 219 String localeTag = entry.getKey(); 220 String manifestUrl = entry.getValue(); 221 Optional<ManifestEnrollment> manifestEnrollmentForLocaleTagAndManifestUrl = 222 manifestEnrollmentsByType.stream() 223 .filter( 224 manifestEnrollment -> 225 manifestEnrollment.getLocaleTag().equals(localeTag) 226 && manifestEnrollment.getManifestUrl().equals(manifestUrl)) 227 .findAny(); 228 if (!manifestEnrollmentForLocaleTagAndManifestUrl.isPresent()) { 229 // The desired manifest failed to be downloaded. 230 TcLog.w( 231 TAG, 232 String.format( 233 "Desired manifest is missing on download completed: %s, %s, %s", 234 modelType, localeTag, manifestUrl)); 235 allModelsDownloaded = false; 236 } 237 } 238 if (allModelsDownloaded) { 239 // Delete unused manifest enrollments. 240 manifestEnrollmentsToDelete.addAll( 241 manifestEnrollmentsByType.stream() 242 .filter( 243 manifestEnrollment -> 244 !manifestEnrollment 245 .getManifestUrl() 246 .equals(localeTagToManifestUrl.get(manifestEnrollment.getLocaleTag()))) 247 .collect(Collectors.toList())); 248 } else { 249 // TODO(licha): We may still need to delete models here. E.g. we are switching from en to 250 // zh. Although we fail to download zh model, we still want to delete en models. 251 TcLog.w( 252 TAG, "Unused models were not deleted because downloading of at least one model failed"); 253 } 254 } 255 db.dao().deleteManifestEnrollments(manifestEnrollmentsToDelete); 256 // Step 2: Clean up Manifests and Models that are not linked to any ManifestEnrollment 257 db.dao().deleteUnusedManifestsAndModels(); 258 // Step 3: Clean up Manifest failure records 259 // We only keep a failure record if the worker stills trys to download it 260 // We restrict the deletion to failure records only because although some manifest urls are not 261 // in allAttemptedManifestUrls, they can still be useful (e.g. current manifest is v901, and we 262 // failed to download v902. v901 will not be in the map, but it should be kept.) 263 List<String> allAttemptedManifestUrls = 264 manifestsToDownload.entrySet().stream() 265 .flatMap( 266 entry -> 267 entry.getValue().localeTagToManifestUrl().entrySet().stream() 268 .map(Map.Entry::getValue)) 269 .collect(Collectors.toList()); 270 db.dao().deleteUnusedManifestFailureRecords(allAttemptedManifestUrls); 271 // Step 4: Update lookup cache 272 updateCache(); 273 // Step 5: Clean up unused model files. 274 Set<String> modelPathsToKeep = 275 db.dao().queryAllModels().stream().map(Model::getModelPath).collect(Collectors.toSet()); 276 for (File modelFile : getModelDownloaderDir().listFiles()) { 277 if (!modelPathsToKeep.contains(modelFile.getAbsolutePath())) { 278 TcLog.d(TAG, "Delete model file: " + modelFile.getAbsolutePath()); 279 if (!modelFile.delete()) { 280 TcLog.e(TAG, "Failed to delete model file: " + modelFile.getAbsolutePath()); 281 } 282 } 283 } 284 } 285 286 // Clear the cache table and rebuild the cache based on ModelView table updateCache()287 private void updateCache() { 288 synchronized (cacheLock) { 289 TcLog.d(TAG, "Updating model lookup cache..."); 290 for (String modelType : ModelType.values()) { 291 modelLookupCache.get(modelType).clear(); 292 } 293 for (ModelView modelView : db.dao().queryAllModelViews()) { 294 modelLookupCache 295 .get(modelView.getManifestEnrollment().getModelType()) 296 .add(modelView.getModel()); 297 } 298 cacheInitialized = true; 299 } 300 } 301 } 302