1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.textclassifier.downloader; 18 19 import static java.lang.Math.min; 20 21 import android.content.Context; 22 import android.os.LocaleList; 23 import android.util.ArrayMap; 24 import android.util.Pair; 25 import androidx.work.ListenableWorker; 26 import androidx.work.WorkerParameters; 27 import com.android.textclassifier.common.ModelType; 28 import com.android.textclassifier.common.ModelType.ModelTypeDef; 29 import com.android.textclassifier.common.TextClassifierServiceExecutors; 30 import com.android.textclassifier.common.TextClassifierSettings; 31 import com.android.textclassifier.common.base.TcLog; 32 import com.android.textclassifier.downloader.DownloadedModelDatabase.Manifest; 33 import com.android.textclassifier.downloader.DownloadedModelDatabase.ManifestEnrollment; 34 import com.android.textclassifier.downloader.DownloadedModelDatabase.Model; 35 import com.google.auto.value.AutoValue; 36 import com.google.common.annotations.VisibleForTesting; 37 import com.google.common.base.Preconditions; 38 import com.google.common.collect.ImmutableList; 39 import com.google.common.collect.ImmutableMap; 40 import com.google.common.util.concurrent.FluentFuture; 41 import com.google.common.util.concurrent.Futures; 42 import com.google.common.util.concurrent.ListenableFuture; 43 import com.google.common.util.concurrent.ListeningExecutorService; 44 import com.google.errorprone.annotations.concurrent.GuardedBy; 45 import java.time.Clock; 46 import java.util.ArrayList; 47 import java.util.Locale; 48 49 /** The WorkManager worker to download models for TextClassifierService. */ 50 public final class ModelDownloadWorker extends ListenableWorker { 51 private static final String TAG = "ModelDownloadWorker"; 52 53 public static final String INPUT_DATA_KEY_WORK_ID = "ModelDownloadWorker_workId"; 54 public static final String INPUT_DATA_KEY_SCHEDULED_TIMESTAMP = 55 "ModelDownloadWorker_scheduledTimestamp"; 56 57 private final ListeningExecutorService executorService; 58 private final ModelDownloader downloader; 59 private final DownloadedModelManager downloadedModelManager; 60 private final TextClassifierSettings settings; 61 62 private final long workId; 63 64 private final Clock clock; 65 private final long workScheduledTimeMillis; 66 67 private final Object lock = new Object(); 68 69 private long workStartedTimeMillis = 0; 70 71 @GuardedBy("lock") 72 private final ArrayMap<String, ListenableFuture<Void>> pendingDownloads; 73 74 private ImmutableMap<String, ManifestsToDownloadByType> manifestsToDownload; 75 ModelDownloadWorker(Context context, WorkerParameters workerParams)76 public ModelDownloadWorker(Context context, WorkerParameters workerParams) { 77 super(context, workerParams); 78 this.executorService = TextClassifierServiceExecutors.getDownloaderExecutor(); 79 this.downloader = new ModelDownloaderImpl(context, executorService); 80 this.downloadedModelManager = DownloadedModelManagerImpl.getInstance(context); 81 this.settings = new TextClassifierSettings(context); 82 this.pendingDownloads = new ArrayMap<>(); 83 this.manifestsToDownload = null; 84 85 this.workId = workerParams.getInputData().getLong(INPUT_DATA_KEY_WORK_ID, 0); 86 this.workScheduledTimeMillis = 87 workerParams.getInputData().getLong(INPUT_DATA_KEY_SCHEDULED_TIMESTAMP, 0); 88 this.clock = Clock.systemUTC(); 89 } 90 91 @VisibleForTesting ModelDownloadWorker( Context context, WorkerParameters workerParams, ListeningExecutorService executorService, ModelDownloader modelDownloader, DownloadedModelManager downloadedModelManager, TextClassifierSettings settings, long workId, Clock clock, long workScheduledTimeMillis)92 ModelDownloadWorker( 93 Context context, 94 WorkerParameters workerParams, 95 ListeningExecutorService executorService, 96 ModelDownloader modelDownloader, 97 DownloadedModelManager downloadedModelManager, 98 TextClassifierSettings settings, 99 long workId, 100 Clock clock, 101 long workScheduledTimeMillis) { 102 super(context, workerParams); 103 this.executorService = executorService; 104 this.downloader = modelDownloader; 105 this.downloadedModelManager = downloadedModelManager; 106 this.settings = settings; 107 this.pendingDownloads = new ArrayMap<>(); 108 this.manifestsToDownload = null; 109 this.workId = workId; 110 this.clock = clock; 111 this.workScheduledTimeMillis = workScheduledTimeMillis; 112 } 113 114 @Override startWork()115 public final ListenableFuture<ListenableWorker.Result> startWork() { 116 TcLog.d(TAG, "Start download work..."); 117 workStartedTimeMillis = getCurrentTimeMillis(); 118 // Notice: startWork() is invoked on the main thread 119 if (!settings.isModelDownloadManagerEnabled()) { 120 TcLog.e(TAG, "Model Downloader is disabled. Abort the work."); 121 logDownloadWorkCompleted( 122 TextClassifierDownloadLogger.WORK_RESULT_FAILURE_MODEL_DOWNLOADER_DISABLED); 123 return Futures.immediateFuture(ListenableWorker.Result.failure()); 124 } 125 if (getRunAttemptCount() >= settings.getModelDownloadWorkerMaxAttempts()) { 126 TcLog.d(TAG, "Max attempt reached. Abort download work."); 127 logDownloadWorkCompleted( 128 TextClassifierDownloadLogger.WORK_RESULT_FAILURE_MAX_RUN_ATTEMPT_REACHED); 129 return Futures.immediateFuture(ListenableWorker.Result.failure()); 130 } 131 132 return FluentFuture.from(Futures.submitAsync(this::checkAndDownloadModels, executorService)) 133 .transform( 134 downloadResult -> { 135 Preconditions.checkNotNull(manifestsToDownload); 136 downloadedModelManager.onDownloadCompleted(manifestsToDownload); 137 TcLog.d(TAG, "Download work completed: " + downloadResult); 138 if (downloadResult.failureCount() == 0) { 139 logDownloadWorkCompleted( 140 downloadResult.successCount() > 0 141 ? TextClassifierDownloadLogger.WORK_RESULT_SUCCESS_MODEL_DOWNLOADED 142 : TextClassifierDownloadLogger.WORK_RESULT_SUCCESS_NO_UPDATE_AVAILABLE); 143 return ListenableWorker.Result.success(); 144 } else { 145 logDownloadWorkCompleted( 146 TextClassifierDownloadLogger.WORK_RESULT_RETRY_MODEL_DOWNLOAD_FAILED); 147 return ListenableWorker.Result.retry(); 148 } 149 }, 150 executorService) 151 .catching( 152 Throwable.class, 153 t -> { 154 TcLog.e(TAG, "Unexpected Exception during downloading: ", t); 155 logDownloadWorkCompleted( 156 TextClassifierDownloadLogger.WORK_RESULT_RETRY_RUNTIME_EXCEPTION); 157 return ListenableWorker.Result.retry(); 158 }, 159 executorService); 160 } 161 162 /** 163 * Checks device settings and returns the list of locales to download according to multi language 164 * support settings. Guarantees that the primary locale goes first. 165 */ getLocalesToDownload()166 private ImmutableList<Locale> getLocalesToDownload() { 167 LocaleList localeList = LocaleList.getAdjustedDefault(); 168 Locale primaryLocale = localeList.get(0); 169 if (!settings.isMultiLanguageSupportEnabled()) { 170 return ImmutableList.of(primaryLocale); 171 } 172 ImmutableList.Builder<Locale> localesToDownloadBuilder = ImmutableList.builder(); 173 int size = min(settings.getMultiLanguageModelsLimit(), localeList.size()); 174 for (int i = 0; i < size; i++) { 175 localesToDownloadBuilder.add(localeList.get(i)); 176 } 177 return localesToDownloadBuilder.build(); 178 } 179 180 /** 181 * Returns list of locales to download from {@code localeList} for the given {@code modelType}. 182 */ getLocalesToDownloadByType( ImmutableList<Locale> localeList, @ModelTypeDef String modelType)183 private ImmutableList<Locale> getLocalesToDownloadByType( 184 ImmutableList<Locale> localeList, @ModelTypeDef String modelType) { 185 if (!settings.getEnabledModelTypesForMultiLanguageSupport().contains(modelType)) { 186 return ImmutableList.of(Locale.getDefault()); 187 } 188 return localeList; 189 } 190 191 /** 192 * Check device config and dispatch download tasks for all modelTypes. 193 * 194 * <p>Download tasks will be combined and logged after completion. Return true if all tasks 195 * succeeded 196 */ checkAndDownloadModels()197 private ListenableFuture<DownloadResult> checkAndDownloadModels() { 198 ImmutableList<Locale> localesToDownload = getLocalesToDownload(); 199 ArrayList<ListenableFuture<Boolean>> downloadResultFutures = new ArrayList<>(); 200 ImmutableMap.Builder<String, ManifestsToDownloadByType> manifestsToDownloadBuilder = 201 ImmutableMap.builder(); 202 for (String modelType : ModelType.values()) { 203 ImmutableList<Locale> localesToDownloadByType = 204 getLocalesToDownloadByType(localesToDownload, modelType); 205 ImmutableMap.Builder<String, String> localeTagToManifestUrlBuilder = ImmutableMap.builder(); 206 for (Locale locale : localesToDownloadByType) { 207 Pair<String, String> bestLocaleTagAndManifestUrl = 208 LocaleUtils.lookupBestLocaleTagAndManifestUrl(modelType, locale, settings); 209 if (bestLocaleTagAndManifestUrl == null) { 210 TcLog.w( 211 TAG, 212 String.format( 213 Locale.US, "No suitable manifest for %s, %s", modelType, locale.toLanguageTag())); 214 continue; 215 } 216 String bestLocaleTag = bestLocaleTagAndManifestUrl.first; 217 String manifestUrl = bestLocaleTagAndManifestUrl.second; 218 localeTagToManifestUrlBuilder.put(bestLocaleTag, manifestUrl); 219 TcLog.d( 220 TAG, 221 String.format( 222 Locale.US, 223 "model type: %s, current locale tag: %s, best locale tag: %s, manifest url: %s", 224 modelType, 225 locale.toLanguageTag(), 226 bestLocaleTag, 227 manifestUrl)); 228 if (!shouldDownloadManifest(modelType, bestLocaleTag, manifestUrl)) { 229 continue; 230 } 231 downloadResultFutures.add( 232 downloadManifestAndRegister(modelType, bestLocaleTag, manifestUrl)); 233 } 234 manifestsToDownloadBuilder.put( 235 modelType, 236 ManifestsToDownloadByType.create(localeTagToManifestUrlBuilder.buildOrThrow())); 237 } 238 manifestsToDownload = manifestsToDownloadBuilder.buildOrThrow(); 239 240 return Futures.whenAllComplete(downloadResultFutures) 241 .call( 242 () -> { 243 TcLog.d(TAG, "All Download Tasks Completed"); 244 int successCount = 0; 245 int failureCount = 0; 246 for (ListenableFuture<Boolean> downloadResultFuture : downloadResultFutures) { 247 if (Futures.getDone(downloadResultFuture)) { 248 successCount += 1; 249 } else { 250 failureCount += 1; 251 } 252 } 253 return DownloadResult.create(successCount, failureCount); 254 }, 255 executorService); 256 } 257 258 private boolean shouldDownloadManifest( 259 @ModelTypeDef String modelType, String localeTag, String manifestUrl) { 260 Manifest downloadedManifest = downloadedModelManager.getManifest(manifestUrl); 261 if (downloadedManifest == null) { 262 return true; 263 } 264 if (downloadedManifest.getStatus() == Manifest.STATUS_FAILED) { 265 if (downloadedManifest.getFailureCounts() >= settings.getManifestDownloadMaxAttempts()) { 266 TcLog.w( 267 TAG, 268 String.format( 269 Locale.US, 270 "Manifest failed too many times, stop retrying: %s %d", 271 manifestUrl, 272 downloadedManifest.getFailureCounts())); 273 return false; 274 } else { 275 return true; 276 } 277 } 278 ManifestEnrollment manifestEnrollment = 279 downloadedModelManager.getManifestEnrollment(modelType, localeTag); 280 return manifestEnrollment == null || !manifestUrl.equals(manifestEnrollment.getManifestUrl()); 281 } 282 283 /** 284 * Downloads a single manifest and models configured inside it. 285 * 286 * <p>The returned future should always resolve to a ManifestDownloadResult as we catch all 287 * exceptions. 288 */ 289 private ListenableFuture<Boolean> downloadManifestAndRegister( 290 @ModelTypeDef String modelType, String localeTag, String manifestUrl) { 291 long downloadStartTimestamp = getCurrentTimeMillis(); 292 return FluentFuture.from(downloadManifest(manifestUrl)) 293 .transform( 294 unused -> { 295 downloadedModelManager.registerManifestEnrollment(modelType, localeTag, manifestUrl); 296 TextClassifierDownloadLogger.downloadSucceeded( 297 workId, 298 modelType, 299 manifestUrl, 300 getRunAttemptCount(), 301 getCurrentTimeMillis() - downloadStartTimestamp); 302 TcLog.d(TAG, "Manifest downloaded and registered: " + manifestUrl); 303 return true; 304 }, 305 executorService) 306 .catching( 307 Throwable.class, 308 t -> { 309 downloadedModelManager.registerManifestDownloadFailure(manifestUrl); 310 int errorCode = ModelDownloadException.UNKNOWN_FAILURE_REASON; 311 int downloaderLibErrorCode = 0; 312 if (t instanceof ModelDownloadException) { 313 ModelDownloadException mde = (ModelDownloadException) t; 314 errorCode = mde.getErrorCode(); 315 downloaderLibErrorCode = mde.getDownloaderLibErrorCode(); 316 } 317 TcLog.e(TAG, "Failed to download manfiest: " + manifestUrl, t); 318 TextClassifierDownloadLogger.downloadFailed( 319 workId, 320 modelType, 321 manifestUrl, 322 errorCode, 323 getRunAttemptCount(), 324 downloaderLibErrorCode, 325 getCurrentTimeMillis() - downloadStartTimestamp); 326 return false; 327 }, 328 executorService); 329 } 330 331 // Download a manifest and its models, and register it to Manifest table. 332 private ListenableFuture<Void> downloadManifest(String manifestUrl) { 333 synchronized (lock) { 334 Manifest downloadedManifest = downloadedModelManager.getManifest(manifestUrl); 335 if (downloadedManifest != null 336 && downloadedManifest.getStatus() == Manifest.STATUS_SUCCEEDED) { 337 TcLog.d(TAG, "Manifest already downloaded: " + manifestUrl); 338 return Futures.immediateVoidFuture(); 339 } 340 if (pendingDownloads.containsKey(manifestUrl)) { 341 return pendingDownloads.get(manifestUrl); 342 } 343 ListenableFuture<Void> manfiestDownloadFuture = 344 FluentFuture.from(downloader.downloadManifest(manifestUrl)) 345 .transformAsync( 346 manifest -> { 347 ModelManifest.Model modelInfo = manifest.getModels(0); 348 return Futures.transform( 349 downloadModel(modelInfo), unused -> modelInfo, executorService); 350 }, 351 executorService) 352 .transform( 353 modelInfo -> { 354 downloadedModelManager.registerManifest(manifestUrl, modelInfo.getUrl()); 355 return null; 356 }, 357 executorService); 358 pendingDownloads.put(manifestUrl, manfiestDownloadFuture); 359 return manfiestDownloadFuture; 360 } 361 } 362 363 // Download a model and register it into Model table. 364 private ListenableFuture<Void> downloadModel(ModelManifest.Model modelInfo) { 365 String modelUrl = modelInfo.getUrl(); 366 synchronized (lock) { 367 Model downloadedModel = downloadedModelManager.getModel(modelUrl); 368 if (downloadedModel != null) { 369 TcLog.d(TAG, "Model file already exists: " + downloadedModel.getModelPath()); 370 return Futures.immediateVoidFuture(); 371 } 372 if (pendingDownloads.containsKey(modelUrl)) { 373 return pendingDownloads.get(modelUrl); 374 } 375 ListenableFuture<Void> modelDownloadFuture = 376 FluentFuture.from( 377 downloader.downloadModel( 378 downloadedModelManager.getModelDownloaderDir(), modelInfo)) 379 .transform( 380 modelFile -> { 381 downloadedModelManager.registerModel(modelUrl, modelFile.getAbsolutePath()); 382 TcLog.d(TAG, "Model File downloaded: " + modelUrl); 383 return null; 384 }, 385 executorService); 386 pendingDownloads.put(modelUrl, modelDownloadFuture); 387 return modelDownloadFuture; 388 } 389 } 390 391 /** 392 * This method will be called when we our work gets interrupted by the system. Result future 393 * should have already been cancelled in that case. Unless it's because the REPLACE policy of 394 * WorkManager unique queue, the interrupted work will be rescheduled later. 395 */ 396 @Override 397 public final void onStopped() { 398 TcLog.d(TAG, String.format(Locale.US, "Stop download. Attempt:%d", getRunAttemptCount())); 399 logDownloadWorkCompleted(TextClassifierDownloadLogger.WORK_RESULT_RETRY_STOPPED_BY_OS); 400 } 401 402 private long getCurrentTimeMillis() { 403 return clock.instant().toEpochMilli(); 404 } 405 406 private void logDownloadWorkCompleted(int workResult) { 407 if (workStartedTimeMillis < workScheduledTimeMillis) { 408 TcLog.w( 409 TAG, 410 String.format( 411 Locale.US, 412 "Bad workStartedTimeMillis: %d, workScheduledTimeMillis: %d", 413 workStartedTimeMillis, 414 workScheduledTimeMillis)); 415 workStartedTimeMillis = workScheduledTimeMillis; 416 } 417 TextClassifierDownloadLogger.downloadWorkCompleted( 418 workId, 419 workResult, 420 getRunAttemptCount(), 421 workStartedTimeMillis - workScheduledTimeMillis, 422 getCurrentTimeMillis() - workStartedTimeMillis); 423 } 424 425 @AutoValue 426 abstract static class DownloadResult { 427 public abstract int successCount(); 428 429 public abstract int failureCount(); 430 431 public static DownloadResult create(int successCount, int failureCount) { 432 return new AutoValue_ModelDownloadWorker_DownloadResult(successCount, failureCount); 433 } 434 } 435 } 436