/* * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.android.textclassifier; import android.content.Context; import android.os.CancellationSignal; import android.service.textclassifier.TextClassifierService; import android.view.textclassifier.ConversationActions; import android.view.textclassifier.SelectionEvent; import android.view.textclassifier.TextClassification; import android.view.textclassifier.TextClassificationContext; import android.view.textclassifier.TextClassificationSessionId; import android.view.textclassifier.TextClassifierEvent; import android.view.textclassifier.TextLanguage; import android.view.textclassifier.TextLinks; import android.view.textclassifier.TextSelection; import androidx.annotation.NonNull; import androidx.collection.LruCache; import com.android.textclassifier.common.TextClassifierServiceExecutors; import com.android.textclassifier.common.TextClassifierSettings; import com.android.textclassifier.common.base.TcLog; import com.android.textclassifier.common.statsd.TextClassifierApiUsageLogger; import com.android.textclassifier.downloader.ModelDownloadManager; import com.android.textclassifier.utils.IndentingPrintWriter; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import com.google.common.util.concurrent.FutureCallback; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.ListeningExecutorService; import com.google.common.util.concurrent.MoreExecutors; import java.io.FileDescriptor; import java.io.PrintWriter; import java.util.Map; import java.util.concurrent.Callable; import java.util.concurrent.ExecutionException; import java.util.concurrent.Executor; import javax.annotation.Nullable; /** An implementation of a TextClassifierService. */ public final class DefaultTextClassifierService extends TextClassifierService { private static final String TAG = "default_tcs"; private final Injector injector; // TODO: Figure out do we need more concurrency. private ListeningExecutorService normPriorityExecutor; private ListeningExecutorService lowPriorityExecutor; @Nullable private ModelDownloadManager modelDownloadManager; private TextClassifierImpl textClassifier; private TextClassifierSettings settings; private ModelFileManager modelFileManager; private LruCache sessionIdToContext; public DefaultTextClassifierService() { this.injector = new InjectorImpl(this); } @VisibleForTesting DefaultTextClassifierService(Injector injector) { this.injector = Preconditions.checkNotNull(injector); } private TextClassifierApiUsageLogger textClassifierApiUsageLogger; @Override public void onCreate() { super.onCreate(); settings = injector.createTextClassifierSettings(); modelDownloadManager = new ModelDownloadManager( injector.getContext().getApplicationContext(), settings, TextClassifierServiceExecutors.getDownloaderExecutor()); modelDownloadManager.onTextClassifierServiceCreated(); modelFileManager = injector.createModelFileManager(settings, modelDownloadManager); normPriorityExecutor = injector.createNormPriorityExecutor(); lowPriorityExecutor = injector.createLowPriorityExecutor(); textClassifier = injector.createTextClassifierImpl(settings, modelFileManager); sessionIdToContext = new LruCache<>(settings.getSessionIdToContextCacheSize()); textClassifierApiUsageLogger = injector.createTextClassifierApiUsageLogger(settings, lowPriorityExecutor); } @Override public void onDestroy() { super.onDestroy(); modelDownloadManager.destroy(); } @Override public void onCreateTextClassificationSession( @NonNull TextClassificationContext context, @NonNull TextClassificationSessionId sessionId) { sessionIdToContext.put(sessionId, context); } @Override public void onDestroyTextClassificationSession(@NonNull TextClassificationSessionId sessionId) { sessionIdToContext.remove(sessionId); } @Override public void onSuggestSelection( TextClassificationSessionId sessionId, TextSelection.Request request, CancellationSignal cancellationSignal, Callback callback) { handleRequestAsync( () -> textClassifier.suggestSelection( sessionId, sessionIdToTextClassificationContext(sessionId), request), callback, textClassifierApiUsageLogger.createSession( TextClassifierApiUsageLogger.API_TYPE_SUGGEST_SELECTION, sessionId), cancellationSignal); } @Override public void onClassifyText( TextClassificationSessionId sessionId, TextClassification.Request request, CancellationSignal cancellationSignal, Callback callback) { handleRequestAsync( () -> textClassifier.classifyText( sessionId, sessionIdToTextClassificationContext(sessionId), request), callback, textClassifierApiUsageLogger.createSession( TextClassifierApiUsageLogger.API_TYPE_CLASSIFY_TEXT, sessionId), cancellationSignal); } @Override public void onGenerateLinks( TextClassificationSessionId sessionId, TextLinks.Request request, CancellationSignal cancellationSignal, Callback callback) { handleRequestAsync( () -> textClassifier.generateLinks( sessionId, sessionIdToTextClassificationContext(sessionId), request), callback, textClassifierApiUsageLogger.createSession( TextClassifierApiUsageLogger.API_TYPE_GENERATE_LINKS, sessionId), cancellationSignal); } @Override public void onSuggestConversationActions( TextClassificationSessionId sessionId, ConversationActions.Request request, CancellationSignal cancellationSignal, Callback callback) { handleRequestAsync( () -> textClassifier.suggestConversationActions( sessionId, sessionIdToTextClassificationContext(sessionId), request), callback, textClassifierApiUsageLogger.createSession( TextClassifierApiUsageLogger.API_TYPE_SUGGEST_CONVERSATION_ACTIONS, sessionId), cancellationSignal); } @Override public void onDetectLanguage( TextClassificationSessionId sessionId, TextLanguage.Request request, CancellationSignal cancellationSignal, Callback callback) { handleRequestAsync( () -> textClassifier.detectLanguage( sessionId, sessionIdToTextClassificationContext(sessionId), request), callback, textClassifierApiUsageLogger.createSession( TextClassifierApiUsageLogger.API_TYPE_DETECT_LANGUAGES, sessionId), cancellationSignal); } @Override public void onSelectionEvent(TextClassificationSessionId sessionId, SelectionEvent event) { handleEvent(() -> textClassifier.onSelectionEvent(sessionId, event)); } @Override public void onTextClassifierEvent( TextClassificationSessionId sessionId, TextClassifierEvent event) { handleEvent(() -> textClassifier.onTextClassifierEvent(sessionId, event)); } @Override protected void dump(FileDescriptor fd, PrintWriter writer, String[] args) { // Dump in a background thread b/c we may need to query Room db (e.g. to init model cache) try { TextClassifierServiceExecutors.getLowPriorityExecutor() .submit( () -> { IndentingPrintWriter indentingPrintWriter = new IndentingPrintWriter(writer); textClassifier.dump(indentingPrintWriter); modelDownloadManager.dump(indentingPrintWriter); dumpImpl(indentingPrintWriter); indentingPrintWriter.flush(); }) .get(); } catch (ExecutionException | InterruptedException e) { TcLog.e(TAG, "Failed to dump Default TextClassifierService", e); } } private void dumpImpl(IndentingPrintWriter printWriter) { printWriter.println("DefaultTextClassifierService:"); printWriter.increaseIndent(); printWriter.println("sessionIdToContext:"); printWriter.increaseIndent(); for (Map.Entry entry : sessionIdToContext.snapshot().entrySet()) { printWriter.printPair(entry.getKey().getValue(), entry.getValue()); } printWriter.decreaseIndent(); printWriter.decreaseIndent(); printWriter.println(); } private void handleRequestAsync( Callable callable, Callback callback, TextClassifierApiUsageLogger.Session apiLoggerSession, CancellationSignal cancellationSignal) { ListenableFuture result = normPriorityExecutor.submit(callable); Futures.addCallback( result, new FutureCallback() { @Override public void onSuccess(T result) { callback.onSuccess(result); apiLoggerSession.reportSuccess(); } @Override public void onFailure(Throwable t) { TcLog.e(TAG, "onFailure: ", t); callback.onFailure(t.getMessage()); apiLoggerSession.reportFailure(); } }, MoreExecutors.directExecutor()); cancellationSignal.setOnCancelListener(() -> result.cancel(/* mayInterruptIfRunning= */ true)); } private void handleEvent(Runnable runnable) { ListenableFuture result = lowPriorityExecutor.submit( () -> { runnable.run(); return null; }); Futures.addCallback( result, new FutureCallback() { @Override public void onSuccess(Void result) {} @Override public void onFailure(Throwable t) { TcLog.e(TAG, "onFailure: ", t); } }, MoreExecutors.directExecutor()); } @Nullable private TextClassificationContext sessionIdToTextClassificationContext( @Nullable TextClassificationSessionId sessionId) { if (sessionId == null) { return null; } return sessionIdToContext.get(sessionId); } // Do not call any of these methods, except the constructor, before Service.onCreate is called. private static class InjectorImpl implements Injector { // Do not access the context object before Service.onCreate is invoked. private final Context context; private InjectorImpl(Context context) { this.context = Preconditions.checkNotNull(context); } @Override public Context getContext() { return context; } @Override public ModelFileManager createModelFileManager( TextClassifierSettings settings, ModelDownloadManager modelDownloadManager) { return new ModelFileManagerImpl(context, modelDownloadManager, settings); } @Override public TextClassifierSettings createTextClassifierSettings() { return new TextClassifierSettings(); } @Override public TextClassifierImpl createTextClassifierImpl( TextClassifierSettings settings, ModelFileManager modelFileManager) { return new TextClassifierImpl(context, settings, modelFileManager); } @Override public ListeningExecutorService createNormPriorityExecutor() { return TextClassifierServiceExecutors.getNormhPriorityExecutor(); } @Override public ListeningExecutorService createLowPriorityExecutor() { return TextClassifierServiceExecutors.getLowPriorityExecutor(); } @Override public TextClassifierApiUsageLogger createTextClassifierApiUsageLogger( TextClassifierSettings settings, Executor executor) { return new TextClassifierApiUsageLogger( settings::getTextClassifierApiLogSampleRate, executor); } } /* * Provides dependencies to the {@link DefaultTextClassifierService}. This makes the service * class testable. */ interface Injector { Context getContext(); ModelFileManager createModelFileManager( TextClassifierSettings settings, ModelDownloadManager modelDownloadManager); TextClassifierSettings createTextClassifierSettings(); TextClassifierImpl createTextClassifierImpl( TextClassifierSettings settings, ModelFileManager modelFileManager); ListeningExecutorService createNormPriorityExecutor(); ListeningExecutorService createLowPriorityExecutor(); TextClassifierApiUsageLogger createTextClassifierApiUsageLogger( TextClassifierSettings settings, Executor executor); } }