1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.textclassifier; 18 19 import android.content.Context; 20 import android.os.CancellationSignal; 21 import android.service.textclassifier.TextClassifierService; 22 import android.view.textclassifier.ConversationActions; 23 import android.view.textclassifier.SelectionEvent; 24 import android.view.textclassifier.TextClassification; 25 import android.view.textclassifier.TextClassificationContext; 26 import android.view.textclassifier.TextClassificationSessionId; 27 import android.view.textclassifier.TextClassifierEvent; 28 import android.view.textclassifier.TextLanguage; 29 import android.view.textclassifier.TextLinks; 30 import android.view.textclassifier.TextSelection; 31 import androidx.annotation.NonNull; 32 import androidx.collection.LruCache; 33 import com.android.textclassifier.common.TextClassifierServiceExecutors; 34 import com.android.textclassifier.common.TextClassifierSettings; 35 import com.android.textclassifier.common.base.TcLog; 36 import com.android.textclassifier.common.statsd.TextClassifierApiUsageLogger; 37 import com.android.textclassifier.downloader.ModelDownloadManager; 38 import com.android.textclassifier.utils.IndentingPrintWriter; 39 import com.google.common.annotations.VisibleForTesting; 40 import com.google.common.base.Preconditions; 41 import com.google.common.util.concurrent.FutureCallback; 42 import com.google.common.util.concurrent.Futures; 43 import com.google.common.util.concurrent.ListenableFuture; 44 import com.google.common.util.concurrent.ListeningExecutorService; 45 import com.google.common.util.concurrent.MoreExecutors; 46 import java.io.FileDescriptor; 47 import java.io.PrintWriter; 48 import java.util.Map; 49 import java.util.concurrent.Callable; 50 import java.util.concurrent.ExecutionException; 51 import java.util.concurrent.Executor; 52 import javax.annotation.Nullable; 53 54 /** An implementation of a TextClassifierService. */ 55 public final class DefaultTextClassifierService extends TextClassifierService { 56 private static final String TAG = "default_tcs"; 57 58 private final Injector injector; 59 // TODO: Figure out do we need more concurrency. 60 private ListeningExecutorService normPriorityExecutor; 61 private ListeningExecutorService lowPriorityExecutor; 62 63 @Nullable private ModelDownloadManager modelDownloadManager; 64 65 private TextClassifierImpl textClassifier; 66 private TextClassifierSettings settings; 67 private ModelFileManager modelFileManager; 68 private LruCache<TextClassificationSessionId, TextClassificationContext> sessionIdToContext; 69 DefaultTextClassifierService()70 public DefaultTextClassifierService() { 71 this.injector = new InjectorImpl(this); 72 } 73 74 @VisibleForTesting DefaultTextClassifierService(Injector injector)75 DefaultTextClassifierService(Injector injector) { 76 this.injector = Preconditions.checkNotNull(injector); 77 } 78 79 private TextClassifierApiUsageLogger textClassifierApiUsageLogger; 80 81 @Override onCreate()82 public void onCreate() { 83 super.onCreate(); 84 settings = injector.createTextClassifierSettings(); 85 modelDownloadManager = 86 new ModelDownloadManager( 87 injector.getContext().getApplicationContext(), 88 settings, 89 TextClassifierServiceExecutors.getDownloaderExecutor()); 90 modelDownloadManager.onTextClassifierServiceCreated(); 91 modelFileManager = injector.createModelFileManager(settings, modelDownloadManager); 92 normPriorityExecutor = injector.createNormPriorityExecutor(); 93 lowPriorityExecutor = injector.createLowPriorityExecutor(); 94 textClassifier = injector.createTextClassifierImpl(settings, modelFileManager); 95 sessionIdToContext = new LruCache<>(settings.getSessionIdToContextCacheSize()); 96 textClassifierApiUsageLogger = 97 injector.createTextClassifierApiUsageLogger(settings, lowPriorityExecutor); 98 } 99 100 @Override onDestroy()101 public void onDestroy() { 102 super.onDestroy(); 103 modelDownloadManager.destroy(); 104 } 105 106 @Override onCreateTextClassificationSession( @onNull TextClassificationContext context, @NonNull TextClassificationSessionId sessionId)107 public void onCreateTextClassificationSession( 108 @NonNull TextClassificationContext context, @NonNull TextClassificationSessionId sessionId) { 109 sessionIdToContext.put(sessionId, context); 110 } 111 112 @Override onDestroyTextClassificationSession(@onNull TextClassificationSessionId sessionId)113 public void onDestroyTextClassificationSession(@NonNull TextClassificationSessionId sessionId) { 114 sessionIdToContext.remove(sessionId); 115 } 116 117 @Override onSuggestSelection( TextClassificationSessionId sessionId, TextSelection.Request request, CancellationSignal cancellationSignal, Callback<TextSelection> callback)118 public void onSuggestSelection( 119 TextClassificationSessionId sessionId, 120 TextSelection.Request request, 121 CancellationSignal cancellationSignal, 122 Callback<TextSelection> callback) { 123 handleRequestAsync( 124 () -> 125 textClassifier.suggestSelection( 126 sessionId, sessionIdToTextClassificationContext(sessionId), request), 127 callback, 128 textClassifierApiUsageLogger.createSession( 129 TextClassifierApiUsageLogger.API_TYPE_SUGGEST_SELECTION, sessionId), 130 cancellationSignal); 131 } 132 133 @Override onClassifyText( TextClassificationSessionId sessionId, TextClassification.Request request, CancellationSignal cancellationSignal, Callback<TextClassification> callback)134 public void onClassifyText( 135 TextClassificationSessionId sessionId, 136 TextClassification.Request request, 137 CancellationSignal cancellationSignal, 138 Callback<TextClassification> callback) { 139 handleRequestAsync( 140 () -> 141 textClassifier.classifyText( 142 sessionId, sessionIdToTextClassificationContext(sessionId), request), 143 callback, 144 textClassifierApiUsageLogger.createSession( 145 TextClassifierApiUsageLogger.API_TYPE_CLASSIFY_TEXT, sessionId), 146 cancellationSignal); 147 } 148 149 @Override onGenerateLinks( TextClassificationSessionId sessionId, TextLinks.Request request, CancellationSignal cancellationSignal, Callback<TextLinks> callback)150 public void onGenerateLinks( 151 TextClassificationSessionId sessionId, 152 TextLinks.Request request, 153 CancellationSignal cancellationSignal, 154 Callback<TextLinks> callback) { 155 handleRequestAsync( 156 () -> 157 textClassifier.generateLinks( 158 sessionId, sessionIdToTextClassificationContext(sessionId), request), 159 callback, 160 textClassifierApiUsageLogger.createSession( 161 TextClassifierApiUsageLogger.API_TYPE_GENERATE_LINKS, sessionId), 162 cancellationSignal); 163 } 164 165 @Override onSuggestConversationActions( TextClassificationSessionId sessionId, ConversationActions.Request request, CancellationSignal cancellationSignal, Callback<ConversationActions> callback)166 public void onSuggestConversationActions( 167 TextClassificationSessionId sessionId, 168 ConversationActions.Request request, 169 CancellationSignal cancellationSignal, 170 Callback<ConversationActions> callback) { 171 handleRequestAsync( 172 () -> 173 textClassifier.suggestConversationActions( 174 sessionId, sessionIdToTextClassificationContext(sessionId), request), 175 callback, 176 textClassifierApiUsageLogger.createSession( 177 TextClassifierApiUsageLogger.API_TYPE_SUGGEST_CONVERSATION_ACTIONS, sessionId), 178 cancellationSignal); 179 } 180 181 @Override onDetectLanguage( TextClassificationSessionId sessionId, TextLanguage.Request request, CancellationSignal cancellationSignal, Callback<TextLanguage> callback)182 public void onDetectLanguage( 183 TextClassificationSessionId sessionId, 184 TextLanguage.Request request, 185 CancellationSignal cancellationSignal, 186 Callback<TextLanguage> callback) { 187 handleRequestAsync( 188 () -> 189 textClassifier.detectLanguage( 190 sessionId, sessionIdToTextClassificationContext(sessionId), request), 191 callback, 192 textClassifierApiUsageLogger.createSession( 193 TextClassifierApiUsageLogger.API_TYPE_DETECT_LANGUAGES, sessionId), 194 cancellationSignal); 195 } 196 197 @Override onSelectionEvent(TextClassificationSessionId sessionId, SelectionEvent event)198 public void onSelectionEvent(TextClassificationSessionId sessionId, SelectionEvent event) { 199 handleEvent(() -> textClassifier.onSelectionEvent(sessionId, event)); 200 } 201 202 @Override onTextClassifierEvent( TextClassificationSessionId sessionId, TextClassifierEvent event)203 public void onTextClassifierEvent( 204 TextClassificationSessionId sessionId, TextClassifierEvent event) { 205 handleEvent(() -> textClassifier.onTextClassifierEvent(sessionId, event)); 206 } 207 208 @Override dump(FileDescriptor fd, PrintWriter writer, String[] args)209 protected void dump(FileDescriptor fd, PrintWriter writer, String[] args) { 210 // Dump in a background thread b/c we may need to query Room db (e.g. to init model cache) 211 try { 212 TextClassifierServiceExecutors.getLowPriorityExecutor() 213 .submit( 214 () -> { 215 IndentingPrintWriter indentingPrintWriter = new IndentingPrintWriter(writer); 216 textClassifier.dump(indentingPrintWriter); 217 modelDownloadManager.dump(indentingPrintWriter); 218 dumpImpl(indentingPrintWriter); 219 indentingPrintWriter.flush(); 220 }) 221 .get(); 222 } catch (ExecutionException | InterruptedException e) { 223 TcLog.e(TAG, "Failed to dump Default TextClassifierService", e); 224 } 225 } 226 dumpImpl(IndentingPrintWriter printWriter)227 private void dumpImpl(IndentingPrintWriter printWriter) { 228 printWriter.println("DefaultTextClassifierService:"); 229 printWriter.increaseIndent(); 230 printWriter.println("sessionIdToContext:"); 231 printWriter.increaseIndent(); 232 for (Map.Entry<TextClassificationSessionId, TextClassificationContext> entry : 233 sessionIdToContext.snapshot().entrySet()) { 234 printWriter.printPair(entry.getKey().getValue(), entry.getValue()); 235 } 236 printWriter.decreaseIndent(); 237 printWriter.decreaseIndent(); 238 printWriter.println(); 239 } 240 handleRequestAsync( Callable<T> callable, Callback<T> callback, TextClassifierApiUsageLogger.Session apiLoggerSession, CancellationSignal cancellationSignal)241 private <T> void handleRequestAsync( 242 Callable<T> callable, 243 Callback<T> callback, 244 TextClassifierApiUsageLogger.Session apiLoggerSession, 245 CancellationSignal cancellationSignal) { 246 ListenableFuture<T> result = normPriorityExecutor.submit(callable); 247 Futures.addCallback( 248 result, 249 new FutureCallback<T>() { 250 @Override 251 public void onSuccess(T result) { 252 callback.onSuccess(result); 253 apiLoggerSession.reportSuccess(); 254 } 255 256 @Override 257 public void onFailure(Throwable t) { 258 TcLog.e(TAG, "onFailure: ", t); 259 callback.onFailure(t.getMessage()); 260 apiLoggerSession.reportFailure(); 261 } 262 }, 263 MoreExecutors.directExecutor()); 264 cancellationSignal.setOnCancelListener(() -> result.cancel(/* mayInterruptIfRunning= */ true)); 265 } 266 handleEvent(Runnable runnable)267 private void handleEvent(Runnable runnable) { 268 ListenableFuture<Void> result = 269 lowPriorityExecutor.submit( 270 () -> { 271 runnable.run(); 272 return null; 273 }); 274 Futures.addCallback( 275 result, 276 new FutureCallback<Void>() { 277 @Override 278 public void onSuccess(Void result) {} 279 280 @Override 281 public void onFailure(Throwable t) { 282 TcLog.e(TAG, "onFailure: ", t); 283 } 284 }, 285 MoreExecutors.directExecutor()); 286 } 287 288 @Nullable sessionIdToTextClassificationContext( @ullable TextClassificationSessionId sessionId)289 private TextClassificationContext sessionIdToTextClassificationContext( 290 @Nullable TextClassificationSessionId sessionId) { 291 if (sessionId == null) { 292 return null; 293 } 294 return sessionIdToContext.get(sessionId); 295 } 296 297 // Do not call any of these methods, except the constructor, before Service.onCreate is called. 298 private static class InjectorImpl implements Injector { 299 // Do not access the context object before Service.onCreate is invoked. 300 private final Context context; 301 InjectorImpl(Context context)302 private InjectorImpl(Context context) { 303 this.context = Preconditions.checkNotNull(context); 304 } 305 306 @Override getContext()307 public Context getContext() { 308 return context; 309 } 310 311 @Override createModelFileManager( TextClassifierSettings settings, ModelDownloadManager modelDownloadManager)312 public ModelFileManager createModelFileManager( 313 TextClassifierSettings settings, ModelDownloadManager modelDownloadManager) { 314 return new ModelFileManagerImpl(context, modelDownloadManager, settings); 315 } 316 317 @Override createTextClassifierSettings()318 public TextClassifierSettings createTextClassifierSettings() { 319 return new TextClassifierSettings(getContext()); 320 } 321 322 @Override createTextClassifierImpl( TextClassifierSettings settings, ModelFileManager modelFileManager)323 public TextClassifierImpl createTextClassifierImpl( 324 TextClassifierSettings settings, ModelFileManager modelFileManager) { 325 return new TextClassifierImpl(context, settings, modelFileManager); 326 } 327 328 @Override createNormPriorityExecutor()329 public ListeningExecutorService createNormPriorityExecutor() { 330 return TextClassifierServiceExecutors.getNormhPriorityExecutor(); 331 } 332 333 @Override createLowPriorityExecutor()334 public ListeningExecutorService createLowPriorityExecutor() { 335 return TextClassifierServiceExecutors.getLowPriorityExecutor(); 336 } 337 338 @Override createTextClassifierApiUsageLogger( TextClassifierSettings settings, Executor executor)339 public TextClassifierApiUsageLogger createTextClassifierApiUsageLogger( 340 TextClassifierSettings settings, Executor executor) { 341 return new TextClassifierApiUsageLogger( 342 settings::getTextClassifierApiLogSampleRate, executor); 343 } 344 } 345 346 /* 347 * Provides dependencies to the {@link DefaultTextClassifierService}. This makes the service 348 * class testable. 349 */ 350 interface Injector { getContext()351 Context getContext(); 352 createModelFileManager( TextClassifierSettings settings, ModelDownloadManager modelDownloadManager)353 ModelFileManager createModelFileManager( 354 TextClassifierSettings settings, ModelDownloadManager modelDownloadManager); 355 createTextClassifierSettings()356 TextClassifierSettings createTextClassifierSettings(); 357 createTextClassifierImpl( TextClassifierSettings settings, ModelFileManager modelFileManager)358 TextClassifierImpl createTextClassifierImpl( 359 TextClassifierSettings settings, ModelFileManager modelFileManager); 360 createNormPriorityExecutor()361 ListeningExecutorService createNormPriorityExecutor(); 362 createLowPriorityExecutor()363 ListeningExecutorService createLowPriorityExecutor(); 364 createTextClassifierApiUsageLogger( TextClassifierSettings settings, Executor executor)365 TextClassifierApiUsageLogger createTextClassifierApiUsageLogger( 366 TextClassifierSettings settings, Executor executor); 367 } 368 } 369