1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.textclassifier; 18 19 import static com.google.common.truth.Truth.assertThat; 20 import static org.hamcrest.CoreMatchers.not; 21 import static org.junit.Assert.assertArrayEquals; 22 import static org.junit.Assert.assertEquals; 23 import static org.junit.Assert.assertThat; 24 import static org.junit.Assert.assertTrue; 25 import static org.mockito.Mockito.any; 26 import static org.mockito.Mockito.eq; 27 import static org.mockito.Mockito.verify; 28 import static org.mockito.Mockito.when; 29 import static org.testng.Assert.expectThrows; 30 31 import android.app.RemoteAction; 32 import android.content.Context; 33 import android.content.Intent; 34 import android.net.Uri; 35 import android.os.Bundle; 36 import android.os.LocaleList; 37 import android.text.Spannable; 38 import android.text.SpannableString; 39 import android.view.textclassifier.ConversationAction; 40 import android.view.textclassifier.ConversationActions; 41 import android.view.textclassifier.TextClassification; 42 import android.view.textclassifier.TextClassifier; 43 import android.view.textclassifier.TextLanguage; 44 import android.view.textclassifier.TextLinks; 45 import android.view.textclassifier.TextSelection; 46 import androidx.collection.LruCache; 47 import androidx.test.ext.junit.runners.AndroidJUnit4; 48 import androidx.test.filters.SdkSuppress; 49 import androidx.test.filters.SmallTest; 50 import com.android.textclassifier.common.ModelFile; 51 import com.android.textclassifier.common.ModelType; 52 import com.android.textclassifier.common.TextClassifierSettings; 53 import com.android.textclassifier.testing.FakeContextBuilder; 54 import com.android.textclassifier.testing.TestingDeviceConfig; 55 import com.google.android.textclassifier.AnnotatorModel; 56 import com.google.common.collect.ImmutableList; 57 import java.io.IOException; 58 import java.util.ArrayList; 59 import java.util.Arrays; 60 import java.util.Collections; 61 import java.util.List; 62 import org.hamcrest.BaseMatcher; 63 import org.hamcrest.Description; 64 import org.hamcrest.Matcher; 65 import org.junit.Before; 66 import org.junit.Test; 67 import org.junit.runner.RunWith; 68 import org.mockito.Mock; 69 import org.mockito.MockitoAnnotations; 70 71 @SmallTest 72 @RunWith(AndroidJUnit4.class) 73 public class TextClassifierImplTest { 74 75 private static final String TYPE_COPY = "copy"; 76 private static final LocaleList LOCALES = LocaleList.forLanguageTags("en-US"); 77 private static final String NO_TYPE = null; 78 79 @Mock private ModelFileManager modelFileManager; 80 81 private Context context; 82 private TestingDeviceConfig deviceConfig; 83 private TextClassifierSettings settings; 84 private LruCache<ModelFile, AnnotatorModel> annotatorModelCache; 85 private TextClassifierImpl classifier; 86 87 @Before setup()88 public void setup() throws IOException { 89 MockitoAnnotations.initMocks(this); 90 this.context = 91 new FakeContextBuilder() 92 .setAllIntentComponent(FakeContextBuilder.DEFAULT_COMPONENT) 93 .setAppLabel(FakeContextBuilder.DEFAULT_COMPONENT.getPackageName(), "Test app") 94 .build(); 95 this.deviceConfig = new TestingDeviceConfig(); 96 this.settings = new TextClassifierSettings(deviceConfig, /* isWear= */ false); 97 this.annotatorModelCache = new LruCache<>(2); 98 this.classifier = 99 new TextClassifierImpl(context, settings, modelFileManager, annotatorModelCache); 100 101 when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any())) 102 .thenReturn(TestDataUtils.getTestAnnotatorModelFileWrapped()); 103 when(modelFileManager.findBestModelFile(eq(ModelType.LANG_ID), any(), any())) 104 .thenReturn(TestDataUtils.getLangIdModelFileWrapped()); 105 when(modelFileManager.findBestModelFile(eq(ModelType.ACTIONS_SUGGESTIONS), any(), any())) 106 .thenReturn(TestDataUtils.getTestActionsModelFileWrapped()); 107 } 108 109 @Test testSuggestSelection()110 public void testSuggestSelection() throws IOException { 111 String text = "Contact me at droid@android.com"; 112 String selected = "droid"; 113 String suggested = "droid@android.com"; 114 int startIndex = text.indexOf(selected); 115 int endIndex = startIndex + selected.length(); 116 int smartStartIndex = text.indexOf(suggested); 117 int smartEndIndex = smartStartIndex + suggested.length(); 118 TextSelection.Request request = 119 new TextSelection.Request.Builder(text, startIndex, endIndex).build(); 120 121 TextSelection selection = classifier.suggestSelection(null, null, request); 122 assertThat( 123 selection, isTextSelection(smartStartIndex, smartEndIndex, TextClassifier.TYPE_EMAIL)); 124 } 125 126 @Test testSuggestSelection_localePreferenceIsPassedToModelFileManager()127 public void testSuggestSelection_localePreferenceIsPassedToModelFileManager() throws IOException { 128 String text = "Contact me at droid@android.com"; 129 String selected = "droid"; 130 String suggested = "droid@android.com"; 131 int startIndex = text.indexOf(selected); 132 int endIndex = startIndex + selected.length(); 133 int smartStartIndex = text.indexOf(suggested); 134 int smartEndIndex = smartStartIndex + suggested.length(); 135 TextSelection.Request request = 136 new TextSelection.Request.Builder(text, startIndex, endIndex) 137 .setDefaultLocales(LOCALES) 138 .build(); 139 140 classifier.suggestSelection(null, null, request); 141 verify(modelFileManager).findBestModelFile(eq(ModelType.ANNOTATOR), eq(LOCALES), any()); 142 } 143 144 @Test testSuggestSelection_url()145 public void testSuggestSelection_url() throws IOException { 146 String text = "Visit http://www.android.com for more information"; 147 String selected = "http"; 148 String suggested = "http://www.android.com"; 149 int startIndex = text.indexOf(selected); 150 int endIndex = startIndex + selected.length(); 151 int smartStartIndex = text.indexOf(suggested); 152 int smartEndIndex = smartStartIndex + suggested.length(); 153 TextSelection.Request request = 154 new TextSelection.Request.Builder(text, startIndex, endIndex).build(); 155 156 TextSelection selection = classifier.suggestSelection(null, null, request); 157 assertThat(selection, isTextSelection(smartStartIndex, smartEndIndex, TextClassifier.TYPE_URL)); 158 } 159 160 @Test testSmartSelection_withEmoji()161 public void testSmartSelection_withEmoji() throws IOException { 162 String text = "\uD83D\uDE02 Hello."; 163 String selected = "Hello"; 164 int startIndex = text.indexOf(selected); 165 int endIndex = startIndex + selected.length(); 166 TextSelection.Request request = 167 new TextSelection.Request.Builder(text, startIndex, endIndex).build(); 168 169 TextSelection selection = classifier.suggestSelection(null, null, request); 170 assertThat(selection, isTextSelection(startIndex, endIndex, NO_TYPE)); 171 } 172 173 @SdkSuppress(minSdkVersion = 31, codeName = "S") 174 @Test testSuggestSelection_includeTextClassification()175 public void testSuggestSelection_includeTextClassification() throws IOException { 176 String text = "Visit http://www.android.com for more information"; 177 String suggested = "http://www.android.com"; 178 int startIndex = text.indexOf(suggested); 179 TextSelection.Request request = 180 new TextSelection.Request.Builder(text, startIndex, /* endIndex= */ startIndex + 1) 181 .setIncludeTextClassification(true) 182 .build(); 183 184 TextSelection selection = classifier.suggestSelection(null, null, request); 185 186 assertThat( 187 selection.getTextClassification(), 188 isTextClassification(suggested, TextClassifier.TYPE_URL)); 189 assertThat(selection.getTextClassification(), containsIntentWithAction(Intent.ACTION_VIEW)); 190 } 191 192 @SdkSuppress(minSdkVersion = 31, codeName = "S") 193 @Test testSuggestSelection_notIncludeTextClassification()194 public void testSuggestSelection_notIncludeTextClassification() throws IOException { 195 String text = "Visit http://www.android.com for more information"; 196 TextSelection.Request request = 197 new TextSelection.Request.Builder(text, /* startIndex= */ 0, /* endIndex= */ 4) 198 .setIncludeTextClassification(false) 199 .build(); 200 201 TextSelection selection = classifier.suggestSelection(null, null, request); 202 203 assertThat(selection.getTextClassification()).isNull(); 204 } 205 206 @Test testClassifyText()207 public void testClassifyText() throws IOException { 208 String text = "Contact me at droid@android.com"; 209 String classifiedText = "droid@android.com"; 210 int startIndex = text.indexOf(classifiedText); 211 int endIndex = startIndex + classifiedText.length(); 212 TextClassification.Request request = 213 new TextClassification.Request.Builder(text, startIndex, endIndex).build(); 214 215 TextClassification classification = 216 classifier.classifyText(/* sessionId= */ null, null, request); 217 assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_EMAIL)); 218 } 219 220 @Test testClassifyText_url()221 public void testClassifyText_url() throws IOException { 222 String text = "Visit www.android.com for more information"; 223 String classifiedText = "www.android.com"; 224 int startIndex = text.indexOf(classifiedText); 225 int endIndex = startIndex + classifiedText.length(); 226 TextClassification.Request request = 227 new TextClassification.Request.Builder(text, startIndex, endIndex).build(); 228 229 TextClassification classification = classifier.classifyText(null, null, request); 230 assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_URL)); 231 assertThat(classification, containsIntentWithAction(Intent.ACTION_VIEW)); 232 } 233 234 @Test testClassifyText_address()235 public void testClassifyText_address() throws IOException { 236 String text = "Brandschenkestrasse 110, Zürich, Switzerland"; 237 TextClassification.Request request = 238 new TextClassification.Request.Builder(text, 0, text.length()).build(); 239 240 TextClassification classification = classifier.classifyText(null, null, request); 241 assertThat(classification, isTextClassification(text, TextClassifier.TYPE_ADDRESS)); 242 } 243 244 @Test testClassifyText_url_inCaps()245 public void testClassifyText_url_inCaps() throws IOException { 246 String text = "Visit HTTP://ANDROID.COM for more information"; 247 String classifiedText = "HTTP://ANDROID.COM"; 248 int startIndex = text.indexOf(classifiedText); 249 int endIndex = startIndex + classifiedText.length(); 250 TextClassification.Request request = 251 new TextClassification.Request.Builder(text, startIndex, endIndex).build(); 252 253 TextClassification classification = classifier.classifyText(null, null, request); 254 assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_URL)); 255 assertThat(classification, containsIntentWithAction(Intent.ACTION_VIEW)); 256 } 257 258 @Test testClassifyText_date()259 public void testClassifyText_date() throws IOException { 260 String text = "Let's meet on January 9, 2018."; 261 String classifiedText = "January 9, 2018"; 262 int startIndex = text.indexOf(classifiedText); 263 int endIndex = startIndex + classifiedText.length(); 264 TextClassification.Request request = 265 new TextClassification.Request.Builder(text, startIndex, endIndex).build(); 266 267 TextClassification classification = classifier.classifyText(null, null, request); 268 assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_DATE)); 269 Bundle extras = classification.getExtras(); 270 List<Bundle> entities = ExtrasUtils.getEntities(extras); 271 assertThat(entities).hasSize(1); 272 assertThat(ExtrasUtils.getEntityType(entities.get(0))).isEqualTo(TextClassifier.TYPE_DATE); 273 ArrayList<Intent> actionsIntents = ExtrasUtils.getActionsIntents(classification); 274 actionsIntents.forEach(TextClassifierImplTest::assertNoPackageInfoInExtras); 275 } 276 277 @Test testClassifyText_datetime()278 public void testClassifyText_datetime() throws IOException { 279 String text = "Let's meet 2018/01/01 10:30:20."; 280 String classifiedText = "2018/01/01 10:30:20"; 281 int startIndex = text.indexOf(classifiedText); 282 int endIndex = startIndex + classifiedText.length(); 283 TextClassification.Request request = 284 new TextClassification.Request.Builder(text, startIndex, endIndex).build(); 285 286 TextClassification classification = classifier.classifyText(null, null, request); 287 assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_DATE_TIME)); 288 } 289 290 @Test testClassifyText_foreignText()291 public void testClassifyText_foreignText() throws IOException { 292 LocaleList originalLocales = LocaleList.getDefault(); 293 LocaleList.setDefault(LocaleList.forLanguageTags("en")); 294 String japaneseText = "これは日本語のテキストです"; 295 TextClassification.Request request = 296 new TextClassification.Request.Builder(japaneseText, 0, japaneseText.length()).build(); 297 298 TextClassification classification = classifier.classifyText(null, null, request); 299 RemoteAction translateAction = classification.getActions().get(0); 300 assertEquals(1, classification.getActions().size()); 301 assertEquals(Intent.ACTION_TRANSLATE, classification.getIntent().getAction()); 302 303 assertEquals(translateAction, ExtrasUtils.findTranslateAction(classification)); 304 Intent intent = ExtrasUtils.getActionsIntents(classification).get(0); 305 assertNoPackageInfoInExtras(intent); 306 assertEquals(Intent.ACTION_TRANSLATE, intent.getAction()); 307 Bundle foreignLanguageInfo = ExtrasUtils.getForeignLanguageExtra(classification); 308 assertEquals("ja", ExtrasUtils.getEntityType(foreignLanguageInfo)); 309 assertTrue(ExtrasUtils.getScore(foreignLanguageInfo) >= 0); 310 assertTrue(ExtrasUtils.getScore(foreignLanguageInfo) <= 1); 311 assertTrue(intent.hasExtra(TextClassifier.EXTRA_FROM_TEXT_CLASSIFIER)); 312 assertEquals("ja", ExtrasUtils.getTopLanguage(intent).first); 313 314 LocaleList.setDefault(originalLocales); 315 } 316 317 @Test testGenerateLinks_phone()318 public void testGenerateLinks_phone() throws IOException { 319 String text = "The number is +12122537077. See you tonight!"; 320 TextLinks.Request request = new TextLinks.Request.Builder(text).build(); 321 assertThat( 322 classifier.generateLinks(null, null, request), 323 isTextLinksContaining(text, "+12122537077", TextClassifier.TYPE_PHONE)); 324 } 325 326 @Test testGenerateLinks_exclude()327 public void testGenerateLinks_exclude() throws IOException { 328 String text = "The number is +12122537077. See you tonight!"; 329 List<String> hints = ImmutableList.of(); 330 List<String> included = ImmutableList.of(); 331 List<String> excluded = Arrays.asList(TextClassifier.TYPE_PHONE); 332 TextLinks.Request request = 333 new TextLinks.Request.Builder(text) 334 .setEntityConfig(TextClassifier.EntityConfig.create(hints, included, excluded)) 335 .build(); 336 assertThat( 337 classifier.generateLinks(null, null, request), 338 not(isTextLinksContaining(text, "+12122537077", TextClassifier.TYPE_PHONE))); 339 } 340 341 @Test testGenerateLinks_explicit_address()342 public void testGenerateLinks_explicit_address() throws IOException { 343 String text = "The address is 1600 Amphitheater Parkway, Mountain View, CA. See you!"; 344 List<String> explicit = Arrays.asList(TextClassifier.TYPE_ADDRESS); 345 TextLinks.Request request = 346 new TextLinks.Request.Builder(text) 347 .setEntityConfig(TextClassifier.EntityConfig.createWithExplicitEntityList(explicit)) 348 .build(); 349 assertThat( 350 classifier.generateLinks(null, null, request), 351 isTextLinksContaining( 352 text, "1600 Amphitheater Parkway, Mountain View, CA", TextClassifier.TYPE_ADDRESS)); 353 } 354 355 @Test testGenerateLinks_exclude_override()356 public void testGenerateLinks_exclude_override() throws IOException { 357 String text = "You want apple@banana.com. See you tonight!"; 358 List<String> hints = ImmutableList.of(); 359 List<String> included = Arrays.asList(TextClassifier.TYPE_EMAIL); 360 List<String> excluded = Arrays.asList(TextClassifier.TYPE_EMAIL); 361 TextLinks.Request request = 362 new TextLinks.Request.Builder(text) 363 .setEntityConfig(TextClassifier.EntityConfig.create(hints, included, excluded)) 364 .build(); 365 assertThat( 366 classifier.generateLinks(null, null, request), 367 not(isTextLinksContaining(text, "apple@banana.com", TextClassifier.TYPE_EMAIL))); 368 } 369 370 @Test testGenerateLinks_maxLength()371 public void testGenerateLinks_maxLength() throws IOException { 372 char[] manySpaces = new char[classifier.getMaxGenerateLinksTextLength()]; 373 Arrays.fill(manySpaces, ' '); 374 TextLinks.Request request = new TextLinks.Request.Builder(new String(manySpaces)).build(); 375 TextLinks links = classifier.generateLinks(null, null, request); 376 assertTrue(links.getLinks().isEmpty()); 377 } 378 379 @Test testApplyLinks_unsupportedCharacter()380 public void testApplyLinks_unsupportedCharacter() throws IOException { 381 Spannable url = new SpannableString("\u202Emoc.diordna.com"); 382 TextLinks.Request request = new TextLinks.Request.Builder(url).build(); 383 assertEquals( 384 TextLinks.STATUS_UNSUPPORTED_CHARACTER, 385 classifier.generateLinks(null, null, request).apply(url, 0, null)); 386 } 387 388 @Test testGenerateLinks_tooLong()389 public void testGenerateLinks_tooLong() { 390 char[] manySpaces = new char[classifier.getMaxGenerateLinksTextLength() + 1]; 391 Arrays.fill(manySpaces, ' '); 392 TextLinks.Request request = new TextLinks.Request.Builder(new String(manySpaces)).build(); 393 expectThrows( 394 IllegalArgumentException.class, () -> classifier.generateLinks(null, null, request)); 395 } 396 397 @Test testGenerateLinks_entityData()398 public void testGenerateLinks_entityData() throws IOException { 399 String text = "The number is +12122537077."; 400 Bundle extras = new Bundle(); 401 ExtrasUtils.putIsSerializedEntityDataEnabled(extras, true); 402 TextLinks.Request request = new TextLinks.Request.Builder(text).setExtras(extras).build(); 403 404 TextLinks textLinks = classifier.generateLinks(null, null, request); 405 406 assertThat(textLinks.getLinks()).hasSize(1); 407 TextLinks.TextLink textLink = textLinks.getLinks().iterator().next(); 408 List<Bundle> entities = ExtrasUtils.getEntities(textLink.getExtras()); 409 assertThat(entities).hasSize(1); 410 Bundle entity = entities.get(0); 411 assertThat(ExtrasUtils.getEntityType(entity)).isEqualTo(TextClassifier.TYPE_PHONE); 412 } 413 414 @Test testGenerateLinks_entityData_disabled()415 public void testGenerateLinks_entityData_disabled() throws IOException { 416 String text = "The number is +12122537077."; 417 TextLinks.Request request = new TextLinks.Request.Builder(text).build(); 418 419 TextLinks textLinks = classifier.generateLinks(null, null, request); 420 421 assertThat(textLinks.getLinks()).hasSize(1); 422 TextLinks.TextLink textLink = textLinks.getLinks().iterator().next(); 423 List<Bundle> entities = ExtrasUtils.getEntities(textLink.getExtras()); 424 assertThat(entities).isNull(); 425 } 426 427 @Test testDetectLanguage()428 public void testDetectLanguage() throws IOException { 429 String text = "This is English text"; 430 TextLanguage.Request request = new TextLanguage.Request.Builder(text).build(); 431 TextLanguage textLanguage = classifier.detectLanguage(null, null, request); 432 assertThat(textLanguage, isTextLanguage("en")); 433 } 434 435 @Test testDetectLanguage_japanese()436 public void testDetectLanguage_japanese() throws IOException { 437 String text = "これは日本語のテキストです"; 438 TextLanguage.Request request = new TextLanguage.Request.Builder(text).build(); 439 TextLanguage textLanguage = classifier.detectLanguage(null, null, request); 440 assertThat(textLanguage, isTextLanguage("ja")); 441 } 442 443 @Test testSuggestConversationActions_textReplyOnly_maxOne()444 public void testSuggestConversationActions_textReplyOnly_maxOne() throws IOException { 445 ConversationActions.Message message = 446 new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS) 447 .setText("Where are you?") 448 .build(); 449 TextClassifier.EntityConfig typeConfig = 450 new TextClassifier.EntityConfig.Builder() 451 .includeTypesFromTextClassifier(false) 452 .setIncludedTypes(Collections.singletonList(ConversationAction.TYPE_TEXT_REPLY)) 453 .build(); 454 ConversationActions.Request request = 455 new ConversationActions.Request.Builder(Collections.singletonList(message)) 456 .setMaxSuggestions(1) 457 .setTypeConfig(typeConfig) 458 .build(); 459 460 ConversationActions conversationActions = 461 classifier.suggestConversationActions(null, null, request); 462 assertThat(conversationActions.getConversationActions()).hasSize(1); 463 ConversationAction conversationAction = conversationActions.getConversationActions().get(0); 464 assertThat(conversationAction.getType()).isEqualTo(ConversationAction.TYPE_TEXT_REPLY); 465 assertThat(conversationAction.getTextReply()).isNotNull(); 466 } 467 468 @Test testSuggestConversationActions_textReplyOnly_noMax()469 public void testSuggestConversationActions_textReplyOnly_noMax() throws IOException { 470 ConversationActions.Message message = 471 new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS) 472 .setText("Where are you?") 473 .build(); 474 TextClassifier.EntityConfig typeConfig = 475 new TextClassifier.EntityConfig.Builder() 476 .includeTypesFromTextClassifier(false) 477 .setIncludedTypes(Collections.singletonList(ConversationAction.TYPE_TEXT_REPLY)) 478 .build(); 479 ConversationActions.Request request = 480 new ConversationActions.Request.Builder(Collections.singletonList(message)) 481 .setTypeConfig(typeConfig) 482 .build(); 483 484 ConversationActions conversationActions = 485 classifier.suggestConversationActions(null, null, request); 486 assertTrue(conversationActions.getConversationActions().size() > 1); 487 for (ConversationAction conversationAction : conversationActions.getConversationActions()) { 488 assertThat(conversationAction, isConversationAction(ConversationAction.TYPE_TEXT_REPLY)); 489 } 490 } 491 492 @Test testSuggestConversationActions_openUrl()493 public void testSuggestConversationActions_openUrl() throws IOException { 494 ConversationActions.Message message = 495 new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS) 496 .setText("Check this out: https://www.android.com") 497 .build(); 498 TextClassifier.EntityConfig typeConfig = 499 new TextClassifier.EntityConfig.Builder() 500 .includeTypesFromTextClassifier(false) 501 .setIncludedTypes(Collections.singletonList(ConversationAction.TYPE_OPEN_URL)) 502 .build(); 503 ConversationActions.Request request = 504 new ConversationActions.Request.Builder(Collections.singletonList(message)) 505 .setMaxSuggestions(1) 506 .setTypeConfig(typeConfig) 507 .build(); 508 509 ConversationActions conversationActions = 510 classifier.suggestConversationActions(null, null, request); 511 assertThat(conversationActions.getConversationActions()).hasSize(1); 512 ConversationAction conversationAction = conversationActions.getConversationActions().get(0); 513 assertThat(conversationAction.getType()).isEqualTo(ConversationAction.TYPE_OPEN_URL); 514 Intent actionIntent = ExtrasUtils.getActionIntent(conversationAction.getExtras()); 515 assertThat(actionIntent.getAction()).isEqualTo(Intent.ACTION_VIEW); 516 assertThat(actionIntent.getData()).isEqualTo(Uri.parse("https://www.android.com")); 517 assertNoPackageInfoInExtras(actionIntent); 518 } 519 520 @Test testSuggestConversationActions_copy()521 public void testSuggestConversationActions_copy() throws IOException { 522 ConversationActions.Message message = 523 new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS) 524 .setText("Authentication code: 12345") 525 .build(); 526 TextClassifier.EntityConfig typeConfig = 527 new TextClassifier.EntityConfig.Builder() 528 .includeTypesFromTextClassifier(false) 529 .setIncludedTypes(Collections.singletonList(TYPE_COPY)) 530 .build(); 531 ConversationActions.Request request = 532 new ConversationActions.Request.Builder(Collections.singletonList(message)) 533 .setMaxSuggestions(1) 534 .setTypeConfig(typeConfig) 535 .build(); 536 537 ConversationActions conversationActions = 538 classifier.suggestConversationActions(null, null, request); 539 assertThat(conversationActions.getConversationActions()).hasSize(1); 540 ConversationAction conversationAction = conversationActions.getConversationActions().get(0); 541 assertThat(conversationAction.getType()).isEqualTo(TYPE_COPY); 542 assertThat(conversationAction.getTextReply()).isAnyOf(null, ""); 543 assertThat(conversationAction.getAction()).isNull(); 544 String code = ExtrasUtils.getCopyText(conversationAction.getExtras()); 545 assertThat(code).isEqualTo("12345"); 546 assertThat(ExtrasUtils.getSerializedEntityData(conversationAction.getExtras())).isNotEmpty(); 547 } 548 549 @Test testSuggestConversationActions_deduplicate()550 public void testSuggestConversationActions_deduplicate() throws IOException { 551 ConversationActions.Message message = 552 new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS) 553 .setText("a@android.com b@android.com") 554 .build(); 555 ConversationActions.Request request = 556 new ConversationActions.Request.Builder(Collections.singletonList(message)) 557 .setMaxSuggestions(3) 558 .build(); 559 560 ConversationActions conversationActions = 561 classifier.suggestConversationActions(null, null, request); 562 563 assertThat(conversationActions.getConversationActions()).isEmpty(); 564 } 565 566 @Test testUseCachedAnnotatorModelDisabled()567 public void testUseCachedAnnotatorModelDisabled() throws IOException { 568 deviceConfig.setConfig(TextClassifierSettings.MODEL_DOWNLOAD_MANAGER_ENABLED, true); 569 570 String annotatorFilePath = TestDataUtils.getTestAnnotatorModelFile().getPath(); 571 ModelFile annotatorModelA = 572 new ModelFile(ModelType.ANNOTATOR, annotatorFilePath, 701, "en", false); 573 ModelFile annotatorModelB = 574 new ModelFile(ModelType.ANNOTATOR, annotatorFilePath, 801, "en", false); 575 576 String englishText = "You can reach me on +12122537077."; 577 String classifiedText = "+12122537077"; 578 TextClassification.Request request = 579 new TextClassification.Request.Builder(englishText, 0, englishText.length()).build(); 580 581 // Check modelFileA v701 582 when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any())) 583 .thenReturn(annotatorModelA); 584 TextClassification classificationA = classifier.classifyText(null, null, request); 585 586 assertThat(classificationA.getId()).contains("v701"); 587 assertThat(classificationA.getText()).contains(classifiedText); 588 assertArrayEquals( 589 new int[] {0, 0, 0, 0}, 590 new int[] { 591 annotatorModelCache.putCount(), 592 annotatorModelCache.evictionCount(), 593 annotatorModelCache.hitCount(), 594 annotatorModelCache.missCount() 595 }); 596 597 // Check modelFileB v801 598 when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any())) 599 .thenReturn(annotatorModelB); 600 TextClassification classificationB = classifier.classifyText(null, null, request); 601 602 assertThat(classificationB.getId()).contains("v801"); 603 assertThat(classificationB.getText()).contains(classifiedText); 604 assertArrayEquals( 605 new int[] {0, 0, 0, 0}, 606 new int[] { 607 annotatorModelCache.putCount(), 608 annotatorModelCache.evictionCount(), 609 annotatorModelCache.hitCount(), 610 annotatorModelCache.missCount() 611 }); 612 613 // Reload modelFileA v701 614 when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any())) 615 .thenReturn(annotatorModelA); 616 TextClassification classificationAcached = classifier.classifyText(null, null, request); 617 618 assertThat(classificationAcached.getId()).contains("v701"); 619 assertThat(classificationAcached.getText()).contains(classifiedText); 620 assertArrayEquals( 621 new int[] {0, 0, 0, 0}, 622 new int[] { 623 annotatorModelCache.putCount(), 624 annotatorModelCache.evictionCount(), 625 annotatorModelCache.hitCount(), 626 annotatorModelCache.missCount() 627 }); 628 } 629 630 @Test testUseCachedAnnotatorModelEnabled()631 public void testUseCachedAnnotatorModelEnabled() throws IOException { 632 deviceConfig.setConfig(TextClassifierSettings.MODEL_DOWNLOAD_MANAGER_ENABLED, true); 633 deviceConfig.setConfig(TextClassifierSettings.MULTI_ANNOTATOR_CACHE_ENABLED, true); 634 635 String annotatorFilePath = TestDataUtils.getTestAnnotatorModelFile().getPath(); 636 ModelFile annotatorModelA = 637 new ModelFile(ModelType.ANNOTATOR, annotatorFilePath, 701, "en", false); 638 ModelFile annotatorModelB = 639 new ModelFile(ModelType.ANNOTATOR, annotatorFilePath, 801, "en", false); 640 641 String englishText = "You can reach me on +12122537077."; 642 String classifiedText = "+12122537077"; 643 TextClassification.Request request = 644 new TextClassification.Request.Builder(englishText, 0, englishText.length()).build(); 645 646 // Check modelFileA v701 647 when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any())) 648 .thenReturn(annotatorModelA); 649 TextClassification classification = classifier.classifyText(null, null, request); 650 651 assertThat(classification.getId()).contains("v701"); 652 assertThat(classification.getText()).contains(classifiedText); 653 assertArrayEquals( 654 new int[] {1, 0, 0, 1}, 655 new int[] { 656 annotatorModelCache.putCount(), 657 annotatorModelCache.evictionCount(), 658 annotatorModelCache.hitCount(), 659 annotatorModelCache.missCount() 660 }); 661 662 // Check modelFileB v801 663 when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any())) 664 .thenReturn(annotatorModelB); 665 TextClassification classificationB = classifier.classifyText(null, null, request); 666 667 assertThat(classificationB.getId()).contains("v801"); 668 assertThat(classificationB.getText()).contains(classifiedText); 669 assertArrayEquals( 670 new int[] {2, 0, 0, 2}, 671 new int[] { 672 annotatorModelCache.putCount(), 673 annotatorModelCache.evictionCount(), 674 annotatorModelCache.hitCount(), 675 annotatorModelCache.missCount() 676 }); 677 678 // Reload modelFileA v701 679 when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any())) 680 .thenReturn(annotatorModelA); 681 TextClassification classificationAcached = classifier.classifyText(null, null, request); 682 683 assertThat(classificationAcached.getId()).contains("v701"); 684 assertThat(classificationAcached.getText()).contains(classifiedText); 685 assertArrayEquals( 686 new int[] {2, 0, 1, 2}, 687 new int[] { 688 annotatorModelCache.putCount(), 689 annotatorModelCache.evictionCount(), 690 annotatorModelCache.hitCount(), 691 annotatorModelCache.missCount() 692 }); 693 } 694 assertNoPackageInfoInExtras(Intent intent)695 private static void assertNoPackageInfoInExtras(Intent intent) { 696 assertThat(intent.getComponent()).isNull(); 697 assertThat(intent.getPackage()).isNull(); 698 } 699 isTextSelection( final int startIndex, final int endIndex, final String type)700 private static Matcher<TextSelection> isTextSelection( 701 final int startIndex, final int endIndex, final String type) { 702 return new BaseMatcher<TextSelection>() { 703 @Override 704 public boolean matches(Object o) { 705 if (o instanceof TextSelection) { 706 TextSelection selection = (TextSelection) o; 707 return startIndex == selection.getSelectionStartIndex() 708 && endIndex == selection.getSelectionEndIndex() 709 && typeMatches(selection, type); 710 } 711 return false; 712 } 713 714 private boolean typeMatches(TextSelection selection, String type) { 715 return type == null 716 || (selection.getEntityCount() > 0 717 && type.trim().equalsIgnoreCase(selection.getEntity(0))); 718 } 719 720 @Override 721 public void describeTo(Description description) { 722 description.appendValue(String.format("%d, %d, %s", startIndex, endIndex, type)); 723 } 724 }; 725 } 726 727 private static Matcher<TextLinks> isTextLinksContaining( 728 final String text, final String substring, final String type) { 729 return new BaseMatcher<TextLinks>() { 730 731 @Override 732 public void describeTo(Description description) { 733 description 734 .appendText("text=") 735 .appendValue(text) 736 .appendText(", substring=") 737 .appendValue(substring) 738 .appendText(", type=") 739 .appendValue(type); 740 } 741 742 @Override 743 public boolean matches(Object o) { 744 if (o instanceof TextLinks) { 745 for (TextLinks.TextLink link : ((TextLinks) o).getLinks()) { 746 if (text.subSequence(link.getStart(), link.getEnd()).toString().equals(substring)) { 747 return type.equals(link.getEntity(0)); 748 } 749 } 750 } 751 return false; 752 } 753 }; 754 } 755 756 private static Matcher<TextClassification> isTextClassification( 757 final String text, final String type) { 758 return new BaseMatcher<TextClassification>() { 759 @Override 760 public boolean matches(Object o) { 761 if (o instanceof TextClassification) { 762 TextClassification result = (TextClassification) o; 763 return text.equals(result.getText()) 764 && result.getEntityCount() > 0 765 && type.equals(result.getEntity(0)); 766 } 767 return false; 768 } 769 770 @Override 771 public void describeTo(Description description) { 772 description.appendText("text=").appendValue(text).appendText(", type=").appendValue(type); 773 } 774 }; 775 } 776 777 private static Matcher<TextClassification> containsIntentWithAction(final String action) { 778 return new BaseMatcher<TextClassification>() { 779 @Override 780 public boolean matches(Object o) { 781 if (o instanceof TextClassification) { 782 TextClassification result = (TextClassification) o; 783 return ExtrasUtils.findAction(result, action) != null; 784 } 785 return false; 786 } 787 788 @Override 789 public void describeTo(Description description) { 790 description.appendText("intent action=").appendValue(action); 791 } 792 }; 793 } 794 795 private static Matcher<TextLanguage> isTextLanguage(final String languageTag) { 796 return new BaseMatcher<TextLanguage>() { 797 @Override 798 public boolean matches(Object o) { 799 if (o instanceof TextLanguage) { 800 TextLanguage result = (TextLanguage) o; 801 return result.getLocaleHypothesisCount() > 0 802 && languageTag.equals(result.getLocale(0).toLanguageTag()); 803 } 804 return false; 805 } 806 807 @Override 808 public void describeTo(Description description) { 809 description.appendText("locale=").appendValue(languageTag); 810 } 811 }; 812 } 813 814 private static Matcher<ConversationAction> isConversationAction(String actionType) { 815 return new BaseMatcher<ConversationAction>() { 816 @Override 817 public boolean matches(Object o) { 818 if (!(o instanceof ConversationAction)) { 819 return false; 820 } 821 ConversationAction conversationAction = (ConversationAction) o; 822 if (!actionType.equals(conversationAction.getType())) { 823 return false; 824 } 825 if (ConversationAction.TYPE_TEXT_REPLY.equals(actionType)) { 826 if (conversationAction.getTextReply() == null) { 827 return false; 828 } 829 } 830 if (conversationAction.getConfidenceScore() < 0 831 || conversationAction.getConfidenceScore() > 1) { 832 return false; 833 } 834 return true; 835 } 836 837 @Override 838 public void describeTo(Description description) { 839 description.appendText("actionType=").appendValue(actionType); 840 } 841 }; 842 } 843 } 844