/* * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.android.textclassifier; import static com.google.common.truth.Truth.assertThat; import static org.hamcrest.CoreMatchers.not; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; import static org.mockito.Mockito.any; import static org.mockito.Mockito.eq; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.testng.Assert.expectThrows; import android.app.RemoteAction; import android.content.Context; import android.content.Intent; import android.net.Uri; import android.os.Bundle; import android.os.LocaleList; import android.text.Spannable; import android.text.SpannableString; import android.view.textclassifier.ConversationAction; import android.view.textclassifier.ConversationActions; import android.view.textclassifier.TextClassification; import android.view.textclassifier.TextClassifier; import android.view.textclassifier.TextLanguage; import android.view.textclassifier.TextLinks; import android.view.textclassifier.TextSelection; import androidx.collection.LruCache; import androidx.test.core.app.ApplicationProvider; import androidx.test.ext.junit.runners.AndroidJUnit4; import androidx.test.filters.SdkSuppress; import androidx.test.filters.SmallTest; import com.android.textclassifier.common.ModelFile; import com.android.textclassifier.common.ModelType; import com.android.textclassifier.common.TextClassifierSettings; import com.android.textclassifier.testing.FakeContextBuilder; import com.android.textclassifier.testing.TestingDeviceConfig; import com.google.android.textclassifier.AnnotatorModel; import com.google.common.collect.ImmutableList; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; import org.hamcrest.BaseMatcher; import org.hamcrest.Description; import org.hamcrest.Matcher; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.MockitoAnnotations; @SmallTest @RunWith(AndroidJUnit4.class) public class TextClassifierImplTest { private static final String TYPE_COPY = "copy"; private static final LocaleList LOCALES = LocaleList.forLanguageTags("en-US"); private static final String NO_TYPE = null; @Mock private ModelFileManager modelFileManager; private Context context; private TestingDeviceConfig deviceConfig; private TextClassifierSettings settings; private LruCache annotatorModelCache; private TextClassifierImpl classifier; @Before public void setup() throws IOException { MockitoAnnotations.initMocks(this); this.context = new FakeContextBuilder() .setAllIntentComponent(FakeContextBuilder.DEFAULT_COMPONENT) .setAppLabel(FakeContextBuilder.DEFAULT_COMPONENT.getPackageName(), "Test app") .build(); this.deviceConfig = new TestingDeviceConfig(); this.settings = new TextClassifierSettings(deviceConfig); this.annotatorModelCache = new LruCache<>(2); this.classifier = new TextClassifierImpl(context, settings, modelFileManager, annotatorModelCache); when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any())) .thenReturn(TestDataUtils.getTestAnnotatorModelFileWrapped()); when(modelFileManager.findBestModelFile(eq(ModelType.LANG_ID), any(), any())) .thenReturn(TestDataUtils.getLangIdModelFileWrapped()); when(modelFileManager.findBestModelFile(eq(ModelType.ACTIONS_SUGGESTIONS), any(), any())) .thenReturn(TestDataUtils.getTestActionsModelFileWrapped()); } @Test public void testSuggestSelection() throws IOException { String text = "Contact me at droid@android.com"; String selected = "droid"; String suggested = "droid@android.com"; int startIndex = text.indexOf(selected); int endIndex = startIndex + selected.length(); int smartStartIndex = text.indexOf(suggested); int smartEndIndex = smartStartIndex + suggested.length(); TextSelection.Request request = new TextSelection.Request.Builder(text, startIndex, endIndex).build(); TextSelection selection = classifier.suggestSelection(null, null, request); assertThat( selection, isTextSelection(smartStartIndex, smartEndIndex, TextClassifier.TYPE_EMAIL)); } @Test public void testSuggestSelection_localePreferenceIsPassedToModelFileManager() throws IOException { String text = "Contact me at droid@android.com"; String selected = "droid"; String suggested = "droid@android.com"; int startIndex = text.indexOf(selected); int endIndex = startIndex + selected.length(); int smartStartIndex = text.indexOf(suggested); int smartEndIndex = smartStartIndex + suggested.length(); TextSelection.Request request = new TextSelection.Request.Builder(text, startIndex, endIndex) .setDefaultLocales(LOCALES) .build(); classifier.suggestSelection(null, null, request); verify(modelFileManager).findBestModelFile(eq(ModelType.ANNOTATOR), eq(LOCALES), any()); } @Test public void testSuggestSelection_url() throws IOException { String text = "Visit http://www.android.com for more information"; String selected = "http"; String suggested = "http://www.android.com"; int startIndex = text.indexOf(selected); int endIndex = startIndex + selected.length(); int smartStartIndex = text.indexOf(suggested); int smartEndIndex = smartStartIndex + suggested.length(); TextSelection.Request request = new TextSelection.Request.Builder(text, startIndex, endIndex).build(); TextSelection selection = classifier.suggestSelection(null, null, request); assertThat(selection, isTextSelection(smartStartIndex, smartEndIndex, TextClassifier.TYPE_URL)); } @Test public void testSmartSelection_withEmoji() throws IOException { String text = "\uD83D\uDE02 Hello."; String selected = "Hello"; int startIndex = text.indexOf(selected); int endIndex = startIndex + selected.length(); TextSelection.Request request = new TextSelection.Request.Builder(text, startIndex, endIndex).build(); TextSelection selection = classifier.suggestSelection(null, null, request); assertThat(selection, isTextSelection(startIndex, endIndex, NO_TYPE)); } @SdkSuppress(minSdkVersion = 31, codeName = "S") @Test public void testSuggestSelection_includeTextClassification() throws IOException { String text = "Visit http://www.android.com for more information"; String suggested = "http://www.android.com"; int startIndex = text.indexOf(suggested); TextSelection.Request request = new TextSelection.Request.Builder(text, startIndex, /*endIndex=*/ startIndex + 1) .setIncludeTextClassification(true) .build(); TextSelection selection = classifier.suggestSelection(null, null, request); assertThat( selection.getTextClassification(), isTextClassification(suggested, TextClassifier.TYPE_URL)); assertThat(selection.getTextClassification(), containsIntentWithAction(Intent.ACTION_VIEW)); } @SdkSuppress(minSdkVersion = 31, codeName = "S") @Test public void testSuggestSelection_notIncludeTextClassification() throws IOException { String text = "Visit http://www.android.com for more information"; TextSelection.Request request = new TextSelection.Request.Builder(text, /*startIndex=*/ 0, /*endIndex=*/ 4) .setIncludeTextClassification(false) .build(); TextSelection selection = classifier.suggestSelection(null, null, request); assertThat(selection.getTextClassification()).isNull(); } @Test public void testClassifyText() throws IOException { String text = "Contact me at droid@android.com"; String classifiedText = "droid@android.com"; int startIndex = text.indexOf(classifiedText); int endIndex = startIndex + classifiedText.length(); TextClassification.Request request = new TextClassification.Request.Builder(text, startIndex, endIndex).build(); TextClassification classification = classifier.classifyText(/* sessionId= */ null, null, request); assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_EMAIL)); } @Test public void testClassifyText_url() throws IOException { String text = "Visit www.android.com for more information"; String classifiedText = "www.android.com"; int startIndex = text.indexOf(classifiedText); int endIndex = startIndex + classifiedText.length(); TextClassification.Request request = new TextClassification.Request.Builder(text, startIndex, endIndex).build(); TextClassification classification = classifier.classifyText(null, null, request); assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_URL)); assertThat(classification, containsIntentWithAction(Intent.ACTION_VIEW)); } @Test public void testClassifyText_address() throws IOException { String text = "Brandschenkestrasse 110, Zürich, Switzerland"; TextClassification.Request request = new TextClassification.Request.Builder(text, 0, text.length()).build(); TextClassification classification = classifier.classifyText(null, null, request); assertThat(classification, isTextClassification(text, TextClassifier.TYPE_ADDRESS)); } @Test public void testClassifyText_url_inCaps() throws IOException { String text = "Visit HTTP://ANDROID.COM for more information"; String classifiedText = "HTTP://ANDROID.COM"; int startIndex = text.indexOf(classifiedText); int endIndex = startIndex + classifiedText.length(); TextClassification.Request request = new TextClassification.Request.Builder(text, startIndex, endIndex).build(); TextClassification classification = classifier.classifyText(null, null, request); assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_URL)); assertThat(classification, containsIntentWithAction(Intent.ACTION_VIEW)); } @Test public void testClassifyText_date() throws IOException { String text = "Let's meet on January 9, 2018."; String classifiedText = "January 9, 2018"; int startIndex = text.indexOf(classifiedText); int endIndex = startIndex + classifiedText.length(); TextClassification.Request request = new TextClassification.Request.Builder(text, startIndex, endIndex).build(); TextClassification classification = classifier.classifyText(null, null, request); assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_DATE)); Bundle extras = classification.getExtras(); List entities = ExtrasUtils.getEntities(extras); assertThat(entities).hasSize(1); assertThat(ExtrasUtils.getEntityType(entities.get(0))).isEqualTo(TextClassifier.TYPE_DATE); ArrayList actionsIntents = ExtrasUtils.getActionsIntents(classification); actionsIntents.forEach(TextClassifierImplTest::assertNoPackageInfoInExtras); } @Test public void testClassifyText_datetime() throws IOException { String text = "Let's meet 2018/01/01 10:30:20."; String classifiedText = "2018/01/01 10:30:20"; int startIndex = text.indexOf(classifiedText); int endIndex = startIndex + classifiedText.length(); TextClassification.Request request = new TextClassification.Request.Builder(text, startIndex, endIndex).build(); TextClassification classification = classifier.classifyText(null, null, request); assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_DATE_TIME)); } @Test public void testClassifyText_foreignText() throws IOException { LocaleList originalLocales = LocaleList.getDefault(); LocaleList.setDefault(LocaleList.forLanguageTags("en")); String japaneseText = "これは日本語のテキストです"; TextClassification.Request request = new TextClassification.Request.Builder(japaneseText, 0, japaneseText.length()).build(); TextClassification classification = classifier.classifyText(null, null, request); RemoteAction translateAction = classification.getActions().get(0); assertEquals(1, classification.getActions().size()); assertEquals(Intent.ACTION_TRANSLATE, classification.getIntent().getAction()); assertEquals(translateAction, ExtrasUtils.findTranslateAction(classification)); Intent intent = ExtrasUtils.getActionsIntents(classification).get(0); assertNoPackageInfoInExtras(intent); assertEquals(Intent.ACTION_TRANSLATE, intent.getAction()); Bundle foreignLanguageInfo = ExtrasUtils.getForeignLanguageExtra(classification); assertEquals("ja", ExtrasUtils.getEntityType(foreignLanguageInfo)); assertTrue(ExtrasUtils.getScore(foreignLanguageInfo) >= 0); assertTrue(ExtrasUtils.getScore(foreignLanguageInfo) <= 1); assertTrue(intent.hasExtra(TextClassifier.EXTRA_FROM_TEXT_CLASSIFIER)); assertEquals("ja", ExtrasUtils.getTopLanguage(intent).first); LocaleList.setDefault(originalLocales); } @Test public void testGenerateLinks_phone() throws IOException { String text = "The number is +12122537077. See you tonight!"; TextLinks.Request request = new TextLinks.Request.Builder(text).build(); assertThat( classifier.generateLinks(null, null, request), isTextLinksContaining(text, "+12122537077", TextClassifier.TYPE_PHONE)); } @Test public void testGenerateLinks_exclude() throws IOException { String text = "The number is +12122537077. See you tonight!"; List hints = ImmutableList.of(); List included = ImmutableList.of(); List excluded = Arrays.asList(TextClassifier.TYPE_PHONE); TextLinks.Request request = new TextLinks.Request.Builder(text) .setEntityConfig(TextClassifier.EntityConfig.create(hints, included, excluded)) .build(); assertThat( classifier.generateLinks(null, null, request), not(isTextLinksContaining(text, "+12122537077", TextClassifier.TYPE_PHONE))); } @Test public void testGenerateLinks_explicit_address() throws IOException { String text = "The address is 1600 Amphitheater Parkway, Mountain View, CA. See you!"; List explicit = Arrays.asList(TextClassifier.TYPE_ADDRESS); TextLinks.Request request = new TextLinks.Request.Builder(text) .setEntityConfig(TextClassifier.EntityConfig.createWithExplicitEntityList(explicit)) .build(); assertThat( classifier.generateLinks(null, null, request), isTextLinksContaining( text, "1600 Amphitheater Parkway, Mountain View, CA", TextClassifier.TYPE_ADDRESS)); } @Test public void testGenerateLinks_exclude_override() throws IOException { String text = "You want apple@banana.com. See you tonight!"; List hints = ImmutableList.of(); List included = Arrays.asList(TextClassifier.TYPE_EMAIL); List excluded = Arrays.asList(TextClassifier.TYPE_EMAIL); TextLinks.Request request = new TextLinks.Request.Builder(text) .setEntityConfig(TextClassifier.EntityConfig.create(hints, included, excluded)) .build(); assertThat( classifier.generateLinks(null, null, request), not(isTextLinksContaining(text, "apple@banana.com", TextClassifier.TYPE_EMAIL))); } @Test public void testGenerateLinks_maxLength() throws IOException { char[] manySpaces = new char[classifier.getMaxGenerateLinksTextLength()]; Arrays.fill(manySpaces, ' '); TextLinks.Request request = new TextLinks.Request.Builder(new String(manySpaces)).build(); TextLinks links = classifier.generateLinks(null, null, request); assertTrue(links.getLinks().isEmpty()); } @Test public void testApplyLinks_unsupportedCharacter() throws IOException { Spannable url = new SpannableString("\u202Emoc.diordna.com"); TextLinks.Request request = new TextLinks.Request.Builder(url).build(); assertEquals( TextLinks.STATUS_UNSUPPORTED_CHARACTER, classifier.generateLinks(null, null, request).apply(url, 0, null)); } @Test public void testGenerateLinks_tooLong() { char[] manySpaces = new char[classifier.getMaxGenerateLinksTextLength() + 1]; Arrays.fill(manySpaces, ' '); TextLinks.Request request = new TextLinks.Request.Builder(new String(manySpaces)).build(); expectThrows( IllegalArgumentException.class, () -> classifier.generateLinks(null, null, request)); } @Test public void testGenerateLinks_entityData() throws IOException { String text = "The number is +12122537077."; Bundle extras = new Bundle(); ExtrasUtils.putIsSerializedEntityDataEnabled(extras, true); TextLinks.Request request = new TextLinks.Request.Builder(text).setExtras(extras).build(); TextLinks textLinks = classifier.generateLinks(null, null, request); assertThat(textLinks.getLinks()).hasSize(1); TextLinks.TextLink textLink = textLinks.getLinks().iterator().next(); List entities = ExtrasUtils.getEntities(textLink.getExtras()); assertThat(entities).hasSize(1); Bundle entity = entities.get(0); assertThat(ExtrasUtils.getEntityType(entity)).isEqualTo(TextClassifier.TYPE_PHONE); } @Test public void testGenerateLinks_entityData_disabled() throws IOException { String text = "The number is +12122537077."; TextLinks.Request request = new TextLinks.Request.Builder(text).build(); TextLinks textLinks = classifier.generateLinks(null, null, request); assertThat(textLinks.getLinks()).hasSize(1); TextLinks.TextLink textLink = textLinks.getLinks().iterator().next(); List entities = ExtrasUtils.getEntities(textLink.getExtras()); assertThat(entities).isNull(); } @Test public void testDetectLanguage() throws IOException { String text = "This is English text"; TextLanguage.Request request = new TextLanguage.Request.Builder(text).build(); TextLanguage textLanguage = classifier.detectLanguage(null, null, request); assertThat(textLanguage, isTextLanguage("en")); } @Test public void testDetectLanguage_japanese() throws IOException { String text = "これは日本語のテキストです"; TextLanguage.Request request = new TextLanguage.Request.Builder(text).build(); TextLanguage textLanguage = classifier.detectLanguage(null, null, request); assertThat(textLanguage, isTextLanguage("ja")); } @Test public void testSuggestConversationActions_textReplyOnly_maxOne() throws IOException { ConversationActions.Message message = new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS) .setText("Where are you?") .build(); TextClassifier.EntityConfig typeConfig = new TextClassifier.EntityConfig.Builder() .includeTypesFromTextClassifier(false) .setIncludedTypes(Collections.singletonList(ConversationAction.TYPE_TEXT_REPLY)) .build(); ConversationActions.Request request = new ConversationActions.Request.Builder(Collections.singletonList(message)) .setMaxSuggestions(1) .setTypeConfig(typeConfig) .build(); ConversationActions conversationActions = classifier.suggestConversationActions(null, null, request); assertThat(conversationActions.getConversationActions()).hasSize(1); ConversationAction conversationAction = conversationActions.getConversationActions().get(0); assertThat(conversationAction.getType()).isEqualTo(ConversationAction.TYPE_TEXT_REPLY); assertThat(conversationAction.getTextReply()).isNotNull(); } @Test public void testSuggestConversationActions_textReplyOnly_noMax() throws IOException { ConversationActions.Message message = new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS) .setText("Where are you?") .build(); TextClassifier.EntityConfig typeConfig = new TextClassifier.EntityConfig.Builder() .includeTypesFromTextClassifier(false) .setIncludedTypes(Collections.singletonList(ConversationAction.TYPE_TEXT_REPLY)) .build(); ConversationActions.Request request = new ConversationActions.Request.Builder(Collections.singletonList(message)) .setTypeConfig(typeConfig) .build(); ConversationActions conversationActions = classifier.suggestConversationActions(null, null, request); assertTrue(conversationActions.getConversationActions().size() > 1); for (ConversationAction conversationAction : conversationActions.getConversationActions()) { assertThat(conversationAction, isConversationAction(ConversationAction.TYPE_TEXT_REPLY)); } } @Test public void testSuggestConversationActions_openUrl() throws IOException { ConversationActions.Message message = new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS) .setText("Check this out: https://www.android.com") .build(); TextClassifier.EntityConfig typeConfig = new TextClassifier.EntityConfig.Builder() .includeTypesFromTextClassifier(false) .setIncludedTypes(Collections.singletonList(ConversationAction.TYPE_OPEN_URL)) .build(); ConversationActions.Request request = new ConversationActions.Request.Builder(Collections.singletonList(message)) .setMaxSuggestions(1) .setTypeConfig(typeConfig) .build(); ConversationActions conversationActions = classifier.suggestConversationActions(null, null, request); assertThat(conversationActions.getConversationActions()).hasSize(1); ConversationAction conversationAction = conversationActions.getConversationActions().get(0); assertThat(conversationAction.getType()).isEqualTo(ConversationAction.TYPE_OPEN_URL); Intent actionIntent = ExtrasUtils.getActionIntent(conversationAction.getExtras()); assertThat(actionIntent.getAction()).isEqualTo(Intent.ACTION_VIEW); assertThat(actionIntent.getData()).isEqualTo(Uri.parse("https://www.android.com")); assertNoPackageInfoInExtras(actionIntent); } @Test public void testSuggestConversationActions_copy() throws IOException { ConversationActions.Message message = new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS) .setText("Authentication code: 12345") .build(); TextClassifier.EntityConfig typeConfig = new TextClassifier.EntityConfig.Builder() .includeTypesFromTextClassifier(false) .setIncludedTypes(Collections.singletonList(TYPE_COPY)) .build(); ConversationActions.Request request = new ConversationActions.Request.Builder(Collections.singletonList(message)) .setMaxSuggestions(1) .setTypeConfig(typeConfig) .build(); ConversationActions conversationActions = classifier.suggestConversationActions(null, null, request); assertThat(conversationActions.getConversationActions()).hasSize(1); ConversationAction conversationAction = conversationActions.getConversationActions().get(0); assertThat(conversationAction.getType()).isEqualTo(TYPE_COPY); assertThat(conversationAction.getTextReply()).isAnyOf(null, ""); assertThat(conversationAction.getAction()).isNull(); String code = ExtrasUtils.getCopyText(conversationAction.getExtras()); assertThat(code).isEqualTo("12345"); assertThat(ExtrasUtils.getSerializedEntityData(conversationAction.getExtras())).isNotEmpty(); } @Test public void testSuggestConversationActions_deduplicate() throws IOException { ConversationActions.Message message = new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS) .setText("a@android.com b@android.com") .build(); ConversationActions.Request request = new ConversationActions.Request.Builder(Collections.singletonList(message)) .setMaxSuggestions(3) .build(); ConversationActions conversationActions = classifier.suggestConversationActions(null, null, request); assertThat(conversationActions.getConversationActions()).isEmpty(); } @Test public void testUseCachedAnnotatorModelDisabled() throws IOException { deviceConfig.setConfig(TextClassifierSettings.MODEL_DOWNLOAD_MANAGER_ENABLED, true); String annotatorFilePath = TestDataUtils.getTestAnnotatorModelFile().getPath(); ModelFile annotatorModelA = new ModelFile(ModelType.ANNOTATOR, annotatorFilePath, 701, "en", false); ModelFile annotatorModelB = new ModelFile(ModelType.ANNOTATOR, annotatorFilePath, 801, "en", false); String englishText = "You can reach me on +12122537077."; String classifiedText = "+12122537077"; TextClassification.Request request = new TextClassification.Request.Builder(englishText, 0, englishText.length()).build(); // Check modelFileA v701 when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any())) .thenReturn(annotatorModelA); TextClassification classificationA = classifier.classifyText(null, null, request); assertThat(classificationA.getId()).contains("v701"); assertThat(classificationA.getText()).contains(classifiedText); assertArrayEquals( new int[] {0, 0, 0, 0}, new int[] { annotatorModelCache.putCount(), annotatorModelCache.evictionCount(), annotatorModelCache.hitCount(), annotatorModelCache.missCount() }); // Check modelFileB v801 when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any())) .thenReturn(annotatorModelB); TextClassification classificationB = classifier.classifyText(null, null, request); assertThat(classificationB.getId()).contains("v801"); assertThat(classificationB.getText()).contains(classifiedText); assertArrayEquals( new int[] {0, 0, 0, 0}, new int[] { annotatorModelCache.putCount(), annotatorModelCache.evictionCount(), annotatorModelCache.hitCount(), annotatorModelCache.missCount() }); // Reload modelFileA v701 when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any())) .thenReturn(annotatorModelA); TextClassification classificationAcached = classifier.classifyText(null, null, request); assertThat(classificationAcached.getId()).contains("v701"); assertThat(classificationAcached.getText()).contains(classifiedText); assertArrayEquals( new int[] {0, 0, 0, 0}, new int[] { annotatorModelCache.putCount(), annotatorModelCache.evictionCount(), annotatorModelCache.hitCount(), annotatorModelCache.missCount() }); } @Test public void testUseCachedAnnotatorModelEnabled() throws IOException { deviceConfig.setConfig(TextClassifierSettings.MODEL_DOWNLOAD_MANAGER_ENABLED, true); deviceConfig.setConfig(TextClassifierSettings.MULTI_ANNOTATOR_CACHE_ENABLED, true); String annotatorFilePath = TestDataUtils.getTestAnnotatorModelFile().getPath(); ModelFile annotatorModelA = new ModelFile(ModelType.ANNOTATOR, annotatorFilePath, 701, "en", false); ModelFile annotatorModelB = new ModelFile(ModelType.ANNOTATOR, annotatorFilePath, 801, "en", false); String englishText = "You can reach me on +12122537077."; String classifiedText = "+12122537077"; TextClassification.Request request = new TextClassification.Request.Builder(englishText, 0, englishText.length()).build(); // Check modelFileA v701 when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any())) .thenReturn(annotatorModelA); TextClassification classification = classifier.classifyText(null, null, request); assertThat(classification.getId()).contains("v701"); assertThat(classification.getText()).contains(classifiedText); assertArrayEquals( new int[] {1, 0, 0, 1}, new int[] { annotatorModelCache.putCount(), annotatorModelCache.evictionCount(), annotatorModelCache.hitCount(), annotatorModelCache.missCount() }); // Check modelFileB v801 when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any())) .thenReturn(annotatorModelB); TextClassification classificationB = classifier.classifyText(null, null, request); assertThat(classificationB.getId()).contains("v801"); assertThat(classificationB.getText()).contains(classifiedText); assertArrayEquals( new int[] {2, 0, 0, 2}, new int[] { annotatorModelCache.putCount(), annotatorModelCache.evictionCount(), annotatorModelCache.hitCount(), annotatorModelCache.missCount() }); // Reload modelFileA v701 when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any())) .thenReturn(annotatorModelA); TextClassification classificationAcached = classifier.classifyText(null, null, request); assertThat(classificationAcached.getId()).contains("v701"); assertThat(classificationAcached.getText()).contains(classifiedText); assertArrayEquals( new int[] {2, 0, 1, 2}, new int[] { annotatorModelCache.putCount(), annotatorModelCache.evictionCount(), annotatorModelCache.hitCount(), annotatorModelCache.missCount() }); } private static void assertNoPackageInfoInExtras(Intent intent) { assertThat(intent.getComponent()).isNull(); assertThat(intent.getPackage()).isNull(); } private static Matcher isTextSelection( final int startIndex, final int endIndex, final String type) { return new BaseMatcher() { @Override public boolean matches(Object o) { if (o instanceof TextSelection) { TextSelection selection = (TextSelection) o; return startIndex == selection.getSelectionStartIndex() && endIndex == selection.getSelectionEndIndex() && typeMatches(selection, type); } return false; } private boolean typeMatches(TextSelection selection, String type) { return type == null || (selection.getEntityCount() > 0 && type.trim().equalsIgnoreCase(selection.getEntity(0))); } @Override public void describeTo(Description description) { description.appendValue(String.format("%d, %d, %s", startIndex, endIndex, type)); } }; } private static Matcher isTextLinksContaining( final String text, final String substring, final String type) { return new BaseMatcher() { @Override public void describeTo(Description description) { description .appendText("text=") .appendValue(text) .appendText(", substring=") .appendValue(substring) .appendText(", type=") .appendValue(type); } @Override public boolean matches(Object o) { if (o instanceof TextLinks) { for (TextLinks.TextLink link : ((TextLinks) o).getLinks()) { if (text.subSequence(link.getStart(), link.getEnd()).toString().equals(substring)) { return type.equals(link.getEntity(0)); } } } return false; } }; } private static Matcher isTextClassification( final String text, final String type) { return new BaseMatcher() { @Override public boolean matches(Object o) { if (o instanceof TextClassification) { TextClassification result = (TextClassification) o; return text.equals(result.getText()) && result.getEntityCount() > 0 && type.equals(result.getEntity(0)); } return false; } @Override public void describeTo(Description description) { description.appendText("text=").appendValue(text).appendText(", type=").appendValue(type); } }; } private static Matcher containsIntentWithAction(final String action) { return new BaseMatcher() { @Override public boolean matches(Object o) { if (o instanceof TextClassification) { TextClassification result = (TextClassification) o; return ExtrasUtils.findAction(result, action) != null; } return false; } @Override public void describeTo(Description description) { description.appendText("intent action=").appendValue(action); } }; } private static Matcher isTextLanguage(final String languageTag) { return new BaseMatcher() { @Override public boolean matches(Object o) { if (o instanceof TextLanguage) { TextLanguage result = (TextLanguage) o; return result.getLocaleHypothesisCount() > 0 && languageTag.equals(result.getLocale(0).toLanguageTag()); } return false; } @Override public void describeTo(Description description) { description.appendText("locale=").appendValue(languageTag); } }; } private static Matcher isConversationAction(String actionType) { return new BaseMatcher() { @Override public boolean matches(Object o) { if (!(o instanceof ConversationAction)) { return false; } ConversationAction conversationAction = (ConversationAction) o; if (!actionType.equals(conversationAction.getType())) { return false; } if (ConversationAction.TYPE_TEXT_REPLY.equals(actionType)) { if (conversationAction.getTextReply() == null) { return false; } } if (conversationAction.getConfidenceScore() < 0 || conversationAction.getConfidenceScore() > 1) { return false; } return true; } @Override public void describeTo(Description description) { description.appendText("actionType=").appendValue(actionType); } }; } }