1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.textclassifier.downloader; 18 19 import static android.content.Context.BIND_AUTO_CREATE; 20 import static android.content.Context.BIND_NOT_FOREGROUND; 21 22 import android.content.ComponentName; 23 import android.content.Context; 24 import android.content.Intent; 25 import android.content.ServiceConnection; 26 import android.os.IBinder; 27 import androidx.concurrent.futures.CallbackToFutureAdapter; 28 import com.android.textclassifier.common.base.TcLog; 29 import com.android.textclassifier.protobuf.ExtensionRegistryLite; 30 import com.google.common.annotations.VisibleForTesting; 31 import com.google.common.base.Preconditions; 32 import com.google.common.hash.HashCode; 33 import com.google.common.hash.Hashing; 34 import com.google.common.io.Files; 35 import com.google.common.util.concurrent.FutureCallback; 36 import com.google.common.util.concurrent.Futures; 37 import com.google.common.util.concurrent.ListenableFuture; 38 import java.io.File; 39 import java.io.FileInputStream; 40 import java.io.IOException; 41 import java.net.URI; 42 import java.util.concurrent.ExecutorService; 43 44 /** 45 * ModelDownloader implementation that forwards requests to ModelDownloaderService. This is to 46 * restrict the INTERNET permission to the service process only (instead of the whole ExtServices). 47 */ 48 final class ModelDownloaderImpl implements ModelDownloader { 49 private static final String TAG = "ModelDownloaderImpl"; 50 51 private final Context context; 52 private final ExecutorService bgExecutorService; 53 private final Class<?> downloaderServiceClass; 54 ModelDownloaderImpl(Context context, ExecutorService bgExecutorService)55 public ModelDownloaderImpl(Context context, ExecutorService bgExecutorService) { 56 this(context, bgExecutorService, ModelDownloaderService.class); 57 } 58 59 @VisibleForTesting ModelDownloaderImpl( Context context, ExecutorService bgExecutorService, Class<?> downloaderServiceClass)60 ModelDownloaderImpl( 61 Context context, ExecutorService bgExecutorService, Class<?> downloaderServiceClass) { 62 this.context = context.getApplicationContext(); 63 this.bgExecutorService = bgExecutorService; 64 this.downloaderServiceClass = downloaderServiceClass; 65 } 66 67 @Override downloadManifest(String manifestUrl)68 public ListenableFuture<ModelManifest> downloadManifest(String manifestUrl) { 69 File manifestFile = 70 new File(context.getCacheDir(), manifestUrl.replaceAll("[^A-Za-z0-9]", "_") + ".manifest"); 71 return Futures.transform( 72 download(URI.create(manifestUrl), manifestFile), 73 bytesWritten -> { 74 try { 75 return ModelManifest.parseFrom( 76 new FileInputStream(manifestFile), ExtensionRegistryLite.getEmptyRegistry()); 77 } catch (Throwable t) { 78 throw new ModelDownloadException(ModelDownloadException.FAILED_TO_PARSE_MANIFEST, t); 79 } finally { 80 manifestFile.delete(); 81 } 82 }, 83 bgExecutorService); 84 } 85 86 @Override 87 public ListenableFuture<File> downloadModel(File targetDir, ModelManifest.Model model) { 88 File modelFile = new File(targetDir, model.getUrl().replaceAll("[^A-Za-z0-9]", "_") + ".model"); 89 ListenableFuture<File> modelFileFuture = 90 Futures.transform( 91 download(URI.create(model.getUrl()), modelFile), 92 bytesWritten -> { 93 validateModel(modelFile, model.getSizeInBytes(), model.getFingerprint()); 94 return modelFile; 95 }, 96 bgExecutorService); 97 Futures.addCallback( 98 modelFileFuture, 99 new FutureCallback<File>() { 100 @Override 101 public void onSuccess(File pendingModelFile) { 102 TcLog.d(TAG, "Download model successfully: " + pendingModelFile.getAbsolutePath()); 103 } 104 105 @Override 106 public void onFailure(Throwable t) { 107 modelFile.delete(); 108 TcLog.e(TAG, "Failed to download: " + modelFile.getAbsolutePath(), t); 109 } 110 }, 111 bgExecutorService); 112 return modelFileFuture; 113 } 114 115 // TODO(licha): Make this visible for testing. So we can avoid some duplicated test cases. 116 /** 117 * Downloads the file from uri to the targetFile. If the targetFile already exists, it will be 118 * deleted. Return bytes written if succeeds. 119 */ 120 private ListenableFuture<Long> download(URI uri, File targetFile) { 121 if (targetFile.exists()) { 122 TcLog.w( 123 TAG, 124 "Target file already exists. Delete it before downloading: " 125 + targetFile.getAbsolutePath()); 126 targetFile.delete(); 127 } 128 DownloaderServiceConnection conn = new DownloaderServiceConnection(); 129 ListenableFuture<IModelDownloaderService> downloaderServiceFuture = connect(conn); 130 ListenableFuture<Long> bytesWrittenFuture = 131 Futures.transformAsync( 132 downloaderServiceFuture, 133 service -> scheduleDownload(service, uri, targetFile), 134 bgExecutorService); 135 bytesWrittenFuture.addListener( 136 () -> { 137 try { 138 context.unbindService(conn); 139 } catch (IllegalArgumentException e) { 140 TcLog.e(TAG, "Error when unbind", e); 141 } 142 }, 143 bgExecutorService); 144 return bytesWrittenFuture; 145 } 146 147 /** Model verification. Throws unchecked Exceptions if validation fails. */ 148 private static void validateModel(File pendingModelFile, long sizeInBytes, String fingerprint) { 149 if (!pendingModelFile.exists()) { 150 throw new ModelDownloadException( 151 ModelDownloadException.DOWNLOADED_FILE_MISSING, "PendingModelFile does not exist."); 152 } 153 if (pendingModelFile.length() != sizeInBytes) { 154 throw new ModelDownloadException( 155 ModelDownloadException.FAILED_TO_VALIDATE_MODEL, 156 String.format( 157 "PendingModelFile size does not match: expected [%d] actual [%d]", 158 sizeInBytes, pendingModelFile.length())); 159 } 160 try { 161 HashCode pendingModelFingerprint = 162 Files.asByteSource(pendingModelFile).hash(Hashing.sha384()); 163 if (!pendingModelFingerprint.equals(HashCode.fromString(fingerprint))) { 164 throw new ModelDownloadException( 165 ModelDownloadException.FAILED_TO_VALIDATE_MODEL, 166 String.format( 167 "PendingModelFile fingerprint does not match: expected [%s] actual [%s]", 168 fingerprint, pendingModelFingerprint)); 169 } 170 } catch (IOException e) { 171 throw new ModelDownloadException(ModelDownloadException.FAILED_TO_VALIDATE_MODEL, e); 172 } 173 TcLog.d(TAG, "Pending model file passed validation."); 174 } 175 176 private ListenableFuture<IModelDownloaderService> connect(DownloaderServiceConnection conn) { 177 TcLog.d(TAG, "Starting a new connection to ModelDownloaderService"); 178 return CallbackToFutureAdapter.getFuture( 179 completer -> { 180 conn.attachCompleter(completer); 181 Intent intent = new Intent(context, downloaderServiceClass); 182 if (context.bindService(intent, conn, BIND_AUTO_CREATE | BIND_NOT_FOREGROUND)) { 183 return "Binding to service"; 184 } else { 185 completer.setException( 186 new ModelDownloadException( 187 ModelDownloadException.FAILED_TO_DOWNLOAD_SERVICE_CONN_BROKEN, 188 "Unable to bind to service")); 189 return "Binding failed"; 190 } 191 }); 192 } 193 194 // Here the returned download result future can be set by: 1) the service can invoke the callback 195 // and set the result/exception; 2) If the service crashed, the CallbackToFutureAdapter will try 196 // to fail the future when the callback is garbage collected. If somehow none of them worked, the 197 // restult future will hang there until time out. (WorkManager forces a 10-min running time.) 198 private static ListenableFuture<Long> scheduleDownload( 199 IModelDownloaderService service, URI uri, File targetFile) { 200 TcLog.d(TAG, "Scheduling a new download task with ModelDownloaderService"); 201 return CallbackToFutureAdapter.getFuture( 202 completer -> { 203 service.download( 204 uri.toString(), 205 targetFile.getAbsolutePath(), 206 new IModelDownloaderCallback.Stub() { 207 @Override 208 public void onSuccess(long bytesWritten) { 209 completer.set(bytesWritten); 210 } 211 212 @Override 213 public void onFailure(int downloaderLibErrorCode, String errorMsg) { 214 completer.setException( 215 new ModelDownloadException( 216 ModelDownloadException.FAILED_TO_DOWNLOAD_OTHER, 217 downloaderLibErrorCode, 218 errorMsg)); 219 } 220 }); 221 return "downlaoderService.download"; 222 }); 223 } 224 225 /** The implementation of {@link ServiceConnection} that handles changes in the connection. */ 226 @VisibleForTesting 227 static class DownloaderServiceConnection implements ServiceConnection { 228 private static final String TAG = "ModelDownloaderImpl.DownloaderServiceConnection"; 229 230 private CallbackToFutureAdapter.Completer<IModelDownloaderService> completer; 231 232 public void attachCompleter( 233 CallbackToFutureAdapter.Completer<IModelDownloaderService> completer) { 234 this.completer = completer; 235 } 236 237 @Override 238 public void onServiceConnected(ComponentName componentName, IBinder iBinder) { 239 TcLog.d(TAG, "DownloaderService connected"); 240 completer.set(Preconditions.checkNotNull(IModelDownloaderService.Stub.asInterface(iBinder))); 241 } 242 243 @Override 244 public void onServiceDisconnected(ComponentName componentName) { 245 // If this is invoked after onServiceConnected, it will be ignored by the completer. 246 completer.setException( 247 new ModelDownloadException( 248 ModelDownloadException.FAILED_TO_DOWNLOAD_SERVICE_CONN_BROKEN, 249 "Service disconnected")); 250 } 251 252 @Override 253 public void onBindingDied(ComponentName name) { 254 completer.setException( 255 new ModelDownloadException( 256 ModelDownloadException.FAILED_TO_DOWNLOAD_SERVICE_CONN_BROKEN, "Binding died")); 257 } 258 259 @Override 260 public void onNullBinding(ComponentName name) { 261 completer.setException( 262 new ModelDownloadException( 263 ModelDownloadException.FAILED_TO_DOWNLOAD_SERVICE_CONN_BROKEN, 264 "Unable to bind to DownloaderService")); 265 } 266 } 267 } 268