• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.textclassifier;
18 
19 import android.content.Context;
20 import android.os.CancellationSignal;
21 import android.service.textclassifier.TextClassifierService;
22 import android.view.textclassifier.ConversationActions;
23 import android.view.textclassifier.SelectionEvent;
24 import android.view.textclassifier.TextClassification;
25 import android.view.textclassifier.TextClassificationContext;
26 import android.view.textclassifier.TextClassificationSessionId;
27 import android.view.textclassifier.TextClassifierEvent;
28 import android.view.textclassifier.TextLanguage;
29 import android.view.textclassifier.TextLinks;
30 import android.view.textclassifier.TextSelection;
31 import androidx.annotation.NonNull;
32 import androidx.collection.LruCache;
33 import com.android.textclassifier.common.TextClassifierServiceExecutors;
34 import com.android.textclassifier.common.TextClassifierSettings;
35 import com.android.textclassifier.common.base.TcLog;
36 import com.android.textclassifier.common.statsd.TextClassifierApiUsageLogger;
37 import com.android.textclassifier.downloader.ModelDownloadManager;
38 import com.android.textclassifier.utils.IndentingPrintWriter;
39 import com.google.common.annotations.VisibleForTesting;
40 import com.google.common.base.Preconditions;
41 import com.google.common.util.concurrent.FutureCallback;
42 import com.google.common.util.concurrent.Futures;
43 import com.google.common.util.concurrent.ListenableFuture;
44 import com.google.common.util.concurrent.ListeningExecutorService;
45 import com.google.common.util.concurrent.MoreExecutors;
46 import java.io.FileDescriptor;
47 import java.io.PrintWriter;
48 import java.util.Map;
49 import java.util.concurrent.Callable;
50 import java.util.concurrent.ExecutionException;
51 import java.util.concurrent.Executor;
52 import javax.annotation.Nullable;
53 
54 /** An implementation of a TextClassifierService. */
55 public final class DefaultTextClassifierService extends TextClassifierService {
56   private static final String TAG = "default_tcs";
57 
58   private final Injector injector;
59   // TODO: Figure out do we need more concurrency.
60   private ListeningExecutorService normPriorityExecutor;
61   private ListeningExecutorService lowPriorityExecutor;
62 
63   @Nullable private ModelDownloadManager modelDownloadManager;
64 
65   private TextClassifierImpl textClassifier;
66   private TextClassifierSettings settings;
67   private ModelFileManager modelFileManager;
68   private LruCache<TextClassificationSessionId, TextClassificationContext> sessionIdToContext;
69 
DefaultTextClassifierService()70   public DefaultTextClassifierService() {
71     this.injector = new InjectorImpl(this);
72   }
73 
74   @VisibleForTesting
DefaultTextClassifierService(Injector injector)75   DefaultTextClassifierService(Injector injector) {
76     this.injector = Preconditions.checkNotNull(injector);
77   }
78 
79   private TextClassifierApiUsageLogger textClassifierApiUsageLogger;
80 
81   @Override
onCreate()82   public void onCreate() {
83     super.onCreate();
84     settings = injector.createTextClassifierSettings();
85     modelDownloadManager =
86         new ModelDownloadManager(
87             injector.getContext().getApplicationContext(),
88             settings,
89             TextClassifierServiceExecutors.getDownloaderExecutor());
90     modelDownloadManager.onTextClassifierServiceCreated();
91     modelFileManager = injector.createModelFileManager(settings, modelDownloadManager);
92     normPriorityExecutor = injector.createNormPriorityExecutor();
93     lowPriorityExecutor = injector.createLowPriorityExecutor();
94     textClassifier = injector.createTextClassifierImpl(settings, modelFileManager);
95     sessionIdToContext = new LruCache<>(settings.getSessionIdToContextCacheSize());
96     textClassifierApiUsageLogger =
97         injector.createTextClassifierApiUsageLogger(settings, lowPriorityExecutor);
98   }
99 
100   @Override
onDestroy()101   public void onDestroy() {
102     super.onDestroy();
103     modelDownloadManager.destroy();
104   }
105 
106   @Override
onCreateTextClassificationSession( @onNull TextClassificationContext context, @NonNull TextClassificationSessionId sessionId)107   public void onCreateTextClassificationSession(
108       @NonNull TextClassificationContext context, @NonNull TextClassificationSessionId sessionId) {
109     sessionIdToContext.put(sessionId, context);
110   }
111 
112   @Override
onDestroyTextClassificationSession(@onNull TextClassificationSessionId sessionId)113   public void onDestroyTextClassificationSession(@NonNull TextClassificationSessionId sessionId) {
114     sessionIdToContext.remove(sessionId);
115   }
116 
117   @Override
onSuggestSelection( TextClassificationSessionId sessionId, TextSelection.Request request, CancellationSignal cancellationSignal, Callback<TextSelection> callback)118   public void onSuggestSelection(
119       TextClassificationSessionId sessionId,
120       TextSelection.Request request,
121       CancellationSignal cancellationSignal,
122       Callback<TextSelection> callback) {
123     handleRequestAsync(
124         () ->
125             textClassifier.suggestSelection(
126                 sessionId, sessionIdToTextClassificationContext(sessionId), request),
127         callback,
128         textClassifierApiUsageLogger.createSession(
129             TextClassifierApiUsageLogger.API_TYPE_SUGGEST_SELECTION, sessionId),
130         cancellationSignal);
131   }
132 
133   @Override
onClassifyText( TextClassificationSessionId sessionId, TextClassification.Request request, CancellationSignal cancellationSignal, Callback<TextClassification> callback)134   public void onClassifyText(
135       TextClassificationSessionId sessionId,
136       TextClassification.Request request,
137       CancellationSignal cancellationSignal,
138       Callback<TextClassification> callback) {
139     handleRequestAsync(
140         () ->
141             textClassifier.classifyText(
142                 sessionId, sessionIdToTextClassificationContext(sessionId), request),
143         callback,
144         textClassifierApiUsageLogger.createSession(
145             TextClassifierApiUsageLogger.API_TYPE_CLASSIFY_TEXT, sessionId),
146         cancellationSignal);
147   }
148 
149   @Override
onGenerateLinks( TextClassificationSessionId sessionId, TextLinks.Request request, CancellationSignal cancellationSignal, Callback<TextLinks> callback)150   public void onGenerateLinks(
151       TextClassificationSessionId sessionId,
152       TextLinks.Request request,
153       CancellationSignal cancellationSignal,
154       Callback<TextLinks> callback) {
155     handleRequestAsync(
156         () ->
157             textClassifier.generateLinks(
158                 sessionId, sessionIdToTextClassificationContext(sessionId), request),
159         callback,
160         textClassifierApiUsageLogger.createSession(
161             TextClassifierApiUsageLogger.API_TYPE_GENERATE_LINKS, sessionId),
162         cancellationSignal);
163   }
164 
165   @Override
onSuggestConversationActions( TextClassificationSessionId sessionId, ConversationActions.Request request, CancellationSignal cancellationSignal, Callback<ConversationActions> callback)166   public void onSuggestConversationActions(
167       TextClassificationSessionId sessionId,
168       ConversationActions.Request request,
169       CancellationSignal cancellationSignal,
170       Callback<ConversationActions> callback) {
171     handleRequestAsync(
172         () ->
173             textClassifier.suggestConversationActions(
174                 sessionId, sessionIdToTextClassificationContext(sessionId), request),
175         callback,
176         textClassifierApiUsageLogger.createSession(
177             TextClassifierApiUsageLogger.API_TYPE_SUGGEST_CONVERSATION_ACTIONS, sessionId),
178         cancellationSignal);
179   }
180 
181   @Override
onDetectLanguage( TextClassificationSessionId sessionId, TextLanguage.Request request, CancellationSignal cancellationSignal, Callback<TextLanguage> callback)182   public void onDetectLanguage(
183       TextClassificationSessionId sessionId,
184       TextLanguage.Request request,
185       CancellationSignal cancellationSignal,
186       Callback<TextLanguage> callback) {
187     handleRequestAsync(
188         () ->
189             textClassifier.detectLanguage(
190                 sessionId, sessionIdToTextClassificationContext(sessionId), request),
191         callback,
192         textClassifierApiUsageLogger.createSession(
193             TextClassifierApiUsageLogger.API_TYPE_DETECT_LANGUAGES, sessionId),
194         cancellationSignal);
195   }
196 
197   @Override
onSelectionEvent(TextClassificationSessionId sessionId, SelectionEvent event)198   public void onSelectionEvent(TextClassificationSessionId sessionId, SelectionEvent event) {
199     handleEvent(() -> textClassifier.onSelectionEvent(sessionId, event));
200   }
201 
202   @Override
onTextClassifierEvent( TextClassificationSessionId sessionId, TextClassifierEvent event)203   public void onTextClassifierEvent(
204       TextClassificationSessionId sessionId, TextClassifierEvent event) {
205     handleEvent(() -> textClassifier.onTextClassifierEvent(sessionId, event));
206   }
207 
208   @Override
dump(FileDescriptor fd, PrintWriter writer, String[] args)209   protected void dump(FileDescriptor fd, PrintWriter writer, String[] args) {
210     // Dump in a background thread b/c we may need to query Room db (e.g. to init model cache)
211     try {
212       TextClassifierServiceExecutors.getLowPriorityExecutor()
213           .submit(
214               () -> {
215                 IndentingPrintWriter indentingPrintWriter = new IndentingPrintWriter(writer);
216                 textClassifier.dump(indentingPrintWriter);
217                 modelDownloadManager.dump(indentingPrintWriter);
218                 dumpImpl(indentingPrintWriter);
219                 indentingPrintWriter.flush();
220               })
221           .get();
222     } catch (ExecutionException | InterruptedException e) {
223       TcLog.e(TAG, "Failed to dump Default TextClassifierService", e);
224     }
225   }
226 
dumpImpl(IndentingPrintWriter printWriter)227   private void dumpImpl(IndentingPrintWriter printWriter) {
228     printWriter.println("DefaultTextClassifierService:");
229     printWriter.increaseIndent();
230     printWriter.println("sessionIdToContext:");
231     printWriter.increaseIndent();
232     for (Map.Entry<TextClassificationSessionId, TextClassificationContext> entry :
233         sessionIdToContext.snapshot().entrySet()) {
234       printWriter.printPair(entry.getKey().getValue(), entry.getValue());
235     }
236     printWriter.decreaseIndent();
237     printWriter.decreaseIndent();
238     printWriter.println();
239   }
240 
handleRequestAsync( Callable<T> callable, Callback<T> callback, TextClassifierApiUsageLogger.Session apiLoggerSession, CancellationSignal cancellationSignal)241   private <T> void handleRequestAsync(
242       Callable<T> callable,
243       Callback<T> callback,
244       TextClassifierApiUsageLogger.Session apiLoggerSession,
245       CancellationSignal cancellationSignal) {
246     ListenableFuture<T> result = normPriorityExecutor.submit(callable);
247     Futures.addCallback(
248         result,
249         new FutureCallback<T>() {
250           @Override
251           public void onSuccess(T result) {
252             callback.onSuccess(result);
253             apiLoggerSession.reportSuccess();
254           }
255 
256           @Override
257           public void onFailure(Throwable t) {
258             TcLog.e(TAG, "onFailure: ", t);
259             callback.onFailure(t.getMessage());
260             apiLoggerSession.reportFailure();
261           }
262         },
263         MoreExecutors.directExecutor());
264     cancellationSignal.setOnCancelListener(() -> result.cancel(/* mayInterruptIfRunning= */ true));
265   }
266 
handleEvent(Runnable runnable)267   private void handleEvent(Runnable runnable) {
268     ListenableFuture<Void> result =
269         lowPriorityExecutor.submit(
270             () -> {
271               runnable.run();
272               return null;
273             });
274     Futures.addCallback(
275         result,
276         new FutureCallback<Void>() {
277           @Override
278           public void onSuccess(Void result) {}
279 
280           @Override
281           public void onFailure(Throwable t) {
282             TcLog.e(TAG, "onFailure: ", t);
283           }
284         },
285         MoreExecutors.directExecutor());
286   }
287 
288   @Nullable
sessionIdToTextClassificationContext( @ullable TextClassificationSessionId sessionId)289   private TextClassificationContext sessionIdToTextClassificationContext(
290       @Nullable TextClassificationSessionId sessionId) {
291     if (sessionId == null) {
292       return null;
293     }
294     return sessionIdToContext.get(sessionId);
295   }
296 
297   // Do not call any of these methods, except the constructor, before Service.onCreate is called.
298   private static class InjectorImpl implements Injector {
299     // Do not access the context object before Service.onCreate is invoked.
300     private final Context context;
301 
InjectorImpl(Context context)302     private InjectorImpl(Context context) {
303       this.context = Preconditions.checkNotNull(context);
304     }
305 
306     @Override
getContext()307     public Context getContext() {
308       return context;
309     }
310 
311     @Override
createModelFileManager( TextClassifierSettings settings, ModelDownloadManager modelDownloadManager)312     public ModelFileManager createModelFileManager(
313         TextClassifierSettings settings, ModelDownloadManager modelDownloadManager) {
314       return new ModelFileManagerImpl(context, modelDownloadManager, settings);
315     }
316 
317     @Override
createTextClassifierSettings()318     public TextClassifierSettings createTextClassifierSettings() {
319       return new TextClassifierSettings(getContext());
320     }
321 
322     @Override
createTextClassifierImpl( TextClassifierSettings settings, ModelFileManager modelFileManager)323     public TextClassifierImpl createTextClassifierImpl(
324         TextClassifierSettings settings, ModelFileManager modelFileManager) {
325       return new TextClassifierImpl(context, settings, modelFileManager);
326     }
327 
328     @Override
createNormPriorityExecutor()329     public ListeningExecutorService createNormPriorityExecutor() {
330       return TextClassifierServiceExecutors.getNormhPriorityExecutor();
331     }
332 
333     @Override
createLowPriorityExecutor()334     public ListeningExecutorService createLowPriorityExecutor() {
335       return TextClassifierServiceExecutors.getLowPriorityExecutor();
336     }
337 
338     @Override
createTextClassifierApiUsageLogger( TextClassifierSettings settings, Executor executor)339     public TextClassifierApiUsageLogger createTextClassifierApiUsageLogger(
340         TextClassifierSettings settings, Executor executor) {
341       return new TextClassifierApiUsageLogger(
342           settings::getTextClassifierApiLogSampleRate, executor);
343     }
344   }
345 
346   /*
347    * Provides dependencies to the {@link DefaultTextClassifierService}. This makes the service
348    * class testable.
349    */
350   interface Injector {
getContext()351     Context getContext();
352 
createModelFileManager( TextClassifierSettings settings, ModelDownloadManager modelDownloadManager)353     ModelFileManager createModelFileManager(
354         TextClassifierSettings settings, ModelDownloadManager modelDownloadManager);
355 
createTextClassifierSettings()356     TextClassifierSettings createTextClassifierSettings();
357 
createTextClassifierImpl( TextClassifierSettings settings, ModelFileManager modelFileManager)358     TextClassifierImpl createTextClassifierImpl(
359         TextClassifierSettings settings, ModelFileManager modelFileManager);
360 
createNormPriorityExecutor()361     ListeningExecutorService createNormPriorityExecutor();
362 
createLowPriorityExecutor()363     ListeningExecutorService createLowPriorityExecutor();
364 
createTextClassifierApiUsageLogger( TextClassifierSettings settings, Executor executor)365     TextClassifierApiUsageLogger createTextClassifierApiUsageLogger(
366         TextClassifierSettings settings, Executor executor);
367   }
368 }
369