• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.textclassifier;
18 
19 import static java.util.stream.Collectors.toCollection;
20 
21 import android.app.PendingIntent;
22 import android.app.RemoteAction;
23 import android.content.Context;
24 import android.content.Intent;
25 import android.content.res.AssetFileDescriptor;
26 import android.icu.util.ULocale;
27 import android.os.Bundle;
28 import android.os.LocaleList;
29 import android.os.Looper;
30 import android.util.ArrayMap;
31 import android.view.View.OnClickListener;
32 import android.view.textclassifier.ConversationAction;
33 import android.view.textclassifier.ConversationActions;
34 import android.view.textclassifier.SelectionEvent;
35 import android.view.textclassifier.TextClassification;
36 import android.view.textclassifier.TextClassification.Request;
37 import android.view.textclassifier.TextClassificationContext;
38 import android.view.textclassifier.TextClassificationSessionId;
39 import android.view.textclassifier.TextClassifier;
40 import android.view.textclassifier.TextClassifierEvent;
41 import android.view.textclassifier.TextLanguage;
42 import android.view.textclassifier.TextLinks;
43 import android.view.textclassifier.TextSelection;
44 import androidx.annotation.GuardedBy;
45 import androidx.annotation.WorkerThread;
46 import androidx.collection.LruCache;
47 import androidx.core.util.Pair;
48 import com.android.textclassifier.common.ModelFile;
49 import com.android.textclassifier.common.ModelType;
50 import com.android.textclassifier.common.TextClassifierSettings;
51 import com.android.textclassifier.common.TextSelectionCompat;
52 import com.android.textclassifier.common.base.TcLog;
53 import com.android.textclassifier.common.intent.LabeledIntent;
54 import com.android.textclassifier.common.intent.TemplateIntentFactory;
55 import com.android.textclassifier.common.logging.ResultIdUtils;
56 import com.android.textclassifier.common.logging.ResultIdUtils.ModelInfo;
57 import com.android.textclassifier.common.statsd.GenerateLinksLogger;
58 import com.android.textclassifier.common.statsd.SelectionEventConverter;
59 import com.android.textclassifier.common.statsd.TextClassificationSessionIdConverter;
60 import com.android.textclassifier.common.statsd.TextClassifierEventConverter;
61 import com.android.textclassifier.common.statsd.TextClassifierEventLogger;
62 import com.android.textclassifier.utils.IndentingPrintWriter;
63 import com.google.android.textclassifier.ActionsSuggestionsModel;
64 import com.google.android.textclassifier.ActionsSuggestionsModel.ActionSuggestions;
65 import com.google.android.textclassifier.AnnotatorModel;
66 import com.google.android.textclassifier.LangIdModel;
67 import com.google.common.annotations.VisibleForTesting;
68 import com.google.common.base.Optional;
69 import com.google.common.base.Preconditions;
70 import com.google.common.collect.FluentIterable;
71 import com.google.common.collect.ImmutableList;
72 import java.io.IOException;
73 import java.time.ZoneId;
74 import java.time.ZonedDateTime;
75 import java.util.ArrayList;
76 import java.util.Collection;
77 import java.util.List;
78 import java.util.Map;
79 import java.util.Objects;
80 import javax.annotation.Nullable;
81 
82 /**
83  * A text classifier that is running locally.
84  *
85  * <p>This class uses machine learning to recognize entities in text. Unless otherwise stated,
86  * methods of this class are blocking operations and should most likely not be called on the UI
87  * thread.
88  */
89 final class TextClassifierImpl {
90 
91   private static final String TAG = "TextClassifierImpl";
92 
93   private final Context context;
94   private final ModelFileManager modelFileManager;
95   private final GenerateLinksLogger generateLinksLogger;
96 
97   private final Object lock = new Object();
98 
99   @GuardedBy("lock")
100   private ModelFile annotatorModelInUse;
101 
102   @GuardedBy("lock")
103   private AnnotatorModel annotatorImpl;
104 
105   @GuardedBy("lock")
106   private ModelFile langIdModelInUse;
107 
108   @GuardedBy("lock")
109   private LangIdModel langIdImpl;
110 
111   @GuardedBy("lock")
112   private ModelFile actionModelInUse;
113 
114   @GuardedBy("lock")
115   private ActionsSuggestionsModel actionsImpl;
116 
117   @GuardedBy("lock")
118   private final LruCache<ModelFile, AnnotatorModel> annotatorModelCache;
119 
120   private final TextClassifierEventLogger textClassifierEventLogger =
121       new TextClassifierEventLogger();
122 
123   private final TextClassifierSettings settings;
124 
125   private final TemplateIntentFactory templateIntentFactory;
126 
TextClassifierImpl( Context context, TextClassifierSettings settings, ModelFileManager modelFileManager)127   TextClassifierImpl(
128       Context context, TextClassifierSettings settings, ModelFileManager modelFileManager) {
129     this(
130         context, settings, modelFileManager, new LruCache<>(settings.getMultiAnnotatorCacheSize()));
131   }
132 
133   @VisibleForTesting
TextClassifierImpl( Context context, TextClassifierSettings settings, ModelFileManager modelFileManager, LruCache<ModelFile, AnnotatorModel> annotatorModelCache)134   public TextClassifierImpl(
135       Context context,
136       TextClassifierSettings settings,
137       ModelFileManager modelFileManager,
138       LruCache<ModelFile, AnnotatorModel> annotatorModelCache) {
139     this.context = Preconditions.checkNotNull(context);
140     this.settings = Preconditions.checkNotNull(settings);
141     this.modelFileManager = Preconditions.checkNotNull(modelFileManager);
142     this.annotatorModelCache = annotatorModelCache;
143     generateLinksLogger = new GenerateLinksLogger(this.settings.getGenerateLinksLogSampleRate());
144     templateIntentFactory = new TemplateIntentFactory();
145   }
146 
147   @WorkerThread
suggestSelection( @ullable TextClassificationSessionId sessionId, @Nullable TextClassificationContext textClassificationContext, TextSelection.Request request)148   TextSelection suggestSelection(
149       @Nullable TextClassificationSessionId sessionId,
150       @Nullable TextClassificationContext textClassificationContext,
151       TextSelection.Request request)
152       throws IOException {
153     Preconditions.checkNotNull(request);
154     checkMainThread();
155     final int rangeLength = request.getEndIndex() - request.getStartIndex();
156     final String string = request.getText().toString();
157     Preconditions.checkArgument(!string.isEmpty(), "input string should not be empty");
158     Preconditions.checkArgument(
159         rangeLength <= settings.getClassifyTextMaxRangeLength(), "range is too large");
160     final String localesString = concatenateLocales(request.getDefaultLocales());
161     final LangIdModel langIdModel = getLangIdImpl();
162     final String detectLanguageTags =
163         String.join(",", detectLanguageTags(langIdModel, request.getText()));
164     final ZonedDateTime refTime = ZonedDateTime.now(ZoneId.systemDefault());
165     final LocaleList detectedLocaleList = LocaleList.forLanguageTags(detectLanguageTags);
166     final ModelFile annotatorModelInUse =
167         getAnnotatorModelFile(request.getDefaultLocales(), detectedLocaleList);
168     final AnnotatorModel annotatorImpl = loadAnnotatorModelFile(annotatorModelInUse);
169     final int[] startEnd =
170         annotatorImpl.suggestSelection(
171             string,
172             request.getStartIndex(),
173             request.getEndIndex(),
174             AnnotatorModel.SelectionOptions.builder()
175                 .setLocales(localesString)
176                 .setDetectedTextLanguageTags(detectLanguageTags)
177                 .build());
178     final int start = startEnd[0];
179     final int end = startEnd[1];
180     if (start >= end
181         || start < 0
182         || start > request.getStartIndex()
183         || end > string.length()
184         || end < request.getEndIndex()) {
185       throw new IllegalArgumentException("Got bad indices for input text. Ignoring result.");
186     }
187     final TextSelection.Builder tsBuilder = new TextSelection.Builder(start, end);
188     final boolean shouldIncludeTextClassification =
189         TextSelectionCompat.shouldIncludeTextClassification(request);
190     final AnnotatorModel.ClassificationResult[] results =
191         annotatorImpl.classifyText(
192             string,
193             start,
194             end,
195             AnnotatorModel.ClassificationOptions.builder()
196                 .setReferenceTimeMsUtc(refTime.toInstant().toEpochMilli())
197                 .setReferenceTimezone(refTime.getZone().getId())
198                 .setLocales(localesString)
199                 .setDetectedTextLanguageTags(detectLanguageTags)
200                 .setAnnotationUsecase(AnnotatorModel.AnnotationUsecase.SMART.getValue())
201                 .setUserFamiliarLanguageTags(LocaleList.getDefault().toLanguageTags())
202                 .setEnableAddContactIntent(false)
203                 .setEnableSearchIntent(shouldEnableSearchIntent(textClassificationContext))
204                 .build(),
205             // Passing null here to suppress intent generation.
206             // TODO: Use an explicit flag to suppress it.
207             shouldIncludeTextClassification ? context : null,
208             getResourceLocalesString());
209     final int size = results.length;
210     for (int i = 0; i < size; i++) {
211       tsBuilder.setEntityType(results[i].getCollection(), results[i].getScore());
212     }
213     final String resultId =
214         createAnnotatorId(string, request.getStartIndex(), request.getEndIndex());
215     if (shouldIncludeTextClassification) {
216       TextClassification textClassification =
217           createClassificationResult(results, string, start, end, langIdModel);
218       TextSelectionCompat.setTextClassification(tsBuilder, textClassification);
219     }
220     return tsBuilder.setId(resultId).build();
221   }
222 
223   @WorkerThread
classifyText( @ullable TextClassificationSessionId sessionId, @Nullable TextClassificationContext textClassificationContext, Request request)224   TextClassification classifyText(
225       @Nullable TextClassificationSessionId sessionId,
226       @Nullable TextClassificationContext textClassificationContext,
227       Request request)
228       throws IOException {
229     Preconditions.checkNotNull(request);
230     checkMainThread();
231     LangIdModel langId = getLangIdImpl();
232     List<String> detectLanguageTags = detectLanguageTags(langId, request.getText());
233     final int rangeLength = request.getEndIndex() - request.getStartIndex();
234     final String string = request.getText().toString();
235     Preconditions.checkArgument(!string.isEmpty(), "input string should not be empty");
236     Preconditions.checkArgument(
237         rangeLength <= settings.getClassifyTextMaxRangeLength(), "range is too large");
238 
239     final String localesString = concatenateLocales(request.getDefaultLocales());
240     final ZonedDateTime refTime =
241         request.getReferenceTime() != null
242             ? request.getReferenceTime()
243             : ZonedDateTime.now(ZoneId.systemDefault());
244     final LocaleList detectedLocaleList =
245         LocaleList.forLanguageTags(String.join(",", detectLanguageTags));
246     final AnnotatorModel.ClassificationResult[] results =
247         getAnnotatorImpl(request.getDefaultLocales(), detectedLocaleList)
248             .classifyText(
249                 string,
250                 request.getStartIndex(),
251                 request.getEndIndex(),
252                 AnnotatorModel.ClassificationOptions.builder()
253                     .setReferenceTimeMsUtc(refTime.toInstant().toEpochMilli())
254                     .setReferenceTimezone(refTime.getZone().getId())
255                     .setLocales(localesString)
256                     .setDetectedTextLanguageTags(String.join(",", detectLanguageTags))
257                     .setAnnotationUsecase(AnnotatorModel.AnnotationUsecase.SMART.getValue())
258                     .setUserFamiliarLanguageTags(LocaleList.getDefault().toLanguageTags())
259                     .setEnableAddContactIntent(false)
260                     .setEnableSearchIntent(shouldEnableSearchIntent(textClassificationContext))
261                     .build(),
262                 context,
263                 getResourceLocalesString());
264     if (results.length == 0) {
265       throw new IllegalStateException("Empty text classification. Something went wrong.");
266     }
267     return createClassificationResult(
268         results, string, request.getStartIndex(), request.getEndIndex(), langId);
269   }
270 
271   @WorkerThread
generateLinks( @ullable TextClassificationSessionId sessionId, @Nullable TextClassificationContext textClassificationContext, TextLinks.Request request)272   TextLinks generateLinks(
273       @Nullable TextClassificationSessionId sessionId,
274       @Nullable TextClassificationContext textClassificationContext,
275       TextLinks.Request request)
276       throws IOException {
277     Preconditions.checkNotNull(request);
278     Preconditions.checkArgument(
279         request.getText().length() <= getMaxGenerateLinksTextLength(),
280         "text.length() cannot be greater than %s",
281         getMaxGenerateLinksTextLength());
282     checkMainThread();
283 
284     final String textString = request.getText().toString();
285     final TextLinks.Builder builder = new TextLinks.Builder(textString);
286 
287     final long startTimeMs = System.currentTimeMillis();
288     final ZonedDateTime refTime = ZonedDateTime.now(ZoneId.systemDefault());
289     final Collection<String> entitiesToIdentify =
290         request.getEntityConfig() != null
291             ? request
292                 .getEntityConfig()
293                 .resolveEntityListModifications(
294                     getEntitiesForHints(request.getEntityConfig().getHints()))
295             : settings.getEntityListDefault();
296     final String localesString = concatenateLocales(request.getDefaultLocales());
297     LangIdModel langId = getLangIdImpl();
298     ImmutableList<String> detectLanguageTags = detectLanguageTags(langId, request.getText());
299     final LocaleList detectedLocaleList =
300         LocaleList.forLanguageTags(String.join(",", detectLanguageTags));
301     final AnnotatorModel annotatorImpl =
302         getAnnotatorImpl(request.getDefaultLocales(), detectedLocaleList);
303     final boolean isSerializedEntityDataEnabled =
304         ExtrasUtils.isSerializedEntityDataEnabled(request);
305     final AnnotatorModel.AnnotatedSpan[] annotations =
306         annotatorImpl.annotate(
307             textString,
308             AnnotatorModel.AnnotationOptions.builder()
309                 .setReferenceTimeMsUtc(refTime.toInstant().toEpochMilli())
310                 .setReferenceTimezone(refTime.getZone().getId())
311                 .setLocales(localesString)
312                 .setDetectedTextLanguageTags(String.join(",", detectLanguageTags))
313                 .setEntityTypes(entitiesToIdentify)
314                 .setAnnotationUsecase(AnnotatorModel.AnnotationUsecase.SMART.getValue())
315                 .setIsSerializedEntityDataEnabled(isSerializedEntityDataEnabled)
316                 .build());
317     for (AnnotatorModel.AnnotatedSpan span : annotations) {
318       final AnnotatorModel.ClassificationResult[] results = span.getClassification();
319       if (results.length == 0 || !entitiesToIdentify.contains(results[0].getCollection())) {
320         continue;
321       }
322       final Map<String, Float> entityScores = new ArrayMap<>();
323       for (AnnotatorModel.ClassificationResult result : results) {
324         entityScores.put(result.getCollection(), result.getScore());
325       }
326       Bundle extras = new Bundle();
327       if (isSerializedEntityDataEnabled) {
328         ExtrasUtils.putEntities(extras, results);
329       }
330       builder.addLink(span.getStartIndex(), span.getEndIndex(), entityScores, extras);
331     }
332     final TextLinks links = builder.build();
333     final long endTimeMs = System.currentTimeMillis();
334     final String callingPackageName =
335         request.getCallingPackageName() == null
336             ? context.getPackageName() // local (in process) TC.
337             : request.getCallingPackageName();
338     Optional<ModelInfo> annotatorModelInfo;
339     Optional<ModelInfo> langIdModelInfo;
340     synchronized (lock) {
341       annotatorModelInfo =
342           Optional.fromNullable(annotatorModelInUse).transform(ModelFile::toModelInfo);
343       langIdModelInfo = Optional.fromNullable(langIdModelInUse).transform(ModelFile::toModelInfo);
344     }
345     generateLinksLogger.logGenerateLinks(
346         sessionId,
347         textClassificationContext,
348         request.getText(),
349         links,
350         callingPackageName,
351         endTimeMs - startTimeMs,
352         annotatorModelInfo,
353         langIdModelInfo);
354     return links;
355   }
356 
getMaxGenerateLinksTextLength()357   int getMaxGenerateLinksTextLength() {
358     return settings.getGenerateLinksMaxTextLength();
359   }
360 
getEntitiesForHints(Collection<String> hints)361   private Collection<String> getEntitiesForHints(Collection<String> hints) {
362     final boolean editable = hints.contains(TextClassifier.HINT_TEXT_IS_EDITABLE);
363     final boolean notEditable = hints.contains(TextClassifier.HINT_TEXT_IS_NOT_EDITABLE);
364 
365     // Use the default if there is no hint, or conflicting ones.
366     final boolean useDefault = editable == notEditable;
367     if (useDefault) {
368       return settings.getEntityListDefault();
369     } else if (editable) {
370       return settings.getEntityListEditable();
371     } else { // notEditable
372       return settings.getEntityListNotEditable();
373     }
374   }
375 
onSelectionEvent(@ullable TextClassificationSessionId sessionId, SelectionEvent event)376   void onSelectionEvent(@Nullable TextClassificationSessionId sessionId, SelectionEvent event) {
377     TextClassifierEvent textClassifierEvent = SelectionEventConverter.toTextClassifierEvent(event);
378     if (textClassifierEvent == null) {
379       return;
380     }
381     onTextClassifierEvent(event.getSessionId(), textClassifierEvent);
382   }
383 
onTextClassifierEvent( @ullable TextClassificationSessionId sessionId, TextClassifierEvent event)384   void onTextClassifierEvent(
385       @Nullable TextClassificationSessionId sessionId, TextClassifierEvent event) {
386     textClassifierEventLogger.writeEvent(
387         TextClassificationSessionIdConverter.fromPlatform(sessionId),
388         TextClassifierEventConverter.fromPlatform(event));
389   }
390 
detectLanguage( @ullable TextClassificationSessionId sessionId, @Nullable TextClassificationContext textClassificationContext, TextLanguage.Request request)391   TextLanguage detectLanguage(
392       @Nullable TextClassificationSessionId sessionId,
393       @Nullable TextClassificationContext textClassificationContext,
394       TextLanguage.Request request)
395       throws IOException {
396     Preconditions.checkNotNull(request);
397     checkMainThread();
398     final TextLanguage.Builder builder = new TextLanguage.Builder();
399     LangIdModel langIdImpl = getLangIdImpl();
400     final LangIdModel.LanguageResult[] langResults =
401         langIdImpl.detectLanguages(request.getText().toString());
402     for (LangIdModel.LanguageResult langResult : langResults) {
403       builder.putLocale(ULocale.forLanguageTag(langResult.getLanguage()), langResult.getScore());
404     }
405     return builder.build();
406   }
407 
suggestConversationActions( @ullable TextClassificationSessionId sessionId, @Nullable TextClassificationContext textClassificationContext, ConversationActions.Request request)408   ConversationActions suggestConversationActions(
409       @Nullable TextClassificationSessionId sessionId,
410       @Nullable TextClassificationContext textClassificationContext,
411       ConversationActions.Request request)
412       throws IOException {
413     Preconditions.checkNotNull(request);
414     checkMainThread();
415     ActionsSuggestionsModel actionsImpl = getActionsImpl();
416     LangIdModel langId = getLangIdImpl();
417     ActionsSuggestionsModel.ConversationMessage[] nativeMessages =
418         ActionsSuggestionsHelper.toNativeMessages(
419             request.getConversation(), text -> detectLanguageTags(langId, text));
420     if (nativeMessages.length == 0) {
421       return new ConversationActions(ImmutableList.of(), /* id= */ null);
422     }
423     ActionsSuggestionsModel.Conversation nativeConversation =
424         new ActionsSuggestionsModel.Conversation(nativeMessages);
425 
426     ActionSuggestions nativeSuggestions =
427         actionsImpl.suggestActionsWithIntents(
428             nativeConversation,
429             null,
430             context,
431             getResourceLocalesString(),
432             getAnnotatorImpl(LocaleList.getDefault(), /* detectedLocaleList= */ null));
433     return createConversationActionResult(request, nativeSuggestions);
434   }
435 
436   /**
437    * Returns the {@link ConversationAction} result, with a non-null extras.
438    *
439    * <p>Whenever the RemoteAction is non-null, you can expect its corresponding intent with a
440    * non-null component name is in the extras.
441    */
createConversationActionResult( ConversationActions.Request request, ActionSuggestions nativeSuggestions)442   private ConversationActions createConversationActionResult(
443       ConversationActions.Request request, ActionSuggestions nativeSuggestions) {
444     Collection<String> expectedTypes = resolveActionTypesFromRequest(request);
445     List<ConversationAction> conversationActions = new ArrayList<>();
446     for (ActionsSuggestionsModel.ActionSuggestion nativeSuggestion :
447         nativeSuggestions.actionSuggestions) {
448       String actionType = nativeSuggestion.getActionType();
449       if (!expectedTypes.contains(actionType)) {
450         continue;
451       }
452       LabeledIntent.Result labeledIntentResult =
453           ActionsSuggestionsHelper.createLabeledIntentResult(
454               context, templateIntentFactory, nativeSuggestion);
455       RemoteAction remoteAction = null;
456       Bundle extras = new Bundle();
457       if (labeledIntentResult != null) {
458         remoteAction = labeledIntentResult.remoteAction.toRemoteAction();
459         ExtrasUtils.putActionIntent(
460             extras, stripPackageInfoFromIntent(labeledIntentResult.resolvedIntent));
461       }
462       ExtrasUtils.putSerializedEntityData(extras, nativeSuggestion.getSerializedEntityData());
463       ExtrasUtils.putEntitiesExtras(
464           extras, TemplateIntentFactory.nameVariantsToBundle(nativeSuggestion.getEntityData()));
465       conversationActions.add(
466           new ConversationAction.Builder(actionType)
467               .setConfidenceScore(nativeSuggestion.getScore())
468               .setTextReply(nativeSuggestion.getResponseText())
469               .setAction(remoteAction)
470               .setExtras(extras)
471               .build());
472     }
473     conversationActions = ActionsSuggestionsHelper.removeActionsWithDuplicates(conversationActions);
474     if (request.getMaxSuggestions() >= 0
475         && conversationActions.size() > request.getMaxSuggestions()) {
476       conversationActions = conversationActions.subList(0, request.getMaxSuggestions());
477     }
478     synchronized (lock) {
479       String resultId =
480           ActionsSuggestionsHelper.createResultId(
481               context,
482               request.getConversation(),
483               Optional.fromNullable(actionModelInUse),
484               Optional.fromNullable(annotatorModelInUse),
485               Optional.fromNullable(langIdModelInUse));
486       return new ConversationActions(conversationActions, resultId);
487     }
488   }
489 
resolveActionTypesFromRequest(ConversationActions.Request request)490   private Collection<String> resolveActionTypesFromRequest(ConversationActions.Request request) {
491     List<String> defaultActionTypes =
492         request.getHints().contains(ConversationActions.Request.HINT_FOR_NOTIFICATION)
493             ? settings.getNotificationConversationActionTypes()
494             : settings.getInAppConversationActionTypes();
495     return request.getTypeConfig().resolveEntityListModifications(defaultActionTypes);
496   }
497 
getAnnotatorModelFile( LocaleList requestLocaleList, LocaleList detectedLocaleList)498   private ModelFile getAnnotatorModelFile(
499       LocaleList requestLocaleList, LocaleList detectedLocaleList) throws IOException {
500     final ModelFile bestModel =
501         modelFileManager.findBestModelFile(
502             ModelType.ANNOTATOR, requestLocaleList, detectedLocaleList);
503     if (bestModel == null) {
504       throw new IllegalStateException("Failed to find the best annotator model");
505     }
506     return bestModel;
507   }
508 
loadAnnotatorModelFile(ModelFile annotatorModelFile)509   private AnnotatorModel loadAnnotatorModelFile(ModelFile annotatorModelFile) throws IOException {
510     synchronized (lock) {
511       if (settings.getMultiAnnotatorCacheEnabled()
512           && !Objects.equals(annotatorModelInUse, annotatorModelFile)) {
513         TcLog.v(TAG, "Attempting to reload cached annotator model....");
514         annotatorImpl = annotatorModelCache.get(annotatorModelFile);
515         if (annotatorImpl != null) {
516           annotatorModelInUse = annotatorModelFile;
517           TcLog.v(TAG, "Successfully reloaded cached annotator model: " + annotatorModelFile);
518         }
519       }
520       if (annotatorImpl == null || !Objects.equals(annotatorModelInUse, annotatorModelFile)) {
521         TcLog.d(TAG, "Loading " + annotatorModelFile);
522         // The current annotator model may be still used by another thread / model.
523         // Do not call close() here, and let the GC to clean it up when no one else
524         // is using it.
525         try (AssetFileDescriptor afd = annotatorModelFile.open(context.getAssets())) {
526           annotatorImpl = new AnnotatorModel(afd);
527           annotatorImpl.setLangIdModel(getLangIdImpl());
528           annotatorModelInUse = annotatorModelFile;
529           if (settings.getMultiAnnotatorCacheEnabled()) {
530             annotatorModelCache.put(annotatorModelFile, annotatorImpl);
531           }
532         }
533       }
534       return annotatorImpl;
535     }
536   }
537 
getAnnotatorImpl( LocaleList requestLocaleList, LocaleList detectedLocaleList)538   private AnnotatorModel getAnnotatorImpl(
539       LocaleList requestLocaleList, LocaleList detectedLocaleList) throws IOException {
540     ModelFile annotatorModelFile = getAnnotatorModelFile(requestLocaleList, detectedLocaleList);
541     return loadAnnotatorModelFile(annotatorModelFile);
542   }
543 
getLangIdImpl()544   private LangIdModel getLangIdImpl() throws IOException {
545     synchronized (lock) {
546       final ModelFile bestModel =
547           modelFileManager.findBestModelFile(
548               ModelType.LANG_ID, /* localePreferences= */ null, /* detectedLocales= */ null);
549       if (bestModel == null) {
550         throw new IllegalStateException("Failed to find the best LangID model.");
551       }
552       if (langIdImpl == null || !Objects.equals(langIdModelInUse, bestModel)) {
553         TcLog.d(TAG, "Loading " + bestModel);
554         try (AssetFileDescriptor afd = bestModel.open(context.getAssets())) {
555           langIdImpl = new LangIdModel(afd);
556           langIdModelInUse = bestModel;
557         }
558       }
559       return langIdImpl;
560     }
561   }
562 
getActionsImpl()563   private ActionsSuggestionsModel getActionsImpl() throws IOException {
564     synchronized (lock) {
565       // TODO: Use LangID to determine the locale we should use here?
566       final ModelFile bestModel =
567           modelFileManager.findBestModelFile(
568               ModelType.ACTIONS_SUGGESTIONS, LocaleList.getDefault(), /* detectedLocales= */ null);
569       if (bestModel == null) {
570         throw new IllegalStateException("Failed to find the best actions model");
571       }
572       if (actionsImpl == null || !Objects.equals(actionModelInUse, bestModel)) {
573         TcLog.d(TAG, "Loading " + bestModel);
574         try (AssetFileDescriptor afd = bestModel.open(context.getAssets())) {
575           actionsImpl = new ActionsSuggestionsModel(afd);
576           actionModelInUse = bestModel;
577         }
578       }
579       return actionsImpl;
580     }
581   }
582 
createAnnotatorId(String text, int start, int end)583   private String createAnnotatorId(String text, int start, int end) {
584     synchronized (lock) {
585       return ResultIdUtils.createId(
586           context,
587           text,
588           start,
589           end,
590           ModelFile.toModelInfos(
591               Optional.fromNullable(annotatorModelInUse), Optional.fromNullable(langIdModelInUse)));
592     }
593   }
594 
concatenateLocales(@ullable LocaleList locales)595   private static String concatenateLocales(@Nullable LocaleList locales) {
596     return (locales == null) ? "" : locales.toLanguageTags();
597   }
598 
createClassificationResult( AnnotatorModel.ClassificationResult[] classifications, String text, int start, int end, LangIdModel langId)599   private TextClassification createClassificationResult(
600       AnnotatorModel.ClassificationResult[] classifications,
601       String text,
602       int start,
603       int end,
604       LangIdModel langId) {
605     final String classifiedText = text.substring(start, end);
606     final TextClassification.Builder builder =
607         new TextClassification.Builder().setText(classifiedText);
608 
609     final int typeCount = classifications.length;
610     AnnotatorModel.ClassificationResult highestScoringResult =
611         typeCount > 0 ? classifications[0] : null;
612     for (int i = 0; i < typeCount; i++) {
613       builder.setEntityType(classifications[i].getCollection(), classifications[i].getScore());
614       if (classifications[i].getScore() > highestScoringResult.getScore()) {
615         highestScoringResult = classifications[i];
616       }
617     }
618 
619     boolean isPrimaryAction = true;
620     final ImmutableList<LabeledIntent> labeledIntents =
621         highestScoringResult == null
622             ? ImmutableList.of()
623             : templateIntentFactory.create(highestScoringResult.getRemoteActionTemplates());
624     final LabeledIntent.TitleChooser titleChooser =
625         (labeledIntent, resolveInfo) -> labeledIntent.titleWithoutEntity;
626 
627     ArrayList<Intent> actionIntents = new ArrayList<>();
628     for (LabeledIntent labeledIntent : labeledIntents) {
629       final LabeledIntent.Result result = labeledIntent.resolve(context, titleChooser);
630       if (result == null) {
631         continue;
632       }
633 
634       final Intent intent = result.resolvedIntent;
635       final RemoteAction action = result.remoteAction.toRemoteAction();
636       if (isPrimaryAction) {
637         // For O backwards compatibility, the first RemoteAction is also written to the
638         // legacy API fields.
639         builder.setIcon(action.getIcon().loadDrawable(context));
640         builder.setLabel(action.getTitle().toString());
641         builder.setIntent(intent);
642         builder.setOnClickListener(
643             createIntentOnClickListener(
644                 createPendingIntent(context, intent, labeledIntent.requestCode)));
645         isPrimaryAction = false;
646       }
647       builder.addAction(action);
648       actionIntents.add(intent);
649     }
650     Bundle extras = new Bundle();
651     Optional<Bundle> foreignLanguageExtra = maybeCreateExtrasForTranslate(actionIntents, langId);
652     if (foreignLanguageExtra.isPresent()) {
653       ExtrasUtils.putForeignLanguageExtra(extras, foreignLanguageExtra.get());
654     }
655     if (actionIntents.stream().anyMatch(Objects::nonNull)) {
656       ArrayList<Intent> strippedIntents =
657           actionIntents.stream()
658               .map(TextClassifierImpl::stripPackageInfoFromIntent)
659               .collect(toCollection(ArrayList::new));
660       ExtrasUtils.putActionsIntents(extras, strippedIntents);
661     }
662     ExtrasUtils.putEntities(extras, classifications);
663     builder.setExtras(extras);
664     String resultId = createAnnotatorId(text, start, end);
665     return builder.setId(resultId).build();
666   }
667 
createIntentOnClickListener(final PendingIntent intent)668   private static OnClickListener createIntentOnClickListener(final PendingIntent intent) {
669     Preconditions.checkNotNull(intent);
670     return v -> {
671       try {
672         intent.send();
673       } catch (PendingIntent.CanceledException e) {
674         TcLog.e(TAG, "Error sending PendingIntent", e);
675       }
676     };
677   }
678 
679   private static Optional<Bundle> maybeCreateExtrasForTranslate(
680       List<Intent> intents, LangIdModel langId) {
681     Optional<Intent> translateIntent =
682         FluentIterable.from(intents)
683             .filter(Objects::nonNull)
684             .filter(intent -> Intent.ACTION_TRANSLATE.equals(intent.getAction()))
685             .first();
686     if (!translateIntent.isPresent()) {
687       return Optional.absent();
688     }
689     Pair<String, Float> topLanguageWithScore = ExtrasUtils.getTopLanguage(translateIntent.get());
690     if (topLanguageWithScore == null) {
691       return Optional.absent();
692     }
693     return Optional.of(
694         ExtrasUtils.createForeignLanguageExtra(
695             topLanguageWithScore.first, topLanguageWithScore.second, langId.getVersion()));
696   }
697 
698   private ImmutableList<String> detectLanguageTags(LangIdModel langId, CharSequence text) {
699     float threshold = getLangIdThreshold(langId);
700     EntityConfidence languagesConfidence = detectLanguages(langId, text, threshold);
701     return ImmutableList.copyOf(languagesConfidence.getEntities());
702   }
703 
704   /**
705    * Detects languages for the specified text. Only returns languages with score that is higher than
706    * or equal to the specified threshold.
707    */
708   private static EntityConfidence detectLanguages(
709       LangIdModel langId, CharSequence text, float threshold) {
710     final LangIdModel.LanguageResult[] langResults = langId.detectLanguages(text.toString());
711     final Map<String, Float> languagesMap = new ArrayMap<>();
712     for (LangIdModel.LanguageResult langResult : langResults) {
713       if (langResult.getScore() >= threshold) {
714         languagesMap.put(langResult.getLanguage(), langResult.getScore());
715       }
716     }
717     return new EntityConfidence(languagesMap);
718   }
719 
720   private float getLangIdThreshold(LangIdModel langId) {
721     return settings.getLangIdThresholdOverride() >= 0
722         ? settings.getLangIdThresholdOverride()
723         : langId.getLangIdThreshold();
724   }
725 
726   void dump(IndentingPrintWriter printWriter) {
727     synchronized (lock) {
728       printWriter.println("TextClassifierImpl:");
729 
730       printWriter.increaseIndent();
731       modelFileManager.dump(printWriter);
732       printWriter.decreaseIndent();
733 
734       printWriter.println();
735       settings.dump(printWriter);
736       printWriter.println();
737     }
738   }
739 
740   /** Returns the locales string for the current resources configuration. */
741   private String getResourceLocalesString() {
742     try {
743       return context.getResources().getConfiguration().getLocales().toLanguageTags();
744     } catch (NullPointerException e) {
745 
746       // NPE is unexpected. Erring on the side of caution.
747       return LocaleList.getDefault().toLanguageTags();
748     }
749   }
750 
751   private static void checkMainThread() {
752     if (Looper.myLooper() == Looper.getMainLooper()) {
753       TcLog.e(TAG, "TCS TextClassifier called on main thread", new Exception());
754     }
755   }
756 
757   private static PendingIntent createPendingIntent(
758       final Context context, final Intent intent, int requestCode) {
759     return PendingIntent.getActivity(
760         context,
761         requestCode,
762         intent,
763         PendingIntent.FLAG_UPDATE_CURRENT | PendingIntent.FLAG_IMMUTABLE);
764   }
765 
766   @Nullable
767   private static Intent stripPackageInfoFromIntent(@Nullable Intent intent) {
768     if (intent == null) {
769       return null;
770     }
771     Intent strippedIntent = new Intent(intent);
772     strippedIntent.setPackage(null);
773     strippedIntent.setComponent(null);
774     return strippedIntent;
775   }
776 
777   private static boolean shouldEnableSearchIntent(
778       @Nullable TextClassificationContext textClassificationContext) {
779     if (textClassificationContext == null) {
780       return false;
781     }
782     String widgetType = textClassificationContext.getWidgetType();
783     // Exclude WebView because there is already a *Web Search* chip there.
784     return !(TextClassifier.WIDGET_TYPE_WEBVIEW.equals(widgetType)
785         || TextClassifier.WIDGET_TYPE_EDIT_WEBVIEW.equals(widgetType));
786   }
787 }
788