1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.textclassifier; 18 19 import static com.google.common.truth.Truth.assertThat; 20 import static org.hamcrest.CoreMatchers.not; 21 import static org.junit.Assert.assertArrayEquals; 22 import static org.junit.Assert.assertEquals; 23 import static org.junit.Assert.assertThat; 24 import static org.junit.Assert.assertTrue; 25 import static org.mockito.Mockito.any; 26 import static org.mockito.Mockito.eq; 27 import static org.mockito.Mockito.verify; 28 import static org.mockito.Mockito.when; 29 import static org.testng.Assert.expectThrows; 30 31 import android.app.RemoteAction; 32 import android.content.Context; 33 import android.content.Intent; 34 import android.net.Uri; 35 import android.os.Bundle; 36 import android.os.LocaleList; 37 import android.text.Spannable; 38 import android.text.SpannableString; 39 import android.view.textclassifier.ConversationAction; 40 import android.view.textclassifier.ConversationActions; 41 import android.view.textclassifier.TextClassification; 42 import android.view.textclassifier.TextClassifier; 43 import android.view.textclassifier.TextLanguage; 44 import android.view.textclassifier.TextLinks; 45 import android.view.textclassifier.TextSelection; 46 import androidx.collection.LruCache; 47 import androidx.test.core.app.ApplicationProvider; 48 import androidx.test.ext.junit.runners.AndroidJUnit4; 49 import androidx.test.filters.SdkSuppress; 50 import androidx.test.filters.SmallTest; 51 import com.android.textclassifier.common.ModelFile; 52 import com.android.textclassifier.common.ModelType; 53 import com.android.textclassifier.common.TextClassifierSettings; 54 import com.android.textclassifier.testing.FakeContextBuilder; 55 import com.android.textclassifier.testing.TestingDeviceConfig; 56 import com.google.android.textclassifier.AnnotatorModel; 57 import com.google.common.collect.ImmutableList; 58 import java.io.IOException; 59 import java.util.ArrayList; 60 import java.util.Arrays; 61 import java.util.Collections; 62 import java.util.List; 63 import org.hamcrest.BaseMatcher; 64 import org.hamcrest.Description; 65 import org.hamcrest.Matcher; 66 import org.junit.Before; 67 import org.junit.Test; 68 import org.junit.runner.RunWith; 69 import org.mockito.Mock; 70 import org.mockito.MockitoAnnotations; 71 72 @SmallTest 73 @RunWith(AndroidJUnit4.class) 74 public class TextClassifierImplTest { 75 76 private static final String TYPE_COPY = "copy"; 77 private static final LocaleList LOCALES = LocaleList.forLanguageTags("en-US"); 78 private static final String NO_TYPE = null; 79 80 @Mock private ModelFileManager modelFileManager; 81 82 private Context context; 83 private TestingDeviceConfig deviceConfig; 84 private TextClassifierSettings settings; 85 private LruCache<ModelFile, AnnotatorModel> annotatorModelCache; 86 private TextClassifierImpl classifier; 87 88 @Before setup()89 public void setup() throws IOException { 90 MockitoAnnotations.initMocks(this); 91 this.context = 92 new FakeContextBuilder() 93 .setAllIntentComponent(FakeContextBuilder.DEFAULT_COMPONENT) 94 .setAppLabel(FakeContextBuilder.DEFAULT_COMPONENT.getPackageName(), "Test app") 95 .build(); 96 this.deviceConfig = new TestingDeviceConfig(); 97 this.settings = new TextClassifierSettings(deviceConfig); 98 this.annotatorModelCache = new LruCache<>(2); 99 this.classifier = 100 new TextClassifierImpl(context, settings, modelFileManager, annotatorModelCache); 101 102 when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any())) 103 .thenReturn(TestDataUtils.getTestAnnotatorModelFileWrapped()); 104 when(modelFileManager.findBestModelFile(eq(ModelType.LANG_ID), any(), any())) 105 .thenReturn(TestDataUtils.getLangIdModelFileWrapped()); 106 when(modelFileManager.findBestModelFile(eq(ModelType.ACTIONS_SUGGESTIONS), any(), any())) 107 .thenReturn(TestDataUtils.getTestActionsModelFileWrapped()); 108 } 109 110 @Test testSuggestSelection()111 public void testSuggestSelection() throws IOException { 112 String text = "Contact me at droid@android.com"; 113 String selected = "droid"; 114 String suggested = "droid@android.com"; 115 int startIndex = text.indexOf(selected); 116 int endIndex = startIndex + selected.length(); 117 int smartStartIndex = text.indexOf(suggested); 118 int smartEndIndex = smartStartIndex + suggested.length(); 119 TextSelection.Request request = 120 new TextSelection.Request.Builder(text, startIndex, endIndex).build(); 121 122 TextSelection selection = classifier.suggestSelection(null, null, request); 123 assertThat( 124 selection, isTextSelection(smartStartIndex, smartEndIndex, TextClassifier.TYPE_EMAIL)); 125 } 126 127 @Test testSuggestSelection_localePreferenceIsPassedToModelFileManager()128 public void testSuggestSelection_localePreferenceIsPassedToModelFileManager() throws IOException { 129 String text = "Contact me at droid@android.com"; 130 String selected = "droid"; 131 String suggested = "droid@android.com"; 132 int startIndex = text.indexOf(selected); 133 int endIndex = startIndex + selected.length(); 134 int smartStartIndex = text.indexOf(suggested); 135 int smartEndIndex = smartStartIndex + suggested.length(); 136 TextSelection.Request request = 137 new TextSelection.Request.Builder(text, startIndex, endIndex) 138 .setDefaultLocales(LOCALES) 139 .build(); 140 141 classifier.suggestSelection(null, null, request); 142 verify(modelFileManager).findBestModelFile(eq(ModelType.ANNOTATOR), eq(LOCALES), any()); 143 } 144 145 @Test testSuggestSelection_url()146 public void testSuggestSelection_url() throws IOException { 147 String text = "Visit http://www.android.com for more information"; 148 String selected = "http"; 149 String suggested = "http://www.android.com"; 150 int startIndex = text.indexOf(selected); 151 int endIndex = startIndex + selected.length(); 152 int smartStartIndex = text.indexOf(suggested); 153 int smartEndIndex = smartStartIndex + suggested.length(); 154 TextSelection.Request request = 155 new TextSelection.Request.Builder(text, startIndex, endIndex).build(); 156 157 TextSelection selection = classifier.suggestSelection(null, null, request); 158 assertThat(selection, isTextSelection(smartStartIndex, smartEndIndex, TextClassifier.TYPE_URL)); 159 } 160 161 @Test testSmartSelection_withEmoji()162 public void testSmartSelection_withEmoji() throws IOException { 163 String text = "\uD83D\uDE02 Hello."; 164 String selected = "Hello"; 165 int startIndex = text.indexOf(selected); 166 int endIndex = startIndex + selected.length(); 167 TextSelection.Request request = 168 new TextSelection.Request.Builder(text, startIndex, endIndex).build(); 169 170 TextSelection selection = classifier.suggestSelection(null, null, request); 171 assertThat(selection, isTextSelection(startIndex, endIndex, NO_TYPE)); 172 } 173 174 @SdkSuppress(minSdkVersion = 31, codeName = "S") 175 @Test testSuggestSelection_includeTextClassification()176 public void testSuggestSelection_includeTextClassification() throws IOException { 177 String text = "Visit http://www.android.com for more information"; 178 String suggested = "http://www.android.com"; 179 int startIndex = text.indexOf(suggested); 180 TextSelection.Request request = 181 new TextSelection.Request.Builder(text, startIndex, /*endIndex=*/ startIndex + 1) 182 .setIncludeTextClassification(true) 183 .build(); 184 185 TextSelection selection = classifier.suggestSelection(null, null, request); 186 187 assertThat( 188 selection.getTextClassification(), 189 isTextClassification(suggested, TextClassifier.TYPE_URL)); 190 assertThat(selection.getTextClassification(), containsIntentWithAction(Intent.ACTION_VIEW)); 191 } 192 193 @SdkSuppress(minSdkVersion = 31, codeName = "S") 194 @Test testSuggestSelection_notIncludeTextClassification()195 public void testSuggestSelection_notIncludeTextClassification() throws IOException { 196 String text = "Visit http://www.android.com for more information"; 197 TextSelection.Request request = 198 new TextSelection.Request.Builder(text, /*startIndex=*/ 0, /*endIndex=*/ 4) 199 .setIncludeTextClassification(false) 200 .build(); 201 202 TextSelection selection = classifier.suggestSelection(null, null, request); 203 204 assertThat(selection.getTextClassification()).isNull(); 205 } 206 207 @Test testClassifyText()208 public void testClassifyText() throws IOException { 209 String text = "Contact me at droid@android.com"; 210 String classifiedText = "droid@android.com"; 211 int startIndex = text.indexOf(classifiedText); 212 int endIndex = startIndex + classifiedText.length(); 213 TextClassification.Request request = 214 new TextClassification.Request.Builder(text, startIndex, endIndex).build(); 215 216 TextClassification classification = 217 classifier.classifyText(/* sessionId= */ null, null, request); 218 assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_EMAIL)); 219 } 220 221 @Test testClassifyText_url()222 public void testClassifyText_url() throws IOException { 223 String text = "Visit www.android.com for more information"; 224 String classifiedText = "www.android.com"; 225 int startIndex = text.indexOf(classifiedText); 226 int endIndex = startIndex + classifiedText.length(); 227 TextClassification.Request request = 228 new TextClassification.Request.Builder(text, startIndex, endIndex).build(); 229 230 TextClassification classification = classifier.classifyText(null, null, request); 231 assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_URL)); 232 assertThat(classification, containsIntentWithAction(Intent.ACTION_VIEW)); 233 } 234 235 @Test testClassifyText_address()236 public void testClassifyText_address() throws IOException { 237 String text = "Brandschenkestrasse 110, Zürich, Switzerland"; 238 TextClassification.Request request = 239 new TextClassification.Request.Builder(text, 0, text.length()).build(); 240 241 TextClassification classification = classifier.classifyText(null, null, request); 242 assertThat(classification, isTextClassification(text, TextClassifier.TYPE_ADDRESS)); 243 } 244 245 @Test testClassifyText_url_inCaps()246 public void testClassifyText_url_inCaps() throws IOException { 247 String text = "Visit HTTP://ANDROID.COM for more information"; 248 String classifiedText = "HTTP://ANDROID.COM"; 249 int startIndex = text.indexOf(classifiedText); 250 int endIndex = startIndex + classifiedText.length(); 251 TextClassification.Request request = 252 new TextClassification.Request.Builder(text, startIndex, endIndex).build(); 253 254 TextClassification classification = classifier.classifyText(null, null, request); 255 assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_URL)); 256 assertThat(classification, containsIntentWithAction(Intent.ACTION_VIEW)); 257 } 258 259 @Test testClassifyText_date()260 public void testClassifyText_date() throws IOException { 261 String text = "Let's meet on January 9, 2018."; 262 String classifiedText = "January 9, 2018"; 263 int startIndex = text.indexOf(classifiedText); 264 int endIndex = startIndex + classifiedText.length(); 265 TextClassification.Request request = 266 new TextClassification.Request.Builder(text, startIndex, endIndex).build(); 267 268 TextClassification classification = classifier.classifyText(null, null, request); 269 assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_DATE)); 270 Bundle extras = classification.getExtras(); 271 List<Bundle> entities = ExtrasUtils.getEntities(extras); 272 assertThat(entities).hasSize(1); 273 assertThat(ExtrasUtils.getEntityType(entities.get(0))).isEqualTo(TextClassifier.TYPE_DATE); 274 ArrayList<Intent> actionsIntents = ExtrasUtils.getActionsIntents(classification); 275 actionsIntents.forEach(TextClassifierImplTest::assertNoPackageInfoInExtras); 276 } 277 278 @Test testClassifyText_datetime()279 public void testClassifyText_datetime() throws IOException { 280 String text = "Let's meet 2018/01/01 10:30:20."; 281 String classifiedText = "2018/01/01 10:30:20"; 282 int startIndex = text.indexOf(classifiedText); 283 int endIndex = startIndex + classifiedText.length(); 284 TextClassification.Request request = 285 new TextClassification.Request.Builder(text, startIndex, endIndex).build(); 286 287 TextClassification classification = classifier.classifyText(null, null, request); 288 assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_DATE_TIME)); 289 } 290 291 @Test testClassifyText_foreignText()292 public void testClassifyText_foreignText() throws IOException { 293 LocaleList originalLocales = LocaleList.getDefault(); 294 LocaleList.setDefault(LocaleList.forLanguageTags("en")); 295 String japaneseText = "これは日本語のテキストです"; 296 TextClassification.Request request = 297 new TextClassification.Request.Builder(japaneseText, 0, japaneseText.length()).build(); 298 299 TextClassification classification = classifier.classifyText(null, null, request); 300 RemoteAction translateAction = classification.getActions().get(0); 301 assertEquals(1, classification.getActions().size()); 302 assertEquals(Intent.ACTION_TRANSLATE, classification.getIntent().getAction()); 303 304 assertEquals(translateAction, ExtrasUtils.findTranslateAction(classification)); 305 Intent intent = ExtrasUtils.getActionsIntents(classification).get(0); 306 assertNoPackageInfoInExtras(intent); 307 assertEquals(Intent.ACTION_TRANSLATE, intent.getAction()); 308 Bundle foreignLanguageInfo = ExtrasUtils.getForeignLanguageExtra(classification); 309 assertEquals("ja", ExtrasUtils.getEntityType(foreignLanguageInfo)); 310 assertTrue(ExtrasUtils.getScore(foreignLanguageInfo) >= 0); 311 assertTrue(ExtrasUtils.getScore(foreignLanguageInfo) <= 1); 312 assertTrue(intent.hasExtra(TextClassifier.EXTRA_FROM_TEXT_CLASSIFIER)); 313 assertEquals("ja", ExtrasUtils.getTopLanguage(intent).first); 314 315 LocaleList.setDefault(originalLocales); 316 } 317 318 @Test testGenerateLinks_phone()319 public void testGenerateLinks_phone() throws IOException { 320 String text = "The number is +12122537077. See you tonight!"; 321 TextLinks.Request request = new TextLinks.Request.Builder(text).build(); 322 assertThat( 323 classifier.generateLinks(null, null, request), 324 isTextLinksContaining(text, "+12122537077", TextClassifier.TYPE_PHONE)); 325 } 326 327 @Test testGenerateLinks_exclude()328 public void testGenerateLinks_exclude() throws IOException { 329 String text = "The number is +12122537077. See you tonight!"; 330 List<String> hints = ImmutableList.of(); 331 List<String> included = ImmutableList.of(); 332 List<String> excluded = Arrays.asList(TextClassifier.TYPE_PHONE); 333 TextLinks.Request request = 334 new TextLinks.Request.Builder(text) 335 .setEntityConfig(TextClassifier.EntityConfig.create(hints, included, excluded)) 336 .build(); 337 assertThat( 338 classifier.generateLinks(null, null, request), 339 not(isTextLinksContaining(text, "+12122537077", TextClassifier.TYPE_PHONE))); 340 } 341 342 @Test testGenerateLinks_explicit_address()343 public void testGenerateLinks_explicit_address() throws IOException { 344 String text = "The address is 1600 Amphitheater Parkway, Mountain View, CA. See you!"; 345 List<String> explicit = Arrays.asList(TextClassifier.TYPE_ADDRESS); 346 TextLinks.Request request = 347 new TextLinks.Request.Builder(text) 348 .setEntityConfig(TextClassifier.EntityConfig.createWithExplicitEntityList(explicit)) 349 .build(); 350 assertThat( 351 classifier.generateLinks(null, null, request), 352 isTextLinksContaining( 353 text, "1600 Amphitheater Parkway, Mountain View, CA", TextClassifier.TYPE_ADDRESS)); 354 } 355 356 @Test testGenerateLinks_exclude_override()357 public void testGenerateLinks_exclude_override() throws IOException { 358 String text = "You want apple@banana.com. See you tonight!"; 359 List<String> hints = ImmutableList.of(); 360 List<String> included = Arrays.asList(TextClassifier.TYPE_EMAIL); 361 List<String> excluded = Arrays.asList(TextClassifier.TYPE_EMAIL); 362 TextLinks.Request request = 363 new TextLinks.Request.Builder(text) 364 .setEntityConfig(TextClassifier.EntityConfig.create(hints, included, excluded)) 365 .build(); 366 assertThat( 367 classifier.generateLinks(null, null, request), 368 not(isTextLinksContaining(text, "apple@banana.com", TextClassifier.TYPE_EMAIL))); 369 } 370 371 @Test testGenerateLinks_maxLength()372 public void testGenerateLinks_maxLength() throws IOException { 373 char[] manySpaces = new char[classifier.getMaxGenerateLinksTextLength()]; 374 Arrays.fill(manySpaces, ' '); 375 TextLinks.Request request = new TextLinks.Request.Builder(new String(manySpaces)).build(); 376 TextLinks links = classifier.generateLinks(null, null, request); 377 assertTrue(links.getLinks().isEmpty()); 378 } 379 380 @Test testApplyLinks_unsupportedCharacter()381 public void testApplyLinks_unsupportedCharacter() throws IOException { 382 Spannable url = new SpannableString("\u202Emoc.diordna.com"); 383 TextLinks.Request request = new TextLinks.Request.Builder(url).build(); 384 assertEquals( 385 TextLinks.STATUS_UNSUPPORTED_CHARACTER, 386 classifier.generateLinks(null, null, request).apply(url, 0, null)); 387 } 388 389 @Test testGenerateLinks_tooLong()390 public void testGenerateLinks_tooLong() { 391 char[] manySpaces = new char[classifier.getMaxGenerateLinksTextLength() + 1]; 392 Arrays.fill(manySpaces, ' '); 393 TextLinks.Request request = new TextLinks.Request.Builder(new String(manySpaces)).build(); 394 expectThrows( 395 IllegalArgumentException.class, () -> classifier.generateLinks(null, null, request)); 396 } 397 398 @Test testGenerateLinks_entityData()399 public void testGenerateLinks_entityData() throws IOException { 400 String text = "The number is +12122537077."; 401 Bundle extras = new Bundle(); 402 ExtrasUtils.putIsSerializedEntityDataEnabled(extras, true); 403 TextLinks.Request request = new TextLinks.Request.Builder(text).setExtras(extras).build(); 404 405 TextLinks textLinks = classifier.generateLinks(null, null, request); 406 407 assertThat(textLinks.getLinks()).hasSize(1); 408 TextLinks.TextLink textLink = textLinks.getLinks().iterator().next(); 409 List<Bundle> entities = ExtrasUtils.getEntities(textLink.getExtras()); 410 assertThat(entities).hasSize(1); 411 Bundle entity = entities.get(0); 412 assertThat(ExtrasUtils.getEntityType(entity)).isEqualTo(TextClassifier.TYPE_PHONE); 413 } 414 415 @Test testGenerateLinks_entityData_disabled()416 public void testGenerateLinks_entityData_disabled() throws IOException { 417 String text = "The number is +12122537077."; 418 TextLinks.Request request = new TextLinks.Request.Builder(text).build(); 419 420 TextLinks textLinks = classifier.generateLinks(null, null, request); 421 422 assertThat(textLinks.getLinks()).hasSize(1); 423 TextLinks.TextLink textLink = textLinks.getLinks().iterator().next(); 424 List<Bundle> entities = ExtrasUtils.getEntities(textLink.getExtras()); 425 assertThat(entities).isNull(); 426 } 427 428 @Test testDetectLanguage()429 public void testDetectLanguage() throws IOException { 430 String text = "This is English text"; 431 TextLanguage.Request request = new TextLanguage.Request.Builder(text).build(); 432 TextLanguage textLanguage = classifier.detectLanguage(null, null, request); 433 assertThat(textLanguage, isTextLanguage("en")); 434 } 435 436 @Test testDetectLanguage_japanese()437 public void testDetectLanguage_japanese() throws IOException { 438 String text = "これは日本語のテキストです"; 439 TextLanguage.Request request = new TextLanguage.Request.Builder(text).build(); 440 TextLanguage textLanguage = classifier.detectLanguage(null, null, request); 441 assertThat(textLanguage, isTextLanguage("ja")); 442 } 443 444 @Test testSuggestConversationActions_textReplyOnly_maxOne()445 public void testSuggestConversationActions_textReplyOnly_maxOne() throws IOException { 446 ConversationActions.Message message = 447 new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS) 448 .setText("Where are you?") 449 .build(); 450 TextClassifier.EntityConfig typeConfig = 451 new TextClassifier.EntityConfig.Builder() 452 .includeTypesFromTextClassifier(false) 453 .setIncludedTypes(Collections.singletonList(ConversationAction.TYPE_TEXT_REPLY)) 454 .build(); 455 ConversationActions.Request request = 456 new ConversationActions.Request.Builder(Collections.singletonList(message)) 457 .setMaxSuggestions(1) 458 .setTypeConfig(typeConfig) 459 .build(); 460 461 ConversationActions conversationActions = 462 classifier.suggestConversationActions(null, null, request); 463 assertThat(conversationActions.getConversationActions()).hasSize(1); 464 ConversationAction conversationAction = conversationActions.getConversationActions().get(0); 465 assertThat(conversationAction.getType()).isEqualTo(ConversationAction.TYPE_TEXT_REPLY); 466 assertThat(conversationAction.getTextReply()).isNotNull(); 467 } 468 469 @Test testSuggestConversationActions_textReplyOnly_noMax()470 public void testSuggestConversationActions_textReplyOnly_noMax() throws IOException { 471 ConversationActions.Message message = 472 new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS) 473 .setText("Where are you?") 474 .build(); 475 TextClassifier.EntityConfig typeConfig = 476 new TextClassifier.EntityConfig.Builder() 477 .includeTypesFromTextClassifier(false) 478 .setIncludedTypes(Collections.singletonList(ConversationAction.TYPE_TEXT_REPLY)) 479 .build(); 480 ConversationActions.Request request = 481 new ConversationActions.Request.Builder(Collections.singletonList(message)) 482 .setTypeConfig(typeConfig) 483 .build(); 484 485 ConversationActions conversationActions = 486 classifier.suggestConversationActions(null, null, request); 487 assertTrue(conversationActions.getConversationActions().size() > 1); 488 for (ConversationAction conversationAction : conversationActions.getConversationActions()) { 489 assertThat(conversationAction, isConversationAction(ConversationAction.TYPE_TEXT_REPLY)); 490 } 491 } 492 493 @Test testSuggestConversationActions_openUrl()494 public void testSuggestConversationActions_openUrl() throws IOException { 495 ConversationActions.Message message = 496 new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS) 497 .setText("Check this out: https://www.android.com") 498 .build(); 499 TextClassifier.EntityConfig typeConfig = 500 new TextClassifier.EntityConfig.Builder() 501 .includeTypesFromTextClassifier(false) 502 .setIncludedTypes(Collections.singletonList(ConversationAction.TYPE_OPEN_URL)) 503 .build(); 504 ConversationActions.Request request = 505 new ConversationActions.Request.Builder(Collections.singletonList(message)) 506 .setMaxSuggestions(1) 507 .setTypeConfig(typeConfig) 508 .build(); 509 510 ConversationActions conversationActions = 511 classifier.suggestConversationActions(null, null, request); 512 assertThat(conversationActions.getConversationActions()).hasSize(1); 513 ConversationAction conversationAction = conversationActions.getConversationActions().get(0); 514 assertThat(conversationAction.getType()).isEqualTo(ConversationAction.TYPE_OPEN_URL); 515 Intent actionIntent = ExtrasUtils.getActionIntent(conversationAction.getExtras()); 516 assertThat(actionIntent.getAction()).isEqualTo(Intent.ACTION_VIEW); 517 assertThat(actionIntent.getData()).isEqualTo(Uri.parse("https://www.android.com")); 518 assertNoPackageInfoInExtras(actionIntent); 519 } 520 521 @Test testSuggestConversationActions_copy()522 public void testSuggestConversationActions_copy() throws IOException { 523 ConversationActions.Message message = 524 new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS) 525 .setText("Authentication code: 12345") 526 .build(); 527 TextClassifier.EntityConfig typeConfig = 528 new TextClassifier.EntityConfig.Builder() 529 .includeTypesFromTextClassifier(false) 530 .setIncludedTypes(Collections.singletonList(TYPE_COPY)) 531 .build(); 532 ConversationActions.Request request = 533 new ConversationActions.Request.Builder(Collections.singletonList(message)) 534 .setMaxSuggestions(1) 535 .setTypeConfig(typeConfig) 536 .build(); 537 538 ConversationActions conversationActions = 539 classifier.suggestConversationActions(null, null, request); 540 assertThat(conversationActions.getConversationActions()).hasSize(1); 541 ConversationAction conversationAction = conversationActions.getConversationActions().get(0); 542 assertThat(conversationAction.getType()).isEqualTo(TYPE_COPY); 543 assertThat(conversationAction.getTextReply()).isAnyOf(null, ""); 544 assertThat(conversationAction.getAction()).isNull(); 545 String code = ExtrasUtils.getCopyText(conversationAction.getExtras()); 546 assertThat(code).isEqualTo("12345"); 547 assertThat(ExtrasUtils.getSerializedEntityData(conversationAction.getExtras())).isNotEmpty(); 548 } 549 550 @Test testSuggestConversationActions_deduplicate()551 public void testSuggestConversationActions_deduplicate() throws IOException { 552 ConversationActions.Message message = 553 new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS) 554 .setText("a@android.com b@android.com") 555 .build(); 556 ConversationActions.Request request = 557 new ConversationActions.Request.Builder(Collections.singletonList(message)) 558 .setMaxSuggestions(3) 559 .build(); 560 561 ConversationActions conversationActions = 562 classifier.suggestConversationActions(null, null, request); 563 564 assertThat(conversationActions.getConversationActions()).isEmpty(); 565 } 566 567 @Test testUseCachedAnnotatorModelDisabled()568 public void testUseCachedAnnotatorModelDisabled() throws IOException { 569 deviceConfig.setConfig(TextClassifierSettings.MODEL_DOWNLOAD_MANAGER_ENABLED, true); 570 571 String annotatorFilePath = TestDataUtils.getTestAnnotatorModelFile().getPath(); 572 ModelFile annotatorModelA = 573 new ModelFile(ModelType.ANNOTATOR, annotatorFilePath, 701, "en", false); 574 ModelFile annotatorModelB = 575 new ModelFile(ModelType.ANNOTATOR, annotatorFilePath, 801, "en", false); 576 577 String englishText = "You can reach me on +12122537077."; 578 String classifiedText = "+12122537077"; 579 TextClassification.Request request = 580 new TextClassification.Request.Builder(englishText, 0, englishText.length()).build(); 581 582 // Check modelFileA v701 583 when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any())) 584 .thenReturn(annotatorModelA); 585 TextClassification classificationA = classifier.classifyText(null, null, request); 586 587 assertThat(classificationA.getId()).contains("v701"); 588 assertThat(classificationA.getText()).contains(classifiedText); 589 assertArrayEquals( 590 new int[] {0, 0, 0, 0}, 591 new int[] { 592 annotatorModelCache.putCount(), 593 annotatorModelCache.evictionCount(), 594 annotatorModelCache.hitCount(), 595 annotatorModelCache.missCount() 596 }); 597 598 // Check modelFileB v801 599 when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any())) 600 .thenReturn(annotatorModelB); 601 TextClassification classificationB = classifier.classifyText(null, null, request); 602 603 assertThat(classificationB.getId()).contains("v801"); 604 assertThat(classificationB.getText()).contains(classifiedText); 605 assertArrayEquals( 606 new int[] {0, 0, 0, 0}, 607 new int[] { 608 annotatorModelCache.putCount(), 609 annotatorModelCache.evictionCount(), 610 annotatorModelCache.hitCount(), 611 annotatorModelCache.missCount() 612 }); 613 614 // Reload modelFileA v701 615 when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any())) 616 .thenReturn(annotatorModelA); 617 TextClassification classificationAcached = classifier.classifyText(null, null, request); 618 619 assertThat(classificationAcached.getId()).contains("v701"); 620 assertThat(classificationAcached.getText()).contains(classifiedText); 621 assertArrayEquals( 622 new int[] {0, 0, 0, 0}, 623 new int[] { 624 annotatorModelCache.putCount(), 625 annotatorModelCache.evictionCount(), 626 annotatorModelCache.hitCount(), 627 annotatorModelCache.missCount() 628 }); 629 } 630 631 @Test testUseCachedAnnotatorModelEnabled()632 public void testUseCachedAnnotatorModelEnabled() throws IOException { 633 deviceConfig.setConfig(TextClassifierSettings.MODEL_DOWNLOAD_MANAGER_ENABLED, true); 634 deviceConfig.setConfig(TextClassifierSettings.MULTI_ANNOTATOR_CACHE_ENABLED, true); 635 636 String annotatorFilePath = TestDataUtils.getTestAnnotatorModelFile().getPath(); 637 ModelFile annotatorModelA = 638 new ModelFile(ModelType.ANNOTATOR, annotatorFilePath, 701, "en", false); 639 ModelFile annotatorModelB = 640 new ModelFile(ModelType.ANNOTATOR, annotatorFilePath, 801, "en", false); 641 642 String englishText = "You can reach me on +12122537077."; 643 String classifiedText = "+12122537077"; 644 TextClassification.Request request = 645 new TextClassification.Request.Builder(englishText, 0, englishText.length()).build(); 646 647 // Check modelFileA v701 648 when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any())) 649 .thenReturn(annotatorModelA); 650 TextClassification classification = classifier.classifyText(null, null, request); 651 652 assertThat(classification.getId()).contains("v701"); 653 assertThat(classification.getText()).contains(classifiedText); 654 assertArrayEquals( 655 new int[] {1, 0, 0, 1}, 656 new int[] { 657 annotatorModelCache.putCount(), 658 annotatorModelCache.evictionCount(), 659 annotatorModelCache.hitCount(), 660 annotatorModelCache.missCount() 661 }); 662 663 // Check modelFileB v801 664 when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any())) 665 .thenReturn(annotatorModelB); 666 TextClassification classificationB = classifier.classifyText(null, null, request); 667 668 assertThat(classificationB.getId()).contains("v801"); 669 assertThat(classificationB.getText()).contains(classifiedText); 670 assertArrayEquals( 671 new int[] {2, 0, 0, 2}, 672 new int[] { 673 annotatorModelCache.putCount(), 674 annotatorModelCache.evictionCount(), 675 annotatorModelCache.hitCount(), 676 annotatorModelCache.missCount() 677 }); 678 679 // Reload modelFileA v701 680 when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any())) 681 .thenReturn(annotatorModelA); 682 TextClassification classificationAcached = classifier.classifyText(null, null, request); 683 684 assertThat(classificationAcached.getId()).contains("v701"); 685 assertThat(classificationAcached.getText()).contains(classifiedText); 686 assertArrayEquals( 687 new int[] {2, 0, 1, 2}, 688 new int[] { 689 annotatorModelCache.putCount(), 690 annotatorModelCache.evictionCount(), 691 annotatorModelCache.hitCount(), 692 annotatorModelCache.missCount() 693 }); 694 } 695 assertNoPackageInfoInExtras(Intent intent)696 private static void assertNoPackageInfoInExtras(Intent intent) { 697 assertThat(intent.getComponent()).isNull(); 698 assertThat(intent.getPackage()).isNull(); 699 } 700 isTextSelection( final int startIndex, final int endIndex, final String type)701 private static Matcher<TextSelection> isTextSelection( 702 final int startIndex, final int endIndex, final String type) { 703 return new BaseMatcher<TextSelection>() { 704 @Override 705 public boolean matches(Object o) { 706 if (o instanceof TextSelection) { 707 TextSelection selection = (TextSelection) o; 708 return startIndex == selection.getSelectionStartIndex() 709 && endIndex == selection.getSelectionEndIndex() 710 && typeMatches(selection, type); 711 } 712 return false; 713 } 714 715 private boolean typeMatches(TextSelection selection, String type) { 716 return type == null 717 || (selection.getEntityCount() > 0 718 && type.trim().equalsIgnoreCase(selection.getEntity(0))); 719 } 720 721 @Override 722 public void describeTo(Description description) { 723 description.appendValue(String.format("%d, %d, %s", startIndex, endIndex, type)); 724 } 725 }; 726 } 727 728 private static Matcher<TextLinks> isTextLinksContaining( 729 final String text, final String substring, final String type) { 730 return new BaseMatcher<TextLinks>() { 731 732 @Override 733 public void describeTo(Description description) { 734 description 735 .appendText("text=") 736 .appendValue(text) 737 .appendText(", substring=") 738 .appendValue(substring) 739 .appendText(", type=") 740 .appendValue(type); 741 } 742 743 @Override 744 public boolean matches(Object o) { 745 if (o instanceof TextLinks) { 746 for (TextLinks.TextLink link : ((TextLinks) o).getLinks()) { 747 if (text.subSequence(link.getStart(), link.getEnd()).toString().equals(substring)) { 748 return type.equals(link.getEntity(0)); 749 } 750 } 751 } 752 return false; 753 } 754 }; 755 } 756 757 private static Matcher<TextClassification> isTextClassification( 758 final String text, final String type) { 759 return new BaseMatcher<TextClassification>() { 760 @Override 761 public boolean matches(Object o) { 762 if (o instanceof TextClassification) { 763 TextClassification result = (TextClassification) o; 764 return text.equals(result.getText()) 765 && result.getEntityCount() > 0 766 && type.equals(result.getEntity(0)); 767 } 768 return false; 769 } 770 771 @Override 772 public void describeTo(Description description) { 773 description.appendText("text=").appendValue(text).appendText(", type=").appendValue(type); 774 } 775 }; 776 } 777 778 private static Matcher<TextClassification> containsIntentWithAction(final String action) { 779 return new BaseMatcher<TextClassification>() { 780 @Override 781 public boolean matches(Object o) { 782 if (o instanceof TextClassification) { 783 TextClassification result = (TextClassification) o; 784 return ExtrasUtils.findAction(result, action) != null; 785 } 786 return false; 787 } 788 789 @Override 790 public void describeTo(Description description) { 791 description.appendText("intent action=").appendValue(action); 792 } 793 }; 794 } 795 796 private static Matcher<TextLanguage> isTextLanguage(final String languageTag) { 797 return new BaseMatcher<TextLanguage>() { 798 @Override 799 public boolean matches(Object o) { 800 if (o instanceof TextLanguage) { 801 TextLanguage result = (TextLanguage) o; 802 return result.getLocaleHypothesisCount() > 0 803 && languageTag.equals(result.getLocale(0).toLanguageTag()); 804 } 805 return false; 806 } 807 808 @Override 809 public void describeTo(Description description) { 810 description.appendText("locale=").appendValue(languageTag); 811 } 812 }; 813 } 814 815 private static Matcher<ConversationAction> isConversationAction(String actionType) { 816 return new BaseMatcher<ConversationAction>() { 817 @Override 818 public boolean matches(Object o) { 819 if (!(o instanceof ConversationAction)) { 820 return false; 821 } 822 ConversationAction conversationAction = (ConversationAction) o; 823 if (!actionType.equals(conversationAction.getType())) { 824 return false; 825 } 826 if (ConversationAction.TYPE_TEXT_REPLY.equals(actionType)) { 827 if (conversationAction.getTextReply() == null) { 828 return false; 829 } 830 } 831 if (conversationAction.getConfidenceScore() < 0 832 || conversationAction.getConfidenceScore() > 1) { 833 return false; 834 } 835 return true; 836 } 837 838 @Override 839 public void describeTo(Description description) { 840 description.appendText("actionType=").appendValue(actionType); 841 } 842 }; 843 } 844 } 845