• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.textclassifier.downloader;
18 
19 import static android.content.Context.BIND_AUTO_CREATE;
20 import static android.content.Context.BIND_NOT_FOREGROUND;
21 
22 import android.content.ComponentName;
23 import android.content.Context;
24 import android.content.Intent;
25 import android.content.ServiceConnection;
26 import android.os.IBinder;
27 import androidx.concurrent.futures.CallbackToFutureAdapter;
28 import com.android.textclassifier.common.base.TcLog;
29 import com.android.textclassifier.protobuf.ExtensionRegistryLite;
30 import com.google.common.annotations.VisibleForTesting;
31 import com.google.common.base.Preconditions;
32 import com.google.common.hash.HashCode;
33 import com.google.common.hash.Hashing;
34 import com.google.common.io.Files;
35 import com.google.common.util.concurrent.FutureCallback;
36 import com.google.common.util.concurrent.Futures;
37 import com.google.common.util.concurrent.ListenableFuture;
38 import java.io.File;
39 import java.io.FileInputStream;
40 import java.io.IOException;
41 import java.net.URI;
42 import java.util.concurrent.ExecutorService;
43 
44 /**
45  * ModelDownloader implementation that forwards requests to ModelDownloaderService. This is to
46  * restrict the INTERNET permission to the service process only (instead of the whole ExtServices).
47  */
48 final class ModelDownloaderImpl implements ModelDownloader {
49   private static final String TAG = "ModelDownloaderImpl";
50 
51   private final Context context;
52   private final ExecutorService bgExecutorService;
53   private final Class<?> downloaderServiceClass;
54 
ModelDownloaderImpl(Context context, ExecutorService bgExecutorService)55   public ModelDownloaderImpl(Context context, ExecutorService bgExecutorService) {
56     this(context, bgExecutorService, ModelDownloaderService.class);
57   }
58 
59   @VisibleForTesting
ModelDownloaderImpl( Context context, ExecutorService bgExecutorService, Class<?> downloaderServiceClass)60   ModelDownloaderImpl(
61       Context context, ExecutorService bgExecutorService, Class<?> downloaderServiceClass) {
62     this.context = context.getApplicationContext();
63     this.bgExecutorService = bgExecutorService;
64     this.downloaderServiceClass = downloaderServiceClass;
65   }
66 
67   @Override
downloadManifest(String manifestUrl)68   public ListenableFuture<ModelManifest> downloadManifest(String manifestUrl) {
69     File manifestFile =
70         new File(context.getCacheDir(), manifestUrl.replaceAll("[^A-Za-z0-9]", "_") + ".manifest");
71     return Futures.transform(
72         download(URI.create(manifestUrl), manifestFile),
73         bytesWritten -> {
74           try {
75             return ModelManifest.parseFrom(
76                 new FileInputStream(manifestFile), ExtensionRegistryLite.getEmptyRegistry());
77           } catch (Throwable t) {
78             throw new ModelDownloadException(ModelDownloadException.FAILED_TO_PARSE_MANIFEST, t);
79           } finally {
80             manifestFile.delete();
81           }
82         },
83         bgExecutorService);
84   }
85 
86   @Override
87   public ListenableFuture<File> downloadModel(File targetDir, ModelManifest.Model model) {
88     File modelFile = new File(targetDir, model.getUrl().replaceAll("[^A-Za-z0-9]", "_") + ".model");
89     ListenableFuture<File> modelFileFuture =
90         Futures.transform(
91             download(URI.create(model.getUrl()), modelFile),
92             bytesWritten -> {
93               validateModel(modelFile, model.getSizeInBytes(), model.getFingerprint());
94               return modelFile;
95             },
96             bgExecutorService);
97     Futures.addCallback(
98         modelFileFuture,
99         new FutureCallback<File>() {
100           @Override
101           public void onSuccess(File pendingModelFile) {
102             TcLog.d(TAG, "Download model successfully: " + pendingModelFile.getAbsolutePath());
103           }
104 
105           @Override
106           public void onFailure(Throwable t) {
107             modelFile.delete();
108             TcLog.e(TAG, "Failed to download: " + modelFile.getAbsolutePath(), t);
109           }
110         },
111         bgExecutorService);
112     return modelFileFuture;
113   }
114 
115   // TODO(licha): Make this visible for testing. So we can avoid some duplicated test cases.
116   /**
117    * Downloads the file from uri to the targetFile. If the targetFile already exists, it will be
118    * deleted. Return bytes written if succeeds.
119    */
120   private ListenableFuture<Long> download(URI uri, File targetFile) {
121     if (targetFile.exists()) {
122       TcLog.w(
123           TAG,
124           "Target file already exists. Delete it before downloading: "
125               + targetFile.getAbsolutePath());
126       targetFile.delete();
127     }
128     DownloaderServiceConnection conn = new DownloaderServiceConnection();
129     ListenableFuture<IModelDownloaderService> downloaderServiceFuture = connect(conn);
130     ListenableFuture<Long> bytesWrittenFuture =
131         Futures.transformAsync(
132             downloaderServiceFuture,
133             service -> scheduleDownload(service, uri, targetFile),
134             bgExecutorService);
135     bytesWrittenFuture.addListener(
136         () -> {
137           try {
138             context.unbindService(conn);
139           } catch (IllegalArgumentException e) {
140             TcLog.e(TAG, "Error when unbind", e);
141           }
142         },
143         bgExecutorService);
144     return bytesWrittenFuture;
145   }
146 
147   /** Model verification. Throws unchecked Exceptions if validation fails. */
148   private static void validateModel(File pendingModelFile, long sizeInBytes, String fingerprint) {
149     if (!pendingModelFile.exists()) {
150       throw new ModelDownloadException(
151           ModelDownloadException.DOWNLOADED_FILE_MISSING, "PendingModelFile does not exist.");
152     }
153     if (pendingModelFile.length() != sizeInBytes) {
154       throw new ModelDownloadException(
155           ModelDownloadException.FAILED_TO_VALIDATE_MODEL,
156           String.format(
157               "PendingModelFile size does not match: expected [%d] actual [%d]",
158               sizeInBytes, pendingModelFile.length()));
159     }
160     try {
161       HashCode pendingModelFingerprint =
162           Files.asByteSource(pendingModelFile).hash(Hashing.sha384());
163       if (!pendingModelFingerprint.equals(HashCode.fromString(fingerprint))) {
164         throw new ModelDownloadException(
165             ModelDownloadException.FAILED_TO_VALIDATE_MODEL,
166             String.format(
167                 "PendingModelFile fingerprint does not match: expected [%s] actual [%s]",
168                 fingerprint, pendingModelFingerprint));
169       }
170     } catch (IOException e) {
171       throw new ModelDownloadException(ModelDownloadException.FAILED_TO_VALIDATE_MODEL, e);
172     }
173     TcLog.d(TAG, "Pending model file passed validation.");
174   }
175 
176   private ListenableFuture<IModelDownloaderService> connect(DownloaderServiceConnection conn) {
177     TcLog.d(TAG, "Starting a new connection to ModelDownloaderService");
178     return CallbackToFutureAdapter.getFuture(
179         completer -> {
180           conn.attachCompleter(completer);
181           Intent intent = new Intent(context, downloaderServiceClass);
182           if (context.bindService(intent, conn, BIND_AUTO_CREATE | BIND_NOT_FOREGROUND)) {
183             return "Binding to service";
184           } else {
185             completer.setException(
186                 new ModelDownloadException(
187                     ModelDownloadException.FAILED_TO_DOWNLOAD_SERVICE_CONN_BROKEN,
188                     "Unable to bind to service"));
189             return "Binding failed";
190           }
191         });
192   }
193 
194   // Here the returned download result future can be set by: 1) the service can invoke the callback
195   // and set the result/exception; 2) If the service crashed, the CallbackToFutureAdapter will try
196   // to fail the future when the callback is garbage collected. If somehow none of them worked, the
197   // restult future will hang there until time out. (WorkManager forces a 10-min running time.)
198   private static ListenableFuture<Long> scheduleDownload(
199       IModelDownloaderService service, URI uri, File targetFile) {
200     TcLog.d(TAG, "Scheduling a new download task with ModelDownloaderService");
201     return CallbackToFutureAdapter.getFuture(
202         completer -> {
203           service.download(
204               uri.toString(),
205               targetFile.getAbsolutePath(),
206               new IModelDownloaderCallback.Stub() {
207                 @Override
208                 public void onSuccess(long bytesWritten) {
209                   completer.set(bytesWritten);
210                 }
211 
212                 @Override
213                 public void onFailure(int downloaderLibErrorCode, String errorMsg) {
214                   completer.setException(
215                       new ModelDownloadException(
216                           ModelDownloadException.FAILED_TO_DOWNLOAD_OTHER,
217                           downloaderLibErrorCode,
218                           errorMsg));
219                 }
220               });
221           return "downlaoderService.download";
222         });
223   }
224 
225   /** The implementation of {@link ServiceConnection} that handles changes in the connection. */
226   @VisibleForTesting
227   static class DownloaderServiceConnection implements ServiceConnection {
228     private static final String TAG = "ModelDownloaderImpl.DownloaderServiceConnection";
229 
230     private CallbackToFutureAdapter.Completer<IModelDownloaderService> completer;
231 
232     public void attachCompleter(
233         CallbackToFutureAdapter.Completer<IModelDownloaderService> completer) {
234       this.completer = completer;
235     }
236 
237     @Override
238     public void onServiceConnected(ComponentName componentName, IBinder iBinder) {
239       TcLog.d(TAG, "DownloaderService connected");
240       completer.set(Preconditions.checkNotNull(IModelDownloaderService.Stub.asInterface(iBinder)));
241     }
242 
243     @Override
244     public void onServiceDisconnected(ComponentName componentName) {
245       // If this is invoked after onServiceConnected, it will be ignored by the completer.
246       completer.setException(
247           new ModelDownloadException(
248               ModelDownloadException.FAILED_TO_DOWNLOAD_SERVICE_CONN_BROKEN,
249               "Service disconnected"));
250     }
251 
252     @Override
253     public void onBindingDied(ComponentName name) {
254       completer.setException(
255           new ModelDownloadException(
256               ModelDownloadException.FAILED_TO_DOWNLOAD_SERVICE_CONN_BROKEN, "Binding died"));
257     }
258 
259     @Override
260     public void onNullBinding(ComponentName name) {
261       completer.setException(
262           new ModelDownloadException(
263               ModelDownloadException.FAILED_TO_DOWNLOAD_SERVICE_CONN_BROKEN,
264               "Unable to bind to DownloaderService"));
265     }
266   }
267 }
268