• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.textclassifier;
18 
19 import static com.google.common.truth.Truth.assertThat;
20 import static org.hamcrest.CoreMatchers.not;
21 import static org.junit.Assert.assertArrayEquals;
22 import static org.junit.Assert.assertEquals;
23 import static org.junit.Assert.assertThat;
24 import static org.junit.Assert.assertTrue;
25 import static org.mockito.Mockito.any;
26 import static org.mockito.Mockito.eq;
27 import static org.mockito.Mockito.verify;
28 import static org.mockito.Mockito.when;
29 import static org.testng.Assert.expectThrows;
30 
31 import android.app.RemoteAction;
32 import android.content.Context;
33 import android.content.Intent;
34 import android.net.Uri;
35 import android.os.Bundle;
36 import android.os.LocaleList;
37 import android.text.Spannable;
38 import android.text.SpannableString;
39 import android.view.textclassifier.ConversationAction;
40 import android.view.textclassifier.ConversationActions;
41 import android.view.textclassifier.TextClassification;
42 import android.view.textclassifier.TextClassifier;
43 import android.view.textclassifier.TextLanguage;
44 import android.view.textclassifier.TextLinks;
45 import android.view.textclassifier.TextSelection;
46 import androidx.collection.LruCache;
47 import androidx.test.ext.junit.runners.AndroidJUnit4;
48 import androidx.test.filters.SdkSuppress;
49 import androidx.test.filters.SmallTest;
50 import com.android.textclassifier.common.ModelFile;
51 import com.android.textclassifier.common.ModelType;
52 import com.android.textclassifier.common.TextClassifierSettings;
53 import com.android.textclassifier.testing.FakeContextBuilder;
54 import com.android.textclassifier.testing.TestingDeviceConfig;
55 import com.google.android.textclassifier.AnnotatorModel;
56 import com.google.common.collect.ImmutableList;
57 import java.io.IOException;
58 import java.util.ArrayList;
59 import java.util.Arrays;
60 import java.util.Collections;
61 import java.util.List;
62 import org.hamcrest.BaseMatcher;
63 import org.hamcrest.Description;
64 import org.hamcrest.Matcher;
65 import org.junit.Before;
66 import org.junit.Test;
67 import org.junit.runner.RunWith;
68 import org.mockito.Mock;
69 import org.mockito.MockitoAnnotations;
70 
71 @SmallTest
72 @RunWith(AndroidJUnit4.class)
73 public class TextClassifierImplTest {
74 
75   private static final String TYPE_COPY = "copy";
76   private static final LocaleList LOCALES = LocaleList.forLanguageTags("en-US");
77   private static final String NO_TYPE = null;
78 
79   @Mock private ModelFileManager modelFileManager;
80 
81   private Context context;
82   private TestingDeviceConfig deviceConfig;
83   private TextClassifierSettings settings;
84   private LruCache<ModelFile, AnnotatorModel> annotatorModelCache;
85   private TextClassifierImpl classifier;
86 
87   @Before
setup()88   public void setup() throws IOException {
89     MockitoAnnotations.initMocks(this);
90     this.context =
91         new FakeContextBuilder()
92             .setAllIntentComponent(FakeContextBuilder.DEFAULT_COMPONENT)
93             .setAppLabel(FakeContextBuilder.DEFAULT_COMPONENT.getPackageName(), "Test app")
94             .build();
95     this.deviceConfig = new TestingDeviceConfig();
96     this.settings = new TextClassifierSettings(deviceConfig, /* isWear= */ false);
97     this.annotatorModelCache = new LruCache<>(2);
98     this.classifier =
99         new TextClassifierImpl(context, settings, modelFileManager, annotatorModelCache);
100 
101     when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any()))
102         .thenReturn(TestDataUtils.getTestAnnotatorModelFileWrapped());
103     when(modelFileManager.findBestModelFile(eq(ModelType.LANG_ID), any(), any()))
104         .thenReturn(TestDataUtils.getLangIdModelFileWrapped());
105     when(modelFileManager.findBestModelFile(eq(ModelType.ACTIONS_SUGGESTIONS), any(), any()))
106         .thenReturn(TestDataUtils.getTestActionsModelFileWrapped());
107   }
108 
109   @Test
testSuggestSelection()110   public void testSuggestSelection() throws IOException {
111     String text = "Contact me at droid@android.com";
112     String selected = "droid";
113     String suggested = "droid@android.com";
114     int startIndex = text.indexOf(selected);
115     int endIndex = startIndex + selected.length();
116     int smartStartIndex = text.indexOf(suggested);
117     int smartEndIndex = smartStartIndex + suggested.length();
118     TextSelection.Request request =
119         new TextSelection.Request.Builder(text, startIndex, endIndex).build();
120 
121     TextSelection selection = classifier.suggestSelection(null, null, request);
122     assertThat(
123         selection, isTextSelection(smartStartIndex, smartEndIndex, TextClassifier.TYPE_EMAIL));
124   }
125 
126   @Test
testSuggestSelection_localePreferenceIsPassedToModelFileManager()127   public void testSuggestSelection_localePreferenceIsPassedToModelFileManager() throws IOException {
128     String text = "Contact me at droid@android.com";
129     String selected = "droid";
130     String suggested = "droid@android.com";
131     int startIndex = text.indexOf(selected);
132     int endIndex = startIndex + selected.length();
133     int smartStartIndex = text.indexOf(suggested);
134     int smartEndIndex = smartStartIndex + suggested.length();
135     TextSelection.Request request =
136         new TextSelection.Request.Builder(text, startIndex, endIndex)
137             .setDefaultLocales(LOCALES)
138             .build();
139 
140     classifier.suggestSelection(null, null, request);
141     verify(modelFileManager).findBestModelFile(eq(ModelType.ANNOTATOR), eq(LOCALES), any());
142   }
143 
144   @Test
testSuggestSelection_url()145   public void testSuggestSelection_url() throws IOException {
146     String text = "Visit http://www.android.com for more information";
147     String selected = "http";
148     String suggested = "http://www.android.com";
149     int startIndex = text.indexOf(selected);
150     int endIndex = startIndex + selected.length();
151     int smartStartIndex = text.indexOf(suggested);
152     int smartEndIndex = smartStartIndex + suggested.length();
153     TextSelection.Request request =
154         new TextSelection.Request.Builder(text, startIndex, endIndex).build();
155 
156     TextSelection selection = classifier.suggestSelection(null, null, request);
157     assertThat(selection, isTextSelection(smartStartIndex, smartEndIndex, TextClassifier.TYPE_URL));
158   }
159 
160   @Test
testSmartSelection_withEmoji()161   public void testSmartSelection_withEmoji() throws IOException {
162     String text = "\uD83D\uDE02 Hello.";
163     String selected = "Hello";
164     int startIndex = text.indexOf(selected);
165     int endIndex = startIndex + selected.length();
166     TextSelection.Request request =
167         new TextSelection.Request.Builder(text, startIndex, endIndex).build();
168 
169     TextSelection selection = classifier.suggestSelection(null, null, request);
170     assertThat(selection, isTextSelection(startIndex, endIndex, NO_TYPE));
171   }
172 
173   @SdkSuppress(minSdkVersion = 31, codeName = "S")
174   @Test
testSuggestSelection_includeTextClassification()175   public void testSuggestSelection_includeTextClassification() throws IOException {
176     String text = "Visit http://www.android.com for more information";
177     String suggested = "http://www.android.com";
178     int startIndex = text.indexOf(suggested);
179     TextSelection.Request request =
180         new TextSelection.Request.Builder(text, startIndex, /* endIndex= */ startIndex + 1)
181             .setIncludeTextClassification(true)
182             .build();
183 
184     TextSelection selection = classifier.suggestSelection(null, null, request);
185 
186     assertThat(
187         selection.getTextClassification(),
188         isTextClassification(suggested, TextClassifier.TYPE_URL));
189     assertThat(selection.getTextClassification(), containsIntentWithAction(Intent.ACTION_VIEW));
190   }
191 
192   @SdkSuppress(minSdkVersion = 31, codeName = "S")
193   @Test
testSuggestSelection_notIncludeTextClassification()194   public void testSuggestSelection_notIncludeTextClassification() throws IOException {
195     String text = "Visit http://www.android.com for more information";
196     TextSelection.Request request =
197         new TextSelection.Request.Builder(text, /* startIndex= */ 0, /* endIndex= */ 4)
198             .setIncludeTextClassification(false)
199             .build();
200 
201     TextSelection selection = classifier.suggestSelection(null, null, request);
202 
203     assertThat(selection.getTextClassification()).isNull();
204   }
205 
206   @Test
testClassifyText()207   public void testClassifyText() throws IOException {
208     String text = "Contact me at droid@android.com";
209     String classifiedText = "droid@android.com";
210     int startIndex = text.indexOf(classifiedText);
211     int endIndex = startIndex + classifiedText.length();
212     TextClassification.Request request =
213         new TextClassification.Request.Builder(text, startIndex, endIndex).build();
214 
215     TextClassification classification =
216         classifier.classifyText(/* sessionId= */ null, null, request);
217     assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_EMAIL));
218   }
219 
220   @Test
testClassifyText_url()221   public void testClassifyText_url() throws IOException {
222     String text = "Visit www.android.com for more information";
223     String classifiedText = "www.android.com";
224     int startIndex = text.indexOf(classifiedText);
225     int endIndex = startIndex + classifiedText.length();
226     TextClassification.Request request =
227         new TextClassification.Request.Builder(text, startIndex, endIndex).build();
228 
229     TextClassification classification = classifier.classifyText(null, null, request);
230     assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_URL));
231     assertThat(classification, containsIntentWithAction(Intent.ACTION_VIEW));
232   }
233 
234   @Test
testClassifyText_address()235   public void testClassifyText_address() throws IOException {
236     String text = "Brandschenkestrasse 110, Zürich, Switzerland";
237     TextClassification.Request request =
238         new TextClassification.Request.Builder(text, 0, text.length()).build();
239 
240     TextClassification classification = classifier.classifyText(null, null, request);
241     assertThat(classification, isTextClassification(text, TextClassifier.TYPE_ADDRESS));
242   }
243 
244   @Test
testClassifyText_url_inCaps()245   public void testClassifyText_url_inCaps() throws IOException {
246     String text = "Visit HTTP://ANDROID.COM for more information";
247     String classifiedText = "HTTP://ANDROID.COM";
248     int startIndex = text.indexOf(classifiedText);
249     int endIndex = startIndex + classifiedText.length();
250     TextClassification.Request request =
251         new TextClassification.Request.Builder(text, startIndex, endIndex).build();
252 
253     TextClassification classification = classifier.classifyText(null, null, request);
254     assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_URL));
255     assertThat(classification, containsIntentWithAction(Intent.ACTION_VIEW));
256   }
257 
258   @Test
testClassifyText_date()259   public void testClassifyText_date() throws IOException {
260     String text = "Let's meet on January 9, 2018.";
261     String classifiedText = "January 9, 2018";
262     int startIndex = text.indexOf(classifiedText);
263     int endIndex = startIndex + classifiedText.length();
264     TextClassification.Request request =
265         new TextClassification.Request.Builder(text, startIndex, endIndex).build();
266 
267     TextClassification classification = classifier.classifyText(null, null, request);
268     assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_DATE));
269     Bundle extras = classification.getExtras();
270     List<Bundle> entities = ExtrasUtils.getEntities(extras);
271     assertThat(entities).hasSize(1);
272     assertThat(ExtrasUtils.getEntityType(entities.get(0))).isEqualTo(TextClassifier.TYPE_DATE);
273     ArrayList<Intent> actionsIntents = ExtrasUtils.getActionsIntents(classification);
274     actionsIntents.forEach(TextClassifierImplTest::assertNoPackageInfoInExtras);
275   }
276 
277   @Test
testClassifyText_datetime()278   public void testClassifyText_datetime() throws IOException {
279     String text = "Let's meet 2018/01/01 10:30:20.";
280     String classifiedText = "2018/01/01 10:30:20";
281     int startIndex = text.indexOf(classifiedText);
282     int endIndex = startIndex + classifiedText.length();
283     TextClassification.Request request =
284         new TextClassification.Request.Builder(text, startIndex, endIndex).build();
285 
286     TextClassification classification = classifier.classifyText(null, null, request);
287     assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_DATE_TIME));
288   }
289 
290   @Test
testClassifyText_foreignText()291   public void testClassifyText_foreignText() throws IOException {
292     LocaleList originalLocales = LocaleList.getDefault();
293     LocaleList.setDefault(LocaleList.forLanguageTags("en"));
294     String japaneseText = "これは日本語のテキストです";
295     TextClassification.Request request =
296         new TextClassification.Request.Builder(japaneseText, 0, japaneseText.length()).build();
297 
298     TextClassification classification = classifier.classifyText(null, null, request);
299     RemoteAction translateAction = classification.getActions().get(0);
300     assertEquals(1, classification.getActions().size());
301     assertEquals(Intent.ACTION_TRANSLATE, classification.getIntent().getAction());
302 
303     assertEquals(translateAction, ExtrasUtils.findTranslateAction(classification));
304     Intent intent = ExtrasUtils.getActionsIntents(classification).get(0);
305     assertNoPackageInfoInExtras(intent);
306     assertEquals(Intent.ACTION_TRANSLATE, intent.getAction());
307     Bundle foreignLanguageInfo = ExtrasUtils.getForeignLanguageExtra(classification);
308     assertEquals("ja", ExtrasUtils.getEntityType(foreignLanguageInfo));
309     assertTrue(ExtrasUtils.getScore(foreignLanguageInfo) >= 0);
310     assertTrue(ExtrasUtils.getScore(foreignLanguageInfo) <= 1);
311     assertTrue(intent.hasExtra(TextClassifier.EXTRA_FROM_TEXT_CLASSIFIER));
312     assertEquals("ja", ExtrasUtils.getTopLanguage(intent).first);
313 
314     LocaleList.setDefault(originalLocales);
315   }
316 
317   @Test
testGenerateLinks_phone()318   public void testGenerateLinks_phone() throws IOException {
319     String text = "The number is +12122537077. See you tonight!";
320     TextLinks.Request request = new TextLinks.Request.Builder(text).build();
321     assertThat(
322         classifier.generateLinks(null, null, request),
323         isTextLinksContaining(text, "+12122537077", TextClassifier.TYPE_PHONE));
324   }
325 
326   @Test
testGenerateLinks_exclude()327   public void testGenerateLinks_exclude() throws IOException {
328     String text = "The number is +12122537077. See you tonight!";
329     List<String> hints = ImmutableList.of();
330     List<String> included = ImmutableList.of();
331     List<String> excluded = Arrays.asList(TextClassifier.TYPE_PHONE);
332     TextLinks.Request request =
333         new TextLinks.Request.Builder(text)
334             .setEntityConfig(TextClassifier.EntityConfig.create(hints, included, excluded))
335             .build();
336     assertThat(
337         classifier.generateLinks(null, null, request),
338         not(isTextLinksContaining(text, "+12122537077", TextClassifier.TYPE_PHONE)));
339   }
340 
341   @Test
testGenerateLinks_explicit_address()342   public void testGenerateLinks_explicit_address() throws IOException {
343     String text = "The address is 1600 Amphitheater Parkway, Mountain View, CA. See you!";
344     List<String> explicit = Arrays.asList(TextClassifier.TYPE_ADDRESS);
345     TextLinks.Request request =
346         new TextLinks.Request.Builder(text)
347             .setEntityConfig(TextClassifier.EntityConfig.createWithExplicitEntityList(explicit))
348             .build();
349     assertThat(
350         classifier.generateLinks(null, null, request),
351         isTextLinksContaining(
352             text, "1600 Amphitheater Parkway, Mountain View, CA", TextClassifier.TYPE_ADDRESS));
353   }
354 
355   @Test
testGenerateLinks_exclude_override()356   public void testGenerateLinks_exclude_override() throws IOException {
357     String text = "You want apple@banana.com. See you tonight!";
358     List<String> hints = ImmutableList.of();
359     List<String> included = Arrays.asList(TextClassifier.TYPE_EMAIL);
360     List<String> excluded = Arrays.asList(TextClassifier.TYPE_EMAIL);
361     TextLinks.Request request =
362         new TextLinks.Request.Builder(text)
363             .setEntityConfig(TextClassifier.EntityConfig.create(hints, included, excluded))
364             .build();
365     assertThat(
366         classifier.generateLinks(null, null, request),
367         not(isTextLinksContaining(text, "apple@banana.com", TextClassifier.TYPE_EMAIL)));
368   }
369 
370   @Test
testGenerateLinks_maxLength()371   public void testGenerateLinks_maxLength() throws IOException {
372     char[] manySpaces = new char[classifier.getMaxGenerateLinksTextLength()];
373     Arrays.fill(manySpaces, ' ');
374     TextLinks.Request request = new TextLinks.Request.Builder(new String(manySpaces)).build();
375     TextLinks links = classifier.generateLinks(null, null, request);
376     assertTrue(links.getLinks().isEmpty());
377   }
378 
379   @Test
testApplyLinks_unsupportedCharacter()380   public void testApplyLinks_unsupportedCharacter() throws IOException {
381     Spannable url = new SpannableString("\u202Emoc.diordna.com");
382     TextLinks.Request request = new TextLinks.Request.Builder(url).build();
383     assertEquals(
384         TextLinks.STATUS_UNSUPPORTED_CHARACTER,
385         classifier.generateLinks(null, null, request).apply(url, 0, null));
386   }
387 
388   @Test
testGenerateLinks_tooLong()389   public void testGenerateLinks_tooLong() {
390     char[] manySpaces = new char[classifier.getMaxGenerateLinksTextLength() + 1];
391     Arrays.fill(manySpaces, ' ');
392     TextLinks.Request request = new TextLinks.Request.Builder(new String(manySpaces)).build();
393     expectThrows(
394         IllegalArgumentException.class, () -> classifier.generateLinks(null, null, request));
395   }
396 
397   @Test
testGenerateLinks_entityData()398   public void testGenerateLinks_entityData() throws IOException {
399     String text = "The number is +12122537077.";
400     Bundle extras = new Bundle();
401     ExtrasUtils.putIsSerializedEntityDataEnabled(extras, true);
402     TextLinks.Request request = new TextLinks.Request.Builder(text).setExtras(extras).build();
403 
404     TextLinks textLinks = classifier.generateLinks(null, null, request);
405 
406     assertThat(textLinks.getLinks()).hasSize(1);
407     TextLinks.TextLink textLink = textLinks.getLinks().iterator().next();
408     List<Bundle> entities = ExtrasUtils.getEntities(textLink.getExtras());
409     assertThat(entities).hasSize(1);
410     Bundle entity = entities.get(0);
411     assertThat(ExtrasUtils.getEntityType(entity)).isEqualTo(TextClassifier.TYPE_PHONE);
412   }
413 
414   @Test
testGenerateLinks_entityData_disabled()415   public void testGenerateLinks_entityData_disabled() throws IOException {
416     String text = "The number is +12122537077.";
417     TextLinks.Request request = new TextLinks.Request.Builder(text).build();
418 
419     TextLinks textLinks = classifier.generateLinks(null, null, request);
420 
421     assertThat(textLinks.getLinks()).hasSize(1);
422     TextLinks.TextLink textLink = textLinks.getLinks().iterator().next();
423     List<Bundle> entities = ExtrasUtils.getEntities(textLink.getExtras());
424     assertThat(entities).isNull();
425   }
426 
427   @Test
testDetectLanguage()428   public void testDetectLanguage() throws IOException {
429     String text = "This is English text";
430     TextLanguage.Request request = new TextLanguage.Request.Builder(text).build();
431     TextLanguage textLanguage = classifier.detectLanguage(null, null, request);
432     assertThat(textLanguage, isTextLanguage("en"));
433   }
434 
435   @Test
testDetectLanguage_japanese()436   public void testDetectLanguage_japanese() throws IOException {
437     String text = "これは日本語のテキストです";
438     TextLanguage.Request request = new TextLanguage.Request.Builder(text).build();
439     TextLanguage textLanguage = classifier.detectLanguage(null, null, request);
440     assertThat(textLanguage, isTextLanguage("ja"));
441   }
442 
443   @Test
testSuggestConversationActions_textReplyOnly_maxOne()444   public void testSuggestConversationActions_textReplyOnly_maxOne() throws IOException {
445     ConversationActions.Message message =
446         new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS)
447             .setText("Where are you?")
448             .build();
449     TextClassifier.EntityConfig typeConfig =
450         new TextClassifier.EntityConfig.Builder()
451             .includeTypesFromTextClassifier(false)
452             .setIncludedTypes(Collections.singletonList(ConversationAction.TYPE_TEXT_REPLY))
453             .build();
454     ConversationActions.Request request =
455         new ConversationActions.Request.Builder(Collections.singletonList(message))
456             .setMaxSuggestions(1)
457             .setTypeConfig(typeConfig)
458             .build();
459 
460     ConversationActions conversationActions =
461         classifier.suggestConversationActions(null, null, request);
462     assertThat(conversationActions.getConversationActions()).hasSize(1);
463     ConversationAction conversationAction = conversationActions.getConversationActions().get(0);
464     assertThat(conversationAction.getType()).isEqualTo(ConversationAction.TYPE_TEXT_REPLY);
465     assertThat(conversationAction.getTextReply()).isNotNull();
466   }
467 
468   @Test
testSuggestConversationActions_textReplyOnly_noMax()469   public void testSuggestConversationActions_textReplyOnly_noMax() throws IOException {
470     ConversationActions.Message message =
471         new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS)
472             .setText("Where are you?")
473             .build();
474     TextClassifier.EntityConfig typeConfig =
475         new TextClassifier.EntityConfig.Builder()
476             .includeTypesFromTextClassifier(false)
477             .setIncludedTypes(Collections.singletonList(ConversationAction.TYPE_TEXT_REPLY))
478             .build();
479     ConversationActions.Request request =
480         new ConversationActions.Request.Builder(Collections.singletonList(message))
481             .setTypeConfig(typeConfig)
482             .build();
483 
484     ConversationActions conversationActions =
485         classifier.suggestConversationActions(null, null, request);
486     assertTrue(conversationActions.getConversationActions().size() > 1);
487     for (ConversationAction conversationAction : conversationActions.getConversationActions()) {
488       assertThat(conversationAction, isConversationAction(ConversationAction.TYPE_TEXT_REPLY));
489     }
490   }
491 
492   @Test
testSuggestConversationActions_openUrl()493   public void testSuggestConversationActions_openUrl() throws IOException {
494     ConversationActions.Message message =
495         new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS)
496             .setText("Check this out: https://www.android.com")
497             .build();
498     TextClassifier.EntityConfig typeConfig =
499         new TextClassifier.EntityConfig.Builder()
500             .includeTypesFromTextClassifier(false)
501             .setIncludedTypes(Collections.singletonList(ConversationAction.TYPE_OPEN_URL))
502             .build();
503     ConversationActions.Request request =
504         new ConversationActions.Request.Builder(Collections.singletonList(message))
505             .setMaxSuggestions(1)
506             .setTypeConfig(typeConfig)
507             .build();
508 
509     ConversationActions conversationActions =
510         classifier.suggestConversationActions(null, null, request);
511     assertThat(conversationActions.getConversationActions()).hasSize(1);
512     ConversationAction conversationAction = conversationActions.getConversationActions().get(0);
513     assertThat(conversationAction.getType()).isEqualTo(ConversationAction.TYPE_OPEN_URL);
514     Intent actionIntent = ExtrasUtils.getActionIntent(conversationAction.getExtras());
515     assertThat(actionIntent.getAction()).isEqualTo(Intent.ACTION_VIEW);
516     assertThat(actionIntent.getData()).isEqualTo(Uri.parse("https://www.android.com"));
517     assertNoPackageInfoInExtras(actionIntent);
518   }
519 
520   @Test
testSuggestConversationActions_copy()521   public void testSuggestConversationActions_copy() throws IOException {
522     ConversationActions.Message message =
523         new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS)
524             .setText("Authentication code: 12345")
525             .build();
526     TextClassifier.EntityConfig typeConfig =
527         new TextClassifier.EntityConfig.Builder()
528             .includeTypesFromTextClassifier(false)
529             .setIncludedTypes(Collections.singletonList(TYPE_COPY))
530             .build();
531     ConversationActions.Request request =
532         new ConversationActions.Request.Builder(Collections.singletonList(message))
533             .setMaxSuggestions(1)
534             .setTypeConfig(typeConfig)
535             .build();
536 
537     ConversationActions conversationActions =
538         classifier.suggestConversationActions(null, null, request);
539     assertThat(conversationActions.getConversationActions()).hasSize(1);
540     ConversationAction conversationAction = conversationActions.getConversationActions().get(0);
541     assertThat(conversationAction.getType()).isEqualTo(TYPE_COPY);
542     assertThat(conversationAction.getTextReply()).isAnyOf(null, "");
543     assertThat(conversationAction.getAction()).isNull();
544     String code = ExtrasUtils.getCopyText(conversationAction.getExtras());
545     assertThat(code).isEqualTo("12345");
546     assertThat(ExtrasUtils.getSerializedEntityData(conversationAction.getExtras())).isNotEmpty();
547   }
548 
549   @Test
testSuggestConversationActions_deduplicate()550   public void testSuggestConversationActions_deduplicate() throws IOException {
551     ConversationActions.Message message =
552         new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS)
553             .setText("a@android.com b@android.com")
554             .build();
555     ConversationActions.Request request =
556         new ConversationActions.Request.Builder(Collections.singletonList(message))
557             .setMaxSuggestions(3)
558             .build();
559 
560     ConversationActions conversationActions =
561         classifier.suggestConversationActions(null, null, request);
562 
563     assertThat(conversationActions.getConversationActions()).isEmpty();
564   }
565 
566   @Test
testUseCachedAnnotatorModelDisabled()567   public void testUseCachedAnnotatorModelDisabled() throws IOException {
568     deviceConfig.setConfig(TextClassifierSettings.MODEL_DOWNLOAD_MANAGER_ENABLED, true);
569 
570     String annotatorFilePath = TestDataUtils.getTestAnnotatorModelFile().getPath();
571     ModelFile annotatorModelA =
572         new ModelFile(ModelType.ANNOTATOR, annotatorFilePath, 701, "en", false);
573     ModelFile annotatorModelB =
574         new ModelFile(ModelType.ANNOTATOR, annotatorFilePath, 801, "en", false);
575 
576     String englishText = "You can reach me on +12122537077.";
577     String classifiedText = "+12122537077";
578     TextClassification.Request request =
579         new TextClassification.Request.Builder(englishText, 0, englishText.length()).build();
580 
581     // Check modelFileA v701
582     when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any()))
583         .thenReturn(annotatorModelA);
584     TextClassification classificationA = classifier.classifyText(null, null, request);
585 
586     assertThat(classificationA.getId()).contains("v701");
587     assertThat(classificationA.getText()).contains(classifiedText);
588     assertArrayEquals(
589         new int[] {0, 0, 0, 0},
590         new int[] {
591           annotatorModelCache.putCount(),
592           annotatorModelCache.evictionCount(),
593           annotatorModelCache.hitCount(),
594           annotatorModelCache.missCount()
595         });
596 
597     // Check modelFileB v801
598     when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any()))
599         .thenReturn(annotatorModelB);
600     TextClassification classificationB = classifier.classifyText(null, null, request);
601 
602     assertThat(classificationB.getId()).contains("v801");
603     assertThat(classificationB.getText()).contains(classifiedText);
604     assertArrayEquals(
605         new int[] {0, 0, 0, 0},
606         new int[] {
607           annotatorModelCache.putCount(),
608           annotatorModelCache.evictionCount(),
609           annotatorModelCache.hitCount(),
610           annotatorModelCache.missCount()
611         });
612 
613     // Reload modelFileA v701
614     when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any()))
615         .thenReturn(annotatorModelA);
616     TextClassification classificationAcached = classifier.classifyText(null, null, request);
617 
618     assertThat(classificationAcached.getId()).contains("v701");
619     assertThat(classificationAcached.getText()).contains(classifiedText);
620     assertArrayEquals(
621         new int[] {0, 0, 0, 0},
622         new int[] {
623           annotatorModelCache.putCount(),
624           annotatorModelCache.evictionCount(),
625           annotatorModelCache.hitCount(),
626           annotatorModelCache.missCount()
627         });
628   }
629 
630   @Test
testUseCachedAnnotatorModelEnabled()631   public void testUseCachedAnnotatorModelEnabled() throws IOException {
632     deviceConfig.setConfig(TextClassifierSettings.MODEL_DOWNLOAD_MANAGER_ENABLED, true);
633     deviceConfig.setConfig(TextClassifierSettings.MULTI_ANNOTATOR_CACHE_ENABLED, true);
634 
635     String annotatorFilePath = TestDataUtils.getTestAnnotatorModelFile().getPath();
636     ModelFile annotatorModelA =
637         new ModelFile(ModelType.ANNOTATOR, annotatorFilePath, 701, "en", false);
638     ModelFile annotatorModelB =
639         new ModelFile(ModelType.ANNOTATOR, annotatorFilePath, 801, "en", false);
640 
641     String englishText = "You can reach me on +12122537077.";
642     String classifiedText = "+12122537077";
643     TextClassification.Request request =
644         new TextClassification.Request.Builder(englishText, 0, englishText.length()).build();
645 
646     // Check modelFileA v701
647     when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any()))
648         .thenReturn(annotatorModelA);
649     TextClassification classification = classifier.classifyText(null, null, request);
650 
651     assertThat(classification.getId()).contains("v701");
652     assertThat(classification.getText()).contains(classifiedText);
653     assertArrayEquals(
654         new int[] {1, 0, 0, 1},
655         new int[] {
656           annotatorModelCache.putCount(),
657           annotatorModelCache.evictionCount(),
658           annotatorModelCache.hitCount(),
659           annotatorModelCache.missCount()
660         });
661 
662     // Check modelFileB v801
663     when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any()))
664         .thenReturn(annotatorModelB);
665     TextClassification classificationB = classifier.classifyText(null, null, request);
666 
667     assertThat(classificationB.getId()).contains("v801");
668     assertThat(classificationB.getText()).contains(classifiedText);
669     assertArrayEquals(
670         new int[] {2, 0, 0, 2},
671         new int[] {
672           annotatorModelCache.putCount(),
673           annotatorModelCache.evictionCount(),
674           annotatorModelCache.hitCount(),
675           annotatorModelCache.missCount()
676         });
677 
678     // Reload modelFileA v701
679     when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any()))
680         .thenReturn(annotatorModelA);
681     TextClassification classificationAcached = classifier.classifyText(null, null, request);
682 
683     assertThat(classificationAcached.getId()).contains("v701");
684     assertThat(classificationAcached.getText()).contains(classifiedText);
685     assertArrayEquals(
686         new int[] {2, 0, 1, 2},
687         new int[] {
688           annotatorModelCache.putCount(),
689           annotatorModelCache.evictionCount(),
690           annotatorModelCache.hitCount(),
691           annotatorModelCache.missCount()
692         });
693   }
694 
assertNoPackageInfoInExtras(Intent intent)695   private static void assertNoPackageInfoInExtras(Intent intent) {
696     assertThat(intent.getComponent()).isNull();
697     assertThat(intent.getPackage()).isNull();
698   }
699 
isTextSelection( final int startIndex, final int endIndex, final String type)700   private static Matcher<TextSelection> isTextSelection(
701       final int startIndex, final int endIndex, final String type) {
702     return new BaseMatcher<TextSelection>() {
703       @Override
704       public boolean matches(Object o) {
705         if (o instanceof TextSelection) {
706           TextSelection selection = (TextSelection) o;
707           return startIndex == selection.getSelectionStartIndex()
708               && endIndex == selection.getSelectionEndIndex()
709               && typeMatches(selection, type);
710         }
711         return false;
712       }
713 
714       private boolean typeMatches(TextSelection selection, String type) {
715         return type == null
716             || (selection.getEntityCount() > 0
717                 && type.trim().equalsIgnoreCase(selection.getEntity(0)));
718       }
719 
720       @Override
721       public void describeTo(Description description) {
722         description.appendValue(String.format("%d, %d, %s", startIndex, endIndex, type));
723       }
724     };
725   }
726 
727   private static Matcher<TextLinks> isTextLinksContaining(
728       final String text, final String substring, final String type) {
729     return new BaseMatcher<TextLinks>() {
730 
731       @Override
732       public void describeTo(Description description) {
733         description
734             .appendText("text=")
735             .appendValue(text)
736             .appendText(", substring=")
737             .appendValue(substring)
738             .appendText(", type=")
739             .appendValue(type);
740       }
741 
742       @Override
743       public boolean matches(Object o) {
744         if (o instanceof TextLinks) {
745           for (TextLinks.TextLink link : ((TextLinks) o).getLinks()) {
746             if (text.subSequence(link.getStart(), link.getEnd()).toString().equals(substring)) {
747               return type.equals(link.getEntity(0));
748             }
749           }
750         }
751         return false;
752       }
753     };
754   }
755 
756   private static Matcher<TextClassification> isTextClassification(
757       final String text, final String type) {
758     return new BaseMatcher<TextClassification>() {
759       @Override
760       public boolean matches(Object o) {
761         if (o instanceof TextClassification) {
762           TextClassification result = (TextClassification) o;
763           return text.equals(result.getText())
764               && result.getEntityCount() > 0
765               && type.equals(result.getEntity(0));
766         }
767         return false;
768       }
769 
770       @Override
771       public void describeTo(Description description) {
772         description.appendText("text=").appendValue(text).appendText(", type=").appendValue(type);
773       }
774     };
775   }
776 
777   private static Matcher<TextClassification> containsIntentWithAction(final String action) {
778     return new BaseMatcher<TextClassification>() {
779       @Override
780       public boolean matches(Object o) {
781         if (o instanceof TextClassification) {
782           TextClassification result = (TextClassification) o;
783           return ExtrasUtils.findAction(result, action) != null;
784         }
785         return false;
786       }
787 
788       @Override
789       public void describeTo(Description description) {
790         description.appendText("intent action=").appendValue(action);
791       }
792     };
793   }
794 
795   private static Matcher<TextLanguage> isTextLanguage(final String languageTag) {
796     return new BaseMatcher<TextLanguage>() {
797       @Override
798       public boolean matches(Object o) {
799         if (o instanceof TextLanguage) {
800           TextLanguage result = (TextLanguage) o;
801           return result.getLocaleHypothesisCount() > 0
802               && languageTag.equals(result.getLocale(0).toLanguageTag());
803         }
804         return false;
805       }
806 
807       @Override
808       public void describeTo(Description description) {
809         description.appendText("locale=").appendValue(languageTag);
810       }
811     };
812   }
813 
814   private static Matcher<ConversationAction> isConversationAction(String actionType) {
815     return new BaseMatcher<ConversationAction>() {
816       @Override
817       public boolean matches(Object o) {
818         if (!(o instanceof ConversationAction)) {
819           return false;
820         }
821         ConversationAction conversationAction = (ConversationAction) o;
822         if (!actionType.equals(conversationAction.getType())) {
823           return false;
824         }
825         if (ConversationAction.TYPE_TEXT_REPLY.equals(actionType)) {
826           if (conversationAction.getTextReply() == null) {
827             return false;
828           }
829         }
830         if (conversationAction.getConfidenceScore() < 0
831             || conversationAction.getConfidenceScore() > 1) {
832           return false;
833         }
834         return true;
835       }
836 
837       @Override
838       public void describeTo(Description description) {
839         description.appendText("actionType=").appendValue(actionType);
840       }
841     };
842   }
843 }
844