• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.textclassifier.downloader;
18 
19 import static java.lang.Math.min;
20 
21 import android.content.Context;
22 import android.os.LocaleList;
23 import android.util.ArrayMap;
24 import android.util.Pair;
25 import androidx.work.ListenableWorker;
26 import androidx.work.WorkerParameters;
27 import com.android.textclassifier.common.ModelType;
28 import com.android.textclassifier.common.ModelType.ModelTypeDef;
29 import com.android.textclassifier.common.TextClassifierServiceExecutors;
30 import com.android.textclassifier.common.TextClassifierSettings;
31 import com.android.textclassifier.common.base.TcLog;
32 import com.android.textclassifier.downloader.DownloadedModelDatabase.Manifest;
33 import com.android.textclassifier.downloader.DownloadedModelDatabase.ManifestEnrollment;
34 import com.android.textclassifier.downloader.DownloadedModelDatabase.Model;
35 import com.google.auto.value.AutoValue;
36 import com.google.common.annotations.VisibleForTesting;
37 import com.google.common.base.Preconditions;
38 import com.google.common.collect.ImmutableList;
39 import com.google.common.collect.ImmutableMap;
40 import com.google.common.util.concurrent.FluentFuture;
41 import com.google.common.util.concurrent.Futures;
42 import com.google.common.util.concurrent.ListenableFuture;
43 import com.google.common.util.concurrent.ListeningExecutorService;
44 import com.google.errorprone.annotations.concurrent.GuardedBy;
45 import java.time.Clock;
46 import java.util.ArrayList;
47 import java.util.Locale;
48 
49 /** The WorkManager worker to download models for TextClassifierService. */
50 public final class ModelDownloadWorker extends ListenableWorker {
51   private static final String TAG = "ModelDownloadWorker";
52 
53   public static final String INPUT_DATA_KEY_WORK_ID = "ModelDownloadWorker_workId";
54   public static final String INPUT_DATA_KEY_SCHEDULED_TIMESTAMP =
55       "ModelDownloadWorker_scheduledTimestamp";
56 
57   private final ListeningExecutorService executorService;
58   private final ModelDownloader downloader;
59   private final DownloadedModelManager downloadedModelManager;
60   private final TextClassifierSettings settings;
61 
62   private final long workId;
63 
64   private final Clock clock;
65   private final long workScheduledTimeMillis;
66 
67   private final Object lock = new Object();
68 
69   private long workStartedTimeMillis = 0;
70 
71   @GuardedBy("lock")
72   private final ArrayMap<String, ListenableFuture<Void>> pendingDownloads;
73 
74   private ImmutableMap<String, ManifestsToDownloadByType> manifestsToDownload;
75 
ModelDownloadWorker(Context context, WorkerParameters workerParams)76   public ModelDownloadWorker(Context context, WorkerParameters workerParams) {
77     super(context, workerParams);
78     this.executorService = TextClassifierServiceExecutors.getDownloaderExecutor();
79     this.downloader = new ModelDownloaderImpl(context, executorService);
80     this.downloadedModelManager = DownloadedModelManagerImpl.getInstance(context);
81     this.settings = new TextClassifierSettings(context);
82     this.pendingDownloads = new ArrayMap<>();
83     this.manifestsToDownload = null;
84 
85     this.workId = workerParams.getInputData().getLong(INPUT_DATA_KEY_WORK_ID, 0);
86     this.workScheduledTimeMillis =
87         workerParams.getInputData().getLong(INPUT_DATA_KEY_SCHEDULED_TIMESTAMP, 0);
88     this.clock = Clock.systemUTC();
89   }
90 
91   @VisibleForTesting
ModelDownloadWorker( Context context, WorkerParameters workerParams, ListeningExecutorService executorService, ModelDownloader modelDownloader, DownloadedModelManager downloadedModelManager, TextClassifierSettings settings, long workId, Clock clock, long workScheduledTimeMillis)92   ModelDownloadWorker(
93       Context context,
94       WorkerParameters workerParams,
95       ListeningExecutorService executorService,
96       ModelDownloader modelDownloader,
97       DownloadedModelManager downloadedModelManager,
98       TextClassifierSettings settings,
99       long workId,
100       Clock clock,
101       long workScheduledTimeMillis) {
102     super(context, workerParams);
103     this.executorService = executorService;
104     this.downloader = modelDownloader;
105     this.downloadedModelManager = downloadedModelManager;
106     this.settings = settings;
107     this.pendingDownloads = new ArrayMap<>();
108     this.manifestsToDownload = null;
109     this.workId = workId;
110     this.clock = clock;
111     this.workScheduledTimeMillis = workScheduledTimeMillis;
112   }
113 
114   @Override
startWork()115   public final ListenableFuture<ListenableWorker.Result> startWork() {
116     TcLog.d(TAG, "Start download work...");
117     workStartedTimeMillis = getCurrentTimeMillis();
118     // Notice: startWork() is invoked on the main thread
119     if (!settings.isModelDownloadManagerEnabled()) {
120       TcLog.e(TAG, "Model Downloader is disabled. Abort the work.");
121       logDownloadWorkCompleted(
122           TextClassifierDownloadLogger.WORK_RESULT_FAILURE_MODEL_DOWNLOADER_DISABLED);
123       return Futures.immediateFuture(ListenableWorker.Result.failure());
124     }
125     if (getRunAttemptCount() >= settings.getModelDownloadWorkerMaxAttempts()) {
126       TcLog.d(TAG, "Max attempt reached. Abort download work.");
127       logDownloadWorkCompleted(
128           TextClassifierDownloadLogger.WORK_RESULT_FAILURE_MAX_RUN_ATTEMPT_REACHED);
129       return Futures.immediateFuture(ListenableWorker.Result.failure());
130     }
131 
132     return FluentFuture.from(Futures.submitAsync(this::checkAndDownloadModels, executorService))
133         .transform(
134             downloadResult -> {
135               Preconditions.checkNotNull(manifestsToDownload);
136               downloadedModelManager.onDownloadCompleted(manifestsToDownload);
137               TcLog.d(TAG, "Download work completed: " + downloadResult);
138               if (downloadResult.failureCount() == 0) {
139                 logDownloadWorkCompleted(
140                     downloadResult.successCount() > 0
141                         ? TextClassifierDownloadLogger.WORK_RESULT_SUCCESS_MODEL_DOWNLOADED
142                         : TextClassifierDownloadLogger.WORK_RESULT_SUCCESS_NO_UPDATE_AVAILABLE);
143                 return ListenableWorker.Result.success();
144               } else {
145                 logDownloadWorkCompleted(
146                     TextClassifierDownloadLogger.WORK_RESULT_RETRY_MODEL_DOWNLOAD_FAILED);
147                 return ListenableWorker.Result.retry();
148               }
149             },
150             executorService)
151         .catching(
152             Throwable.class,
153             t -> {
154               TcLog.e(TAG, "Unexpected Exception during downloading: ", t);
155               logDownloadWorkCompleted(
156                   TextClassifierDownloadLogger.WORK_RESULT_RETRY_RUNTIME_EXCEPTION);
157               return ListenableWorker.Result.retry();
158             },
159             executorService);
160   }
161 
162   /**
163    * Checks device settings and returns the list of locales to download according to multi language
164    * support settings. Guarantees that the primary locale goes first.
165    */
getLocalesToDownload()166   private ImmutableList<Locale> getLocalesToDownload() {
167     LocaleList localeList = LocaleList.getAdjustedDefault();
168     Locale primaryLocale = localeList.get(0);
169     if (!settings.isMultiLanguageSupportEnabled()) {
170       return ImmutableList.of(primaryLocale);
171     }
172     ImmutableList.Builder<Locale> localesToDownloadBuilder = ImmutableList.builder();
173     int size = min(settings.getMultiLanguageModelsLimit(), localeList.size());
174     for (int i = 0; i < size; i++) {
175       localesToDownloadBuilder.add(localeList.get(i));
176     }
177     return localesToDownloadBuilder.build();
178   }
179 
180   /**
181    * Returns list of locales to download from {@code localeList} for the given {@code modelType}.
182    */
getLocalesToDownloadByType( ImmutableList<Locale> localeList, @ModelTypeDef String modelType)183   private ImmutableList<Locale> getLocalesToDownloadByType(
184       ImmutableList<Locale> localeList, @ModelTypeDef String modelType) {
185     if (!settings.getEnabledModelTypesForMultiLanguageSupport().contains(modelType)) {
186       return ImmutableList.of(Locale.getDefault());
187     }
188     return localeList;
189   }
190 
191   /**
192    * Check device config and dispatch download tasks for all modelTypes.
193    *
194    * <p>Download tasks will be combined and logged after completion. Return true if all tasks
195    * succeeded
196    */
checkAndDownloadModels()197   private ListenableFuture<DownloadResult> checkAndDownloadModels() {
198     ImmutableList<Locale> localesToDownload = getLocalesToDownload();
199     ArrayList<ListenableFuture<Boolean>> downloadResultFutures = new ArrayList<>();
200     ImmutableMap.Builder<String, ManifestsToDownloadByType> manifestsToDownloadBuilder =
201         ImmutableMap.builder();
202     for (String modelType : ModelType.values()) {
203       ImmutableList<Locale> localesToDownloadByType =
204           getLocalesToDownloadByType(localesToDownload, modelType);
205       ImmutableMap.Builder<String, String> localeTagToManifestUrlBuilder = ImmutableMap.builder();
206       for (Locale locale : localesToDownloadByType) {
207         Pair<String, String> bestLocaleTagAndManifestUrl =
208             LocaleUtils.lookupBestLocaleTagAndManifestUrl(modelType, locale, settings);
209         if (bestLocaleTagAndManifestUrl == null) {
210           TcLog.w(
211               TAG,
212               String.format(
213                   Locale.US, "No suitable manifest for %s, %s", modelType, locale.toLanguageTag()));
214           continue;
215         }
216         String bestLocaleTag = bestLocaleTagAndManifestUrl.first;
217         String manifestUrl = bestLocaleTagAndManifestUrl.second;
218         localeTagToManifestUrlBuilder.put(bestLocaleTag, manifestUrl);
219         TcLog.d(
220             TAG,
221             String.format(
222                 Locale.US,
223                 "model type: %s, current locale tag: %s, best locale tag: %s, manifest url: %s",
224                 modelType,
225                 locale.toLanguageTag(),
226                 bestLocaleTag,
227                 manifestUrl));
228         if (!shouldDownloadManifest(modelType, bestLocaleTag, manifestUrl)) {
229           continue;
230         }
231         downloadResultFutures.add(
232             downloadManifestAndRegister(modelType, bestLocaleTag, manifestUrl));
233       }
234       manifestsToDownloadBuilder.put(
235           modelType,
236           ManifestsToDownloadByType.create(localeTagToManifestUrlBuilder.buildOrThrow()));
237     }
238     manifestsToDownload = manifestsToDownloadBuilder.buildOrThrow();
239 
240     return Futures.whenAllComplete(downloadResultFutures)
241         .call(
242             () -> {
243               TcLog.d(TAG, "All Download Tasks Completed");
244               int successCount = 0;
245               int failureCount = 0;
246               for (ListenableFuture<Boolean> downloadResultFuture : downloadResultFutures) {
247                 if (Futures.getDone(downloadResultFuture)) {
248                   successCount += 1;
249                 } else {
250                   failureCount += 1;
251                 }
252               }
253               return DownloadResult.create(successCount, failureCount);
254             },
255             executorService);
256   }
257 
258   private boolean shouldDownloadManifest(
259       @ModelTypeDef String modelType, String localeTag, String manifestUrl) {
260     Manifest downloadedManifest = downloadedModelManager.getManifest(manifestUrl);
261     if (downloadedManifest == null) {
262       return true;
263     }
264     if (downloadedManifest.getStatus() == Manifest.STATUS_FAILED) {
265       if (downloadedManifest.getFailureCounts() >= settings.getManifestDownloadMaxAttempts()) {
266         TcLog.w(
267             TAG,
268             String.format(
269                 Locale.US,
270                 "Manifest failed too many times, stop retrying: %s %d",
271                 manifestUrl,
272                 downloadedManifest.getFailureCounts()));
273         return false;
274       } else {
275         return true;
276       }
277     }
278     ManifestEnrollment manifestEnrollment =
279         downloadedModelManager.getManifestEnrollment(modelType, localeTag);
280     return manifestEnrollment == null || !manifestUrl.equals(manifestEnrollment.getManifestUrl());
281   }
282 
283   /**
284    * Downloads a single manifest and models configured inside it.
285    *
286    * <p>The returned future should always resolve to a ManifestDownloadResult as we catch all
287    * exceptions.
288    */
289   private ListenableFuture<Boolean> downloadManifestAndRegister(
290       @ModelTypeDef String modelType, String localeTag, String manifestUrl) {
291     long downloadStartTimestamp = getCurrentTimeMillis();
292     return FluentFuture.from(downloadManifest(manifestUrl))
293         .transform(
294             unused -> {
295               downloadedModelManager.registerManifestEnrollment(modelType, localeTag, manifestUrl);
296               TextClassifierDownloadLogger.downloadSucceeded(
297                   workId,
298                   modelType,
299                   manifestUrl,
300                   getRunAttemptCount(),
301                   getCurrentTimeMillis() - downloadStartTimestamp);
302               TcLog.d(TAG, "Manifest downloaded and registered: " + manifestUrl);
303               return true;
304             },
305             executorService)
306         .catching(
307             Throwable.class,
308             t -> {
309               downloadedModelManager.registerManifestDownloadFailure(manifestUrl);
310               int errorCode = ModelDownloadException.UNKNOWN_FAILURE_REASON;
311               int downloaderLibErrorCode = 0;
312               if (t instanceof ModelDownloadException) {
313                 ModelDownloadException mde = (ModelDownloadException) t;
314                 errorCode = mde.getErrorCode();
315                 downloaderLibErrorCode = mde.getDownloaderLibErrorCode();
316               }
317               TcLog.e(TAG, "Failed to download manfiest: " + manifestUrl, t);
318               TextClassifierDownloadLogger.downloadFailed(
319                   workId,
320                   modelType,
321                   manifestUrl,
322                   errorCode,
323                   getRunAttemptCount(),
324                   downloaderLibErrorCode,
325                   getCurrentTimeMillis() - downloadStartTimestamp);
326               return false;
327             },
328             executorService);
329   }
330 
331   // Download a manifest and its models, and register it to Manifest table.
332   private ListenableFuture<Void> downloadManifest(String manifestUrl) {
333     synchronized (lock) {
334       Manifest downloadedManifest = downloadedModelManager.getManifest(manifestUrl);
335       if (downloadedManifest != null
336           && downloadedManifest.getStatus() == Manifest.STATUS_SUCCEEDED) {
337         TcLog.d(TAG, "Manifest already downloaded: " + manifestUrl);
338         return Futures.immediateVoidFuture();
339       }
340       if (pendingDownloads.containsKey(manifestUrl)) {
341         return pendingDownloads.get(manifestUrl);
342       }
343       ListenableFuture<Void> manfiestDownloadFuture =
344           FluentFuture.from(downloader.downloadManifest(manifestUrl))
345               .transformAsync(
346                   manifest -> {
347                     ModelManifest.Model modelInfo = manifest.getModels(0);
348                     return Futures.transform(
349                         downloadModel(modelInfo), unused -> modelInfo, executorService);
350                   },
351                   executorService)
352               .transform(
353                   modelInfo -> {
354                     downloadedModelManager.registerManifest(manifestUrl, modelInfo.getUrl());
355                     return null;
356                   },
357                   executorService);
358       pendingDownloads.put(manifestUrl, manfiestDownloadFuture);
359       return manfiestDownloadFuture;
360     }
361   }
362 
363   // Download a model and register it into Model table.
364   private ListenableFuture<Void> downloadModel(ModelManifest.Model modelInfo) {
365     String modelUrl = modelInfo.getUrl();
366     synchronized (lock) {
367       Model downloadedModel = downloadedModelManager.getModel(modelUrl);
368       if (downloadedModel != null) {
369         TcLog.d(TAG, "Model file already exists: " + downloadedModel.getModelPath());
370         return Futures.immediateVoidFuture();
371       }
372       if (pendingDownloads.containsKey(modelUrl)) {
373         return pendingDownloads.get(modelUrl);
374       }
375       ListenableFuture<Void> modelDownloadFuture =
376           FluentFuture.from(
377                   downloader.downloadModel(
378                       downloadedModelManager.getModelDownloaderDir(), modelInfo))
379               .transform(
380                   modelFile -> {
381                     downloadedModelManager.registerModel(modelUrl, modelFile.getAbsolutePath());
382                     TcLog.d(TAG, "Model File downloaded: " + modelUrl);
383                     return null;
384                   },
385                   executorService);
386       pendingDownloads.put(modelUrl, modelDownloadFuture);
387       return modelDownloadFuture;
388     }
389   }
390 
391   /**
392    * This method will be called when we our work gets interrupted by the system. Result future
393    * should have already been cancelled in that case. Unless it's because the REPLACE policy of
394    * WorkManager unique queue, the interrupted work will be rescheduled later.
395    */
396   @Override
397   public final void onStopped() {
398     TcLog.d(TAG, String.format(Locale.US, "Stop download. Attempt:%d", getRunAttemptCount()));
399     logDownloadWorkCompleted(TextClassifierDownloadLogger.WORK_RESULT_RETRY_STOPPED_BY_OS);
400   }
401 
402   private long getCurrentTimeMillis() {
403     return clock.instant().toEpochMilli();
404   }
405 
406   private void logDownloadWorkCompleted(int workResult) {
407     if (workStartedTimeMillis < workScheduledTimeMillis) {
408       TcLog.w(
409           TAG,
410           String.format(
411               Locale.US,
412               "Bad workStartedTimeMillis: %d, workScheduledTimeMillis: %d",
413               workStartedTimeMillis,
414               workScheduledTimeMillis));
415       workStartedTimeMillis = workScheduledTimeMillis;
416     }
417     TextClassifierDownloadLogger.downloadWorkCompleted(
418         workId,
419         workResult,
420         getRunAttemptCount(),
421         workStartedTimeMillis - workScheduledTimeMillis,
422         getCurrentTimeMillis() - workStartedTimeMillis);
423   }
424 
425   @AutoValue
426   abstract static class DownloadResult {
427     public abstract int successCount();
428 
429     public abstract int failureCount();
430 
431     public static DownloadResult create(int successCount, int failureCount) {
432       return new AutoValue_ModelDownloadWorker_DownloadResult(successCount, failureCount);
433     }
434   }
435 }
436