1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.textclassifier; 18 19 import static java.util.stream.Collectors.toCollection; 20 21 import android.app.PendingIntent; 22 import android.app.RemoteAction; 23 import android.content.Context; 24 import android.content.Intent; 25 import android.content.res.AssetFileDescriptor; 26 import android.icu.util.ULocale; 27 import android.os.Bundle; 28 import android.os.LocaleList; 29 import android.os.Looper; 30 import android.util.ArrayMap; 31 import android.view.View.OnClickListener; 32 import android.view.textclassifier.ConversationAction; 33 import android.view.textclassifier.ConversationActions; 34 import android.view.textclassifier.SelectionEvent; 35 import android.view.textclassifier.TextClassification; 36 import android.view.textclassifier.TextClassification.Request; 37 import android.view.textclassifier.TextClassificationContext; 38 import android.view.textclassifier.TextClassificationSessionId; 39 import android.view.textclassifier.TextClassifier; 40 import android.view.textclassifier.TextClassifierEvent; 41 import android.view.textclassifier.TextLanguage; 42 import android.view.textclassifier.TextLinks; 43 import android.view.textclassifier.TextSelection; 44 import androidx.annotation.GuardedBy; 45 import androidx.annotation.WorkerThread; 46 import androidx.collection.LruCache; 47 import androidx.core.util.Pair; 48 import com.android.textclassifier.common.ModelFile; 49 import com.android.textclassifier.common.ModelType; 50 import com.android.textclassifier.common.TextClassifierSettings; 51 import com.android.textclassifier.common.TextSelectionCompat; 52 import com.android.textclassifier.common.base.TcLog; 53 import com.android.textclassifier.common.intent.LabeledIntent; 54 import com.android.textclassifier.common.intent.TemplateIntentFactory; 55 import com.android.textclassifier.common.logging.ResultIdUtils; 56 import com.android.textclassifier.common.logging.ResultIdUtils.ModelInfo; 57 import com.android.textclassifier.common.statsd.GenerateLinksLogger; 58 import com.android.textclassifier.common.statsd.SelectionEventConverter; 59 import com.android.textclassifier.common.statsd.TextClassificationSessionIdConverter; 60 import com.android.textclassifier.common.statsd.TextClassifierEventConverter; 61 import com.android.textclassifier.common.statsd.TextClassifierEventLogger; 62 import com.android.textclassifier.utils.IndentingPrintWriter; 63 import com.google.android.textclassifier.ActionsSuggestionsModel; 64 import com.google.android.textclassifier.ActionsSuggestionsModel.ActionSuggestions; 65 import com.google.android.textclassifier.AnnotatorModel; 66 import com.google.android.textclassifier.LangIdModel; 67 import com.google.common.annotations.VisibleForTesting; 68 import com.google.common.base.Optional; 69 import com.google.common.base.Preconditions; 70 import com.google.common.collect.FluentIterable; 71 import com.google.common.collect.ImmutableList; 72 import java.io.IOException; 73 import java.time.ZoneId; 74 import java.time.ZonedDateTime; 75 import java.util.ArrayList; 76 import java.util.Collection; 77 import java.util.List; 78 import java.util.Map; 79 import java.util.Objects; 80 import javax.annotation.Nullable; 81 82 /** 83 * A text classifier that is running locally. 84 * 85 * <p>This class uses machine learning to recognize entities in text. Unless otherwise stated, 86 * methods of this class are blocking operations and should most likely not be called on the UI 87 * thread. 88 */ 89 final class TextClassifierImpl { 90 91 private static final String TAG = "TextClassifierImpl"; 92 93 private final Context context; 94 private final ModelFileManager modelFileManager; 95 private final GenerateLinksLogger generateLinksLogger; 96 97 private final Object lock = new Object(); 98 99 @GuardedBy("lock") 100 private ModelFile annotatorModelInUse; 101 102 @GuardedBy("lock") 103 private AnnotatorModel annotatorImpl; 104 105 @GuardedBy("lock") 106 private ModelFile langIdModelInUse; 107 108 @GuardedBy("lock") 109 private LangIdModel langIdImpl; 110 111 @GuardedBy("lock") 112 private ModelFile actionModelInUse; 113 114 @GuardedBy("lock") 115 private ActionsSuggestionsModel actionsImpl; 116 117 @GuardedBy("lock") 118 private final LruCache<ModelFile, AnnotatorModel> annotatorModelCache; 119 120 private final TextClassifierEventLogger textClassifierEventLogger = 121 new TextClassifierEventLogger(); 122 123 private final TextClassifierSettings settings; 124 125 private final TemplateIntentFactory templateIntentFactory; 126 TextClassifierImpl( Context context, TextClassifierSettings settings, ModelFileManager modelFileManager)127 TextClassifierImpl( 128 Context context, TextClassifierSettings settings, ModelFileManager modelFileManager) { 129 this( 130 context, settings, modelFileManager, new LruCache<>(settings.getMultiAnnotatorCacheSize())); 131 } 132 133 @VisibleForTesting TextClassifierImpl( Context context, TextClassifierSettings settings, ModelFileManager modelFileManager, LruCache<ModelFile, AnnotatorModel> annotatorModelCache)134 public TextClassifierImpl( 135 Context context, 136 TextClassifierSettings settings, 137 ModelFileManager modelFileManager, 138 LruCache<ModelFile, AnnotatorModel> annotatorModelCache) { 139 this.context = Preconditions.checkNotNull(context); 140 this.settings = Preconditions.checkNotNull(settings); 141 this.modelFileManager = Preconditions.checkNotNull(modelFileManager); 142 this.annotatorModelCache = annotatorModelCache; 143 generateLinksLogger = new GenerateLinksLogger(this.settings.getGenerateLinksLogSampleRate()); 144 templateIntentFactory = new TemplateIntentFactory(); 145 } 146 147 @WorkerThread suggestSelection( @ullable TextClassificationSessionId sessionId, @Nullable TextClassificationContext textClassificationContext, TextSelection.Request request)148 TextSelection suggestSelection( 149 @Nullable TextClassificationSessionId sessionId, 150 @Nullable TextClassificationContext textClassificationContext, 151 TextSelection.Request request) 152 throws IOException { 153 Preconditions.checkNotNull(request); 154 checkMainThread(); 155 final int rangeLength = request.getEndIndex() - request.getStartIndex(); 156 final String string = request.getText().toString(); 157 Preconditions.checkArgument(!string.isEmpty(), "input string should not be empty"); 158 Preconditions.checkArgument( 159 rangeLength <= settings.getClassifyTextMaxRangeLength(), "range is too large"); 160 final String localesString = concatenateLocales(request.getDefaultLocales()); 161 final LangIdModel langIdModel = getLangIdImpl(); 162 final String detectLanguageTags = 163 String.join(",", detectLanguageTags(langIdModel, request.getText())); 164 final ZonedDateTime refTime = ZonedDateTime.now(ZoneId.systemDefault()); 165 final LocaleList detectedLocaleList = LocaleList.forLanguageTags(detectLanguageTags); 166 final ModelFile annotatorModelInUse = 167 getAnnotatorModelFile(request.getDefaultLocales(), detectedLocaleList); 168 final AnnotatorModel annotatorImpl = loadAnnotatorModelFile(annotatorModelInUse); 169 final int[] startEnd = 170 annotatorImpl.suggestSelection( 171 string, 172 request.getStartIndex(), 173 request.getEndIndex(), 174 AnnotatorModel.SelectionOptions.builder() 175 .setLocales(localesString) 176 .setDetectedTextLanguageTags(detectLanguageTags) 177 .build()); 178 final int start = startEnd[0]; 179 final int end = startEnd[1]; 180 if (start >= end 181 || start < 0 182 || start > request.getStartIndex() 183 || end > string.length() 184 || end < request.getEndIndex()) { 185 throw new IllegalArgumentException("Got bad indices for input text. Ignoring result."); 186 } 187 final TextSelection.Builder tsBuilder = new TextSelection.Builder(start, end); 188 final boolean shouldIncludeTextClassification = 189 TextSelectionCompat.shouldIncludeTextClassification(request); 190 final AnnotatorModel.ClassificationResult[] results = 191 annotatorImpl.classifyText( 192 string, 193 start, 194 end, 195 AnnotatorModel.ClassificationOptions.builder() 196 .setReferenceTimeMsUtc(refTime.toInstant().toEpochMilli()) 197 .setReferenceTimezone(refTime.getZone().getId()) 198 .setLocales(localesString) 199 .setDetectedTextLanguageTags(detectLanguageTags) 200 .setAnnotationUsecase(AnnotatorModel.AnnotationUsecase.SMART.getValue()) 201 .setUserFamiliarLanguageTags(LocaleList.getDefault().toLanguageTags()) 202 .setEnableAddContactIntent(false) 203 .setEnableSearchIntent(shouldEnableSearchIntent(textClassificationContext)) 204 .build(), 205 // Passing null here to suppress intent generation. 206 // TODO: Use an explicit flag to suppress it. 207 shouldIncludeTextClassification ? context : null, 208 getResourceLocalesString()); 209 final int size = results.length; 210 for (int i = 0; i < size; i++) { 211 tsBuilder.setEntityType(results[i].getCollection(), results[i].getScore()); 212 } 213 final String resultId = 214 createAnnotatorId(string, request.getStartIndex(), request.getEndIndex()); 215 if (shouldIncludeTextClassification) { 216 TextClassification textClassification = 217 createClassificationResult(results, string, start, end, langIdModel); 218 TextSelectionCompat.setTextClassification(tsBuilder, textClassification); 219 } 220 return tsBuilder.setId(resultId).build(); 221 } 222 223 @WorkerThread classifyText( @ullable TextClassificationSessionId sessionId, @Nullable TextClassificationContext textClassificationContext, Request request)224 TextClassification classifyText( 225 @Nullable TextClassificationSessionId sessionId, 226 @Nullable TextClassificationContext textClassificationContext, 227 Request request) 228 throws IOException { 229 Preconditions.checkNotNull(request); 230 checkMainThread(); 231 LangIdModel langId = getLangIdImpl(); 232 List<String> detectLanguageTags = detectLanguageTags(langId, request.getText()); 233 final int rangeLength = request.getEndIndex() - request.getStartIndex(); 234 final String string = request.getText().toString(); 235 Preconditions.checkArgument(!string.isEmpty(), "input string should not be empty"); 236 Preconditions.checkArgument( 237 rangeLength <= settings.getClassifyTextMaxRangeLength(), "range is too large"); 238 239 final String localesString = concatenateLocales(request.getDefaultLocales()); 240 final ZonedDateTime refTime = 241 request.getReferenceTime() != null 242 ? request.getReferenceTime() 243 : ZonedDateTime.now(ZoneId.systemDefault()); 244 final LocaleList detectedLocaleList = 245 LocaleList.forLanguageTags(String.join(",", detectLanguageTags)); 246 final AnnotatorModel.ClassificationResult[] results = 247 getAnnotatorImpl(request.getDefaultLocales(), detectedLocaleList) 248 .classifyText( 249 string, 250 request.getStartIndex(), 251 request.getEndIndex(), 252 AnnotatorModel.ClassificationOptions.builder() 253 .setReferenceTimeMsUtc(refTime.toInstant().toEpochMilli()) 254 .setReferenceTimezone(refTime.getZone().getId()) 255 .setLocales(localesString) 256 .setDetectedTextLanguageTags(String.join(",", detectLanguageTags)) 257 .setAnnotationUsecase(AnnotatorModel.AnnotationUsecase.SMART.getValue()) 258 .setUserFamiliarLanguageTags(LocaleList.getDefault().toLanguageTags()) 259 .setEnableAddContactIntent(false) 260 .setEnableSearchIntent(shouldEnableSearchIntent(textClassificationContext)) 261 .build(), 262 context, 263 getResourceLocalesString()); 264 if (results.length == 0) { 265 throw new IllegalStateException("Empty text classification. Something went wrong."); 266 } 267 return createClassificationResult( 268 results, string, request.getStartIndex(), request.getEndIndex(), langId); 269 } 270 271 @WorkerThread generateLinks( @ullable TextClassificationSessionId sessionId, @Nullable TextClassificationContext textClassificationContext, TextLinks.Request request)272 TextLinks generateLinks( 273 @Nullable TextClassificationSessionId sessionId, 274 @Nullable TextClassificationContext textClassificationContext, 275 TextLinks.Request request) 276 throws IOException { 277 Preconditions.checkNotNull(request); 278 Preconditions.checkArgument( 279 request.getText().length() <= getMaxGenerateLinksTextLength(), 280 "text.length() cannot be greater than %s", 281 getMaxGenerateLinksTextLength()); 282 checkMainThread(); 283 284 final String textString = request.getText().toString(); 285 final TextLinks.Builder builder = new TextLinks.Builder(textString); 286 287 final long startTimeMs = System.currentTimeMillis(); 288 final ZonedDateTime refTime = ZonedDateTime.now(ZoneId.systemDefault()); 289 final Collection<String> entitiesToIdentify = 290 request.getEntityConfig() != null 291 ? request 292 .getEntityConfig() 293 .resolveEntityListModifications( 294 getEntitiesForHints(request.getEntityConfig().getHints())) 295 : settings.getEntityListDefault(); 296 final String localesString = concatenateLocales(request.getDefaultLocales()); 297 LangIdModel langId = getLangIdImpl(); 298 ImmutableList<String> detectLanguageTags = detectLanguageTags(langId, request.getText()); 299 final LocaleList detectedLocaleList = 300 LocaleList.forLanguageTags(String.join(",", detectLanguageTags)); 301 final AnnotatorModel annotatorImpl = 302 getAnnotatorImpl(request.getDefaultLocales(), detectedLocaleList); 303 final boolean isSerializedEntityDataEnabled = 304 ExtrasUtils.isSerializedEntityDataEnabled(request); 305 final AnnotatorModel.AnnotatedSpan[] annotations = 306 annotatorImpl.annotate( 307 textString, 308 AnnotatorModel.AnnotationOptions.builder() 309 .setReferenceTimeMsUtc(refTime.toInstant().toEpochMilli()) 310 .setReferenceTimezone(refTime.getZone().getId()) 311 .setLocales(localesString) 312 .setDetectedTextLanguageTags(String.join(",", detectLanguageTags)) 313 .setEntityTypes(entitiesToIdentify) 314 .setAnnotationUsecase(AnnotatorModel.AnnotationUsecase.SMART.getValue()) 315 .setIsSerializedEntityDataEnabled(isSerializedEntityDataEnabled) 316 .build()); 317 for (AnnotatorModel.AnnotatedSpan span : annotations) { 318 final AnnotatorModel.ClassificationResult[] results = span.getClassification(); 319 if (results.length == 0 || !entitiesToIdentify.contains(results[0].getCollection())) { 320 continue; 321 } 322 final Map<String, Float> entityScores = new ArrayMap<>(); 323 for (AnnotatorModel.ClassificationResult result : results) { 324 entityScores.put(result.getCollection(), result.getScore()); 325 } 326 Bundle extras = new Bundle(); 327 if (isSerializedEntityDataEnabled) { 328 ExtrasUtils.putEntities(extras, results); 329 } 330 builder.addLink(span.getStartIndex(), span.getEndIndex(), entityScores, extras); 331 } 332 final TextLinks links = builder.build(); 333 final long endTimeMs = System.currentTimeMillis(); 334 final String callingPackageName = 335 request.getCallingPackageName() == null 336 ? context.getPackageName() // local (in process) TC. 337 : request.getCallingPackageName(); 338 Optional<ModelInfo> annotatorModelInfo; 339 Optional<ModelInfo> langIdModelInfo; 340 synchronized (lock) { 341 annotatorModelInfo = 342 Optional.fromNullable(annotatorModelInUse).transform(ModelFile::toModelInfo); 343 langIdModelInfo = Optional.fromNullable(langIdModelInUse).transform(ModelFile::toModelInfo); 344 } 345 generateLinksLogger.logGenerateLinks( 346 sessionId, 347 textClassificationContext, 348 request.getText(), 349 links, 350 callingPackageName, 351 endTimeMs - startTimeMs, 352 annotatorModelInfo, 353 langIdModelInfo); 354 return links; 355 } 356 getMaxGenerateLinksTextLength()357 int getMaxGenerateLinksTextLength() { 358 return settings.getGenerateLinksMaxTextLength(); 359 } 360 getEntitiesForHints(Collection<String> hints)361 private Collection<String> getEntitiesForHints(Collection<String> hints) { 362 final boolean editable = hints.contains(TextClassifier.HINT_TEXT_IS_EDITABLE); 363 final boolean notEditable = hints.contains(TextClassifier.HINT_TEXT_IS_NOT_EDITABLE); 364 365 // Use the default if there is no hint, or conflicting ones. 366 final boolean useDefault = editable == notEditable; 367 if (useDefault) { 368 return settings.getEntityListDefault(); 369 } else if (editable) { 370 return settings.getEntityListEditable(); 371 } else { // notEditable 372 return settings.getEntityListNotEditable(); 373 } 374 } 375 onSelectionEvent(@ullable TextClassificationSessionId sessionId, SelectionEvent event)376 void onSelectionEvent(@Nullable TextClassificationSessionId sessionId, SelectionEvent event) { 377 TextClassifierEvent textClassifierEvent = SelectionEventConverter.toTextClassifierEvent(event); 378 if (textClassifierEvent == null) { 379 return; 380 } 381 onTextClassifierEvent(event.getSessionId(), textClassifierEvent); 382 } 383 onTextClassifierEvent( @ullable TextClassificationSessionId sessionId, TextClassifierEvent event)384 void onTextClassifierEvent( 385 @Nullable TextClassificationSessionId sessionId, TextClassifierEvent event) { 386 textClassifierEventLogger.writeEvent( 387 TextClassificationSessionIdConverter.fromPlatform(sessionId), 388 TextClassifierEventConverter.fromPlatform(event)); 389 } 390 detectLanguage( @ullable TextClassificationSessionId sessionId, @Nullable TextClassificationContext textClassificationContext, TextLanguage.Request request)391 TextLanguage detectLanguage( 392 @Nullable TextClassificationSessionId sessionId, 393 @Nullable TextClassificationContext textClassificationContext, 394 TextLanguage.Request request) 395 throws IOException { 396 Preconditions.checkNotNull(request); 397 checkMainThread(); 398 final TextLanguage.Builder builder = new TextLanguage.Builder(); 399 LangIdModel langIdImpl = getLangIdImpl(); 400 final LangIdModel.LanguageResult[] langResults = 401 langIdImpl.detectLanguages(request.getText().toString()); 402 for (LangIdModel.LanguageResult langResult : langResults) { 403 builder.putLocale(ULocale.forLanguageTag(langResult.getLanguage()), langResult.getScore()); 404 } 405 return builder.build(); 406 } 407 suggestConversationActions( @ullable TextClassificationSessionId sessionId, @Nullable TextClassificationContext textClassificationContext, ConversationActions.Request request)408 ConversationActions suggestConversationActions( 409 @Nullable TextClassificationSessionId sessionId, 410 @Nullable TextClassificationContext textClassificationContext, 411 ConversationActions.Request request) 412 throws IOException { 413 Preconditions.checkNotNull(request); 414 checkMainThread(); 415 ActionsSuggestionsModel actionsImpl = getActionsImpl(); 416 LangIdModel langId = getLangIdImpl(); 417 ActionsSuggestionsModel.ConversationMessage[] nativeMessages = 418 ActionsSuggestionsHelper.toNativeMessages( 419 request.getConversation(), text -> detectLanguageTags(langId, text)); 420 if (nativeMessages.length == 0) { 421 return new ConversationActions(ImmutableList.of(), /* id= */ null); 422 } 423 ActionsSuggestionsModel.Conversation nativeConversation = 424 new ActionsSuggestionsModel.Conversation(nativeMessages); 425 426 ActionSuggestions nativeSuggestions = 427 actionsImpl.suggestActionsWithIntents( 428 nativeConversation, 429 null, 430 context, 431 getResourceLocalesString(), 432 getAnnotatorImpl(LocaleList.getDefault(), /* detectedLocaleList= */ null)); 433 return createConversationActionResult(request, nativeSuggestions); 434 } 435 436 /** 437 * Returns the {@link ConversationAction} result, with a non-null extras. 438 * 439 * <p>Whenever the RemoteAction is non-null, you can expect its corresponding intent with a 440 * non-null component name is in the extras. 441 */ createConversationActionResult( ConversationActions.Request request, ActionSuggestions nativeSuggestions)442 private ConversationActions createConversationActionResult( 443 ConversationActions.Request request, ActionSuggestions nativeSuggestions) { 444 Collection<String> expectedTypes = resolveActionTypesFromRequest(request); 445 List<ConversationAction> conversationActions = new ArrayList<>(); 446 for (ActionsSuggestionsModel.ActionSuggestion nativeSuggestion : 447 nativeSuggestions.actionSuggestions) { 448 String actionType = nativeSuggestion.getActionType(); 449 if (!expectedTypes.contains(actionType)) { 450 continue; 451 } 452 LabeledIntent.Result labeledIntentResult = 453 ActionsSuggestionsHelper.createLabeledIntentResult( 454 context, templateIntentFactory, nativeSuggestion); 455 RemoteAction remoteAction = null; 456 Bundle extras = new Bundle(); 457 if (labeledIntentResult != null) { 458 remoteAction = labeledIntentResult.remoteAction.toRemoteAction(); 459 ExtrasUtils.putActionIntent( 460 extras, stripPackageInfoFromIntent(labeledIntentResult.resolvedIntent)); 461 } 462 ExtrasUtils.putSerializedEntityData(extras, nativeSuggestion.getSerializedEntityData()); 463 ExtrasUtils.putEntitiesExtras( 464 extras, TemplateIntentFactory.nameVariantsToBundle(nativeSuggestion.getEntityData())); 465 conversationActions.add( 466 new ConversationAction.Builder(actionType) 467 .setConfidenceScore(nativeSuggestion.getScore()) 468 .setTextReply(nativeSuggestion.getResponseText()) 469 .setAction(remoteAction) 470 .setExtras(extras) 471 .build()); 472 } 473 conversationActions = ActionsSuggestionsHelper.removeActionsWithDuplicates(conversationActions); 474 if (request.getMaxSuggestions() >= 0 475 && conversationActions.size() > request.getMaxSuggestions()) { 476 conversationActions = conversationActions.subList(0, request.getMaxSuggestions()); 477 } 478 synchronized (lock) { 479 String resultId = 480 ActionsSuggestionsHelper.createResultId( 481 context, 482 request.getConversation(), 483 Optional.fromNullable(actionModelInUse), 484 Optional.fromNullable(annotatorModelInUse), 485 Optional.fromNullable(langIdModelInUse)); 486 return new ConversationActions(conversationActions, resultId); 487 } 488 } 489 resolveActionTypesFromRequest(ConversationActions.Request request)490 private Collection<String> resolveActionTypesFromRequest(ConversationActions.Request request) { 491 List<String> defaultActionTypes = 492 request.getHints().contains(ConversationActions.Request.HINT_FOR_NOTIFICATION) 493 ? settings.getNotificationConversationActionTypes() 494 : settings.getInAppConversationActionTypes(); 495 return request.getTypeConfig().resolveEntityListModifications(defaultActionTypes); 496 } 497 getAnnotatorModelFile( LocaleList requestLocaleList, LocaleList detectedLocaleList)498 private ModelFile getAnnotatorModelFile( 499 LocaleList requestLocaleList, LocaleList detectedLocaleList) throws IOException { 500 final ModelFile bestModel = 501 modelFileManager.findBestModelFile( 502 ModelType.ANNOTATOR, requestLocaleList, detectedLocaleList); 503 if (bestModel == null) { 504 throw new IllegalStateException("Failed to find the best annotator model"); 505 } 506 return bestModel; 507 } 508 loadAnnotatorModelFile(ModelFile annotatorModelFile)509 private AnnotatorModel loadAnnotatorModelFile(ModelFile annotatorModelFile) throws IOException { 510 synchronized (lock) { 511 if (settings.getMultiAnnotatorCacheEnabled() 512 && !Objects.equals(annotatorModelInUse, annotatorModelFile)) { 513 TcLog.v(TAG, "Attempting to reload cached annotator model...."); 514 annotatorImpl = annotatorModelCache.get(annotatorModelFile); 515 if (annotatorImpl != null) { 516 annotatorModelInUse = annotatorModelFile; 517 TcLog.v(TAG, "Successfully reloaded cached annotator model: " + annotatorModelFile); 518 } 519 } 520 if (annotatorImpl == null || !Objects.equals(annotatorModelInUse, annotatorModelFile)) { 521 TcLog.d(TAG, "Loading " + annotatorModelFile); 522 // The current annotator model may be still used by another thread / model. 523 // Do not call close() here, and let the GC to clean it up when no one else 524 // is using it. 525 try (AssetFileDescriptor afd = annotatorModelFile.open(context.getAssets())) { 526 annotatorImpl = new AnnotatorModel(afd); 527 annotatorImpl.setLangIdModel(getLangIdImpl()); 528 annotatorModelInUse = annotatorModelFile; 529 if (settings.getMultiAnnotatorCacheEnabled()) { 530 annotatorModelCache.put(annotatorModelFile, annotatorImpl); 531 } 532 } 533 } 534 return annotatorImpl; 535 } 536 } 537 getAnnotatorImpl( LocaleList requestLocaleList, LocaleList detectedLocaleList)538 private AnnotatorModel getAnnotatorImpl( 539 LocaleList requestLocaleList, LocaleList detectedLocaleList) throws IOException { 540 ModelFile annotatorModelFile = getAnnotatorModelFile(requestLocaleList, detectedLocaleList); 541 return loadAnnotatorModelFile(annotatorModelFile); 542 } 543 getLangIdImpl()544 private LangIdModel getLangIdImpl() throws IOException { 545 synchronized (lock) { 546 final ModelFile bestModel = 547 modelFileManager.findBestModelFile( 548 ModelType.LANG_ID, /* localePreferences= */ null, /* detectedLocales= */ null); 549 if (bestModel == null) { 550 throw new IllegalStateException("Failed to find the best LangID model."); 551 } 552 if (langIdImpl == null || !Objects.equals(langIdModelInUse, bestModel)) { 553 TcLog.d(TAG, "Loading " + bestModel); 554 try (AssetFileDescriptor afd = bestModel.open(context.getAssets())) { 555 langIdImpl = new LangIdModel(afd); 556 langIdModelInUse = bestModel; 557 } 558 } 559 return langIdImpl; 560 } 561 } 562 getActionsImpl()563 private ActionsSuggestionsModel getActionsImpl() throws IOException { 564 synchronized (lock) { 565 // TODO: Use LangID to determine the locale we should use here? 566 final ModelFile bestModel = 567 modelFileManager.findBestModelFile( 568 ModelType.ACTIONS_SUGGESTIONS, LocaleList.getDefault(), /* detectedLocales= */ null); 569 if (bestModel == null) { 570 throw new IllegalStateException("Failed to find the best actions model"); 571 } 572 if (actionsImpl == null || !Objects.equals(actionModelInUse, bestModel)) { 573 TcLog.d(TAG, "Loading " + bestModel); 574 try (AssetFileDescriptor afd = bestModel.open(context.getAssets())) { 575 actionsImpl = new ActionsSuggestionsModel(afd); 576 actionModelInUse = bestModel; 577 } 578 } 579 return actionsImpl; 580 } 581 } 582 createAnnotatorId(String text, int start, int end)583 private String createAnnotatorId(String text, int start, int end) { 584 synchronized (lock) { 585 return ResultIdUtils.createId( 586 context, 587 text, 588 start, 589 end, 590 ModelFile.toModelInfos( 591 Optional.fromNullable(annotatorModelInUse), Optional.fromNullable(langIdModelInUse))); 592 } 593 } 594 concatenateLocales(@ullable LocaleList locales)595 private static String concatenateLocales(@Nullable LocaleList locales) { 596 return (locales == null) ? "" : locales.toLanguageTags(); 597 } 598 createClassificationResult( AnnotatorModel.ClassificationResult[] classifications, String text, int start, int end, LangIdModel langId)599 private TextClassification createClassificationResult( 600 AnnotatorModel.ClassificationResult[] classifications, 601 String text, 602 int start, 603 int end, 604 LangIdModel langId) { 605 final String classifiedText = text.substring(start, end); 606 final TextClassification.Builder builder = 607 new TextClassification.Builder().setText(classifiedText); 608 609 final int typeCount = classifications.length; 610 AnnotatorModel.ClassificationResult highestScoringResult = 611 typeCount > 0 ? classifications[0] : null; 612 for (int i = 0; i < typeCount; i++) { 613 builder.setEntityType(classifications[i].getCollection(), classifications[i].getScore()); 614 if (classifications[i].getScore() > highestScoringResult.getScore()) { 615 highestScoringResult = classifications[i]; 616 } 617 } 618 619 boolean isPrimaryAction = true; 620 final ImmutableList<LabeledIntent> labeledIntents = 621 highestScoringResult == null 622 ? ImmutableList.of() 623 : templateIntentFactory.create(highestScoringResult.getRemoteActionTemplates()); 624 final LabeledIntent.TitleChooser titleChooser = 625 (labeledIntent, resolveInfo) -> labeledIntent.titleWithoutEntity; 626 627 ArrayList<Intent> actionIntents = new ArrayList<>(); 628 for (LabeledIntent labeledIntent : labeledIntents) { 629 final LabeledIntent.Result result = labeledIntent.resolve(context, titleChooser); 630 if (result == null) { 631 continue; 632 } 633 634 final Intent intent = result.resolvedIntent; 635 final RemoteAction action = result.remoteAction.toRemoteAction(); 636 if (isPrimaryAction) { 637 // For O backwards compatibility, the first RemoteAction is also written to the 638 // legacy API fields. 639 builder.setIcon(action.getIcon().loadDrawable(context)); 640 builder.setLabel(action.getTitle().toString()); 641 builder.setIntent(intent); 642 builder.setOnClickListener( 643 createIntentOnClickListener( 644 createPendingIntent(context, intent, labeledIntent.requestCode))); 645 isPrimaryAction = false; 646 } 647 builder.addAction(action); 648 actionIntents.add(intent); 649 } 650 Bundle extras = new Bundle(); 651 Optional<Bundle> foreignLanguageExtra = maybeCreateExtrasForTranslate(actionIntents, langId); 652 if (foreignLanguageExtra.isPresent()) { 653 ExtrasUtils.putForeignLanguageExtra(extras, foreignLanguageExtra.get()); 654 } 655 if (actionIntents.stream().anyMatch(Objects::nonNull)) { 656 ArrayList<Intent> strippedIntents = 657 actionIntents.stream() 658 .map(TextClassifierImpl::stripPackageInfoFromIntent) 659 .collect(toCollection(ArrayList::new)); 660 ExtrasUtils.putActionsIntents(extras, strippedIntents); 661 } 662 ExtrasUtils.putEntities(extras, classifications); 663 builder.setExtras(extras); 664 String resultId = createAnnotatorId(text, start, end); 665 return builder.setId(resultId).build(); 666 } 667 createIntentOnClickListener(final PendingIntent intent)668 private static OnClickListener createIntentOnClickListener(final PendingIntent intent) { 669 Preconditions.checkNotNull(intent); 670 return v -> { 671 try { 672 intent.send(); 673 } catch (PendingIntent.CanceledException e) { 674 TcLog.e(TAG, "Error sending PendingIntent", e); 675 } 676 }; 677 } 678 679 private static Optional<Bundle> maybeCreateExtrasForTranslate( 680 List<Intent> intents, LangIdModel langId) { 681 Optional<Intent> translateIntent = 682 FluentIterable.from(intents) 683 .filter(Objects::nonNull) 684 .filter(intent -> Intent.ACTION_TRANSLATE.equals(intent.getAction())) 685 .first(); 686 if (!translateIntent.isPresent()) { 687 return Optional.absent(); 688 } 689 Pair<String, Float> topLanguageWithScore = ExtrasUtils.getTopLanguage(translateIntent.get()); 690 if (topLanguageWithScore == null) { 691 return Optional.absent(); 692 } 693 return Optional.of( 694 ExtrasUtils.createForeignLanguageExtra( 695 topLanguageWithScore.first, topLanguageWithScore.second, langId.getVersion())); 696 } 697 698 private ImmutableList<String> detectLanguageTags(LangIdModel langId, CharSequence text) { 699 float threshold = getLangIdThreshold(langId); 700 EntityConfidence languagesConfidence = detectLanguages(langId, text, threshold); 701 return ImmutableList.copyOf(languagesConfidence.getEntities()); 702 } 703 704 /** 705 * Detects languages for the specified text. Only returns languages with score that is higher than 706 * or equal to the specified threshold. 707 */ 708 private static EntityConfidence detectLanguages( 709 LangIdModel langId, CharSequence text, float threshold) { 710 final LangIdModel.LanguageResult[] langResults = langId.detectLanguages(text.toString()); 711 final Map<String, Float> languagesMap = new ArrayMap<>(); 712 for (LangIdModel.LanguageResult langResult : langResults) { 713 if (langResult.getScore() >= threshold) { 714 languagesMap.put(langResult.getLanguage(), langResult.getScore()); 715 } 716 } 717 return new EntityConfidence(languagesMap); 718 } 719 720 private float getLangIdThreshold(LangIdModel langId) { 721 return settings.getLangIdThresholdOverride() >= 0 722 ? settings.getLangIdThresholdOverride() 723 : langId.getLangIdThreshold(); 724 } 725 726 void dump(IndentingPrintWriter printWriter) { 727 synchronized (lock) { 728 printWriter.println("TextClassifierImpl:"); 729 730 printWriter.increaseIndent(); 731 modelFileManager.dump(printWriter); 732 printWriter.decreaseIndent(); 733 734 printWriter.println(); 735 settings.dump(printWriter); 736 printWriter.println(); 737 } 738 } 739 740 /** Returns the locales string for the current resources configuration. */ 741 private String getResourceLocalesString() { 742 try { 743 return context.getResources().getConfiguration().getLocales().toLanguageTags(); 744 } catch (NullPointerException e) { 745 746 // NPE is unexpected. Erring on the side of caution. 747 return LocaleList.getDefault().toLanguageTags(); 748 } 749 } 750 751 private static void checkMainThread() { 752 if (Looper.myLooper() == Looper.getMainLooper()) { 753 TcLog.e(TAG, "TCS TextClassifier called on main thread", new Exception()); 754 } 755 } 756 757 private static PendingIntent createPendingIntent( 758 final Context context, final Intent intent, int requestCode) { 759 return PendingIntent.getActivity( 760 context, 761 requestCode, 762 intent, 763 PendingIntent.FLAG_UPDATE_CURRENT | PendingIntent.FLAG_IMMUTABLE); 764 } 765 766 @Nullable 767 private static Intent stripPackageInfoFromIntent(@Nullable Intent intent) { 768 if (intent == null) { 769 return null; 770 } 771 Intent strippedIntent = new Intent(intent); 772 strippedIntent.setPackage(null); 773 strippedIntent.setComponent(null); 774 return strippedIntent; 775 } 776 777 private static boolean shouldEnableSearchIntent( 778 @Nullable TextClassificationContext textClassificationContext) { 779 if (textClassificationContext == null) { 780 return false; 781 } 782 String widgetType = textClassificationContext.getWidgetType(); 783 // Exclude WebView because there is already a *Web Search* chip there. 784 return !(TextClassifier.WIDGET_TYPE_WEBVIEW.equals(widgetType) 785 || TextClassifier.WIDGET_TYPE_EDIT_WEBVIEW.equals(widgetType)); 786 } 787 } 788