1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.textclassifier; 18 19 import static com.google.common.truth.Truth.assertThat; 20 import static org.hamcrest.CoreMatchers.not; 21 import static org.junit.Assert.assertArrayEquals; 22 import static org.junit.Assert.assertEquals; 23 import static org.junit.Assert.assertThat; 24 import static org.junit.Assert.assertTrue; 25 import static org.mockito.Mockito.any; 26 import static org.mockito.Mockito.eq; 27 import static org.mockito.Mockito.verify; 28 import static org.mockito.Mockito.when; 29 import static org.testng.Assert.expectThrows; 30 31 import android.app.RemoteAction; 32 import android.content.Context; 33 import android.content.Intent; 34 import android.net.Uri; 35 import android.os.Bundle; 36 import android.os.LocaleList; 37 import android.permission.flags.Flags; 38 import android.platform.test.annotations.RequiresFlagsEnabled; 39 import android.text.Spannable; 40 import android.text.SpannableString; 41 import android.view.textclassifier.ConversationAction; 42 import android.view.textclassifier.ConversationActions; 43 import android.view.textclassifier.TextClassification; 44 import android.view.textclassifier.TextClassifier; 45 import android.view.textclassifier.TextLanguage; 46 import android.view.textclassifier.TextLinks; 47 import android.view.textclassifier.TextSelection; 48 import androidx.collection.LruCache; 49 import androidx.test.ext.junit.runners.AndroidJUnit4; 50 import androidx.test.filters.SdkSuppress; 51 import androidx.test.filters.SmallTest; 52 import com.android.textclassifier.common.ModelFile; 53 import com.android.textclassifier.common.ModelType; 54 import com.android.textclassifier.common.TextClassifierSettings; 55 import com.android.textclassifier.testing.FakeContextBuilder; 56 import com.android.textclassifier.testing.TestingDeviceConfig; 57 import com.android.textclassifier.utils.TextClassifierUtils; 58 59 import com.google.android.textclassifier.AnnotatorModel; 60 import com.google.common.collect.ImmutableList; 61 import java.io.IOException; 62 import java.util.ArrayList; 63 import java.util.Arrays; 64 import java.util.Collections; 65 import java.util.List; 66 import org.hamcrest.BaseMatcher; 67 import org.hamcrest.Description; 68 import org.hamcrest.Matcher; 69 import org.junit.Assume; 70 import org.junit.Before; 71 import org.junit.Test; 72 import org.junit.runner.RunWith; 73 import org.mockito.Mock; 74 import org.mockito.MockitoAnnotations; 75 76 @SmallTest 77 @RunWith(AndroidJUnit4.class) 78 public class TextClassifierImplTest { 79 80 private static final String TYPE_COPY = "copy"; 81 private static final LocaleList LOCALES = LocaleList.forLanguageTags("en-US"); 82 private static final String NO_TYPE = null; 83 84 @Mock private ModelFileManager modelFileManager; 85 86 private Context context; 87 private TestingDeviceConfig deviceConfig; 88 private TextClassifierSettings settings; 89 private LruCache<ModelFile, AnnotatorModel> annotatorModelCache; 90 private TextClassifierImpl classifier; 91 92 @Before setup()93 public void setup() throws IOException { 94 MockitoAnnotations.initMocks(this); 95 this.context = 96 new FakeContextBuilder() 97 .setAllIntentComponent(FakeContextBuilder.DEFAULT_COMPONENT) 98 .setAppLabel(FakeContextBuilder.DEFAULT_COMPONENT.getPackageName(), "Test app") 99 .build(); 100 this.deviceConfig = new TestingDeviceConfig(); 101 this.settings = new TextClassifierSettings(deviceConfig, /* isWear= */ false); 102 this.annotatorModelCache = new LruCache<>(2); 103 this.classifier = 104 new TextClassifierImpl(context, settings, modelFileManager, annotatorModelCache); 105 106 when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any())) 107 .thenReturn(TestDataUtils.getTestAnnotatorModelFileWrapped()); 108 when(modelFileManager.findBestModelFile(eq(ModelType.LANG_ID), any(), any())) 109 .thenReturn(TestDataUtils.getLangIdModelFileWrapped()); 110 when(modelFileManager.findBestModelFile(eq(ModelType.ACTIONS_SUGGESTIONS), any(), any())) 111 .thenReturn(TestDataUtils.getTestActionsModelFileWrapped()); 112 } 113 114 @Test testSuggestSelection()115 public void testSuggestSelection() throws IOException { 116 String text = "Contact me at droid@android.com"; 117 String selected = "droid"; 118 String suggested = "droid@android.com"; 119 int startIndex = text.indexOf(selected); 120 int endIndex = startIndex + selected.length(); 121 int smartStartIndex = text.indexOf(suggested); 122 int smartEndIndex = smartStartIndex + suggested.length(); 123 TextSelection.Request request = 124 new TextSelection.Request.Builder(text, startIndex, endIndex).build(); 125 126 TextSelection selection = classifier.suggestSelection(null, null, request); 127 assertThat( 128 selection, isTextSelection(smartStartIndex, smartEndIndex, TextClassifier.TYPE_EMAIL)); 129 } 130 131 @Test testSuggestSelection_localePreferenceIsPassedToModelFileManager()132 public void testSuggestSelection_localePreferenceIsPassedToModelFileManager() throws IOException { 133 String text = "Contact me at droid@android.com"; 134 String selected = "droid"; 135 String suggested = "droid@android.com"; 136 int startIndex = text.indexOf(selected); 137 int endIndex = startIndex + selected.length(); 138 int smartStartIndex = text.indexOf(suggested); 139 int smartEndIndex = smartStartIndex + suggested.length(); 140 TextSelection.Request request = 141 new TextSelection.Request.Builder(text, startIndex, endIndex) 142 .setDefaultLocales(LOCALES) 143 .build(); 144 145 classifier.suggestSelection(null, null, request); 146 verify(modelFileManager).findBestModelFile(eq(ModelType.ANNOTATOR), eq(LOCALES), any()); 147 } 148 149 @Test testSuggestSelection_url()150 public void testSuggestSelection_url() throws IOException { 151 String text = "Visit http://www.android.com for more information"; 152 String selected = "http"; 153 String suggested = "http://www.android.com"; 154 int startIndex = text.indexOf(selected); 155 int endIndex = startIndex + selected.length(); 156 int smartStartIndex = text.indexOf(suggested); 157 int smartEndIndex = smartStartIndex + suggested.length(); 158 TextSelection.Request request = 159 new TextSelection.Request.Builder(text, startIndex, endIndex).build(); 160 161 TextSelection selection = classifier.suggestSelection(null, null, request); 162 assertThat(selection, isTextSelection(smartStartIndex, smartEndIndex, TextClassifier.TYPE_URL)); 163 } 164 165 @Test testSmartSelection_withEmoji()166 public void testSmartSelection_withEmoji() throws IOException { 167 String text = "\uD83D\uDE02 Hello."; 168 String selected = "Hello"; 169 int startIndex = text.indexOf(selected); 170 int endIndex = startIndex + selected.length(); 171 TextSelection.Request request = 172 new TextSelection.Request.Builder(text, startIndex, endIndex).build(); 173 174 TextSelection selection = classifier.suggestSelection(null, null, request); 175 assertThat(selection, isTextSelection(startIndex, endIndex, NO_TYPE)); 176 } 177 178 @SdkSuppress(minSdkVersion = 31, codeName = "S") 179 @Test testSuggestSelection_includeTextClassification()180 public void testSuggestSelection_includeTextClassification() throws IOException { 181 String text = "Visit http://www.android.com for more information"; 182 String suggested = "http://www.android.com"; 183 int startIndex = text.indexOf(suggested); 184 TextSelection.Request request = 185 new TextSelection.Request.Builder(text, startIndex, /* endIndex= */ startIndex + 1) 186 .setIncludeTextClassification(true) 187 .build(); 188 189 TextSelection selection = classifier.suggestSelection(null, null, request); 190 191 assertThat( 192 selection.getTextClassification(), 193 isTextClassification(suggested, TextClassifier.TYPE_URL)); 194 assertThat(selection.getTextClassification(), containsIntentWithAction(Intent.ACTION_VIEW)); 195 } 196 197 @SdkSuppress(minSdkVersion = 31, codeName = "S") 198 @Test testSuggestSelection_notIncludeTextClassification()199 public void testSuggestSelection_notIncludeTextClassification() throws IOException { 200 String text = "Visit http://www.android.com for more information"; 201 TextSelection.Request request = 202 new TextSelection.Request.Builder(text, /* startIndex= */ 0, /* endIndex= */ 4) 203 .setIncludeTextClassification(false) 204 .build(); 205 206 TextSelection selection = classifier.suggestSelection(null, null, request); 207 208 assertThat(selection.getTextClassification()).isNull(); 209 } 210 211 @Test testClassifyText()212 public void testClassifyText() throws IOException { 213 String text = "Contact me at droid@android.com"; 214 String classifiedText = "droid@android.com"; 215 int startIndex = text.indexOf(classifiedText); 216 int endIndex = startIndex + classifiedText.length(); 217 TextClassification.Request request = 218 new TextClassification.Request.Builder(text, startIndex, endIndex).build(); 219 220 TextClassification classification = 221 classifier.classifyText(/* sessionId= */ null, null, request); 222 assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_EMAIL)); 223 } 224 225 @Test testClassifyText_url()226 public void testClassifyText_url() throws IOException { 227 String text = "Visit www.android.com for more information"; 228 String classifiedText = "www.android.com"; 229 int startIndex = text.indexOf(classifiedText); 230 int endIndex = startIndex + classifiedText.length(); 231 TextClassification.Request request = 232 new TextClassification.Request.Builder(text, startIndex, endIndex).build(); 233 234 TextClassification classification = classifier.classifyText(null, null, request); 235 assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_URL)); 236 assertThat(classification, containsIntentWithAction(Intent.ACTION_VIEW)); 237 } 238 239 @Test testClassifyText_address()240 public void testClassifyText_address() throws IOException { 241 String text = "Brandschenkestrasse 110, Zürich, Switzerland"; 242 TextClassification.Request request = 243 new TextClassification.Request.Builder(text, 0, text.length()).build(); 244 245 TextClassification classification = classifier.classifyText(null, null, request); 246 assertThat(classification, isTextClassification(text, TextClassifier.TYPE_ADDRESS)); 247 } 248 249 @Test testClassifyText_url_inCaps()250 public void testClassifyText_url_inCaps() throws IOException { 251 String text = "Visit HTTP://ANDROID.COM for more information"; 252 String classifiedText = "HTTP://ANDROID.COM"; 253 int startIndex = text.indexOf(classifiedText); 254 int endIndex = startIndex + classifiedText.length(); 255 TextClassification.Request request = 256 new TextClassification.Request.Builder(text, startIndex, endIndex).build(); 257 258 TextClassification classification = classifier.classifyText(null, null, request); 259 assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_URL)); 260 assertThat(classification, containsIntentWithAction(Intent.ACTION_VIEW)); 261 } 262 263 @Test testClassifyText_date()264 public void testClassifyText_date() throws IOException { 265 String text = "Let's meet on January 9, 2018."; 266 String classifiedText = "January 9, 2018"; 267 int startIndex = text.indexOf(classifiedText); 268 int endIndex = startIndex + classifiedText.length(); 269 TextClassification.Request request = 270 new TextClassification.Request.Builder(text, startIndex, endIndex).build(); 271 272 TextClassification classification = classifier.classifyText(null, null, request); 273 assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_DATE)); 274 Bundle extras = classification.getExtras(); 275 List<Bundle> entities = ExtrasUtils.getEntities(extras); 276 assertThat(entities).hasSize(1); 277 assertThat(ExtrasUtils.getEntityType(entities.get(0))).isEqualTo(TextClassifier.TYPE_DATE); 278 ArrayList<Intent> actionsIntents = ExtrasUtils.getActionsIntents(classification); 279 actionsIntents.forEach(TextClassifierImplTest::assertNoPackageInfoInExtras); 280 } 281 282 @Test testClassifyText_datetime()283 public void testClassifyText_datetime() throws IOException { 284 String text = "Let's meet 2018/01/01 10:30:20."; 285 String classifiedText = "2018/01/01 10:30:20"; 286 int startIndex = text.indexOf(classifiedText); 287 int endIndex = startIndex + classifiedText.length(); 288 TextClassification.Request request = 289 new TextClassification.Request.Builder(text, startIndex, endIndex).build(); 290 291 TextClassification classification = classifier.classifyText(null, null, request); 292 assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_DATE_TIME)); 293 } 294 295 @Test testClassifyText_foreignText()296 public void testClassifyText_foreignText() throws IOException { 297 LocaleList originalLocales = LocaleList.getDefault(); 298 LocaleList.setDefault(LocaleList.forLanguageTags("en")); 299 String japaneseText = "これは日本語のテキストです"; 300 TextClassification.Request request = 301 new TextClassification.Request.Builder(japaneseText, 0, japaneseText.length()).build(); 302 303 TextClassification classification = classifier.classifyText(null, null, request); 304 RemoteAction translateAction = classification.getActions().get(0); 305 assertEquals(1, classification.getActions().size()); 306 assertEquals(Intent.ACTION_TRANSLATE, classification.getIntent().getAction()); 307 308 assertEquals(translateAction, ExtrasUtils.findTranslateAction(classification)); 309 Intent intent = ExtrasUtils.getActionsIntents(classification).get(0); 310 assertNoPackageInfoInExtras(intent); 311 assertEquals(Intent.ACTION_TRANSLATE, intent.getAction()); 312 Bundle foreignLanguageInfo = ExtrasUtils.getForeignLanguageExtra(classification); 313 assertEquals("ja", ExtrasUtils.getEntityType(foreignLanguageInfo)); 314 assertTrue(ExtrasUtils.getScore(foreignLanguageInfo) >= 0); 315 assertTrue(ExtrasUtils.getScore(foreignLanguageInfo) <= 1); 316 assertTrue(intent.hasExtra(TextClassifier.EXTRA_FROM_TEXT_CLASSIFIER)); 317 assertEquals("ja", ExtrasUtils.getTopLanguage(intent).first); 318 319 LocaleList.setDefault(originalLocales); 320 } 321 322 @Test testGenerateLinks_phone()323 public void testGenerateLinks_phone() throws IOException { 324 String text = "The number is +12122537077. See you tonight!"; 325 TextLinks.Request request = new TextLinks.Request.Builder(text).build(); 326 assertThat( 327 classifier.generateLinks(null, null, request), 328 isTextLinksContaining(text, "+12122537077", TextClassifier.TYPE_PHONE)); 329 } 330 331 @Test 332 @RequiresFlagsEnabled(Flags.FLAG_ENABLE_OTP_IN_TEXT_CLASSIFIERS) testGenerateLinks_otp()333 public void testGenerateLinks_otp() throws IOException { 334 Assume.assumeTrue(TextClassifierUtils.isOtpClassificationEnabled()); 335 String text = "Your OTP code is 123456"; 336 List<String> included = new ArrayList<>(List.of(TextClassifier.TYPE_OTP)); 337 TextClassifier.EntityConfig config = 338 new TextClassifier.EntityConfig.Builder() 339 .setIncludedTypes(included) 340 .includeTypesFromTextClassifier(false) 341 .build(); 342 TextLinks.Request request = new TextLinks.Request.Builder(text).setEntityConfig(config).build(); 343 assertThat( 344 classifier.generateLinks(null, null, request), 345 isTextLinksContaining(text, "", TextClassifier.TYPE_OTP)); 346 } 347 348 @Test 349 @RequiresFlagsEnabled(Flags.FLAG_ENABLE_OTP_IN_TEXT_CLASSIFIERS) testGenerateLinks_otpTypeNotInRequest()350 public void testGenerateLinks_otpTypeNotInRequest() throws IOException { 351 Assume.assumeTrue(TextClassifierUtils.isOtpClassificationEnabled()); 352 String text = "Your OTP code is 123456"; 353 TextLinks.Request request = new TextLinks.Request.Builder(text).build(); 354 assertThat( 355 classifier.generateLinks(null, null, request).getLinks()).isEmpty(); 356 } 357 358 @Test testGenerateLinks_exclude()359 public void testGenerateLinks_exclude() throws IOException { 360 String text = "The number is +12122537077. See you tonight!"; 361 List<String> hints = ImmutableList.of(); 362 List<String> included = ImmutableList.of(); 363 List<String> excluded = Arrays.asList(TextClassifier.TYPE_PHONE); 364 TextLinks.Request request = 365 new TextLinks.Request.Builder(text) 366 .setEntityConfig(TextClassifier.EntityConfig.create(hints, included, excluded)) 367 .build(); 368 assertThat( 369 classifier.generateLinks(null, null, request), 370 not(isTextLinksContaining(text, "+12122537077", TextClassifier.TYPE_PHONE))); 371 } 372 373 @Test testGenerateLinks_explicit_address()374 public void testGenerateLinks_explicit_address() throws IOException { 375 String text = "The address is 1600 Amphitheater Parkway, Mountain View, CA. See you!"; 376 List<String> explicit = Arrays.asList(TextClassifier.TYPE_ADDRESS); 377 TextLinks.Request request = 378 new TextLinks.Request.Builder(text) 379 .setEntityConfig(TextClassifier.EntityConfig.createWithExplicitEntityList(explicit)) 380 .build(); 381 assertThat( 382 classifier.generateLinks(null, null, request), 383 isTextLinksContaining( 384 text, "1600 Amphitheater Parkway, Mountain View, CA", TextClassifier.TYPE_ADDRESS)); 385 } 386 387 @Test testGenerateLinks_exclude_override()388 public void testGenerateLinks_exclude_override() throws IOException { 389 String text = "You want apple@banana.com. See you tonight!"; 390 List<String> hints = ImmutableList.of(); 391 List<String> included = Arrays.asList(TextClassifier.TYPE_EMAIL); 392 List<String> excluded = Arrays.asList(TextClassifier.TYPE_EMAIL); 393 TextLinks.Request request = 394 new TextLinks.Request.Builder(text) 395 .setEntityConfig(TextClassifier.EntityConfig.create(hints, included, excluded)) 396 .build(); 397 assertThat( 398 classifier.generateLinks(null, null, request), 399 not(isTextLinksContaining(text, "apple@banana.com", TextClassifier.TYPE_EMAIL))); 400 } 401 402 @Test testGenerateLinks_maxLength()403 public void testGenerateLinks_maxLength() throws IOException { 404 char[] manySpaces = new char[classifier.getMaxGenerateLinksTextLength()]; 405 Arrays.fill(manySpaces, ' '); 406 TextLinks.Request request = new TextLinks.Request.Builder(new String(manySpaces)).build(); 407 TextLinks links = classifier.generateLinks(null, null, request); 408 assertTrue(links.getLinks().isEmpty()); 409 } 410 411 @Test testApplyLinks_unsupportedCharacter()412 public void testApplyLinks_unsupportedCharacter() throws IOException { 413 Spannable url = new SpannableString("\u202Emoc.diordna.com"); 414 TextLinks.Request request = new TextLinks.Request.Builder(url).build(); 415 assertEquals( 416 TextLinks.STATUS_UNSUPPORTED_CHARACTER, 417 classifier.generateLinks(null, null, request).apply(url, 0, null)); 418 } 419 420 @Test testGenerateLinks_tooLong()421 public void testGenerateLinks_tooLong() { 422 char[] manySpaces = new char[classifier.getMaxGenerateLinksTextLength() + 1]; 423 Arrays.fill(manySpaces, ' '); 424 TextLinks.Request request = new TextLinks.Request.Builder(new String(manySpaces)).build(); 425 expectThrows( 426 IllegalArgumentException.class, () -> classifier.generateLinks(null, null, request)); 427 } 428 429 @Test testGenerateLinks_entityData()430 public void testGenerateLinks_entityData() throws IOException { 431 String text = "The number is +12122537077."; 432 Bundle extras = new Bundle(); 433 ExtrasUtils.putIsSerializedEntityDataEnabled(extras, true); 434 TextLinks.Request request = new TextLinks.Request.Builder(text).setExtras(extras).build(); 435 436 TextLinks textLinks = classifier.generateLinks(null, null, request); 437 438 assertThat(textLinks.getLinks()).hasSize(1); 439 TextLinks.TextLink textLink = textLinks.getLinks().iterator().next(); 440 List<Bundle> entities = ExtrasUtils.getEntities(textLink.getExtras()); 441 assertThat(entities).hasSize(1); 442 Bundle entity = entities.get(0); 443 assertThat(ExtrasUtils.getEntityType(entity)).isEqualTo(TextClassifier.TYPE_PHONE); 444 } 445 446 @Test testGenerateLinks_entityData_disabled()447 public void testGenerateLinks_entityData_disabled() throws IOException { 448 String text = "The number is +12122537077."; 449 TextLinks.Request request = new TextLinks.Request.Builder(text).build(); 450 451 TextLinks textLinks = classifier.generateLinks(null, null, request); 452 453 assertThat(textLinks.getLinks()).hasSize(1); 454 TextLinks.TextLink textLink = textLinks.getLinks().iterator().next(); 455 List<Bundle> entities = ExtrasUtils.getEntities(textLink.getExtras()); 456 assertThat(entities).isNull(); 457 } 458 459 @Test testDetectLanguage()460 public void testDetectLanguage() throws IOException { 461 String text = "This is English text"; 462 TextLanguage.Request request = new TextLanguage.Request.Builder(text).build(); 463 TextLanguage textLanguage = classifier.detectLanguage(null, null, request); 464 assertThat(textLanguage, isTextLanguage("en")); 465 } 466 467 @Test testDetectLanguage_japanese()468 public void testDetectLanguage_japanese() throws IOException { 469 String text = "これは日本語のテキストです"; 470 TextLanguage.Request request = new TextLanguage.Request.Builder(text).build(); 471 TextLanguage textLanguage = classifier.detectLanguage(null, null, request); 472 assertThat(textLanguage, isTextLanguage("ja")); 473 } 474 475 @Test testSuggestConversationActions_textReplyOnly_maxOne()476 public void testSuggestConversationActions_textReplyOnly_maxOne() throws IOException { 477 ConversationActions.Message message = 478 new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS) 479 .setText("Where are you?") 480 .build(); 481 TextClassifier.EntityConfig typeConfig = 482 new TextClassifier.EntityConfig.Builder() 483 .includeTypesFromTextClassifier(false) 484 .setIncludedTypes(Collections.singletonList(ConversationAction.TYPE_TEXT_REPLY)) 485 .build(); 486 ConversationActions.Request request = 487 new ConversationActions.Request.Builder(Collections.singletonList(message)) 488 .setMaxSuggestions(1) 489 .setTypeConfig(typeConfig) 490 .build(); 491 492 ConversationActions conversationActions = 493 classifier.suggestConversationActions(null, null, request); 494 assertThat(conversationActions.getConversationActions()).hasSize(1); 495 ConversationAction conversationAction = conversationActions.getConversationActions().get(0); 496 assertThat(conversationAction.getType()).isEqualTo(ConversationAction.TYPE_TEXT_REPLY); 497 assertThat(conversationAction.getTextReply()).isNotNull(); 498 } 499 500 @Test testSuggestConversationActions_textReplyOnly_noMax()501 public void testSuggestConversationActions_textReplyOnly_noMax() throws IOException { 502 ConversationActions.Message message = 503 new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS) 504 .setText("Where are you?") 505 .build(); 506 TextClassifier.EntityConfig typeConfig = 507 new TextClassifier.EntityConfig.Builder() 508 .includeTypesFromTextClassifier(false) 509 .setIncludedTypes(Collections.singletonList(ConversationAction.TYPE_TEXT_REPLY)) 510 .build(); 511 ConversationActions.Request request = 512 new ConversationActions.Request.Builder(Collections.singletonList(message)) 513 .setTypeConfig(typeConfig) 514 .build(); 515 516 ConversationActions conversationActions = 517 classifier.suggestConversationActions(null, null, request); 518 assertTrue(conversationActions.getConversationActions().size() > 1); 519 for (ConversationAction conversationAction : conversationActions.getConversationActions()) { 520 assertThat(conversationAction, isConversationAction(ConversationAction.TYPE_TEXT_REPLY)); 521 } 522 } 523 524 @Test testSuggestConversationActions_openUrl()525 public void testSuggestConversationActions_openUrl() throws IOException { 526 ConversationActions.Message message = 527 new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS) 528 .setText("Check this out: https://www.android.com") 529 .build(); 530 TextClassifier.EntityConfig typeConfig = 531 new TextClassifier.EntityConfig.Builder() 532 .includeTypesFromTextClassifier(false) 533 .setIncludedTypes(Collections.singletonList(ConversationAction.TYPE_OPEN_URL)) 534 .build(); 535 ConversationActions.Request request = 536 new ConversationActions.Request.Builder(Collections.singletonList(message)) 537 .setMaxSuggestions(1) 538 .setTypeConfig(typeConfig) 539 .build(); 540 541 ConversationActions conversationActions = 542 classifier.suggestConversationActions(null, null, request); 543 assertThat(conversationActions.getConversationActions()).hasSize(1); 544 ConversationAction conversationAction = conversationActions.getConversationActions().get(0); 545 assertThat(conversationAction.getType()).isEqualTo(ConversationAction.TYPE_OPEN_URL); 546 Intent actionIntent = ExtrasUtils.getActionIntent(conversationAction.getExtras()); 547 assertThat(actionIntent.getAction()).isEqualTo(Intent.ACTION_VIEW); 548 assertThat(actionIntent.getData()).isEqualTo(Uri.parse("https://www.android.com")); 549 assertNoPackageInfoInExtras(actionIntent); 550 } 551 552 @Test testSuggestConversationActions_copy()553 public void testSuggestConversationActions_copy() throws IOException { 554 ConversationActions.Message message = 555 new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS) 556 .setText("Authentication code: 12345") 557 .build(); 558 TextClassifier.EntityConfig typeConfig = 559 new TextClassifier.EntityConfig.Builder() 560 .includeTypesFromTextClassifier(false) 561 .setIncludedTypes(Collections.singletonList(TYPE_COPY)) 562 .build(); 563 ConversationActions.Request request = 564 new ConversationActions.Request.Builder(Collections.singletonList(message)) 565 .setMaxSuggestions(1) 566 .setTypeConfig(typeConfig) 567 .build(); 568 569 ConversationActions conversationActions = 570 classifier.suggestConversationActions(null, null, request); 571 assertThat(conversationActions.getConversationActions()).hasSize(1); 572 ConversationAction conversationAction = conversationActions.getConversationActions().get(0); 573 assertThat(conversationAction.getType()).isEqualTo(TYPE_COPY); 574 assertThat(conversationAction.getTextReply()).isAnyOf(null, ""); 575 assertThat(conversationAction.getAction()).isNull(); 576 String code = ExtrasUtils.getCopyText(conversationAction.getExtras()); 577 assertThat(code).isEqualTo("12345"); 578 assertThat(ExtrasUtils.getSerializedEntityData(conversationAction.getExtras())).isNotEmpty(); 579 } 580 581 @Test testSuggestConversationActions_deduplicate()582 public void testSuggestConversationActions_deduplicate() throws IOException { 583 ConversationActions.Message message = 584 new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS) 585 .setText("a@android.com b@android.com") 586 .build(); 587 ConversationActions.Request request = 588 new ConversationActions.Request.Builder(Collections.singletonList(message)) 589 .setMaxSuggestions(3) 590 .build(); 591 592 ConversationActions conversationActions = 593 classifier.suggestConversationActions(null, null, request); 594 595 assertThat(conversationActions.getConversationActions()).isEmpty(); 596 } 597 598 @Test testUseCachedAnnotatorModelDisabled()599 public void testUseCachedAnnotatorModelDisabled() throws IOException { 600 deviceConfig.setConfig(TextClassifierSettings.MODEL_DOWNLOAD_MANAGER_ENABLED, true); 601 602 String annotatorFilePath = TestDataUtils.getTestAnnotatorModelFile().getPath(); 603 ModelFile annotatorModelA = 604 new ModelFile(ModelType.ANNOTATOR, annotatorFilePath, 701, "en", false); 605 ModelFile annotatorModelB = 606 new ModelFile(ModelType.ANNOTATOR, annotatorFilePath, 801, "en", false); 607 608 String englishText = "You can reach me on +12122537077."; 609 String classifiedText = "+12122537077"; 610 TextClassification.Request request = 611 new TextClassification.Request.Builder(englishText, 0, englishText.length()).build(); 612 613 // Check modelFileA v701 614 when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any())) 615 .thenReturn(annotatorModelA); 616 TextClassification classificationA = classifier.classifyText(null, null, request); 617 618 assertThat(classificationA.getId()).contains("v701"); 619 assertThat(classificationA.getText()).contains(classifiedText); 620 assertArrayEquals( 621 new int[] {0, 0, 0, 0}, 622 new int[] { 623 annotatorModelCache.putCount(), 624 annotatorModelCache.evictionCount(), 625 annotatorModelCache.hitCount(), 626 annotatorModelCache.missCount() 627 }); 628 629 // Check modelFileB v801 630 when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any())) 631 .thenReturn(annotatorModelB); 632 TextClassification classificationB = classifier.classifyText(null, null, request); 633 634 assertThat(classificationB.getId()).contains("v801"); 635 assertThat(classificationB.getText()).contains(classifiedText); 636 assertArrayEquals( 637 new int[] {0, 0, 0, 0}, 638 new int[] { 639 annotatorModelCache.putCount(), 640 annotatorModelCache.evictionCount(), 641 annotatorModelCache.hitCount(), 642 annotatorModelCache.missCount() 643 }); 644 645 // Reload modelFileA v701 646 when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any())) 647 .thenReturn(annotatorModelA); 648 TextClassification classificationAcached = classifier.classifyText(null, null, request); 649 650 assertThat(classificationAcached.getId()).contains("v701"); 651 assertThat(classificationAcached.getText()).contains(classifiedText); 652 assertArrayEquals( 653 new int[] {0, 0, 0, 0}, 654 new int[] { 655 annotatorModelCache.putCount(), 656 annotatorModelCache.evictionCount(), 657 annotatorModelCache.hitCount(), 658 annotatorModelCache.missCount() 659 }); 660 } 661 662 @Test testUseCachedAnnotatorModelEnabled()663 public void testUseCachedAnnotatorModelEnabled() throws IOException { 664 deviceConfig.setConfig(TextClassifierSettings.MODEL_DOWNLOAD_MANAGER_ENABLED, true); 665 deviceConfig.setConfig(TextClassifierSettings.MULTI_ANNOTATOR_CACHE_ENABLED, true); 666 667 String annotatorFilePath = TestDataUtils.getTestAnnotatorModelFile().getPath(); 668 ModelFile annotatorModelA = 669 new ModelFile(ModelType.ANNOTATOR, annotatorFilePath, 701, "en", false); 670 ModelFile annotatorModelB = 671 new ModelFile(ModelType.ANNOTATOR, annotatorFilePath, 801, "en", false); 672 673 String englishText = "You can reach me on +12122537077."; 674 String classifiedText = "+12122537077"; 675 TextClassification.Request request = 676 new TextClassification.Request.Builder(englishText, 0, englishText.length()).build(); 677 678 // Check modelFileA v701 679 when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any())) 680 .thenReturn(annotatorModelA); 681 TextClassification classification = classifier.classifyText(null, null, request); 682 683 assertThat(classification.getId()).contains("v701"); 684 assertThat(classification.getText()).contains(classifiedText); 685 assertArrayEquals( 686 new int[] {1, 0, 0, 1}, 687 new int[] { 688 annotatorModelCache.putCount(), 689 annotatorModelCache.evictionCount(), 690 annotatorModelCache.hitCount(), 691 annotatorModelCache.missCount() 692 }); 693 694 // Check modelFileB v801 695 when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any())) 696 .thenReturn(annotatorModelB); 697 TextClassification classificationB = classifier.classifyText(null, null, request); 698 699 assertThat(classificationB.getId()).contains("v801"); 700 assertThat(classificationB.getText()).contains(classifiedText); 701 assertArrayEquals( 702 new int[] {2, 0, 0, 2}, 703 new int[] { 704 annotatorModelCache.putCount(), 705 annotatorModelCache.evictionCount(), 706 annotatorModelCache.hitCount(), 707 annotatorModelCache.missCount() 708 }); 709 710 // Reload modelFileA v701 711 when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any())) 712 .thenReturn(annotatorModelA); 713 TextClassification classificationAcached = classifier.classifyText(null, null, request); 714 715 assertThat(classificationAcached.getId()).contains("v701"); 716 assertThat(classificationAcached.getText()).contains(classifiedText); 717 assertArrayEquals( 718 new int[] {2, 0, 1, 2}, 719 new int[] { 720 annotatorModelCache.putCount(), 721 annotatorModelCache.evictionCount(), 722 annotatorModelCache.hitCount(), 723 annotatorModelCache.missCount() 724 }); 725 } 726 assertNoPackageInfoInExtras(Intent intent)727 private static void assertNoPackageInfoInExtras(Intent intent) { 728 assertThat(intent.getComponent()).isNull(); 729 assertThat(intent.getPackage()).isNull(); 730 } 731 isTextSelection( final int startIndex, final int endIndex, final String type)732 private static Matcher<TextSelection> isTextSelection( 733 final int startIndex, final int endIndex, final String type) { 734 return new BaseMatcher<TextSelection>() { 735 @Override 736 public boolean matches(Object o) { 737 if (o instanceof TextSelection) { 738 TextSelection selection = (TextSelection) o; 739 return startIndex == selection.getSelectionStartIndex() 740 && endIndex == selection.getSelectionEndIndex() 741 && typeMatches(selection, type); 742 } 743 return false; 744 } 745 746 private boolean typeMatches(TextSelection selection, String type) { 747 return type == null 748 || (selection.getEntityCount() > 0 749 && type.trim().equalsIgnoreCase(selection.getEntity(0))); 750 } 751 752 @Override 753 public void describeTo(Description description) { 754 description.appendValue(String.format("%d, %d, %s", startIndex, endIndex, type)); 755 } 756 }; 757 } 758 759 private static Matcher<TextLinks> isTextLinksContaining( 760 final String text, final String substring, final String type) { 761 return new BaseMatcher<TextLinks>() { 762 763 @Override 764 public void describeTo(Description description) { 765 description 766 .appendText("text=") 767 .appendValue(text) 768 .appendText(", substring=") 769 .appendValue(substring) 770 .appendText(", type=") 771 .appendValue(type); 772 } 773 774 @Override 775 public boolean matches(Object o) { 776 if (o instanceof TextLinks) { 777 for (TextLinks.TextLink link : ((TextLinks) o).getLinks()) { 778 if (text.subSequence(link.getStart(), link.getEnd()).toString().equals(substring)) { 779 return type.equals(link.getEntity(0)); 780 } 781 } 782 } 783 return false; 784 } 785 }; 786 } 787 788 private static Matcher<TextClassification> isTextClassification( 789 final String text, final String type) { 790 return new BaseMatcher<TextClassification>() { 791 @Override 792 public boolean matches(Object o) { 793 if (o instanceof TextClassification) { 794 TextClassification result = (TextClassification) o; 795 return text.equals(result.getText()) 796 && result.getEntityCount() > 0 797 && type.equals(result.getEntity(0)); 798 } 799 return false; 800 } 801 802 @Override 803 public void describeTo(Description description) { 804 description.appendText("text=").appendValue(text).appendText(", type=").appendValue(type); 805 } 806 }; 807 } 808 809 private static Matcher<TextClassification> containsIntentWithAction(final String action) { 810 return new BaseMatcher<TextClassification>() { 811 @Override 812 public boolean matches(Object o) { 813 if (o instanceof TextClassification) { 814 TextClassification result = (TextClassification) o; 815 return ExtrasUtils.findAction(result, action) != null; 816 } 817 return false; 818 } 819 820 @Override 821 public void describeTo(Description description) { 822 description.appendText("intent action=").appendValue(action); 823 } 824 }; 825 } 826 827 private static Matcher<TextLanguage> isTextLanguage(final String languageTag) { 828 return new BaseMatcher<TextLanguage>() { 829 @Override 830 public boolean matches(Object o) { 831 if (o instanceof TextLanguage) { 832 TextLanguage result = (TextLanguage) o; 833 return result.getLocaleHypothesisCount() > 0 834 && languageTag.equals(result.getLocale(0).toLanguageTag()); 835 } 836 return false; 837 } 838 839 @Override 840 public void describeTo(Description description) { 841 description.appendText("locale=").appendValue(languageTag); 842 } 843 }; 844 } 845 846 private static Matcher<ConversationAction> isConversationAction(String actionType) { 847 return new BaseMatcher<ConversationAction>() { 848 @Override 849 public boolean matches(Object o) { 850 if (!(o instanceof ConversationAction)) { 851 return false; 852 } 853 ConversationAction conversationAction = (ConversationAction) o; 854 if (!actionType.equals(conversationAction.getType())) { 855 return false; 856 } 857 if (ConversationAction.TYPE_TEXT_REPLY.equals(actionType)) { 858 if (conversationAction.getTextReply() == null) { 859 return false; 860 } 861 } 862 if (conversationAction.getConfidenceScore() < 0 863 || conversationAction.getConfidenceScore() > 1) { 864 return false; 865 } 866 return true; 867 } 868 869 @Override 870 public void describeTo(Description description) { 871 description.appendText("actionType=").appendValue(actionType); 872 } 873 }; 874 } 875 } 876